Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 111 additions & 45 deletions internal/sql_workbench/service/sql_workbench_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -1062,16 +1063,17 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar
// 注意:解析仅服务于审核辅助路径,解析失败不应直接阻塞用户的 SQL 执行;
// 否则一旦中间件辅助能力出错(如 sid 解码失败),用户连查询都跑不了。
// 真正的「未启用审核 / 审核失败」等强策略仍由后续分支按既有 fail-closed 处理。
sql, datasourceID, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes)
sql, sidInfo, err := sqlWorkbenchService.parseStreamExecuteRequest(bodyBytes)
if err != nil {
sqlWorkbenchService.log.Warnf("failed to parse streamExecute request, skipping audit: %v", err)
return next(c)
}

if sql == "" || datasourceID == "" {
if sql == "" || sidInfo == nil || sidInfo.datasourceID == "" {
sqlWorkbenchService.log.Warnf("SQL or datasource ID is empty, skipping audit")
return next(c)
}
datasourceID := sidInfo.datasourceID

// 获取当前用户 ID
dmsUserId, err := sqlWorkbenchService.getDMSUserIdFromRequest(c)
Expand Down Expand Up @@ -1099,14 +1101,24 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar
return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditGetDBServiceErr))
}

// 检查是否启用 SQL 审核
// 未开启 SQL 审核时直接放行,由 ODC 执行 SQL
if !sqlWorkbenchService.isEnableSQLAudit(dbService) {
sqlWorkbenchService.log.Debugf("SQL audit is not enabled for DBService: %s", dmsDBServiceID)
return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditNotEnabledErr))
return next(c)
}

schemaName := sidInfo.schemaName
if schemaName == "" && sidInfo.dbID > 0 {
resolved, resolveErr := sqlWorkbenchService.getODCDatabaseName(c.Request().Context(), c, sidInfo.dbID)
if resolveErr != nil {
sqlWorkbenchService.log.Warnf("failed to resolve schema from ODC database id %d: %v", sidInfo.dbID, resolveErr)
} else {
schemaName = resolved
}
}

// 调用 SQLE 审核接口
auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService)
auditResult, err := sqlWorkbenchService.callSQLEAudit(c.Request().Context(), sql, dbService, schemaName)
if err != nil {
sqlWorkbenchService.log.Errorf("call SQLE audit failed: %v", err)
return errors.New(locale.Bundle.LocalizeMsgByCtx(c.Request().Context(), locale.SqlWorkbenchAuditCallSQLEErr))
Expand All @@ -1118,75 +1130,129 @@ func (sqlWorkbenchService *SqlWorkbenchService) AuditMiddleware() echo.Middlewar
}
}

