diff --git a/internal/storage/connection_service.go b/internal/storage/connection_service.go index 0068fad..63098a0 100644 --- a/internal/storage/connection_service.go +++ b/internal/storage/connection_service.go @@ -135,38 +135,59 @@ func (s *ConnectionService) DeleteConnection(id uint) error { }) } +// resolvePassword 解析密码(编辑模式下从已保存连接中获取) +func (s *ConnectionService) resolvePassword(id uint, password string) (string, error) { + if id > 0 && password == "" { + conn, err := s.GetConnection(id) + if err != nil { + return "", fmt.Errorf("获取连接信息失败: %v", err) + } + decryptPassword, err := crypto.DecryptPassword(conn.Password) + if err != nil { + return "", fmt.Errorf("密码解密失败: %v", err) + } + return decryptPassword, nil + } + return password, nil +} + +// parseMongoOptions 解析 MongoDB 连接选项 +func parseMongoOptions(options string) (authSource, authMechanism string) { + if options == "" { + return "", "" + } + var opts map[string]interface{} + if err := json.Unmarshal([]byte(options), &opts); err != nil { + return "", "" + } + authSource, _ = opts["authSource"].(string) + authMechanism, _ = opts["authMechanism"].(string) + return authSource, authMechanism +} + // TestConnection 测试连接(需要根据类型调用不同的测试方法) func (s *ConnectionService) TestConnection(conn *models.DbConnection) error { - // 解密密码用于测试 password, err := crypto.DecryptPassword(conn.Password) if err != nil { return fmt.Errorf("密码解密失败: %v", err) } - // 根据类型测试连接 - switch conn.Type { + authSource, authMechanism := parseMongoOptions(conn.Options) + + return s.testConnectionByType(conn.Type, conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism) +} + +// testConnectionByType 根据类型调用对应的测试方法 +func (s *ConnectionService) testConnectionByType(dbType, host string, port int, username, password, database, authSource, authMechanism string) error { + switch dbType { case "mysql": - return testMySQLConnection(conn.Host, conn.Port, conn.Username, password, conn.Database) + return testMySQLConnection(host, port, username, password, database) case "redis": - return testRedisConnection(conn.Host, conn.Port, password) + return testRedisConnection(host, port, password) case "mongo": - // 解析 Options 获取 MongoDB 连接参数 - authSource := "" - authMechanism := "" - if conn.Options != "" { - var opts map[string]interface{} - if err := json.Unmarshal([]byte(conn.Options), &opts); err == nil { - if as, ok := opts["authSource"].(string); ok && as != "" { - authSource = as - } - if am, ok := opts["authMechanism"].(string); ok && am != "" { - authMechanism = am - } - } - } - return testMongoConnection(conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism) + return testMongoConnection(host, port, username, password, database, authSource, authMechanism) default: - return fmt.Errorf("不支持的数据库类型: %s", conn.Type) + return fmt.Errorf("不支持的数据库类型: %s", dbType) } } @@ -187,67 +208,30 @@ func testMongoConnection(host string, port int, username, password, database, au // TestConnectionWithParams 使用参数测试连接(不保存数据) func (s *ConnectionService) TestConnectionWithParams(dbType, host string, port int, username, password, database, options string, id uint) error { - // 如果是编辑模式且有ID,获取已保存的密码 - if id > 0 && password == "" { - conn, err := s.GetConnection(id) - if err != nil { - return fmt.Errorf("获取连接信息失败: %v", err) - } - decryptPassword, err := crypto.DecryptPassword(conn.Password) - if err != nil { - return fmt.Errorf("密码解密失败: %v", err) - } - password = decryptPassword + password, err := s.resolvePassword(id, password) + if err != nil { + return err } - // 根据类型测试连接 - switch dbType { - case "mysql": - return testMySQLConnection(host, port, username, password, database) - case "redis": - return testRedisConnection(host, port, password) - case "mongo": - // 解析 Options 获取 MongoDB 连接参数 - authSource := "" - authMechanism := "" - if options != "" { - var opts map[string]interface{} - if err := json.Unmarshal([]byte(options), &opts); err == nil { - if as, ok := opts["authSource"].(string); ok && as != "" { - authSource = as - } - if am, ok := opts["authMechanism"].(string); ok && am != "" { - authMechanism = am - } - } - } - return testMongoConnection(host, port, username, password, database, authSource, authMechanism) - default: - return fmt.Errorf("不支持的数据库类型: %s", dbType) - } + authSource, authMechanism := parseMongoOptions(options) + return s.testConnectionByType(dbType, host, port, username, password, database, authSource, authMechanism) } // LoadAllDatabases 加载全部数据库列表 func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, id uint) ([]string, error) { - // 如果是编辑模式且有ID,获取已保存的密码 - if id > 0 && password == "" { - conn, err := s.GetConnection(id) - if err != nil { - return nil, fmt.Errorf("获取连接信息失败: %v", err) - } - decryptPassword, err := crypto.DecryptPassword(conn.Password) - if err != nil { - return nil, fmt.Errorf("密码解密失败: %v", err) - } - password = decryptPassword + password, err := s.resolvePassword(id, password) + if err != nil { + return nil, err } + authSource, authMechanism := parseMongoOptions(options) + // 根据类型加载数据库列表 switch dbType { case "mysql": return loadMySQLDatabases(host, port, username, password, database) case "mongo": - return loadMongoDatabases(host, port, username, password, database, options) + return loadMongoDatabasesWithOptions(host, port, username, password, database, authSource, authMechanism) case "redis": // Redis 没有数据库概念,返回空列表 return []string{}, nil @@ -274,23 +258,8 @@ func loadMySQLDatabases(host string, port int, username, password, defaultDataba return client.ListDatabases(context.Background()) } -// loadMongoDatabases 加载 MongoDB 数据库列表 -func loadMongoDatabases(host string, port int, username, password, defaultDatabase, options string) ([]string, error) { - // 解析 Options 获取 MongoDB 连接参数 - authSource := "" - authMechanism := "" - if options != "" { - var opts map[string]interface{} - if err := json.Unmarshal([]byte(options), &opts); err == nil { - if as, ok := opts["authSource"].(string); ok && as != "" { - authSource = as - } - if am, ok := opts["authMechanism"].(string); ok && am != "" { - authMechanism = am - } - } - } - +// loadMongoDatabasesWithOptions 加载 MongoDB 数据库列表(使用解析后的选项) +func loadMongoDatabasesWithOptions(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) { mongoConfig := &dbclient.MongoConfig{ Host: host, Port: port, diff --git a/web/src/composables/useVisibleDatabases.ts b/web/src/composables/useVisibleDatabases.ts new file mode 100644 index 0000000..6555230 --- /dev/null +++ b/web/src/composables/useVisibleDatabases.ts @@ -0,0 +1,50 @@ +/** + * 可见数据库管理 Composable + * 封装 visible_databases 字段的解析和过滤逻辑 + */ + +/** + * 解析可见数据库 JSON 字符串 + * @param jsonStr - JSON 字符串或 null + * @returns 解析后的数据库数组,解析失败返回空数组 + */ +export function parseVisibleDatabases(jsonStr: string | null): string[] { + if (!jsonStr) return [] + try { + const parsed = JSON.parse(jsonStr) + return Array.isArray(parsed) ? parsed : [] + } catch { + return [] + } +} + +/** + * 根据可见数据库配置过滤数据库列表 + * @param databases - 完整的数据库列表 + * @param visibleJson - 可见数据库 JSON 字符串 + * @returns 过滤后的数据库列表(如果未配置过滤则返回全部) + */ +export function filterDatabases(databases: string[], visibleJson: string | null): string[] { + const visible = parseVisibleDatabases(visibleJson) + return visible.length > 0 ? databases.filter(db => visible.includes(db)) : databases +} + +/** + * 将数据库数组序列化为 JSON 字符串(空数组返回空字符串) + * @param databases - 数据库数组 + * @returns JSON 字符串或空字符串 + */ +export function serializeVisibleDatabases(databases: string[]): string { + return databases.length > 0 ? JSON.stringify(databases) : '' +} + +/** + * 可见数据库管理 Composable + */ +export function useVisibleDatabases() { + return { + parse: parseVisibleDatabases, + filter: filterDatabases, + serialize: serializeVisibleDatabases, + } +} diff --git a/web/src/views/db-cli/components/ConnectionForm.vue b/web/src/views/db-cli/components/ConnectionForm.vue index 4463d48..4ac9483 100644 --- a/web/src/views/db-cli/components/ConnectionForm.vue +++ b/web/src/views/db-cli/components/ConnectionForm.vue @@ -162,6 +162,7 @@ import { SaveDbConnection } from '../../../wailsjs/wailsjs/go/main/App' import { getConnectionFailedTip, getLoadFailedTip } from '@/utils/database-error' +import { useVisibleDatabases } from '@/composables/useVisibleDatabases' // 使用 defineModel 简化 v-model:visible 双向绑定(Vue 3.5+) const visible = defineModel('visible', { type: Boolean, default: false }) @@ -183,6 +184,7 @@ const errorMessage = ref('') const isPasswordChanged = ref(false) // 数据库过滤相关 +const { parse: parseVisibleDatabases, filter: filterVisibleDatabases } = useVisibleDatabases() const loadingDatabases = ref(false) const allDatabases = ref([]) const selectedDatabases = ref([]) @@ -373,16 +375,7 @@ const loadConnection = async () => { isPasswordChanged.value = false // 恢复数据库选择 - if (conn.visible_databases) { - try { - selectedDatabases.value = JSON.parse(conn.visible_databases) - } catch (error) { - console.warn('解析可见数据库列表失败:', error) - selectedDatabases.value = [] - } - } else { - selectedDatabases.value = [] - } + selectedDatabases.value = parseVisibleDatabases(conn.visible_databases || null) // 编辑模式:自动加载数据库列表 nextTick(() => { @@ -558,18 +551,8 @@ const loadAllDatabases = async () => { allDatabases.value = databases || [] - // 从已保存的 visibleDatabases 中恢复选择 - if (form.visibleDatabases) { - try { - selectedDatabases.value = JSON.parse(form.visibleDatabases) - .filter((db: string) => databases.includes(db)) - } catch (error) { - console.warn('解析可见数据库列表失败:', error) - selectedDatabases.value = [] - } - } else { - selectedDatabases.value = [] - } + // 从已保存的 visibleDatabases 中恢复选择(使用 composable) + selectedDatabases.value = filterVisibleDatabases(databases, form.visibleDatabases || null) Message.success(`成功加载 ${databases.length} 个数据库`) } catch (error) { diff --git a/web/src/views/db-cli/components/ConnectionTree.vue b/web/src/views/db-cli/components/ConnectionTree.vue index ca801c8..6b0bcd0 100644 --- a/web/src/views/db-cli/components/ConnectionTree.vue +++ b/web/src/views/db-cli/components/ConnectionTree.vue @@ -142,6 +142,7 @@ import { STORAGE_KEYS } from '../constants/storage' import { listConnections, getDatabases, getTables, deleteConnection } from '@/api' import type { Connection } from '@/api' import { getLoadFailedTip } from '@/utils/database-error' +import { useVisibleDatabases } from '@/composables/useVisibleDatabases' // 连接类型定义(使用 API 层的类型) type DbConnection = Connection @@ -777,9 +778,11 @@ const withLoadingNode = async (nodeKey: string, loader: () => Promise) => } // 加载数据库列表 +const { filter: filterDatabases } = useVisibleDatabases() + const loadDatabases = async (connectionNode) => { if (!connectionNode?.connectionId) return - + await withLoadingNode(connectionNode.key, async () => { if (!window.go?.main?.App?.GetDatabases) { throw new Error('Go 后端未就绪') @@ -788,20 +791,9 @@ const loadDatabases = async (connectionNode) => { const databases = await getDatabases(connectionNode.connectionId) if (!Array.isArray(databases)) return - // 获取连接配置,检查是否有可见数据库过滤 + // 获取连接配置,应用可见数据库过滤 const conn = connections.value.find(c => c.id === connectionNode.connectionId) - let filteredDatabases = databases - - if (conn?.visible_databases) { - try { - const visibleDbs = JSON.parse(conn.visible_databases) - if (Array.isArray(visibleDbs) && visibleDbs.length > 0) { - filteredDatabases = databases.filter(db => visibleDbs.includes(db)) - } - } catch (error) { - console.warn('解析可见数据库列表失败:', error) - } - } + const filteredDatabases = filterDatabases(databases, conn?.visible_databases || null) // 根据数据库类型设置节点标题 connectionNode.children = filteredDatabases.map(db => ({