- Markdown 编辑器:实时预览、PDF 导出、独立查看器 - 数据库优化:动态连接池、查询缓存、Redis Pipeline - 窗口置顶功能 - 文件系统增强:右键菜单、编辑器集成、收藏夹重构 - 安全修复:XSS 防护、路径穿越、HTML 注入 - 代码质量:正则预编译、缓存锁优化、死代码清理
476 lines
13 KiB
Go
476 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"u-desk/internal/common"
|
||
"u-desk/internal/dbclient"
|
||
"u-desk/internal/storage/models"
|
||
"u-desk/internal/storage/repository"
|
||
)
|
||
|
||
// SqlExecService SQL执行服务
|
||
type SqlExecService struct {
|
||
connRepo repository.ConnectionRepository
|
||
pool *dbclient.ConnectionPool
|
||
}
|
||
|
||
// NewSqlExecService 创建SQL执行服务
|
||
func NewSqlExecService() (*SqlExecService, error) {
|
||
connRepo, err := repository.NewConnectionRepository()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &SqlExecService{
|
||
connRepo: connRepo,
|
||
pool: dbclient.GetPool(),
|
||
}, nil
|
||
}
|
||
|
||
// SqlResult SQL执行结果
|
||
type SqlResult struct {
|
||
Type string `json:"type"` // query/update/command
|
||
Data interface{} `json:"data"` // 查询结果数据
|
||
Columns []string `json:"columns"` // 列顺序(仅查询时有效)
|
||
RowsAffected int `json:"rowsAffected"` // 影响行数
|
||
ExecutionTime int64 `json:"executionTime"` // 执行时间(毫秒)
|
||
}
|
||
|
||
// ExecuteSQL 执行SQL语句
|
||
// 注意:SQL 语句应该已经包含分页信息(LIMIT 和 OFFSET),由客户端添加
|
||
func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database string) (*SqlResult, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
startTime := time.Now()
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
return s.executeMySQL(ctx, conn, sqlStr, database, startTime)
|
||
case "redis":
|
||
return s.executeRedis(ctx, conn, sqlStr, startTime)
|
||
case "mongo":
|
||
return s.executeMongo(ctx, conn, sqlStr, database, startTime)
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// executeMySQL 执行MySQL SQL
|
||
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
|
||
sqlStr = strings.TrimSpace(sqlStr)
|
||
sqlUpper := strings.ToUpper(sqlStr)
|
||
|
||
// 获取数据库参数
|
||
dbName := database
|
||
if dbName == "" {
|
||
dbName = conn.Database
|
||
}
|
||
|
||
result := &SqlResult{
|
||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||
}
|
||
|
||
// 判断是查询还是更新
|
||
if strings.HasPrefix(sqlUpper, "SELECT") || strings.HasPrefix(sqlUpper, "SHOW") ||
|
||
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
|
||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
|
||
// 查询语句
|
||
queryResult, err := pc.Client.ExecuteQuery(ctx, sqlStr, dbName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
result.Type = "query"
|
||
result.Data = queryResult.Data
|
||
result.Columns = queryResult.Columns
|
||
result.RowsAffected = len(queryResult.Data)
|
||
} else {
|
||
// 更新语句
|
||
rowsAffected, err := pc.Client.ExecuteUpdate(ctx, sqlStr, dbName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
result.Type = "update"
|
||
result.RowsAffected = int(rowsAffected)
|
||
result.Data = nil
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// executeRedis 执行Redis命令
|
||
func (s *SqlExecService) executeRedis(ctx context.Context, conn *models.DbConnection, sqlStr string, startTime time.Time) (*SqlResult, error) {
|
||
client, err := s.pool.GetRedisClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||
}
|
||
|
||
// 解析Redis命令
|
||
parts := parseRedisCommand(sqlStr)
|
||
if len(parts) == 0 {
|
||
return nil, fmt.Errorf("Redis 命令不能为空")
|
||
}
|
||
|
||
cmd := strings.ToUpper(parts[0])
|
||
args := make([]interface{}, 0)
|
||
for i := 1; i < len(parts); i++ {
|
||
args = append(args, parts[i])
|
||
}
|
||
|
||
data, err := client.ExecuteCommand(ctx, cmd, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &SqlResult{
|
||
Type: "command",
|
||
Data: data,
|
||
RowsAffected: 1,
|
||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||
}, nil
|
||
}
|
||
|
||
// executeMongo 执行MongoDB命令
|
||
func (s *SqlExecService) executeMongo(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
|
||
// 解析MongoDB命令(JSON格式)
|
||
var command map[string]interface{}
|
||
sqlStr = strings.TrimSpace(sqlStr)
|
||
if err := json.Unmarshal([]byte(sqlStr), &command); err != nil {
|
||
return nil, fmt.Errorf("MongoDB 命令必须是有效的 JSON 格式: %v", err)
|
||
}
|
||
|
||
// 确定数据库
|
||
dbName := conn.Database
|
||
if db, ok := command["database"].(string); ok && db != "" {
|
||
dbName = db
|
||
}
|
||
if database != "" {
|
||
dbName = database
|
||
}
|
||
if dbName == "" {
|
||
return nil, fmt.Errorf("需要指定数据库名称")
|
||
}
|
||
|
||
// 执行命令
|
||
data, err := client.ExecuteCommand(ctx, dbName, command)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := &SqlResult{
|
||
Type: "command",
|
||
Data: data,
|
||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||
}
|
||
|
||
// 根据操作类型确定影响行数
|
||
if op, ok := command["op"].(string); ok {
|
||
switch op {
|
||
case "find":
|
||
if results, ok := data.([]map[string]interface{}); ok {
|
||
result.RowsAffected = len(results)
|
||
}
|
||
case "count":
|
||
if count, ok := data.(int64); ok {
|
||
result.RowsAffected = int(count)
|
||
}
|
||
case "insertOne", "deleteOne":
|
||
result.RowsAffected = 1
|
||
case "insertMany":
|
||
if resultMap, ok := data.(map[string]interface{}); ok {
|
||
if count, ok := resultMap["insertedCount"].(int); ok {
|
||
result.RowsAffected = count
|
||
}
|
||
}
|
||
default:
|
||
result.RowsAffected = 0
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetDatabases 获取数据库列表
|
||
func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
return pc.Client.ListDatabases(ctx)
|
||
case "redis":
|
||
databases := make([]string, 16)
|
||
for i := 0; i < 16; i++ {
|
||
databases[i] = fmt.Sprintf("%d", i)
|
||
}
|
||
return databases, nil
|
||
case "mongo":
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
return client.ListDatabases(ctx)
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// GetTables 获取表列表(MySQL/MongoDB)或Key列表(Redis)
|
||
func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
return pc.Client.ListTables(ctx, database)
|
||
case "redis":
|
||
client, err := s.pool.GetRedisClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||
}
|
||
return client.GetKeys(ctx, database)
|
||
case "mongo":
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
return client.ListCollections(ctx, database)
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// parseRedisCommand 解析Redis命令
|
||
func parseRedisCommand(cmd string) []string {
|
||
cmd = strings.TrimSpace(cmd)
|
||
if cmd == "" {
|
||
return []string{}
|
||
}
|
||
|
||
var parts []string
|
||
var current strings.Builder
|
||
inQuotes := false
|
||
quoteChar := byte(0)
|
||
|
||
for i := 0; i < len(cmd); i++ {
|
||
char := cmd[i]
|
||
if !inQuotes {
|
||
if char == '"' || char == '\'' {
|
||
inQuotes = true
|
||
quoteChar = char
|
||
} else if char == ' ' || char == '\t' {
|
||
if current.Len() > 0 {
|
||
parts = append(parts, current.String())
|
||
current.Reset()
|
||
}
|
||
} else {
|
||
current.WriteByte(char)
|
||
}
|
||
} else {
|
||
if char == quoteChar {
|
||
inQuotes = false
|
||
quoteChar = byte(0)
|
||
} else {
|
||
current.WriteByte(char)
|
||
}
|
||
}
|
||
}
|
||
|
||
if current.Len() > 0 {
|
||
parts = append(parts, current.String())
|
||
}
|
||
return parts
|
||
}
|
||
|
||
// GetTableStructure 获取表结构
|
||
func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
structure, err := pc.Client.GetTableStructure(ctx, database, tableName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return map[string]interface{}{
|
||
"type": "mysql",
|
||
"database": database,
|
||
"table": tableName,
|
||
"columns": structure,
|
||
}, nil
|
||
|
||
case "mongo":
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
structure, err := client.GetCollectionStructure(ctx, database, tableName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return map[string]interface{}{
|
||
"type": "mongo",
|
||
"database": database,
|
||
"collection": tableName,
|
||
"structure": structure,
|
||
}, nil
|
||
|
||
case "redis":
|
||
client, err := s.pool.GetRedisClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||
}
|
||
info, err := client.GetKeyInfo(ctx, tableName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return map[string]interface{}{
|
||
"type": "redis",
|
||
"key": tableName,
|
||
"info": info,
|
||
}, nil
|
||
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// GetIndexes 获取索引列表
|
||
func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
return pc.Client.GetIndexes(ctx, database, tableName)
|
||
|
||
case "mongo", "redis":
|
||
return []map[string]interface{}{}, nil
|
||
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// PreviewTableStructure 预览表结构变更
|
||
func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
return pc.Client.PreviewTableStructure(ctx, database, tableName, structure)
|
||
|
||
case "mongo":
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
return client.PreviewCollectionIndexes(ctx, database, tableName, structure)
|
||
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|
||
|
||
// UpdateTableStructure 更新表结构
|
||
func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||
conn, err := s.connRepo.FindByID(connectionID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
|
||
defer cancel()
|
||
|
||
switch conn.Type {
|
||
case "mysql":
|
||
pc := s.pool.GetMySQLClient(conn)
|
||
if pc.Client == nil {
|
||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||
}
|
||
defer pc.Release()
|
||
return pc.Client.UpdateTableStructure(ctx, database, tableName, structure)
|
||
|
||
case "mongo":
|
||
client, err := s.pool.GetMongoClient(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||
}
|
||
return client.UpdateCollectionIndexes(ctx, database, tableName, structure)
|
||
|
||
default:
|
||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||
}
|
||
}
|