重构:移除数据库客户端模块 v0.4.0(-17,885行,专注文件管理)
- 删除全部 MySQL/Redis/MongoDB 客户端代码(dbclient/api/service/storage) - 清理 4 个驱动依赖(mysql/redis/mongo/gorm-mysql),构建体积 -10MB - 前端移除 db-cli 整个目录(40 文件)+ 7 个 API/工具文件 - 版本号升级至 v0.4.0,顶部 Tab 仅保留文件管理
This commit is contained in:
@@ -1,128 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"u-desk/internal/service"
|
||||
"u-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionAPI 连接管理API
|
||||
type ConnectionAPI struct {
|
||||
connService *service.ConnectionService
|
||||
}
|
||||
|
||||
// NewConnectionAPI 创建连接管理API
|
||||
func NewConnectionAPI() (*ConnectionAPI, error) {
|
||||
connService, err := service.NewConnectionService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConnectionAPI{connService}, nil
|
||||
}
|
||||
|
||||
// SaveConnectionRequest 保存连接请求结构体
|
||||
type SaveConnectionRequest struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Options string `json:"options"`
|
||||
VisibleDatabases string `json:"visible_databases"`
|
||||
}
|
||||
|
||||
// SaveDbConnection 保存数据库连接配置
|
||||
func (api *ConnectionAPI) SaveDbConnection(req SaveConnectionRequest) error {
|
||||
conn := &models.DbConnection{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Database: req.Database,
|
||||
Options: req.Options,
|
||||
VisibleDatabases: req.VisibleDatabases,
|
||||
}
|
||||
return api.connService.SaveConnection(conn)
|
||||
}
|
||||
|
||||
// ListDbConnections 获取连接列表
|
||||
func (api *ConnectionAPI) ListDbConnections() ([]map[string]interface{}, error) {
|
||||
connections, err := api.connService.ListConnections()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(connections))
|
||||
timeFormat := "2006-01-02 15:04:05"
|
||||
for i, conn := range connections {
|
||||
result[i] = map[string]interface{}{
|
||||
"id": conn.ID,
|
||||
"name": conn.Name,
|
||||
"type": conn.Type,
|
||||
"host": conn.Host,
|
||||
"port": conn.Port,
|
||||
"username": conn.Username,
|
||||
"database": conn.Database,
|
||||
"options": conn.Options,
|
||||
"visible_databases": conn.VisibleDatabases,
|
||||
"created_at": conn.CreatedAt.Format(timeFormat),
|
||||
"updated_at": conn.UpdatedAt.Format(timeFormat),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (api *ConnectionAPI) DeleteDbConnection(id uint) error {
|
||||
return api.connService.DeleteConnection(id)
|
||||
}
|
||||
|
||||
func (api *ConnectionAPI) TestDbConnection(id uint) error {
|
||||
return api.connService.TestConnection(id)
|
||||
}
|
||||
|
||||
// TestConnectionRequest 测试连接请求结构体(不保存数据)
|
||||
type TestConnectionRequest struct {
|
||||
ID uint `json:"id"` // 编辑模式下的连接ID(用于获取已保存的密码)
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Options string `json:"options"`
|
||||
}
|
||||
|
||||
// TestDbConnectionWithParams 测试数据库连接(直接传入参数,不保存数据)
|
||||
func (api *ConnectionAPI) TestDbConnectionWithParams(req TestConnectionRequest) error {
|
||||
return api.connService.TestConnectionWithParams(
|
||||
req.Type, req.Host, req.Port,
|
||||
req.Username, req.Password, req.Database,
|
||||
req.Options, req.ID,
|
||||
)
|
||||
}
|
||||
|
||||
// LoadAllDatabasesRequest 加载全部数据库请求结构体
|
||||
type LoadAllDatabasesRequest struct {
|
||||
ID uint `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Options string `json:"options"`
|
||||
}
|
||||
|
||||
// LoadAllDatabases 加载全部数据库列表
|
||||
func (api *ConnectionAPI) LoadAllDatabases(req LoadAllDatabasesRequest) ([]string, error) {
|
||||
return api.connService.LoadAllDatabases(
|
||||
req.Type, req.Host, req.Port,
|
||||
req.Username, req.Password, req.Database,
|
||||
req.Options, req.ID,
|
||||
)
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"u-desk/internal/service"
|
||||
"u-desk/internal/storage/models"
|
||||
"u-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
type SqlAPI struct {
|
||||
sqlService *service.SqlExecService
|
||||
resultRepo repository.ResultRepository
|
||||
}
|
||||
|
||||
func NewSqlAPI() (*SqlAPI, error) {
|
||||
sqlService, err := service.NewSqlExecService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultRepo, err := repository.NewResultRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SqlAPI{sqlService, resultRepo}, nil
|
||||
}
|
||||
|
||||
// ExecuteSQL 执行SQL语句
|
||||
// 注意:SQL 语句应该已经包含分页信息(LIMIT 和 OFFSET),由客户端添加
|
||||
func (api *SqlAPI) ExecuteSQL(connectionID uint, sqlStr string, database string) (map[string]interface{}, error) {
|
||||
result, err := api.sqlService.ExecuteSQL(connectionID, sqlStr, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"type": result.Type,
|
||||
"data": result.Data,
|
||||
"rowsAffected": result.RowsAffected,
|
||||
"executionTime": result.ExecutionTime,
|
||||
}
|
||||
// 如果是查询,添加列顺序信息
|
||||
if result.Type == "query" && len(result.Columns) > 0 {
|
||||
response["columns"] = result.Columns
|
||||
}
|
||||
|
||||
// 自动保存结果到历史记录(异步执行)
|
||||
go func() {
|
||||
api.resultRepo.Save(connectionID, database, sqlStr, result.Type, result.Data, result.Columns, result.RowsAffected, result.ExecutionTime)
|
||||
}()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetDatabases(connectionID uint) ([]string, error) {
|
||||
return api.sqlService.GetDatabases(connectionID)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetTables(connectionID uint, database string) ([]string, error) {
|
||||
return api.sqlService.GetTables(connectionID, database)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
|
||||
return api.sqlService.GetTableStructure(connectionID, database, tableName)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
|
||||
return api.sqlService.GetIndexes(connectionID, database, tableName)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
return api.sqlService.PreviewTableStructure(connectionID, database, tableName, structure)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
return api.sqlService.UpdateTableStructure(connectionID, database, tableName, structure)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) SaveResult(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (map[string]interface{}, error) {
|
||||
history, err := api.resultRepo.Save(connectionID, database, sql, resultType, data, columns, rowsAffected, executionTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return historyToMap(history), nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetResultHistory(connectionID *uint, keyword string, limit, offset int) (map[string]interface{}, error) {
|
||||
histories, total, err := api.resultRepo.Search(connectionID, keyword, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]map[string]interface{}, len(histories))
|
||||
for i, h := range histories {
|
||||
items[i] = historyToMap(&h)
|
||||
}
|
||||
|
||||
return map[string]interface{}{"items": items, "total": total}, nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetResultHistoryByID(id uint) (map[string]interface{}, error) {
|
||||
history, err := api.resultRepo.FindByID(id)
|
||||
if err != nil || history == nil {
|
||||
return nil, err
|
||||
}
|
||||
return historyToMap(history), nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) DeleteResultHistory(id uint) error {
|
||||
return api.resultRepo.Delete(id)
|
||||
}
|
||||
|
||||
func historyToMap(history *models.SqlResultHistory) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"id": history.ID,
|
||||
"connection_id": history.ConnectionID,
|
||||
"database": history.Database,
|
||||
"sql": history.Sql,
|
||||
"type": history.Type,
|
||||
"rows_affected": history.RowsAffected,
|
||||
"execution_time": history.ExecutionTime,
|
||||
"created_at": history.CreatedAt,
|
||||
}
|
||||
|
||||
if history.Data != "" {
|
||||
var data interface{}
|
||||
json.Unmarshal([]byte(history.Data), &data)
|
||||
result["data"] = data
|
||||
}
|
||||
|
||||
if history.Columns != "" {
|
||||
var columns []string
|
||||
json.Unmarshal([]byte(history.Columns), &columns)
|
||||
result["columns"] = columns
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"u-desk/internal/service"
|
||||
"u-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// TabAPI 标签页API
|
||||
type TabAPI struct {
|
||||
tabService *service.TabService
|
||||
}
|
||||
|
||||
// NewTabAPI 创建标签页API
|
||||
func NewTabAPI() (*TabAPI, error) {
|
||||
tabService, err := service.NewTabService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TabAPI{tabService: tabService}, nil
|
||||
}
|
||||
|
||||
// SaveSqlTabs 保存SQL标签页列表(接收 map 格式,转换为模型)
|
||||
func (api *TabAPI) SaveSqlTabs(tabs []map[string]interface{}) error {
|
||||
sqlTabs := make([]models.SqlTab, len(tabs))
|
||||
for idx, tabData := range tabs {
|
||||
tab := models.SqlTab{
|
||||
Order: idx,
|
||||
}
|
||||
|
||||
// 处理 ID
|
||||
if id, ok := tabData["id"].(float64); ok && id > 0 {
|
||||
tab.ID = uint(id)
|
||||
}
|
||||
|
||||
// 处理标题
|
||||
if title, ok := tabData["title"].(string); ok {
|
||||
tab.Title = title
|
||||
} else {
|
||||
tab.Title = fmt.Sprintf("查询 %d", idx+1)
|
||||
}
|
||||
|
||||
// 处理内容
|
||||
if content, ok := tabData["content"].(string); ok {
|
||||
tab.Content = content
|
||||
}
|
||||
|
||||
// 处理连接ID
|
||||
if connId, ok := tabData["connectionId"].(float64); ok && connId > 0 {
|
||||
connID := uint(connId)
|
||||
tab.ConnectionID = &connID
|
||||
}
|
||||
|
||||
sqlTabs[idx] = tab
|
||||
}
|
||||
return api.tabService.SaveTabs(sqlTabs)
|
||||
}
|
||||
|
||||
// ListSqlTabs 获取SQL标签页列表(返回 map 格式)
|
||||
func (api *TabAPI) ListSqlTabs() ([]map[string]interface{}, error) {
|
||||
tabs, err := api.tabService.ListTabs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(tabs))
|
||||
for i, tab := range tabs {
|
||||
result[i] = map[string]interface{}{
|
||||
"id": tab.ID,
|
||||
"title": tab.Title,
|
||||
"content": tab.Content,
|
||||
"connectionId": tab.ConnectionID,
|
||||
"order": tab.Order,
|
||||
"createdAt": tab.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
"updatedAt": tab.UpdatedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package common
|
||||
|
||||
// Default visible tabs configuration
|
||||
const (
|
||||
// TabDatabase 数据库管理 Tab
|
||||
TabDatabase = "db-cli"
|
||||
// TabFileSystem 文件系统 Tab
|
||||
TabFileSystem = "file-system"
|
||||
// TabDevice 设备测试 Tab
|
||||
@@ -11,4 +9,4 @@ const (
|
||||
)
|
||||
|
||||
// DefaultVisibleTabs 默认可见的 Tabs
|
||||
var DefaultVisibleTabs = []string{TabDatabase, TabFileSystem, TabDevice}
|
||||
var DefaultVisibleTabs = []string{TabFileSystem, TabDevice}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package common
|
||||
|
||||
import "time"
|
||||
|
||||
// 数据库操作超时配置
|
||||
const (
|
||||
TimeoutPing = 2 * time.Second // 连接测试超时
|
||||
TimeoutConnect = 5 * time.Second // 初始连接超时
|
||||
TimeoutFastQuery = 10 * time.Second // 元数据查询超时
|
||||
TimeoutQuery = 30 * time.Second // 普通查询超时
|
||||
TimeoutLongOp = 60 * time.Second // 长时间操作超时
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// 旧版硬编码密钥(用于兼容迁移已有加密数据)
|
||||
var legacyKey = []byte("go-desk-db-cli-key-32bytes123456")
|
||||
|
||||
var (
|
||||
encryptionKey []byte
|
||||
keyOnce sync.Once
|
||||
keyInitErr error
|
||||
)
|
||||
|
||||
// getKey 获取或创建机器唯一密钥
|
||||
// 首次启动时生成并持久化到用户配置目录,后续直接读取
|
||||
func getKey() ([]byte, error) {
|
||||
keyOnce.Do(func() {
|
||||
keyFile, err := getKeyFilePath()
|
||||
if err != nil {
|
||||
keyInitErr = fmt.Errorf("获取密钥路径失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试读取已有密钥
|
||||
if data, err := os.ReadFile(keyFile); err == nil && len(data) == 32 {
|
||||
encryptionKey = data
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新密钥
|
||||
newKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, newKey); err != nil {
|
||||
keyInitErr = fmt.Errorf("生成密钥失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 持久化密钥
|
||||
dir := filepath.Dir(keyFile)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
keyInitErr = fmt.Errorf("创建密钥目录失败: %v", err)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(keyFile, newKey, 0600); err != nil {
|
||||
keyInitErr = fmt.Errorf("保存密钥失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
encryptionKey = newKey
|
||||
})
|
||||
|
||||
return encryptionKey, keyInitErr
|
||||
}
|
||||
|
||||
// getKeyFilePath 返回密钥文件路径
|
||||
func getKeyFilePath() (string, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configDir, "u-desk", ".aes-key"), nil
|
||||
}
|
||||
|
||||
// DecryptPasswordV2 使用指定密钥解密(用于密钥迁移)
|
||||
func DecryptPasswordV2(encryptedPassword string, key []byte) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(encryptedPassword) < 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建解密器失败: %v", err)
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
}
|
||||
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("密文长度不足")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// EncryptPassword 加密密码
|
||||
func EncryptPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
key, err := getKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取加密密钥失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建加密器失败: %v", err)
|
||||
}
|
||||
|
||||
// 使用 GCM 模式
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
}
|
||||
|
||||
// 生成随机 nonce
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("生成 nonce 失败: %v", err)
|
||||
}
|
||||
|
||||
// 加密
|
||||
ciphertext := aesGCM.Seal(nonce, nonce, []byte(password), nil)
|
||||
|
||||
// Base64 编码
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptPassword 解密密码(自动回退旧密钥兼容旧数据)
|
||||
func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(encryptedPassword) < 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
key, err := getKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取解密密钥失败: %v", err)
|
||||
}
|
||||
|
||||
// 先用新密钥尝试解密
|
||||
result, err := DecryptPasswordV2(encryptedPassword, key)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 新密钥失败,尝试旧密钥(兼容已迁移的旧数据)
|
||||
result, err = DecryptPasswordV2(encryptedPassword, legacyKey)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 两种密钥都失败
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"u-desk/internal/model"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotConnected = errors.New("数据库未连接")
|
||||
)
|
||||
|
||||
// DB 数据库连接封装
|
||||
type DB struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
var globalDB *DB
|
||||
|
||||
// Init 初始化数据库连接
|
||||
func Init() (*DB, error) {
|
||||
if globalDB != nil {
|
||||
return globalDB, nil
|
||||
}
|
||||
|
||||
// 数据库配置 - 测试服 lab_dev
|
||||
// 测试机外网IP: 39.99.243.191
|
||||
// 使用 mysqldriver.Config 结构体构建 DSN,自动处理密码中的特殊字符
|
||||
config := mysqldriver.Config{
|
||||
User: "root",
|
||||
Passwd: "123456",
|
||||
Net: "tcp",
|
||||
Addr: "127.0.0.1:3306",
|
||||
DBName: "lab_dev",
|
||||
Params: map[string]string{"charset": "utf8mb4", "parseTime": "True", "loc": "Local"},
|
||||
AllowNativePasswords: true,
|
||||
}
|
||||
dsn := config.FormatDSN()
|
||||
|
||||
// GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
}
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库连接失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取底层 sql.DB 设置连接池参数
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(25)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(300) * time.Second)
|
||||
|
||||
globalDB = &DB{db: db}
|
||||
return globalDB, nil
|
||||
}
|
||||
|
||||
// QueryUsers 查询用户列表
|
||||
func (d *DB) QueryUsers(keyword string, status int, role int, organid int, page int, pageSize int, sortField string, sortOrder string) (map[string]interface{}, error) {
|
||||
if d.db == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
|
||||
query := d.db.Model(&model.MemberInfo{})
|
||||
|
||||
// 关键字搜索(姓名、账号、电话)
|
||||
if keyword != "" {
|
||||
query = query.Where("membername LIKE ? OR account LIKE ? OR contactphone LIKE ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status > 0 {
|
||||
query = query.Where("status = ?", status)
|
||||
} else {
|
||||
// 默认过滤删除状态
|
||||
query = query.Where("status != ?", 3)
|
||||
}
|
||||
|
||||
// 角色筛选(需要关联查询,暂时简化)
|
||||
if role > 0 {
|
||||
// TODO: 关联 sys_member_role 表查询
|
||||
}
|
||||
|
||||
// 机构筛选
|
||||
if organid > 0 {
|
||||
query = query.Where("organid = ?", organid)
|
||||
}
|
||||
|
||||
// 排序
|
||||
if sortField != "" {
|
||||
if sortOrder == "descend" || sortOrder == "desc" {
|
||||
query = query.Order(sortField + " DESC")
|
||||
} else {
|
||||
query = query.Order(sortField + " ASC")
|
||||
}
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("createtime DESC")
|
||||
}
|
||||
|
||||
// 总数
|
||||
var total int64
|
||||
query.Count(&total)
|
||||
|
||||
// 分页
|
||||
offset := (page - 1) * pageSize
|
||||
var users []model.MemberInfo
|
||||
if err := query.Offset(offset).Limit(pageSize).Find(&users).Error; err != nil {
|
||||
return nil, fmt.Errorf("查询用户失败: %v", err)
|
||||
}
|
||||
|
||||
// 返回结果
|
||||
result := map[string]interface{}{
|
||||
"rows": users,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,479 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// QueryCache 查询缓存
|
||||
type QueryCache struct {
|
||||
items map[string]*CachedQuery
|
||||
size int
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 智能缓存策略
|
||||
hitRate float64 // 缓存命中率
|
||||
hitCount int64 // 命中次数
|
||||
missCount int64 // 未命中次数
|
||||
evictionCount int64 // 驱逐次数
|
||||
hotQueries map[string]bool // 热点查询标记
|
||||
cooldowns map[string]time.Time // 冷却时间(避免频繁驱逐)
|
||||
|
||||
// 内存限制
|
||||
maxMemoryBytes int64 // 缓存最大内存(字节),默认 100MB
|
||||
usedMemory int64 // 当前估算内存使用量
|
||||
}
|
||||
|
||||
// NewQueryCache 创建新的查询缓存
|
||||
func NewQueryCache(size int, ttl time.Duration) *QueryCache {
|
||||
cache := &QueryCache{
|
||||
items: make(map[string]*CachedQuery),
|
||||
size: size,
|
||||
ttl: ttl,
|
||||
stopCh: make(chan struct{}),
|
||||
hitRate: 0.0,
|
||||
hitCount: 0,
|
||||
missCount: 0,
|
||||
evictionCount: 0,
|
||||
hotQueries: make(map[string]bool),
|
||||
cooldowns: make(map[string]time.Time),
|
||||
maxMemoryBytes: 100 * 1024 * 1024, // 默认 100MB
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
cache.StartCleanup()
|
||||
|
||||
// 启动统计协程
|
||||
cache.StartStatsCollection()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get 从缓存中获取查询结果
|
||||
func (c *QueryCache) Get(params QueryParams) (*CachedQuery, error) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
c.mu.RLock()
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
c.missCount++
|
||||
_, inCooldown := c.cooldowns[key]
|
||||
if inCooldown && time.Now().Before(c.cooldowns[key]) {
|
||||
c.mu.RUnlock()
|
||||
return nil, ErrCacheCooldown
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
return nil, ErrCacheNotFound
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(item.ExpiryTime) {
|
||||
if c.isHotQuery(key) {
|
||||
c.mu.RUnlock()
|
||||
c.mu.Lock()
|
||||
item.ExpiryTime = time.Now().Add(c.ttl)
|
||||
c.hitCount++
|
||||
c.markAsHot(key)
|
||||
c.mu.Unlock()
|
||||
return item, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.evictionCount++
|
||||
c.missCount++
|
||||
c.mu.Unlock()
|
||||
return nil, ErrCacheExpired
|
||||
}
|
||||
|
||||
// 命中
|
||||
c.hitCount++
|
||||
needsMark := !c.hotQueries[key]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if needsMark {
|
||||
c.mu.Lock()
|
||||
c.markAsHot(key)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Set 将查询结果存入缓存
|
||||
func (c *QueryCache) Set(params QueryParams, item *CachedQuery) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
// 估算条目内存大小
|
||||
itemSize := c.estimateSize(params, item)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// 更新统计
|
||||
c.recordQueryAttempt(key)
|
||||
|
||||
// 如果超过内存限制,执行驱逐直到有空间
|
||||
for c.usedMemory+itemSize > c.maxMemoryBytes && len(c.items) > 0 {
|
||||
c.smartEvict(key)
|
||||
}
|
||||
|
||||
// 如果条目数已满,执行智能驱逐
|
||||
if len(c.items) >= c.size {
|
||||
c.smartEvict(key)
|
||||
}
|
||||
|
||||
// 如果已有旧条目,先减去旧的大小
|
||||
if old, exists := c.items[key]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(old)
|
||||
}
|
||||
|
||||
c.items[key] = item
|
||||
c.usedMemory += itemSize
|
||||
|
||||
// 标记为热点查询
|
||||
c.markAsHot(key)
|
||||
}
|
||||
|
||||
// smartEvict 智能驱逐策略
|
||||
func (c *QueryCache) smartEvict(newKey string) {
|
||||
if len(c.items) == 0 {
|
||||
return
|
||||
}
|
||||
// LRU + LFU 混合策略
|
||||
var evictKey string
|
||||
var worstScore float64 = -1
|
||||
|
||||
for key, item := range c.items {
|
||||
if key == newKey {
|
||||
continue
|
||||
}
|
||||
|
||||
score := c.calculateEvictionScore(key, item)
|
||||
if score > worstScore {
|
||||
worstScore = score
|
||||
evictKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if evictKey != "" {
|
||||
if evicted, exists := c.items[evictKey]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(evicted)
|
||||
}
|
||||
c.cooldowns[evictKey] = time.Now().Add(1 * time.Minute)
|
||||
delete(c.items, evictKey)
|
||||
c.evictionCount++
|
||||
}
|
||||
}
|
||||
|
||||
// calculateEvictionScore 计算驱逐分数(越低越适合保留)
|
||||
func (c *QueryCache) calculateEvictionScore(key string, item *CachedQuery) float64 {
|
||||
now := time.Now()
|
||||
|
||||
// 基础分数
|
||||
score := 1.0
|
||||
|
||||
// 热点查询加分(优先保留)
|
||||
if c.isHotQuery(key) {
|
||||
score -= 0.5
|
||||
}
|
||||
|
||||
// 接近过期的加分(优先驱逐即将过期的)
|
||||
if item.ExpiryTime.Sub(now) < c.ttl/2 {
|
||||
score += 0.3
|
||||
}
|
||||
|
||||
// 最近使用的加分(优先保留最近使用的)
|
||||
if !item.LastUsed.IsZero() {
|
||||
recency := now.Sub(item.LastUsed)
|
||||
if recency < 5*time.Minute {
|
||||
score -= 0.2
|
||||
}
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// isHotQuery 检查是否为热点查询
|
||||
func (c *QueryCache) isHotQuery(key string) bool {
|
||||
return c.hotQueries[key]
|
||||
}
|
||||
|
||||
// markAsHot 标记为热点查询
|
||||
func (c *QueryCache) markAsHot(key string) {
|
||||
c.hotQueries[key] = true
|
||||
}
|
||||
|
||||
// cleanupHotMarkers 清理热点标记
|
||||
func (c *QueryCache) cleanupHotMarkers() {
|
||||
now := time.Now()
|
||||
for key := range c.hotQueries {
|
||||
// 清理超过10分钟未使用的热点标记
|
||||
if item, exists := c.items[key]; exists {
|
||||
if now.Sub(item.LastUsed) > 10*time.Minute {
|
||||
delete(c.hotQueries, key)
|
||||
}
|
||||
} else {
|
||||
delete(c.hotQueries, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordQueryAttempt 记录查询尝试
|
||||
func (c *QueryCache) recordQueryAttempt(key string) {
|
||||
// 更新命中率
|
||||
c.updateHitRate()
|
||||
|
||||
// 更新最后使用时间
|
||||
if item, exists := c.items[key]; exists {
|
||||
item.LastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// updateHitRate 更新命中率
|
||||
func (c *QueryCache) updateHitRate() {
|
||||
total := c.hitCount + c.missCount
|
||||
if total > 0 {
|
||||
c.hitRate = float64(c.hitCount) / float64(total)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 从缓存中删除指定查询
|
||||
func (c *QueryCache) Delete(params QueryParams) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(item)
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear 清空整个缓存
|
||||
func (c *QueryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items = make(map[string]*CachedQuery)
|
||||
c.usedMemory = 0
|
||||
}
|
||||
|
||||
// Size 获取缓存大小
|
||||
func (c *QueryCache) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.items)
|
||||
}
|
||||
|
||||
// CleanupExpired 清理过期的缓存条目
|
||||
func (c *QueryCache) CleanupExpired() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiryTime) {
|
||||
c.usedMemory -= c.estimateItemSize(item)
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys 获取缓存中所有的键
|
||||
func (c *QueryCache) Keys() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
keys := make([]string, 0, len(c.items))
|
||||
for key := range c.items {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计信息
|
||||
func (c *QueryCache) Stats() CacheStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
expired := 0
|
||||
active := 0
|
||||
|
||||
for _, item := range c.items {
|
||||
if now.After(item.ExpiryTime) {
|
||||
expired++
|
||||
} else {
|
||||
active++
|
||||
}
|
||||
}
|
||||
|
||||
return CacheStats{
|
||||
TotalItems: len(c.items),
|
||||
ActiveItems: active,
|
||||
ExpiredItems: expired,
|
||||
Size: c.size,
|
||||
TTL: c.ttl,
|
||||
HitRate: c.hitRate,
|
||||
HitCount: c.hitCount,
|
||||
MissCount: c.missCount,
|
||||
EvictionCount: c.evictionCount,
|
||||
HotQueries: len(c.hotQueries),
|
||||
}
|
||||
}
|
||||
|
||||
// generateKey 生成缓存键
|
||||
func (c *QueryCache) generateKey(params QueryParams) string {
|
||||
key := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
|
||||
params.SQL, params.Database, params.Limit, params.Offset,
|
||||
params.Table, params.Where, params.SortBy, params.IsReadOnly)
|
||||
h := sha256.Sum256([]byte(key))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// evictOldest 删除最老的缓存条目
|
||||
func (c *QueryCache) evictOldest() {
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
|
||||
for key, item := range c.items {
|
||||
if oldestKey == "" || item.CreatedAt.Before(oldestTime) {
|
||||
oldestKey = key
|
||||
oldestTime = item.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
if oldestKey != "" {
|
||||
delete(c.items, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanup 启动清理协程
|
||||
func (c *QueryCache) StartCleanup() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(c.ttl / 2) // 每 TTL/2 时间检查一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.CleanupExpired()
|
||||
c.cleanupCooldowns() // 清理冷却时间
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StartStatsCollection 启动统计收集协程
|
||||
func (c *QueryCache) StartStatsCollection() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute) // 每分钟收集一次统计
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.updateHitRate()
|
||||
c.cleanupHotMarkers()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupCooldowns 清理冷却时间
|
||||
func (c *QueryCache) cleanupCooldowns() {
|
||||
now := time.Now()
|
||||
for key, cooldown := range c.cooldowns {
|
||||
if now.After(cooldown) {
|
||||
delete(c.cooldowns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止缓存清理
|
||||
func (c *QueryCache) Stop() {
|
||||
close(c.stopCh)
|
||||
c.wg.Wait()
|
||||
}
|
||||
|
||||
// CacheStats 缓存统计信息
|
||||
type CacheStats struct {
|
||||
TotalItems int
|
||||
ActiveItems int
|
||||
ExpiredItems int
|
||||
Size int
|
||||
TTL time.Duration
|
||||
HitRate float64
|
||||
HitCount int64
|
||||
MissCount int64
|
||||
EvictionCount int64
|
||||
HotQueries int
|
||||
}
|
||||
|
||||
// 缓存错误定义
|
||||
var (
|
||||
ErrCacheNotFound = &CacheError{Message: "缓存未找到"}
|
||||
ErrCacheExpired = &CacheError{Message: "缓存已过期"}
|
||||
ErrCacheCooldown = &CacheError{Message: "查询在冷却中"}
|
||||
)
|
||||
|
||||
// CacheError 缓存错误
|
||||
type CacheError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *CacheError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// estimateSize 估算缓存条目的内存大小(字节)
|
||||
func (c *QueryCache) estimateSize(params QueryParams, item *CachedQuery) int64 {
|
||||
size := int64(len(params.SQL) + len(params.Database) + len(params.Table) +
|
||||
len(params.Where) + len(params.SortBy))
|
||||
if item != nil && item.Result != nil {
|
||||
size += c.estimateItemSize(item)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// estimateItemSize 估算 CachedQuery 的内存大小
|
||||
func (c *QueryCache) estimateItemSize(item *CachedQuery) int64 {
|
||||
if item == nil || item.Result == nil {
|
||||
return 128 // 基础结构体大小
|
||||
}
|
||||
size := int64(128) // CachedQuery 结构体基础大小
|
||||
for _, row := range item.Result.Data {
|
||||
for _, v := range row {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
size += int64(len(val))
|
||||
case []byte:
|
||||
size += int64(len(val))
|
||||
case nil:
|
||||
// 无额外开销
|
||||
default:
|
||||
size += 64 // 其他类型的估算值
|
||||
}
|
||||
}
|
||||
}
|
||||
size += int64(len(item.Result.Columns)) * 64 // 列名估算
|
||||
return size
|
||||
}
|
||||
@@ -1,825 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"u-desk/internal/common"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
// MongoClient MongoDB 客户端
|
||||
type MongoClient struct {
|
||||
client *mongo.Client
|
||||
database *mongo.Database
|
||||
config *MongoConfig
|
||||
}
|
||||
|
||||
// MongoConfig MongoDB 配置
|
||||
type MongoConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string // 认证数据库,默认为 "admin"
|
||||
AuthMechanism string // 认证机制,如 "SCRAM-SHA-1", "SCRAM-SHA-256" 等
|
||||
}
|
||||
|
||||
// NewMongoClient 创建 MongoDB 客户端
|
||||
func NewMongoClient(config *MongoConfig) (*MongoClient, error) {
|
||||
// 确定认证数据库,默认为 admin
|
||||
authSource := config.AuthSource
|
||||
if authSource == "" {
|
||||
authSource = "admin"
|
||||
}
|
||||
|
||||
// 如果指定了认证机制,直接使用;否则尝试自动检测
|
||||
authMechanisms := []string{}
|
||||
if config.AuthMechanism != "" {
|
||||
// 用户明确指定了认证机制,只使用该机制
|
||||
authMechanisms = []string{config.AuthMechanism}
|
||||
} else {
|
||||
// 未指定时,先尝试 SCRAM-SHA-256(更安全),失败则尝试 SCRAM-SHA-1
|
||||
authMechanisms = []string{"SCRAM-SHA-256", "SCRAM-SHA-1"}
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, authMechanism := range authMechanisms {
|
||||
client, err := tryConnectMongo(config, authSource, authMechanism)
|
||||
if err == nil {
|
||||
return client, nil
|
||||
}
|
||||
lastErr = err
|
||||
// 如果明确指定了认证机制,失败后不再尝试其他机制
|
||||
if config.AuthMechanism != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 所有认证机制都失败
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", lastErr)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("MongoDB 连接失败: 未知错误")
|
||||
}
|
||||
|
||||
// tryConnectMongo 尝试使用指定的认证机制连接 MongoDB
|
||||
func tryConnectMongo(config *MongoConfig, authSource, authMechanism string) (*MongoClient, error) {
|
||||
// 构建连接 URI
|
||||
var uri string
|
||||
|
||||
if config.Username != "" && config.Password != "" {
|
||||
// 使用 url.UserPassword 正确转义用户名和密码中的特殊字符
|
||||
// 这会正确处理 @、:、/ 等特殊字符
|
||||
userInfo := url.UserPassword(config.Username, config.Password)
|
||||
|
||||
// 构建基础 URI
|
||||
uri = fmt.Sprintf("mongodb://%s@%s:%d", userInfo.String(), config.Host, config.Port)
|
||||
|
||||
// 添加数据库和认证源参数
|
||||
params := url.Values{}
|
||||
params.Set("authSource", authSource)
|
||||
|
||||
// 添加认证机制参数
|
||||
if authMechanism != "" {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
|
||||
// 如果有业务数据库,添加到路径中
|
||||
if config.Database != "" {
|
||||
uri = fmt.Sprintf("%s/%s?%s", uri, config.Database, params.Encode())
|
||||
} else {
|
||||
// MongoDB URI 要求查询参数前必须有 /,即使没有数据库名
|
||||
uri = fmt.Sprintf("%s/?%s", uri, params.Encode())
|
||||
}
|
||||
} else if config.Database != "" {
|
||||
// 没有认证信息时,数据库部分用于指定默认数据库
|
||||
uri = fmt.Sprintf("mongodb://%s:%d/%s", config.Host, config.Port, config.Database)
|
||||
} else {
|
||||
uri = fmt.Sprintf("mongodb://%s:%d", config.Host, config.Port)
|
||||
}
|
||||
|
||||
// 客户端选项
|
||||
clientOptions := options.Client().
|
||||
ApplyURI(uri).
|
||||
SetConnectTimeout(common.TimeoutConnect).
|
||||
SetServerSelectionTimeout(common.TimeoutConnect)
|
||||
|
||||
// 创建客户端 (v2: 移除了 context 参数)
|
||||
client, err := mongo.Connect(clientOptions)
|
||||
|
||||
// 创建 context 用于其他操作
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
|
||||
defer cancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MongoDB 失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := client.Ping(ctx, nil); err != nil {
|
||||
client.Disconnect(ctx)
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
var database *mongo.Database
|
||||
if config.Database != "" {
|
||||
database = client.Database(config.Database)
|
||||
}
|
||||
|
||||
return &MongoClient{
|
||||
client: client,
|
||||
database: database,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestMongoConnection 测试连接
|
||||
func TestMongoConnection(host string, port int, username, password, database string) error {
|
||||
return TestMongoConnectionWithAuthSource(host, port, username, password, database, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithAuthSource 测试连接(支持指定认证数据库)
|
||||
func TestMongoConnectionWithAuthSource(host string, port int, username, password, database, authSource string) error {
|
||||
return TestMongoConnectionWithOptions(host, port, username, password, database, authSource, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithOptions 测试连接(支持指定认证数据库和认证机制)
|
||||
func TestMongoConnectionWithOptions(host string, port int, username, password, database, authSource, authMechanism string) error {
|
||||
config := &MongoConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MongoClient) Close() error {
|
||||
if c.client != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
|
||||
defer cancel()
|
||||
return c.client.Disconnect(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MongoClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
databases, err := c.client.ListDatabaseNames(ctx, bson.M{})
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListCollections 获取集合列表
|
||||
func (c *MongoClient) ListCollections(ctx context.Context, database string) ([]string, error) {
|
||||
db := c.client.Database(database)
|
||||
collections, err := db.ListCollectionNames(ctx, bson.M{})
|
||||
return collections, err
|
||||
}
|
||||
|
||||
// GetCollectionStructure 获取集合结构
|
||||
func (c *MongoClient) GetCollectionStructure(ctx context.Context, database, collectionName string) (map[string]interface{}, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"database": database,
|
||||
"collection": collectionName,
|
||||
"sampleDocs": []map[string]interface{}{},
|
||||
"fieldStats": map[string]int{},
|
||||
"indexes": []map[string]interface{}{},
|
||||
"documentCount": int64(0),
|
||||
}
|
||||
|
||||
// 获取文档示例(最多 5 个)
|
||||
opts := options.Find().SetLimit(5)
|
||||
cursor, err := coll.Find(ctx, bson.M{}, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档示例失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var docs []bson.M
|
||||
if err = cursor.All(ctx, &docs); err != nil {
|
||||
return nil, fmt.Errorf("解析文档失败: %v", err)
|
||||
}
|
||||
|
||||
// 转换为 map
|
||||
sampleDocs := make([]map[string]interface{}, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docMap := make(map[string]interface{})
|
||||
for k, v := range doc {
|
||||
docMap[k] = v
|
||||
}
|
||||
sampleDocs = append(sampleDocs, docMap)
|
||||
}
|
||||
result["sampleDocs"] = sampleDocs
|
||||
|
||||
// 字段统计:使用 $sample 聚合管道随机采样10个文档进行统计
|
||||
// 这样可以获得更准确的字段分布,同时保持良好性能
|
||||
// 使用异步方式执行,避免阻塞主流程
|
||||
sampleSize := 10
|
||||
pipeline := []bson.M{
|
||||
{"$sample": bson.M{"size": sampleSize}},
|
||||
{"$project": bson.M{"keys": bson.M{"$objectToArray": "$$ROOT"}}},
|
||||
{"$unwind": "$keys"},
|
||||
{"$group": bson.M{
|
||||
"_id": "$keys.k",
|
||||
"count": bson.M{"$sum": 1},
|
||||
}},
|
||||
{"$sort": bson.M{"count": -1}}, // 按出现次数降序排序
|
||||
}
|
||||
|
||||
sampleCursor, err := coll.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
// 如果采样失败,回退到基于文档示例的统计
|
||||
fieldCount := make(map[string]int)
|
||||
for _, doc := range docs {
|
||||
for key := range doc {
|
||||
fieldCount[key]++
|
||||
}
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = len(docs) // 记录实际采样数量
|
||||
result["fieldStatsMethod"] = "sample-docs" // 标记统计方式
|
||||
} else {
|
||||
defer sampleCursor.Close(ctx)
|
||||
fieldCount := make(map[string]int)
|
||||
for sampleCursor.Next(ctx) {
|
||||
var statResult bson.M
|
||||
if err := sampleCursor.Decode(&statResult); err != nil {
|
||||
continue
|
||||
}
|
||||
fieldName, ok := statResult["_id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var count int
|
||||
switch v := statResult["count"].(type) {
|
||||
case int32:
|
||||
count = int(v)
|
||||
case int64:
|
||||
count = int(v)
|
||||
case int:
|
||||
count = v
|
||||
case float64:
|
||||
count = int(v)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
fieldCount[fieldName] = count
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = sampleSize // 记录采样数量
|
||||
result["fieldStatsMethod"] = "sample-aggregate" // 标记统计方式
|
||||
}
|
||||
|
||||
// 文档总数(使用估算值,性能更好)
|
||||
// 对于大数据集,estimatedDocumentCount 比 CountDocuments 快得多
|
||||
// 如果需要精确值,可以使用 CountDocuments,但性能较差
|
||||
count, err := coll.EstimatedDocumentCount(ctx)
|
||||
if err != nil {
|
||||
// 如果估算失败,尝试精确计数(可能较慢)
|
||||
count, err = coll.CountDocuments(ctx, bson.M{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
result["documentCount"] = count
|
||||
|
||||
// 索引信息
|
||||
indexCursor, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
// 索引查询失败不影响主流程
|
||||
result["indexes"] = []map[string]interface{}{}
|
||||
} else {
|
||||
var indexes []map[string]interface{}
|
||||
for indexCursor.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := indexCursor.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
indexes = append(indexes, map[string]interface{}{
|
||||
"name": indexSpec["name"],
|
||||
"unique": indexSpec["unique"],
|
||||
"keys": indexSpec["key"],
|
||||
})
|
||||
}
|
||||
indexCursor.Close(ctx)
|
||||
result["indexes"] = indexes
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询
|
||||
func (c *MongoClient) ExecuteQuery(ctx context.Context, database, collection string, filter bson.M, limit int64) ([]map[string]interface{}, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// CountDocuments 获取文档数量
|
||||
func (c *MongoClient) CountDocuments(ctx context.Context, database, collection string, filter bson.M) (int64, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
return coll.CountDocuments(ctx, filter)
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 MongoDB 命令
|
||||
// command 可以是 JSON 格式的字符串,格式:{"op": "find", "database": "test", "collection": "users", "filter": {}, "limit": 100}
|
||||
// 支持的操作:find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany
|
||||
func (c *MongoClient) ExecuteCommand(ctx context.Context, database string, command map[string]interface{}) (interface{}, error) {
|
||||
op, ok := command["op"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'op' 字段或格式错误")
|
||||
}
|
||||
|
||||
collectionName, ok := command["collection"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'collection' 字段或格式错误")
|
||||
}
|
||||
|
||||
// 如果没有指定数据库,使用配置中的默认数据库
|
||||
if database == "" {
|
||||
if c.config != nil && c.config.Database != "" {
|
||||
database = c.config.Database
|
||||
} else {
|
||||
return nil, fmt.Errorf("需要指定数据库名称")
|
||||
}
|
||||
}
|
||||
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collectionName)
|
||||
|
||||
switch op {
|
||||
case "find":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
limit := int64(100)
|
||||
if l, ok := command["limit"]; ok {
|
||||
if limitVal, ok := l.(float64); ok {
|
||||
limit = int64(limitVal)
|
||||
} else if limitVal, ok := l.(int64); ok {
|
||||
limit = limitVal
|
||||
}
|
||||
}
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
|
||||
case "count":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
count, err := coll.CountDocuments(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("统计失败: %v", err)
|
||||
}
|
||||
return count, nil
|
||||
|
||||
case "insertOne":
|
||||
document, ok := command["document"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertOne 操作需要 'document' 字段")
|
||||
}
|
||||
|
||||
doc := bson.M{}
|
||||
if docMap, ok := document.(map[string]interface{}); ok {
|
||||
doc = bson.M(docMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("document 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertOne(ctx, doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedId": result.InsertedID,
|
||||
}, nil
|
||||
|
||||
case "insertMany":
|
||||
documents, ok := command["documents"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertMany 操作需要 'documents' 字段")
|
||||
}
|
||||
|
||||
docs := []interface{}{}
|
||||
if docsSlice, ok := documents.([]interface{}); ok {
|
||||
for _, d := range docsSlice {
|
||||
if docMap, ok := d.(map[string]interface{}); ok {
|
||||
docs = append(docs, bson.M(docMap))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("documents 必须是数组格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertMany(ctx, docs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedIds": result.InsertedIDs,
|
||||
"insertedCount": len(result.InsertedIDs),
|
||||
}, nil
|
||||
|
||||
case "updateOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateOne(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "updateMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateMany(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteOne(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteMany(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的操作: %s,支持的操作: find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany", op)
|
||||
}
|
||||
}
|
||||
|
||||
// PreviewCollectionIndexes 预览集合索引变更,只生成命令列表不执行
|
||||
func (c *MongoClient) PreviewCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
var commands []string
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
cmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项,并跟踪 unique 状态(v2: IndexOptionsBuilder 无 Unique 字段可读)
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
isUnique := false
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
isUnique = true
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
isUnique = true
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
dropCmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, dropCmd)
|
||||
}
|
||||
|
||||
// 构建命令字符串(MongoDB shell 格式)
|
||||
keysStr := "{"
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
keysStr += ", "
|
||||
}
|
||||
keysStr += fmt.Sprintf("%s: %d", key.Key, key.Value)
|
||||
}
|
||||
keysStr += "}"
|
||||
|
||||
optionsStr := "{name: \"" + name + "\""
|
||||
if isUnique {
|
||||
optionsStr += ", unique: true"
|
||||
}
|
||||
optionsStr += "}"
|
||||
|
||||
cmd := fmt.Sprintf("db.%s.createIndex(%s, %s)", collectionName, keysStr, optionsStr)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
// UpdateCollectionIndexes 更新集合索引,返回执行的命令列表
|
||||
func (c *MongoClient) UpdateCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成命令列表
|
||||
commands, err := c.PreviewCollectionIndexes(ctx, database, collectionName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
// v2: DropOne 只返回 error,不再返回 bson.Raw
|
||||
err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
indexModel := mongo.IndexModel{
|
||||
Keys: keys,
|
||||
Options: indexOptions,
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
// v2: DropOne 只返回 error,不再返回 bson.Raw
|
||||
err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除旧索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := coll.Indexes().CreateOne(ctx, indexModel)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("创建索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
@@ -1,875 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// MySQLClient MySQL 客户端
|
||||
type MySQLClient struct {
|
||||
db *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
config *MySQLConfig
|
||||
}
|
||||
|
||||
// MySQLConfig MySQL 配置
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// NewMySQLClient 创建 MySQL 客户端
|
||||
func NewMySQLClient(config *MySQLConfig) (*MySQLClient, error) {
|
||||
// 构建 DSN
|
||||
mysqlConfig := mysqldriver.Config{
|
||||
User: config.Username,
|
||||
Passwd: config.Password,
|
||||
Net: "tcp",
|
||||
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
DBName: config.Database,
|
||||
Params: map[string]string{
|
||||
"charset": "utf8mb4",
|
||||
"parseTime": "True",
|
||||
"loc": "Local",
|
||||
"multiStatements": "true", // 支持多条SQL语句执行
|
||||
},
|
||||
AllowNativePasswords: true,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
dsn := mysqlConfig.FormatDSN()
|
||||
|
||||
// GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
// 打开连接
|
||||
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MySQL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取底层 sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("MySQL 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(2)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
return &MySQLClient{
|
||||
db: db,
|
||||
sqlDB: sqlDB,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func TestMySQLConnection(host string, port int, username, password, database string) error {
|
||||
config := &MySQLConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
}
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c.sqlDB != nil {
|
||||
return c.sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryResult 查询结果,包含数据和列顺序
|
||||
type QueryResult struct {
|
||||
Data []map[string]interface{}
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询 SQL
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
// 注意:SQL 语句应该已经包含 LIMIT 和 OFFSET(由客户端添加)
|
||||
func (c *MySQLClient) ExecuteQuery(ctx context.Context, sqlStr string, database string) (*QueryResult, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return nil, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sqlStr).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("执行查询失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 检查 rows 错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("查询结果错误: %v", err)
|
||||
}
|
||||
|
||||
// 获取列名
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有列,返回空数组
|
||||
if len(columns) == 0 {
|
||||
return &QueryResult{
|
||||
Data: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
// 创建值数组和指针数组
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// 扫描行数据
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 构建结果 map,按照列顺序构建
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
// 处理 nil 值
|
||||
if val == nil {
|
||||
row[col] = nil
|
||||
} else if b, ok := val.([]byte); ok {
|
||||
// 处理 []byte 类型
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
// 检查迭代过程中的错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("读取数据时发生错误: %v", err)
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Data: results,
|
||||
Columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteUpdate 执行更新 SQL(INSERT/UPDATE/DELETE)
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
func (c *MySQLClient) ExecuteUpdate(ctx context.Context, sqlStr string, database string) (int64, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return 0, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sqlStr)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("执行更新失败: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MySQLClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
var databases []string
|
||||
err := c.db.Raw("SHOW DATABASES").Scan(&databases).Error
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListTables 获取表列表
|
||||
func (c *MySQLClient) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
var tables []string
|
||||
query := "SHOW TABLES"
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", database)
|
||||
}
|
||||
err := c.db.Raw(query).Scan(&tables).Error
|
||||
return tables, err
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (c *MySQLClient) GetTableStructure(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
// 使用 SHOW FULL COLUMNS 来获取包含 comment 的完整字段信息
|
||||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", database, tableName)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表结构失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else if val == nil {
|
||||
row[col] = nil
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Comment 字段存在(SHOW FULL COLUMNS 返回的字段名是 Comment)
|
||||
if _, ok := row["Comment"]; !ok {
|
||||
row["Comment"] = ""
|
||||
}
|
||||
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (c *MySQLClient) GetIndexes(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
query := "SHOW INDEX FROM "
|
||||
if database != "" {
|
||||
query += fmt.Sprintf("`%s`.", database)
|
||||
}
|
||||
query += fmt.Sprintf("`%s`", tableName)
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取索引列表失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更,只生成 SQL 语句不执行
|
||||
func (c *MySQLClient) PreviewTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 获取当前表结构
|
||||
currentColumns, err := c.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前表结构失败: %v", err)
|
||||
}
|
||||
|
||||
currentIndexes, err := c.GetIndexes(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析新的结构数据
|
||||
var newColumns []map[string]interface{}
|
||||
var newIndexes []map[string]interface{}
|
||||
|
||||
if cols, ok := structure["columns"].([]interface{}); ok {
|
||||
for _, col := range cols {
|
||||
if colMap, ok := col.(map[string]interface{}); ok {
|
||||
newColumns = append(newColumns, colMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ALTER TABLE 语句
|
||||
var alterStatements []string
|
||||
|
||||
// 处理字段变更
|
||||
alterStatements = append(alterStatements, c.buildColumnAlterStatements(tableName, currentColumns, newColumns)...)
|
||||
|
||||
// 处理索引变更
|
||||
alterStatements = append(alterStatements, c.buildIndexAlterStatements(tableName, currentIndexes, newIndexes)...)
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构,返回生成的 SQL 语句列表
|
||||
func (c *MySQLClient) UpdateTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成 SQL 语句
|
||||
alterStatements, err := c.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行所有 ALTER TABLE 语句
|
||||
if len(alterStatements) > 0 {
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, stmt := range alterStatements {
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("执行 ALTER TABLE 失败: %v, SQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// buildColumnAlterStatements 构建字段变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildColumnAlterStatements(tableName string, currentColumns, newColumns []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建字段名映射和顺序映射
|
||||
currentFieldMap := make(map[string]map[string]interface{})
|
||||
currentFieldOrder := make([]string, 0, len(currentColumns))
|
||||
for _, col := range currentColumns {
|
||||
if field, ok := col["Field"].(string); ok {
|
||||
currentFieldMap[field] = col
|
||||
currentFieldOrder = append(currentFieldOrder, field)
|
||||
}
|
||||
}
|
||||
|
||||
newFieldMap := make(map[string]bool)
|
||||
newFieldOrder := make([]string, 0, len(newColumns))
|
||||
newColumnsMap := make(map[string]map[string]interface{})
|
||||
for _, col := range newColumns {
|
||||
if field, ok := col["Field"].(string); ok && field != "" {
|
||||
newFieldMap[field] = true
|
||||
newFieldOrder = append(newFieldOrder, field)
|
||||
newColumnsMap[field] = col
|
||||
}
|
||||
}
|
||||
|
||||
// 检测字段重命名:优先使用位置匹配,如果位置相同但字段名不同,认为是重命名
|
||||
renameMap := make(map[string]string) // oldName -> newName
|
||||
processedNewFields := make(map[string]bool)
|
||||
|
||||
// 第一步:使用位置匹配检测重命名(最可靠)
|
||||
for oldIndex, oldFieldName := range currentFieldOrder {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
|
||||
// 检查新字段列表中相同位置是否有字段
|
||||
if oldIndex < len(newFieldOrder) {
|
||||
newFieldName := newFieldOrder[oldIndex]
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent && !processedNewFields[newFieldName] {
|
||||
// 新字段不在当前字段列表中,且位置相同,很可能是重命名
|
||||
// 进一步验证:检查类型是否相同(类型相同更可能是重命名)
|
||||
oldCol := currentFieldMap[oldFieldName]
|
||||
newCol := newColumnsMap[newFieldName]
|
||||
oldType := getStringValue(oldCol["Type"])
|
||||
newType := getStringValue(newCol["Type"])
|
||||
|
||||
// 如果类型相同,认为是重命名
|
||||
if oldType == newType {
|
||||
renameMap[oldFieldName] = newFieldName
|
||||
processedNewFields[newFieldName] = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二步:对于未匹配的字段,使用属性匹配(兼容旧逻辑)
|
||||
for oldFieldName, oldCol := range currentFieldMap {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
if renameMap[oldFieldName] != "" {
|
||||
continue // 已经通过位置匹配识别为重命名
|
||||
}
|
||||
|
||||
// 查找属性完全匹配的新字段
|
||||
var matchedNewField string
|
||||
for newFieldName, newCol := range newColumnsMap {
|
||||
if processedNewFields[newFieldName] {
|
||||
continue // 已经被匹配过了
|
||||
}
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent {
|
||||
// 这是一个新增字段,检查属性是否匹配
|
||||
if c.isColumnPropertiesEqual(oldCol, newCol) {
|
||||
if matchedNewField == "" {
|
||||
matchedNewField = newFieldName
|
||||
} else {
|
||||
// 有多个匹配,无法确定,不认为是重命名
|
||||
matchedNewField = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到唯一匹配,认为是重命名
|
||||
if matchedNewField != "" {
|
||||
renameMap[oldFieldName] = matchedNewField
|
||||
processedNewFields[matchedNewField] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理字段重命名
|
||||
for oldName, newName := range renameMap {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tableName, oldName, newName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// 处理字段添加、修改和位置调整(排除已重命名的字段)
|
||||
for i, newCol := range newColumns {
|
||||
field, _ := newCol["Field"].(string)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是重命名的字段
|
||||
isRenamed := false
|
||||
var oldName string
|
||||
for old, new := range renameMap {
|
||||
if new == field {
|
||||
isRenamed = true
|
||||
oldName = old
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isRenamed {
|
||||
// 重命名的字段:如果属性有变化,需要 MODIFY COLUMN
|
||||
oldCol := currentFieldMap[oldName]
|
||||
needsModify := c.isColumnChanged(oldCol, newCol)
|
||||
|
||||
// 检查顺序变化:使用旧字段名在 currentOrder 中查找位置,与新位置比较
|
||||
oldIndex := -1
|
||||
for idx, name := range currentFieldOrder {
|
||||
if name == oldName {
|
||||
oldIndex = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
needsReorder := (oldIndex != -1 && oldIndex != i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
// 重命名后需要修改属性或位置
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentCol, exists := currentFieldMap[field]; exists {
|
||||
// 修改现有字段
|
||||
needsModify := c.isColumnChanged(currentCol, newCol)
|
||||
needsReorder := c.isColumnOrderChanged(currentFieldOrder, newFieldOrder, field, i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新字段(排除重命名的字段)
|
||||
stmt := c.buildAddColumnStatement(tableName, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的字段(排除已重命名的字段)
|
||||
for field := range currentFieldMap {
|
||||
if !newFieldMap[field] && renameMap[field] == "" {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`", tableName, field)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// buildIndexAlterStatements 构建索引变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildIndexAlterStatements(tableName string, currentIndexes, newIndexes []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建索引名映射
|
||||
currentIndexMap := make(map[string]map[string]interface{})
|
||||
for _, idx := range currentIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "PRIMARY" {
|
||||
currentIndexMap[keyName] = idx
|
||||
}
|
||||
}
|
||||
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "" && keyName != "PRIMARY" {
|
||||
newIndexMap[keyName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理索引变更
|
||||
for _, newIdx := range newIndexes {
|
||||
keyName, _ := newIdx["Key_name"].(string)
|
||||
if keyName == "" || keyName == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentIdx, exists := currentIndexMap[keyName]; exists {
|
||||
// 修改现有索引
|
||||
if c.isIndexChanged(currentIdx, newIdx) {
|
||||
dropStmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
addStmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if addStmt != "" {
|
||||
statements = append(statements, dropStmt)
|
||||
statements = append(statements, addStmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新索引
|
||||
stmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for keyName := range currentIndexMap {
|
||||
if !newIndexMap[keyName] {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// isColumnChanged 检查字段是否发生变化(不包括字段名)
|
||||
func (c *MySQLClient) isColumnChanged(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isColumnPropertiesEqual 检查字段属性是否完全相等(不包括字段名)
|
||||
func (c *MySQLClient) isColumnPropertiesEqual(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Key", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnOrderChanged 检查字段顺序是否发生变化
|
||||
func (c *MySQLClient) isColumnOrderChanged(currentOrder, newOrder []string, fieldName string, newIndex int) bool {
|
||||
// 查找字段在当前顺序中的位置
|
||||
currentIndex := -1
|
||||
for i, name := range currentOrder {
|
||||
if name == fieldName {
|
||||
currentIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果字段不存在于当前顺序中(新字段),不需要检查顺序
|
||||
if currentIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果索引相同,检查前面的字段是否相同
|
||||
if newIndex == currentIndex {
|
||||
// 检查前面的字段集合是否相同
|
||||
if newIndex > 0 {
|
||||
currentPrevFields := make(map[string]bool)
|
||||
for i := 0; i < currentIndex; i++ {
|
||||
currentPrevFields[currentOrder[i]] = true
|
||||
}
|
||||
|
||||
newPrevFields := make(map[string]bool)
|
||||
for i := 0; i < newIndex; i++ {
|
||||
newPrevFields[newOrder[i]] = true
|
||||
}
|
||||
|
||||
// 如果前面的字段集合不同,说明顺序变了
|
||||
if len(currentPrevFields) != len(newPrevFields) {
|
||||
return true
|
||||
}
|
||||
for f := range currentPrevFields {
|
||||
if !newPrevFields[f] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 索引不同,说明顺序变了
|
||||
return true
|
||||
}
|
||||
|
||||
// isIndexChanged 检查索引是否发生变化
|
||||
func (c *MySQLClient) isIndexChanged(oldIdx, newIdx map[string]interface{}) bool {
|
||||
oldCol := getStringValue(oldIdx["Column_name"])
|
||||
newCol := getStringValue(newIdx["Column_name"])
|
||||
if oldCol != newCol {
|
||||
return true
|
||||
}
|
||||
|
||||
oldUnique := getIntValue(oldIdx["Non_unique"])
|
||||
newUnique := getIntValue(newIdx["Non_unique"])
|
||||
return oldUnique != newUnique
|
||||
}
|
||||
|
||||
// buildAddColumnStatement 构建添加字段的语句
|
||||
func (c *MySQLClient) buildAddColumnStatement(tableName string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
field := getStringValue(col["Field"])
|
||||
if field == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildModifyColumnStatement 构建修改字段的语句
|
||||
func (c *MySQLClient) buildModifyColumnStatement(tableName, field string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` MODIFY COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildColumnPosition 构建字段位置子句(AFTER 或 FIRST)
|
||||
func (c *MySQLClient) buildColumnPosition(fieldOrder []string, index int) string {
|
||||
if index < 0 || index >= len(fieldOrder) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
// 第一个字段使用 FIRST
|
||||
return " FIRST"
|
||||
}
|
||||
|
||||
// 其他字段使用 AFTER 前一个字段
|
||||
prevField := fieldOrder[index-1]
|
||||
return fmt.Sprintf(" AFTER `%s`", prevField)
|
||||
}
|
||||
|
||||
// buildColumnDefinition 构建字段定义
|
||||
func (c *MySQLClient) buildColumnDefinition(col map[string]interface{}) string {
|
||||
field := getStringValue(col["Field"])
|
||||
colType := getStringValue(col["Type"])
|
||||
null := getStringValue(col["Null"])
|
||||
defaultVal := col["Default"]
|
||||
extra := getStringValue(col["Extra"])
|
||||
comment := getStringValue(col["Comment"])
|
||||
|
||||
def := fmt.Sprintf("`%s` %s", field, colType)
|
||||
|
||||
if null == "NO" {
|
||||
def += " NOT NULL"
|
||||
}
|
||||
|
||||
if defaultVal != nil {
|
||||
if defaultStr, ok := defaultVal.(string); ok {
|
||||
if defaultStr == "" {
|
||||
// 空字符串表示默认值为空字符串
|
||||
def += " DEFAULT ''"
|
||||
} else if defaultStr != "NULL" {
|
||||
// 转义单引号
|
||||
escapedDefault := strings.ReplaceAll(defaultStr, "'", "''")
|
||||
def += fmt.Sprintf(" DEFAULT '%s'", escapedDefault)
|
||||
}
|
||||
// 如果 defaultStr == "NULL",不添加 DEFAULT 子句(允许 NULL)
|
||||
} else {
|
||||
// 非字符串类型的默认值
|
||||
def += fmt.Sprintf(" DEFAULT %v", defaultVal)
|
||||
}
|
||||
}
|
||||
|
||||
if extra != "" {
|
||||
def += " " + extra
|
||||
}
|
||||
|
||||
if comment != "" {
|
||||
// 转义单引号
|
||||
escapedComment := strings.ReplaceAll(comment, "'", "''")
|
||||
def += fmt.Sprintf(" COMMENT '%s'", escapedComment)
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// buildAddIndexStatement 构建添加索引的语句
|
||||
func (c *MySQLClient) buildAddIndexStatement(tableName string, idx map[string]interface{}) string {
|
||||
keyName := getStringValue(idx["Key_name"])
|
||||
columnName := getStringValue(idx["Column_name"])
|
||||
nonUnique := getIntValue(idx["Non_unique"])
|
||||
|
||||
if keyName == "" || columnName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
indexType := "INDEX"
|
||||
if nonUnique == 0 {
|
||||
indexType = "UNIQUE INDEX"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (`%s`)", tableName, indexType, keyName, columnName)
|
||||
}
|
||||
|
||||
// getStringValue 安全获取字符串值
|
||||
func getStringValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// getIntValue 安全获取整数值
|
||||
func getIntValue(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var i int
|
||||
fmt.Sscanf(val, "%d", &i)
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/common"
|
||||
"u-desk/internal/crypto"
|
||||
"u-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionPool 连接池管理器
|
||||
type ConnectionPool struct {
|
||||
mysqlClients map[uint]*MySQLClient
|
||||
redisClients map[uint]*RedisClient
|
||||
mongoClients map[uint]*MongoClient
|
||||
|
||||
// 新增:MySQL 真连接池
|
||||
mysqlPool *MySQLConnectionPool
|
||||
|
||||
// 查询优化器
|
||||
queryOptimizer *QueryOptimizer
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
globalPool *ConnectionPool
|
||||
poolOnce sync.Once
|
||||
)
|
||||
|
||||
// GetPool 获取全局连接池实例
|
||||
func GetPool() *ConnectionPool {
|
||||
poolOnce.Do(func() {
|
||||
// 创建 MySQL 连接池
|
||||
poolConfig := DefaultPoolConfig()
|
||||
|
||||
mysqlPool := NewMySQLConnectionPool(poolConfig)
|
||||
// 启动维护协程
|
||||
mysqlPool.StartMaintenance()
|
||||
|
||||
// 创建查询优化器
|
||||
queryOptimizer := NewQueryOptimizer(nil)
|
||||
|
||||
globalPool = &ConnectionPool{
|
||||
mysqlClients: make(map[uint]*MySQLClient),
|
||||
redisClients: make(map[uint]*RedisClient),
|
||||
mongoClients: make(map[uint]*MongoClient),
|
||||
mysqlPool: mysqlPool,
|
||||
queryOptimizer: queryOptimizer,
|
||||
}
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// PooledClient 带释放语义的客户端包装
|
||||
type PooledClient struct {
|
||||
Client *MySQLClient
|
||||
entry *MySQLPoolEntry
|
||||
pool *MySQLConnectionPool
|
||||
fromPool bool
|
||||
}
|
||||
|
||||
// Release 释放连接回连接池
|
||||
func (pc *PooledClient) Release() {
|
||||
if pc.fromPool && pc.pool != nil && pc.entry != nil {
|
||||
pc.pool.Release(pc.entry)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMySQLClient 获取或创建 MySQL 客户端(使用连接池)
|
||||
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) *PooledClient {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 尝试从连接池获取连接
|
||||
if p.mysqlPool != nil {
|
||||
entry, err := p.mysqlPool.Acquire(conn)
|
||||
if err == nil {
|
||||
return &PooledClient{Client: entry.Client, entry: entry, pool: p.mysqlPool, fromPool: true}
|
||||
}
|
||||
p.logPoolError("Acquire failed", err)
|
||||
}
|
||||
|
||||
// 降级到原有逻辑
|
||||
client, err := p.getMySQLClientLegacy(conn)
|
||||
if err != nil {
|
||||
return &PooledClient{Client: nil, fromPool: false}
|
||||
}
|
||||
return &PooledClient{Client: client, fromPool: false}
|
||||
}
|
||||
|
||||
// logPoolError 记录连接池错误
|
||||
func (p *ConnectionPool) logPoolError(operation string, err error) {
|
||||
if p.queryOptimizer != nil {
|
||||
// 通过查询优化器记录错误
|
||||
p.queryOptimizer.RecordPoolError(operation, err)
|
||||
}
|
||||
}
|
||||
|
||||
// getMySQLClientLegacy 原有的 MySQL 客户端获取逻辑(向后兼容)
|
||||
func (p *ConnectionPool) getMySQLClientLegacy(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mysqlClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
if err := client.sqlDB.Ping(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mysqlClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MySQLConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password, // 如果密码为空,MySQL会尝试无密码连接
|
||||
Database: conn.Database,
|
||||
}
|
||||
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mysqlClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetMySQLPoolStats 获取 MySQL 连接池统计信息
|
||||
func (p *ConnectionPool) GetMySQLPoolStats() *PoolStats {
|
||||
if p.mysqlPool != nil {
|
||||
stats := p.mysqlPool.Stats()
|
||||
return &stats
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OptimizeQuery 优化查询执行
|
||||
func (p *ConnectionPool) OptimizeQuery(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (*QueryResult, time.Duration, error) {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, 0, fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.OptimizeQuery(ctx, pc.Client, sqlStr, database)
|
||||
}
|
||||
|
||||
// 降级到普通查询
|
||||
startTime := time.Now()
|
||||
result, err := pc.Client.ExecuteQuery(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// ExecuteOptimizedUpdate 执行优化的更新操作
|
||||
func (p *ConnectionPool) ExecuteOptimizedUpdate(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return 0, 0, fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.ExecuteOptimizedUpdate(ctx, pc.Client, sqlStr, database)
|
||||
}
|
||||
|
||||
// 降级到普通更新
|
||||
startTime := time.Now()
|
||||
result, err := pc.Client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (p *ConnectionPool) GetQueryStats() QueryStats {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetQueryStats()
|
||||
}
|
||||
return QueryStats{}
|
||||
}
|
||||
|
||||
// GetSlowQueries 获取慢查询记录
|
||||
func (p *ConnectionPool) GetSlowQueries(limit int) []SlowQuery {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetSlowQueries(limit)
|
||||
}
|
||||
return []SlowQuery{}
|
||||
}
|
||||
|
||||
// GetIndexSuggestions 获取索引建议
|
||||
func (p *ConnectionPool) GetIndexSuggestions(table string) []IndexSuggestion {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetIndexSuggestions(table)
|
||||
}
|
||||
return []IndexSuggestion{}
|
||||
}
|
||||
|
||||
// GenerateIndexSuggestions 为表生成索引建议
|
||||
func (p *ConnectionPool) GenerateIndexSuggestions(ctx context.Context, conn *models.DbConnection, database, table string) error {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GenerateIndexSuggestions(ctx, pc.Client, database, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearQueryCache 清空查询缓存
|
||||
func (p *ConnectionPool) ClearQueryCache() {
|
||||
if p.queryOptimizer != nil {
|
||||
p.queryOptimizer.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRedisClient 获取或创建 Redis 客户端
|
||||
func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.redisClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutPing)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx).Err(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.redisClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Redis DB 编号(从 Database 字段,默认为 0)
|
||||
dbNum := 0
|
||||
if conn.Database != "" {
|
||||
// 尝试解析 Database 字段为数字
|
||||
_, err := fmt.Sscanf(conn.Database, "%d", &dbNum)
|
||||
if err != nil {
|
||||
// 如果解析失败,使用默认值 0
|
||||
dbNum = 0
|
||||
}
|
||||
// 限制 DB 编号在 0-15 之间
|
||||
if dbNum < 0 || dbNum > 15 {
|
||||
dbNum = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &RedisConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.redisClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetMongoClient 获取或创建 MongoDB 客户端
|
||||
func (p *ConnectionPool) GetMongoClient(conn *models.DbConnection) (*MongoClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mongoClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutPing)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx, nil); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mongoClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MongoConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password,
|
||||
Database: conn.Database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mongoClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// CloseConnection 关闭指定连接
|
||||
func (p *ConnectionPool) CloseConnection(connID uint, dbType string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
if client, ok := p.mysqlClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mysqlClients, connID)
|
||||
}
|
||||
case "redis":
|
||||
if client, ok := p.redisClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.redisClients, connID)
|
||||
}
|
||||
case "mongo":
|
||||
if client, ok := p.mongoClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mongoClients, connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll 关闭所有连接
|
||||
func (p *ConnectionPool) CloseAll() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for _, client := range p.mysqlClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.redisClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.mongoClients {
|
||||
client.Close()
|
||||
}
|
||||
|
||||
p.mysqlClients = make(map[uint]*MySQLClient)
|
||||
p.redisClients = make(map[uint]*RedisClient)
|
||||
p.mongoClients = make(map[uint]*MongoClient)
|
||||
}
|
||||
@@ -1,679 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/crypto"
|
||||
"u-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// PoolConfig 连接池配置
|
||||
type PoolConfig struct {
|
||||
// 最大打开连接数(硬上限)
|
||||
MaxOpenConns int
|
||||
// 最大空闲连接数(超过此数量的空闲连接会被关闭)
|
||||
MaxIdleConns int
|
||||
// 连接最大生命周期(超过此时间的连接会被关闭)
|
||||
ConnMaxLifetime time.Duration
|
||||
// 连接最大空闲时间(超过此时间未使用的连接会被关闭)
|
||||
ConnMaxIdleTime time.Duration
|
||||
// 最小空闲连接数(保持此数量的空闲连接以快速响应)
|
||||
MinIdleConns int
|
||||
// 连接超时时间(建立连接的最长时间)
|
||||
ConnTimeout time.Duration
|
||||
// 健康检查间隔(定期 Ping 连接检查有效性)
|
||||
HealthCheckInterval time.Duration
|
||||
// 是否启用连接预热(启动时建立最小连接)
|
||||
EnableWarmup bool
|
||||
// 是否启用慢连接日志(记录建立时间超过阈值的连接)
|
||||
EnableSlowConnLog bool
|
||||
// 慢连接阈值(超过此时间记录为慢连接)
|
||||
SlowConnThreshold time.Duration
|
||||
// 连接池最大容量(防止资源耗尽)
|
||||
MaxPoolCapacity int
|
||||
|
||||
// 动态连接池配置
|
||||
EnableDynamicScaling bool // 是否启用动态连接池调整
|
||||
DynamicScaleFactor float64 // 动态调整因子(0.5-2.0)
|
||||
ScaleUpThreshold float64 // 扩容阈值(0-1.0,当使用率超过此值时扩容)
|
||||
ScaleDownThreshold float64 // 缩容阈值(0-1.0,当使用率低于此值时缩容)
|
||||
MinScaleUpInterval time.Duration // 最小扩容间隔(防止频繁调整)
|
||||
MinScaleDownInterval time.Duration // 最小缩容间隔
|
||||
MaxIdleTimeForScale time.Duration // 用于动态调整的最大空闲时间
|
||||
}
|
||||
|
||||
// DefaultPoolConfig 返回默认连接池配置
|
||||
func DefaultPoolConfig() *PoolConfig {
|
||||
return &PoolConfig{
|
||||
MaxOpenConns: 50, // 最大50个连接(提高并发)
|
||||
MaxIdleConns: 20, // 最大20个空闲(提高响应速度)
|
||||
ConnMaxLifetime: 60 * time.Minute, // 连接最长60分钟(延长连接生命周期)
|
||||
ConnMaxIdleTime: 15 * time.Minute, // 空闲15分钟关闭(更长的空闲时间)
|
||||
MinIdleConns: 5, // 保持5个最小空闲(更好的响应性能)
|
||||
ConnTimeout: 3 * time.Second, // 连接超时3秒(更快失败)
|
||||
HealthCheckInterval: 20 * time.Second, // 20秒健康检查一次(更频繁的健康检查)
|
||||
EnableWarmup: true, // 启用预热
|
||||
EnableSlowConnLog: true, // 启用慢连接日志
|
||||
SlowConnThreshold: 200 * time.Millisecond, // 超过200ms算慢连接(更严格的性能要求)
|
||||
MaxPoolCapacity: 100, // 连接池最大容量(支持更高并发)
|
||||
|
||||
// 动态连接池配置(更智能的调整策略)
|
||||
EnableDynamicScaling: true, // 启用动态调整
|
||||
DynamicScaleFactor: 1.8, // 调整因子1.8倍(更激进的扩容)
|
||||
ScaleUpThreshold: 0.7, // 使用率超过70%扩容(更早扩容)
|
||||
ScaleDownThreshold: 0.4, // 使用率低于40%缩容(避免频繁调整)
|
||||
MinScaleUpInterval: 1 * time.Minute, // 最小扩容间隔1分钟(更快的响应)
|
||||
MinScaleDownInterval: 3 * time.Minute, // 最小缩容间隔3分钟(稳定缩容)
|
||||
MaxIdleTimeForScale: 20 * time.Minute, // 用于调整的最大空闲时间
|
||||
}
|
||||
}
|
||||
|
||||
// MySQLPoolEntry MySQL 连接池条目
|
||||
type MySQLPoolEntry struct {
|
||||
Client *MySQLClient
|
||||
LastUsed time.Time
|
||||
CreatedAt time.Time
|
||||
InUse bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// AcquireResult 连接获取结果
|
||||
type AcquireResult struct {
|
||||
Entry *MySQLPoolEntry
|
||||
Err error
|
||||
}
|
||||
|
||||
// ReleaseResult 连接释放结果
|
||||
type ReleaseResult struct {
|
||||
Success bool
|
||||
Err error
|
||||
}
|
||||
|
||||
// Stats 连接池统计信息
|
||||
type PoolStats struct {
|
||||
TotalConns int // 总连接数
|
||||
ActiveConns int // 使用中的连接数
|
||||
IdleConns int // 空闲连接数
|
||||
WaitCount int64 // 等待连接的次数
|
||||
WaitDuration time.Duration // 总等待时间
|
||||
SlowConnCount int64 // 慢连接数量
|
||||
}
|
||||
|
||||
// MySQLConnectionPool MySQL 连接池(真正的连接池)
|
||||
type MySQLConnectionPool struct {
|
||||
config *PoolConfig
|
||||
configHash string // 配置哈希,用于检测配置变更
|
||||
mu sync.RWMutex
|
||||
entries []*MySQLPoolEntry // 连接池条目
|
||||
connMap map[uint]*MySQLClient // 连接ID -> 客户端映射(兼容现有代码)
|
||||
stats PoolStats
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 动态调整相关
|
||||
lastScaleUpTime time.Time // 上次扩容时间
|
||||
lastScaleDownTime time.Time // 上次缩容时间
|
||||
currentTargetSize int // 当前目标连接数
|
||||
usageHistory []float64 // 使用率历史记录(用于智能调整)
|
||||
adaptiveWeights map[uint]float64 // 连接权重(基于性能表现)
|
||||
}
|
||||
|
||||
// NewMySQLConnectionPool 创建新的 MySQL 连接池
|
||||
func NewMySQLConnectionPool(config *PoolConfig) *MySQLConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultPoolConfig()
|
||||
}
|
||||
|
||||
pool := &MySQLConnectionPool{
|
||||
config: config,
|
||||
entries: make([]*MySQLPoolEntry, 0, config.MaxPoolCapacity),
|
||||
connMap: make(map[uint]*MySQLClient),
|
||||
stopCh: make(chan struct{}),
|
||||
currentTargetSize: config.MinIdleConns,
|
||||
usageHistory: make([]float64, 0, 100), // 保留最近100个使用率记录
|
||||
adaptiveWeights: make(map[uint]float64),
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// Acquire 获取一个连接(阻塞等待直到有可用连接)
|
||||
func (p *MySQLConnectionPool) Acquire(conn *models.DbConnection) (*MySQLPoolEntry, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 尝试获取最优连接(启用动态调整时)
|
||||
if p.config.EnableDynamicScaling {
|
||||
if entry, err := p.getOptimalConnection(); err == nil {
|
||||
p.updateWaitStats(startTime)
|
||||
return entry, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 降级到标准逻辑 - 查找空闲连接
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
entry.InUse = true
|
||||
entry.LastUsed = time.Now()
|
||||
entry.mu.Unlock()
|
||||
|
||||
// 更新统计
|
||||
p.updateWaitStats(startTime)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
// 没有可用连接,创建新连接
|
||||
if len(p.entries) >= p.config.MaxOpenConns {
|
||||
// 已达到最大连接数,等待
|
||||
return p.waitForAvailableConnection(conn)
|
||||
}
|
||||
|
||||
// 创建新连接(使用传入的连接配置)
|
||||
newEntry, err := p.createNewEntry(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建连接失败: %v", err)
|
||||
}
|
||||
|
||||
p.entries = append(p.entries, newEntry)
|
||||
p.updateStats()
|
||||
p.updateWaitStats(startTime)
|
||||
|
||||
return newEntry, nil
|
||||
}
|
||||
|
||||
// Release 释放连接回池中
|
||||
func (p *MySQLConnectionPool) Release(entry *MySQLPoolEntry) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
entry.mu.Lock()
|
||||
entry.InUse = false
|
||||
entry.LastUsed = time.Now()
|
||||
entry.mu.Unlock()
|
||||
|
||||
p.updateStats()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (p *MySQLConnectionPool) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 发送停止信号
|
||||
close(p.stopCh)
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
p.wg.Wait()
|
||||
|
||||
// 关闭所有连接
|
||||
var lastErr error
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if err := entry.Client.Close(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
entry.InUse = false
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
p.entries = make([]*MySQLPoolEntry, 0, p.config.MaxPoolCapacity)
|
||||
p.connMap = make(map[uint]*MySQLClient)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Stats 获取连接池统计信息
|
||||
func (p *MySQLConnectionPool) Stats() PoolStats {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.stats
|
||||
}
|
||||
|
||||
// cleanupIdleConnections 清理空闲连接
|
||||
func (p *MySQLConnectionPool) cleanupIdleConnections() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
keepEntries := make([]*MySQLPoolEntry, 0, len(p.entries))
|
||||
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
isIdle := !entry.InUse
|
||||
idleDuration := now.Sub(entry.LastUsed)
|
||||
entry.mu.Unlock()
|
||||
|
||||
// 保留条件:正在使用 或 空闲时间未超过阈值 或 数量少于最小空闲数
|
||||
keep := !isIdle ||
|
||||
idleDuration < p.config.ConnMaxIdleTime ||
|
||||
len(keepEntries) < p.config.MinIdleConns
|
||||
|
||||
if keep {
|
||||
keepEntries = append(keepEntries, entry)
|
||||
} else {
|
||||
// 关闭连接
|
||||
entry.Client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
p.entries = keepEntries
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// healthCheck 健康检查(增强版本)
|
||||
func (p *MySQLConnectionPool) healthCheck() {
|
||||
p.enhancedHealthCheck()
|
||||
}
|
||||
|
||||
// StartMaintenance 启动维护协程(清理和健康检查)
|
||||
func (p *MySQLConnectionPool) StartMaintenance() {
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// 健康检查Ticker
|
||||
healthTicker := time.NewTicker(p.config.HealthCheckInterval)
|
||||
defer healthTicker.Stop()
|
||||
|
||||
// 动态调整Ticker(较短间隔)
|
||||
scaleTicker := time.NewTicker(1 * time.Minute)
|
||||
defer scaleTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-healthTicker.C:
|
||||
// 清理空闲连接
|
||||
p.cleanupIdleConnections()
|
||||
// 健康检查
|
||||
p.healthCheck()
|
||||
|
||||
case <-scaleTicker.C:
|
||||
// 动态连接池调整
|
||||
if p.config.EnableDynamicScaling {
|
||||
p.adaptiveScaling()
|
||||
}
|
||||
|
||||
case <-p.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// createNewEntry 创建新的连接池条目
|
||||
func (p *MySQLConnectionPool) createNewEntry(conn *models.DbConnection) (*MySQLPoolEntry, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
client, err := createMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
|
||||
// 慢连接日志
|
||||
if p.config.EnableSlowConnLog && elapsed > p.config.SlowConnThreshold {
|
||||
// 记录慢连接
|
||||
p.mu.Lock()
|
||||
p.stats.SlowConnCount++
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
entry := &MySQLPoolEntry{
|
||||
Client: client,
|
||||
LastUsed: time.Now(),
|
||||
CreatedAt: startTime,
|
||||
InUse: true,
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// waitForAvailableConnection 等待可用连接并获取它
|
||||
func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnection) (*MySQLPoolEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ErrPoolExhausted
|
||||
case <-ticker.C:
|
||||
p.mu.Lock()
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
entry.InUse = true
|
||||
entry.LastUsed = time.Now()
|
||||
entry.mu.Unlock()
|
||||
p.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateWaitStats 更新等待统计(调用方必须持有 p.mu)
|
||||
func (p *MySQLConnectionPool) updateWaitStats(startTime time.Time) {
|
||||
p.stats.WaitCount++
|
||||
p.stats.WaitDuration += time.Since(startTime)
|
||||
}
|
||||
|
||||
// updateStats 更新连接池统计
|
||||
func (p *MySQLConnectionPool) updateStats() {
|
||||
total := len(p.entries)
|
||||
active := 0
|
||||
idle := 0
|
||||
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if entry.InUse {
|
||||
active++
|
||||
} else {
|
||||
idle++
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
p.stats.TotalConns = total
|
||||
p.stats.ActiveConns = active
|
||||
p.stats.IdleConns = idle
|
||||
}
|
||||
|
||||
// adaptiveScaling 自适应连接池调整
|
||||
func (p *MySQLConnectionPool) adaptiveScaling() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 计算当前使用率
|
||||
if len(p.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
usageRate := float64(p.stats.ActiveConns) / float64(len(p.entries))
|
||||
|
||||
// 记录使用率历史
|
||||
p.usageHistory = append(p.usageHistory, usageRate)
|
||||
if len(p.usageHistory) > 100 {
|
||||
p.usageHistory = p.usageHistory[1:]
|
||||
}
|
||||
|
||||
// 检查是否需要调整
|
||||
now := time.Now()
|
||||
|
||||
// 扩容逻辑
|
||||
if usageRate >= p.config.ScaleUpThreshold {
|
||||
if now.Sub(p.lastScaleUpTime) >= p.config.MinScaleUpInterval {
|
||||
p.scaleUp()
|
||||
p.lastScaleUpTime = now
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 缩容逻辑
|
||||
if usageRate <= p.config.ScaleDownThreshold && len(p.entries) > p.config.MinIdleConns {
|
||||
if now.Sub(p.lastScaleDownTime) >= p.config.MinScaleDownInterval {
|
||||
p.scaleDown()
|
||||
p.lastScaleDownTime = now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scaleUp 扩容
|
||||
func (p *MySQLConnectionPool) scaleUp() {
|
||||
// scaleUp 仅更新目标大小,实际连接在 Acquire 时按需创建
|
||||
// 移除了创建无效虚拟连接的逻辑
|
||||
currentSize := len(p.entries)
|
||||
scaleFactor := p.config.DynamicScaleFactor
|
||||
|
||||
newSize := int(float64(currentSize) * scaleFactor)
|
||||
newSize = min(newSize, p.config.MaxOpenConns)
|
||||
newSize = max(newSize, currentSize+1)
|
||||
|
||||
p.currentTargetSize = newSize
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// scaleDown 缩容
|
||||
func (p *MySQLConnectionPool) scaleDown() {
|
||||
// 计算新目标大小
|
||||
currentSize := len(p.entries)
|
||||
scaleFactor := 1.0 / p.config.DynamicScaleFactor
|
||||
|
||||
newSize := int(float64(currentSize) * scaleFactor)
|
||||
newSize = max(newSize, p.config.MinIdleConns)
|
||||
newSize = min(newSize, currentSize-1) // 至少减少1个连接
|
||||
|
||||
if newSize < currentSize {
|
||||
// 关闭多余的空闲连接
|
||||
p.closeIdleConnections(currentSize - newSize)
|
||||
p.currentTargetSize = newSize
|
||||
p.updateStats()
|
||||
}
|
||||
}
|
||||
|
||||
// closeIdleConnections 关闭指定数量的空闲连接
|
||||
func (p *MySQLConnectionPool) closeIdleConnections(count int) {
|
||||
// 收集空闲连接
|
||||
idleEntries := make([]*MySQLPoolEntry, 0)
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
idleEntries = append(idleEntries, entry)
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
// 关闭指定数量的空闲连接
|
||||
closedEntries := make(map[*MySQLPoolEntry]bool)
|
||||
for i := 0; i < min(count, len(idleEntries)); i++ {
|
||||
entry := idleEntries[i]
|
||||
entry.mu.Lock()
|
||||
entry.Client.Close()
|
||||
entry.mu.Unlock()
|
||||
closedEntries[entry] = true
|
||||
}
|
||||
|
||||
// 重新构建连接池
|
||||
remainingEntries := make([]*MySQLPoolEntry, 0, len(p.entries))
|
||||
for _, entry := range p.entries {
|
||||
if closedEntries[entry] {
|
||||
continue // 跳过已关闭的连接
|
||||
}
|
||||
remainingEntries = append(remainingEntries, entry)
|
||||
}
|
||||
|
||||
p.entries = remainingEntries
|
||||
}
|
||||
|
||||
// enhancedHealthCheck 增强的健康检查
|
||||
func (p *MySQLConnectionPool) enhancedHealthCheck() {
|
||||
p.mu.RLock()
|
||||
entriesCopy := make([]*MySQLPoolEntry, len(p.entries))
|
||||
copy(entriesCopy, p.entries)
|
||||
p.mu.RUnlock()
|
||||
|
||||
var healthyEntries []*MySQLPoolEntry
|
||||
var performanceWeights []float64
|
||||
|
||||
for _, entry := range entriesCopy {
|
||||
entry.mu.Lock()
|
||||
isIdle := !entry.InUse
|
||||
|
||||
// 测试连接有效性
|
||||
isHealthy := true
|
||||
startTime := time.Now()
|
||||
|
||||
if isIdle {
|
||||
// 空闲连接:简单Ping测试
|
||||
if err := entry.Client.sqlDB.Ping(); err != nil {
|
||||
isHealthy = false
|
||||
// 关闭失效连接
|
||||
entry.Client.Close()
|
||||
}
|
||||
} else {
|
||||
// 使用中的连接:快速测试(避免影响正常查询)
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := entry.Client.sqlDB.PingContext(ctx); err != nil {
|
||||
isHealthy = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 计算连接性能权重
|
||||
if isHealthy {
|
||||
healthyEntries = append(healthyEntries, entry)
|
||||
|
||||
// 基于连接性能计算权重
|
||||
responseTime := time.Since(startTime).Microseconds()
|
||||
weight := 1.0 / max(float64(responseTime)/1000.0, 1.0) // 转换为毫秒,避免除零
|
||||
|
||||
performanceWeights = append(performanceWeights, weight)
|
||||
} else {
|
||||
// 不健康的连接
|
||||
if isIdle {
|
||||
entry.Client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
// 更新连接池
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.entries = healthyEntries
|
||||
|
||||
// 更新自适应权重
|
||||
if len(healthyEntries) > 0 {
|
||||
for i := range healthyEntries {
|
||||
if i < len(performanceWeights) {
|
||||
p.adaptiveWeights[uint(i)] = performanceWeights[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// warmUp 连接池预热
|
||||
func (p *MySQLConnectionPool) warmUp() {
|
||||
if !p.config.EnableWarmup {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
currentIdle := 0
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
currentIdle++
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
targetIdle := p.config.MinIdleConns
|
||||
needed := targetIdle - currentIdle
|
||||
|
||||
// warmUp 仅记录目标大小,不在无连接配置的情况下创建无效虚拟连接
|
||||
// 实际连接在 Acquire 时按需创建
|
||||
_ = needed
|
||||
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// getOptimalConnection 获取最优连接(基于性能权重)
|
||||
// 注意:调用方必须已持有 p.mu
|
||||
func (p *MySQLConnectionPool) getOptimalConnection() (*MySQLPoolEntry, error) {
|
||||
var bestEntry *MySQLPoolEntry
|
||||
var bestWeight float64
|
||||
|
||||
for i, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
weight := 1.0 // 默认权重
|
||||
if w, ok := p.adaptiveWeights[uint(i)]; ok {
|
||||
weight = w
|
||||
}
|
||||
|
||||
if bestEntry == nil || weight > bestWeight {
|
||||
bestEntry = entry
|
||||
bestWeight = weight
|
||||
}
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
if bestEntry == nil {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
bestEntry.InUse = true
|
||||
bestEntry.LastUsed = time.Now()
|
||||
return bestEntry, nil
|
||||
}
|
||||
|
||||
// createMySQLClient 创建 MySQL 客户端的辅助函数
|
||||
func createMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
config := &MySQLConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password,
|
||||
Database: conn.Database,
|
||||
}
|
||||
|
||||
return NewMySQLClient(config)
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrPoolExhausted = &PoolError{Message: "连接池已耗尽"}
|
||||
ErrPoolClosed = &PoolError{Message: "连接池已关闭"}
|
||||
)
|
||||
|
||||
// PoolError 连接池错误
|
||||
type PoolError struct {
|
||||
Message string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *PoolError) Error() string {
|
||||
if e.Err != nil {
|
||||
return e.Message + ": " + e.Err.Error()
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
@@ -1,762 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
reLimitOffset = regexp.MustCompile(`limit\s+(\d+)(?:\s*,\s*(\d+))?`)
|
||||
reFromTable = regexp.MustCompile(`(?i)from\s+([^\s,]+)`)
|
||||
reWhereClause = regexp.MustCompile(`(?i)where\s+(.*?)(?:\s+order\s+by|\s+limit|\s+group\s+by|$)`)
|
||||
reOrderBy = regexp.MustCompile(`(?i)order\s+by\s+(.*?)(?:\s+limit|$)`)
|
||||
reBatchOperation = regexp.MustCompile(`(?i)^\s*(INSERT|UPDATE|DELETE).*VALUES\s*\(`)
|
||||
)
|
||||
|
||||
// CachedQuery 缓存查询结果
|
||||
type CachedQuery struct {
|
||||
Result *QueryResult
|
||||
ExpiryTime time.Time
|
||||
CreatedAt time.Time
|
||||
QueryHash string
|
||||
QueryParams QueryParams
|
||||
LastUsed time.Time // 最后使用时间(用于LRU策略)
|
||||
AccessCount int64 // 访问次数(用于LFU策略)
|
||||
}
|
||||
|
||||
// QueryParams 查询参数(用于缓存键生成)
|
||||
type QueryParams struct {
|
||||
SQL string
|
||||
Database string
|
||||
Limit int
|
||||
Offset int
|
||||
Table string
|
||||
Where string
|
||||
SortBy string
|
||||
IsReadOnly bool
|
||||
}
|
||||
|
||||
// QueryStats 查询统计信息
|
||||
type QueryStats struct {
|
||||
TotalQueries int64
|
||||
CachedQueries int64
|
||||
SlowQueries int64
|
||||
TotalDuration time.Duration
|
||||
AverageDuration time.Duration
|
||||
CacheHitRate float64
|
||||
LastCacheUpdate time.Time
|
||||
}
|
||||
|
||||
// SlowQuery 慢查询记录
|
||||
type SlowQuery struct {
|
||||
Query string
|
||||
Database string
|
||||
Duration time.Duration
|
||||
Timestamp time.Time
|
||||
Params QueryParams
|
||||
Table string
|
||||
IndexUsed string
|
||||
RowsAffected int64
|
||||
Error error
|
||||
}
|
||||
|
||||
// IndexSuggestion 索引建议
|
||||
type IndexSuggestion struct {
|
||||
Table string
|
||||
Columns []string
|
||||
IndexType string // "normal", "unique", "fulltext"
|
||||
Priority string // "high", "medium", "low"
|
||||
Query string
|
||||
Justification string
|
||||
CanBeApplied bool
|
||||
}
|
||||
|
||||
// QueryOptimizer 查询优化器
|
||||
type QueryOptimizer struct {
|
||||
cache *QueryCache
|
||||
stats *QueryStats
|
||||
slowQueries []SlowQuery
|
||||
indexSuggestions []IndexSuggestion
|
||||
mu sync.RWMutex
|
||||
config *OptimizerConfig
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// OptimizerConfig 查询优化器配置
|
||||
type OptimizerConfig struct {
|
||||
// 缓存配置
|
||||
CacheSize int // 最大缓存条目数
|
||||
CacheTTL time.Duration // 缓存过期时间
|
||||
EnableCache bool // 是否启用缓存
|
||||
|
||||
// 慢查询配置
|
||||
SlowQueryThreshold time.Duration // 慢查询阈值
|
||||
EnableSlowLog bool // 是否启用慢查询日志
|
||||
MaxSlowLogs int // 最大慢查询记录数
|
||||
|
||||
// 索引建议配置
|
||||
EnableIndexSuggestions bool // 是否启用索引建议
|
||||
MaxSuggestions int // 最大索引建议数
|
||||
|
||||
// 查询分析配置
|
||||
EnableQueryAnalysis bool // 是否启用查询分析
|
||||
MaxAnalysisDepth int // 查询分析深度
|
||||
}
|
||||
|
||||
// DefaultOptimizerConfig 返回默认的查询优化器配置
|
||||
func DefaultOptimizerConfig() *OptimizerConfig {
|
||||
return &OptimizerConfig{
|
||||
CacheSize: 1000, // 最多缓存1000个查询
|
||||
CacheTTL: 30 * time.Minute, // 缓存30分钟
|
||||
EnableCache: true, // 启用缓存
|
||||
SlowQueryThreshold: 100 * time.Millisecond, // 100ms以上为慢查询
|
||||
EnableSlowLog: true, // 启用慢查询日志
|
||||
MaxSlowLogs: 1000, // 最多记录1000条慢查询
|
||||
EnableIndexSuggestions: true, // 启用索引建议
|
||||
MaxSuggestions: 100, // 最多100个索引建议
|
||||
EnableQueryAnalysis: true, // 启用查询分析
|
||||
MaxAnalysisDepth: 3, // 分析深度3
|
||||
}
|
||||
}
|
||||
|
||||
// NewQueryOptimizer 创建新的查询优化器
|
||||
func NewQueryOptimizer(config *OptimizerConfig) *QueryOptimizer {
|
||||
if config == nil {
|
||||
config = DefaultOptimizerConfig()
|
||||
}
|
||||
|
||||
optimizer := &QueryOptimizer{
|
||||
cache: NewQueryCache(config.CacheSize, config.CacheTTL),
|
||||
stats: &QueryStats{},
|
||||
config: config,
|
||||
stopCh: make(chan struct{}),
|
||||
slowQueries: make([]SlowQuery, 0),
|
||||
indexSuggestions: make([]IndexSuggestion, 0),
|
||||
}
|
||||
|
||||
// 启动维护协程
|
||||
optimizer.StartMaintenance()
|
||||
|
||||
return optimizer
|
||||
}
|
||||
|
||||
// OptimizeQuery 优化查询执行
|
||||
func (o *QueryOptimizer) OptimizeQuery(ctx context.Context, client *MySQLClient, sqlStr string, database string) (*QueryResult, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
queryParams := o.parseQueryParams(sqlStr, database)
|
||||
|
||||
// 检查缓存
|
||||
if o.config.EnableCache && queryParams.IsReadOnly {
|
||||
cached, err := o.cache.Get(queryParams)
|
||||
if err == nil && cached != nil {
|
||||
o.recordCacheHit()
|
||||
return cached.Result, time.Since(startTime), nil
|
||||
}
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
result, err := client.ExecuteQuery(ctx, sqlStr, database)
|
||||
if err != nil {
|
||||
duration := time.Since(startTime)
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
|
||||
return nil, duration, err
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 检查是否为慢查询
|
||||
if duration > o.config.SlowQueryThreshold {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
|
||||
}
|
||||
|
||||
// 缓存只读查询结果
|
||||
if o.config.EnableCache && queryParams.IsReadOnly && err == nil {
|
||||
cachedResult := &CachedQuery{
|
||||
Result: result,
|
||||
ExpiryTime: time.Now().Add(o.config.CacheTTL),
|
||||
CreatedAt: time.Now(),
|
||||
QueryHash: o.generateQueryHash(queryParams),
|
||||
QueryParams: queryParams,
|
||||
LastUsed: time.Now(),
|
||||
AccessCount: 1,
|
||||
}
|
||||
o.cache.Set(queryParams, cachedResult)
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// ExecuteOptimizedUpdate 执行优化的更新操作
|
||||
func (o *QueryOptimizer) ExecuteOptimizedUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 分析更新查询
|
||||
queryParams := o.parseQueryParams(sqlStr, database)
|
||||
|
||||
// 检查是否为批量操作
|
||||
if o.isBatchOperation(sqlStr) {
|
||||
// 优化批量操作
|
||||
rowsAffected, duration, err := o.optimizeBatchUpdate(ctx, client, sqlStr, database)
|
||||
if err != nil {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
|
||||
return 0, duration, err
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return rowsAffected, duration, nil
|
||||
}
|
||||
|
||||
// 执行普通更新
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if duration > o.config.SlowQueryThreshold {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return rowsAffected, duration, err
|
||||
}
|
||||
|
||||
// GetIndexSuggestions 获取索引建议
|
||||
func (o *QueryOptimizer) GetIndexSuggestions(table string) []IndexSuggestion {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
var suggestions []IndexSuggestion
|
||||
for _, suggestion := range o.indexSuggestions {
|
||||
if suggestion.Table == table || table == "" {
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// GenerateIndexSuggestions 为表生成索引建议
|
||||
func (o *QueryOptimizer) GenerateIndexSuggestions(ctx context.Context, client *MySQLClient, database, table string) error {
|
||||
// 获取表的慢查询记录
|
||||
tableSlowQueries := o.getTableSlowQueries(database, table)
|
||||
|
||||
// 分析查询模式
|
||||
for _, slowQuery := range tableSlowQueries {
|
||||
suggestions := o.analyzeQueryForIndexes(slowQuery.Query, table)
|
||||
o.mu.Lock()
|
||||
o.indexSuggestions = append(o.indexSuggestions, suggestions...)
|
||||
|
||||
// 限制建议数量
|
||||
if len(o.indexSuggestions) > o.config.MaxSuggestions {
|
||||
o.indexSuggestions = o.indexSuggestions[:o.config.MaxSuggestions]
|
||||
}
|
||||
o.mu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (o *QueryOptimizer) GetQueryStats() QueryStats {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
return *o.stats
|
||||
}
|
||||
|
||||
// GetSlowQueries 获取慢查询记录
|
||||
func (o *QueryOptimizer) GetSlowQueries(limit int) []SlowQuery {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if limit <= 0 || limit > len(o.slowQueries) {
|
||||
limit = len(o.slowQueries)
|
||||
}
|
||||
|
||||
return o.slowQueries[:limit]
|
||||
}
|
||||
|
||||
// ClearCache 清空缓存
|
||||
func (o *QueryOptimizer) ClearCache() {
|
||||
o.cache.Clear()
|
||||
}
|
||||
|
||||
// Stop 停止优化器
|
||||
func (o *QueryOptimizer) Stop() {
|
||||
close(o.stopCh)
|
||||
o.wg.Wait()
|
||||
}
|
||||
|
||||
// parseQueryParams 解析查询参数
|
||||
func (o *QueryOptimizer) parseQueryParams(sqlStr, database string) QueryParams {
|
||||
params := QueryParams{
|
||||
SQL: sqlStr,
|
||||
Database: database,
|
||||
}
|
||||
|
||||
// 解析LIMIT和OFFSET
|
||||
limit, offset := o.parseLimitOffset(sqlStr)
|
||||
params.Limit = limit
|
||||
params.Offset = offset
|
||||
|
||||
// 解析表名
|
||||
tables := o.parseTables(sqlStr)
|
||||
if len(tables) > 0 {
|
||||
params.Table = tables[0]
|
||||
}
|
||||
|
||||
// 解析WHERE条件
|
||||
where := o.parseWhereCondition(sqlStr)
|
||||
params.Where = where
|
||||
|
||||
// 解析排序
|
||||
sort := o.parseSortOrder(sqlStr)
|
||||
params.SortBy = sort
|
||||
|
||||
// 判断是否为只读查询
|
||||
params.IsReadOnly = o.isReadOnlyQuery(sqlStr)
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// parseLimitOffset 解析LIMIT和OFFSET
|
||||
func (o *QueryOptimizer) parseLimitOffset(sqlStr string) (limit, offset int) {
|
||||
sqlStr = strings.ToLower(sqlStr)
|
||||
|
||||
matches := reLimitOffset.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
fmt.Sscanf(matches[1], "%d", &limit)
|
||||
if len(matches) > 2 && matches[2] != "" {
|
||||
fmt.Sscanf(matches[2], "%d", &offset)
|
||||
}
|
||||
}
|
||||
|
||||
// MySQL LIMIT offset, count: matches[1]=offset, matches[2]=count
|
||||
if len(matches) > 2 && matches[2] != "" {
|
||||
offset, limit = limit, offset
|
||||
}
|
||||
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
// parseTables 解析查询中的表名
|
||||
func (o *QueryOptimizer) parseTables(sqlStr string) []string {
|
||||
// 简单实现:解析FROM和JOIN中的表名
|
||||
tables := make([]string, 0)
|
||||
|
||||
fromMatches := reFromTable.FindAllStringSubmatch(sqlStr, -1)
|
||||
|
||||
for _, match := range fromMatches {
|
||||
if len(match) > 1 {
|
||||
tableName := strings.Trim(match[1], "`\"'[]")
|
||||
tables = append(tables, tableName)
|
||||
}
|
||||
}
|
||||
|
||||
return tables
|
||||
}
|
||||
|
||||
// parseWhereCondition 解析WHERE条件
|
||||
func (o *QueryOptimizer) parseWhereCondition(sqlStr string) string {
|
||||
matches := reWhereClause.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSortOrder 解析排序条件
|
||||
func (o *QueryOptimizer) parseSortOrder(sqlStr string) string {
|
||||
matches := reOrderBy.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isReadOnlyQuery 判断是否为只读查询
|
||||
func (o *QueryOptimizer) isReadOnlyQuery(sqlStr string) bool {
|
||||
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
|
||||
|
||||
// SELECT只读查询
|
||||
if strings.HasPrefix(sqlStr, "SELECT") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持的只读查询类型
|
||||
readOnlyQueries := []string{
|
||||
"SHOW", "DESCRIBE", "DESC", "EXPLAIN",
|
||||
"WITH", "UNION", "INTERSECT", "EXCEPT",
|
||||
}
|
||||
|
||||
for _, query := range readOnlyQueries {
|
||||
if strings.HasPrefix(sqlStr, query) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isBatchOperation 判断是否为批量操作
|
||||
func (o *QueryOptimizer) isBatchOperation(sqlStr string) bool {
|
||||
return reBatchOperation.MatchString(sqlStr)
|
||||
}
|
||||
|
||||
// generateQueryHash 生成查询哈希
|
||||
func (o *QueryOptimizer) generateQueryHash(params QueryParams) string {
|
||||
hashData := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
|
||||
params.SQL, params.Database, params.Limit, params.Offset,
|
||||
params.Table, params.Where, params.SortBy, params.IsReadOnly)
|
||||
|
||||
h := sha256.Sum256([]byte(hashData))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// recordQuery 记录查询统计
|
||||
func (o *QueryOptimizer) recordQuery(duration time.Duration) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.stats.TotalQueries++
|
||||
o.stats.TotalDuration += duration
|
||||
o.stats.AverageDuration = time.Duration(int64(float64(o.stats.TotalDuration) / float64(o.stats.TotalQueries)))
|
||||
|
||||
now := time.Now()
|
||||
if o.stats.LastCacheUpdate.IsZero() || now.Sub(o.stats.LastCacheUpdate) > 5*time.Minute {
|
||||
// 更新缓存命中率
|
||||
total := o.stats.TotalQueries
|
||||
hit := o.stats.CachedQueries
|
||||
o.stats.CacheHitRate = float64(hit) / float64(total) * 100
|
||||
o.stats.LastCacheUpdate = now
|
||||
}
|
||||
}
|
||||
|
||||
// recordCacheHit 记录缓存命中
|
||||
func (o *QueryOptimizer) recordCacheHit() {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.stats.CachedQueries++
|
||||
}
|
||||
|
||||
// recordSlowQuery 记录慢查询
|
||||
func (o *QueryOptimizer) recordSlowQuery(query, database string, duration time.Duration, params QueryParams, result *QueryResult, err error) {
|
||||
if !o.config.EnableSlowLog {
|
||||
return
|
||||
}
|
||||
|
||||
slowQuery := SlowQuery{
|
||||
Query: query,
|
||||
Database: database,
|
||||
Duration: duration,
|
||||
Timestamp: time.Now(),
|
||||
Params: params,
|
||||
Table: params.Table,
|
||||
IndexUsed: o.extractIndexUsed(query),
|
||||
RowsAffected: o.extractRowsAffected(result),
|
||||
Error: err,
|
||||
}
|
||||
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.slowQueries = append(o.slowQueries, slowQuery)
|
||||
|
||||
// 限制慢查询记录数量
|
||||
if len(o.slowQueries) > o.config.MaxSlowLogs {
|
||||
o.slowQueries = o.slowQueries[1:]
|
||||
}
|
||||
|
||||
o.stats.SlowQueries++
|
||||
}
|
||||
|
||||
// extractIndexUsed 提取使用的索引
|
||||
func (o *QueryOptimizer) extractIndexUsed(query string) string {
|
||||
// 简单实现:从EXPLAIN结果中提取索引信息
|
||||
// 实际项目中应该执行EXPLAIN语句分析
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractRowsAffected 提取影响的行数
|
||||
func (o *QueryOptimizer) extractRowsAffected(result *QueryResult) int64 {
|
||||
if result != nil && len(result.Data) > 0 {
|
||||
if rows, ok := result.Data[0]["rows_affected"].(int64); ok {
|
||||
return rows
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// analyzeQuery 分析查询性能
|
||||
func (o *QueryOptimizer) analyzeQuery(query, database string, result *QueryResult, duration time.Duration) {
|
||||
// 这里可以实现更复杂的查询分析逻辑
|
||||
// 比如分析查询计划、检测N+1查询问题等
|
||||
|
||||
// 简单实现:记录查询到统计信息中
|
||||
_ = query
|
||||
_ = database
|
||||
_ = result
|
||||
_ = duration
|
||||
}
|
||||
|
||||
// analyzeQueryForIndexes 分析查询为索引建议
|
||||
func (o *QueryOptimizer) analyzeQueryForIndexes(query, table string) []IndexSuggestion {
|
||||
var suggestions []IndexSuggestion
|
||||
|
||||
// 解析查询中的WHERE条件
|
||||
where := o.parseWhereCondition(query)
|
||||
if where != "" {
|
||||
// 提取WHERE条件中的列
|
||||
columns := o.extractColumnsFromWhere(where)
|
||||
|
||||
if len(columns) > 0 {
|
||||
// 创建索引建议
|
||||
suggestion := IndexSuggestion{
|
||||
Table: table,
|
||||
Columns: columns,
|
||||
IndexType: "normal",
|
||||
Priority: "medium",
|
||||
Query: query,
|
||||
Justification: fmt.Sprintf("查询经常使用WHERE条件 %s", where),
|
||||
CanBeApplied: true,
|
||||
}
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析ORDER BY条件
|
||||
order := o.parseSortOrder(query)
|
||||
if order != "" {
|
||||
// 提取排序的列
|
||||
columns := o.extractColumnsFromOrder(order)
|
||||
|
||||
if len(columns) > 0 {
|
||||
// 创建排序索引建议
|
||||
suggestion := IndexSuggestion{
|
||||
Table: table,
|
||||
Columns: columns,
|
||||
IndexType: "normal",
|
||||
Priority: "low",
|
||||
Query: query,
|
||||
Justification: fmt.Sprintf("查询经常使用ORDER BY %s", order),
|
||||
CanBeApplied: true,
|
||||
}
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// extractColumnsFromWhere 从WHERE条件中提取列名
|
||||
func (o *QueryOptimizer) extractColumnsFromWhere(where string) []string {
|
||||
// 简单实现:提取WHERE条件中的列名
|
||||
columns := make([]string, 0)
|
||||
|
||||
// 这里可以实现更复杂的列名解析逻辑
|
||||
// 目前只做简单处理
|
||||
words := strings.Fields(where)
|
||||
for _, word := range words {
|
||||
// 去除运算符和引号
|
||||
if !strings.Contains(word, "=") &&
|
||||
!strings.Contains(word, ">") &&
|
||||
!strings.Contains(word, "<") &&
|
||||
!strings.Contains(word, "!=") &&
|
||||
!strings.HasPrefix(word, "'") &&
|
||||
!strings.HasPrefix(word, "\"") {
|
||||
columns = append(columns, strings.Trim(word, " `\"'[]"))
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// extractColumnsFromOrder 从ORDER BY条件中提取列名
|
||||
func (o *QueryOptimizer) extractColumnsFromOrder(order string) []string {
|
||||
// 简单实现:提取ORDER BY中的列名
|
||||
columns := strings.Split(order, ",")
|
||||
for i, col := range columns {
|
||||
columns[i] = strings.TrimSpace(strings.Split(col, " ")[0])
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// getTableSlowQueries 获取表的慢查询记录
|
||||
func (o *QueryOptimizer) getTableSlowQueries(database, table string) []SlowQuery {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
var tableQueries []SlowQuery
|
||||
for _, query := range o.slowQueries {
|
||||
if (database == "" || query.Database == database) &&
|
||||
(table == "" || query.Table == table) {
|
||||
tableQueries = append(tableQueries, query)
|
||||
}
|
||||
}
|
||||
return tableQueries
|
||||
}
|
||||
|
||||
// optimizeBatchUpdate 优化批量更新操作
|
||||
func (o *QueryOptimizer) optimizeBatchUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
// 简单实现:执行原始查询
|
||||
// 实际项目中可以实现批量操作优化
|
||||
startTime := time.Now()
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return rowsAffected, duration, err
|
||||
}
|
||||
|
||||
// StartMaintenance 启动维护协程
|
||||
func (o *QueryOptimizer) StartMaintenance() {
|
||||
o.wg.Add(1)
|
||||
go func() {
|
||||
defer o.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 清理过期的缓存
|
||||
o.cache.CleanupExpired()
|
||||
|
||||
// 分析慢查询生成新的索引建议
|
||||
o.analyzeSlowQueriesForSuggestions()
|
||||
|
||||
case <-o.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// RecordPoolError 记录连接池错误
|
||||
func (o *QueryOptimizer) RecordPoolError(operation string, err error) {
|
||||
if !o.config.EnableSlowLog || err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
poolError := SlowQuery{
|
||||
Query: operation,
|
||||
Database: "pool",
|
||||
Duration: 0,
|
||||
Timestamp: time.Now(),
|
||||
Params: QueryParams{SQL: operation},
|
||||
Table: "connection_pool",
|
||||
IndexUsed: "N/A",
|
||||
RowsAffected: 0,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.slowQueries = append(o.slowQueries, poolError)
|
||||
|
||||
// 限制慢查询记录数量
|
||||
if len(o.slowQueries) > o.config.MaxSlowLogs {
|
||||
o.slowQueries = o.slowQueries[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// analyzeSlowQueriesForSuggestions 分析慢查询生成索引建议
|
||||
func (o *QueryOptimizer) analyzeSlowQueriesForSuggestions() {
|
||||
// 这里可以实现更复杂的慢查询分析逻辑
|
||||
// 比如分析查询模式、统计索引使用情况等
|
||||
|
||||
// 分析慢查询模式
|
||||
o.analyzeSlowQueryPatterns()
|
||||
}
|
||||
|
||||
// analyzeSlowQueryPatterns 分析慢查询模式
|
||||
func (o *QueryOptimizer) analyzeSlowQueryPatterns() {
|
||||
o.mu.RLock()
|
||||
queryTypes := make(map[string]int)
|
||||
tableQueries := make(map[string]int)
|
||||
|
||||
for _, query := range o.slowQueries {
|
||||
queryType := o.detectQueryType(query.Query)
|
||||
queryTypes[queryType]++
|
||||
|
||||
if query.Table != "" {
|
||||
tableQueries[query.Table]++
|
||||
}
|
||||
}
|
||||
o.mu.RUnlock()
|
||||
|
||||
// 根据统计结果生成智能建议(在锁外执行,避免死锁)
|
||||
o.generateSmartSuggestions(queryTypes, tableQueries)
|
||||
}
|
||||
|
||||
// detectQueryType 检测查询类型
|
||||
func (o *QueryOptimizer) detectQueryType(sqlStr string) string {
|
||||
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
|
||||
|
||||
if strings.HasPrefix(sqlStr, "SELECT") {
|
||||
if strings.Contains(sqlStr, "JOIN") {
|
||||
return "SELECT_JOIN"
|
||||
} else if strings.Contains(sqlStr, "GROUP BY") {
|
||||
return "SELECT_GROUP"
|
||||
} else {
|
||||
return "SELECT_SIMPLE"
|
||||
}
|
||||
} else if strings.HasPrefix(sqlStr, "INSERT") {
|
||||
return "INSERT"
|
||||
} else if strings.HasPrefix(sqlStr, "UPDATE") {
|
||||
return "UPDATE"
|
||||
} else if strings.HasPrefix(sqlStr, "DELETE") {
|
||||
return "DELETE"
|
||||
}
|
||||
|
||||
return "OTHER"
|
||||
}
|
||||
|
||||
// generateSmartSuggestions 生成智能建议
|
||||
func (o *QueryOptimizer) generateSmartSuggestions(queryTypes map[string]int, tableQueries map[string]int) {
|
||||
// 分析频繁执行的查询类型
|
||||
var mostFrequentType string
|
||||
var maxCount int
|
||||
|
||||
for queryType, count := range queryTypes {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
mostFrequentType = queryType
|
||||
}
|
||||
}
|
||||
|
||||
// 生成针对性的索引建议
|
||||
switch mostFrequentType {
|
||||
case "SELECT_JOIN":
|
||||
// 为JOIN查询建议复合索引
|
||||
o.generateJoinSuggestions()
|
||||
case "SELECT_GROUP":
|
||||
// 为GROUP BY查询建议索引
|
||||
o.generateGroupSuggestions()
|
||||
case "INSERT":
|
||||
// 为批量插入建议优化
|
||||
o.generateInsertSuggestions()
|
||||
}
|
||||
}
|
||||
|
||||
// generateJoinSuggestions 生成JOIN查询建议
|
||||
func (o *QueryOptimizer) generateJoinSuggestions() {
|
||||
}
|
||||
|
||||
// generateGroupSuggestions 生成GROUP BY查询建议
|
||||
func (o *QueryOptimizer) generateGroupSuggestions() {
|
||||
}
|
||||
|
||||
// generateInsertSuggestions 生成批量插入建议
|
||||
func (o *QueryOptimizer) generateInsertSuggestions() {
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/common"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisClient Redis 客户端
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
config *RedisConfig
|
||||
}
|
||||
|
||||
// RedisConfig Redis 配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int // 数据库编号,默认 0
|
||||
}
|
||||
|
||||
// NewRedisClient 创建 Redis 客户端
|
||||
func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
DialTimeout: common.TimeoutConnect,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
|
||||
defer cancel()
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("Redis 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
return &RedisClient{
|
||||
client: rdb,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestRedisConnection 测试连接
|
||||
func TestRedisConnection(host string, port int, password string) error {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: 0,
|
||||
}
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewRedisClientByDB 根据参数创建指定 DB 的 Redis 客户端(用于多 DB 场景)
|
||||
func NewRedisClientByDB(host string, port int, password string, dbNum int) (*RedisClient, error) {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
return NewRedisClient(config)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *RedisClient) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 Redis 命令
|
||||
func (c *RedisClient) ExecuteCommand(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
|
||||
return c.client.Do(ctx, append([]interface{}{cmd}, args...)...).Result()
|
||||
}
|
||||
|
||||
// GetKeys 获取 Key 列表(支持 pattern,使用 SCAN 代替 KEYS 以提高性能)
|
||||
func (c *RedisClient) GetKeys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
|
||||
var keys []string
|
||||
var cursor uint64
|
||||
const count = 100 // 每次扫描的数量
|
||||
|
||||
for {
|
||||
var err error
|
||||
var batch []string
|
||||
batch, cursor, err = c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys = append(keys, batch...)
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetKeyType 获取 Key 类型
|
||||
func (c *RedisClient) GetKeyType(ctx context.Context, key string) (string, error) {
|
||||
return c.client.Type(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyValue 获取 Key 值
|
||||
func (c *RedisClient) GetKeyValue(ctx context.Context, key string) (interface{}, error) {
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch keyType {
|
||||
case "string":
|
||||
return c.client.Get(ctx, key).Result()
|
||||
case "list":
|
||||
return c.client.LRange(ctx, key, 0, -1).Result()
|
||||
case "set":
|
||||
return c.client.SMembers(ctx, key).Result()
|
||||
case "zset":
|
||||
// 对于有序集合,返回带分数的结果
|
||||
zMembers, err := c.client.ZRangeWithScores(ctx, key, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为 map 格式,便于展示
|
||||
result := make([]map[string]interface{}, len(zMembers))
|
||||
for i, member := range zMembers {
|
||||
result[i] = map[string]interface{}{
|
||||
"member": member.Member,
|
||||
"score": member.Score,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
case "hash":
|
||||
return c.client.HGetAll(ctx, key).Result()
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的类型: %s", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTTL 获取 Key 的 TTL
|
||||
func (c *RedisClient) GetTTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
return c.client.TTL(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyInfo 获取 Key 详细信息
|
||||
func (c *RedisClient) GetKeyInfo(ctx context.Context, key string) (map[string]interface{}, error) {
|
||||
info := map[string]interface{}{
|
||||
"key": key,
|
||||
"type": "",
|
||||
"value": nil,
|
||||
"ttl": 0,
|
||||
}
|
||||
|
||||
// 获取 Key 类型
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 类型失败: %v", err)
|
||||
}
|
||||
info["type"] = keyType
|
||||
|
||||
// 获取 TTL
|
||||
ttl, err := c.GetTTL(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 TTL 失败: %v", err)
|
||||
}
|
||||
info["ttl"] = ttl.Seconds()
|
||||
|
||||
// 获取 Key 值(限制大小,避免过大)
|
||||
value, err := c.GetKeyValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 值失败: %v", err)
|
||||
}
|
||||
info["value"] = formatValuePreview(value)
|
||||
|
||||
// 获取 Key 长度(使用 STRLEN、HLEN、SCARD、ZCARD)
|
||||
var keyLength int64
|
||||
switch keyType {
|
||||
case "string":
|
||||
keyLength, err = c.client.StrLen(ctx, key).Result()
|
||||
case "list":
|
||||
keyLength, err = c.client.LLen(ctx, key).Result()
|
||||
case "set":
|
||||
keyLength, err = c.client.SCard(ctx, key).Result()
|
||||
case "zset":
|
||||
keyLength, err = c.client.ZCard(ctx, key).Result()
|
||||
case "hash":
|
||||
keyLength, err = c.client.HLen(ctx, key).Result()
|
||||
}
|
||||
if err == nil {
|
||||
info["length"] = keyLength
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// formatValuePreview 格式化值预览(限制长度)
|
||||
func formatValuePreview(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
const maxPreviewLength = 200
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
if len(valueStr) > maxPreviewLength {
|
||||
valueStr = valueStr[:maxPreviewLength] + "..."
|
||||
}
|
||||
|
||||
return valueStr
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表(Redis 使用 DB number)
|
||||
// Redis 没有传统数据库概念,这里返回空数组
|
||||
func (c *RedisClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
// Redis 可以使用 DB number 来隔离数据
|
||||
// 这里可以返回当前配置的 DB 或者所有可用的 DB
|
||||
// 为简单起见,返回空数组,让用户直接操作 Key
|
||||
return []string{}, nil
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisPipeline Redis Pipeline 操作
|
||||
type RedisPipeline struct {
|
||||
client *RedisClient
|
||||
commands []RedisCommand
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// RedisCommand Redis 命令结构
|
||||
type RedisCommand struct {
|
||||
Command string
|
||||
Args []interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
// NewRedisPipeline 创建新的 Redis Pipeline
|
||||
func (r *RedisClient) NewPipeline(ctx context.Context) *RedisPipeline {
|
||||
return &RedisPipeline{
|
||||
client: r,
|
||||
commands: make([]RedisCommand, 0),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCommand 添加命令到 Pipeline
|
||||
func (p *RedisPipeline) AddCommand(command string, args ...interface{}) {
|
||||
p.commands = append(p.commands, RedisCommand{
|
||||
Command: command,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute 使用 go-redis 原生 Pipeline 执行所有命令
|
||||
func (p *RedisPipeline) Execute() ([]interface{}, error) {
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pipe := p.client.client.Pipeline()
|
||||
|
||||
cmds := make([]*redis.Cmd, len(p.commands))
|
||||
for i, c := range p.commands {
|
||||
cmds[i] = pipe.Do(p.ctx, append([]interface{}{c.Command}, c.Args...)...)
|
||||
}
|
||||
|
||||
// 一次性发送所有命令
|
||||
results := make([]interface{}, len(p.commands))
|
||||
cmdResults, err := pipe.Exec(p.ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
log.Printf("[RedisPipeline] Exec 错误: %v", err)
|
||||
}
|
||||
|
||||
for i, cmd := range cmds {
|
||||
result, cmdErr := cmd.Result()
|
||||
results[i] = result
|
||||
p.commands[i].Result = result
|
||||
p.commands[i].Error = cmdErr
|
||||
}
|
||||
|
||||
// 如果 Exec 返回了命令结果(部分 Redis 版本),使用它们
|
||||
for i, cr := range cmdResults {
|
||||
if cr.Err() != nil && cr.Err() != redis.Nil {
|
||||
p.commands[i].Error = cr.Err()
|
||||
if i < len(results) {
|
||||
results[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = results // 已经通过 cmds 获取
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetCommands 获取 Pipeline 中的命令列表
|
||||
func (p *RedisPipeline) GetCommands() []RedisCommand {
|
||||
return p.commands
|
||||
}
|
||||
|
||||
// Len 获取 Pipeline 中的命令数量
|
||||
func (p *RedisPipeline) Len() int {
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
// Clear 清空 Pipeline
|
||||
func (p *RedisPipeline) Clear() {
|
||||
p.commands = make([]RedisCommand, 0)
|
||||
}
|
||||
|
||||
// RedisTransaction Redis 事务支持
|
||||
type RedisTransaction struct {
|
||||
client *RedisClient
|
||||
watch []string
|
||||
cmds []RedisCommand
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisTransaction 创建新的 Redis 事务
|
||||
func (r *RedisClient) NewTransaction(ctx context.Context, watch ...string) *RedisTransaction {
|
||||
return &RedisTransaction{
|
||||
client: r,
|
||||
watch: watch,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCommand 添加命令到事务
|
||||
func (tx *RedisTransaction) AddCommand(command string, args ...interface{}) {
|
||||
tx.cmds = append(tx.cmds, RedisCommand{
|
||||
Command: command,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Exec 使用 go-redis Watch + TxPipeline 执行事务(MULTI/EXEC)
|
||||
func (tx *RedisTransaction) Exec() ([]interface{}, error) {
|
||||
pipe := tx.client.client.TxPipeline()
|
||||
|
||||
// 添加所有命令
|
||||
cmds := make([]*redis.Cmd, len(tx.cmds))
|
||||
for i, c := range tx.cmds {
|
||||
cmds[i] = pipe.Do(tx.ctx, append([]interface{}{c.Command}, c.Args...)...)
|
||||
}
|
||||
|
||||
// TxPipeline 自动发送 MULTI/EXEC
|
||||
results := make([]interface{}, len(tx.cmds))
|
||||
_, err := pipe.Exec(tx.ctx)
|
||||
|
||||
for i, cmd := range cmds {
|
||||
result, cmdErr := cmd.Result()
|
||||
results[i] = result
|
||||
tx.cmds[i].Result = result
|
||||
tx.cmds[i].Error = cmdErr
|
||||
}
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return results, fmt.Errorf("事务执行失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package model
|
||||
|
||||
// MemberInfo 用户信息表
|
||||
type MemberInfo struct {
|
||||
Memberid int `gorm:"primaryKey;column:memberid;type:int;comment:用户ID" json:"memberid"`
|
||||
Membername string `gorm:"column:membername;type:varchar(100);comment:姓名" json:"membername"`
|
||||
Account string `gorm:"column:account;type:varchar(100);comment:账号" json:"account"`
|
||||
Password string `gorm:"column:password;type:varchar(100);comment:密码" json:"-"`
|
||||
Contactphone string `gorm:"column:contactphone;type:varchar(50);comment:联系电话" json:"contactphone"`
|
||||
Organid int `gorm:"column:organid;type:int;comment:所属机构ID" json:"organid"`
|
||||
Createtime string `gorm:"column:createtime;type:varchar(50);comment:创建时间" json:"createtime"`
|
||||
Updatetime string `gorm:"column:updatetime;type:varchar(50);comment:修改时间" json:"updatetime"`
|
||||
Role int16 `gorm:"column:role;type:smallint;comment:角色类别" json:"role"`
|
||||
Status int16 `gorm:"column:status;type:smallint;comment:状态 1正常 2停用 3删除" json:"status"`
|
||||
Calluserid string `gorm:"column:calluserid;type:varchar(100);comment:坐席用户ID" json:"calluserid"`
|
||||
Remainingexport int `gorm:"column:remainingexport;type:int;comment:本月剩余导出次数" json:"remainingexport"`
|
||||
|
||||
// 虚拟字段(关联查询)
|
||||
Organname string `gorm:"-" json:"organname"` // 机构名称
|
||||
Rolename string `gorm:"-" json:"rolename"` // 角色名称
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (MemberInfo) TableName() string {
|
||||
return "member_info"
|
||||
}
|
||||
@@ -42,11 +42,10 @@ type TabConfig struct {
|
||||
var defaultTabConfig = TabConfig{
|
||||
AvailableTabs: []TabDefinition{
|
||||
{Key: "file-system", Title: "文件管理", Enabled: true},
|
||||
{Key: "db-cli", Title: "数据库", Enabled: true},
|
||||
{Key: "markdown-editor", Title: "Markdown", Enabled: true},
|
||||
{Key: "version", Title: "版本历史", Enabled: true},
|
||||
},
|
||||
VisibleTabs: []string{"file-system", "db-cli", "markdown-editor", "version"},
|
||||
VisibleTabs: []string{"file-system", "markdown-editor", "version"},
|
||||
DefaultTab: "file-system",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"u-desk/internal/crypto"
|
||||
"u-desk/internal/dbclient"
|
||||
"u-desk/internal/storage/models"
|
||||
"u-desk/internal/storage/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ConnectionService 连接管理服务
|
||||
type ConnectionService struct {
|
||||
repo repository.ConnectionRepository
|
||||
}
|
||||
|
||||
// NewConnectionService 创建连接服务
|
||||
func NewConnectionService() (*ConnectionService, error) {
|
||||
repo, err := repository.NewConnectionRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建连接仓库失败: %v", err)
|
||||
}
|
||||
return &ConnectionService{repo: repo}, nil
|
||||
}
|
||||
|
||||
// SaveConnection 保存连接配置
|
||||
func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error {
|
||||
// 验证
|
||||
if conn.Name == "" {
|
||||
return fmt.Errorf("连接名称不能为空")
|
||||
}
|
||||
if conn.Type == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
if conn.Host == "" {
|
||||
return fmt.Errorf("主机地址不能为空")
|
||||
}
|
||||
|
||||
// 检查名称是否重复
|
||||
existing, err := s.repo.FindByName(conn.Name, conn.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查连接名称失败: %v", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return fmt.Errorf("连接名称已存在")
|
||||
}
|
||||
|
||||
// 处理密码
|
||||
if conn.ID > 0 {
|
||||
if conn.Password == "" {
|
||||
// 更新模式:保留原密码
|
||||
conn.Password, err = s.getPassword(conn.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 加密新密码
|
||||
conn.Password, err = crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 新增模式:加密密码
|
||||
conn.Password, err = crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.Save(conn)
|
||||
}
|
||||
|
||||
// getPassword 获取原始密码
|
||||
func (s *ConnectionService) getPassword(id uint) (string, error) {
|
||||
existing, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
return existing.Password, nil
|
||||
}
|
||||
|
||||
// ListConnections 获取连接列表
|
||||
func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) {
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
|
||||
// GetConnection 获取连接详情
|
||||
func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) {
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// DeleteConnection 删除连接配置(含关联数据和连接池清理)
|
||||
func (s *ConnectionService) DeleteConnection(id uint) error {
|
||||
conn, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // 连接不存在视为成功
|
||||
}
|
||||
return fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 关闭连接池中的连接
|
||||
dbclient.GetPool().CloseConnection(id, conn.Type)
|
||||
|
||||
// 删除连接记录
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
|
||||
// TestConnection 测试连接(通过已保存的连接ID)
|
||||
func (s *ConnectionService) TestConnection(id uint) error {
|
||||
conn, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 解密密码用于测试
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 根据类型测试连接
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
return dbclient.TestMySQLConnection(conn.Host, conn.Port, conn.Username, password, conn.Database)
|
||||
case "redis":
|
||||
return dbclient.TestRedisConnection(conn.Host, conn.Port, password)
|
||||
case "mongo":
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbclient.TestMongoConnectionWithOptions(conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionWithParams 测试连接(直接传入参数,不保存数据)
|
||||
func (s *ConnectionService) TestConnectionWithParams(connType, host string, port int, username, password, database, options string, existingId uint) error {
|
||||
// 验证必填项
|
||||
if connType == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
if host == "" {
|
||||
return fmt.Errorf("主机地址不能为空")
|
||||
}
|
||||
|
||||
// 如果是编辑模式且密码为空,尝试获取已保存的密码
|
||||
actualPassword := password
|
||||
if existingId > 0 && password == "" {
|
||||
conn, err := s.repo.FindByID(existingId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
// 解密原密码
|
||||
actualPassword, err = crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 根据类型测试连接
|
||||
switch connType {
|
||||
case "mysql":
|
||||
return dbclient.TestMySQLConnection(host, port, username, actualPassword, database)
|
||||
case "redis":
|
||||
return dbclient.TestRedisConnection(host, port, actualPassword)
|
||||
case "mongo":
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbclient.TestMongoConnectionWithOptions(host, port, username, actualPassword, database, authSource, authMechanism)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", connType)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAllDatabases 加载全部数据库列表
|
||||
func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, existingId uint) ([]string, error) {
|
||||
// 如果是编辑模式且密码为空,尝试获取已保存的密码
|
||||
actualPassword := password
|
||||
if existingId > 0 && password == "" {
|
||||
conn, err := s.repo.FindByID(existingId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
actualPassword, err = crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 MongoDB 选项
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(options), &opts); err == nil {
|
||||
authSource, _ = opts["authSource"].(string)
|
||||
authMechanism, _ = opts["authMechanism"].(string)
|
||||
}
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
return loadDatabasesForMySQL(host, port, username, actualPassword, database)
|
||||
case "mongo":
|
||||
return loadDatabasesForMongo(host, port, username, actualPassword, database, authSource, authMechanism)
|
||||
case "redis":
|
||||
return []string{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", dbType)
|
||||
}
|
||||
}
|
||||
|
||||
func loadDatabasesForMySQL(host string, port int, username, password, defaultDatabase string) ([]string, error) {
|
||||
config := &dbclient.MySQLConfig{
|
||||
Host: host, Port: port, Username: username,
|
||||
Password: password, Database: defaultDatabase,
|
||||
}
|
||||
client, err := dbclient.NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
return client.ListDatabases(context.Background())
|
||||
}
|
||||
|
||||
func loadDatabasesForMongo(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) {
|
||||
config := &dbclient.MongoConfig{
|
||||
Host: host, Port: port, Username: username,
|
||||
Password: password, Database: defaultDatabase,
|
||||
AuthSource: authSource, AuthMechanism: authMechanism,
|
||||
}
|
||||
client, err := dbclient.NewMongoClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
return client.ListDatabases(context.Background())
|
||||
}
|
||||
@@ -1,475 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/common"
|
||||
"u-desk/internal/dbclient"
|
||||
"u-desk/internal/storage/models"
|
||||
"u-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// SqlExecService SQL执行服务
|
||||
type SqlExecService struct {
|
||||
connRepo repository.ConnectionRepository
|
||||
pool *dbclient.ConnectionPool
|
||||
}
|
||||
|
||||
// NewSqlExecService 创建SQL执行服务
|
||||
func NewSqlExecService() (*SqlExecService, error) {
|
||||
connRepo, err := repository.NewConnectionRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SqlExecService{
|
||||
connRepo: connRepo,
|
||||
pool: dbclient.GetPool(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SqlResult SQL执行结果
|
||||
type SqlResult struct {
|
||||
Type string `json:"type"` // query/update/command
|
||||
Data interface{} `json:"data"` // 查询结果数据
|
||||
Columns []string `json:"columns"` // 列顺序(仅查询时有效)
|
||||
RowsAffected int `json:"rowsAffected"` // 影响行数
|
||||
ExecutionTime int64 `json:"executionTime"` // 执行时间(毫秒)
|
||||
}
|
||||
|
||||
// ExecuteSQL 执行SQL语句
|
||||
// 注意:SQL 语句应该已经包含分页信息(LIMIT 和 OFFSET),由客户端添加
|
||||
func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database string) (*SqlResult, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
return s.executeMySQL(ctx, conn, sqlStr, database, startTime)
|
||||
case "redis":
|
||||
return s.executeRedis(ctx, conn, sqlStr, startTime)
|
||||
case "mongo":
|
||||
return s.executeMongo(ctx, conn, sqlStr, database, startTime)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// executeMySQL 执行MySQL SQL
|
||||
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
sqlStr = strings.TrimSpace(sqlStr)
|
||||
sqlUpper := strings.ToUpper(sqlStr)
|
||||
|
||||
// 获取数据库参数
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = conn.Database
|
||||
}
|
||||
|
||||
result := &SqlResult{
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// 判断是查询还是更新
|
||||
if strings.HasPrefix(sqlUpper, "SELECT") || strings.HasPrefix(sqlUpper, "SHOW") ||
|
||||
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
|
||||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
|
||||
// 查询语句
|
||||
queryResult, err := pc.Client.ExecuteQuery(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Type = "query"
|
||||
result.Data = queryResult.Data
|
||||
result.Columns = queryResult.Columns
|
||||
result.RowsAffected = len(queryResult.Data)
|
||||
} else {
|
||||
// 更新语句
|
||||
rowsAffected, err := pc.Client.ExecuteUpdate(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Type = "update"
|
||||
result.RowsAffected = int(rowsAffected)
|
||||
result.Data = nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// executeRedis 执行Redis命令
|
||||
func (s *SqlExecService) executeRedis(ctx context.Context, conn *models.DbConnection, sqlStr string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析Redis命令
|
||||
parts := parseRedisCommand(sqlStr)
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("Redis 命令不能为空")
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(parts[0])
|
||||
args := make([]interface{}, 0)
|
||||
for i := 1; i < len(parts); i++ {
|
||||
args = append(args, parts[i])
|
||||
}
|
||||
|
||||
data, err := client.ExecuteCommand(ctx, cmd, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SqlResult{
|
||||
Type: "command",
|
||||
Data: data,
|
||||
RowsAffected: 1,
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// executeMongo 执行MongoDB命令
|
||||
func (s *SqlExecService) executeMongo(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析MongoDB命令(JSON格式)
|
||||
var command map[string]interface{}
|
||||
sqlStr = strings.TrimSpace(sqlStr)
|
||||
if err := json.Unmarshal([]byte(sqlStr), &command); err != nil {
|
||||
return nil, fmt.Errorf("MongoDB 命令必须是有效的 JSON 格式: %v", err)
|
||||
}
|
||||
|
||||
// 确定数据库
|
||||
dbName := conn.Database
|
||||
if db, ok := command["database"].(string); ok && db != "" {
|
||||
dbName = db
|
||||
}
|
||||
if database != "" {
|
||||
dbName = database
|
||||
}
|
||||
if dbName == "" {
|
||||
return nil, fmt.Errorf("需要指定数据库名称")
|
||||
}
|
||||
|
||||
// 执行命令
|
||||
data, err := client.ExecuteCommand(ctx, dbName, command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SqlResult{
|
||||
Type: "command",
|
||||
Data: data,
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// 根据操作类型确定影响行数
|
||||
if op, ok := command["op"].(string); ok {
|
||||
switch op {
|
||||
case "find":
|
||||
if results, ok := data.([]map[string]interface{}); ok {
|
||||
result.RowsAffected = len(results)
|
||||
}
|
||||
case "count":
|
||||
if count, ok := data.(int64); ok {
|
||||
result.RowsAffected = int(count)
|
||||
}
|
||||
case "insertOne", "deleteOne":
|
||||
result.RowsAffected = 1
|
||||
case "insertMany":
|
||||
if resultMap, ok := data.(map[string]interface{}); ok {
|
||||
if count, ok := resultMap["insertedCount"].(int); ok {
|
||||
result.RowsAffected = count
|
||||
}
|
||||
}
|
||||
default:
|
||||
result.RowsAffected = 0
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDatabases 获取数据库列表
|
||||
func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
return pc.Client.ListDatabases(ctx)
|
||||
case "redis":
|
||||
databases := make([]string, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
databases[i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
return databases, nil
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListDatabases(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTables 获取表列表(MySQL/MongoDB)或Key列表(Redis)
|
||||
func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
return pc.Client.ListTables(ctx, database)
|
||||
case "redis":
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
return client.GetKeys(ctx, database)
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListCollections(ctx, database)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// parseRedisCommand 解析Redis命令
|
||||
func parseRedisCommand(cmd string) []string {
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
if cmd == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
inQuotes := false
|
||||
quoteChar := byte(0)
|
||||
|
||||
for i := 0; i < len(cmd); i++ {
|
||||
char := cmd[i]
|
||||
if !inQuotes {
|
||||
if char == '"' || char == '\'' {
|
||||
inQuotes = true
|
||||
quoteChar = char
|
||||
} else if char == ' ' || char == '\t' {
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(char)
|
||||
}
|
||||
} else {
|
||||
if char == quoteChar {
|
||||
inQuotes = false
|
||||
quoteChar = byte(0)
|
||||
} else {
|
||||
current.WriteByte(char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
structure, err := pc.Client.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "mysql",
|
||||
"database": database,
|
||||
"table": tableName,
|
||||
"columns": structure,
|
||||
}, nil
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
structure, err := client.GetCollectionStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "mongo",
|
||||
"database": database,
|
||||
"collection": tableName,
|
||||
"structure": structure,
|
||||
}, nil
|
||||
|
||||
case "redis":
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
info, err := client.GetKeyInfo(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "redis",
|
||||
"key": tableName,
|
||||
"info": info,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
return pc.Client.GetIndexes(ctx, database, tableName)
|
||||
|
||||
case "mongo", "redis":
|
||||
return []map[string]interface{}{}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更
|
||||
func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
return pc.Client.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.PreviewCollectionIndexes(ctx, database, tableName, structure)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构
|
||||
func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
return pc.Client.UpdateTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.UpdateCollectionIndexes(ctx, database, tableName, structure)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"u-desk/internal/storage/models"
|
||||
"u-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// TabService 标签页管理服务
|
||||
type TabService struct {
|
||||
repo repository.TabRepository
|
||||
}
|
||||
|
||||
// NewTabService 创建标签页服务
|
||||
func NewTabService() (*TabService, error) {
|
||||
repo, err := repository.NewTabRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建标签页仓库失败: %v", err)
|
||||
}
|
||||
return &TabService{repo: repo}, nil
|
||||
}
|
||||
|
||||
// SaveTabs 保存标签页列表
|
||||
func (s *TabService) SaveTabs(tabs []models.SqlTab) error {
|
||||
return s.repo.SaveAll(tabs)
|
||||
}
|
||||
|
||||
// ListTabs 获取标签页列表
|
||||
func (s *TabService) ListTabs() ([]models.SqlTab, error) {
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// ==================== 常量定义 ====================
|
||||
|
||||
// AppVersion 应用版本号(发布时直接修改此处)
|
||||
const AppVersion = "0.3.3"
|
||||
const AppVersion = "0.4.0"
|
||||
|
||||
// 版本号缓存
|
||||
var (
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DbConnection 数据库连接配置
|
||||
type DbConnection struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(100);not null" json:"name"` // 连接名称
|
||||
Type string `gorm:"type:varchar(20);not null" json:"type"` // 数据库类型: mysql/redis/mongo
|
||||
Host string `gorm:"type:varchar(255);not null" json:"host"` // 主机地址
|
||||
Port int `gorm:"not null" json:"port"` // 端口
|
||||
Username string `gorm:"type:varchar(100)" json:"username"` // 用户名
|
||||
Password string `gorm:"type:varchar(500)" json:"-"` // 密码(加密存储,不返回)
|
||||
Database string `gorm:"type:varchar(100)" json:"database"` // 数据库名(MySQL/MongoDB)
|
||||
Options string `gorm:"type:text" json:"options"` // 额外选项(JSON格式)
|
||||
VisibleDatabases string `gorm:"type:text" json:"visible_databases"` // 可见数据库列表(JSON数组,为空则全部可见)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (DbConnection) TableName() string {
|
||||
return "db_connection"
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SqlResultHistory SQL 执行结果历史
|
||||
type SqlResultHistory struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ConnectionID uint `gorm:"index;not null" json:"connection_id"` // 连接ID
|
||||
Database string `gorm:"type:varchar(100)" json:"database"` // 数据库名
|
||||
Sql string `gorm:"type:text;not null" json:"sql"` // SQL语句
|
||||
Type string `gorm:"type:varchar(20);not null" json:"type"` // 结果类型: query/update/command
|
||||
Data string `gorm:"type:text" json:"data"` // 结果数据(JSON)
|
||||
Columns string `gorm:"type:text" json:"columns"` // 列信息(JSON)
|
||||
RowsAffected int `gorm:"default:0" json:"rows_affected"` // 影响行数
|
||||
ExecutionTime int64 `gorm:"default:0" json:"execution_time"` // 执行时间(毫秒)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SqlResultHistory) TableName() string {
|
||||
return "sql_result_history"
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SqlTab SQL 编辑器标签页
|
||||
type SqlTab struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Title string `gorm:"type:varchar(100);not null" json:"title"` // 标签页标题
|
||||
Content string `gorm:"type:text" json:"content"` // SQL 内容
|
||||
ConnectionID *uint `gorm:"index" json:"connection_id"` // 关联的连接ID(可为空)
|
||||
Order int `gorm:"default:0" json:"order"` // 排序顺序
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SqlTab) TableName() string {
|
||||
return "sql_tab"
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"u-desk/internal/storage"
|
||||
"u-desk/internal/storage/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConnectionRepository interface {
|
||||
Save(conn *models.DbConnection) error
|
||||
FindAll() ([]models.DbConnection, error)
|
||||
FindByID(id uint) (*models.DbConnection, error)
|
||||
Delete(id uint) error
|
||||
FindByName(name string, excludeID uint) (*models.DbConnection, error)
|
||||
}
|
||||
|
||||
type connectionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewConnectionRepository() (ConnectionRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &connectionRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *connectionRepository) Save(conn *models.DbConnection) error {
|
||||
if conn.ID > 0 {
|
||||
return r.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(conn).Error
|
||||
}
|
||||
return r.db.Create(conn).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindAll() ([]models.DbConnection, error) {
|
||||
var connections []models.DbConnection
|
||||
return connections, r.db.Order("created_at DESC").Find(&connections).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindByID(id uint) (*models.DbConnection, error) {
|
||||
var conn models.DbConnection
|
||||
err := r.db.First(&conn, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &conn, err
|
||||
}
|
||||
|
||||
func (r *connectionRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.DbConnection{}, id).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindByName(name string, excludeID uint) (*models.DbConnection, error) {
|
||||
var conn models.DbConnection
|
||||
query := r.db.Where("name = ?", name)
|
||||
if excludeID > 0 {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
err := query.First(&conn).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &conn, err
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"u-desk/internal/storage"
|
||||
"u-desk/internal/storage/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ResultRepository interface {
|
||||
Save(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (*models.SqlResultHistory, error)
|
||||
FindByID(id uint) (*models.SqlResultHistory, error)
|
||||
Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error)
|
||||
Delete(id uint) error
|
||||
}
|
||||
|
||||
type resultRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewResultRepository() (ResultRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &resultRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *resultRepository) Save(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (*models.SqlResultHistory, error) {
|
||||
dataJSON, _ := json.Marshal(data)
|
||||
columnsJSON, _ := json.Marshal(columns)
|
||||
|
||||
history := &models.SqlResultHistory{
|
||||
ConnectionID: connectionID,
|
||||
Database: database,
|
||||
Sql: sql,
|
||||
Type: resultType,
|
||||
Data: string(dataJSON),
|
||||
Columns: string(columnsJSON),
|
||||
RowsAffected: rowsAffected,
|
||||
ExecutionTime: executionTime,
|
||||
}
|
||||
|
||||
return history, r.db.Create(history).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) FindByID(id uint) (*models.SqlResultHistory, error) {
|
||||
var history models.SqlResultHistory
|
||||
err := r.db.First(&history, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &history, err
|
||||
}
|
||||
|
||||
func (r *resultRepository) Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error) {
|
||||
query := r.db.Model(&models.SqlResultHistory{})
|
||||
|
||||
if connectionID != nil {
|
||||
query = query.Where("connection_id = ?", *connectionID)
|
||||
}
|
||||
if keyword != "" {
|
||||
query = query.Where("sql LIKE ? OR database LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var histories []models.SqlResultHistory
|
||||
query = query.Order("created_at DESC")
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
return histories, total, query.Find(&histories).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.SqlResultHistory{}, id).Error
|
||||
}
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"u-desk/internal/storage"
|
||||
"u-desk/internal/storage/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TabRepository interface {
|
||||
SaveAll(tabs []models.SqlTab) error
|
||||
FindAll() ([]models.SqlTab, error)
|
||||
Delete(id uint) error
|
||||
}
|
||||
|
||||
type tabRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTabRepository() (TabRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &tabRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *tabRepository) SaveAll(tabs []models.SqlTab) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("1=1").Delete(&models.SqlTab{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tabs) > 0 {
|
||||
return tx.Create(&tabs).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *tabRepository) FindAll() ([]models.SqlTab, error) {
|
||||
var tabs []models.SqlTab
|
||||
return tabs, r.db.Order("`order` ASC, created_at ASC").Find(&tabs).Error
|
||||
}
|
||||
|
||||
func (r *tabRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.SqlTab{}, id).Error
|
||||
}
|
||||
|
||||
@@ -62,9 +62,6 @@ func InitFast() (*gorm.DB, error) {
|
||||
// AutoMigrate 在启动时执行,但只在表结构不存在时创建
|
||||
// SQLite 的 AutoMigrate 很快,不会造成明显延迟
|
||||
if err := db.AutoMigrate(
|
||||
&models.DbConnection{},
|
||||
&models.SqlTab{},
|
||||
&models.SqlResultHistory{},
|
||||
&models.AppConfig{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user