Private
Public Access
1
0

重构: 死代码清理 + 拷贝优化 + 滚动条修复

This commit is contained in:
2026-04-11 23:36:08 +08:00
parent 7dbd57a8b6
commit 756028af0f
39 changed files with 185 additions and 1308 deletions

View File

@@ -12,6 +12,3 @@ const (
// DefaultVisibleTabs 默认可见的 Tabs
var DefaultVisibleTabs = []string{TabDatabase, TabFileSystem, TabDevice}
// DefaultTab 默认打开的 Tab
const DefaultTab = TabDatabase

View File

@@ -61,12 +61,3 @@ func IsWindows() bool {
return runtime.GOOS == "windows"
}
// IsMac 判断是否为Mac系统
func IsMac() bool {
return runtime.GOOS == "darwin"
}
// IsLinux 判断是否为Linux系统
func IsLinux() bool {
return runtime.GOOS == "linux"
}

View File

@@ -2,7 +2,6 @@ package filesystem
import (
"context"
"encoding/base64"
"fmt"
"log"
"net/http"
@@ -303,52 +302,6 @@ func getContentType(ext string) string {
return defaultFileTypeManager.GetMIMEType(ext)
}
// ReadFileAsBase64 读取文件并返回 base64 编码的字符串
// 用于读取从 ZIP 提取的临时图片文件
func ReadFileAsBase64(filePath string) (string, error) {
log.Printf("[ReadFileAsBase64] 读取文件: %s", filePath)
if !isSafePath(filePath) {
return "", fmt.Errorf("路径不安全")
}
// 检查文件是否存在
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("文件不存在: %s", filePath)
}
return "", fmt.Errorf("无法访问文件: %v", err)
}
log.Printf("[ReadFileAsBase64] 文件大小: %d bytes", fileInfo.Size())
// 读取文件
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("读取文件失败: %v", err)
}
// 编码为 base64
encoded := base64.StdEncoding.EncodeToString(data)
log.Printf("[ReadFileAsBase64] 编码成功: 原始=%d, base64=%d", len(data), len(encoded))
// 获取文件扩展名并确定 MIME 类型
ext := strings.ToLower(filepath.Ext(filePath))
mimeType := getContentType(ext)
// 返回 data URI 格式: data:image/png;base64,iVBORw0KG...
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), nil
}
// HandleLocalFile 处理 /localfs/ 路由的 HTTP 请求
// 前端可以请求 http://localhost:18765/localfs/C:/path/to/image.jpg
// 注意:此函数与 ServeHTTP 功能重复,建议统一使用 ServeHTTP
func HandleLocalFile(w http.ResponseWriter, r *http.Request) {
handler := NewLocalFileHandler()
handler.ServeHTTP(w, r)
}
// isAllowedFileType 检查文件类型是否在白名单中
func isAllowedFileType(ext string) bool {
return defaultFileTypeManager.IsAllowed(ext)

View File

@@ -220,37 +220,6 @@ func (a *AuditLogger) Close() error {
return a.logFile.Close()
}
// RotateLog 日志轮转(每天创建新文件)
func (a *AuditLogger) RotateLog() error {
a.mu.Lock()
defer a.mu.Unlock()
// 刷新缓冲区
if err := a.flush(); err != nil {
return err
}
// 关闭当前文件
if err := a.logFile.Close(); err != nil {
return err
}
// 生成新的日志文件名
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(filepath.Dir(a.logPath), fmt.Sprintf("audit_%s.log", timestamp))
// 打开新文件
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
a.logFile = logFile
a.logPath = logPath
return nil
}
// GetRecentLogs 获取最近的审计日志
func GetRecentLogs(logDir string, limit int) ([]AuditLogEntry, error) {
// 读取今天的日志文件
@@ -309,22 +278,8 @@ func parseLines(text string) []string {
var globalAuditLogger *AuditLogger
var auditLoggerOnce sync.Once
// InitAuditLogger 初始化全局审计日志记录器
func InitAuditLogger(logDir string) error {
var err error
globalAuditLogger, err = NewAuditLogger(logDir)
return err
}
// GetAuditLogger 获取全局审计日志记录器
func GetAuditLogger() *AuditLogger {
return globalAuditLogger
}
// CloseAuditLogger 关闭全局审计日志记录器
func CloseAuditLogger() error {
if globalAuditLogger != nil {
return globalAuditLogger.Close()
}
return nil
}

View File

@@ -13,22 +13,13 @@ const (
// HTTP 文件服务大小限制
MaxHTTPFileSize = 500 * 1024 * 1024 // 500MB - HTTP 访问文件最大大小
// 删除操作限制
MaxDeleteSizeGB = 1 * 1024 * 1024 * 1024 // 1GB - 单个文件删除大小限制
MaxDeleteDirSizeGB = 1 * 1024 * 1024 * 1024 // 1GB - 目录删除大小限制
)
// 时间相关常量
const (
// 审计日志
AuditFlushInterval = 5 * time.Second // 审计日志刷新间隔
AuditLogBufferSize = 100 // 审计日志缓冲区大小
// 回收站
RecycleBinRetentionDays = 30 // 回收站文件保留天数(天)
RecycleBinRetentionPeriod = 30 * 24 * time.Hour // 回收站文件保留期
// 临时文件
TempFileCleanupAge = 24 * time.Hour // 临时文件清理周期
TempFileDir = "u-desk-zip" // 临时文件目录名
@@ -36,7 +27,6 @@ const (
// 数量限制常量
const (
MaxDirectoryDepth = 15 // 最大目录深度
MaxFileCount = 1000 // 最大文件数量(目录)
)
@@ -48,15 +38,9 @@ const (
// 随机字符串相关常量
const (
RandomStringCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
RandomStringDefaultLength = 6 // 回收站文件名随机后缀长度
)
// 文件路径相关常量
const (
WindowsDriveLength = 2 // Windows 盘符长度 (C:)
)
// 路径遍历检测字符串
const (
PathTraversalPattern = ".." // 路径遍历特征字符串
@@ -69,17 +53,5 @@ const (
FileTypeAudio = "audio"
FileTypeDocument = "document"
FileTypeText = "text"
FileTypeArchive = "archive"
FileTypeApplication = "application"
)
// 安全相关常量
const (
// ZIP 安全
MinValidZipSize = 22 // ZIP 文件最小有效大小(文件头)
ZipFileHeaderSignature = 0x504B // "PK" - ZIP 文件头签名
// 文件锁
LockCheckMaxRetries = 3 // 文件锁检查最大重试次数
LockCheckRetryInterval = 100 * time.Millisecond // 文件锁检查重试间隔
)

View File

@@ -6,130 +6,6 @@ import (
"runtime"
)
// ErrorCode 错误码类型
type ErrorCode string
const (
// 通用错误
ErrCodeGeneral ErrorCode = "GENERAL_ERROR"
ErrCodeInvalid ErrorCode = "INVALID_ARGUMENT"
ErrCodeNotFound ErrorCode = "NOT_FOUND"
ErrCodePermission ErrorCode = "PERMISSION_DENIED"
ErrCodeIO ErrorCode = "IO_ERROR"
// 路径相关错误
ErrCodePathTraversal ErrorCode = "PATH_TRAVERSAL"
ErrCodeInvalidPath ErrorCode = "INVALID_PATH"
ErrCodeSensitivePath ErrorCode = "SENSITIVE_PATH"
// 文件操作错误
ErrCodeFileNotFound ErrorCode = "FILE_NOT_FOUND"
ErrCodeFileExists ErrorCode = "FILE_EXISTS"
ErrCodeDirectoryNotEmpty ErrorCode = "DIRECTORY_NOT_EMPTY"
// 安全相关错误
ErrCodeSecurityViolation ErrorCode = "SECURITY_VIOLATION"
ErrCodeSizeLimit ErrorCode = "SIZE_LIMIT_EXCEEDED"
ErrCodeFileLocked ErrorCode = "FILE_LOCKED"
// ZIP相关错误
ErrCodeZipInvalid ErrorCode = "ZIP_INVALID"
ErrCodeZipBomb ErrorCode = "ZIP_BOMB"
ErrCodeZipExtract ErrorCode = "ZIP_EXTRACT_FAILED"
)
// FileError 文件系统专用错误类型
// 包含详细的错误上下文信息,便于调试和用户提示
type FileNotFoundError struct {
Path string
Err error
}
func (e *FileNotFoundError) Error() string {
return fmt.Sprintf("文件不存在: %s", e.Path)
}
func (e *FileNotFoundError) Unwrap() error {
return e.Err
}
// PathValidationError 路径验证错误
type PathValidationError struct {
Path string
Reason string
IsSensitive bool
}
func (e *PathValidationError) Error() string {
return fmt.Sprintf("路径验证失败: %s - %s", e.Path, e.Reason)
}
// SecurityViolationError 安全违规错误
type SecurityViolationError struct {
Path string
Violation string
Suggestion string
}
func (e *SecurityViolationError) Error() string {
msg := fmt.Sprintf("安全违规: %s - %s", e.Path, e.Violation)
if e.Suggestion != "" {
msg += fmt.Sprintf("\n建议: %s", e.Suggestion)
}
return msg
}
// SizeLimitError 大小限制错误
type SizeLimitError struct {
Path string
ActualSize int64
MaxSize int64
SizeType string // "file" or "directory"
}
func (e *SizeLimitError) Error() string {
return fmt.Sprintf("%s大小超限: %s (实际: %.2f GB, 限制: %.2f GB)",
e.SizeType, e.Path,
float64(e.ActualSize)/(1024*1024*1024),
float64(e.MaxSize)/(1024*1024*1024),
)
}
// FileLockedError 文件锁定错误
type FileLockedError struct {
Path string
ProcessInfo string
}
func (e *FileLockedError) Error() string {
msg := fmt.Sprintf("文件被占用: %s", e.Path)
if e.ProcessInfo != "" {
msg += fmt.Sprintf("\n占用程序: %s", e.ProcessInfo)
}
return msg
}
// WrapError 错误包装函数
// 添加上下文信息到错误中
func WrapError(operation string, path string, err error) error {
return fmt.Errorf("%s 失败: %s - %w", operation, path, err)
}
// WrapErrorf 格式化错误包装
func WrapErrorf(format string, args ...interface{}) error {
return fmt.Errorf(format, args...)
}
// GetStackTrace 获取堆栈跟踪(用于调试)
func GetStackTrace(skip int) string {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
if n > 0 {
return string(buf[:n])
}
return ""
}
// DeleteRestrictionWarning 删除限制警告
// 用于在删除受限文件时提供详细的警告信息
type DeleteRestrictionWarning struct {
@@ -141,3 +17,13 @@ type DeleteRestrictionWarning struct {
func (w *DeleteRestrictionWarning) Error() string {
return fmt.Sprintf("删除限制警告: %s\n%s", w.Path, w.Details)
}
// GetStackTrace 获取堆栈跟踪(用于调试)
func GetStackTrace(skip int) string {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
if n > 0 {
return string(buf[:n])
}
return ""
}

View File

@@ -4,14 +4,6 @@ import (
"fmt"
"os"
"syscall"
"time"
)
// Windows API 锁相关函数和常量
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetLastError = modkernel32.NewProc("GetLastError")
procGetProcessId = modkernel32.NewProc("GetProcessId")
)
// FileLockChecker 文件锁检查器
@@ -102,37 +94,6 @@ func (c *FileLockChecker) getProcessInfo(path string) (string, error) {
return "文件正被其他程序使用", nil
}
// CheckFileWithRetry 带重试的文件锁检查
func (c *FileLockChecker) CheckFileWithRetry(path string, maxRetries int, retryInterval time.Duration) error {
for i := 0; i < maxRetries; i++ {
locked, processInfo, err := c.IsFileLocked(path)
if err != nil && !locked {
// 非锁相关的错误,直接返回
return err
}
if !locked {
// 文件未被锁定,可以操作
return nil
}
// 文件被锁定
if i < maxRetries-1 {
// 还有重试机会,等待后重试
time.Sleep(retryInterval)
continue
}
// 最后一次重试失败,返回错误
if processInfo != "" {
return fmt.Errorf("文件被占用: %s", processInfo)
}
return fmt.Errorf("文件被其他程序占用,请关闭相关程序后重试")
}
return fmt.Errorf("文件检查超时")
}
// SafeDeleteWithLockCheck 带锁检查的安全删除
func (c *FileLockChecker) SafeDeleteWithLockCheck(path string) error {
// 检查文件是否被锁定
@@ -158,20 +119,6 @@ const (
ERROR_SHARING_VIOLATION = 32 // syscall.Errno(32)
)
// BY_HANDLE_FILE_INFORMATION 文件信息结构体
type BY_HANDLE_FILE_INFORMATION struct {
FileAttributes uint32
CreationTime syscall.Filetime
LastAccessTime syscall.Filetime
LastWriteTime syscall.Filetime
VolumeSerialNumber uint32
FileSizeHigh uint32
FileSizeLow uint32
NumberOfLinks uint32
FileIndexHigh uint32
FileIndexLow uint32
}
// contains 检查字符串是否包含子串(不区分大小写)
func contains(str, substr string) bool {
return len(str) >= len(substr) && (str == substr || len(substr) == 0 ||
@@ -203,18 +150,3 @@ func containsIgnoreCase(str, substr string) bool {
return false
}
// 全局文件锁检查器
var globalLockChecker *FileLockChecker
// InitFileLockChecker 初始化全局文件锁检查器
func InitFileLockChecker() {
globalLockChecker = NewFileLockChecker()
}
// GetFileLockChecker 获取全局文件锁检查器
func GetFileLockChecker() *FileLockChecker {
if globalLockChecker == nil {
globalLockChecker = NewFileLockChecker()
}
return globalLockChecker
}

View File

@@ -70,11 +70,6 @@ func (l *Logger) Info(format string, args ...interface{}) {
l.log(LogLevelInfo, "INFO", format, args...)
}
// Warn 记录警告日志
func (l *Logger) Warn(format string, args ...interface{}) {
l.log(LogLevelWarn, "WARN", format, args...)
}
// Error 记录错误日志
func (l *Logger) Error(format string, args ...interface{}) {
l.log(LogLevelError, "ERROR", format, args...)
@@ -141,37 +136,10 @@ func LogError(operation string, path string, err error) {
var (
globalLogger *Logger
loggerOnce sync.Once
)
// InitLogger 初始化全局日志记录器
func InitLogger(logDir string, minLevel LogLevel) error {
var initErr error
loggerOnce.Do(func() {
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(logDir, fmt.Sprintf("filesystem_%s.log", timestamp))
logger, err := NewLogger(logPath, minLevel)
if err != nil {
initErr = err
return
}
globalLogger = logger
log.Printf("[日志系统] 已启动,日志文件: %s", logPath)
})
return initErr
}
// GetGlobalLogger 获取全局日志记录器
func GetGlobalLogger() *Logger {
return globalLogger
}
// CloseLogger 关闭全局日志记录器
func CloseLogger() error {
if globalLogger != nil {
return globalLogger.Close()
}
return nil
}

View File

@@ -376,16 +376,6 @@ func generateRandomString(length int) string {
// 全局回收站实例
var globalRecycleBin *RecycleBin
// InitRecycleBin 初始化全局回收站
func InitRecycleBin(binPath string) error {
bin, err := NewRecycleBin(binPath)
if err != nil {
return err
}
globalRecycleBin = bin
return nil
}
// GetRecycleBin 获取全局回收站实例
func GetRecycleBin() *RecycleBin {
return globalRecycleBin

View File

@@ -119,11 +119,6 @@ func (s *FileSystemService) initRecycleBin() error {
// ========== 核心文件操作 ==========
// Read 读取文件内容(实现 FileService 接口)
func (s *FileSystemService) Read(path string) (string, error) {
return s.ReadFile(path)
}
// ReadFile 读取文件内容(限制最大 10MB
func (s *FileSystemService) ReadFile(path string) (string, error) {
// 路径验证
@@ -151,10 +146,6 @@ func (s *FileSystemService) ReadFile(path string) (string, error) {
}
// Write 写入文件内容(实现 FileService 接口)
func (s *FileSystemService) Write(path, content string) error {
return s.WriteFile(path, content)
}
// writeFile 内部写入实现(路径验证+大小检查+写入+日志)
func (s *FileSystemService) writeFileWithLog(path string, data []byte) error {
if err := s.validatePath(path); err != nil {
@@ -192,31 +183,6 @@ func (s *FileSystemService) SaveBase64File(path, base64Content string) error {
return s.writeFileWithLog(path, data)
}
// List 列出目录内容(实现 FileService 接口)
func (s *FileSystemService) List(path string) ([]map[string]interface{}, error) {
return s.ListDir(path)
}
// Open 打开文件(实现 FileService 接口)
func (s *FileSystemService) Open(path string) error {
// 使用系统默认程序打开文件
var cmd *exec.Cmd
switch runtime.GOOS {
case "windows":
cmd = exec.Command("cmd", "/c", "start", "", path)
case "darwin":
cmd = exec.Command("open", path)
default:
cmd = exec.Command("xdg-open", path)
}
return cmd.Start()
}
// Delete 删除文件或目录(实现 FileService 接口)
func (s *FileSystemService) Delete(path string) (*FileOperationResult, error) {
return s.DeletePathWithContext(context.Background(), path)
}
// DeletePath 删除文件或目录
func (s *FileSystemService) DeletePath(path string) (*FileOperationResult, error) {
return s.DeletePathWithContext(context.Background(), path)
@@ -430,11 +396,6 @@ func (s *FileSystemService) CreateFile(path string) (*FileOperationResult, error
}, nil
}
// GetInfo 获取文件信息(实现 FileService 接口)
func (s *FileSystemService) GetInfo(path string) (map[string]interface{}, error) {
return s.GetFileInfo(path)
}
// GetFileInfo 获取文件信息
func (s *FileSystemService) GetFileInfo(path string) (map[string]interface{}, error) {
if err := s.validatePath(path); err != nil {
@@ -519,31 +480,16 @@ func (s *FileSystemService) RenamePath(oldPath, newPath string) (*FileOperationR
// ========== ZIP操作接口 ==========
// ListZip 列出ZIP文件内容
func (s *FileSystemService) ListZip(zipPath string) ([]map[string]interface{}, error) {
return ListZipContents(zipPath)
}
// ListZipContents 列出ZIP文件内容别名保持向后兼容
func (s *FileSystemService) ListZipContents(zipPath string) ([]map[string]interface{}, error) {
return ListZipContents(zipPath)
}
// ExtractZipFile 从ZIP提取文件内容
func (s *FileSystemService) ExtractZipFile(zipPath, filePath string) (string, error) {
return ExtractFileFromZip(zipPath, filePath)
}
// ExtractFileFromZip 从ZIP提取文件内容别名保持向后兼容
func (s *FileSystemService) ExtractFileFromZip(zipPath, filePath string) (string, error) {
return ExtractFileFromZip(zipPath, filePath)
}
// ExtractZipFileToTemp 从ZIP提取文件到临时目录
func (s *FileSystemService) ExtractZipFileToTemp(zipPath, filePath string) (string, error) {
return ExtractFileFromZipToTemp(zipPath, filePath)
}
// ExtractFileFromZipToTemp 从ZIP提取文件到临时目录别名保持向后兼容
func (s *FileSystemService) ExtractFileFromZipToTemp(zipPath, filePath string) (string, error) {
return ExtractFileFromZipToTemp(zipPath, filePath)
@@ -564,7 +510,9 @@ func getCurrentTimestamp() time.Time {
// isInRecycleBin 检查路径是否在回收站中
func isInRecycleBin(path string) bool {
recycleBinPath := filepath.Join(common.GetUserDataDir(), "recycle_bin")
return filepath.HasPrefix(filepath.Clean(path), filepath.Clean(recycleBinPath))
cleanPath := filepath.Clean(path)
cleanBinPath := filepath.Clean(recycleBinPath)
return len(cleanPath) >= len(cleanBinPath) && cleanPath[:len(cleanBinPath)] == cleanBinPath
}
// ========== 辅助方法 ==========
@@ -787,16 +735,3 @@ func GetGlobalService() (*FileSystemService, error) {
return globalService, initErr
}
// InitGlobalFileSystem 初始化全局文件系统(兼容旧代码)
func InitGlobalFileSystem() error {
_, err := GetGlobalService()
return err
}
// CloseGlobalFileSystem 关闭全局文件系统
func CloseGlobalFileSystem(ctx context.Context) error {
if globalService != nil {
return globalService.Close(ctx)
}
return nil
}

View File

@@ -346,46 +346,4 @@ func GetZipFileInfo(zipPath, filePath string) (map[string]interface{}, error) {
return result.(map[string]interface{}), nil
}
// validateZipFileBasic 验证ZIP文件的基本信息提取自ListZipContents
func validateZipFileBasic(zipPath string) error {
if err := validateZipPath(zipPath); err != nil {
return err
}
fileInfo, err := os.Stat(zipPath)
if err != nil {
return fmt.Errorf("无法访问文件: %v", err)
}
if fileInfo.Size() < MinValidZipSize {
return fmt.Errorf("文件太小 (%d bytes)", fileInfo.Size())
}
if fileInfo.Size() > MaxZipSize {
return fmt.Errorf("ZIP文件过大 (%d bytes)", fileInfo.Size())
}
return checkZipFileHeader(zipPath)
}
// checkZipFileHeader 检查ZIP文件头签名
func checkZipFileHeader(zipPath string) error {
file, err := os.Open(zipPath)
if err != nil {
return fmt.Errorf("无法打开文件: %v", err)
}
defer file.Close()
header := make([]byte, 4)
n, err := file.Read(header)
if err != nil || n != 4 {
return fmt.Errorf("无法读取文件头")
}
if header[0] != 0x50 || header[1] != 0x4B {
return fmt.Errorf("不是有效的 ZIP 文件")
}
return nil
}

View File

@@ -65,25 +65,6 @@ func isMatchFile(file *zip.File, targetPath string) bool {
filepath.Clean(file.Name) == filepath.Clean(targetPath)
}
// openZipFileInReader 在ZIP reader中打开指定文件
// 用于读取文件内容的辅助函数
func openZipFileInReader(reader *zip.ReadCloser, filePath string) (io.ReadCloser, *zip.File, error) {
for _, file := range reader.File {
if isMatchFile(file, filePath) {
if file.Mode().IsDir() {
return nil, nil, fmt.Errorf("不能读取目录")
}
rc, err := file.Open()
if err != nil {
return nil, nil, fmt.Errorf("打开 zip 中的文件失败: %v", err)
}
return rc, file, nil
}
}
return nil, nil, fmt.Errorf("文件在 zip 中不存在: %s", filePath)
}
// readAllFromFile 从文件读取所有内容
// 辅助函数,避免重复的 io.ReadAll 调用
func readAllFromFile(rc io.ReadCloser) ([]byte, error) {

View File

@@ -29,8 +29,3 @@ func (s *TabService) SaveTabs(tabs []models.SqlTab) error {
func (s *TabService) ListTabs() ([]models.SqlTab, error) {
return s.repo.FindAll()
}
// DeleteTab 删除标签页
func (s *TabService) DeleteTab(id uint) error {
return s.repo.Delete(id)
}

View File

@@ -102,22 +102,6 @@ func SaveUpdateConfig(config *UpdateConfig) error {
return nil
}
// ShouldCheckUpdate 判断是否应该检查更新
func (c *UpdateConfig) ShouldCheckUpdate() bool {
if !c.AutoCheckEnabled {
return false
}
// 如果从未检查过,应该检查
if c.LastCheckTime.IsZero() {
return true
}
// 检查是否超过间隔分钟数
minutesSinceLastCheck := time.Since(c.LastCheckTime).Minutes()
return minutesSinceLastCheck >= float64(c.CheckIntervalMinutes)
}
// UpdateLastCheckTime 更新最后检查时间
func (c *UpdateConfig) UpdateLastCheckTime() error {
c.LastCheckTime = time.Now()

View File

@@ -64,11 +64,6 @@ func ParseVersion(versionStr string) (*Version, error) {
return &Version{Major: major, Minor: minor, Patch: patch}, nil
}
// String 返回版本号字符串格式v1.0.0
func (v *Version) String() string {
return fmt.Sprintf("v%d.%d.%d", v.Major, v.Minor, v.Patch)
}
// Compare 比较版本号
// 返回值:-1 表示当前版本小于目标版本0 表示相等1 表示大于
func (v *Version) Compare(other *Version) int {
@@ -100,11 +95,6 @@ func (v *Version) IsNewerThan(other *Version) bool {
return v.Compare(other) > 0
}
// IsOlderThan 判断是否比目标版本旧
func (v *Version) IsOlderThan(other *Version) bool {
return v.Compare(other) < 0
}
// ==================== 版本号获取 ====================
// GetCurrentVersion 获取当前版本号(带缓存)

View File

@@ -1,279 +0,0 @@
package storage
import (
"context"
"encoding/json"
"fmt"
"u-desk/internal/crypto"
"u-desk/internal/dbclient"
"u-desk/internal/storage/models"
"gorm.io/gorm"
)
// ConnectionService 连接管理服务
type ConnectionService struct {
db *gorm.DB
}
// NewConnectionService 创建连接服务
func NewConnectionService() (*ConnectionService, error) {
db := GetDB()
if db == nil {
// 尝试重新初始化
var err error
db, err = Init()
if err != nil {
return nil, fmt.Errorf("数据库初始化失败: %v", err)
}
}
return &ConnectionService{db: db}, nil
}
// SaveConnection 保存连接配置
func (s *ConnectionService) SaveConnection(conn *models.DbConnection) error {
if conn.Name == "" {
return fmt.Errorf("连接名称不能为空")
}
if conn.Type == "" {
return fmt.Errorf("数据库类型不能为空")
}
if conn.Host == "" {
return fmt.Errorf("主机地址不能为空")
}
// 检查名称是否重复(排除当前记录)
var count int64
query := s.db.Model(&models.DbConnection{}).Where("name = ?", conn.Name)
if conn.ID > 0 {
query = query.Where("id != ?", conn.ID)
}
query.Count(&count)
if count > 0 {
return fmt.Errorf("连接名称已存在")
}
if conn.ID > 0 {
// 更新模式
updateData := map[string]interface{}{
"name": conn.Name,
"type": conn.Type,
"host": conn.Host,
"port": conn.Port,
"username": conn.Username,
"database": conn.Database,
"options": conn.Options,
"visible_databases": conn.VisibleDatabases,
}
// 如果提供了新密码,加密后更新
if conn.Password != "" {
encrypted, err := crypto.EncryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
updateData["password"] = encrypted
}
// 如果密码为空,不更新密码字段(保留原密码)
return s.db.Model(&models.DbConnection{}).Where("id = ?", conn.ID).Updates(updateData).Error
}
// 新增模式 - 必须提供密码
if conn.Password == "" {
return fmt.Errorf("新增连接时密码不能为空")
}
// 加密密码
encrypted, err := crypto.EncryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
conn.Password = encrypted
return s.db.Create(conn).Error
}
// ListConnections 获取连接列表
func (s *ConnectionService) ListConnections() ([]models.DbConnection, error) {
var connections []models.DbConnection
err := s.db.Order("created_at DESC").Find(&connections).Error
return connections, err
}
// GetConnection 获取连接详情
func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error) {
var conn models.DbConnection
err := s.db.First(&conn, id).Error
if err != nil {
return nil, err
}
return &conn, nil
}
// DeleteConnection 删除连接配置
func (s *ConnectionService) DeleteConnection(id uint) error {
var conn models.DbConnection
if err := s.db.First(&conn, id).Error; err != nil {
return nil // 连接不存在视为成功
}
// 使用事务删除
return s.db.Transaction(func(tx *gorm.DB) error {
// 清理关联数据
tx.Where("connection_id = ?", id).Delete(&models.SqlResultHistory{})
tx.Where("connection_id = ?", id).Delete(&models.SqlTab{})
// 删除连接
if err := tx.Delete(&conn).Error; err != nil {
return err
}
// 关闭连接池
dbclient.GetPool().CloseConnection(id, conn.Type)
return nil
})
}
// resolvePassword 解析密码(编辑模式下从已保存连接中获取)
func (s *ConnectionService) resolvePassword(id uint, password string) (string, error) {
if id > 0 && password == "" {
conn, err := s.GetConnection(id)
if err != nil {
return "", fmt.Errorf("获取连接信息失败: %v", err)
}
decryptPassword, err := crypto.DecryptPassword(conn.Password)
if err != nil {
return "", fmt.Errorf("密码解密失败: %v", err)
}
return decryptPassword, nil
}
return password, nil
}
// parseMongoOptions 解析 MongoDB 连接选项
func parseMongoOptions(options string) (authSource, authMechanism string) {
if options == "" {
return "", ""
}
var opts map[string]interface{}
if err := json.Unmarshal([]byte(options), &opts); err != nil {
return "", ""
}
authSource, _ = opts["authSource"].(string)
authMechanism, _ = opts["authMechanism"].(string)
return authSource, authMechanism
}
// TestConnection 测试连接(需要根据类型调用不同的测试方法)
func (s *ConnectionService) TestConnection(conn *models.DbConnection) error {
password, err := crypto.DecryptPassword(conn.Password)
if err != nil {
return fmt.Errorf("密码解密失败: %v", err)
}
authSource, authMechanism := parseMongoOptions(conn.Options)
return s.testConnectionByType(conn.Type, conn.Host, conn.Port, conn.Username, password, conn.Database, authSource, authMechanism)
}
// testConnectionByType 根据类型调用对应的测试方法
func (s *ConnectionService) testConnectionByType(dbType, host string, port int, username, password, database, authSource, authMechanism string) error {
switch dbType {
case "mysql":
return testMySQLConnection(host, port, username, password, database)
case "redis":
return testRedisConnection(host, port, password)
case "mongo":
return testMongoConnection(host, port, username, password, database, authSource, authMechanism)
default:
return fmt.Errorf("不支持的数据库类型: %s", dbType)
}
}
// testMySQLConnection 测试 MySQL 连接
func testMySQLConnection(host string, port int, username, password, database string) error {
return dbclient.TestMySQLConnection(host, port, username, password, database)
}
// testRedisConnection 测试 Redis 连接
func testRedisConnection(host string, port int, password string) error {
return dbclient.TestRedisConnection(host, port, password)
}
// testMongoConnection 测试 MongoDB 连接
func testMongoConnection(host string, port int, username, password, database, authSource, authMechanism string) error {
return dbclient.TestMongoConnectionWithOptions(host, port, username, password, database, authSource, authMechanism)
}
// TestConnectionWithParams 使用参数测试连接(不保存数据)
func (s *ConnectionService) TestConnectionWithParams(dbType, host string, port int, username, password, database, options string, id uint) error {
password, err := s.resolvePassword(id, password)
if err != nil {
return err
}
authSource, authMechanism := parseMongoOptions(options)
return s.testConnectionByType(dbType, host, port, username, password, database, authSource, authMechanism)
}
// LoadAllDatabases 加载全部数据库列表
func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, id uint) ([]string, error) {
password, err := s.resolvePassword(id, password)
if err != nil {
return nil, err
}
authSource, authMechanism := parseMongoOptions(options)
// 根据类型加载数据库列表
switch dbType {
case "mysql":
return loadMySQLDatabases(host, port, username, password, database)
case "mongo":
return loadMongoDatabasesWithOptions(host, port, username, password, database, authSource, authMechanism)
case "redis":
// Redis 没有数据库概念,返回空列表
return []string{}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", dbType)
}
}
// loadMySQLDatabases 加载 MySQL 数据库列表
func loadMySQLDatabases(host string, port int, username, password, defaultDatabase string) ([]string, error) {
config := &dbclient.MySQLConfig{
Host: host,
Port: port,
Username: username,
Password: password,
Database: defaultDatabase,
}
client, err := dbclient.NewMySQLClient(config)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}
// loadMongoDatabasesWithOptions 加载 MongoDB 数据库列表(使用解析后的选项)
func loadMongoDatabasesWithOptions(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) {
mongoConfig := &dbclient.MongoConfig{
Host: host,
Port: port,
Username: username,
Password: password,
Database: defaultDatabase,
AuthSource: authSource,
AuthMechanism: authMechanism,
}
client, err := dbclient.NewMongoClient(mongoConfig)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}

View File

@@ -1,20 +0,0 @@
package models
import (
"time"
)
// SqlFile SQL 文件记录
type SqlFile struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"type:varchar(200);not null" json:"name"` // 文件名
Path string `gorm:"type:varchar(500);not null;uniqueIndex" json:"path"` // 文件路径
Content string `gorm:"type:text" json:"content"` // 文件内容
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TableName 指定表名
func (SqlFile) TableName() string {
return "sql_file"
}

View File

@@ -1,20 +0,0 @@
package models
import "time"
// Version 版本信息
type Version struct {
ID int `gorm:"primaryKey" json:"id"` // 主键ID
Version string `gorm:"type:varchar(20);not null;uniqueIndex" json:"version"` // 版本号语义化版本如1.0.0
DownloadURL string `gorm:"type:varchar(500)" json:"download_url"` // 下载地址更新包下载URL
Changelog string `gorm:"type:text" json:"changelog"` // 更新日志Markdown格式
ForceUpdate int `gorm:"type:tinyint;not null;default:0" json:"force_update"` // 是否强制更新1:是 0:否)
ReleaseDate *time.Time `gorm:"type:date" json:"release_date"` // 发布日期
CreatedAt time.Time `gorm:"autoCreateTime:false" json:"created_at"` // 创建时间(由程序设置)
UpdatedAt time.Time `gorm:"autoUpdateTime:false" json:"updated_at"` // 更新时间(由程序设置)
}
// TableName 指定表名
func (Version) TableName() string {
return "sys_version"
}

View File

@@ -5,17 +5,13 @@ import (
"u-desk/internal/storage"
"u-desk/internal/storage/models"
"gorm.io/gorm"
"time"
)
type ResultRepository interface {
Save(connectionID uint, database, sql string, resultType string, data interface{}, columns []string, rowsAffected int, executionTime int64) (*models.SqlResultHistory, error)
FindByID(id uint) (*models.SqlResultHistory, error)
FindByConnection(connectionID uint, limit int) ([]models.SqlResultHistory, error)
Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error)
Delete(id uint) error
DeleteByConnection(connectionID uint) error
DeleteOld(keepDays int) error
}
type resultRepository struct {
@@ -61,15 +57,6 @@ func (r *resultRepository) FindByID(id uint) (*models.SqlResultHistory, error) {
return &history, err
}
func (r *resultRepository) FindByConnection(connectionID uint, limit int) ([]models.SqlResultHistory, error) {
var histories []models.SqlResultHistory
query := r.db.Where("connection_id = ?", connectionID).Order("created_at DESC")
if limit > 0 {
query = query.Limit(limit)
}
return histories, query.Find(&histories).Error
}
func (r *resultRepository) Search(connectionID *uint, keyword string, limit, offset int) ([]models.SqlResultHistory, int64, error) {
query := r.db.Model(&models.SqlResultHistory{})
@@ -101,10 +88,3 @@ func (r *resultRepository) Delete(id uint) error {
return r.db.Delete(&models.SqlResultHistory{}, id).Error
}
func (r *resultRepository) DeleteByConnection(connectionID uint) error {
return r.db.Where("connection_id = ?", connectionID).Delete(&models.SqlResultHistory{}).Error
}
func (r *resultRepository) DeleteOld(keepDays int) error {
return r.db.Where("created_at < ?", time.Now().AddDate(0, 0, -keepDays)).Delete(&models.SqlResultHistory{}).Error
}

View File

@@ -10,7 +10,6 @@ type TabRepository interface {
SaveAll(tabs []models.SqlTab) error
FindAll() ([]models.SqlTab, error)
Delete(id uint) error
DeleteAll() error
}
type tabRepository struct {
@@ -50,6 +49,3 @@ func (r *tabRepository) Delete(id uint) error {
return r.db.Delete(&models.SqlTab{}, id).Error
}
func (r *tabRepository) DeleteAll() error {
return r.db.Where("1=1").Delete(&models.SqlTab{}).Error
}