新增:连接管理、数据查询等功能
This commit is contained in:
875
internal/dbclient/mysql.go
Normal file
875
internal/dbclient/mysql.go
Normal file
@@ -0,0 +1,875 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mysqldriver "github.com/go-sql-driver/mysql"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// MySQLClient MySQL 客户端
|
||||
type MySQLClient struct {
|
||||
db *gorm.DB
|
||||
sqlDB *sql.DB
|
||||
config *MySQLConfig
|
||||
}
|
||||
|
||||
// MySQLConfig MySQL 配置
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// NewMySQLClient 创建 MySQL 客户端
|
||||
func NewMySQLClient(config *MySQLConfig) (*MySQLClient, error) {
|
||||
// 构建 DSN
|
||||
mysqlConfig := mysqldriver.Config{
|
||||
User: config.Username,
|
||||
Passwd: config.Password,
|
||||
Net: "tcp",
|
||||
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
DBName: config.Database,
|
||||
Params: map[string]string{
|
||||
"charset": "utf8mb4",
|
||||
"parseTime": "True",
|
||||
"loc": "Local",
|
||||
"multiStatements": "true", // 支持多条SQL语句执行
|
||||
},
|
||||
AllowNativePasswords: true,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
dsn := mysqlConfig.FormatDSN()
|
||||
|
||||
// GORM 配置
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
// 打开连接
|
||||
db, err := gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 MySQL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取底层 sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("MySQL 连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(10)
|
||||
sqlDB.SetMaxIdleConns(2)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
return &MySQLClient{
|
||||
db: db,
|
||||
sqlDB: sqlDB,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func TestMySQLConnection(host string, port int, username, password, database string) error {
|
||||
config := &MySQLConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
}
|
||||
client, err := NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *MySQLClient) Close() error {
|
||||
if c.sqlDB != nil {
|
||||
return c.sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryResult 查询结果,包含数据和列顺序
|
||||
type QueryResult struct {
|
||||
Data []map[string]interface{}
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// ExecuteQuery 执行查询 SQL
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
// 注意:SQL 语句应该已经包含 LIMIT 和 OFFSET(由客户端添加)
|
||||
func (c *MySQLClient) ExecuteQuery(ctx context.Context, sqlStr string, database string) (*QueryResult, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return nil, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Raw(sqlStr).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("执行查询失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 检查 rows 错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("查询结果错误: %v", err)
|
||||
}
|
||||
|
||||
// 获取列名
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有列,返回空数组
|
||||
if len(columns) == 0 {
|
||||
return &QueryResult{
|
||||
Data: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
// 创建值数组和指针数组
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
// 扫描行数据
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 构建结果 map,按照列顺序构建
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
// 处理 nil 值
|
||||
if val == nil {
|
||||
row[col] = nil
|
||||
} else if b, ok := val.([]byte); ok {
|
||||
// 处理 []byte 类型
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
// 检查迭代过程中的错误
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("读取数据时发生错误: %v", err)
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Data: results,
|
||||
Columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteUpdate 执行更新 SQL(INSERT/UPDATE/DELETE)
|
||||
// database 参数可选,如果提供且不为空则优先使用,否则使用配置中的数据库
|
||||
func (c *MySQLClient) ExecuteUpdate(ctx context.Context, sqlStr string, database string) (int64, error) {
|
||||
// 确定要使用的数据库
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
// 使用 Session 创建独立的数据库会话,避免影响其他查询
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
|
||||
// 如果指定了数据库,先切换到该数据库
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return 0, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Exec(sqlStr)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("执行更新失败: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// ListDatabases 获取数据库列表
|
||||
func (c *MySQLClient) ListDatabases(ctx context.Context) ([]string, error) {
|
||||
var databases []string
|
||||
err := c.db.Raw("SHOW DATABASES").Scan(&databases).Error
|
||||
return databases, err
|
||||
}
|
||||
|
||||
// ListTables 获取表列表
|
||||
func (c *MySQLClient) ListTables(ctx context.Context, database string) ([]string, error) {
|
||||
var tables []string
|
||||
query := "SHOW TABLES"
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW TABLES FROM `%s`", database)
|
||||
}
|
||||
err := c.db.Raw(query).Scan(&tables).Error
|
||||
return tables, err
|
||||
}
|
||||
|
||||
// GetTableStructure 获取表结构
|
||||
func (c *MySQLClient) GetTableStructure(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
// 使用 SHOW FULL COLUMNS 来获取包含 comment 的完整字段信息
|
||||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||||
if database != "" {
|
||||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", database, tableName)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表结构失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else if val == nil {
|
||||
row[col] = nil
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Comment 字段存在(SHOW FULL COLUMNS 返回的字段名是 Comment)
|
||||
if _, ok := row["Comment"]; !ok {
|
||||
row["Comment"] = ""
|
||||
}
|
||||
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetIndexes 获取索引列表
|
||||
func (c *MySQLClient) GetIndexes(ctx context.Context, database, tableName string) ([]map[string]interface{}, error) {
|
||||
query := "SHOW INDEX FROM "
|
||||
if database != "" {
|
||||
query += fmt.Sprintf("`%s`.", database)
|
||||
}
|
||||
query += fmt.Sprintf("`%s`", tableName)
|
||||
|
||||
rows, err := c.db.Raw(query).Rows()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取索引列表失败: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取列名失败: %v", err)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("扫描数据失败: %v", err)
|
||||
}
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
val := values[i]
|
||||
if b, ok := val.([]byte); ok {
|
||||
row[col] = string(b)
|
||||
} else {
|
||||
row[col] = val
|
||||
}
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// PreviewTableStructure 预览表结构变更,只生成 SQL 语句不执行
|
||||
func (c *MySQLClient) PreviewTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 获取当前表结构
|
||||
currentColumns, err := c.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前表结构失败: %v", err)
|
||||
}
|
||||
|
||||
currentIndexes, err := c.GetIndexes(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取当前索引失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析新的结构数据
|
||||
var newColumns []map[string]interface{}
|
||||
var newIndexes []map[string]interface{}
|
||||
|
||||
if cols, ok := structure["columns"].([]interface{}); ok {
|
||||
for _, col := range cols {
|
||||
if colMap, ok := col.(map[string]interface{}); ok {
|
||||
newColumns = append(newColumns, colMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idxs, ok := structure["indexes"].([]interface{}); ok {
|
||||
for _, idx := range idxs {
|
||||
if idxMap, ok := idx.(map[string]interface{}); ok {
|
||||
newIndexes = append(newIndexes, idxMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ALTER TABLE 语句
|
||||
var alterStatements []string
|
||||
|
||||
// 处理字段变更
|
||||
alterStatements = append(alterStatements, c.buildColumnAlterStatements(tableName, currentColumns, newColumns)...)
|
||||
|
||||
// 处理索引变更
|
||||
alterStatements = append(alterStatements, c.buildIndexAlterStatements(tableName, currentIndexes, newIndexes)...)
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// UpdateTableStructure 更新表结构,返回生成的 SQL 语句列表
|
||||
func (c *MySQLClient) UpdateTableStructure(ctx context.Context, database, tableName string, structure map[string]interface{}) ([]string, error) {
|
||||
// 先预览生成 SQL 语句
|
||||
alterStatements, err := c.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行所有 ALTER TABLE 语句
|
||||
if len(alterStatements) > 0 {
|
||||
dbName := database
|
||||
if dbName == "" {
|
||||
dbName = c.config.Database
|
||||
}
|
||||
|
||||
db := c.db.Session(&gorm.Session{})
|
||||
if dbName != "" {
|
||||
if err := db.Exec(fmt.Sprintf("USE `%s`", dbName)).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("切换数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, stmt := range alterStatements {
|
||||
if err := db.Exec(stmt).Error; err != nil {
|
||||
return alterStatements, fmt.Errorf("执行 ALTER TABLE 失败: %v, SQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alterStatements, nil
|
||||
}
|
||||
|
||||
// buildColumnAlterStatements 构建字段变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildColumnAlterStatements(tableName string, currentColumns, newColumns []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建字段名映射和顺序映射
|
||||
currentFieldMap := make(map[string]map[string]interface{})
|
||||
currentFieldOrder := make([]string, 0, len(currentColumns))
|
||||
for _, col := range currentColumns {
|
||||
if field, ok := col["Field"].(string); ok {
|
||||
currentFieldMap[field] = col
|
||||
currentFieldOrder = append(currentFieldOrder, field)
|
||||
}
|
||||
}
|
||||
|
||||
newFieldMap := make(map[string]bool)
|
||||
newFieldOrder := make([]string, 0, len(newColumns))
|
||||
newColumnsMap := make(map[string]map[string]interface{})
|
||||
for _, col := range newColumns {
|
||||
if field, ok := col["Field"].(string); ok && field != "" {
|
||||
newFieldMap[field] = true
|
||||
newFieldOrder = append(newFieldOrder, field)
|
||||
newColumnsMap[field] = col
|
||||
}
|
||||
}
|
||||
|
||||
// 检测字段重命名:优先使用位置匹配,如果位置相同但字段名不同,认为是重命名
|
||||
renameMap := make(map[string]string) // oldName -> newName
|
||||
processedNewFields := make(map[string]bool)
|
||||
|
||||
// 第一步:使用位置匹配检测重命名(最可靠)
|
||||
for oldIndex, oldFieldName := range currentFieldOrder {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
|
||||
// 检查新字段列表中相同位置是否有字段
|
||||
if oldIndex < len(newFieldOrder) {
|
||||
newFieldName := newFieldOrder[oldIndex]
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent && !processedNewFields[newFieldName] {
|
||||
// 新字段不在当前字段列表中,且位置相同,很可能是重命名
|
||||
// 进一步验证:检查类型是否相同(类型相同更可能是重命名)
|
||||
oldCol := currentFieldMap[oldFieldName]
|
||||
newCol := newColumnsMap[newFieldName]
|
||||
oldType := getStringValue(oldCol["Type"])
|
||||
newType := getStringValue(newCol["Type"])
|
||||
|
||||
// 如果类型相同,认为是重命名
|
||||
if oldType == newType {
|
||||
renameMap[oldFieldName] = newFieldName
|
||||
processedNewFields[newFieldName] = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二步:对于未匹配的字段,使用属性匹配(兼容旧逻辑)
|
||||
for oldFieldName, oldCol := range currentFieldMap {
|
||||
if newFieldMap[oldFieldName] {
|
||||
continue // 字段名未改变,跳过
|
||||
}
|
||||
if renameMap[oldFieldName] != "" {
|
||||
continue // 已经通过位置匹配识别为重命名
|
||||
}
|
||||
|
||||
// 查找属性完全匹配的新字段
|
||||
var matchedNewField string
|
||||
for newFieldName, newCol := range newColumnsMap {
|
||||
if processedNewFields[newFieldName] {
|
||||
continue // 已经被匹配过了
|
||||
}
|
||||
_, existsInCurrent := currentFieldMap[newFieldName]
|
||||
if !existsInCurrent {
|
||||
// 这是一个新增字段,检查属性是否匹配
|
||||
if c.isColumnPropertiesEqual(oldCol, newCol) {
|
||||
if matchedNewField == "" {
|
||||
matchedNewField = newFieldName
|
||||
} else {
|
||||
// 有多个匹配,无法确定,不认为是重命名
|
||||
matchedNewField = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到唯一匹配,认为是重命名
|
||||
if matchedNewField != "" {
|
||||
renameMap[oldFieldName] = matchedNewField
|
||||
processedNewFields[matchedNewField] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理字段重命名
|
||||
for oldName, newName := range renameMap {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s`", tableName, oldName, newName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// 处理字段添加、修改和位置调整(排除已重命名的字段)
|
||||
for i, newCol := range newColumns {
|
||||
field, _ := newCol["Field"].(string)
|
||||
if field == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是重命名的字段
|
||||
isRenamed := false
|
||||
var oldName string
|
||||
for old, new := range renameMap {
|
||||
if new == field {
|
||||
isRenamed = true
|
||||
oldName = old
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isRenamed {
|
||||
// 重命名的字段:如果属性有变化,需要 MODIFY COLUMN
|
||||
oldCol := currentFieldMap[oldName]
|
||||
needsModify := c.isColumnChanged(oldCol, newCol)
|
||||
|
||||
// 检查顺序变化:使用旧字段名在 currentOrder 中查找位置,与新位置比较
|
||||
oldIndex := -1
|
||||
for idx, name := range currentFieldOrder {
|
||||
if name == oldName {
|
||||
oldIndex = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
needsReorder := (oldIndex != -1 && oldIndex != i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
// 重命名后需要修改属性或位置
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentCol, exists := currentFieldMap[field]; exists {
|
||||
// 修改现有字段
|
||||
needsModify := c.isColumnChanged(currentCol, newCol)
|
||||
needsReorder := c.isColumnOrderChanged(currentFieldOrder, newFieldOrder, field, i)
|
||||
|
||||
if needsModify || needsReorder {
|
||||
stmt := c.buildModifyColumnStatement(tableName, field, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新字段(排除重命名的字段)
|
||||
stmt := c.buildAddColumnStatement(tableName, newCol, newFieldOrder, i)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的字段(排除已重命名的字段)
|
||||
for field := range currentFieldMap {
|
||||
if !newFieldMap[field] && renameMap[field] == "" {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`", tableName, field)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// buildIndexAlterStatements 构建索引变更的 ALTER TABLE 语句
|
||||
func (c *MySQLClient) buildIndexAlterStatements(tableName string, currentIndexes, newIndexes []map[string]interface{}) []string {
|
||||
var statements []string
|
||||
|
||||
// 创建索引名映射
|
||||
currentIndexMap := make(map[string]map[string]interface{})
|
||||
for _, idx := range currentIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "PRIMARY" {
|
||||
currentIndexMap[keyName] = idx
|
||||
}
|
||||
}
|
||||
|
||||
newIndexMap := make(map[string]bool)
|
||||
for _, idx := range newIndexes {
|
||||
if keyName, ok := idx["Key_name"].(string); ok && keyName != "" && keyName != "PRIMARY" {
|
||||
newIndexMap[keyName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理索引变更
|
||||
for _, newIdx := range newIndexes {
|
||||
keyName, _ := newIdx["Key_name"].(string)
|
||||
if keyName == "" || keyName == "PRIMARY" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentIdx, exists := currentIndexMap[keyName]; exists {
|
||||
// 修改现有索引
|
||||
if c.isIndexChanged(currentIdx, newIdx) {
|
||||
dropStmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
addStmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if addStmt != "" {
|
||||
statements = append(statements, dropStmt)
|
||||
statements = append(statements, addStmt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 添加新索引
|
||||
stmt := c.buildAddIndexStatement(tableName, newIdx)
|
||||
if stmt != "" {
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除不存在的索引
|
||||
for keyName := range currentIndexMap {
|
||||
if !newIndexMap[keyName] {
|
||||
stmt := fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`", tableName, keyName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// isColumnChanged 检查字段是否发生变化(不包括字段名)
|
||||
func (c *MySQLClient) isColumnChanged(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isColumnPropertiesEqual 检查字段属性是否完全相等(不包括字段名)
|
||||
func (c *MySQLClient) isColumnPropertiesEqual(oldCol, newCol map[string]interface{}) bool {
|
||||
fields := []string{"Type", "Null", "Default", "Extra", "Key", "Comment"}
|
||||
for _, field := range fields {
|
||||
oldVal := getStringValue(oldCol[field])
|
||||
newVal := getStringValue(newCol[field])
|
||||
if oldVal != newVal {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnOrderChanged 检查字段顺序是否发生变化
|
||||
func (c *MySQLClient) isColumnOrderChanged(currentOrder, newOrder []string, fieldName string, newIndex int) bool {
|
||||
// 查找字段在当前顺序中的位置
|
||||
currentIndex := -1
|
||||
for i, name := range currentOrder {
|
||||
if name == fieldName {
|
||||
currentIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果字段不存在于当前顺序中(新字段),不需要检查顺序
|
||||
if currentIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果索引相同,检查前面的字段是否相同
|
||||
if newIndex == currentIndex {
|
||||
// 检查前面的字段集合是否相同
|
||||
if newIndex > 0 {
|
||||
currentPrevFields := make(map[string]bool)
|
||||
for i := 0; i < currentIndex; i++ {
|
||||
currentPrevFields[currentOrder[i]] = true
|
||||
}
|
||||
|
||||
newPrevFields := make(map[string]bool)
|
||||
for i := 0; i < newIndex; i++ {
|
||||
newPrevFields[newOrder[i]] = true
|
||||
}
|
||||
|
||||
// 如果前面的字段集合不同,说明顺序变了
|
||||
if len(currentPrevFields) != len(newPrevFields) {
|
||||
return true
|
||||
}
|
||||
for f := range currentPrevFields {
|
||||
if !newPrevFields[f] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 索引不同,说明顺序变了
|
||||
return true
|
||||
}
|
||||
|
||||
// isIndexChanged 检查索引是否发生变化
|
||||
func (c *MySQLClient) isIndexChanged(oldIdx, newIdx map[string]interface{}) bool {
|
||||
oldCol := getStringValue(oldIdx["Column_name"])
|
||||
newCol := getStringValue(newIdx["Column_name"])
|
||||
if oldCol != newCol {
|
||||
return true
|
||||
}
|
||||
|
||||
oldUnique := getIntValue(oldIdx["Non_unique"])
|
||||
newUnique := getIntValue(newIdx["Non_unique"])
|
||||
return oldUnique != newUnique
|
||||
}
|
||||
|
||||
// buildAddColumnStatement 构建添加字段的语句
|
||||
func (c *MySQLClient) buildAddColumnStatement(tableName string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
field := getStringValue(col["Field"])
|
||||
if field == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildModifyColumnStatement 构建修改字段的语句
|
||||
func (c *MySQLClient) buildModifyColumnStatement(tableName, field string, col map[string]interface{}, fieldOrder []string, index int) string {
|
||||
colDef := c.buildColumnDefinition(col)
|
||||
|
||||
// 确定字段位置
|
||||
position := c.buildColumnPosition(fieldOrder, index)
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` MODIFY COLUMN %s%s", tableName, colDef, position)
|
||||
}
|
||||
|
||||
// buildColumnPosition 构建字段位置子句(AFTER 或 FIRST)
|
||||
func (c *MySQLClient) buildColumnPosition(fieldOrder []string, index int) string {
|
||||
if index < 0 || index >= len(fieldOrder) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
// 第一个字段使用 FIRST
|
||||
return " FIRST"
|
||||
}
|
||||
|
||||
// 其他字段使用 AFTER 前一个字段
|
||||
prevField := fieldOrder[index-1]
|
||||
return fmt.Sprintf(" AFTER `%s`", prevField)
|
||||
}
|
||||
|
||||
// buildColumnDefinition 构建字段定义
|
||||
func (c *MySQLClient) buildColumnDefinition(col map[string]interface{}) string {
|
||||
field := getStringValue(col["Field"])
|
||||
colType := getStringValue(col["Type"])
|
||||
null := getStringValue(col["Null"])
|
||||
defaultVal := col["Default"]
|
||||
extra := getStringValue(col["Extra"])
|
||||
comment := getStringValue(col["Comment"])
|
||||
|
||||
def := fmt.Sprintf("`%s` %s", field, colType)
|
||||
|
||||
if null == "NO" {
|
||||
def += " NOT NULL"
|
||||
}
|
||||
|
||||
if defaultVal != nil {
|
||||
if defaultStr, ok := defaultVal.(string); ok {
|
||||
if defaultStr == "" {
|
||||
// 空字符串表示默认值为空字符串
|
||||
def += " DEFAULT ''"
|
||||
} else if defaultStr != "NULL" {
|
||||
// 转义单引号
|
||||
escapedDefault := strings.ReplaceAll(defaultStr, "'", "''")
|
||||
def += fmt.Sprintf(" DEFAULT '%s'", escapedDefault)
|
||||
}
|
||||
// 如果 defaultStr == "NULL",不添加 DEFAULT 子句(允许 NULL)
|
||||
} else {
|
||||
// 非字符串类型的默认值
|
||||
def += fmt.Sprintf(" DEFAULT %v", defaultVal)
|
||||
}
|
||||
}
|
||||
|
||||
if extra != "" {
|
||||
def += " " + extra
|
||||
}
|
||||
|
||||
if comment != "" {
|
||||
// 转义单引号
|
||||
escapedComment := strings.ReplaceAll(comment, "'", "''")
|
||||
def += fmt.Sprintf(" COMMENT '%s'", escapedComment)
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// buildAddIndexStatement 构建添加索引的语句
|
||||
func (c *MySQLClient) buildAddIndexStatement(tableName string, idx map[string]interface{}) string {
|
||||
keyName := getStringValue(idx["Key_name"])
|
||||
columnName := getStringValue(idx["Column_name"])
|
||||
nonUnique := getIntValue(idx["Non_unique"])
|
||||
|
||||
if keyName == "" || columnName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
indexType := "INDEX"
|
||||
if nonUnique == 0 {
|
||||
indexType = "UNIQUE INDEX"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (`%s`)", tableName, indexType, keyName, columnName)
|
||||
}
|
||||
|
||||
// getStringValue 安全获取字符串值
|
||||
func getStringValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// getIntValue 安全获取整数值
|
||||
func getIntValue(v interface{}) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
case float64:
|
||||
return int(val)
|
||||
case string:
|
||||
var i int
|
||||
fmt.Sscanf(val, "%d", &i)
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user