diff --git a/pkg/domain/BUILD.bazel b/pkg/domain/BUILD.bazel index f3b14f2778db9..d0a6522d87622 100644 --- a/pkg/domain/BUILD.bazel +++ b/pkg/domain/BUILD.bazel @@ -96,7 +96,7 @@ go_library( "//pkg/util/sqlexec", "//pkg/util/sqlkiller", "//pkg/util/syncutil", - "//pkg/workloadbasedlearning", + "//pkg/workloadlearning", "@com_github_burntsushi_toml//:toml", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index fe270830223a2..af8288b53c101 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -99,7 +99,7 @@ import ( "github.com/pingcap/tidb/pkg/util/size" "github.com/pingcap/tidb/pkg/util/sqlkiller" "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/pingcap/tidb/pkg/workloadbasedlearning" + "github.com/pingcap/tidb/pkg/workloadlearning" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/txnkv/transaction" pd "github.com/tikv/pd/client" @@ -3394,7 +3394,7 @@ func (do *Domain) planCacheEvictTrigger() { // SetupWorkloadBasedLearningWorker sets up all of the workload based learning workers. func (do *Domain) SetupWorkloadBasedLearningWorker() { - wbLearningHandle := workloadbasedlearning.NewWorkloadBasedLearningHandle() + wbLearningHandle := workloadlearning.NewWorkloadLearningHandle(do.sysSessionPool) // Start the workload based learning worker to analyze the read workload by statement_summary. do.wg.Run( func() { @@ -3406,7 +3406,7 @@ func (do *Domain) SetupWorkloadBasedLearningWorker() { } // readTableCostWorker is a background worker that periodically analyze the read path table cost by statement_summary. -func (do *Domain) readTableCostWorker(wbLearningHandle *workloadbasedlearning.Handle) { +func (do *Domain) readTableCostWorker(wbLearningHandle *workloadlearning.Handle) { // Recover the panic and log the error when worker exit. defer util.Recover(metrics.LabelDomain, "readTableCostWorker", nil, false) readTableCostTicker := time.NewTicker(vardef.WorkloadBasedLearningInterval.Load()) @@ -3418,7 +3418,7 @@ func (do *Domain) readTableCostWorker(wbLearningHandle *workloadbasedlearning.Ha select { case <-readTableCostTicker.C: if vardef.EnableWorkloadBasedLearning.Load() && do.statsOwner.IsOwner() { - wbLearningHandle.HandleReadTableCost() + wbLearningHandle.HandleReadTableCost(do.InfoSchema()) } case <-do.exit: return diff --git a/pkg/executor/importer/importer_testkit_test.go b/pkg/executor/importer/importer_testkit_test.go index a97c0e93fbae9..5093416f5ad87 100644 --- a/pkg/executor/importer/importer_testkit_test.go +++ b/pkg/executor/importer/importer_testkit_test.go @@ -345,7 +345,7 @@ func TestProcessChunkWith(t *testing.T) { require.NoError(t, err) checksumMap := checksum.GetInnerChecksums() require.Len(t, checksumMap, 1) - require.Equal(t, verify.MakeKVChecksum(111, 3, 14585065391351463171), *checksumMap[verify.DataKVGroupID]) + require.Equal(t, verify.MakeKVChecksum(111, 3, 17951921359894607752), *checksumMap[verify.DataKVGroupID]) }) } diff --git a/pkg/executor/infoschema_cluster_table_test.go b/pkg/executor/infoschema_cluster_table_test.go index 568a381762197..e189fc53253fb 100644 --- a/pkg/executor/infoschema_cluster_table_test.go +++ b/pkg/executor/infoschema_cluster_table_test.go @@ -406,7 +406,7 @@ func TestTableStorageStats(t *testing.T) { "test 2", )) rows := tk.MustQuery("select TABLE_NAME from information_schema.TABLE_STORAGE_STATS where TABLE_SCHEMA = 'mysql';").Rows() - result := 58 + result := 59 require.Len(t, rows, result) // More tests about the privileges. diff --git a/pkg/executor/infoschema_reader_test.go b/pkg/executor/infoschema_reader_test.go index 47cc9d558bf75..4c993a6ac70fb 100644 --- a/pkg/executor/infoschema_reader_test.go +++ b/pkg/executor/infoschema_reader_test.go @@ -920,22 +920,22 @@ func TestInfoSchemaDDLJobs(t *testing.T) { tk2 := testkit.NewTestKit(t, store) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE table_name = "t1";`).Check(testkit.RowsWithSep("|", - "131|add index|public|124|129|t1|synced", - "130|create table|public|124|129|t1|synced", - "117|add index|public|110|115|t1|synced", - "116|create table|public|110|115|t1|synced", + "133|add index|public|126|131|t1|synced", + "132|create table|public|126|131|t1|synced", + "119|add index|public|112|117|t1|synced", + "118|create table|public|112|117|t1|synced", )) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE db_name = "d1" and JOB_TYPE LIKE "add index%%";`).Check(testkit.RowsWithSep("|", - "137|add index|public|124|135|t3|synced", - "134|add index|public|124|132|t2|synced", - "131|add index|public|124|129|t1|synced", - "128|add index|public|124|126|t0|synced", + "139|add index|public|126|137|t3|synced", + "136|add index|public|126|134|t2|synced", + "133|add index|public|126|131|t1|synced", + "130|add index|public|126|128|t0|synced", )) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE db_name = "d0" and table_name = "t3";`).Check(testkit.RowsWithSep("|", - "123|add index|public|110|121|t3|synced", - "122|create table|public|110|121|t3|synced", + "125|add index|public|112|123|t3|synced", + "124|create table|public|112|123|t3|synced", )) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE state = "running";`).Check(testkit.Rows()) @@ -946,15 +946,15 @@ func TestInfoSchemaDDLJobs(t *testing.T) { if job.SchemaState == model.StateWriteOnly && loaded.CompareAndSwap(false, true) { tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE table_name = "t0" and state = "running";`).Check(testkit.RowsWithSep("|", - "138 add index write only 110 112 t0 running", + "140 add index write only 112 114 t0 running", )) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE db_name = "d0" and state = "running";`).Check(testkit.RowsWithSep("|", - "138 add index write only 110 112 t0 running", + "140 add index write only 112 114 t0 running", )) tk2.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE state = "running";`).Check(testkit.RowsWithSep("|", - "138 add index write only 110 112 t0 running", + "140 add index write only 112 114 t0 running", )) } }) @@ -970,8 +970,8 @@ func TestInfoSchemaDDLJobs(t *testing.T) { tk.MustExec("create table test2.t1(id int)") tk.MustQuery(`SELECT JOB_ID, JOB_TYPE, SCHEMA_STATE, SCHEMA_ID, TABLE_ID, table_name, STATE FROM information_schema.ddl_jobs WHERE db_name = "test2" and table_name = "t1"`).Check(testkit.RowsWithSep("|", - "147|create table|public|144|146|t1|synced", - "142|create table|public|139|141|t1|synced", + "149|create table|public|146|148|t1|synced", + "144|create table|public|141|143|t1|synced", )) // Test explain output, since the output may change in future. diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 8a23ce27198f2..c6d53c4ee974e 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -1752,18 +1752,18 @@ func TestTiDBEncodeKey(t *testing.T) { err := tk.QueryToErr("select tidb_encode_record_key('test', 't1', 0);") require.ErrorContains(t, err, "doesn't exist") tk.MustQuery("select tidb_encode_record_key('test', 't', 1);"). - Check(testkit.Rows("74800000000000006e5f728000000000000001")) + Check(testkit.Rows("7480000000000000705f728000000000000001")) tk.MustExec("alter table t add index i(b);") err = tk.QueryToErr("select tidb_encode_index_key('test', 't', 'i1', 1);") require.ErrorContains(t, err, "index not found") tk.MustQuery("select tidb_encode_index_key('test', 't', 'i', 1, 1);"). - Check(testkit.Rows("74800000000000006e5f698000000000000001038000000000000001038000000000000001")) + Check(testkit.Rows("7480000000000000705f698000000000000001038000000000000001038000000000000001")) tk.MustExec("create table t1 (a int primary key, b int) partition by hash(a) partitions 4;") tk.MustExec("insert into t1 values (1, 1);") - tk.MustQuery("select tidb_encode_record_key('test', 't1(p1)', 1);").Check(testkit.Rows("7480000000000000735f728000000000000001")) - rs := tk.MustQuery("select tidb_mvcc_info('74800000000000006f5f728000000000000001');") + tk.MustQuery("select tidb_encode_record_key('test', 't1(p1)', 1);").Check(testkit.Rows("7480000000000000755f728000000000000001")) + rs := tk.MustQuery("select tidb_mvcc_info('74800000000000007f5f728000000000000001');") mvccInfo := rs.Rows()[0][0].(string) require.NotEqual(t, mvccInfo, `{"info":{}}`) @@ -1772,14 +1772,14 @@ func TestTiDBEncodeKey(t *testing.T) { tk2 := testkit.NewTestKit(t, store) err = tk2.Session().Auth(&auth.UserIdentity{Username: "alice", Hostname: "localhost"}, nil, nil, nil) require.NoError(t, err) - err = tk2.QueryToErr("select tidb_mvcc_info('74800000000000006f5f728000000000000001');") + err = tk2.QueryToErr("select tidb_mvcc_info('74800000000000007f5f728000000000000001');") require.ErrorContains(t, err, "Access denied") err = tk2.QueryToErr("select tidb_encode_record_key('test', 't1(p1)', 1);") require.ErrorContains(t, err, "SELECT command denied") err = tk2.QueryToErr("select tidb_encode_index_key('test', 't', 'i1', 1);") require.ErrorContains(t, err, "SELECT command denied") tk.MustExec("grant select on test.t1 to 'alice'@'%';") - tk2.MustQuery("select tidb_encode_record_key('test', 't1(p1)', 1);").Check(testkit.Rows("7480000000000000735f728000000000000001")) + tk2.MustQuery("select tidb_encode_record_key('test', 't1(p1)', 1);").Check(testkit.Rows("7480000000000000755f728000000000000001")) } func TestIssue9710(t *testing.T) { diff --git a/pkg/kv/option.go b/pkg/kv/option.go index 43567235704f1..faeb53585ec88 100644 --- a/pkg/kv/option.go +++ b/pkg/kv/option.go @@ -197,6 +197,8 @@ const ( InternalTxnStats = "stats" // InternalTxnBindInfo is the type of bind info txn. InternalTxnBindInfo = InternalTxnOthers + // InternalTxnWorkloadLearning is the type of workload-based learning txn. + InternalTxnWorkloadLearning = "WorkloadLearning" // InternalTxnSysVar is the type of sys var txn. InternalTxnSysVar = InternalTxnOthers // InternalTxnAdmin is the type of admin operations. diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index 6e37280bd4e51..3c6577d9c9a0f 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -759,7 +759,7 @@ const ( extra json, -- for the cloud env to save more info like RU, cost_saving, ... index idx_create(created_at), index idx_update(updated_at), - unique index idx(schema_name, table_name, index_columns))` + unique index idx(schema_name, table_name, index_columns));` // CreateKernelOptionsTable is a table to store kernel options for tidb. CreateKernelOptionsTable = `CREATE TABLE IF NOT EXISTS mysql.tidb_kernel_options ( @@ -769,7 +769,18 @@ const ( updated_at datetime, status varchar(128), description text, - primary key(module, name))` + primary key(module, name));` + + // CreateTiDBWorkloadValuesTable is a table to store workload-based learning values for tidb. + CreateTiDBWorkloadValuesTable = `CREATE TABLE IF NOT EXISTS mysql.tidb_workload_values ( + id bigint(20) NOT NULL AUTO_INCREMENT PRIMARY KEY, + version bigint(20) NOT NULL, + category varchar(64) NOT NULL, + type varchar(64) NOT NULL, + table_id bigint(20) NOT NULL, + value json NOT NULL, + index idx_version_category_type (version, category, type), + index idx_table_id (table_id));` ) // CreateTimers is a table to store all timers for tidb @@ -1247,6 +1258,7 @@ const ( // version 242 // insert `cluster_id` into the `mysql.tidb` table. + // Add workload-based learning system tables version242 = 242 ) @@ -3345,8 +3357,8 @@ func upgradeToVer242(s sessiontypes.Session, ver int64) { if ver >= version242 { return } - writeClusterID(s) + mustExecute(s, CreateTiDBWorkloadValuesTable) } // initGlobalVariableIfNotExists initialize a global variable with specific val if it does not exist. @@ -3503,6 +3515,8 @@ func doDDLWorks(s sessiontypes.Session) { mustExecute(s, CreateIndexAdvisorTable) // create mysql.tidb_kernel_options mustExecute(s, CreateKernelOptionsTable) + // create mysql.tidb_workload_values + mustExecute(s, CreateTiDBWorkloadValuesTable) } // doBootstrapSQLFile executes SQL commands in a file as the last stage of bootstrap. diff --git a/pkg/session/bootstrap_test.go b/pkg/session/bootstrap_test.go index 06aabd8234c25..6060382b93034 100644 --- a/pkg/session/bootstrap_test.go +++ b/pkg/session/bootstrap_test.go @@ -240,6 +240,8 @@ func TestBootstrapWithError(t *testing.T) { // Check tidb_ttl_table_status table MustExec(t, se, "SELECT * from mysql.tidb_ttl_table_status") + // Check mysql.tidb_workload_values table + MustExec(t, se, "SELECT * from mysql.tidb_workload_values") } func TestDDLTableCreateBackfillTable(t *testing.T) { diff --git a/pkg/statistics/handle/autoanalyze/priorityqueue/dynamic_partitioned_table_analysis_job_test.go b/pkg/statistics/handle/autoanalyze/priorityqueue/dynamic_partitioned_table_analysis_job_test.go index 95f519184ce36..9d62c93f3afe8 100644 --- a/pkg/statistics/handle/autoanalyze/priorityqueue/dynamic_partitioned_table_analysis_job_test.go +++ b/pkg/statistics/handle/autoanalyze/priorityqueue/dynamic_partitioned_table_analysis_job_test.go @@ -133,12 +133,14 @@ func TestValidateAndPrepareForDynamicPartitionedTable(t *testing.T) { tk.MustExec("create table example_table (a int, b int, index idx(a)) partition by range (a) (partition p0 values less than (2), partition p1 values less than (4))") tableInfo, err := dom.InfoSchema().TableByName(context.Background(), ast.NewCIStr("example_schema"), ast.NewCIStr("example_table")) require.NoError(t, err) + partitionInfo := tableInfo.Meta().GetPartitionInfo() + require.NotNil(t, partitionInfo) job := &priorityqueue.DynamicPartitionedTableAnalysisJob{ SchemaName: "example_schema", GlobalTableID: tableInfo.Meta().ID, PartitionIDs: map[int64]struct{}{ - 113: {}, - 114: {}, + partitionInfo.Definitions[0].ID: {}, + partitionInfo.Definitions[1].ID: {}, }, Weight: 2, } diff --git a/pkg/statistics/handle/handletest/initstats/load_stats_test.go b/pkg/statistics/handle/handletest/initstats/load_stats_test.go index 81d7689a8fe1d..5de1994aeeab9 100644 --- a/pkg/statistics/handle/handletest/initstats/load_stats_test.go +++ b/pkg/statistics/handle/handletest/initstats/load_stats_test.go @@ -100,7 +100,7 @@ func testConcurrentlyInitStats(t *testing.T) { require.False(t, col.IsAllEvicted()) } } - require.Equal(t, int64(126), handle.GetMaxTidRecordForTest()) + require.Equal(t, int64(128), handle.GetMaxTidRecordForTest()) } func TestDropTableBeforeConcurrentlyInitStats(t *testing.T) { diff --git a/pkg/workloadbasedlearning/BUILD.bazel b/pkg/workloadbasedlearning/BUILD.bazel deleted file mode 100644 index a18b0f3427405..0000000000000 --- a/pkg/workloadbasedlearning/BUILD.bazel +++ /dev/null @@ -1,11 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -go_library( - name = "workloadbasedlearning", - srcs = [ - "handle.go", - "metrics.go", - ], - importpath = "github.com/pingcap/tidb/pkg/workloadbasedlearning", - visibility = ["//visibility:public"], -) diff --git a/pkg/workloadbasedlearning/handle.go b/pkg/workloadbasedlearning/handle.go deleted file mode 100644 index 54dff3f74e5b9..0000000000000 --- a/pkg/workloadbasedlearning/handle.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package workloadbasedlearning implements the Workload-Based Learning Optimizer. -// The Workload-Based Learning Optimizer introduces a new module in TiDB that leverages captured workload history to -// enhance the database query optimizer. -// By learning from historical data, this module helps the optimizer make smarter decisions, such as identify hot and cold tables, -// analyze resource consumption, etc. -// The workload analysis results can be used to directly suggest a better path, -// or to indirectly influence the cost model and stats so that the optimizer can select the best plan more intelligently and adaptively. -package workloadbasedlearning - -// Handle The entry point for all workload-based learning related tasks -type Handle struct { -} - -// NewWorkloadBasedLearningHandle Create a new WorkloadBasedLearningHandle -// WorkloadBasedLearningHandle is Singleton pattern -func NewWorkloadBasedLearningHandle() *Handle { - return &Handle{} -} - -// HandleReadTableCost Start a new round of analysis of all historical read queries. -// According to abstracted table cost metrics, calculate the percentage of read scan time and memory usage for each table. -// The result will be saved to the table "mysql.workload_values". -// Dataflow -// 1. Abstract middle table cost metrics(scan time, memory usage, read frequency) -// from every record in statement_summary/statement_stats -// -// 2,3. Group by tablename, get the total scan time, total memory usage, and every table scan time, memory usage, -// -// read frequency -// -// 4. Calculate table cost for each table, table cost = table scan time / total scan time + table mem usage / total mem usage -// 5. Save all table cost metrics[per table](scan time, table cost, etc) to table "mysql.workload_values" -func (handle *Handle) HandleReadTableCost() { - // step1: abstract middle table cost metrics from every record in statement_summary - middleMetrics := handle.analyzeBasedOnStatementSummary() - if len(middleMetrics) == 0 { - return - } - // step2: group by tablename, sum(table-scan-time), sum(table-mem-usage), sum(read-frequency) - // step3: calculate the total scan time and total memory usage - tableNameToMetrics := make(map[string]*ReadTableCostMetrics) - totalScanTime := 0.0 - totalMemUsage := 0.0 - for _, middleMetric := range middleMetrics { - metric, ok := tableNameToMetrics[middleMetric.tableName] - if !ok { - tableNameToMetrics[middleMetric.tableName] = middleMetric - } else { - metric.tableScanTime += middleMetric.tableScanTime * float64(middleMetric.readFrequency) - metric.tableMemUsage += middleMetric.tableMemUsage * float64(middleMetric.readFrequency) - metric.readFrequency += middleMetric.readFrequency - } - totalScanTime += middleMetric.tableScanTime - totalMemUsage += middleMetric.tableMemUsage - } - if totalScanTime == 0 || totalMemUsage == 0 { - return - } - // step4: calculate the percentage of scan time and memory usage for each table - for _, metric := range tableNameToMetrics { - metric.tableCost = metric.tableScanTime/totalScanTime + metric.tableMemUsage/totalMemUsage - } - // TODO step5: save the table cost metrics to table "mysql.workload_values" -} - -func (handle *Handle) analyzeBasedOnStatementSummary() []*ReadTableCostMetrics { - // step1: get all record from statement_summary - // step2: abstract table cost metrics from each record - return nil -} - -// TODO -func (handle *Handle) analyzeBasedOnStatementStats() []*ReadTableCostMetrics { - // step1: get all record from statement_stats - // step2: abstract table cost metrics from each record - return nil -} diff --git a/pkg/workloadlearning/BUILD.bazel b/pkg/workloadlearning/BUILD.bazel new file mode 100644 index 0000000000000..0413811722eba --- /dev/null +++ b/pkg/workloadlearning/BUILD.bazel @@ -0,0 +1,35 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "workloadlearning", + srcs = [ + "handle.go", + "metrics.go", + ], + importpath = "github.com/pingcap/tidb/pkg/workloadlearning", + visibility = ["//visibility:public"], + deps = [ + "//pkg/infoschema", + "//pkg/kv", + "//pkg/parser/ast", + "//pkg/sessionctx", + "//pkg/sessiontxn", + "//pkg/util", + "//pkg/util/logutil", + "//pkg/util/sqlescape", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "workloadlearning_test", + timeout = "short", + srcs = ["handle_test.go"], + flaky = True, + deps = [ + ":workloadlearning", + "//pkg/parser/ast", + "//pkg/testkit", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/workloadlearning/handle.go b/pkg/workloadlearning/handle.go new file mode 100644 index 0000000000000..b589b4196ccfe --- /dev/null +++ b/pkg/workloadlearning/handle.go @@ -0,0 +1,214 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package workloadlearning implements the Workload-Based Learning Optimizer. +// The Workload-Based Learning Optimizer introduces a new module in TiDB that leverages captured workload history to +// enhance the database query optimizer. +// By learning from historical data, this module helps the optimizer make smarter decisions, such as identify hot and cold tables, +// analyze resource consumption, etc. +// The workload analysis results can be used to directly suggest a better path, +// or to indirectly influence the cost model and stats so that the optimizer can select the best plan more intelligently and adaptively. +package workloadlearning + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "go.uber.org/zap" +) + +const batchInsertSize = 1000 +const ( + // The category of workload-based learning + feedbackCategory = "Feedback" +) +const ( + // The type of workload-based learning + tableCostType = "TableCost" +) + +// Handle The entry point for all workload-based learning related tasks +type Handle struct { + sysSessionPool util.SessionPool +} + +// NewWorkloadLearningHandle Create a new WorkloadLearningHandle +// WorkloadLearningHandle is Singleton pattern +func NewWorkloadLearningHandle(pool util.SessionPool) *Handle { + return &Handle{pool} +} + +// HandleReadTableCost Start a new round of analysis of all historical read queries. +// According to abstracted table cost metrics, calculate the percentage of read scan time and memory usage for each table. +// The result will be saved to the table "mysql.tidb_workload_values". +// Dataflow +// 1. Abstract middle table cost metrics(scan time, memory usage, read frequency) +// from every record in statement_summary/statement_stats +// +// 2,3. Group by tablename, get the total scan time, total memory usage, and every table scan time, memory usage, +// +// read frequency +// +// 4. Calculate table cost for each table, table cost = table scan time / total scan time + table mem usage / total mem usage +// 5. Save all table cost metrics[per table](scan time, table cost, etc) to table "mysql.tidb_workload_values" +func (handle *Handle) HandleReadTableCost(infoSchema infoschema.InfoSchema) { + // step1: abstract middle table cost metrics from every record in statement_summary + middleMetrics, startTime, endTime := handle.analyzeBasedOnStatementStats() + if len(middleMetrics) == 0 { + return + } + // step2: group by tablename, sum(table-scan-time), sum(table-mem-usage), sum(read-frequency) + // step3: calculate the total scan time and total memory usage + tableNameToMetrics := make(map[ast.CIStr]*ReadTableCostMetrics) + totalScanTime := 0.0 + totalMemUsage := 0.0 + for _, middleMetric := range middleMetrics { + metric, ok := tableNameToMetrics[middleMetric.TableName] + if !ok { + tableNameToMetrics[middleMetric.TableName] = middleMetric + } else { + metric.TableScanTime += middleMetric.TableScanTime * float64(middleMetric.ReadFrequency) + metric.TableMemUsage += middleMetric.TableMemUsage * float64(middleMetric.ReadFrequency) + metric.ReadFrequency += middleMetric.ReadFrequency + } + totalScanTime += middleMetric.TableScanTime + totalMemUsage += middleMetric.TableMemUsage + } + if totalScanTime == 0 || totalMemUsage == 0 { + return + } + // step4: calculate the percentage of scan time and memory usage for each table + for _, metric := range tableNameToMetrics { + metric.TableCost = metric.TableScanTime/totalScanTime + metric.TableMemUsage/totalMemUsage + } + // step5: save the table cost metrics to table "mysql.tidb_workload_values" + handle.SaveReadTableCostMetrics(tableNameToMetrics, startTime, endTime, infoSchema) +} + +func (handle *Handle) analyzeBasedOnStatementSummary() []*ReadTableCostMetrics { + // step1: get all record from statement_summary + // step2: abstract table cost metrics from each record + return nil +} + +// TODO +func (handle *Handle) analyzeBasedOnStatementStats() ([]*ReadTableCostMetrics, time.Time, time.Time) { + // step1: get all record from statement_stats + // step2: abstract table cost metrics from each record + // TODO change the mock value + return nil, time.Now(), time.Now() +} + +// SaveReadTableCostMetrics table cost metrics, workload-based start and end time, version, +func (handle *Handle) SaveReadTableCostMetrics(metrics map[ast.CIStr]*ReadTableCostMetrics, + startTime, endTime time.Time, infoSchema infoschema.InfoSchema) { + // TODO save the workload job info such as start end time into workload_jobs table + // step1: create a new session, context, txn for saving table cost metrics + se, err := handle.sysSessionPool.Get() + if err != nil { + logutil.BgLogger().Warn("get system session failed when saving table cost metrics", zap.Error(err)) + return + } + // TODO to destroy the error session instead of put it back to the pool + defer handle.sysSessionPool.Put(se) + sctx := se.(sessionctx.Context) + exec := sctx.GetRestrictedSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning) + // begin a new txn + err = sessiontxn.NewTxn(context.Background(), sctx) + if err != nil { + logutil.BgLogger().Warn("get txn failed when saving table cost metrics", zap.Error(err)) + return + } + txn, err := sctx.Txn(true) + if err != nil { + logutil.BgLogger().Warn("failed to get txn when saving table cost metrics", zap.Error(err)) + return + } + // enable plan cache + sctx.GetSessionVars().EnableNonPreparedPlanCache = true + + // step2: insert new version table cost metrics by batch using one common txn and context + version := txn.StartTS() + // build insert sql by batch(1000 tables) + i := 0 + sql := new(strings.Builder) + sqlescape.MustFormatSQL(sql, "insert into mysql.tidb_workload_values (version, category, type, table_id, value) values ") + for _, metric := range metrics { + tbl, err := infoSchema.TableByName(ctx, metric.DbName, metric.TableName) + if err != nil { + logutil.BgLogger().Warn("failed to save this table cost metrics due to table id not found in info schema", + zap.String("db_name", metric.DbName.String()), + zap.String("table_name", metric.TableName.String()), + zap.Float64("table_scan_time", metric.TableScanTime), + zap.Float64("table_mem_usage", metric.TableMemUsage), + zap.Int64("read_frequency", metric.ReadFrequency), + zap.Float64("table_cost", metric.TableCost), + zap.Error(err)) + continue + } + metricBytes, err := json.Marshal(metric) + if err != nil { + logutil.BgLogger().Warn("marshal table cost metrics failed", + zap.String("db_name", metric.DbName.String()), + zap.String("table_name", metric.TableName.String()), + zap.Float64("table_scan_time", metric.TableScanTime), + zap.Float64("table_mem_usage", metric.TableMemUsage), + zap.Int64("read_frequency", metric.ReadFrequency), + zap.Float64("table_cost", metric.TableCost), + zap.Error(err)) + continue + } + sqlescape.MustFormatSQL(sql, "(%?, %?, %?, %?, %?)", + version, feedbackCategory, tableCostType, tbl.Meta().ID, json.RawMessage(metricBytes)) + // TODO check the txn record limit + if i%batchInsertSize == batchInsertSize-1 { + _, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + logutil.BgLogger().Warn("insert new version table cost metrics failed", zap.Error(err)) + return + } + sql.Reset() + sql.WriteString("insert into mysql.tidb_workload_values (version, category, type, table_id, value) values ") + } else { + sql.WriteString(", ") + } + i++ + } + // insert the last batch + if sql.Len() != 0 { + // remove the tail comma + sql := sql.String()[:sql.Len()-2] + _, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + logutil.BgLogger().Warn("insert new version table cost metrics failed", zap.Error(err)) + return + } + } + // step3: commit the txn, finish the save + err = txn.Commit(context.Background()) + if err != nil { + logutil.BgLogger().Warn("commit txn failed when saving table cost metrics", zap.Error(err)) + } +} diff --git a/pkg/workloadlearning/handle_test.go b/pkg/workloadlearning/handle_test.go new file mode 100644 index 0000000000000..8aa93f3c1c029 --- /dev/null +++ b/pkg/workloadlearning/handle_test.go @@ -0,0 +1,50 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadlearning_test + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/workloadlearning" + "github.com/stretchr/testify/require" +) + +func TestSaveReadTableCostMetrics(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec("create table test (a int, b int, index idx(a))") + // mock a table cost metrics + readTableCostMetrics := &workloadlearning.ReadTableCostMetrics{ + DbName: ast.CIStr{O: "test", L: "test"}, + TableName: ast.CIStr{O: "test", L: "test"}, + TableScanTime: 10.0, + TableMemUsage: 10.0, + ReadFrequency: 10, + TableCost: 1.0, + } + tableCostMetrics := map[ast.CIStr]*workloadlearning.ReadTableCostMetrics{ + {O: "test", L: "test"}: readTableCostMetrics, + } + handle := workloadlearning.NewWorkloadLearningHandle(dom.SysSessionPool()) + handle.SaveReadTableCostMetrics(tableCostMetrics, time.Now(), time.Now(), dom.InfoSchema()) + + // check the result + result := tk.MustQuery("select * from mysql.tidb_workload_values").Rows() + require.Equal(t, 1, len(result)) +} diff --git a/pkg/workloadbasedlearning/metrics.go b/pkg/workloadlearning/metrics.go similarity index 67% rename from pkg/workloadbasedlearning/metrics.go rename to pkg/workloadlearning/metrics.go index e707d52f6f466..3d769a59f7c09 100644 --- a/pkg/workloadbasedlearning/metrics.go +++ b/pkg/workloadlearning/metrics.go @@ -12,19 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package workloadbasedlearning +package workloadlearning + +import "github.com/pingcap/tidb/pkg/parser/ast" // ReadTableCostMetrics is used to indicate the intermediate status and results analyzed through read workload // for function "HandleReadTableCost". type ReadTableCostMetrics struct { - tableName string - // tableScanTime[t] = sum(scan-time * readFrequency) of all records in statement_summary where table-name = t - tableScanTime float64 - // tableMemUsage[t] = sum(mem-usage * readFrequency) of all records in statement_summary where table-name = t - tableMemUsage float64 - // readFrequency[t] = sum(read-frequency) of all records in statement_summary where table-name = t - readFrequency int64 - // tableCost[t] = tableScanTime[t] / totalScanTime + tableMemUsage[t] / totalMemUsage + DbName ast.CIStr + TableName ast.CIStr + // TableScanTime[t] = sum(scan-time * readFrequency) of all records in statement_summary where table-name = t + TableScanTime float64 + // TableMemUsage[t] = sum(mem-usage * readFrequency) of all records in statement_summary where table-name = t + TableMemUsage float64 + // ReadFrequency[t] = sum(read-frequency) of all records in statement_summary where table-name = t + ReadFrequency int64 + // TableCost[t] = TableScanTime[t] / totalScanTime + TableMemUsage[t] / totalMemUsage // range between 0 ~ 2 - tableCost float64 + TableCost float64 }