@@ -23,13 +23,17 @@ package workloadbasedlearning
2323
2424import (
2525 "context"
26+ "encoding/json"
2627 "strings"
2728 "time"
2829
30+ "github.com/pingcap/tidb/pkg/infoschema"
2931 "github.com/pingcap/tidb/pkg/kv"
32+ "github.com/pingcap/tidb/pkg/parser/ast"
3033 "github.com/pingcap/tidb/pkg/sessionctx"
3134 "github.com/pingcap/tidb/pkg/util"
3235 "github.com/pingcap/tidb/pkg/util/logutil"
36+ "github.com/pingcap/tidb/pkg/util/sqlescape"
3337 "go.uber.org/zap"
3438)
3539
@@ -67,15 +71,15 @@ func NewWorkloadBasedLearningHandle(pool util.SessionPool) *Handle {
6771//
6872// 4. Calculate table cost for each table, table cost = table scan time / total scan time + table mem usage / total mem usage
6973// 5. Save all table cost metrics[per table](scan time, table cost, etc) to table "mysql.workload_values"
70- func (handle * Handle ) HandleReadTableCost () {
74+ func (handle * Handle ) HandleReadTableCost (infoSchema infoschema. InfoSchema ) {
7175 // step1: abstract middle table cost metrics from every record in statement_summary
7276 middleMetrics , startTime , endTime := handle .analyzeBasedOnStatementStats ()
7377 if len (middleMetrics ) == 0 {
7478 return
7579 }
7680 // step2: group by tablename, sum(table-scan-time), sum(table-mem-usage), sum(read-frequency)
7781 // step3: calculate the total scan time and total memory usage
78- tableNameToMetrics := make (map [string ]* ReadTableCostMetrics )
82+ tableNameToMetrics := make (map [ast. CIStr ]* ReadTableCostMetrics )
7983 totalScanTime := 0.0
8084 totalMemUsage := 0.0
8185 for _ , middleMetric := range middleMetrics {
@@ -98,7 +102,7 @@ func (handle *Handle) HandleReadTableCost() {
98102 metric .tableCost = metric .tableScanTime / totalScanTime + metric .tableMemUsage / totalMemUsage
99103 }
100104 // step5: save the table cost metrics to table "mysql.workload_values"
101- handle .saveReadTableCostMetrics (tableNameToMetrics , startTime , endTime )
105+ handle .saveReadTableCostMetrics (tableNameToMetrics , startTime , endTime , infoSchema )
102106}
103107
104108func (handle * Handle ) analyzeBasedOnStatementSummary () []* ReadTableCostMetrics {
@@ -115,17 +119,20 @@ func (handle *Handle) analyzeBasedOnStatementStats() ([]*ReadTableCostMetrics, t
115119 return nil , time .Now (), time .Now ()
116120}
117121
122+ // TODO save the workload job info such as start end time into workload_jobs table
118123// table cost metrics, workload-based start and end time, version,
119- func (handle * Handle ) saveReadTableCostMetrics (metrics map [string ]* ReadTableCostMetrics , startTime , endTime time.Time ) {
124+ func (handle * Handle ) saveReadTableCostMetrics (metrics map [ast.CIStr ]* ReadTableCostMetrics ,
125+ startTime , endTime time.Time , infoSchema infoschema.InfoSchema ) {
120126 // step1: create a new session, context, txn for saving table cost metrics
121- // TODO enable the plan cache
122127 se , err := handle .sysSessionPool .Get ()
123128 if err != nil {
124129 logutil .BgLogger ().Warn ("get system session failed when saving table cost metrics" , zap .Error (err ))
125130 return
126131 }
127132 defer handle .sysSessionPool .Put (se )
128133 sctx := se .(sessionctx.Context )
134+ // enable plan cache
135+ sctx .GetSessionVars ().EnableNonPreparedPlanCache = true
129136 txn , err := sctx .Txn (true )
130137 if err != nil {
131138 logutil .BgLogger ().Warn ("get txn failed when saving table cost metrics" , zap .Error (err ))
@@ -136,47 +143,61 @@ func (handle *Handle) saveReadTableCostMetrics(metrics map[string]*ReadTableCost
136143
137144 // step2: insert new version table cost metrics by batch using one common txn and context
138145 version := txn .StartTS ()
139- // build insert stringBuilder by batch(1000 tables)
146+ // build insert sql by batch(1000 tables)
140147 i := 0
141- stringBuilder := new (strings.Builder )
142- stringBuilder .WriteString ("insert into mysql.workload_values (version, category, type, table_id, value) values " )
143- for tableName , metric := range metrics {
144- stringBuilder .WriteString ("(" )
145- stringBuilder .WriteString (version )
146- stringBuilder .WriteString (", " )
147- stringBuilder .WriteString (feedbackCategory )
148- stringBuilder .WriteString (", " )
149- stringBuilder .WriteString (tableCostType )
150- stringBuilder .WriteString (", " )
151- // TODO get the table id by table name
152- tableId := 0
153- stringBuilder .WriteString (tableId )
154- stringBuilder .WriteString ("', " )
155- // TODO build the value and start end time to json
156- stringBuilder .WriteString (metric )
157- stringBuilder .WriteString (")" )
148+ sql := new (strings.Builder )
149+ sqlescape .MustFormatSQL (sql , "insert into mysql.workload_values (version, category, type, table_id, value) values " )
150+ for _ , metric := range metrics {
151+ tbl , err := infoSchema .TableByName (ctx , metric .dbName , metric .tableName )
152+ if err != nil {
153+ logutil .BgLogger ().Warn ("Failed to save this table cost metrics due to table id not found in info schema" ,
154+ zap .String ("db_name" , metric .dbName .String ()),
155+ zap .String ("table_name" , metric .tableName .String ()),
156+ zap .Float64 ("table_scan_time" , metric .tableScanTime ),
157+ zap .Float64 ("table_mem_usage" , metric .tableMemUsage ),
158+ zap .Int64 ("read_frequency" , metric .readFrequency ),
159+ zap .Float64 ("table_cost" , metric .tableCost ),
160+ zap .Error (err ))
161+ continue
162+ }
163+ metricBytes , err := json .Marshal (metric )
164+ if err != nil {
165+ logutil .BgLogger ().Warn ("Marshal table cost metrics failed" ,
166+ zap .String ("db_name" , metric .dbName .String ()),
167+ zap .String ("table_name" , metric .tableName .String ()),
168+ zap .Float64 ("table_scan_time" , metric .tableScanTime ),
169+ zap .Float64 ("table_mem_usage" , metric .tableMemUsage ),
170+ zap .Int64 ("read_frequency" , metric .readFrequency ),
171+ zap .Float64 ("table_cost" , metric .tableCost ),
172+ zap .Error (err ))
173+ continue
174+ }
175+ sqlescape .MustFormatSQL (sql , "(%?, %?, %?, %?, %?)" ,
176+ version , feedbackCategory , tableCostType , tbl .Meta ().ID , json .RawMessage (metricBytes ))
177+ // TODO check the txn record limit
158178 if i % batchInsertSize == batchInsertSize - 1 {
159- _ , _ , err := exec .ExecRestrictedSQL (ctx , nil , stringBuilder .String ())
179+ _ , _ , err := exec .ExecRestrictedSQL (ctx , nil , sql .String ())
160180 if err != nil {
161181 logutil .BgLogger ().Warn ("insert new version table cost metrics failed" , zap .Error (err ))
162182 return
163183 }
164- stringBuilder .Reset ()
165- stringBuilder .WriteString ("insert into mysql.workload_values (version, category, type, table_id, value) values " )
184+ sql .Reset ()
185+ sql .WriteString ("insert into mysql.workload_values (version, category, type, table_id, value) values " )
166186 } else {
167- stringBuilder .WriteString (", " )
187+ sql .WriteString (", " )
168188 }
169189 i ++
170190 }
171191 // insert the last batch
172- if stringBuilder .Len () != 0 {
192+ if sql .Len () != 0 {
173193 // remove the tail comma
174- sql := stringBuilder .String ()[:stringBuilder .Len ()- 2 ]
194+ sql := sql .String ()[:sql .Len ()- 2 ]
175195 _ , _ , err := exec .ExecRestrictedSQL (ctx , nil , sql )
176196 if err != nil {
177197 logutil .BgLogger ().Warn ("insert new version table cost metrics failed" , zap .Error (err ))
178198 return
179199 }
180200 }
181-
201+ // step3: commit the txn, finish the save
202+ sctx .CommitTxn (ctx )
182203}
0 commit comments