Private
Public Access
1
0
Files
u-desk/internal/storage/connection_service.go
绝尘 4a1f0213df 重构:消除代码重复,提升可维护性
后端优化:
- 新增 resolvePassword 函数,消除密码获取重复逻辑
- 新增 parseMongoOptions 函数,消除 Options 解析重复
- 新增 testConnectionByType 统一连接测试调用
- 重构 loadMongoDatabasesWithOptions 接收解析后参数
- 删除重复代码 37 行

前端优化:
- 新增 useVisibleDatabases composable
- 统一 visible_databases 解析和过滤逻辑
- 简化错误处理,移除 try-catch 包装
- 删除重复代码 22 行

代码质量:
- 消除 6 处重复代码块
- 新增 5 个可复用函数
- 提升代码可维护性和可测试性
2026-03-31 11:49:25 +08:00

280 lines
8.3 KiB
Go

package storage
import (
"context"
"encoding/json"
"fmt"
"u-desk/internal/crypto"
"u-desk/internal/dbclient"
"u-desk/internal/storage/models"
"gorm.io/gorm"
)
// ConnectionService 连接管理服务
type ConnectionService struct {
db *gorm.DB
}
// NewConnectionService 创建连接服务
func NewConnectionService() (*ConnectionService, error) {
db := GetDB()
if db == nil {
// 尝试重新初始化
var err error
db, err = Init()
if err != nil {
return nil, fmt.Errorf("数据库初始化失败: %v", err)
}
}
return &ConnectionService{db: db}, nil
}
// SaveConnection 保存连接配置
func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error {
if conn.Name == "" {
return fmt.Errorf("连接名称不能为空")
}
if conn.Type == "" {
return fmt.Errorf("数据库类型不能为空")
}
if conn.Host == "" {
return fmt.Errorf("主机地址不能为空")
}
// 检查名称是否重复(排除当前记录)
var count int64
query := s.db.Model(&models.DbConnection{}).Where("name = ?", conn.Name)
if conn.ID > 0 {
query = query.Where("id != ?", conn.ID)
}
query.Count(&count)
if count > 0 {
return fmt.Errorf("连接名称已存在")
}
if conn.ID > 0 {
// 更新模式
updateData := map[string]interface{}{
"name": conn.Name,
"type": conn.Type,
"host": conn.Host,
"port": conn.Port,
"username": conn.Username,
"database": conn.Database,
"options": conn.Options,
"visible_databases": conn.VisibleDatabases,
}
// 如果提供了新密码,加密后更新
if conn.Password != "" {
encrypted, err := crypto.EncryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
updateData["password"] = encrypted
}
// 如果密码为空,不更新密码字段(保留原密码)
return s.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(updateData).Error
}
// 新增模式 - 必须提供密码
if conn.Password == "" {
return fmt.Errorf("新增连接时密码不能为空")
}
// 加密密码
encrypted, err := crypto.EncryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
conn.Password = encrypted
return s.db.Create(conn).Error
}
// ListConnections 获取连接列表
func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) {
var connections []models.DbConnection
err := s.db.Order("created_at DESC").Find(&connections).Error
return connections, err
}
// GetConnection 获取连接详情
func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) {
var conn models.DbConnection
err := s.db.First(&conn, id).Error
if err != nil {
return nil, err
}
return &conn, nil
}
// DeleteConnection 删除连接配置
func (s *ConnectionService) DeleteConnection(id uint) error {
var conn models.DbConnection
if err := s.db.First(&conn, id).Error; err != nil {
return nil // 连接不存在视为成功
}
// 使用事务删除
return s.db.Transaction(func(tx *gorm.DB) error {
// 清理关联数据
tx.Where("connection_id = ?", id).Delete(&models.SqlResultHistory{})
tx.Where("connection_id = ?", id).Delete(&models.SqlTab{})
// 删除连接
if err := tx.Delete(&conn).Error; err != nil {
return err
}
// 关闭连接池
dbclient.GetPool().CloseConnection(id, conn.Type)
return nil
})
}
// resolvePassword 解析密码(编辑模式下从已保存连接中获取)
func (s *ConnectionService) resolvePassword(id uint, password string) (string, error) {
if id > 0 && password == "" {
conn, err := s.GetConnection(id)
if err != nil {
return "", fmt.Errorf("获取连接信息失败: %v", err)
}
decryptPassword, err := crypto.DecryptPassword(conn.Password)
if err != nil {
return "", fmt.Errorf("密码解密失败: %v", err)
}
return decryptPassword, nil
}
return password, nil
}
// parseMongoOptions 解析 MongoDB 连接选项
func parseMongoOptions(options string) (authSource, authMechanism string) {
if options == "" {
return "", ""
}
var opts map[string]interface{}
if err := json.Unmarshal([]byte(options), &opts); err != nil {
return "", ""
}
authSource, _ = opts["authSource"].(string)
authMechanism, _ = opts["authMechanism"].(string)
return authSource, authMechanism
}
// TestConnection 测试连接(需要根据类型调用不同的测试方法)
func (s *ConnectionService) TestConnection(conn *models.DbConnection) error {
password, err := crypto.DecryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码解密失败: %v", err)
}
authSource, authMechanism := parseMongoOptions(conn.Options)
return s.testConnectionByType(conn.Type, conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism)
}
// testConnectionByType 根据类型调用对应的测试方法
func (s *ConnectionService) testConnectionByType(dbType, host string, port int, username, password, database, authSource, authMechanism string) error {
switch dbType {
case "mysql":
return testMySQLConnection(host, port, username, password, database)
case "redis":
return testRedisConnection(host, port, password)
case "mongo":
return testMongoConnection(host, port, username, password, database, authSource, authMechanism)
default:
return fmt.Errorf("不支持的数据库类型: %s", dbType)
}
}
// testMySQLConnection 测试 MySQL 连接
func testMySQLConnection(host string, port int, username, password, database string) error {
return dbclient.TestMySQLConnection(host, port, username, password, database)
}
// testRedisConnection 测试 Redis 连接
func testRedisConnection(host string, port int, password string) error {
return dbclient.TestRedisConnection(host, port, password)
}
// testMongoConnection 测试 MongoDB 连接
func testMongoConnection(host string, port int, username, password, database, authSource, authMechanism string) error {
return dbclient.TestMongoConnectionWithOptions(host, port, username, password, database, authSource, authMechanism)
}
// TestConnectionWithParams 使用参数测试连接(不保存数据)
func (s *ConnectionService) TestConnectionWithParams(dbType, host string, port int, username, password, database, options string, id uint) error {
password, err := s.resolvePassword(id, password)
if err != nil {
return err
}
authSource, authMechanism := parseMongoOptions(options)
return s.testConnectionByType(dbType, host, port, username, password, database, authSource, authMechanism)
}
// LoadAllDatabases 加载全部数据库列表
func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, id uint) ([]string, error) {
password, err := s.resolvePassword(id, password)
if err != nil {
return nil, err
}
authSource, authMechanism := parseMongoOptions(options)
// 根据类型加载数据库列表
switch dbType {
case "mysql":
return loadMySQLDatabases(host, port, username, password, database)
case "mongo":
return loadMongoDatabasesWithOptions(host, port, username, password, database, authSource, authMechanism)
case "redis":
// Redis 没有数据库概念,返回空列表
return []string{}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", dbType)
}
}
// loadMySQLDatabases 加载 MySQL 数据库列表
func loadMySQLDatabases(host string, port int, username, password, defaultDatabase string) ([]string, error) {
config := &dbclient.MySQLConfig{
Host: host,
Port: port,
Username: username,
Password: password,
Database: defaultDatabase,
}
client, err := dbclient.NewMySQLClient(config)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}
// loadMongoDatabasesWithOptions 加载 MongoDB 数据库列表(使用解析后的选项)
func loadMongoDatabasesWithOptions(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) {
mongoConfig := &dbclient.MongoConfig{
Host: host,
Port: port,
Username: username,
Password: password,
Database: defaultDatabase,
AuthSource: authSource,
AuthMechanism: authMechanism,
}
client, err := dbclient.NewMongoClient(mongoConfig)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}