From c3c45cc37f0a2f7ae0916c5a1c7a996e2234d883 Mon Sep 17 00:00:00 2001 From: actiontech-zihan Date: Wed, 20 May 2026 06:35:39 +0000 Subject: [PATCH] refactor(audit): enhance SQL audit middleware to utilize session context from streamExecute requests and improve error handling --- .../service/sql_workbench_service.go | 156 +++++++++++++----- 1 file changed, 111 insertions(+), 45 deletions(-) diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index ce7dc994..300ead12 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -1062,16 +1063,17 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar // 注意:解析仅服务于审核辅助路径,解析失败不应直接阻塞用户的 SQL 执行; // 否则一旦中间件辅助能力出错(如 sid 解码失败),用户连查询都跑不了。 // 真正的「未启用审核 / 审核失败」等强策略仍由后续分支按既有 fail-closed 处理。 - sql, datasourceID, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes) + sql, sidInfo, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes) if err != nil { sqlWorkbenchService.log.Warnf("failed to parse streamExecute request, skipping audit: %v", err) return next(c) } - if sql == "" || datasourceID == "" { + if sql == "" || sidInfo == nil || sidInfo.datasourceID == "" { sqlWorkbenchService.log.Warnf("SQL or datasource ID is empty, skipping audit") return next(c) } + datasourceID := sidInfo.datasourceID // 获取当前用户 ID dmsUserId, err := sqlWorkbenchService.getDMSUserIdFromRequest(c) @@ -1099,14 +1101,24 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditGetDBServiceErr)) } - // 检查是否启用 SQL 审核 + // 未开启 SQL 审核时直接放行,由 ODC 执行 SQL if !sqlWorkbenchService.isEnableSQLAudit(dbService) { sqlWorkbenchService.log.Debugf("SQL audit is not enabled for DBService: %s", dmsDBServiceID) - return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditNotEnabledErr)) + return next(c) + } + + schemaName := sidInfo.schemaName + if schemaName == "" && sidInfo.dbID > 0 { + resolved, resolveErr := sqlWorkbenchService.getODCDatabaseName(c.Request().Context(), c, sidInfo.dbID) + if resolveErr != nil { + sqlWorkbenchService.log.Warnf("failed to resolve schema from ODC database id %d: %v", sidInfo.dbID, resolveErr) + } else { + schemaName = resolved + } } // 调用 SQLE 审核接口 - auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService) + auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService, schemaName) if err != nil { sqlWorkbenchService.log.Errorf("call SQLE audit failed: %v", err) return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditCallSQLEErr)) @@ -1118,75 +1130,129 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar } } -// parseStreamExecuteRequest 解析 streamExecute 请求体,提取 SQL 和 datasource ID -func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteRequest(bodyBytes []byte) (sql string, datasourceID string, err error) { +// streamExecuteSidInfo 从 ODC streamExecute 请求的 sid 中解析出的会话上下文 +type streamExecuteSidInfo struct { + datasourceID string + schemaName string + dbID int +} + +// parseStreamExecuteRequest 解析 streamExecute 请求体,提取 SQL 与会话 sid 信息 +func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteRequest(bodyBytes []byte) (sql string, sidInfo *streamExecuteSidInfo, err error) { var requestBody map[string]interface{} if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return "", "", fmt.Errorf("failed to unmarshal request body: %v", err) + return "", nil, fmt.Errorf("failed to unmarshal request body: %v", err) } - // 从 sql 字段获取 SQL if sqlVal, ok := requestBody["sql"]; ok { if sqlStr, ok := sqlVal.(string); ok { sql = sqlStr } } - // 从 sid 字段解析 datasource ID - // sid 格式: sid:{base64编码的JSON}:d:dms - // base64 JSON 包含: {"dbId":623,"dsId":28,"from":"192.168.21.47","logicalSession":false,"realId":"ee9b8ab276"} if sidVal, ok := requestBody["sid"]; ok { if sidStr, ok := sidVal.(string); ok { - dsId, parseErr := sqlWorkbenchService.parseSidToDatasourceID(sidStr) - if parseErr != nil { - sqlWorkbenchService.log.Debugf("Failed to parse sid to datasource ID: %v", parseErr) - } else { - datasourceID = dsId + sidInfo, err = sqlWorkbenchService.parseStreamExecuteSid(sidStr) + if err != nil { + sqlWorkbenchService.log.Debugf("failed to parse streamExecute sid: %v", err) } } } - return sql, datasourceID, nil + return sql, sidInfo, nil } -// parseSidToDatasourceID 从 sid 字符串中解析出 datasource ID -// sid 格式: sid:{base64编码的JSON}:d:dms -func (sqlWorkbenchService *SqlWorkbenchService) parseSidToDatasourceID(sid string) (string, error) { - // 检查 sid 格式: sid:...:d:dms +// parseStreamExecuteSid 解析 ODC sid(与 pathUtil.generateDatabaseSid 对齐): +// - sid:{sessionId}:d:{dbName} — 选库后执行 SQL,dbName 为 encodeURIComponent 编码 +// - sid:{sessionId}:did:{databaseId} — 按库 ID 建会话 +// - sid:{base64SessionJSON} — 无库后缀,JSON 含 dsId/dbId +// 其中 sessionId 常为 ODC DefaultConnectSessionIdGenerator 生成的 base64 JSON。 +func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteSid(sid string) (*streamExecuteSidInfo, error) { if !strings.HasPrefix(sid, "sid:") { - return "", fmt.Errorf("invalid sid format, missing 'sid:' prefix") + return nil, fmt.Errorf("invalid sid format, missing 'sid:' prefix") } + rest := strings.TrimPrefix(sid, "sid:") + info := &streamExecuteSidInfo{} - // 移除 "sid:" 前缀 - sid = strings.TrimPrefix(sid, "sid:") + // 必须先匹配 :did:,因其包含 :d: 子串 + if didMarker := strings.LastIndex(rest, ":did:"); didMarker != -1 { + dbPart := rest[didMarker+5:] + prefix := rest[:didMarker] + dbID, err := strconv.Atoi(dbPart) + if err != nil { + return nil, fmt.Errorf("invalid database id in sibd: %v", err) + } + info.dbID = dbID + if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(prefix, info); err != nil { + sqlWorkbenchService.log.Debugf("sid prefix is not base64 session JSON, skip: %v", err) + } + return info, nil + } + if dMarker := strings.LastIndex(rest, ":d:"); dMarker != -1 { + dbPart := rest[dMarker+3:] + prefix := rest[:dMarker] + schemaName, err := url.QueryUnescape(dbPart) + if err != nil { + return nil, fmt.Errorf("failed to decode database name from sid: %v", err) + } + info.schemaName = schemaName + if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(prefix, info); err != nil { + sqlWorkbenchService.log.Debugf("sid prefix is not base64 session JSON, skip: %v", err) + } + return info, nil + } - // 查找最后一个 ":d" 后缀并移除从 ":d" 开始的所有字符 - if idx := strings.LastIndex(sid, ":d"); idx != -1 { - sid = sid[:idx] + if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(rest, info); err != nil { + return nil, err } + return info, nil +} - // ODC 服务端使用 Base64.getUrlEncoder() 生成 sessionId(URL-safe,包含 '-'/'_'), - // 这里必须用 URLEncoding 解码,否则遇到 '-'/'_' 会报 illegal base64 data。 - decodedBytes, err := base64.URLEncoding.DecodeString(sid) +func (sqlWorkbenchService *SqlWorkbenchService) fillStreamExecuteSidFromBase64(encoded string, info *streamExecuteSidInfo) error { + decodedBytes, err := base64.URLEncoding.DecodeString(encoded) if err != nil { - return "", fmt.Errorf("failed to decode base64 sid: %v", err) + return fmt.Errorf("failed to decode base64 sid: %v", err) } - - // 解析 JSON var sidData struct { - DbId int `json:"dbId"` - DsId int `json:"dsId"` - From string `json:"from"` - LogicalSession bool `json:"logicalSession"` - RealId string `json:"realId"` + DbId int `json:"dbId"` + DsId int `json:"dsId"` } - if err := json.Unmarshal(decodedBytes, &sidData); err != nil { - return "", fmt.Errorf("failed to unmarshal sid JSON: %v", err) + return fmt.Errorf("failed to unmarshal sid JSON: %v", err) } + info.datasourceID = fmt.Sprintf("%d", sidData.DsId) + info.dbID = sidData.DbId + return nil +} - // 返回 dsId 作为字符串 - return fmt.Sprintf("%d", sidData.DsId), nil +// getODCDatabaseName 通过 ODC MetaDB 中的 databaseId 查询库名,作为 SQLE 审核的 schema 上下文 +func (sqlWorkbenchService *SqlWorkbenchService) getODCDatabaseName(ctx context.Context, c echo.Context, dbID int) (string, error) { + if dbID <= 0 { + return "", nil + } + if sqlWorkbenchService.cfg == nil || sqlWorkbenchService.cfg.Host == "" || sqlWorkbenchService.cfg.Port == "" { + return "", fmt.Errorf("sql workbench is not configured") + } + odcURL := fmt.Sprintf("http://%s:%s/api/v2/database/databases/%d", sqlWorkbenchService.cfg.Host, sqlWorkbenchService.cfg.Port, dbID) + header := map[string]string{} + var cookieParts []string + for _, cookie := range c.Request().Cookies() { + if cookie.Name == ODCSessionCookieName || cookie.Name == ODCXsrfTokenCookieName { + cookieParts = append(cookieParts, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value)) + } + } + if len(cookieParts) > 0 { + header["Cookie"] = strings.Join(cookieParts, "; ") + } + reply := struct { + Data struct { + Name string `json:"name"` + } `json:"data"` + }{} + if err := pkgHttp.Get(ctx, odcURL, header, nil, &reply); err != nil { + return "", err + } + return reply.Data.Name, nil } // getDMSUserIdFromRequest 从请求中获取 DMS 用户 ID @@ -1256,7 +1322,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) isEnableSQLAudit(dbService *biz. } // callSQLEAudit 调用 SQLE 直接审核接口 -func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService) (*cloudbeaver.AuditSQLReply, error) { +func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService, schemaName string) (*cloudbeaver.AuditSQLReply, error) { // 获取 SQLE 服务地址 target, err := sqlWorkbenchService.proxyTargetRepo.GetProxyTargetByName(ctx, _const.SqleComponentName) if err != nil { @@ -1265,7 +1331,6 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex sqleAddr := fmt.Sprintf("%s/v2/sql_audit", target.URL.String()) - // 构建审核请求 auditReq := cloudbeaver.AuditSQLReq{ InstanceType: dbService.DBType, SQLContent: sql, @@ -1273,6 +1338,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex ProjectId: dbService.ProjectUID, InstanceName: dbService.Name, RuleTemplateName: dbService.SQLEConfig.SQLQueryConfig.RuleTemplateName, + SchemaName: schemaName, } // 设置请求头