Private
Public Access
1
0
Files
u-desk/internal/service/sql_exec_service.go
绝尘 e5dbe89a6f 新增:Markdown编辑器/数据库优化/安全修复
- Markdown 编辑器:实时预览、PDF 导出、独立查看器
- 数据库优化:动态连接池、查询缓存、Redis Pipeline
- 窗口置顶功能
- 文件系统增强:右键菜单、编辑器集成、收藏夹重构
- 安全修复:XSS 防护、路径穿越、HTML 注入
- 代码质量:正则预编译、缓存锁优化、死代码清理
2026-03-31 11:49:25 +08:00

476 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"u-desk/internal/common"
"u-desk/internal/dbclient"
"u-desk/internal/storage/models"
"u-desk/internal/storage/repository"
)
// SqlExecService SQL执行服务
type SqlExecService struct {
connRepo repository.ConnectionRepository
pool *dbclient.ConnectionPool
}
// NewSqlExecService 创建SQL执行服务
func NewSqlExecService() (*SqlExecService, error) {
connRepo, err := repository.NewConnectionRepository()
if err != nil {
return nil, err
}
return &SqlExecService{
connRepo: connRepo,
pool: dbclient.GetPool(),
}, nil
}
// SqlResult SQL执行结果
type SqlResult struct {
Type string `json:"type"` // query/update/command
Data interface{} `json:"data"` // 查询结果数据
Columns []string `json:"columns"` // 列顺序(仅查询时有效)
RowsAffected int `json:"rowsAffected"` // 影响行数
ExecutionTime int64 `json:"executionTime"` // 执行时间(毫秒)
}
// ExecuteSQL 执行SQL语句
// 注意SQL 语句应该已经包含分页信息LIMIT 和 OFFSET由客户端添加
func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database string) (*SqlResult, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
startTime := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
case "mysql":
return s.executeMySQL(ctx, conn, sqlStr, database, startTime)
case "redis":
return s.executeRedis(ctx, conn, sqlStr, startTime)
case "mongo":
return s.executeMongo(ctx, conn, sqlStr, database, startTime)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// executeMySQL 执行MySQL SQL
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
sqlStr = strings.TrimSpace(sqlStr)
sqlUpper := strings.ToUpper(sqlStr)
// 获取数据库参数
dbName := database
if dbName == "" {
dbName = conn.Database
}
result := &SqlResult{
ExecutionTime: time.Since(startTime).Milliseconds(),
}
// 判断是查询还是更新
if strings.HasPrefix(sqlUpper, "SELECT") || strings.HasPrefix(sqlUpper, "SHOW") ||
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
// 查询语句
queryResult, err := pc.Client.ExecuteQuery(ctx, sqlStr, dbName)
if err != nil {
return nil, err
}
result.Type = "query"
result.Data = queryResult.Data
result.Columns = queryResult.Columns
result.RowsAffected = len(queryResult.Data)
} else {
// 更新语句
rowsAffected, err := pc.Client.ExecuteUpdate(ctx, sqlStr, dbName)
if err != nil {
return nil, err
}
result.Type = "update"
result.RowsAffected = int(rowsAffected)
result.Data = nil
}
return result, nil
}
// executeRedis 执行Redis命令
func (s *SqlExecService) executeRedis(ctx context.Context, conn *models.DbConnection, sqlStr string, startTime time.Time) (*SqlResult, error) {
client, err := s.pool.GetRedisClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
}
// 解析Redis命令
parts := parseRedisCommand(sqlStr)
if len(parts) == 0 {
return nil, fmt.Errorf("Redis 命令不能为空")
}
cmd := strings.ToUpper(parts[0])
args := make([]interface{}, 0)
for i := 1; i < len(parts); i++ {
args = append(args, parts[i])
}
data, err := client.ExecuteCommand(ctx, cmd, args...)
if err != nil {
return nil, err
}
return &SqlResult{
Type: "command",
Data: data,
RowsAffected: 1,
ExecutionTime: time.Since(startTime).Milliseconds(),
}, nil
}
// executeMongo 执行MongoDB命令
func (s *SqlExecService) executeMongo(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
// 解析MongoDB命令JSON格式
var command map[string]interface{}
sqlStr = strings.TrimSpace(sqlStr)
if err := json.Unmarshal([]byte(sqlStr), &command); err != nil {
return nil, fmt.Errorf("MongoDB 命令必须是有效的 JSON 格式: %v", err)
}
// 确定数据库
dbName := conn.Database
if db, ok := command["database"].(string); ok && db != "" {
dbName = db
}
if database != "" {
dbName = database
}
if dbName == "" {
return nil, fmt.Errorf("需要指定数据库名称")
}
// 执行命令
data, err := client.ExecuteCommand(ctx, dbName, command)
if err != nil {
return nil, err
}
result := &SqlResult{
Type: "command",
Data: data,
ExecutionTime: time.Since(startTime).Milliseconds(),
}
// 根据操作类型确定影响行数
if op, ok := command["op"].(string); ok {
switch op {
case "find":
if results, ok := data.([]map[string]interface{}); ok {
result.RowsAffected = len(results)
}
case "count":
if count, ok := data.(int64); ok {
result.RowsAffected = int(count)
}
case "insertOne", "deleteOne":
result.RowsAffected = 1
case "insertMany":
if resultMap, ok := data.(map[string]interface{}); ok {
if count, ok := resultMap["insertedCount"].(int); ok {
result.RowsAffected = count
}
}
default:
result.RowsAffected = 0
}
}
return result, nil
}
// GetDatabases 获取数据库列表
func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
return pc.Client.ListDatabases(ctx)
case "redis":
databases := make([]string, 16)
for i := 0; i < 16; i++ {
databases[i] = fmt.Sprintf("%d", i)
}
return databases, nil
case "mongo":
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
return client.ListDatabases(ctx)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// GetTables 获取表列表MySQL/MongoDB或Key列表Redis
func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
return pc.Client.ListTables(ctx, database)
case "redis":
client, err := s.pool.GetRedisClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
}
return client.GetKeys(ctx, database)
case "mongo":
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
return client.ListCollections(ctx, database)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// parseRedisCommand 解析Redis命令
func parseRedisCommand(cmd string) []string {
cmd = strings.TrimSpace(cmd)
if cmd == "" {
return []string{}
}
var parts []string
var current strings.Builder
inQuotes := false
quoteChar := byte(0)
for i := 0; i < len(cmd); i++ {
char := cmd[i]
if !inQuotes {
if char == '"' || char == '\'' {
inQuotes = true
quoteChar = char
} else if char == ' ' || char == '\t' {
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
} else {
current.WriteByte(char)
}
} else {
if char == quoteChar {
inQuotes = false
quoteChar = byte(0)
} else {
current.WriteByte(char)
}
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// GetTableStructure 获取表结构
func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
structure, err := pc.Client.GetTableStructure(ctx, database, tableName)
if err != nil {
return nil, err
}
return map[string]interface{}{
"type": "mysql",
"database": database,
"table": tableName,
"columns": structure,
}, nil
case "mongo":
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
structure, err := client.GetCollectionStructure(ctx, database, tableName)
if err != nil {
return nil, err
}
return map[string]interface{}{
"type": "mongo",
"database": database,
"collection": tableName,
"structure": structure,
}, nil
case "redis":
client, err := s.pool.GetRedisClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
}
info, err := client.GetKeyInfo(ctx, tableName)
if err != nil {
return nil, err
}
return map[string]interface{}{
"type": "redis",
"key": tableName,
"info": info,
}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// GetIndexes 获取索引列表
func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
return pc.Client.GetIndexes(ctx, database, tableName)
case "mongo", "redis":
return []map[string]interface{}{}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// PreviewTableStructure 预览表结构变更
func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
return pc.Client.PreviewTableStructure(ctx, database, tableName, structure)
case "mongo":
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
return client.PreviewCollectionIndexes(ctx, database, tableName, structure)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}
// UpdateTableStructure 更新表结构
func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
conn, err := s.connRepo.FindByID(connectionID)
if err != nil {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
defer cancel()
switch conn.Type {
case "mysql":
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
return pc.Client.UpdateTableStructure(ctx, database, tableName, structure)
case "mongo":
client, err := s.pool.GetMongoClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
}
return client.UpdateCollectionIndexes(ctx, database, tableName, structure)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
}
}