// parseStreamExecuteRequest 解析 streamExecute 请求体,提取 SQL 和 datasource ID
func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteRequest(bodyBytes []byte) (sql string, datasourceID string, err error) {
// streamExecuteSidInfo 从 ODC streamExecute 请求的 sid 中解析出的会话上下文
type streamExecuteSidInfo struct {
datasourceID string
schemaName string
dbID int
}

// parseStreamExecuteRequest 解析 streamExecute 请求体,提取 SQL 与会话 sid 信息
func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteRequest(bodyBytes []byte) (sql string, sidInfo *streamExecuteSidInfo, err error) {
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
return "", "", fmt.Errorf("failed to unmarshal request body: %v", err)
return "", nil, fmt.Errorf("failed to unmarshal request body: %v", err)
}

// 从 sql 字段获取 SQL
if sqlVal, ok := requestBody["sql"]; ok {
if sqlStr, ok := sqlVal.(string); ok {
sql = sqlStr
}
}

// 从 sid 字段解析 datasource ID
// sid 格式: sid:{base64编码的JSON}:d:dms
// base64 JSON 包含: {"dbId":623,"dsId":28,"from":"192.168.21.47","logicalSession":false,"realId":"ee9b8ab276"}
if sidVal, ok := requestBody["sid"]; ok {
if sidStr, ok := sidVal.(string); ok {
dsId, parseErr := sqlWorkbenchService.parseSidToDatasourceID(sidStr)
if parseErr != nil {
sqlWorkbenchService.log.Debugf("Failed to parse sid to datasource ID: %v", parseErr)
} else {
datasourceID = dsId
sidInfo, err = sqlWorkbenchService.parseStreamExecuteSid(sidStr)
if err != nil {
sqlWorkbenchService.log.Debugf("failed to parse streamExecute sid: %v", err)
}
}
}

return sql, datasourceID, nil
return sql, sidInfo, nil
}

// parseSidToDatasourceID 从 sid 字符串中解析出 datasource ID
// sid 格式: sid:{base64编码的JSON}:d:dms
func (sqlWorkbenchService *SqlWorkbenchService) parseSidToDatasourceID(sid string) (string, error) {
// 检查 sid 格式: sid:...:d:dms
// parseStreamExecuteSid 解析 ODC sid(与 pathUtil.generateDatabaseSid 对齐):
// - sid:{sessionId}:d:{dbName} — 选库后执行 SQL,dbName 为 encodeURIComponent 编码
// - sid:{sessionId}:did:{databaseId} — 按库 ID 建会话
// - sid:{base64SessionJSON} — 无库后缀,JSON 含 dsId/dbId
// 其中 sessionId 常为 ODC DefaultConnectSessionIdGenerator 生成的 base64 JSON。
func (sqlWorkbenchService *SqlWorkbenchService) parseStreamExecuteSid(sid string) (*streamExecuteSidInfo, error) {
if !strings.HasPrefix(sid, "sid:") {
return "", fmt.Errorf("invalid sid format, missing 'sid:' prefix")
return nil, fmt.Errorf("invalid sid format, missing 'sid:' prefix")
}
rest := strings.TrimPrefix(sid, "sid:")
info := &streamExecuteSidInfo{}

// 移除 "sid:" 前缀
sid = strings.TrimPrefix(sid, "sid:")
// 必须先匹配 :did:,因其包含 :d: 子串
if didMarker := strings.LastIndex(rest, ":did:"); didMarker != -1 {
dbPart := rest[didMarker+5:]
prefix := rest[:didMarker]
dbID, err := strconv.Atoi(dbPart)
if err != nil {
return nil, fmt.Errorf("invalid database id in sibd: %v", err)
}
info.dbID = dbID
if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(prefix, info); err != nil {
sqlWorkbenchService.log.Debugf("sid prefix is not base64 session JSON, skip: %v", err)
}
return info, nil
}
if dMarker := strings.LastIndex(rest, ":d:"); dMarker != -1 {
dbPart := rest[dMarker+3:]
prefix := rest[:dMarker]
schemaName, err := url.QueryUnescape(dbPart)
if err != nil {
return nil, fmt.Errorf("failed to decode database name from sid: %v", err)
}
info.schemaName = schemaName
if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(prefix, info); err != nil {
sqlWorkbenchService.log.Debugf("sid prefix is not base64 session JSON, skip: %v", err)
}
return info, nil
}

// 查找最后一个 ":d" 后缀并移除从 ":d" 开始的所有字符
if idx := strings.LastIndex(sid, ":d"); idx != -1 {
sid = sid[:idx]
if err := sqlWorkbenchService.fillStreamExecuteSidFromBase64(rest, info); err != nil {
return nil, err
}
return info, nil
}

// ODC 服务端使用 Base64.getUrlEncoder() 生成 sessionId(URL-safe,包含 '-'/'_'),
// 这里必须用 URLEncoding 解码,否则遇到 '-'/'_' 会报 illegal base64 data。
decodedBytes, err := base64.URLEncoding.DecodeString(sid)
func (sqlWorkbenchService *SqlWorkbenchService) fillStreamExecuteSidFromBase64(encoded string, info *streamExecuteSidInfo) error {
decodedBytes, err := base64.URLEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("failed to decode base64 sid: %v", err)
return fmt.Errorf("failed to decode base64 sid: %v", err)
}

// 解析 JSON
var sidData struct {
DbId int `json:"dbId"`
DsId int `json:"dsId"`
From string `json:"from"`
LogicalSession bool `json:"logicalSession"`
RealId string `json:"realId"`
DbId int `json:"dbId"`
DsId int `json:"dsId"`
}

if err := json.Unmarshal(decodedBytes, &sidData); err != nil {
return "", fmt.Errorf("failed to unmarshal sid JSON: %v", err)
return fmt.Errorf("failed to unmarshal sid JSON: %v", err)
}
info.datasourceID = fmt.Sprintf("%d", sidData.DsId)
info.dbID = sidData.DbId
return nil
}

// 返回 dsId 作为字符串
return fmt.Sprintf("%d", sidData.DsId), nil
// getODCDatabaseName 通过 ODC MetaDB 中的 databaseId 查询库名,作为 SQLE 审核的 schema 上下文
func (sqlWorkbenchService *SqlWorkbenchService) getODCDatabaseName(ctx context.Context, c echo.Context, dbID int) (string, error) {
if dbID <= 0 {
return "", nil
}
if sqlWorkbenchService.cfg == nil || sqlWorkbenchService.cfg.Host == "" || sqlWorkbenchService.cfg.Port == "" {
return "", fmt.Errorf("sql workbench is not configured")
}
odcURL := fmt.Sprintf("http://%s:%s/api/v2/database/databases/%d", sqlWorkbenchService.cfg.Host, sqlWorkbenchService.cfg.Port, dbID)
header := map[string]string{}
var cookieParts []string
for _, cookie := range c.Request().Cookies() {
if cookie.Name == ODCSessionCookieName || cookie.Name == ODCXsrfTokenCookieName {
cookieParts = append(cookieParts, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
}
}
if len(cookieParts) > 0 {
header["Cookie"] = strings.Join(cookieParts, "; ")
}
reply := struct {
Data struct {
Name string `json:"name"`
} `json:"data"`
}{}
if err := pkgHttp.Get(ctx, odcURL, header, nil, &reply); err != nil {
return "", err
}
return reply.Data.Name, nil
}

// getDMSUserIdFromRequest 从请求中获取 DMS 用户 ID
Expand Down Expand Up @@ -1256,7 +1322,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) isEnableSQLAudit(dbService *biz.
}

// callSQLEAudit 调用 SQLE 直接审核接口
func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService) (*cloudbeaver.AuditSQLReply, error) {
func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService, schemaName string) (*cloudbeaver.AuditSQLReply, error) {
// 获取 SQLE 服务地址
target, err := sqlWorkbenchService.proxyTargetRepo.GetProxyTargetByName(ctx, _const.SqleComponentName)
if err != nil {
Expand All @@ -1265,14 +1331,14 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex

sqleAddr := fmt.Sprintf("%s/v2/sql_audit", target.URL.String())

// 构建审核请求
auditReq := cloudbeaver.AuditSQLReq{
InstanceType: dbService.DBType,
SQLContent: sql,
SQLType: "sql",
ProjectId: dbService.ProjectUID,
InstanceName: dbService.Name,
RuleTemplateName: dbService.SQLEConfig.SQLQueryConfig.RuleTemplateName,
SchemaName: schemaName,
}

// 设置请求头
Expand Down