新增:连接管理、数据查询等功能
This commit is contained in:
818
internal/dbclient/mongo.go
Normal file
818
internal/dbclient/mongo.go
Normal file
@@ -0,0 +1,818 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
// MongoClient MongoDB 客户端
|
||||
type MongoClient struct {
|
||||
client *mongo.Client
|
||||
database *mongo.Database
|
||||
config *MongoConfig
|
||||
}
|
||||
|
||||
// MongoConfig MongoDB 配置
|
||||
type MongoConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string // 认证数据库,默认为 "admin"
|
||||
AuthMechanism string // 认证机制,如 "SCRAM-SHA-1", "SCRAM-SHA-256" 等
|
||||
}
|
||||
|
||||
// NewMongoClient 创建 MongoDB 客户端
|
||||
func NewMongoClient(config *MongoConfig) (*MongoClient, error) {
|
||||
// 确定认证数据库,默认为 admin
|
||||
authSource := config.AuthSource
|
||||
if authSource == "" {
|
||||
authSource = "admin"
|
||||
}
|
||||
|
||||
// 如果指定了认证机制,直接使用;否则尝试自动检测
|
||||
authMechanisms := []string{}
|
||||
if config.AuthMechanism != "" {
|
||||
// 用户明确指定了认证机制,只使用该机制
|
||||
authMechanisms = []string{config.AuthMechanism}
|
||||
} else {
|
||||
// 未指定时,先尝试 SCRAM-SHA-256(更安全),失败则尝试 SCRAM-SHA-1
|
||||
authMechanisms = []string{"SCRAM-SHA-256", "SCRAM-SHA-1"}
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, authMechanism := range authMechanisms {
|
||||
client, err := tryConnectMongo(config, authSource, authMechanism)
|
||||
if err == nil {
|
||||
return client, nil
|
||||
}
|
||||
lastErr = err
|
||||
// 如果明确指定了认证机制,失败后不再尝试其他机制
|
||||
if config.AuthMechanism != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 所有认证机制都失败
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", lastErr)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("MongoDB 连接失败: 未知错误")
|
||||
}
|
||||
|
||||
// tryConnectMongo 尝试使用指定的认证机制连接 MongoDB
|
||||
func tryConnectMongo(config *MongoConfig, authSource, authMechanism string) (*MongoClient, error) {
|
||||
// 构建连接 URI
|
||||
var uri string
|
||||
|
||||
if config.Username != "" && config.Password != "" {
|
||||
// 使用 url.UserPassword 正确转义用户名和密码中的特殊字符
|
||||
// 这会正确处理 @、:、/ 等特殊字符
|
||||
userInfo := url.UserPassword(config.Username, config.Password)
|
||||
|
||||
// 构建基础 URI
|
||||
uri = fmt.Sprintf("mongodb://%s@%s:%d", userInfo.String(), config.Host, config.Port)
|
||||
|
||||
// 添加数据库和认证源参数
|
||||
params := url.Values{}
|
||||
params.Set("authSource", authSource)
|
||||
|
||||
// 添加认证机制参数
|
||||
if authMechanism != "" {
|
||||
params.Set("authMechanism", authMechanism)
|
||||
}
|
||||
|
||||
// 如果有业务数据库,添加到路径中
|
||||
if config.Database != "" {
|
||||
uri = fmt.Sprintf("%s/%s?%s", uri, config.Database, params.Encode())
|
||||
} else {
|
||||
// MongoDB URI 要求查询参数前必须有 /,即使没有数据库名
|
||||
uri = fmt.Sprintf("%s/?%s", uri, params.Encode())
|
||||
}
|
||||
} else if config.Database != "" {
|
||||
// 没有认证信息时,数据库部分用于指定默认数据库
|
||||
uri = fmt.Sprintf("mongodb://%s:%d/%s", config.Host, config.Port, config.Database)
|
||||
} else {
|
||||
uri = fmt.Sprintf("mongodb://%s:%d", config.Host, config.Port)
|
||||
}
|
||||
|
||||
// 客户端选项
|
||||
clientOptions := options.Client().
|
||||
ApplyURI(uri).
|
||||
SetConnectTimeout(5 * time.Second).
|
||||
SetServerSelectionTimeout(5 * time.Second)
|
||||
|
||||
// 创建客户端
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MongoDB 失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := client.Ping(ctx, nil); err != nil {
|
||||
client.Disconnect(ctx)
|
||||
return nil, fmt.Errorf("MongoDB 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
var database *mongo.Database
|
||||
if config.Database != "" {
|
||||
database = client.Database(config.Database)
|
||||
}
|
||||
|
||||
return &MongoClient{
|
||||
client: client,
|
||||
database: database,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestMongoConnection 测试连接
|
||||
func TestMongoConnection(host string, port int, username, password, database string) error {
|
||||
return TestMongoConnectionWithAuthSource(host, port, username, password, database, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithAuthSource 测试连接(支持指定认证数据库)
|
||||
func TestMongoConnectionWithAuthSource(host string, port int, username, password, database, authSource string) error {
|
||||
return TestMongoConnectionWithOptions(host, port, username, password, database, authSource, "")
|
||||
}
|
||||
|
||||
// TestMongoConnectionWithOptions 测试连接(支持指定认证数据库和认证机制)
|
||||
func TestMongoConnectionWithOptions(host string, port int, username, password, database, authSource, authMechanism string) error {
|
||||
config := &MongoConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MongoClient) Close() error {
|
||||
if c.client != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return c.client.Disconnect(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MongoClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
databases, err := c.client.ListDatabaseNames(ctx, bson.M{})
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListCollections 获取集合列表
|
||||
func (c *MongoClient) ListCollections(ctx context.Context, database string) ([]string, error) {
|
||||
db := c.client.Database(database)
|
||||
collections, err := db.ListCollectionNames(ctx, bson.M{})
|
||||
return collections, err
|
||||
}
|
||||
|
||||
// GetCollectionStructure 获取集合结构
|
||||
func (c *MongoClient) GetCollectionStructure(ctx context.Context, database, collectionName string) (map[string]interface{}, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"database": database,
|
||||
"collection": collectionName,
|
||||
"sampleDocs": []map[string]interface{}{},
|
||||
"fieldStats": map[string]int{},
|
||||
"indexes": []map[string]interface{}{},
|
||||
"documentCount": int64(0),
|
||||
}
|
||||
|
||||
// 获取文档示例(最多 5 个)
|
||||
opts := options.Find().SetLimit(5)
|
||||
cursor, err := coll.Find(ctx, bson.M{}, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档示例失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var docs []bson.M
|
||||
if err = cursor.All(ctx, &docs); err != nil {
|
||||
return nil, fmt.Errorf("解析文档失败: %v", err)
|
||||
}
|
||||
|
||||
// 转换为 map
|
||||
sampleDocs := make([]map[string]interface{}, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
docMap := make(map[string]interface{})
|
||||
for k, v := range doc {
|
||||
docMap[k] = v
|
||||
}
|
||||
sampleDocs = append(sampleDocs, docMap)
|
||||
}
|
||||
result["sampleDocs"] = sampleDocs
|
||||
|
||||
// 字段统计:使用 $sample 聚合管道随机采样10个文档进行统计
|
||||
// 这样可以获得更准确的字段分布,同时保持良好性能
|
||||
// 使用异步方式执行,避免阻塞主流程
|
||||
sampleSize := 10
|
||||
pipeline := []bson.M{
|
||||
{"$sample": bson.M{"size": sampleSize}},
|
||||
{"$project": bson.M{"keys": bson.M{"$objectToArray": "$$ROOT"}}},
|
||||
{"$unwind": "$keys"},
|
||||
{"$group": bson.M{
|
||||
"_id": "$keys.k",
|
||||
"count": bson.M{"$sum": 1},
|
||||
}},
|
||||
{"$sort": bson.M{"count": -1}}, // 按出现次数降序排序
|
||||
}
|
||||
|
||||
sampleCursor, err := coll.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
// 如果采样失败,回退到基于文档示例的统计
|
||||
fieldCount := make(map[string]int)
|
||||
for _, doc := range docs {
|
||||
for key := range doc {
|
||||
fieldCount[key]++
|
||||
}
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = len(docs) // 记录实际采样数量
|
||||
result["fieldStatsMethod"] = "sample-docs" // 标记统计方式
|
||||
} else {
|
||||
defer sampleCursor.Close(ctx)
|
||||
fieldCount := make(map[string]int)
|
||||
for sampleCursor.Next(ctx) {
|
||||
var statResult bson.M
|
||||
if err := sampleCursor.Decode(&statResult); err != nil {
|
||||
continue
|
||||
}
|
||||
fieldName, ok := statResult["_id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var count int
|
||||
switch v := statResult["count"].(type) {
|
||||
case int32:
|
||||
count = int(v)
|
||||
case int64:
|
||||
count = int(v)
|
||||
case int:
|
||||
count = v
|
||||
case float64:
|
||||
count = int(v)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
fieldCount[fieldName] = count
|
||||
}
|
||||
result["fieldStats"] = fieldCount
|
||||
result["fieldStatsSampleSize"] = sampleSize // 记录采样数量
|
||||
result["fieldStatsMethod"] = "sample-aggregate" // 标记统计方式
|
||||
}
|
||||
|
||||
// 文档总数(使用估算值,性能更好)
|
||||
// 对于大数据集,estimatedDocumentCount 比 CountDocuments 快得多
|
||||
// 如果需要精确值,可以使用 CountDocuments,但性能较差
|
||||
count, err := coll.EstimatedDocumentCount(ctx)
|
||||
if err != nil {
|
||||
// 如果估算失败,尝试精确计数(可能较慢)
|
||||
count, err = coll.CountDocuments(ctx, bson.M{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文档数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
result["documentCount"] = count
|
||||
|
||||
// 索引信息
|
||||
indexCursor, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
// 索引查询失败不影响主流程
|
||||
result["indexes"] = []map[string]interface{}{}
|
||||
} else {
|
||||
var indexes []map[string]interface{}
|
||||
for indexCursor.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := indexCursor.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
indexes = append(indexes, map[string]interface{}{
|
||||
"name": indexSpec["name"],
|
||||
"unique": indexSpec["unique"],
|
||||
"keys": indexSpec["key"],
|
||||
})
|
||||
}
|
||||
indexCursor.Close(ctx)
|
||||
result["indexes"] = indexes
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询
|
||||
func (c *MongoClient) ExecuteQuery(ctx context.Context, database, collection string, filter bson.M, limit int64) ([]map[string]interface{}, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// CountDocuments 获取文档数量
|
||||
func (c *MongoClient) CountDocuments(ctx context.Context, database, collection string, filter bson.M) (int64, error) {
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collection)
|
||||
return coll.CountDocuments(ctx, filter)
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 MongoDB 命令
|
||||
// command 可以是 JSON 格式的字符串,格式:{"op": "find", "database": "test", "collection": "users", "filter": {}, "limit": 100}
|
||||
// 支持的操作:find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany
|
||||
func (c *MongoClient) ExecuteCommand(ctx context.Context, database string, command map[string]interface{}) (interface{}, error) {
|
||||
op, ok := command["op"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'op' 字段或格式错误")
|
||||
}
|
||||
|
||||
collectionName, ok := command["collection"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("命令中缺少 'collection' 字段或格式错误")
|
||||
}
|
||||
|
||||
// 如果没有指定数据库,使用配置中的默认数据库
|
||||
if database == "" {
|
||||
if c.config != nil && c.config.Database != "" {
|
||||
database = c.config.Database
|
||||
} else {
|
||||
return nil, fmt.Errorf("需要指定数据库名称")
|
||||
}
|
||||
}
|
||||
|
||||
db := c.client.Database(database)
|
||||
coll := db.Collection(collectionName)
|
||||
|
||||
switch op {
|
||||
case "find":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
limit := int64(100)
|
||||
if l, ok := command["limit"]; ok {
|
||||
if limitVal, ok := l.(float64); ok {
|
||||
limit = int64(limitVal)
|
||||
} else if limitVal, ok := l.(int64); ok {
|
||||
limit = limitVal
|
||||
}
|
||||
}
|
||||
|
||||
opts := options.Find().SetLimit(limit)
|
||||
cursor, err := coll.Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := cursor.All(ctx, &results); err != nil {
|
||||
return nil, fmt.Errorf("读取结果失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
|
||||
case "count":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
}
|
||||
|
||||
count, err := coll.CountDocuments(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("统计失败: %v", err)
|
||||
}
|
||||
return count, nil
|
||||
|
||||
case "insertOne":
|
||||
document, ok := command["document"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertOne 操作需要 'document' 字段")
|
||||
}
|
||||
|
||||
doc := bson.M{}
|
||||
if docMap, ok := document.(map[string]interface{}); ok {
|
||||
doc = bson.M(docMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("document 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertOne(ctx, doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedId": result.InsertedID,
|
||||
}, nil
|
||||
|
||||
case "insertMany":
|
||||
documents, ok := command["documents"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("insertMany 操作需要 'documents' 字段")
|
||||
}
|
||||
|
||||
docs := []interface{}{}
|
||||
if docsSlice, ok := documents.([]interface{}); ok {
|
||||
for _, d := range docsSlice {
|
||||
if docMap, ok := d.(map[string]interface{}); ok {
|
||||
docs = append(docs, bson.M(docMap))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("documents 必须是数组格式")
|
||||
}
|
||||
|
||||
result, err := coll.InsertMany(ctx, docs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量插入失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"insertedIds": result.InsertedIDs,
|
||||
"insertedCount": len(result.InsertedIDs),
|
||||
}, nil
|
||||
|
||||
case "updateOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateOne 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateOne(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "updateMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
update, ok := command["update"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("updateMany 操作需要 'update' 字段")
|
||||
}
|
||||
|
||||
updateDoc := bson.M{}
|
||||
if updateMap, ok := update.(map[string]interface{}); ok {
|
||||
updateDoc = bson.M(updateMap)
|
||||
} else {
|
||||
return nil, fmt.Errorf("update 必须是对象格式")
|
||||
}
|
||||
|
||||
result, err := coll.UpdateMany(ctx, filter, bson.M{"$set": updateDoc})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量更新失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"matchedCount": result.MatchedCount,
|
||||
"modifiedCount": result.ModifiedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteOne":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteOne 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteOne(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
case "deleteMany":
|
||||
filter := bson.M{}
|
||||
if f, ok := command["filter"]; ok {
|
||||
if filterMap, ok := f.(map[string]interface{}); ok {
|
||||
filter = bson.M(filterMap)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("deleteMany 操作需要 'filter' 字段")
|
||||
}
|
||||
|
||||
result, err := coll.DeleteMany(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量删除失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"deletedCount": result.DeletedCount,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的操作: %s,支持的操作: find, count, insertOne, insertMany, updateOne, updateMany, deleteOne, deleteMany", op)
|
||||
}
|
||||
}
|
||||
|
||||
// PreviewCollectionIndexes 预览集合索引变更,只生成命令列表不执行
|
||||
func (c *MongoClient) PreviewCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
var commands []string
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
cmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
dropCmd := fmt.Sprintf("db.%s.dropIndex(\"%s\")", collectionName, name)
|
||||
commands = append(commands, dropCmd)
|
||||
}
|
||||
|
||||
// 构建命令字符串(MongoDB shell 格式)
|
||||
keysStr := "{"
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
keysStr += ", "
|
||||
}
|
||||
keysStr += fmt.Sprintf("%s: %d", key.Key, key.Value)
|
||||
}
|
||||
keysStr += "}"
|
||||
|
||||
optionsStr := "{name: \"" + name + "\""
|
||||
if indexOptions.Unique != nil && *indexOptions.Unique {
|
||||
optionsStr += ", unique: true"
|
||||
}
|
||||
optionsStr += "}"
|
||||
|
||||
cmd := fmt.Sprintf("db.%s.createIndex(%s, %s)", collectionName, keysStr, optionsStr)
|
||||
commands = append(commands, cmd)
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
// UpdateCollectionIndexes 更新集合索引,返回执行的命令列表
|
||||
func (c *MongoClient) UpdateCollectionIndexes(ctx context.Context, database, collectionName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成命令列表
|
||||
commands, err := c.PreviewCollectionIndexes(ctx, database, collectionName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
coll := c.client.Database(database).Collection(collectionName)
|
||||
|
||||
// 获取当前索引
|
||||
currentIndexes, err := coll.Indexes().List(ctx)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
defer currentIndexes.Close(ctx)
|
||||
|
||||
// 解析新的索引数据
|
||||
var newIndexes []map[string]interface{}
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建当前索引名映射
|
||||
currentIndexMap := make(map[string]bool)
|
||||
for currentIndexes.Next(ctx) {
|
||||
var indexSpec bson.M
|
||||
if err := currentIndexes.Decode(&indexSpec); err != nil {
|
||||
continue
|
||||
}
|
||||
if name, ok := indexSpec["name"].(string); ok && name != "_id_" {
|
||||
currentIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新索引名映射
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if name, ok := idx["name"].(string); ok && name != "" && name != "_id_" {
|
||||
newIndexMap[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for name := range currentIndexMap {
|
||||
if !newIndexMap[name] {
|
||||
_, err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加或更新索引
|
||||
for _, idx := range newIndexes {
|
||||
name, _ := idx["name"].(string)
|
||||
if name == "" || name == "_id_" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引键
|
||||
keys := bson.D{}
|
||||
if keysData, ok := idx["keys"].(map[string]interface{}); ok {
|
||||
for k, v := range keysData {
|
||||
var order int
|
||||
if vFloat, ok := v.(float64); ok {
|
||||
order = int(vFloat)
|
||||
} else if vInt, ok := v.(int); ok {
|
||||
order = vInt
|
||||
} else {
|
||||
order = 1 // 默认升序
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k, Value: order})
|
||||
}
|
||||
} else if columnName, ok := idx["Column_name"].(string); ok && columnName != "" {
|
||||
// 兼容 MySQL 格式的索引数据
|
||||
keys = append(keys, bson.E{Key: columnName, Value: 1})
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 构建索引选项
|
||||
indexOptions := options.Index()
|
||||
indexOptions.SetName(name)
|
||||
|
||||
if unique, ok := idx["unique"].(bool); ok && unique {
|
||||
indexOptions.SetUnique(true)
|
||||
} else if nonUnique, ok := idx["Non_unique"].(float64); ok && nonUnique == 0 {
|
||||
indexOptions.SetUnique(true)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
indexModel := mongo.IndexModel{
|
||||
Keys: keys,
|
||||
Options: indexOptions,
|
||||
}
|
||||
|
||||
// 如果索引已存在,先删除再创建
|
||||
if currentIndexMap[name] {
|
||||
_, err := coll.Indexes().DropOne(ctx, name)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("删除旧索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := coll.Indexes().CreateOne(ctx, indexModel)
|
||||
if err != nil {
|
||||
return commands, fmt.Errorf("创建索引失败: %v, 索引名: %s", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
return commands, nil
|
||||
}
|
||||
875
internal/dbclient/mysql.go
Normal file
875
internal/dbclient/mysql.go
Normal file
@@ -0,0 +1,875 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// MySQLClient MySQL 客户端
|
||||
type MySQLClient struct {
|
||||
db *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
config *MySQLConfig
|
||||
}
|
||||
|
||||
// MySQLConfig MySQL 配置
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// NewMySQLClient 创建 MySQL 客户端
|
||||
func NewMySQLClient(config *MySQLConfig) (*MySQLClient, error) {
|
||||
// 构建 DSN
|
||||
mysqlConfig := mysqldriver.Config{
|
||||
User: config.Username,
|
||||
Passwd: config.Password,
|
||||
Net: "tcp",
|
||||
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
DBName: config.Database,
|
||||
Params: map[string]string{
|
||||
"charset": "utf8mb4",
|
||||
"parseTime": "True",
|
||||
"loc": "Local",
|
||||
"multiStatements": "true", // 支持多条SQL语句执行
|
||||
},
|
||||
AllowNativePasswords: true,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
dsn := mysqlConfig.FormatDSN()
|
||||
|
||||
// GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
// 打开连接
|
||||
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MySQL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取底层 sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("MySQL 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(2)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
return &MySQLClient{
|
||||
db: db,
|
||||
sqlDB: sqlDB,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func TestMySQLConnection(host string, port int, username, password, database string) error {
|
||||
config := &MySQLConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
}
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c.sqlDB != nil {
|
||||
return c.sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryResult 查询结果,包含数据和列顺序
|
||||
type QueryResult struct {
|
||||
Data []map[string]interface{}
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询 SQL
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
// 注意:SQL 语句应该已经包含 LIMIT 和 OFFSET(由客户端添加)
|
||||
func (c *MySQLClient) ExecuteQuery(ctx context.Context, sqlStr string, database string) (*QueryResult, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return nil, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sqlStr).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("执行查询失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 检查 rows 错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("查询结果错误: %v", err)
|
||||
}
|
||||
|
||||
// 获取列名
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有列,返回空数组
|
||||
if len(columns) == 0 {
|
||||
return &QueryResult{
|
||||
Data: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
// 创建值数组和指针数组
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// 扫描行数据
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 构建结果 map,按照列顺序构建
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
// 处理 nil 值
|
||||
if val == nil {
|
||||
row[col] = nil
|
||||
} else if b, ok := val.([]byte); ok {
|
||||
// 处理 []byte 类型
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
// 检查迭代过程中的错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("读取数据时发生错误: %v", err)
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Data: results,
|
||||
Columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteUpdate 执行更新 SQL(INSERT/UPDATE/DELETE)
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
func (c *MySQLClient) ExecuteUpdate(ctx context.Context, sqlStr string, database string) (int64, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return 0, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sqlStr)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("执行更新失败: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MySQLClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
var databases []string
|
||||
err := c.db.Raw("SHOW DATABASES").Scan(&databases).Error
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListTables 获取表列表
|
||||
func (c *MySQLClient) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
var tables []string
|
||||
query := "SHOW TABLES"
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", database)
|
||||
}
|
||||
err := c.db.Raw(query).Scan(&tables).Error
|
||||
return tables, err
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (c *MySQLClient) GetTableStructure(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
// 使用 SHOW FULL COLUMNS 来获取包含 comment 的完整字段信息
|
||||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", database, tableName)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表结构失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else if val == nil {
|
||||
row[col] = nil
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Comment 字段存在(SHOW FULL COLUMNS 返回的字段名是 Comment)
|
||||
if _, ok := row["Comment"]; !ok {
|
||||
row["Comment"] = ""
|
||||
}
|
||||
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (c *MySQLClient) GetIndexes(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
query := "SHOW INDEX FROM "
|
||||
if database != "" {
|
||||
query += fmt.Sprintf("`%s`.", database)
|
||||
}
|
||||
query += fmt.Sprintf("`%s`", tableName)
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取索引列表失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更,只生成 SQL 语句不执行
|
||||
func (c *MySQLClient) PreviewTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 获取当前表结构
|
||||
currentColumns, err := c.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前表结构失败: %v", err)
|
||||
}
|
||||
|
||||
currentIndexes, err := c.GetIndexes(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析新的结构数据
|
||||
var newColumns []map[string]interface{}
|
||||
var newIndexes []map[string]interface{}
|
||||
|
||||
if cols, ok := structure["columns"].([]interface{}); ok {
|
||||
for _, col := range cols {
|
||||
if colMap, ok := col.(map[string]interface{}); ok {
|
||||
newColumns = append(newColumns, colMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ALTER TABLE 语句
|
||||
var alterStatements []string
|
||||
|
||||
// 处理字段变更
|
||||
alterStatements = append(alterStatements, c.buildColumnAlterStatements(tableName, currentColumns, newColumns)...)
|
||||
|
||||
// 处理索引变更
|
||||
alterStatements = append(alterStatements, c.buildIndexAlterStatements(tableName, currentIndexes, newIndexes)...)
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构,返回生成的 SQL 语句列表
|
||||
func (c *MySQLClient) UpdateTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成 SQL 语句
|
||||
alterStatements, err := c.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行所有 ALTER TABLE 语句
|
||||
if len(alterStatements) > 0 {
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, stmt := range alterStatements {
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("执行 ALTER TABLE 失败: %v, SQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// buildColumnAlterStatements 构建字段变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildColumnAlterStatements(tableName string, currentColumns, newColumns []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建字段名映射和顺序映射
|
||||
currentFieldMap := make(map[string]map[string]interface{})
|
||||
currentFieldOrder := make([]string, 0, len(currentColumns))
|
||||
for _, col := range currentColumns {
|
||||
if field, ok := col["Field"].(string); ok {
|
||||
currentFieldMap[field] = col
|
||||
currentFieldOrder = append(currentFieldOrder, field)
|
||||
}
|
||||
}
|
||||
|
||||
newFieldMap := make(map[string]bool)
|
||||
newFieldOrder := make([]string, 0, len(newColumns))
|
||||
newColumnsMap := make(map[string]map[string]interface{})
|
||||
for _, col := range newColumns {
|
||||
if field, ok := col["Field"].(string); ok && field != "" {
|
||||
newFieldMap[field] = true
|
||||
newFieldOrder = append(newFieldOrder, field)
|
||||
newColumnsMap[field] = col
|
||||
}
|
||||
}
|
||||
|
||||
// 检测字段重命名:优先使用位置匹配,如果位置相同但字段名不同,认为是重命名
|
||||
renameMap := make(map[string]string) // oldName -> newName
|
||||
processedNewFields := make(map[string]bool)
|
||||
|
||||
// 第一步:使用位置匹配检测重命名(最可靠)
|
||||
for oldIndex, oldFieldName := range currentFieldOrder {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
|
||||
// 检查新字段列表中相同位置是否有字段
|
||||
if oldIndex < len(newFieldOrder) {
|
||||
newFieldName := newFieldOrder[oldIndex]
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent && !processedNewFields[newFieldName] {
|
||||
// 新字段不在当前字段列表中,且位置相同,很可能是重命名
|
||||
// 进一步验证:检查类型是否相同(类型相同更可能是重命名)
|
||||
oldCol := currentFieldMap[oldFieldName]
|
||||
newCol := newColumnsMap[newFieldName]
|
||||
oldType := getStringValue(oldCol["Type"])
|
||||
newType := getStringValue(newCol["Type"])
|
||||
|
||||
// 如果类型相同,认为是重命名
|
||||
if oldType == newType {
|
||||
renameMap[oldFieldName] = newFieldName
|
||||
processedNewFields[newFieldName] = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二步:对于未匹配的字段,使用属性匹配(兼容旧逻辑)
|
||||
for oldFieldName, oldCol := range currentFieldMap {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
if renameMap[oldFieldName] != "" {
|
||||
continue // 已经通过位置匹配识别为重命名
|
||||
}
|
||||
|
||||
// 查找属性完全匹配的新字段
|
||||
var matchedNewField string
|
||||
for newFieldName, newCol := range newColumnsMap {
|
||||
if processedNewFields[newFieldName] {
|
||||
continue // 已经被匹配过了
|
||||
}
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent {
|
||||
// 这是一个新增字段,检查属性是否匹配
|
||||
if c.isColumnPropertiesEqual(oldCol, newCol) {
|
||||
if matchedNewField == "" {
|
||||
matchedNewField = newFieldName
|
||||
} else {
|
||||
// 有多个匹配,无法确定,不认为是重命名
|
||||
matchedNewField = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到唯一匹配,认为是重命名
|
||||
if matchedNewField != "" {
|
||||
renameMap[oldFieldName] = matchedNewField
|
||||
processedNewFields[matchedNewField] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理字段重命名
|
||||
for oldName, newName := range renameMap {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tableName, oldName, newName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// 处理字段添加、修改和位置调整(排除已重命名的字段)
|
||||
for i, newCol := range newColumns {
|
||||
field, _ := newCol["Field"].(string)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是重命名的字段
|
||||
isRenamed := false
|
||||
var oldName string
|
||||
for old, new := range renameMap {
|
||||
if new == field {
|
||||
isRenamed = true
|
||||
oldName = old
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isRenamed {
|
||||
// 重命名的字段:如果属性有变化,需要 MODIFY COLUMN
|
||||
oldCol := currentFieldMap[oldName]
|
||||
needsModify := c.isColumnChanged(oldCol, newCol)
|
||||
|
||||
// 检查顺序变化:使用旧字段名在 currentOrder 中查找位置,与新位置比较
|
||||
oldIndex := -1
|
||||
for idx, name := range currentFieldOrder {
|
||||
if name == oldName {
|
||||
oldIndex = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
needsReorder := (oldIndex != -1 && oldIndex != i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
// 重命名后需要修改属性或位置
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentCol, exists := currentFieldMap[field]; exists {
|
||||
// 修改现有字段
|
||||
needsModify := c.isColumnChanged(currentCol, newCol)
|
||||
needsReorder := c.isColumnOrderChanged(currentFieldOrder, newFieldOrder, field, i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新字段(排除重命名的字段)
|
||||
stmt := c.buildAddColumnStatement(tableName, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的字段(排除已重命名的字段)
|
||||
for field := range currentFieldMap {
|
||||
if !newFieldMap[field] && renameMap[field] == "" {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`", tableName, field)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// buildIndexAlterStatements 构建索引变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildIndexAlterStatements(tableName string, currentIndexes, newIndexes []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建索引名映射
|
||||
currentIndexMap := make(map[string]map[string]interface{})
|
||||
for _, idx := range currentIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "PRIMARY" {
|
||||
currentIndexMap[keyName] = idx
|
||||
}
|
||||
}
|
||||
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "" && keyName != "PRIMARY" {
|
||||
newIndexMap[keyName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理索引变更
|
||||
for _, newIdx := range newIndexes {
|
||||
keyName, _ := newIdx["Key_name"].(string)
|
||||
if keyName == "" || keyName == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentIdx, exists := currentIndexMap[keyName]; exists {
|
||||
// 修改现有索引
|
||||
if c.isIndexChanged(currentIdx, newIdx) {
|
||||
dropStmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
addStmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if addStmt != "" {
|
||||
statements = append(statements, dropStmt)
|
||||
statements = append(statements, addStmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新索引
|
||||
stmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for keyName := range currentIndexMap {
|
||||
if !newIndexMap[keyName] {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// isColumnChanged 检查字段是否发生变化(不包括字段名)
|
||||
func (c *MySQLClient) isColumnChanged(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isColumnPropertiesEqual 检查字段属性是否完全相等(不包括字段名)
|
||||
func (c *MySQLClient) isColumnPropertiesEqual(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Key", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnOrderChanged 检查字段顺序是否发生变化
|
||||
func (c *MySQLClient) isColumnOrderChanged(currentOrder, newOrder []string, fieldName string, newIndex int) bool {
|
||||
// 查找字段在当前顺序中的位置
|
||||
currentIndex := -1
|
||||
for i, name := range currentOrder {
|
||||
if name == fieldName {
|
||||
currentIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果字段不存在于当前顺序中(新字段),不需要检查顺序
|
||||
if currentIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果索引相同,检查前面的字段是否相同
|
||||
if newIndex == currentIndex {
|
||||
// 检查前面的字段集合是否相同
|
||||
if newIndex > 0 {
|
||||
currentPrevFields := make(map[string]bool)
|
||||
for i := 0; i < currentIndex; i++ {
|
||||
currentPrevFields[currentOrder[i]] = true
|
||||
}
|
||||
|
||||
newPrevFields := make(map[string]bool)
|
||||
for i := 0; i < newIndex; i++ {
|
||||
newPrevFields[newOrder[i]] = true
|
||||
}
|
||||
|
||||
// 如果前面的字段集合不同,说明顺序变了
|
||||
if len(currentPrevFields) != len(newPrevFields) {
|
||||
return true
|
||||
}
|
||||
for f := range currentPrevFields {
|
||||
if !newPrevFields[f] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 索引不同,说明顺序变了
|
||||
return true
|
||||
}
|
||||
|
||||
// isIndexChanged 检查索引是否发生变化
|
||||
func (c *MySQLClient) isIndexChanged(oldIdx, newIdx map[string]interface{}) bool {
|
||||
oldCol := getStringValue(oldIdx["Column_name"])
|
||||
newCol := getStringValue(newIdx["Column_name"])
|
||||
if oldCol != newCol {
|
||||
return true
|
||||
}
|
||||
|
||||
oldUnique := getIntValue(oldIdx["Non_unique"])
|
||||
newUnique := getIntValue(newIdx["Non_unique"])
|
||||
return oldUnique != newUnique
|
||||
}
|
||||
|
||||
// buildAddColumnStatement 构建添加字段的语句
|
||||
func (c *MySQLClient) buildAddColumnStatement(tableName string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
field := getStringValue(col["Field"])
|
||||
if field == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildModifyColumnStatement 构建修改字段的语句
|
||||
func (c *MySQLClient) buildModifyColumnStatement(tableName, field string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` MODIFY COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildColumnPosition 构建字段位置子句(AFTER 或 FIRST)
|
||||
func (c *MySQLClient) buildColumnPosition(fieldOrder []string, index int) string {
|
||||
if index < 0 || index >= len(fieldOrder) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
// 第一个字段使用 FIRST
|
||||
return " FIRST"
|
||||
}
|
||||
|
||||
// 其他字段使用 AFTER 前一个字段
|
||||
prevField := fieldOrder[index-1]
|
||||
return fmt.Sprintf(" AFTER `%s`", prevField)
|
||||
}
|
||||
|
||||
// buildColumnDefinition 构建字段定义
|
||||
func (c *MySQLClient) buildColumnDefinition(col map[string]interface{}) string {
|
||||
field := getStringValue(col["Field"])
|
||||
colType := getStringValue(col["Type"])
|
||||
null := getStringValue(col["Null"])
|
||||
defaultVal := col["Default"]
|
||||
extra := getStringValue(col["Extra"])
|
||||
comment := getStringValue(col["Comment"])
|
||||
|
||||
def := fmt.Sprintf("`%s` %s", field, colType)
|
||||
|
||||
if null == "NO" {
|
||||
def += " NOT NULL"
|
||||
}
|
||||
|
||||
if defaultVal != nil {
|
||||
if defaultStr, ok := defaultVal.(string); ok {
|
||||
if defaultStr == "" {
|
||||
// 空字符串表示默认值为空字符串
|
||||
def += " DEFAULT ''"
|
||||
} else if defaultStr != "NULL" {
|
||||
// 转义单引号
|
||||
escapedDefault := strings.ReplaceAll(defaultStr, "'", "''")
|
||||
def += fmt.Sprintf(" DEFAULT '%s'", escapedDefault)
|
||||
}
|
||||
// 如果 defaultStr == "NULL",不添加 DEFAULT 子句(允许 NULL)
|
||||
} else {
|
||||
// 非字符串类型的默认值
|
||||
def += fmt.Sprintf(" DEFAULT %v", defaultVal)
|
||||
}
|
||||
}
|
||||
|
||||
if extra != "" {
|
||||
def += " " + extra
|
||||
}
|
||||
|
||||
if comment != "" {
|
||||
// 转义单引号
|
||||
escapedComment := strings.ReplaceAll(comment, "'", "''")
|
||||
def += fmt.Sprintf(" COMMENT '%s'", escapedComment)
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// buildAddIndexStatement 构建添加索引的语句
|
||||
func (c *MySQLClient) buildAddIndexStatement(tableName string, idx map[string]interface{}) string {
|
||||
keyName := getStringValue(idx["Key_name"])
|
||||
columnName := getStringValue(idx["Column_name"])
|
||||
nonUnique := getIntValue(idx["Non_unique"])
|
||||
|
||||
if keyName == "" || columnName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
indexType := "INDEX"
|
||||
if nonUnique == 0 {
|
||||
indexType = "UNIQUE INDEX"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (`%s`)", tableName, indexType, keyName, columnName)
|
||||
}
|
||||
|
||||
// getStringValue 安全获取字符串值
|
||||
func getStringValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// getIntValue 安全获取整数值
|
||||
func getIntValue(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var i int
|
||||
fmt.Sscanf(val, "%d", &i)
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
236
internal/dbclient/pool.go
Normal file
236
internal/dbclient/pool.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go-desk/internal/crypto"
|
||||
"go-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionPool 连接池管理器
|
||||
type ConnectionPool struct {
|
||||
mysqlClients map[uint]*MySQLClient
|
||||
redisClients map[uint]*RedisClient
|
||||
mongoClients map[uint]*MongoClient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
globalPool *ConnectionPool
|
||||
poolOnce sync.Once
|
||||
)
|
||||
|
||||
// GetPool 获取全局连接池实例
|
||||
func GetPool() *ConnectionPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = &ConnectionPool{
|
||||
mysqlClients: make(map[uint]*MySQLClient),
|
||||
redisClients: make(map[uint]*RedisClient),
|
||||
mongoClients: make(map[uint]*MongoClient),
|
||||
}
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// GetMySQLClient 获取或创建 MySQL 客户端
|
||||
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mysqlClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
if err := client.sqlDB.Ping(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mysqlClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MySQLConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password, // 如果密码为空,MySQL会尝试无密码连接
|
||||
Database: conn.Database,
|
||||
}
|
||||
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mysqlClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetRedisClient 获取或创建 Redis 客户端
|
||||
func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.redisClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx).Err(); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.redisClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Redis DB 编号(从 Database 字段,默认为 0)
|
||||
dbNum := 0
|
||||
if conn.Database != "" {
|
||||
// 尝试解析 Database 字段为数字
|
||||
_, err := fmt.Sscanf(conn.Database, "%d", &dbNum)
|
||||
if err != nil {
|
||||
// 如果解析失败,使用默认值 0
|
||||
dbNum = 0
|
||||
}
|
||||
// 限制 DB 编号在 0-15 之间
|
||||
if dbNum < 0 || dbNum > 15 {
|
||||
dbNum = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &RedisConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.redisClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetMongoClient 获取或创建 MongoDB 客户端
|
||||
func (p *ConnectionPool) GetMongoClient(conn *models.DbConnection) (*MongoClient, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if client, ok := p.mongoClients[conn.ID]; ok {
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := client.client.Ping(ctx, nil); err == nil {
|
||||
return client, nil
|
||||
}
|
||||
// 连接已断开,移除并重新创建
|
||||
client.Close()
|
||||
delete(p.mongoClients, conn.ID)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
password, err := crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析 Options 获取 MongoDB 连接参数
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if conn.Options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil {
|
||||
if as, ok := opts["authSource"].(string); ok && as != "" {
|
||||
authSource = as
|
||||
}
|
||||
if am, ok := opts["authMechanism"].(string); ok && am != "" {
|
||||
authMechanism = am
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新客户端
|
||||
config := &MongoConfig{
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: password,
|
||||
Database: conn.Database,
|
||||
AuthSource: authSource,
|
||||
AuthMechanism: authMechanism,
|
||||
}
|
||||
|
||||
client, err := NewMongoClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.mongoClients[conn.ID] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// CloseConnection 关闭指定连接
|
||||
func (p *ConnectionPool) CloseConnection(connID uint, dbType string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
if client, ok := p.mysqlClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mysqlClients, connID)
|
||||
}
|
||||
case "redis":
|
||||
if client, ok := p.redisClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.redisClients, connID)
|
||||
}
|
||||
case "mongo":
|
||||
if client, ok := p.mongoClients[connID]; ok {
|
||||
client.Close()
|
||||
delete(p.mongoClients, connID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll 关闭所有连接
|
||||
func (p *ConnectionPool) CloseAll() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for _, client := range p.mysqlClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.redisClients {
|
||||
client.Close()
|
||||
}
|
||||
for _, client := range p.mongoClients {
|
||||
client.Close()
|
||||
}
|
||||
|
||||
p.mysqlClients = make(map[uint]*MySQLClient)
|
||||
p.redisClients = make(map[uint]*RedisClient)
|
||||
p.mongoClients = make(map[uint]*MongoClient)
|
||||
}
|
||||
239
internal/dbclient/redis.go
Normal file
239
internal/dbclient/redis.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisClient Redis 客户端
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
config *RedisConfig
|
||||
}
|
||||
|
||||
// RedisConfig Redis 配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int // 数据库编号,默认 0
|
||||
}
|
||||
|
||||
// NewRedisClient 创建 Redis 客户端
|
||||
func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("Redis 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
return &RedisClient{
|
||||
client: rdb,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestRedisConnection 测试连接
|
||||
func TestRedisConnection(host string, port int, password string) error {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: 0,
|
||||
}
|
||||
client, err := NewRedisClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewRedisClientByDB 根据参数创建指定 DB 的 Redis 客户端(用于多 DB 场景)
|
||||
func NewRedisClientByDB(host string, port int, password string, dbNum int) (*RedisClient, error) {
|
||||
config := &RedisConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
DB: dbNum,
|
||||
}
|
||||
return NewRedisClient(config)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *RedisClient) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteCommand 执行 Redis 命令
|
||||
func (c *RedisClient) ExecuteCommand(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
|
||||
return c.client.Do(ctx, append([]interface{}{cmd}, args...)...).Result()
|
||||
}
|
||||
|
||||
// GetKeys 获取 Key 列表(支持 pattern,使用 SCAN 代替 KEYS 以提高性能)
|
||||
func (c *RedisClient) GetKeys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if pattern == "" {
|
||||
pattern = "*"
|
||||
}
|
||||
|
||||
var keys []string
|
||||
var cursor uint64
|
||||
const count = 100 // 每次扫描的数量
|
||||
|
||||
for {
|
||||
var err error
|
||||
var batch []string
|
||||
batch, cursor, err = c.client.Scan(ctx, cursor, pattern, count).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys = append(keys, batch...)
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetKeyType 获取 Key 类型
|
||||
func (c *RedisClient) GetKeyType(ctx context.Context, key string) (string, error) {
|
||||
return c.client.Type(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyValue 获取 Key 值
|
||||
func (c *RedisClient) GetKeyValue(ctx context.Context, key string) (interface{}, error) {
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch keyType {
|
||||
case "string":
|
||||
return c.client.Get(ctx, key).Result()
|
||||
case "list":
|
||||
return c.client.LRange(ctx, key, 0, -1).Result()
|
||||
case "set":
|
||||
return c.client.SMembers(ctx, key).Result()
|
||||
case "zset":
|
||||
// 对于有序集合,返回带分数的结果
|
||||
zMembers, err := c.client.ZRangeWithScores(ctx, key, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为 map 格式,便于展示
|
||||
result := make([]map[string]interface{}, len(zMembers))
|
||||
for i, member := range zMembers {
|
||||
result[i] = map[string]interface{}{
|
||||
"member": member.Member,
|
||||
"score": member.Score,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
case "hash":
|
||||
return c.client.HGetAll(ctx, key).Result()
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的类型: %s", keyType)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTTL 获取 Key 的 TTL
|
||||
func (c *RedisClient) GetTTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
return c.client.TTL(ctx, key).Result()
|
||||
}
|
||||
|
||||
// GetKeyInfo 获取 Key 详细信息
|
||||
func (c *RedisClient) GetKeyInfo(ctx context.Context, key string) (map[string]interface{}, error) {
|
||||
info := map[string]interface{}{
|
||||
"key": key,
|
||||
"type": "",
|
||||
"value": nil,
|
||||
"ttl": 0,
|
||||
}
|
||||
|
||||
// 获取 Key 类型
|
||||
keyType, err := c.GetKeyType(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 类型失败: %v", err)
|
||||
}
|
||||
info["type"] = keyType
|
||||
|
||||
// 获取 TTL
|
||||
ttl, err := c.GetTTL(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 TTL 失败: %v", err)
|
||||
}
|
||||
info["ttl"] = ttl.Seconds()
|
||||
|
||||
// 获取 Key 值(限制大小,避免过大)
|
||||
value, err := c.GetKeyValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 Key 值失败: %v", err)
|
||||
}
|
||||
info["value"] = formatValuePreview(value)
|
||||
|
||||
// 获取 Key 长度(使用 STRLEN、HLEN、SCARD、ZCARD)
|
||||
var keyLength int64
|
||||
switch keyType {
|
||||
case "string":
|
||||
keyLength, err = c.client.StrLen(ctx, key).Result()
|
||||
case "list":
|
||||
keyLength, err = c.client.LLen(ctx, key).Result()
|
||||
case "set":
|
||||
keyLength, err = c.client.SCard(ctx, key).Result()
|
||||
case "zset":
|
||||
keyLength, err = c.client.ZCard(ctx, key).Result()
|
||||
case "hash":
|
||||
keyLength, err = c.client.HLen(ctx, key).Result()
|
||||
}
|
||||
if err == nil {
|
||||
info["length"] = keyLength
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// formatValuePreview 格式化值预览(限制长度)
|
||||
func formatValuePreview(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
const maxPreviewLength = 200
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
if len(valueStr) > maxPreviewLength {
|
||||
valueStr = valueStr[:maxPreviewLength] + "..."
|
||||
}
|
||||
|
||||
return valueStr
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表(Redis 使用 DB number)
|
||||
// Redis 没有传统数据库概念,这里返回空数组
|
||||
func (c *RedisClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
// Redis 可以使用 DB number 来隔离数据
|
||||
// 这里可以返回当前配置的 DB 或者所有可用的 DB
|
||||
// 为简单起见,返回空数组,让用户直接操作 Key
|
||||
return []string{}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user