新增:连接管理、数据查询等功能
This commit is contained in:
109
internal/api/connection_api.go
Normal file
109
internal/api/connection_api.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"go-desk/internal/service"
|
||||
"go-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionAPI 连接管理API
|
||||
type ConnectionAPI struct {
|
||||
connService *service.ConnectionService
|
||||
}
|
||||
|
||||
// NewConnectionAPI 创建连接管理API
|
||||
func NewConnectionAPI() (*ConnectionAPI, error) {
|
||||
connService, err := service.NewConnectionService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConnectionAPI{connService}, nil
|
||||
}
|
||||
|
||||
// SaveConnectionRequest 保存连接请求结构体
|
||||
type SaveConnectionRequest struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Options string `json:"options"`
|
||||
}
|
||||
|
||||
// SaveDbConnection 保存数据库连接配置
|
||||
func (api *ConnectionAPI) SaveDbConnection(req SaveConnectionRequest) error {
|
||||
conn := &models.DbConnection{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Database: req.Database,
|
||||
Options: req.Options,
|
||||
}
|
||||
return api.connService.SaveConnection(conn)
|
||||
}
|
||||
|
||||
// ListDbConnections 获取连接列表
|
||||
func (api *ConnectionAPI) ListDbConnections() ([]map[string]interface{}, error) {
|
||||
connections, err := api.connService.ListConnections()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(connections))
|
||||
timeFormat := "2006-01-02 15:04:05"
|
||||
for i, conn := range connections {
|
||||
result[i] = map[string]interface{}{
|
||||
"id": conn.ID,
|
||||
"name": conn.Name,
|
||||
"type": conn.Type,
|
||||
"host": conn.Host,
|
||||
"port": conn.Port,
|
||||
"username": conn.Username,
|
||||
"database": conn.Database,
|
||||
"options": conn.Options,
|
||||
"created_at": conn.CreatedAt.Format(timeFormat),
|
||||
"updated_at": conn.UpdatedAt.Format(timeFormat),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (api *ConnectionAPI) DeleteDbConnection(id uint) error {
|
||||
return api.connService.DeleteConnection(id)
|
||||
}
|
||||
|
||||
func (api *ConnectionAPI) TestDbConnection(id uint) error {
|
||||
return api.connService.TestConnection(id)
|
||||
}
|
||||
|
||||
// TestConnectionRequest 测试连接请求结构体(不保存数据)
|
||||
type TestConnectionRequest struct {
|
||||
ID uint `json:"id"` // 编辑模式下的连接ID(用于获取已保存的密码)
|
||||
Type string `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
Options string `json:"options"`
|
||||
}
|
||||
|
||||
// TestDbConnectionWithParams 测试数据库连接(直接传入参数,不保存数据)
|
||||
func (api *ConnectionAPI) TestDbConnectionWithParams(req TestConnectionRequest) error {
|
||||
return api.connService.TestConnectionWithParams(
|
||||
req.Type,
|
||||
req.Host,
|
||||
req.Port,
|
||||
req.Username,
|
||||
req.Password,
|
||||
req.Database,
|
||||
req.Options,
|
||||
req.ID,
|
||||
)
|
||||
}
|
||||
137
internal/api/sql_api.go
Normal file
137
internal/api/sql_api.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"go-desk/internal/service"
|
||||
"go-desk/internal/storage/models"
|
||||
"go-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
type SqlAPI struct {
|
||||
sqlService *service.SqlExecService
|
||||
resultRepo repository.ResultRepository
|
||||
}
|
||||
|
||||
func NewSqlAPI() (*SqlAPI, error) {
|
||||
sqlService, err := service.NewSqlExecService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultRepo, err := repository.NewResultRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SqlAPI{sqlService, resultRepo}, nil
|
||||
}
|
||||
|
||||
// ExecuteSQL 执行SQL语句
|
||||
// 注意:SQL 语句应该已经包含分页信息(LIMIT 和 OFFSET),由客户端添加
|
||||
func (api *SqlAPI) ExecuteSQL(connectionID uint, sqlStr string, database string) (map[string]interface{}, error) {
|
||||
result, err := api.sqlService.ExecuteSQL(connectionID, sqlStr, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"type": result.Type,
|
||||
"data": result.Data,
|
||||
"rowsAffected": result.RowsAffected,
|
||||
"executionTime": result.ExecutionTime,
|
||||
}
|
||||
// 如果是查询,添加列顺序信息
|
||||
if result.Type == "query" && len(result.Columns) > 0 {
|
||||
response["columns"] = result.Columns
|
||||
}
|
||||
|
||||
// 自动保存结果到历史记录(异步执行)
|
||||
go func() {
|
||||
api.resultRepo.Save(connectionID, database, sqlStr, result.Type, result.Data, result.Columns, result.RowsAffected, result.ExecutionTime)
|
||||
}()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetDatabases(connectionID uint) ([]string, error) {
|
||||
return api.sqlService.GetDatabases(connectionID)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetTables(connectionID uint, database string) ([]string, error) {
|
||||
return api.sqlService.GetTables(connectionID, database)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
|
||||
return api.sqlService.GetTableStructure(connectionID, database, tableName)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
|
||||
return api.sqlService.GetIndexes(connectionID, database, tableName)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
return api.sqlService.PreviewTableStructure(connectionID, database, tableName, structure)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
return api.sqlService.UpdateTableStructure(connectionID, database, tableName, structure)
|
||||
}
|
||||
|
||||
func (api *SqlAPI) SaveResult(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (map[string]interface{}, error) {
|
||||
history, err := api.resultRepo.Save(connectionID, database, sql, resultType, data, columns, rowsAffected, executionTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return historyToMap(history), nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetResultHistory(connectionID *uint, keyword string, limit, offset int) (map[string]interface{}, error) {
|
||||
histories, total, err := api.resultRepo.Search(connectionID, keyword, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
items := make([]map[string]interface{}, len(histories))
|
||||
for i, h := range histories {
|
||||
items[i] = historyToMap(&h)
|
||||
}
|
||||
|
||||
return map[string]interface{}{"items": items, "total": total}, nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) GetResultHistoryByID(id uint) (map[string]interface{}, error) {
|
||||
history, err := api.resultRepo.FindByID(id)
|
||||
if err != nil || history == nil {
|
||||
return nil, err
|
||||
}
|
||||
return historyToMap(history), nil
|
||||
}
|
||||
|
||||
func (api *SqlAPI) DeleteResultHistory(id uint) error {
|
||||
return api.resultRepo.Delete(id)
|
||||
}
|
||||
|
||||
func historyToMap(history *models.SqlResultHistory) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"id": history.ID,
|
||||
"connection_id": history.ConnectionID,
|
||||
"database": history.Database,
|
||||
"sql": history.Sql,
|
||||
"type": history.Type,
|
||||
"rows_affected": history.RowsAffected,
|
||||
"execution_time": history.ExecutionTime,
|
||||
"created_at": history.CreatedAt,
|
||||
}
|
||||
|
||||
if history.Data != "" {
|
||||
var data interface{}
|
||||
json.Unmarshal([]byte(history.Data), &data)
|
||||
result["data"] = data
|
||||
}
|
||||
|
||||
if history.Columns != "" {
|
||||
var columns []string
|
||||
json.Unmarshal([]byte(history.Columns), &columns)
|
||||
result["columns"] = columns
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
79
internal/api/tab_api.go
Normal file
79
internal/api/tab_api.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go-desk/internal/service"
|
||||
"go-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// TabAPI 标签页API
|
||||
type TabAPI struct {
|
||||
tabService *service.TabService
|
||||
}
|
||||
|
||||
// NewTabAPI 创建标签页API
|
||||
func NewTabAPI() (*TabAPI, error) {
|
||||
tabService, err := service.NewTabService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TabAPI{tabService: tabService}, nil
|
||||
}
|
||||
|
||||
// SaveSqlTabs 保存SQL标签页列表(接收 map 格式,转换为模型)
|
||||
func (api *TabAPI) SaveSqlTabs(tabs []map[string]interface{}) error {
|
||||
sqlTabs := make([]models.SqlTab, len(tabs))
|
||||
for idx, tabData := range tabs {
|
||||
tab := models.SqlTab{
|
||||
Order: idx,
|
||||
}
|
||||
|
||||
// 处理 ID
|
||||
if id, ok := tabData["id"].(float64); ok && id > 0 {
|
||||
tab.ID = uint(id)
|
||||
}
|
||||
|
||||
// 处理标题
|
||||
if title, ok := tabData["title"].(string); ok {
|
||||
tab.Title = title
|
||||
} else {
|
||||
tab.Title = fmt.Sprintf("查询 %d", idx+1)
|
||||
}
|
||||
|
||||
// 处理内容
|
||||
if content, ok := tabData["content"].(string); ok {
|
||||
tab.Content = content
|
||||
}
|
||||
|
||||
// 处理连接ID
|
||||
if connId, ok := tabData["connectionId"].(float64); ok && connId > 0 {
|
||||
connID := uint(connId)
|
||||
tab.ConnectionID = &connID
|
||||
}
|
||||
|
||||
sqlTabs[idx] = tab
|
||||
}
|
||||
return api.tabService.SaveTabs(sqlTabs)
|
||||
}
|
||||
|
||||
// ListSqlTabs 获取SQL标签页列表(返回 map 格式)
|
||||
func (api *TabAPI) ListSqlTabs() ([]map[string]interface{}, error) {
|
||||
tabs, err := api.tabService.ListTabs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(tabs))
|
||||
for i, tab := range tabs {
|
||||
result[i] = map[string]interface{}{
|
||||
"id": tab.ID,
|
||||
"title": tab.Title,
|
||||
"content": tab.Content,
|
||||
"connectionId": tab.ConnectionID,
|
||||
"order": tab.Order,
|
||||
"createdAt": tab.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
"updatedAt": tab.UpdatedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
99
internal/crypto/aes.go
Normal file
99
internal/crypto/aes.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
// 默认密钥(实际应用中应该从配置文件或环境变量读取)
|
||||
// AES-256 需要 32 字节密钥
|
||||
// "go-desk-db-cli-key-32bytes123456" = 32 bytes
|
||||
defaultKey = []byte("go-desk-db-cli-key-32bytes123456") // 32 bytes for AES-256
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 验证密钥长度
|
||||
if len(defaultKey) != 32 {
|
||||
panic(fmt.Sprintf("AES-256 密钥长度必须为 32 字节,当前为 %d 字节", len(defaultKey)))
|
||||
}
|
||||
}
|
||||
|
||||
// EncryptPassword 加密密码
|
||||
func EncryptPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(defaultKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建加密器失败: %v", err)
|
||||
}
|
||||
|
||||
// 使用 GCM 模式
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
}
|
||||
|
||||
// 生成随机 nonce
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("生成 nonce 失败: %v", err)
|
||||
}
|
||||
|
||||
// 加密
|
||||
ciphertext := aesGCM.Seal(nonce, nonce, []byte(password), nil)
|
||||
|
||||
// Base64 编码
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptPassword 解密密码
|
||||
func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// 如果加密字符串为空或格式不正确,返回空字符串
|
||||
if len(encryptedPassword) < 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Base64 解码
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(defaultKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建解密器失败: %v", err)
|
||||
}
|
||||
|
||||
// 使用 GCM 模式
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
}
|
||||
|
||||
// 提取 nonce
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("密文长度不足")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
|
||||
// 解密
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
818
internal/dbclient/mongo.go
Normal file
818
internal/dbclient/mongo.go
Normal file
@@ -0,0 +1,818 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
// MongoClient MongoDB 客户端
|
||||
type MongoClient struct {
|
||||
client *mongo.Client
|
||||
database *mongo.Database
|
||||
config *MongoConfig
|
||||
}
|
||||
|
||||
// MongoConfig MongoDB 配置
|
||||
type MongoConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string // 认证数据库,默认为 "admin"
|
||||
AuthMechanism string // 认证机制,如 "SCRAM-SHA-1", "SCRAM-SHA-256" 等
|
||||
}
|
||||
|
||||
// NewMongoClient 创建 MongoDB 客户端
|
||||
func NewMongoClient(config *MongoConfig) (*MongoClient, error) {
|
||||
// 确定认证数据库,默认为 admin
|
||||
authSource := config.AuthSource
|
||||
if authSource == "" {
|
||||
authSource = "admin"
|
||||
}
|
||||
|
||||
// 如果指定了认证机制,直接使用;否则尝试自动检测
|
||||
authMechanisms := []string{}
|
||||
if config.AuthMechanism != "" {
|
||||
// 用户明确指定了认证机制,只使用该机制
|
||||
authMechanisms = []string{config.AuthMechanism}
|
||||
} else {
|
||||
// 未指定时,先尝试 SCRAM-SHA-256(更安全),失败则尝试 SCRAM-SHA-1
|
||||
authMechanisms = []string{"SCRAM-SHA-256", "SCRAM-SHA-1"}
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, authMechanism := range authMechanisms {
|
||||
client, err := tryConnectMongo(config, authSource, authMechanism)
|
||||
if err == nil {
|
||||
return client, nil
|
||||
}
|
||||
lastErr = err
|
||||
// 如果明确指定了认证机制,失败后不再尝试其他机制
|
||||
if config.AuthMechanism != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 所有认证机制都失败
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", lastErr)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("MongoDB 连接失败: 未知错误")
|
||||
}
|
||||
|
||||
// tryConnectMongo 尝试使用指定的认证机制连接 MongoDB
|
||||
func tryConnectMongo(config *MongoConfig, authSource, authMechanism string) (*MongoClient, error) {
|
||||
// 构建连接 URI
|
||||
var uri string
|
||||
|
||||
if config.Username != "" && config.Password != "" {
|
||||
// 使用 url.UserPassword 正确转义用户名和密码中的特殊字符
|
||||
// 这会正确处理 @、:、/ 等特殊字符
|
||||
userInfo := url.UserPassword(config.Username, config.Password)
|
||||
|
||||
// 构建基础 URI
|
||||
uri = fmt.Sprintf("mongodb://%s@%s:%d", userInfo.String(), config.Host, config.Port)
|
||||
|
||||
// 添加数据库和认证源参数
|
||||
params := url.Values{}
|
||||
params.Set("authSource", authSource)
|
||||
|
||||
// 添加认证机制参数
|
||||
if authMechanism != "" {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
|
||||
// 如果有业务数据库,添加到路径中
|
||||
if config.Database != "" {
|
||||
uri = fmt.Sprintf("%s/%s?%s", uri, config.Database, params.Encode())
|
||||
} else {
|
||||
// MongoDB URI 要求查询参数前必须有 /,即使没有数据库名
|
||||
uri = fmt.Sprintf("%s/?%s", uri, params.Encode())
|
||||
}
|
||||
} else if config.Database != "" {
|
||||
// 没有认证信息时,数据库部分用于指定默认数据库
|
||||
uri = fmt.Sprintf("mongodb://%s:%d/%s", config.Host, config.Port, config.Database)
|
||||
} else {
|
||||
uri = fmt.Sprintf("mongodb://%s:%d", config.Host, config.Port)
|
||||
}
|
||||
|
||||
// 客户端选项
|
||||
clientOptions := options.Client().
|
||||
ApplyURI(uri).
|
||||
SetConnectTimeout(5 * time.Second).
|
||||
SetServerSelectionTimeout(5 * time.Second)
|
||||
|
||||
// 创建客户端
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MongoDB 失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := client.Ping(ctx, nil); err != nil {
|
||||
client.Disconnect(ctx)
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
var database *mongo.Database
|
||||
if config.Database != "" {
|
||||
database = client.Database(config.Database)
|
||||
}
|
||||
|
||||
return &MongoClient{
|
||||
client: client,
|
||||
database: database,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestMongoConnection 测试连接
|
||||
func TestMongoConnection(host string, port int, username, password, database string) error {
|
||||
return TestMongoConnectionWithAuthSource(host, port, username, password, database, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithAuthSource 测试连接(支持指定认证数据库)
|
||||
func TestMongoConnectionWithAuthSource(host string, port int, username, password, database, authSource string) error {
|
||||
return TestMongoConnectionWithOptions(host, port, username, password, database, authSource, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithOptions 测试连接(支持指定认证数据库和认证机制)
|
||||
func TestMongoConnectionWithOptions(host string, port int, username, password, database, authSource, authMechanism string) error {
|
||||
config := &MongoConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MongoClient) Close() error {
|
||||
if c.client != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return c.client.Disconnect(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MongoClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
databases, err := c.client.ListDatabaseNames(ctx, bson.M{})
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListCollections 获取集合列表
|
||||
func (c *MongoClient) ListCollections(ctx context.Context, database string) ([]string, error) {
|
||||
db := c.client.Database(database)
|
||||
collections, err := db.ListCollectionNames(ctx, bson.M{})
|
||||
return collections, err
|
||||
}
|
||||
|
||||
// GetCollectionStructure 获取集合结构
|
||||
func (c *MongoClient) GetCollectionStructure(ctx context.Context, database, collectionName string) (map[string]interface{}, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"database": database,
|
||||
"collection": collectionName,
|
||||
"sampleDocs": []map[string]interface{}{},
|
||||
"fieldStats": map[string]int{},
|
||||
"indexes": []map[string]interface{}{},
|
||||
"documentCount": int64(0),
|
||||
}
|
||||
|
||||
// 获取文档示例(最多 5 个)
|
||||
opts := options.Find().SetLimit(5)
|
||||
cursor, err := coll.Find(ctx, bson.M{}, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档示例失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var docs []bson.M
|
||||
if err = cursor.All(ctx, &docs); err != nil {
|
||||
return nil, fmt.Errorf("解析文档失败: %v", err)
|
||||
}
|
||||
|
||||
// 转换为 map
|
||||
sampleDocs := make([]map[string]interface{}, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docMap := make(map[string]interface{})
|
||||
for k, v := range doc {
|
||||
docMap[k] = v
|
||||
}
|
||||
sampleDocs = append(sampleDocs, docMap)
|
||||
}
|
||||
result["sampleDocs"] = sampleDocs
|
||||
|
||||
// 字段统计:使用 $sample 聚合管道随机采样10个文档进行统计
|
||||
// 这样可以获得更准确的字段分布,同时保持良好性能
|
||||
// 使用异步方式执行,避免阻塞主流程
|
||||
sampleSize := 10
|
||||
pipeline := []bson.M{
|
||||
{"$sample": bson.M{"size": sampleSize}},
|
||||
{"$project": bson.M{"keys": bson.M{"$objectToArray": "$$ROOT"}}},
|
||||
{"$unwind": "$keys"},
|
||||
{"$group": bson.M{
|
||||
"_id": "$keys.k",
|
||||
"count": bson.M{"$sum": 1},
|
||||
}},
|
||||
{"$sort": bson.M{"count": -1}}, // 按出现次数降序排序
|
||||
}
|
||||
|
||||
sampleCursor, err := coll.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
// 如果采样失败,回退到基于文档示例的统计
|
||||
fieldCount := make(map[string]int)
|
||||
for _, doc := range docs {
|
||||
for key := range doc {
|
||||
fieldCount[key]++
|
||||
}
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = len(docs) // 记录实际采样数量
|
||||
result["fieldStatsMethod"] = "sample-docs" // 标记统计方式
|
||||
} else {
|
||||
defer sampleCursor.Close(ctx)
|
||||
fieldCount := make(map[string]int)
|
||||
for sampleCursor.Next(ctx) {
|
||||
var statResult bson.M
|
||||
if err := sampleCursor.Decode(&statResult); err != nil {
|
||||
continue
|
||||
}
|
||||
fieldName, ok := statResult["_id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var count int
|
||||
switch v := statResult["count"].(type) {
|
||||
case int32:
|
||||
count = int(v)
|
||||
case int64:
|
||||
count = int(v)
|
||||
case int:
|
||||
count = v
|
||||
case float64:
|
||||
count = int(v)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
fieldCount[fieldName] = count
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = sampleSize // 记录采样数量
|
||||
result["fieldStatsMethod"] = "sample-aggregate" // 标记统计方式
|
||||
}
|
||||
|
||||
// 文档总数(使用估算值,性能更好)
|
||||
// 对于大数据集,estimatedDocumentCount 比 CountDocuments 快得多
|
||||
// 如果需要精确值,可以使用 CountDocuments,但性能较差
|
||||
count, err := coll.EstimatedDocumentCount(ctx)
|
||||
if err != nil {
|
||||
// 如果估算失败,尝试精确计数(可能较慢)
|
||||
count, err = coll.CountDocuments(ctx, bson.M{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
result["documentCount"] = count
|
||||
|
||||
// 索引信息
|
||||
indexCursor, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
// 索引查询失败不影响主流程
|
||||
result["indexes"] = []map[string]interface{}{}
|
||||
} else {
|
||||
var indexes []map[string]interface{}
|
||||
for indexCursor.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := indexCursor.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
indexes = append(indexes, map[string]interface{}{
|
||||
"name": indexSpec["name"],
|
||||
"unique": indexSpec["unique"],
|
||||
"keys": indexSpec["key"],
|
||||
})
|
||||
}
|
||||
indexCursor.Close(ctx)
|
||||
result["indexes"] = indexes
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询
|
||||
func (c *MongoClient) ExecuteQuery(ctx context.Context, database, collection string, filter bson.M, limit int64) ([]map[string]interface{}, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// CountDocuments 获取文档数量
|
||||
func (c *MongoClient) CountDocuments(ctx context.Context, database, collection string, filter bson.M) (int64, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
return coll.CountDocuments(ctx, filter)
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 MongoDB 命令
|
||||
// command 可以是 JSON 格式的字符串,格式:{"op": "find", "database": "test", "collection": "users", "filter": {}, "limit": 100}
|
||||
// 支持的操作:find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany
|
||||
func (c *MongoClient) ExecuteCommand(ctx context.Context, database string, command map[string]interface{}) (interface{}, error) {
|
||||
op, ok := command["op"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'op' 字段或格式错误")
|
||||
}
|
||||
|
||||
collectionName, ok := command["collection"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'collection' 字段或格式错误")
|
||||
}
|
||||
|
||||
// 如果没有指定数据库,使用配置中的默认数据库
|
||||
if database == "" {
|
||||
if c.config != nil && c.config.Database != "" {
|
||||
database = c.config.Database
|
||||
} else {
|
||||
return nil, fmt.Errorf("需要指定数据库名称")
|
||||
}
|
||||
}
|
||||
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collectionName)
|
||||
|
||||
switch op {
|
||||
case "find":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
limit := int64(100)
|
||||
if l, ok := command["limit"]; ok {
|
||||
if limitVal, ok := l.(float64); ok {
|
||||
limit = int64(limitVal)
|
||||
} else if limitVal, ok := l.(int64); ok {
|
||||
limit = limitVal
|
||||
}
|
||||
}
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
|
||||
case "count":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
count, err := coll.CountDocuments(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("统计失败: %v", err)
|
||||
}
|
||||
return count, nil
|
||||
|
||||
case "insertOne":
|
||||
document, ok := command["document"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertOne 操作需要 'document' 字段")
|
||||
}
|
||||
|
||||
doc := bson.M{}
|
||||
if docMap, ok := document.(map[string]interface{}); ok {
|
||||
doc = bson.M(docMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("document 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertOne(ctx, doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedId": result.InsertedID,
|
||||
}, nil
|
||||
|
||||
case "insertMany":
|
||||
documents, ok := command["documents"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertMany 操作需要 'documents' 字段")
|
||||
}
|
||||
|
||||
docs := []interface{}{}
|
||||
if docsSlice, ok := documents.([]interface{}); ok {
|
||||
for _, d := range docsSlice {
|
||||
if docMap, ok := d.(map[string]interface{}); ok {
|
||||
docs = append(docs, bson.M(docMap))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("documents 必须是数组格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertMany(ctx, docs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedIds": result.InsertedIDs,
|
||||
"insertedCount": len(result.InsertedIDs),
|
||||
}, nil
|
||||
|
||||
case "updateOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateOne(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "updateMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateMany(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteOne(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteMany(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的操作: %s,支持的操作: find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany", op)
|
||||
}
|
||||
}
|
||||
|
||||
// PreviewCollectionIndexes 预览集合索引变更,只生成命令列表不执行
|
||||
func (c *MongoClient) PreviewCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
var commands []string
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
cmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
dropCmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, dropCmd)
|
||||
}
|
||||
|
||||
// 构建命令字符串(MongoDB shell 格式)
|
||||
keysStr := "{"
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
keysStr += ", "
|
||||
}
|
||||
keysStr += fmt.Sprintf("%s: %d", key.Key, key.Value)
|
||||
}
|
||||
keysStr += "}"
|
||||
|
||||
optionsStr := "{name: \"" + name + "\""
|
||||
if indexOptions.Unique != nil && *indexOptions.Unique {
|
||||
optionsStr += ", unique: true"
|
||||
}
|
||||
optionsStr += "}"
|
||||
|
||||
cmd := fmt.Sprintf("db.%s.createIndex(%s, %s)", collectionName, keysStr, optionsStr)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
// UpdateCollectionIndexes 更新集合索引,返回执行的命令列表
|
||||
func (c *MongoClient) UpdateCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成命令列表
|
||||
commands, err := c.PreviewCollectionIndexes(ctx, database, collectionName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
_, err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
indexModel := mongo.IndexModel{
|
||||
Keys: keys,
|
||||
Options: indexOptions,
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
_, err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除旧索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := coll.Indexes().CreateOne(ctx, indexModel)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("创建索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
875
internal/dbclient/mysql.go
Normal file
875
internal/dbclient/mysql.go
Normal file
@@ -0,0 +1,875 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// MySQLClient MySQL 客户端
|
||||
type MySQLClient struct {
|
||||
db *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
config *MySQLConfig
|
||||
}
|
||||
|
||||
// MySQLConfig MySQL 配置
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// NewMySQLClient 创建 MySQL 客户端
|
||||
func NewMySQLClient(config *MySQLConfig) (*MySQLClient, error) {
|
||||
// 构建 DSN
|
||||
mysqlConfig := mysqldriver.Config{
|
||||
User: config.Username,
|
||||
Passwd: config.Password,
|
||||
Net: "tcp",
|
||||
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
DBName: config.Database,
|
||||
Params: map[string]string{
|
||||
"charset": "utf8mb4",
|
||||
"parseTime": "True",
|
||||
"loc": "Local",
|
||||
"multiStatements": "true", // 支持多条SQL语句执行
|
||||
},
|
||||
AllowNativePasswords: true,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
dsn := mysqlConfig.FormatDSN()
|
||||
|
||||
// GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
// 打开连接
|
||||
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MySQL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取底层 sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("MySQL 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(2)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
return &MySQLClient{
|
||||
db: db,
|
||||
sqlDB: sqlDB,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func TestMySQLConnection(host string, port int, username, password, database string) error {
|
||||
config := &MySQLConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
}
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c.sqlDB != nil {
|
||||
return c.sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryResult 查询结果,包含数据和列顺序
|
||||
type QueryResult struct {
|
||||
Data []map[string]interface{}
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询 SQL
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
// 注意:SQL 语句应该已经包含 LIMIT 和 OFFSET(由客户端添加)
|
||||
func (c *MySQLClient) ExecuteQuery(ctx context.Context, sqlStr string, database string) (*QueryResult, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return nil, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sqlStr).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("执行查询失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 检查 rows 错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("查询结果错误: %v", err)
|
||||
}
|
||||
|
||||
// 获取列名
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有列,返回空数组
|
||||
if len(columns) == 0 {
|
||||
return &QueryResult{
|
||||
Data: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
// 创建值数组和指针数组
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// 扫描行数据
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 构建结果 map,按照列顺序构建
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
// 处理 nil 值
|
||||
if val == nil {
|
||||
row[col] = nil
|
||||
} else if b, ok := val.([]byte); ok {
|
||||
// 处理 []byte 类型
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
// 检查迭代过程中的错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("读取数据时发生错误: %v", err)
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Data: results,
|
||||
Columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteUpdate 执行更新 SQL(INSERT/UPDATE/DELETE)
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
func (c *MySQLClient) ExecuteUpdate(ctx context.Context, sqlStr string, database string) (int64, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return 0, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sqlStr)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("执行更新失败: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MySQLClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
var databases []string
|
||||
err := c.db.Raw("SHOW DATABASES").Scan(&databases).Error
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListTables 获取表列表
|
||||
func (c *MySQLClient) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
var tables []string
|
||||
query := "SHOW TABLES"
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", database)
|
||||
}
|
||||
err := c.db.Raw(query).Scan(&tables).Error
|
||||
return tables, err
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (c *MySQLClient) GetTableStructure(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
// 使用 SHOW FULL COLUMNS 来获取包含 comment 的完整字段信息
|
||||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", database, tableName)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表结构失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else if val == nil {
|
||||
row[col] = nil
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Comment 字段存在(SHOW FULL COLUMNS 返回的字段名是 Comment)
|
||||
if _, ok := row["Comment"]; !ok {
|
||||
row["Comment"] = ""
|
||||
}
|
||||
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (c *MySQLClient) GetIndexes(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
query := "SHOW INDEX FROM "
|
||||
if database != "" {
|
||||
query += fmt.Sprintf("`%s`.", database)
|
||||
}
|
||||
query += fmt.Sprintf("`%s`", tableName)
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取索引列表失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更,只生成 SQL 语句不执行
|
||||
func (c *MySQLClient) PreviewTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 获取当前表结构
|
||||
currentColumns, err := c.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前表结构失败: %v", err)
|
||||
}
|
||||
|
||||
currentIndexes, err := c.GetIndexes(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析新的结构数据
|
||||
var newColumns []map[string]interface{}
|
||||
var newIndexes []map[string]interface{}
|
||||
|
||||
if cols, ok := structure["columns"].([]interface{}); ok {
|
||||
for _, col := range cols {
|
||||
if colMap, ok := col.(map[string]interface{}); ok {
|
||||
newColumns = append(newColumns, colMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ALTER TABLE 语句
|
||||
var alterStatements []string
|
||||
|
||||
// 处理字段变更
|
||||
alterStatements = append(alterStatements, c.buildColumnAlterStatements(tableName, currentColumns, newColumns)...)
|
||||
|
||||
// 处理索引变更
|
||||
alterStatements = append(alterStatements, c.buildIndexAlterStatements(tableName, currentIndexes, newIndexes)...)
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构,返回生成的 SQL 语句列表
|
||||
func (c *MySQLClient) UpdateTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成 SQL 语句
|
||||
alterStatements, err := c.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行所有 ALTER TABLE 语句
|
||||
if len(alterStatements) > 0 {
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, stmt := range alterStatements {
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("执行 ALTER TABLE 失败: %v, SQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// buildColumnAlterStatements 构建字段变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildColumnAlterStatements(tableName string, currentColumns, newColumns []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建字段名映射和顺序映射
|
||||
currentFieldMap := make(map[string]map[string]interface{})
|
||||
currentFieldOrder := make([]string, 0, len(currentColumns))
|
||||
for _, col := range currentColumns {
|
||||
if field, ok := col["Field"].(string); ok {
|
||||
currentFieldMap[field] = col
|
||||
currentFieldOrder = append(currentFieldOrder, field)
|
||||
}
|
||||
}
|
||||
|
||||
newFieldMap := make(map[string]bool)
|
||||
newFieldOrder := make([]string, 0, len(newColumns))
|
||||
newColumnsMap := make(map[string]map[string]interface{})
|
||||
for _, col := range newColumns {
|
||||
if field, ok := col["Field"].(string); ok && field != "" {
|
||||
newFieldMap[field] = true
|
||||
newFieldOrder = append(newFieldOrder, field)
|
||||
newColumnsMap[field] = col
|
||||
}
|
||||
}
|
||||
|
||||
// 检测字段重命名:优先使用位置匹配,如果位置相同但字段名不同,认为是重命名
|
||||
renameMap := make(map[string]string) // oldName -> newName
|
||||
processedNewFields := make(map[string]bool)
|
||||
|
||||
// 第一步:使用位置匹配检测重命名(最可靠)
|
||||
for oldIndex, oldFieldName := range currentFieldOrder {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
|
||||
// 检查新字段列表中相同位置是否有字段
|
||||
if oldIndex < len(newFieldOrder) {
|
||||
newFieldName := newFieldOrder[oldIndex]
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent && !processedNewFields[newFieldName] {
|
||||
// 新字段不在当前字段列表中,且位置相同,很可能是重命名
|
||||
// 进一步验证:检查类型是否相同(类型相同更可能是重命名)
|
||||
oldCol := currentFieldMap[oldFieldName]
|
||||
newCol := newColumnsMap[newFieldName]
|
||||
oldType := getStringValue(oldCol["Type"])
|
||||
newType := getStringValue(newCol["Type"])
|
||||
|
||||
// 如果类型相同,认为是重命名
|
||||
if oldType == newType {
|
||||
renameMap[oldFieldName] = newFieldName
|
||||
processedNewFields[newFieldName] = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二步:对于未匹配的字段,使用属性匹配(兼容旧逻辑)
|
||||
for oldFieldName, oldCol := range currentFieldMap {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
if renameMap[oldFieldName] != "" {
|
||||
continue // 已经通过位置匹配识别为重命名
|
||||
}
|
||||
|
||||
// 查找属性完全匹配的新字段
|
||||
var matchedNewField string
|
||||
for newFieldName, newCol := range newColumnsMap {
|
||||
if processedNewFields[newFieldName] {
|
||||
continue // 已经被匹配过了
|
||||
}
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent {
|
||||
// 这是一个新增字段,检查属性是否匹配
|
||||
if c.isColumnPropertiesEqual(oldCol, newCol) {
|
||||
if matchedNewField == "" {
|
||||
matchedNewField = newFieldName
|
||||
} else {
|
||||
// 有多个匹配,无法确定,不认为是重命名
|
||||
matchedNewField = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到唯一匹配,认为是重命名
|
||||
if matchedNewField != "" {
|
||||
renameMap[oldFieldName] = matchedNewField
|
||||
processedNewFields[matchedNewField] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理字段重命名
|
||||
for oldName, newName := range renameMap {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tableName, oldName, newName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// 处理字段添加、修改和位置调整(排除已重命名的字段)
|
||||
for i, newCol := range newColumns {
|
||||
field, _ := newCol["Field"].(string)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是重命名的字段
|
||||
isRenamed := false
|
||||
var oldName string
|
||||
for old, new := range renameMap {
|
||||
if new == field {
|
||||
isRenamed = true
|
||||
oldName = old
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isRenamed {
|
||||
// 重命名的字段:如果属性有变化,需要 MODIFY COLUMN
|
||||
oldCol := currentFieldMap[oldName]
|
||||
needsModify := c.isColumnChanged(oldCol, newCol)
|
||||
|
||||
// 检查顺序变化:使用旧字段名在 currentOrder 中查找位置,与新位置比较
|
||||
oldIndex := -1
|
||||
for idx, name := range currentFieldOrder {
|
||||
if name == oldName {
|
||||
oldIndex = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
needsReorder := (oldIndex != -1 && oldIndex != i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
// 重命名后需要修改属性或位置
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentCol, exists := currentFieldMap[field]; exists {
|
||||
// 修改现有字段
|
||||
needsModify := c.isColumnChanged(currentCol, newCol)
|
||||
needsReorder := c.isColumnOrderChanged(currentFieldOrder, newFieldOrder, field, i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新字段(排除重命名的字段)
|
||||
stmt := c.buildAddColumnStatement(tableName, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的字段(排除已重命名的字段)
|
||||
for field := range currentFieldMap {
|
||||
if !newFieldMap[field] && renameMap[field] == "" {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`", tableName, field)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// buildIndexAlterStatements 构建索引变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildIndexAlterStatements(tableName string, currentIndexes, newIndexes []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建索引名映射
|
||||
currentIndexMap := make(map[string]map[string]interface{})
|
||||
for _, idx := range currentIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "PRIMARY" {
|
||||
currentIndexMap[keyName] = idx
|
||||
}
|
||||
}
|
||||
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "" && keyName != "PRIMARY" {
|
||||
newIndexMap[keyName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理索引变更
|
||||
for _, newIdx := range newIndexes {
|
||||
keyName, _ := newIdx["Key_name"].(string)
|
||||
if keyName == "" || keyName == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentIdx, exists := currentIndexMap[keyName]; exists {
|
||||
// 修改现有索引
|
||||
if c.isIndexChanged(currentIdx, newIdx) {
|
||||
dropStmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
addStmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if addStmt != "" {
|
||||
statements = append(statements, dropStmt)
|
||||
statements = append(statements, addStmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新索引
|
||||
stmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for keyName := range currentIndexMap {
|
||||
if !newIndexMap[keyName] {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// isColumnChanged 检查字段是否发生变化(不包括字段名)
|
||||
func (c *MySQLClient) isColumnChanged(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isColumnPropertiesEqual 检查字段属性是否完全相等(不包括字段名)
|
||||
func (c *MySQLClient) isColumnPropertiesEqual(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Key", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnOrderChanged 检查字段顺序是否发生变化
|
||||
func (c *MySQLClient) isColumnOrderChanged(currentOrder, newOrder []string, fieldName string, newIndex int) bool {
|
||||
// 查找字段在当前顺序中的位置
|
||||
currentIndex := -1
|
||||
for i, name := range currentOrder {
|
||||
if name == fieldName {
|
||||
currentIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果字段不存在于当前顺序中(新字段),不需要检查顺序
|
||||
if currentIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果索引相同,检查前面的字段是否相同
|
||||
if newIndex == currentIndex {
|
||||
// 检查前面的字段集合是否相同
|
||||
if newIndex > 0 {
|
||||
currentPrevFields := make(map[string]bool)
|
||||
for i := 0; i < currentIndex; i++ {
|
||||
currentPrevFields[currentOrder[i]] = true
|
||||
}
|
||||
|
||||
newPrevFields := make(map[string]bool)
|
||||
for i := 0; i < newIndex; i++ {
|
||||
newPrevFields[newOrder[i]] = true
|
||||
}
|
||||
|
||||
// 如果前面的字段集合不同,说明顺序变了
|
||||
if len(currentPrevFields) != len(newPrevFields) {
|
||||
return true
|
||||
}
|
||||
for f := range currentPrevFields {
|
||||
if !newPrevFields[f] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 索引不同,说明顺序变了
|
||||
return true
|
||||
}
|
||||
|
||||
// isIndexChanged 检查索引是否发生变化
|
||||
func (c *MySQLClient) isIndexChanged(oldIdx, newIdx map[string]interface{}) bool {
|
||||
oldCol := getStringValue(oldIdx["Column_name"])
|
||||
newCol := getStringValue(newIdx["Column_name"])
|
||||
if oldCol != newCol {
|
||||
return true
|
||||
}
|
||||
|
||||
oldUnique := getIntValue(oldIdx["Non_unique"])
|
||||
newUnique := getIntValue(newIdx["Non_unique"])
|
||||
return oldUnique != newUnique
|
||||
}
|
||||
|
||||
// buildAddColumnStatement 构建添加字段的语句
|
||||
func (c *MySQLClient) buildAddColumnStatement(tableName string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
field := getStringValue(col["Field"])
|
||||
if field == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildModifyColumnStatement 构建修改字段的语句
|
||||
func (c *MySQLClient) buildModifyColumnStatement(tableName, field string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` MODIFY COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildColumnPosition 构建字段位置子句(AFTER 或 FIRST)
|
||||
func (c *MySQLClient) buildColumnPosition(fieldOrder []string, index int) string {
|
||||
if index < 0 || index >= len(fieldOrder) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
// 第一个字段使用 FIRST
|
||||
return " FIRST"
|
||||
}
|
||||
|
||||
// 其他字段使用 AFTER 前一个字段
|
||||
prevField := fieldOrder[index-1]
|
||||
return fmt.Sprintf(" AFTER `%s`", prevField)
|
||||
}
|
||||
|
||||
// buildColumnDefinition 构建字段定义
|
||||
func (c *MySQLClient) buildColumnDefinition(col map[string]interface{}) string {
|
||||
field := getStringValue(col["Field"])
|
||||
colType := getStringValue(col["Type"])
|
||||
null := getStringValue(col["Null"])
|
||||
defaultVal := col["Default"]
|
||||
extra := getStringValue(col["Extra"])
|
||||
comment := getStringValue(col["Comment"])
|
||||
|
||||
def := fmt.Sprintf("`%s` %s", field, colType)
|
||||
|
||||
if null == "NO" {
|
||||
def += " NOT NULL"
|
||||
}
|
||||
|
||||
if defaultVal != nil {
|
||||
if defaultStr, ok := defaultVal.(string); ok {
|
||||
if defaultStr == "" {
|
||||
// 空字符串表示默认值为空字符串
|
||||
def += " DEFAULT ''"
|
||||
} else if defaultStr != "NULL" {
|
||||
// 转义单引号
|
||||
escapedDefault := strings.ReplaceAll(defaultStr, "'", "''")
|
||||
def += fmt.Sprintf(" DEFAULT '%s'", escapedDefault)
|
||||
}
|
||||
// 如果 defaultStr == "NULL",不添加 DEFAULT 子句(允许 NULL)
|
||||
} else {
|
||||
// 非字符串类型的默认值
|
||||
def += fmt.Sprintf(" DEFAULT %v", defaultVal)
|
||||
}
|
||||
}
|
||||
|
||||
if extra != "" {
|
||||
def += " " + extra
|
||||
}
|
||||
|
||||
if comment != "" {
|
||||
// 转义单引号
|
||||
escapedComment := strings.ReplaceAll(comment, "'", "''")
|
||||
def += fmt.Sprintf(" COMMENT '%s'", escapedComment)
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// buildAddIndexStatement 构建添加索引的语句
|
||||
func (c *MySQLClient) buildAddIndexStatement(tableName string, idx map[string]interface{}) string {
|
||||
keyName := getStringValue(idx["Key_name"])
|
||||
columnName := getStringValue(idx["Column_name"])
|
||||
nonUnique := getIntValue(idx["Non_unique"])
|
||||
|
||||
if keyName == "" || columnName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
indexType := "INDEX"
|
||||
if nonUnique == 0 {
|
||||
indexType = "UNIQUE INDEX"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (`%s`)", tableName, indexType, keyName, columnName)
|
||||
}
|
||||
|
||||
// getStringValue 安全获取字符串值
|
||||
func getStringValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// getIntValue 安全获取整数值
|
||||
func getIntValue(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var i int
|
||||
fmt.Sscanf(val, "%d", &i)
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
236
internal/dbclient/pool.go
Normal file
236
internal/dbclient/pool.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go-desk/internal/crypto"
|
||||
"go-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionPool 连接池管理器
|
||||
type ConnectionPool struct {
|
||||
mysqlClients map[uint]*MySQLClient
|
||||
redisClients map[uint]*RedisClient
|
||||
mongoClients map[uint]*MongoClient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
globalPool *ConnectionPool
|
||||
poolOnce sync.Once
|
||||
)
|
||||
|
||||
// GetPool 获取全局连接池实例
|
||||
func GetPool() *ConnectionPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = &ConnectionPool{
|
||||
mysqlClients: make(map[uint]*MySQLClient),
|
||||
redisClients: make(map[uint]*RedisClient),
|
||||
mongoClients: make(map[uint]*MongoClient),
|
||||
}
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// GetMySQLClient 获取或创建 MySQL 客户端
|
||||
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mysqlClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
if err := client.sqlDB.Ping(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mysqlClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MySQLConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password, // 如果密码为空,MySQL会尝试无密码连接
|
||||
Database: conn.Database,
|
||||
}
|
||||
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mysqlClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetRedisClient 获取或创建 Redis 客户端
|
||||
func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.redisClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx).Err(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.redisClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Redis DB 编号(从 Database 字段,默认为 0)
|
||||
dbNum := 0
|
||||
if conn.Database != "" {
|
||||
// 尝试解析 Database 字段为数字
|
||||
_, err := fmt.Sscanf(conn.Database, "%d", &dbNum)
|
||||
if err != nil {
|
||||
// 如果解析失败,使用默认值 0
|
||||
dbNum = 0
|
||||
}
|
||||
// 限制 DB 编号在 0-15 之间
|
||||
if dbNum < 0 || dbNum > 15 {
|
||||
dbNum = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &RedisConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.redisClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetMongoClient 获取或创建 MongoDB 客户端
|
||||
func (p *ConnectionPool) GetMongoClient(conn *models.DbConnection) (*MongoClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mongoClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx, nil); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mongoClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MongoConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password,
|
||||
Database: conn.Database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mongoClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// CloseConnection 关闭指定连接
|
||||
func (p *ConnectionPool) CloseConnection(connID uint, dbType string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
if client, ok := p.mysqlClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mysqlClients, connID)
|
||||
}
|
||||
case "redis":
|
||||
if client, ok := p.redisClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.redisClients, connID)
|
||||
}
|
||||
case "mongo":
|
||||
if client, ok := p.mongoClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mongoClients, connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll 关闭所有连接
|
||||
func (p *ConnectionPool) CloseAll() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for _, client := range p.mysqlClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.redisClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.mongoClients {
|
||||
client.Close()
|
||||
}
|
||||
|
||||
p.mysqlClients = make(map[uint]*MySQLClient)
|
||||
p.redisClients = make(map[uint]*RedisClient)
|
||||
p.mongoClients = make(map[uint]*MongoClient)
|
||||
}
|
||||
239
internal/dbclient/redis.go
Normal file
239
internal/dbclient/redis.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisClient Redis 客户端
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
config *RedisConfig
|
||||
}
|
||||
|
||||
// RedisConfig Redis 配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int // 数据库编号,默认 0
|
||||
}
|
||||
|
||||
// NewRedisClient 创建 Redis 客户端
|
||||
func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("Redis 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
return &RedisClient{
|
||||
client: rdb,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestRedisConnection 测试连接
|
||||
func TestRedisConnection(host string, port int, password string) error {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: 0,
|
||||
}
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewRedisClientByDB 根据参数创建指定 DB 的 Redis 客户端(用于多 DB 场景)
|
||||
func NewRedisClientByDB(host string, port int, password string, dbNum int) (*RedisClient, error) {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
return NewRedisClient(config)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *RedisClient) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 Redis 命令
|
||||
func (c *RedisClient) ExecuteCommand(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
|
||||
return c.client.Do(ctx, append([]interface{}{cmd}, args...)...).Result()
|
||||
}
|
||||
|
||||
// GetKeys 获取 Key 列表(支持 pattern,使用 SCAN 代替 KEYS 以提高性能)
|
||||
func (c *RedisClient) GetKeys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
|
||||
var keys []string
|
||||
var cursor uint64
|
||||
const count = 100 // 每次扫描的数量
|
||||
|
||||
for {
|
||||
var err error
|
||||
var batch []string
|
||||
batch, cursor, err = c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys = append(keys, batch...)
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetKeyType 获取 Key 类型
|
||||
func (c *RedisClient) GetKeyType(ctx context.Context, key string) (string, error) {
|
||||
return c.client.Type(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyValue 获取 Key 值
|
||||
func (c *RedisClient) GetKeyValue(ctx context.Context, key string) (interface{}, error) {
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch keyType {
|
||||
case "string":
|
||||
return c.client.Get(ctx, key).Result()
|
||||
case "list":
|
||||
return c.client.LRange(ctx, key, 0, -1).Result()
|
||||
case "set":
|
||||
return c.client.SMembers(ctx, key).Result()
|
||||
case "zset":
|
||||
// 对于有序集合,返回带分数的结果
|
||||
zMembers, err := c.client.ZRangeWithScores(ctx, key, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为 map 格式,便于展示
|
||||
result := make([]map[string]interface{}, len(zMembers))
|
||||
for i, member := range zMembers {
|
||||
result[i] = map[string]interface{}{
|
||||
"member": member.Member,
|
||||
"score": member.Score,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
case "hash":
|
||||
return c.client.HGetAll(ctx, key).Result()
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的类型: %s", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTTL 获取 Key 的 TTL
|
||||
func (c *RedisClient) GetTTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
return c.client.TTL(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyInfo 获取 Key 详细信息
|
||||
func (c *RedisClient) GetKeyInfo(ctx context.Context, key string) (map[string]interface{}, error) {
|
||||
info := map[string]interface{}{
|
||||
"key": key,
|
||||
"type": "",
|
||||
"value": nil,
|
||||
"ttl": 0,
|
||||
}
|
||||
|
||||
// 获取 Key 类型
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 类型失败: %v", err)
|
||||
}
|
||||
info["type"] = keyType
|
||||
|
||||
// 获取 TTL
|
||||
ttl, err := c.GetTTL(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 TTL 失败: %v", err)
|
||||
}
|
||||
info["ttl"] = ttl.Seconds()
|
||||
|
||||
// 获取 Key 值(限制大小,避免过大)
|
||||
value, err := c.GetKeyValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 值失败: %v", err)
|
||||
}
|
||||
info["value"] = formatValuePreview(value)
|
||||
|
||||
// 获取 Key 长度(使用 STRLEN、HLEN、SCARD、ZCARD)
|
||||
var keyLength int64
|
||||
switch keyType {
|
||||
case "string":
|
||||
keyLength, err = c.client.StrLen(ctx, key).Result()
|
||||
case "list":
|
||||
keyLength, err = c.client.LLen(ctx, key).Result()
|
||||
case "set":
|
||||
keyLength, err = c.client.SCard(ctx, key).Result()
|
||||
case "zset":
|
||||
keyLength, err = c.client.ZCard(ctx, key).Result()
|
||||
case "hash":
|
||||
keyLength, err = c.client.HLen(ctx, key).Result()
|
||||
}
|
||||
if err == nil {
|
||||
info["length"] = keyLength
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// formatValuePreview 格式化值预览(限制长度)
|
||||
func formatValuePreview(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
const maxPreviewLength = 200
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
if len(valueStr) > maxPreviewLength {
|
||||
valueStr = valueStr[:maxPreviewLength] + "..."
|
||||
}
|
||||
|
||||
return valueStr
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表(Redis 使用 DB number)
|
||||
// Redis 没有传统数据库概念,这里返回空数组
|
||||
func (c *RedisClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
// Redis 可以使用 DB number 来隔离数据
|
||||
// 这里可以返回当前配置的 DB 或者所有可用的 DB
|
||||
// 为简单起见,返回空数组,让用户直接操作 Key
|
||||
return []string{}, nil
|
||||
}
|
||||
187
internal/service/connection_service.go
Normal file
187
internal/service/connection_service.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go-desk/internal/crypto"
|
||||
"go-desk/internal/dbclient"
|
||||
"go-desk/internal/storage/models"
|
||||
"go-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// ConnectionService 连接管理服务
|
||||
type ConnectionService struct {
|
||||
repo repository.ConnectionRepository
|
||||
}
|
||||
|
||||
// NewConnectionService 创建连接服务
|
||||
func NewConnectionService() (*ConnectionService, error) {
|
||||
repo, err := repository.NewConnectionRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建连接仓库失败: %v", err)
|
||||
}
|
||||
return &ConnectionService{repo: repo}, nil
|
||||
}
|
||||
|
||||
// SaveConnection 保存连接配置
|
||||
func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error {
|
||||
// 验证
|
||||
if conn.Name == "" {
|
||||
return fmt.Errorf("连接名称不能为空")
|
||||
}
|
||||
if conn.Type == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
if conn.Host == "" {
|
||||
return fmt.Errorf("主机地址不能为空")
|
||||
}
|
||||
|
||||
// 检查名称是否重复
|
||||
existing, err := s.repo.FindByName(conn.Name, conn.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查连接名称失败: %v", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return fmt.Errorf("连接名称已存在")
|
||||
}
|
||||
|
||||
// 处理密码
|
||||
if conn.ID > 0 {
|
||||
if conn.Password == "" {
|
||||
// 更新模式:保留原密码
|
||||
conn.Password, err = s.getPassword(conn.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 加密新密码
|
||||
conn.Password, err = crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 新增模式:加密密码
|
||||
conn.Password, err = crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.Save(conn)
|
||||
}
|
||||
|
||||
// getPassword 获取原始密码
|
||||
func (s *ConnectionService) getPassword(id uint) (string, error) {
|
||||
existing, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
return existing.Password, nil
|
||||
}
|
||||
|
||||
// ListConnections 获取连接列表
|
||||
func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) {
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
|
||||
// GetConnection 获取连接详情
|
||||
func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) {
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// DeleteConnection 删除连接配置
|
||||
func (s *ConnectionService) DeleteConnection(id uint) error {
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
|
||||
// TestConnection 测试连接(通过已保存的连接ID)
|
||||
func (s *ConnectionService) TestConnection(id uint) error {
|
||||
conn, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 解密密码用于测试
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 根据类型测试连接
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
return dbclient.TestMySQLConnection(conn.Host, conn.Port, conn.Username, password, conn.Database)
|
||||
case "redis":
|
||||
return dbclient.TestRedisConnection(conn.Host, conn.Port, password)
|
||||
case "mongo":
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbclient.TestMongoConnectionWithOptions(conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionWithParams 测试连接(直接传入参数,不保存数据)
|
||||
func (s *ConnectionService) TestConnectionWithParams(connType, host string, port int, username, password, database, options string, existingId uint) error {
|
||||
// 验证必填项
|
||||
if connType == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
if host == "" {
|
||||
return fmt.Errorf("主机地址不能为空")
|
||||
}
|
||||
|
||||
// 如果是编辑模式且密码为空,尝试获取已保存的密码
|
||||
actualPassword := password
|
||||
if existingId > 0 && password == "" {
|
||||
conn, err := s.repo.FindByID(existingId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
// 解密原密码
|
||||
actualPassword, err = crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 根据类型测试连接
|
||||
switch connType {
|
||||
case "mysql":
|
||||
return dbclient.TestMySQLConnection(host, port, username, actualPassword, database)
|
||||
case "redis":
|
||||
return dbclient.TestRedisConnection(host, port, actualPassword)
|
||||
case "mongo":
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbclient.TestMongoConnectionWithOptions(host, port, username, actualPassword, database, authSource, authMechanism)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", connType)
|
||||
}
|
||||
}
|
||||
467
internal/service/sql_exec_service.go
Normal file
467
internal/service/sql_exec_service.go
Normal file
@@ -0,0 +1,467 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go-desk/internal/dbclient"
|
||||
"go-desk/internal/storage/models"
|
||||
"go-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// SqlExecService SQL执行服务
|
||||
type SqlExecService struct {
|
||||
connRepo repository.ConnectionRepository
|
||||
pool *dbclient.ConnectionPool
|
||||
}
|
||||
|
||||
// NewSqlExecService 创建SQL执行服务
|
||||
func NewSqlExecService() (*SqlExecService, error) {
|
||||
connRepo, err := repository.NewConnectionRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SqlExecService{
|
||||
connRepo: connRepo,
|
||||
pool: dbclient.GetPool(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SqlResult SQL执行结果
|
||||
type SqlResult struct {
|
||||
Type string `json:"type"` // query/update/command
|
||||
Data interface{} `json:"data"` // 查询结果数据
|
||||
Columns []string `json:"columns"` // 列顺序(仅查询时有效)
|
||||
RowsAffected int `json:"rowsAffected"` // 影响行数
|
||||
ExecutionTime int64 `json:"executionTime"` // 执行时间(毫秒)
|
||||
}
|
||||
|
||||
// ExecuteSQL 执行SQL语句
|
||||
// 注意:SQL 语句应该已经包含分页信息(LIMIT 和 OFFSET),由客户端添加
|
||||
func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database string) (*SqlResult, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
return s.executeMySQL(ctx, conn, sqlStr, database, startTime)
|
||||
case "redis":
|
||||
return s.executeRedis(ctx, conn, sqlStr, startTime)
|
||||
case "mongo":
|
||||
return s.executeMongo(ctx, conn, sqlStr, database, startTime)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// executeMySQL 执行MySQL SQL
|
||||
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
|
||||
sqlStr = strings.TrimSpace(sqlStr)
|
||||
sqlUpper := strings.ToUpper(sqlStr)
|
||||
|
||||
// 获取数据库参数
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = conn.Database
|
||||
}
|
||||
|
||||
result := &SqlResult{
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// 判断是查询还是更新
|
||||
if strings.HasPrefix(sqlUpper, "SELECT") || strings.HasPrefix(sqlUpper, "SHOW") ||
|
||||
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
|
||||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
|
||||
// 查询语句
|
||||
queryResult, err := client.ExecuteQuery(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Type = "query"
|
||||
result.Data = queryResult.Data
|
||||
result.Columns = queryResult.Columns
|
||||
result.RowsAffected = len(queryResult.Data)
|
||||
} else {
|
||||
// 更新语句
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Type = "update"
|
||||
result.RowsAffected = int(rowsAffected)
|
||||
result.Data = nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// executeRedis 执行Redis命令
|
||||
func (s *SqlExecService) executeRedis(ctx context.Context, conn *models.DbConnection, sqlStr string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析Redis命令
|
||||
parts := parseRedisCommand(sqlStr)
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("Redis 命令不能为空")
|
||||
}
|
||||
|
||||
cmd := strings.ToUpper(parts[0])
|
||||
args := make([]interface{}, 0)
|
||||
for i := 1; i < len(parts); i++ {
|
||||
args = append(args, parts[i])
|
||||
}
|
||||
|
||||
data, err := client.ExecuteCommand(ctx, cmd, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SqlResult{
|
||||
Type: "command",
|
||||
Data: data,
|
||||
RowsAffected: 1,
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// executeMongo 执行MongoDB命令
|
||||
func (s *SqlExecService) executeMongo(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析MongoDB命令(JSON格式)
|
||||
var command map[string]interface{}
|
||||
sqlStr = strings.TrimSpace(sqlStr)
|
||||
if err := json.Unmarshal([]byte(sqlStr), &command); err != nil {
|
||||
return nil, fmt.Errorf("MongoDB 命令必须是有效的 JSON 格式: %v", err)
|
||||
}
|
||||
|
||||
// 确定数据库
|
||||
dbName := conn.Database
|
||||
if db, ok := command["database"].(string); ok && db != "" {
|
||||
dbName = db
|
||||
}
|
||||
if database != "" {
|
||||
dbName = database
|
||||
}
|
||||
if dbName == "" {
|
||||
return nil, fmt.Errorf("需要指定数据库名称")
|
||||
}
|
||||
|
||||
// 执行命令
|
||||
data, err := client.ExecuteCommand(ctx, dbName, command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SqlResult{
|
||||
Type: "command",
|
||||
Data: data,
|
||||
ExecutionTime: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// 根据操作类型确定影响行数
|
||||
if op, ok := command["op"].(string); ok {
|
||||
switch op {
|
||||
case "find":
|
||||
if results, ok := data.([]map[string]interface{}); ok {
|
||||
result.RowsAffected = len(results)
|
||||
}
|
||||
case "count":
|
||||
if count, ok := data.(int64); ok {
|
||||
result.RowsAffected = int(count)
|
||||
}
|
||||
case "insertOne", "deleteOne":
|
||||
result.RowsAffected = 1
|
||||
case "insertMany":
|
||||
if resultMap, ok := data.(map[string]interface{}); ok {
|
||||
if count, ok := resultMap["insertedCount"].(int); ok {
|
||||
result.RowsAffected = count
|
||||
}
|
||||
}
|
||||
default:
|
||||
result.RowsAffected = 0
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDatabases 获取数据库列表
|
||||
func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListDatabases(ctx)
|
||||
case "redis":
|
||||
databases := make([]string, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
databases[i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
return databases, nil
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListDatabases(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTables 获取表列表(MySQL/MongoDB)或Key列表(Redis)
|
||||
func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListTables(ctx, database)
|
||||
case "redis":
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
return client.GetKeys(ctx, database)
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.ListCollections(ctx, database)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// parseRedisCommand 解析Redis命令
|
||||
func parseRedisCommand(cmd string) []string {
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
if cmd == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
inQuotes := false
|
||||
quoteChar := byte(0)
|
||||
|
||||
for i := 0; i < len(cmd); i++ {
|
||||
char := cmd[i]
|
||||
if !inQuotes {
|
||||
if char == '"' || char == '\'' {
|
||||
inQuotes = true
|
||||
quoteChar = char
|
||||
} else if char == ' ' || char == '\t' {
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(char)
|
||||
}
|
||||
} else {
|
||||
if char == quoteChar {
|
||||
inQuotes = false
|
||||
quoteChar = 0
|
||||
} else {
|
||||
current.WriteByte(char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableName string) (map[string]interface{}, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
structure, err := client.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "mysql",
|
||||
"database": database,
|
||||
"table": tableName,
|
||||
"columns": structure,
|
||||
}, nil
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
structure, err := client.GetCollectionStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "mongo",
|
||||
"database": database,
|
||||
"collection": tableName,
|
||||
"structure": structure,
|
||||
}, nil
|
||||
|
||||
case "redis":
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Redis 客户端失败: %v", err)
|
||||
}
|
||||
info, err := client.GetKeyInfo(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "redis",
|
||||
"key": tableName,
|
||||
"info": info,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName string) ([]map[string]interface{}, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
return client.GetIndexes(ctx, database, tableName)
|
||||
|
||||
case "mongo", "redis":
|
||||
return []map[string]interface{}{}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更
|
||||
func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
return client.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.PreviewCollectionIndexes(ctx, database, tableName, structure)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构
|
||||
func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
conn, err := s.connRepo.FindByID(connectionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
}
|
||||
return client.UpdateTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MongoDB 客户端失败: %v", err)
|
||||
}
|
||||
return client.UpdateCollectionIndexes(ctx, database, tableName, structure)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
36
internal/service/tab_service.go
Normal file
36
internal/service/tab_service.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go-desk/internal/storage/models"
|
||||
"go-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// TabService 标签页管理服务
|
||||
type TabService struct {
|
||||
repo repository.TabRepository
|
||||
}
|
||||
|
||||
// NewTabService 创建标签页服务
|
||||
func NewTabService() (*TabService, error) {
|
||||
repo, err := repository.NewTabRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建标签页仓库失败: %v", err)
|
||||
}
|
||||
return &TabService{repo: repo}, nil
|
||||
}
|
||||
|
||||
// SaveTabs 保存标签页列表
|
||||
func (s *TabService) SaveTabs(tabs []models.SqlTab) error {
|
||||
return s.repo.SaveAll(tabs)
|
||||
}
|
||||
|
||||
// ListTabs 获取标签页列表
|
||||
func (s *TabService) ListTabs() ([]models.SqlTab, error) {
|
||||
return s.repo.FindAll()
|
||||
}
|
||||
|
||||
// DeleteTab 删除标签页
|
||||
func (s *TabService) DeleteTab(id uint) error {
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
165
internal/storage/connection_service.go
Normal file
165
internal/storage/connection_service.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go-desk/internal/crypto"
|
||||
"go-desk/internal/dbclient"
|
||||
"go-desk/internal/storage/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ConnectionService 连接管理服务
|
||||
type ConnectionService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewConnectionService 创建连接服务
|
||||
func NewConnectionService() (*ConnectionService, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
// 尝试重新初始化
|
||||
var err error
|
||||
db, err = Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("数据库初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
return &ConnectionService{db: db}, nil
|
||||
}
|
||||
|
||||
// SaveConnection 保存连接配置
|
||||
func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error {
|
||||
if conn.Name == "" {
|
||||
return fmt.Errorf("连接名称不能为空")
|
||||
}
|
||||
if conn.Type == "" {
|
||||
return fmt.Errorf("数据库类型不能为空")
|
||||
}
|
||||
if conn.Host == "" {
|
||||
return fmt.Errorf("主机地址不能为空")
|
||||
}
|
||||
|
||||
// 检查名称是否重复(排除当前记录)
|
||||
var count int64
|
||||
query := s.db.Model(&models.DbConnection{}).Where("name = ?", conn.Name)
|
||||
if conn.ID > 0 {
|
||||
query = query.Where("id != ?", conn.ID)
|
||||
}
|
||||
query.Count(&count)
|
||||
if count > 0 {
|
||||
return fmt.Errorf("连接名称已存在")
|
||||
}
|
||||
|
||||
if conn.ID > 0 {
|
||||
// 更新模式
|
||||
updateData := map[string]interface{}{
|
||||
"name": conn.Name,
|
||||
"type": conn.Type,
|
||||
"host": conn.Host,
|
||||
"port": conn.Port,
|
||||
"username": conn.Username,
|
||||
"database": conn.Database,
|
||||
"options": conn.Options,
|
||||
}
|
||||
|
||||
// 如果提供了新密码,加密后更新
|
||||
if conn.Password != "" {
|
||||
encrypted, err := crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
updateData["password"] = encrypted
|
||||
}
|
||||
// 如果密码为空,不更新密码字段(保留原密码)
|
||||
|
||||
return s.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(updateData).Error
|
||||
}
|
||||
|
||||
// 新增模式 - 必须提供密码
|
||||
if conn.Password == "" {
|
||||
return fmt.Errorf("新增连接时密码不能为空")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
encrypted, err := crypto.EncryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %v", err)
|
||||
}
|
||||
conn.Password = encrypted
|
||||
|
||||
return s.db.Create(conn).Error
|
||||
}
|
||||
|
||||
// ListConnections 获取连接列表
|
||||
func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) {
|
||||
var connections []models.DbConnection
|
||||
err := s.db.Order("created_at DESC").Find(&connections).Error
|
||||
return connections, err
|
||||
}
|
||||
|
||||
// GetConnection 获取连接详情
|
||||
func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) {
|
||||
var conn models.DbConnection
|
||||
err := s.db.First(&conn, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conn, nil
|
||||
}
|
||||
|
||||
// DeleteConnection 删除连接配置
|
||||
func (s *ConnectionService) DeleteConnection(id uint) error {
|
||||
return s.db.Delete(&models.DbConnection{}, id).Error
|
||||
}
|
||||
|
||||
// TestConnection 测试连接(需要根据类型调用不同的测试方法)
|
||||
func (s *ConnectionService) TestConnection(conn *models.DbConnection) error {
|
||||
// 解密密码用于测试
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 根据类型测试连接
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
return testMySQLConnection(conn.Host, conn.Port, conn.Username, password, conn.Database)
|
||||
case "redis":
|
||||
return testRedisConnection(conn.Host, conn.Port, password)
|
||||
case "mongo":
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
return testMongoConnection(conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", conn.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// testMySQLConnection 测试 MySQL 连接
|
||||
func testMySQLConnection(host string, port int, username, password, database string) error {
|
||||
return dbclient.TestMySQLConnection(host, port, username, password, database)
|
||||
}
|
||||
|
||||
// testRedisConnection 测试 Redis 连接
|
||||
func testRedisConnection(host string, port int, password string) error {
|
||||
return dbclient.TestRedisConnection(host, port, password)
|
||||
}
|
||||
|
||||
// testMongoConnection 测试 MongoDB 连接
|
||||
func testMongoConnection(host string, port int, username, password, database, authSource, authMechanism string) error {
|
||||
return dbclient.TestMongoConnectionWithOptions(host, port, username, password, database, authSource, authMechanism)
|
||||
}
|
||||
25
internal/storage/models/connection.go
Normal file
25
internal/storage/models/connection.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DbConnection 数据库连接配置
|
||||
type DbConnection struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(100);not null" json:"name"` // 连接名称
|
||||
Type string `gorm:"type:varchar(20);not null" json:"type"` // 数据库类型: mysql/redis/mongo
|
||||
Host string `gorm:"type:varchar(255);not null" json:"host"` // 主机地址
|
||||
Port int `gorm:"not null" json:"port"` // 端口
|
||||
Username string `gorm:"type:varchar(100)" json:"username"` // 用户名
|
||||
Password string `gorm:"type:varchar(500)" json:"-"` // 密码(加密存储,不返回)
|
||||
Database string `gorm:"type:varchar(100)" json:"database"` // 数据库名(MySQL/MongoDB)
|
||||
Options string `gorm:"type:text" json:"options"` // 额外选项(JSON格式)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (DbConnection) TableName() string {
|
||||
return "db_connection"
|
||||
}
|
||||
20
internal/storage/models/file.go
Normal file
20
internal/storage/models/file.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SqlFile SQL 文件记录
|
||||
type SqlFile struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(200);not null" json:"name"` // 文件名
|
||||
Path string `gorm:"type:varchar(500);not null;uniqueIndex" json:"path"` // 文件路径
|
||||
Content string `gorm:"type:text" json:"content"` // 文件内容
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SqlFile) TableName() string {
|
||||
return "sql_file"
|
||||
}
|
||||
24
internal/storage/models/sql_result_history.go
Normal file
24
internal/storage/models/sql_result_history.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SqlResultHistory SQL 执行结果历史
|
||||
type SqlResultHistory struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ConnectionID uint `gorm:"index;not null" json:"connection_id"` // 连接ID
|
||||
Database string `gorm:"type:varchar(100)" json:"database"` // 数据库名
|
||||
Sql string `gorm:"type:text;not null" json:"sql"` // SQL语句
|
||||
Type string `gorm:"type:varchar(20);not null" json:"type"` // 结果类型: query/update/command
|
||||
Data string `gorm:"type:text" json:"data"` // 结果数据(JSON)
|
||||
Columns string `gorm:"type:text" json:"columns"` // 列信息(JSON)
|
||||
RowsAffected int `gorm:"default:0" json:"rows_affected"` // 影响行数
|
||||
ExecutionTime int64 `gorm:"default:0" json:"execution_time"` // 执行时间(毫秒)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SqlResultHistory) TableName() string {
|
||||
return "sql_result_history"
|
||||
}
|
||||
21
internal/storage/models/sql_tab.go
Normal file
21
internal/storage/models/sql_tab.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SqlTab SQL 编辑器标签页
|
||||
type SqlTab struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Title string `gorm:"type:varchar(100);not null" json:"title"` // 标签页标题
|
||||
Content string `gorm:"type:text" json:"content"` // SQL 内容
|
||||
ConnectionID *uint `gorm:"index" json:"connection_id"` // 关联的连接ID(可为空)
|
||||
Order int `gorm:"default:0" json:"order"` // 排序顺序
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SqlTab) TableName() string {
|
||||
return "sql_tab"
|
||||
}
|
||||
70
internal/storage/repository/connection_repo.go
Normal file
70
internal/storage/repository/connection_repo.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"go-desk/internal/storage"
|
||||
"go-desk/internal/storage/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConnectionRepository interface {
|
||||
Save(conn *models.DbConnection) error
|
||||
FindAll() ([]models.DbConnection, error)
|
||||
FindByID(id uint) (*models.DbConnection, error)
|
||||
Delete(id uint) error
|
||||
FindByName(name string, excludeID uint) (*models.DbConnection, error)
|
||||
}
|
||||
|
||||
type connectionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewConnectionRepository() (ConnectionRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &connectionRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *connectionRepository) Save(conn *models.DbConnection) error {
|
||||
if conn.ID > 0 {
|
||||
return r.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(conn).Error
|
||||
}
|
||||
return r.db.Create(conn).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindAll() ([]models.DbConnection, error) {
|
||||
var connections []models.DbConnection
|
||||
return connections, r.db.Order("created_at DESC").Find(&connections).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindByID(id uint) (*models.DbConnection, error) {
|
||||
var conn models.DbConnection
|
||||
err := r.db.First(&conn, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &conn, err
|
||||
}
|
||||
|
||||
func (r *connectionRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.DbConnection{}, id).Error
|
||||
}
|
||||
|
||||
func (r *connectionRepository) FindByName(name string, excludeID uint) (*models.DbConnection, error) {
|
||||
var conn models.DbConnection
|
||||
query := r.db.Where("name = ?", name)
|
||||
if excludeID > 0 {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
err := query.First(&conn).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &conn, err
|
||||
}
|
||||
110
internal/storage/repository/result_repo.go
Normal file
110
internal/storage/repository/result_repo.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"go-desk/internal/storage"
|
||||
"go-desk/internal/storage/models"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ResultRepository interface {
|
||||
Save(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (*models.SqlResultHistory, error)
|
||||
FindByID(id uint) (*models.SqlResultHistory, error)
|
||||
FindByConnection(connectionID uint, limit int) ([]models.SqlResultHistory, error)
|
||||
Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error)
|
||||
Delete(id uint) error
|
||||
DeleteByConnection(connectionID uint) error
|
||||
DeleteOld(keepDays int) error
|
||||
}
|
||||
|
||||
type resultRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewResultRepository() (ResultRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &resultRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *resultRepository) Save(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (*models.SqlResultHistory, error) {
|
||||
dataJSON, _ := json.Marshal(data)
|
||||
columnsJSON, _ := json.Marshal(columns)
|
||||
|
||||
history := &models.SqlResultHistory{
|
||||
ConnectionID: connectionID,
|
||||
Database: database,
|
||||
Sql: sql,
|
||||
Type: resultType,
|
||||
Data: string(dataJSON),
|
||||
Columns: string(columnsJSON),
|
||||
RowsAffected: rowsAffected,
|
||||
ExecutionTime: executionTime,
|
||||
}
|
||||
|
||||
return history, r.db.Create(history).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) FindByID(id uint) (*models.SqlResultHistory, error) {
|
||||
var history models.SqlResultHistory
|
||||
err := r.db.First(&history, id).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &history, err
|
||||
}
|
||||
|
||||
func (r *resultRepository) FindByConnection(connectionID uint, limit int) ([]models.SqlResultHistory, error) {
|
||||
var histories []models.SqlResultHistory
|
||||
query := r.db.Where("connection_id = ?", connectionID).Order("created_at DESC")
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
return histories, query.Find(&histories).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error) {
|
||||
query := r.db.Model(&models.SqlResultHistory{})
|
||||
|
||||
if connectionID != nil {
|
||||
query = query.Where("connection_id = ?", *connectionID)
|
||||
}
|
||||
if keyword != "" {
|
||||
query = query.Where("sql LIKE ? OR database LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var histories []models.SqlResultHistory
|
||||
query = query.Order("created_at DESC")
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
return histories, total, query.Find(&histories).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.SqlResultHistory{}, id).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) DeleteByConnection(connectionID uint) error {
|
||||
return r.db.Where("connection_id = ?", connectionID).Delete(&models.SqlResultHistory{}).Error
|
||||
}
|
||||
|
||||
func (r *resultRepository) DeleteOld(keepDays int) error {
|
||||
return r.db.Where("created_at < ?", time.Now().AddDate(0, 0, -keepDays)).Delete(&models.SqlResultHistory{}).Error
|
||||
}
|
||||
55
internal/storage/repository/tab_repo.go
Normal file
55
internal/storage/repository/tab_repo.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"go-desk/internal/storage"
|
||||
"go-desk/internal/storage/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TabRepository interface {
|
||||
SaveAll(tabs []models.SqlTab) error
|
||||
FindAll() ([]models.SqlTab, error)
|
||||
Delete(id uint) error
|
||||
DeleteAll() error
|
||||
}
|
||||
|
||||
type tabRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTabRepository() (TabRepository, error) {
|
||||
db := storage.GetDB()
|
||||
if db == nil {
|
||||
var err error
|
||||
db, err = storage.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &tabRepository{db}, nil
|
||||
}
|
||||
|
||||
func (r *tabRepository) SaveAll(tabs []models.SqlTab) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("1=1").Delete(&models.SqlTab{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tabs) > 0 {
|
||||
return tx.Create(&tabs).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *tabRepository) FindAll() ([]models.SqlTab, error) {
|
||||
var tabs []models.SqlTab
|
||||
return tabs, r.db.Order("`order` ASC, created_at ASC").Find(&tabs).Error
|
||||
}
|
||||
|
||||
func (r *tabRepository) Delete(id uint) error {
|
||||
return r.db.Delete(&models.SqlTab{}, id).Error
|
||||
}
|
||||
|
||||
func (r *tabRepository) DeleteAll() error {
|
||||
return r.db.Where("1=1").Delete(&models.SqlTab{}).Error
|
||||
}
|
||||
57
internal/storage/sqlite.go
Normal file
57
internal/storage/sqlite.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"go-desk/internal/storage/models"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var globalDB *gorm.DB
|
||||
|
||||
func Init() (*gorm.DB, error) {
|
||||
if globalDB != nil {
|
||||
return globalDB, nil
|
||||
}
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataDir := filepath.Join(homeDir, ".go-desk")
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
|
||||
dbPath := filepath.Join(dataDir, "db-cli.db")
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&models.DbConnection{},
|
||||
&models.SqlTab{},
|
||||
&models.SqlResultHistory{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
globalDB = db
|
||||
return globalDB, nil
|
||||
}
|
||||
|
||||
func GetDB() *gorm.DB {
|
||||
return globalDB
|
||||
}
|
||||
Reference in New Issue
Block a user