package dbclient import ( "context" "database/sql" "fmt" "strings" "time" mysqldriver "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) // MySQLClient MySQL 客户端 type MySQLClient struct { db *gorm.DB sqlDB *sql.DB config *MySQLConfig } // MySQLConfig MySQL 配置 type MySQLConfig struct { Host string Port int Username string Password string Database string } // NewMySQLClient 创建 MySQL 客户端 func NewMySQLClient(config *MySQLConfig) (*MySQLClient, error) { // 构建 DSN mysqlConfig := mysqldriver.Config{ User: config.Username, Passwd: config.Password, Net: "tcp", Addr: fmt.Sprintf("%s:%d", config.Host, config.Port), DBName: config.Database, Params: map[string]string{ "charset": "utf8mb4", "parseTime": "True", "loc": "Local", "multiStatements": "true", // 支持多条SQL语句执行 }, AllowNativePasswords: true, Timeout: 5 * time.Second, } dsn := mysqlConfig.FormatDSN() // GORM 配置 gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), } // 打开连接 db, err := gorm.Open(mysql.Open(dsn), gormConfig) if err != nil { return nil, fmt.Errorf("连接 MySQL 失败: %v", err) } // 获取底层 sql.DB sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("获取数据库实例失败: %v", err) } // 测试连接 if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("MySQL 连接测试失败: %v", err) } // 设置连接池参数 sqlDB.SetMaxOpenConns(10) sqlDB.SetMaxIdleConns(2) sqlDB.SetConnMaxLifetime(5 * time.Minute) return &MySQLClient{ db: db, sqlDB: sqlDB, config: config, }, nil } // TestConnection 测试连接 func TestMySQLConnection(host string, port int, username, password, database string) error { config := &MySQLConfig{ Host: host, Port: port, Username: username, Password: password, Database: database, } client, err := NewMySQLClient(config) if err != nil { return err } defer client.Close() return nil } // Close 关闭连接 func (c *MySQLClient) Close() error { if c.sqlDB != nil { return c.sqlDB.Close() } return nil } // QueryResult 查询结果,包含数据和列顺序 type QueryResult struct { Data []map[string]interface{} Columns []string } // ExecuteQuery 执行查询 SQL // database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库 // 注意:SQL 语句应该已经包含 LIMIT 和 OFFSET(由客户端添加) func (c *MySQLClient) ExecuteQuery(ctx context.Context, sqlStr string, database string) (*QueryResult, error) { // 确定要使用的数据库 dbName := database if dbName == "" { dbName = c.config.Database } // 使用 Session 创建独立的数据库会话,避免影响其他查询 db := c.db.Session(&gorm.Session{}) // 如果指定了数据库,先切换到该数据库 if dbName != "" { if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil { return nil, fmt.Errorf("切换数据库失败: %v", err) } } rows, err := db.Raw(sqlStr).Rows() if err != nil { return nil, fmt.Errorf("执行查询失败: %v", err) } defer rows.Close() // 检查 rows 错误 if err := rows.Err(); err != nil { return nil, fmt.Errorf("查询结果错误: %v", err) } // 获取列名 columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("获取列名失败: %v", err) } // 如果没有列,返回空数组 if len(columns) == 0 { return &QueryResult{ Data: []map[string]interface{}{}, Columns: []string{}, }, nil } // 读取数据 var results []map[string]interface{} for rows.Next() { // 创建值数组和指针数组 values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } // 扫描行数据 if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("扫描数据失败: %v", err) } // 构建结果 map,按照列顺序构建 row := make(map[string]interface{}) for i, col := range columns { val := values[i] // 处理 nil 值 if val == nil { row[col] = nil } else if b, ok := val.([]byte); ok { // 处理 []byte 类型 row[col] = string(b) } else { row[col] = val } } results = append(results, row) } // 检查迭代过程中的错误 if err := rows.Err(); err != nil { return nil, fmt.Errorf("读取数据时发生错误: %v", err) } return &QueryResult{ Data: results, Columns: columns, }, nil } // ExecuteUpdate 执行更新 SQL(INSERT/UPDATE/DELETE) // database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库 func (c *MySQLClient) ExecuteUpdate(ctx context.Context, sqlStr string, database string) (int64, error) { // 确定要使用的数据库 dbName := database if dbName == "" { dbName = c.config.Database } // 使用 Session 创建独立的数据库会话,避免影响其他查询 db := c.db.Session(&gorm.Session{}) // 如果指定了数据库,先切换到该数据库 if dbName != "" { if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil { return 0, fmt.Errorf("切换数据库失败: %v", err) } } result := db.Exec(sqlStr) if result.Error != nil { return 0, fmt.Errorf("执行更新失败: %v", result.Error) } return result.RowsAffected, nil } // ListDatabases 获取数据库列表 func (c *MySQLClient) ListDatabases(ctx context.Context) ([]string, error) { var databases []string err := c.db.Raw("SHOW DATABASES").Scan(&databases).Error return databases, err } // ListTables 获取表列表 func (c *MySQLClient) ListTables(ctx context.Context, database string) ([]string, error) { var tables []string query := "SHOW TABLES" if database != "" { query = fmt.Sprintf("SHOW TABLES FROM `%s`", database) } err := c.db.Raw(query).Scan(&tables).Error return tables, err } // GetTableStructure 获取表结构 func (c *MySQLClient) GetTableStructure(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) { // 使用 SHOW FULL COLUMNS 来获取包含 comment 的完整字段信息 query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName) if database != "" { query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", database, tableName) } rows, err := c.db.Raw(query).Rows() if err != nil { return nil, fmt.Errorf("获取表结构失败: %v", err) } defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("获取列名失败: %v", err) } var results []map[string]interface{} for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("扫描数据失败: %v", err) } row := make(map[string]interface{}) for i, col := range columns { val := values[i] if b, ok := val.([]byte); ok { row[col] = string(b) } else if val == nil { row[col] = nil } else { row[col] = val } } // 确保 Comment 字段存在(SHOW FULL COLUMNS 返回的字段名是 Comment) if _, ok := row["Comment"]; !ok { row["Comment"] = "" } results = append(results, row) } return results, nil } // GetIndexes 获取索引列表 func (c *MySQLClient) GetIndexes(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) { query := "SHOW INDEX FROM " if database != "" { query += fmt.Sprintf("`%s`.", database) } query += fmt.Sprintf("`%s`", tableName) rows, err := c.db.Raw(query).Rows() if err != nil { return nil, fmt.Errorf("获取索引列表失败: %v", err) } defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("获取列名失败: %v", err) } var results []map[string]interface{} for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("扫描数据失败: %v", err) } row := make(map[string]interface{}) for i, col := range columns { val := values[i] if b, ok := val.([]byte); ok { row[col] = string(b) } else { row[col] = val } } results = append(results, row) } return results, nil } // PreviewTableStructure 预览表结构变更,只生成 SQL 语句不执行 func (c *MySQLClient) PreviewTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) { // 获取当前表结构 currentColumns, err := c.GetTableStructure(ctx, database, tableName) if err != nil { return nil, fmt.Errorf("获取当前表结构失败: %v", err) } currentIndexes, err := c.GetIndexes(ctx, database, tableName) if err != nil { return nil, fmt.Errorf("获取当前索引失败: %v", err) } // 解析新的结构数据 var newColumns []map[string]interface{} var newIndexes []map[string]interface{} if cols, ok := structure["columns"].([]interface{}); ok { for _, col := range cols { if colMap, ok := col.(map[string]interface{}); ok { newColumns = append(newColumns, colMap) } } } if idxs, ok := structure["indexes"].([]interface{}); ok { for _, idx := range idxs { if idxMap, ok := idx.(map[string]interface{}); ok { newIndexes = append(newIndexes, idxMap) } } } // 构建 ALTER TABLE 语句 var alterStatements []string // 处理字段变更 alterStatements = append(alterStatements, c.buildColumnAlterStatements(tableName, currentColumns, newColumns)...) // 处理索引变更 alterStatements = append(alterStatements, c.buildIndexAlterStatements(tableName, currentIndexes, newIndexes)...) return alterStatements, nil } // UpdateTableStructure 更新表结构,返回生成的 SQL 语句列表 func (c *MySQLClient) UpdateTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) { // 先预览生成 SQL 语句 alterStatements, err := c.PreviewTableStructure(ctx, database, tableName, structure) if err != nil { return nil, err } // 执行所有 ALTER TABLE 语句 if len(alterStatements) > 0 { dbName := database if dbName == "" { dbName = c.config.Database } db := c.db.Session(&gorm.Session{}) if dbName != "" { if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil { return alterStatements, fmt.Errorf("切换数据库失败: %v", err) } } for _, stmt := range alterStatements { if err := db.Exec(stmt).Error; err != nil { return alterStatements, fmt.Errorf("执行 ALTER TABLE 失败: %v, SQL: %s", err, stmt) } } } return alterStatements, nil } // buildColumnAlterStatements 构建字段变更的 ALTER TABLE 语句 func (c *MySQLClient) buildColumnAlterStatements(tableName string, currentColumns, newColumns []map[string]interface{}) []string { var statements []string // 创建字段名映射和顺序映射 currentFieldMap := make(map[string]map[string]interface{}) currentFieldOrder := make([]string, 0, len(currentColumns)) for _, col := range currentColumns { if field, ok := col["Field"].(string); ok { currentFieldMap[field] = col currentFieldOrder = append(currentFieldOrder, field) } } newFieldMap := make(map[string]bool) newFieldOrder := make([]string, 0, len(newColumns)) newColumnsMap := make(map[string]map[string]interface{}) for _, col := range newColumns { if field, ok := col["Field"].(string); ok && field != "" { newFieldMap[field] = true newFieldOrder = append(newFieldOrder, field) newColumnsMap[field] = col } } // 检测字段重命名:优先使用位置匹配,如果位置相同但字段名不同,认为是重命名 renameMap := make(map[string]string) // oldName -> newName processedNewFields := make(map[string]bool) // 第一步:使用位置匹配检测重命名(最可靠) for oldIndex, oldFieldName := range currentFieldOrder { if newFieldMap[oldFieldName] { continue // 字段名未改变,跳过 } // 检查新字段列表中相同位置是否有字段 if oldIndex < len(newFieldOrder) { newFieldName := newFieldOrder[oldIndex] _, existsInCurrent := currentFieldMap[newFieldName] if !existsInCurrent && !processedNewFields[newFieldName] { // 新字段不在当前字段列表中,且位置相同,很可能是重命名 // 进一步验证:检查类型是否相同(类型相同更可能是重命名) oldCol := currentFieldMap[oldFieldName] newCol := newColumnsMap[newFieldName] oldType := getStringValue(oldCol["Type"]) newType := getStringValue(newCol["Type"]) // 如果类型相同,认为是重命名 if oldType == newType { renameMap[oldFieldName] = newFieldName processedNewFields[newFieldName] = true continue } } } } // 第二步:对于未匹配的字段,使用属性匹配(兼容旧逻辑) for oldFieldName, oldCol := range currentFieldMap { if newFieldMap[oldFieldName] { continue // 字段名未改变,跳过 } if renameMap[oldFieldName] != "" { continue // 已经通过位置匹配识别为重命名 } // 查找属性完全匹配的新字段 var matchedNewField string for newFieldName, newCol := range newColumnsMap { if processedNewFields[newFieldName] { continue // 已经被匹配过了 } _, existsInCurrent := currentFieldMap[newFieldName] if !existsInCurrent { // 这是一个新增字段,检查属性是否匹配 if c.isColumnPropertiesEqual(oldCol, newCol) { if matchedNewField == "" { matchedNewField = newFieldName } else { // 有多个匹配,无法确定,不认为是重命名 matchedNewField = "" break } } } } // 如果找到唯一匹配,认为是重命名 if matchedNewField != "" { renameMap[oldFieldName] = matchedNewField processedNewFields[matchedNewField] = true } } // 处理字段重命名 for oldName, newName := range renameMap { stmt := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tableName, oldName, newName) statements = append(statements, stmt) } // 处理字段添加、修改和位置调整(排除已重命名的字段) for i, newCol := range newColumns { field, _ := newCol["Field"].(string) if field == "" { continue } // 检查是否是重命名的字段 isRenamed := false var oldName string for old, new := range renameMap { if new == field { isRenamed = true oldName = old break } } if isRenamed { // 重命名的字段:如果属性有变化,需要 MODIFY COLUMN oldCol := currentFieldMap[oldName] needsModify := c.isColumnChanged(oldCol, newCol) // 检查顺序变化:使用旧字段名在 currentOrder 中查找位置,与新位置比较 oldIndex := -1 for idx, name := range currentFieldOrder { if name == oldName { oldIndex = idx break } } needsReorder := (oldIndex != -1 && oldIndex != i) if needsModify || needsReorder { // 重命名后需要修改属性或位置 stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i) if stmt != "" { statements = append(statements, stmt) } } continue } if currentCol, exists := currentFieldMap[field]; exists { // 修改现有字段 needsModify := c.isColumnChanged(currentCol, newCol) needsReorder := c.isColumnOrderChanged(currentFieldOrder, newFieldOrder, field, i) if needsModify || needsReorder { stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i) if stmt != "" { statements = append(statements, stmt) } } } else { // 添加新字段(排除重命名的字段) stmt := c.buildAddColumnStatement(tableName, newCol, newFieldOrder, i) if stmt != "" { statements = append(statements, stmt) } } } // 删除不存在的字段(排除已重命名的字段) for field := range currentFieldMap { if !newFieldMap[field] && renameMap[field] == "" { stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`", tableName, field) statements = append(statements, stmt) } } return statements } // buildIndexAlterStatements 构建索引变更的 ALTER TABLE 语句 func (c *MySQLClient) buildIndexAlterStatements(tableName string, currentIndexes, newIndexes []map[string]interface{}) []string { var statements []string // 创建索引名映射 currentIndexMap := make(map[string]map[string]interface{}) for _, idx := range currentIndexes { if keyName, ok := idx["Key_name"].(string); ok && keyName != "PRIMARY" { currentIndexMap[keyName] = idx } } newIndexMap := make(map[string]bool) for _, idx := range newIndexes { if keyName, ok := idx["Key_name"].(string); ok && keyName != "" && keyName != "PRIMARY" { newIndexMap[keyName] = true } } // 处理索引变更 for _, newIdx := range newIndexes { keyName, _ := newIdx["Key_name"].(string) if keyName == "" || keyName == "PRIMARY" { continue } if currentIdx, exists := currentIndexMap[keyName]; exists { // 修改现有索引 if c.isIndexChanged(currentIdx, newIdx) { dropStmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName) addStmt := c.buildAddIndexStatement(tableName, newIdx) if addStmt != "" { statements = append(statements, dropStmt) statements = append(statements, addStmt) } } } else { // 添加新索引 stmt := c.buildAddIndexStatement(tableName, newIdx) if stmt != "" { statements = append(statements, stmt) } } } // 删除不存在的索引 for keyName := range currentIndexMap { if !newIndexMap[keyName] { stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName) statements = append(statements, stmt) } } return statements } // isColumnChanged 检查字段是否发生变化(不包括字段名) func (c *MySQLClient) isColumnChanged(oldCol, newCol map[string]interface{}) bool { fields := []string{"Type", "Null", "Default", "Extra", "Comment"} for _, field := range fields { oldVal := getStringValue(oldCol[field]) newVal := getStringValue(newCol[field]) if oldVal != newVal { return true } } return false } // isColumnPropertiesEqual 检查字段属性是否完全相等(不包括字段名) func (c *MySQLClient) isColumnPropertiesEqual(oldCol, newCol map[string]interface{}) bool { fields := []string{"Type", "Null", "Default", "Extra", "Key", "Comment"} for _, field := range fields { oldVal := getStringValue(oldCol[field]) newVal := getStringValue(newCol[field]) if oldVal != newVal { return false } } return true } // isColumnOrderChanged 检查字段顺序是否发生变化 func (c *MySQLClient) isColumnOrderChanged(currentOrder, newOrder []string, fieldName string, newIndex int) bool { // 查找字段在当前顺序中的位置 currentIndex := -1 for i, name := range currentOrder { if name == fieldName { currentIndex = i break } } // 如果字段不存在于当前顺序中(新字段),不需要检查顺序 if currentIndex == -1 { return false } // 如果索引相同,检查前面的字段是否相同 if newIndex == currentIndex { // 检查前面的字段集合是否相同 if newIndex > 0 { currentPrevFields := make(map[string]bool) for i := 0; i < currentIndex; i++ { currentPrevFields[currentOrder[i]] = true } newPrevFields := make(map[string]bool) for i := 0; i < newIndex; i++ { newPrevFields[newOrder[i]] = true } // 如果前面的字段集合不同,说明顺序变了 if len(currentPrevFields) != len(newPrevFields) { return true } for f := range currentPrevFields { if !newPrevFields[f] { return true } } } return false } // 索引不同,说明顺序变了 return true } // isIndexChanged 检查索引是否发生变化 func (c *MySQLClient) isIndexChanged(oldIdx, newIdx map[string]interface{}) bool { oldCol := getStringValue(oldIdx["Column_name"]) newCol := getStringValue(newIdx["Column_name"]) if oldCol != newCol { return true } oldUnique := getIntValue(oldIdx["Non_unique"]) newUnique := getIntValue(newIdx["Non_unique"]) return oldUnique != newUnique } // buildAddColumnStatement 构建添加字段的语句 func (c *MySQLClient) buildAddColumnStatement(tableName string, col map[string]interface{}, fieldOrder []string, index int) string { field := getStringValue(col["Field"]) if field == "" { return "" } colDef := c.buildColumnDefinition(col) // 确定字段位置 position := c.buildColumnPosition(fieldOrder, index) return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s%s", tableName, colDef, position) } // buildModifyColumnStatement 构建修改字段的语句 func (c *MySQLClient) buildModifyColumnStatement(tableName, field string, col map[string]interface{}, fieldOrder []string, index int) string { colDef := c.buildColumnDefinition(col) // 确定字段位置 position := c.buildColumnPosition(fieldOrder, index) return fmt.Sprintf("ALTER TABLE `%s` MODIFY COLUMN %s%s", tableName, colDef, position) } // buildColumnPosition 构建字段位置子句(AFTER 或 FIRST) func (c *MySQLClient) buildColumnPosition(fieldOrder []string, index int) string { if index < 0 || index >= len(fieldOrder) { return "" } if index == 0 { // 第一个字段使用 FIRST return " FIRST" } // 其他字段使用 AFTER 前一个字段 prevField := fieldOrder[index-1] return fmt.Sprintf(" AFTER `%s`", prevField) } // buildColumnDefinition 构建字段定义 func (c *MySQLClient) buildColumnDefinition(col map[string]interface{}) string { field := getStringValue(col["Field"]) colType := getStringValue(col["Type"]) null := getStringValue(col["Null"]) defaultVal := col["Default"] extra := getStringValue(col["Extra"]) comment := getStringValue(col["Comment"]) def := fmt.Sprintf("`%s` %s", field, colType) if null == "NO" { def += " NOT NULL" } if defaultVal != nil { if defaultStr, ok := defaultVal.(string); ok { if defaultStr == "" { // 空字符串表示默认值为空字符串 def += " DEFAULT ''" } else if defaultStr != "NULL" { // 转义单引号 escapedDefault := strings.ReplaceAll(defaultStr, "'", "''") def += fmt.Sprintf(" DEFAULT '%s'", escapedDefault) } // 如果 defaultStr == "NULL",不添加 DEFAULT 子句(允许 NULL) } else { // 非字符串类型的默认值 def += fmt.Sprintf(" DEFAULT %v", defaultVal) } } if extra != "" { def += " " + extra } if comment != "" { // 转义单引号 escapedComment := strings.ReplaceAll(comment, "'", "''") def += fmt.Sprintf(" COMMENT '%s'", escapedComment) } return def } // buildAddIndexStatement 构建添加索引的语句 func (c *MySQLClient) buildAddIndexStatement(tableName string, idx map[string]interface{}) string { keyName := getStringValue(idx["Key_name"]) columnName := getStringValue(idx["Column_name"]) nonUnique := getIntValue(idx["Non_unique"]) if keyName == "" || columnName == "" { return "" } indexType := "INDEX" if nonUnique == 0 { indexType = "UNIQUE INDEX" } return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (`%s`)", tableName, indexType, keyName, columnName) } // getStringValue 安全获取字符串值 func getStringValue(v interface{}) string { if v == nil { return "" } if s, ok := v.(string); ok { return s } return fmt.Sprintf("%v", v) } // getIntValue 安全获取整数值 func getIntValue(v interface{}) int { if v == nil { return 0 } switch val := v.(type) { case int: return val case int64: return int(val) case float64: return int(val) case string: var i int fmt.Sscanf(val, "%d", &i) return i } return 0 }