package storage import ( "context" "encoding/json" "fmt" "u-desk/internal/crypto" "u-desk/internal/dbclient" "u-desk/internal/storage/models" "gorm.io/gorm" ) // ConnectionService 连接管理服务 type ConnectionService struct { db *gorm.DB } // NewConnectionService 创建连接服务 func NewConnectionService() (*ConnectionService, error) { db := GetDB() if db == nil { // 尝试重新初始化 var err error db, err = Init() if err != nil { return nil, fmt.Errorf("数据库初始化失败: %v", err) } } return &ConnectionService{db: db}, nil } // SaveConnection 保存连接配置 func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error { if conn.Name == "" { return fmt.Errorf("连接名称不能为空") } if conn.Type == "" { return fmt.Errorf("数据库类型不能为空") } if conn.Host == "" { return fmt.Errorf("主机地址不能为空") } // 检查名称是否重复(排除当前记录) var count int64 query := s.db.Model(&models.DbConnection{}).Where("name = ?", conn.Name) if conn.ID > 0 { query = query.Where("id != ?", conn.ID) } query.Count(&count) if count > 0 { return fmt.Errorf("连接名称已存在") } if conn.ID > 0 { // 更新模式 updateData := map[string]interface{}{ "name": conn.Name, "type": conn.Type, "host": conn.Host, "port": conn.Port, "username": conn.Username, "database": conn.Database, "options": conn.Options, "visible_databases": conn.VisibleDatabases, } // 如果提供了新密码,加密后更新 if conn.Password != "" { encrypted, err := crypto.EncryptPassword(conn.Password) if err != nil { return fmt.Errorf("密码加密失败: %v", err) } updateData["password"] = encrypted } // 如果密码为空,不更新密码字段(保留原密码) return s.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(updateData).Error } // 新增模式 - 必须提供密码 if conn.Password == "" { return fmt.Errorf("新增连接时密码不能为空") } // 加密密码 encrypted, err := crypto.EncryptPassword(conn.Password) if err != nil { return fmt.Errorf("密码加密失败: %v", err) } conn.Password = encrypted return s.db.Create(conn).Error } // ListConnections 获取连接列表 func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) { var connections []models.DbConnection err := s.db.Order("created_at DESC").Find(&connections).Error return connections, err } // GetConnection 获取连接详情 func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) { var conn models.DbConnection err := s.db.First(&conn, id).Error if err != nil { return nil, err } return &conn, nil } // DeleteConnection 删除连接配置 func (s *ConnectionService) DeleteConnection(id uint) error { var conn models.DbConnection if err := s.db.First(&conn, id).Error; err != nil { return nil // 连接不存在视为成功 } // 使用事务删除 return s.db.Transaction(func(tx *gorm.DB) error { // 清理关联数据 tx.Where("connection_id = ?", id).Delete(&models.SqlResultHistory{}) tx.Where("connection_id = ?", id).Delete(&models.SqlTab{}) // 删除连接 if err := tx.Delete(&conn).Error; err != nil { return err } // 关闭连接池 dbclient.GetPool().CloseConnection(id, conn.Type) return nil }) } // TestConnection 测试连接(需要根据类型调用不同的测试方法) func (s *ConnectionService) TestConnection(conn *models.DbConnection) error { // 解密密码用于测试 password, err := crypto.DecryptPassword(conn.Password) if err != nil { return fmt.Errorf("密码解密失败: %v", err) } // 根据类型测试连接 switch conn.Type { case "mysql": return testMySQLConnection(conn.Host, conn.Port, conn.Username, password, conn.Database) case "redis": return testRedisConnection(conn.Host, conn.Port, password) case "mongo": // 解析 Options 获取 MongoDB 连接参数 authSource := "" authMechanism := "" if conn.Options != "" { var opts map[string]interface{} if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil { if as, ok := opts["authSource"].(string); ok && as != "" { authSource = as } if am, ok := opts["authMechanism"].(string); ok && am != "" { authMechanism = am } } } return testMongoConnection(conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism) default: return fmt.Errorf("不支持的数据库类型: %s", conn.Type) } } // testMySQLConnection 测试 MySQL 连接 func testMySQLConnection(host string, port int, username, password, database string) error { return dbclient.TestMySQLConnection(host, port, username, password, database) } // testRedisConnection 测试 Redis 连接 func testRedisConnection(host string, port int, password string) error { return dbclient.TestRedisConnection(host, port, password) } // testMongoConnection 测试 MongoDB 连接 func testMongoConnection(host string, port int, username, password, database, authSource, authMechanism string) error { return dbclient.TestMongoConnectionWithOptions(host, port, username, password, database, authSource, authMechanism) } // TestConnectionWithParams 使用参数测试连接(不保存数据) func (s *ConnectionService) TestConnectionWithParams(dbType, host string, port int, username, password, database, options string, id uint) error { // 如果是编辑模式且有ID,获取已保存的密码 if id > 0 && password == "" { conn, err := s.GetConnection(id) if err != nil { return fmt.Errorf("获取连接信息失败: %v", err) } decryptPassword, err := crypto.DecryptPassword(conn.Password) if err != nil { return fmt.Errorf("密码解密失败: %v", err) } password = decryptPassword } // 根据类型测试连接 switch dbType { case "mysql": return testMySQLConnection(host, port, username, password, database) case "redis": return testRedisConnection(host, port, password) case "mongo": // 解析 Options 获取 MongoDB 连接参数 authSource := "" authMechanism := "" if options != "" { var opts map[string]interface{} if err := json.Unmarshal([]byte(options), &opts); err == nil { if as, ok := opts["authSource"].(string); ok && as != "" { authSource = as } if am, ok := opts["authMechanism"].(string); ok && am != "" { authMechanism = am } } } return testMongoConnection(host, port, username, password, database, authSource, authMechanism) default: return fmt.Errorf("不支持的数据库类型: %s", dbType) } } // LoadAllDatabases 加载全部数据库列表 func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, id uint) ([]string, error) { // 如果是编辑模式且有ID,获取已保存的密码 if id > 0 && password == "" { conn, err := s.GetConnection(id) if err != nil { return nil, fmt.Errorf("获取连接信息失败: %v", err) } decryptPassword, err := crypto.DecryptPassword(conn.Password) if err != nil { return nil, fmt.Errorf("密码解密失败: %v", err) } password = decryptPassword } // 根据类型加载数据库列表 switch dbType { case "mysql": return loadMySQLDatabases(host, port, username, password, database) case "mongo": return loadMongoDatabases(host, port, username, password, database, options) case "redis": // Redis 没有数据库概念,返回空列表 return []string{}, nil default: return nil, fmt.Errorf("不支持的数据库类型: %s", dbType) } } // loadMySQLDatabases 加载 MySQL 数据库列表 func loadMySQLDatabases(host string, port int, username, password, defaultDatabase string) ([]string, error) { config := &dbclient.MySQLConfig{ Host: host, Port: port, Username: username, Password: password, Database: defaultDatabase, } client, err := dbclient.NewMySQLClient(config) if err != nil { return nil, err } defer client.Close() return client.ListDatabases(context.Background()) } // loadMongoDatabases 加载 MongoDB 数据库列表 func loadMongoDatabases(host string, port int, username, password, defaultDatabase, options string) ([]string, error) { // 解析 Options 获取 MongoDB 连接参数 authSource := "" authMechanism := "" if options != "" { var opts map[string]interface{} if err := json.Unmarshal([]byte(options), &opts); err == nil { if as, ok := opts["authSource"].(string); ok && as != "" { authSource = as } if am, ok := opts["authMechanism"].(string); ok && am != "" { authMechanism = am } } } mongoConfig := &dbclient.MongoConfig{ Host: host, Port: port, Username: username, Password: password, Database: defaultDatabase, AuthSource: authSource, AuthMechanism: authMechanism, } client, err := dbclient.NewMongoClient(mongoConfig) if err != nil { return nil, err } defer client.Close() return client.ListDatabases(context.Background()) }