diff --git a/app.go b/app.go index 3559e4c..3c38738 100644 --- a/app.go +++ b/app.go @@ -9,6 +9,7 @@ import ( "go-desk/internal/storage" "go-desk/internal/system" "os" + "strings" "github.com/wailsapp/wails/v2/pkg/runtime" ) @@ -20,6 +21,7 @@ type App struct { connectionAPI *api.ConnectionAPI sqlAPI *api.SqlAPI tabAPI *api.TabAPI + updateAPI *api.UpdateAPI } // NewApp 创建新的应用实例 @@ -27,8 +29,8 @@ func NewApp() *App { return &App{} } -// startup 应用启动时调用 -func (a *App) startup(ctx context.Context) { +// Startup 应用启动时调用 +func (a *App) Startup(ctx context.Context) { a.ctx = ctx // 初始化 SQLite 本地存储(核心依赖,必须成功) @@ -52,17 +54,15 @@ func (a *App) startup(ctx context.Context) { if err := a.initAPIs(); err != nil { panic(fmt.Sprintf("API 初始化失败,应用无法启动: %v", err)) } + + // 设置 updateAPI 的上下文 + if a.updateAPI != nil { + a.updateAPI.SetContext(ctx) + } } // QueryUsers 查询用户列表 func (a *App) QueryUsers(keyword string, status int, role int, organid int, page int, pageSize int, sortField string, sortOrder string) (map[string]interface{}, error) { - if a.db == nil { - return map[string]interface{}{ - "rows": []interface{}{}, - "total": 0, - }, nil - } - return a.db.QueryUsers(keyword, status, role, organid, page, pageSize, sortField, sortOrder) } @@ -125,24 +125,13 @@ func (a *App) GetFileInfo(path string) (map[string]interface{}, error) { func (a *App) GetEnvVars() (map[string]string, error) { envVars := make(map[string]string) for _, env := range os.Environ() { - parts := splitEnv(env) - if len(parts) == 2 { - envVars[parts[0]] = parts[1] + if key, value, found := strings.Cut(env, "="); found { + envVars[key] = value } } return envVars, nil } -// splitEnv 分割环境变量字符串(key=value) -func splitEnv(env string) []string { - for i := 0; i < len(env); i++ { - if env[i] == '=' { - return []string{env[:i], env[i+1:]} - } - } - return []string{env} -} - // ========== 数据库连接管理接口 ========== // initAPIs 初始化所有API(在startup中调用) @@ -157,6 +146,10 @@ func (a *App) initAPIs() error { return err } a.tabAPI, err = api.NewTabAPI() + if err != nil { + return err + } + a.updateAPI, err = api.NewUpdateAPI("https://img.1216.top/go-desk/last-version.json") return err } @@ -267,3 +260,46 @@ func (a *App) SaveSqlTabs(tabs []map[string]interface{}) error { func (a *App) ListSqlTabs() ([]map[string]interface{}, error) { return a.tabAPI.ListSqlTabs() } + +// ========== 版本更新管理接口 ========== + +// CheckUpdate 检查更新 +func (a *App) CheckUpdate() (map[string]interface{}, error) { + return a.updateAPI.CheckUpdate() +} + +// GetCurrentVersion 获取当前版本号 +func (a *App) GetCurrentVersion() (map[string]interface{}, error) { + return a.updateAPI.GetCurrentVersion() +} + +// GetUpdateConfig 获取更新配置 +func (a *App) GetUpdateConfig() (map[string]interface{}, error) { + return a.updateAPI.GetUpdateConfig() +} + +// SetUpdateConfig 设置更新配置 +func (a *App) SetUpdateConfig(autoCheckEnabled bool, checkIntervalMinutes int, checkURL string) (map[string]interface{}, error) { + return a.updateAPI.SetUpdateConfig(autoCheckEnabled, checkIntervalMinutes, checkURL) +} + +// DownloadUpdate 下载更新包 +func (a *App) DownloadUpdate(downloadURL string) (map[string]interface{}, error) { + return a.updateAPI.DownloadUpdate(downloadURL) +} + +// InstallUpdate 安装更新包 +func (a *App) InstallUpdate(installerPath string, autoRestart bool) (map[string]interface{}, error) { + return a.updateAPI.InstallUpdate(installerPath, autoRestart) +} + +// InstallUpdateWithHash 安装更新包(带哈希验证) +func (a *App) InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (map[string]interface{}, error) { + return a.updateAPI.InstallUpdateWithHash(installerPath, autoRestart, expectedHash, hashType) +} + +// VerifyUpdateFile 验证更新文件哈希值 +func (a *App) VerifyUpdateFile(filePath string, expectedHash string, hashType string) (map[string]interface{}, error) { + return a.updateAPI.VerifyUpdateFile(filePath, expectedHash, hashType) +} + diff --git a/internal/api/update_api.go b/internal/api/update_api.go new file mode 100644 index 0000000..0188aa7 --- /dev/null +++ b/internal/api/update_api.go @@ -0,0 +1,185 @@ +package api + +import ( + "context" + "encoding/json" + "go-desk/internal/service" + "time" + + "github.com/wailsapp/wails/v2/pkg/runtime" +) + +// UpdateAPI 版本更新 API +type UpdateAPI struct { + updateService *service.UpdateService + ctx context.Context +} + +// NewUpdateAPI 创建版本更新 API +func NewUpdateAPI(checkURL string) (*UpdateAPI, error) { + return &UpdateAPI{ + updateService: service.NewUpdateService(checkURL), + }, nil +} + +// SetContext 设置上下文(用于事件推送) +func (api *UpdateAPI) SetContext(ctx context.Context) { + api.ctx = ctx +} + +// successResponse 构造成功响应 +func successResponse(data interface{}) map[string]interface{} { + return map[string]interface{}{"success": true, "data": data} +} + +// errorResponse 构造错误响应 +func errorResponse(message string) map[string]interface{} { + return map[string]interface{}{"success": false, "message": message} +} + +// CheckUpdate 检查更新 +func (api *UpdateAPI) CheckUpdate() (map[string]interface{}, error) { + result, err := api.updateService.CheckUpdate() + if err != nil { + return errorResponse(err.Error()), nil + } + + return successResponse(result), nil +} + +// GetCurrentVersion 获取当前版本号 +func (api *UpdateAPI) GetCurrentVersion() (map[string]interface{}, error) { + version := service.GetCurrentVersion() + + // 同步配置中的版本号 + if config, err := service.LoadUpdateConfig(); err == nil && config.CurrentVersion != version { + config.CurrentVersion = version + service.SaveUpdateConfig(config) + } + + return successResponse(map[string]interface{}{ + "version": version, + }), nil +} + +// GetUpdateConfig 获取更新配置 +func (api *UpdateAPI) GetUpdateConfig() (map[string]interface{}, error) { + config, err := service.LoadUpdateConfig() + if err != nil { + return errorResponse(err.Error()), nil + } + + // 同步最新版本号 + latestVersion := service.GetCurrentVersion() + if config.CurrentVersion != latestVersion { + config.CurrentVersion = latestVersion + service.SaveUpdateConfig(config) + } + + return successResponse(map[string]interface{}{ + "current_version": config.CurrentVersion, + "last_check_time": config.LastCheckTime.Format("2006-01-02 15:04:05"), + "auto_check_enabled": config.AutoCheckEnabled, + "check_interval_minutes": config.CheckIntervalMinutes, + "check_url": config.CheckURL, + }), nil +} + +// SetUpdateConfig 设置更新配置 +func (api *UpdateAPI) SetUpdateConfig(autoCheckEnabled bool, checkIntervalMinutes int, checkURL string) (map[string]interface{}, error) { + config, err := service.LoadUpdateConfig() + if err != nil { + return errorResponse(err.Error()), nil + } + + config.AutoCheckEnabled = autoCheckEnabled + config.CheckIntervalMinutes = checkIntervalMinutes + if checkURL != "" { + config.CheckURL = checkURL + api.updateService = service.NewUpdateService(checkURL) + } + + if err := service.SaveUpdateConfig(config); err != nil { + return errorResponse(err.Error()), nil + } + + return successResponse(map[string]interface{}{ + "message": "配置保存成功", + }), nil +} + +// DownloadUpdate 下载更新包(异步,通过事件推送进度) +func (api *UpdateAPI) DownloadUpdate(downloadURL string) (map[string]interface{}, error) { + if downloadURL == "" { + return errorResponse("下载地址不能为空"), nil + } + + go func() { + progressCallback := func(progress float64, speed float64, downloaded int64, total int64) { + progressInfo := map[string]interface{}{ + "progress": progress, + "speed": speed, + "downloaded": downloaded, + "total": total, + } + progressJSON, _ := json.Marshal(progressInfo) + runtime.EventsEmit(api.ctx, "download-progress", string(progressJSON)) + } + + time.Sleep(100 * time.Millisecond) + result, err := service.DownloadUpdate(downloadURL, progressCallback) + + if err != nil { + errorInfo := map[string]interface{}{"error": err.Error()} + errorJSON, _ := json.Marshal(errorInfo) + runtime.EventsEmit(api.ctx, "download-complete", string(errorJSON)) + } else { + resultInfo := map[string]interface{}{ + "success": true, + "file_path": result.FilePath, + "file_size": result.FileSize, + } + resultJSON, _ := json.Marshal(resultInfo) + runtime.EventsEmit(api.ctx, "download-complete", string(resultJSON)) + } + }() + + return successResponse(map[string]interface{}{ + "message": "下载已开始", + }), nil +} + +// InstallUpdate 安装更新包 +func (api *UpdateAPI) InstallUpdate(installerPath string, autoRestart bool) (map[string]interface{}, error) { + return api.InstallUpdateWithHash(installerPath, autoRestart, "", "") +} + +// InstallUpdateWithHash 安装更新包(带哈希验证) +func (api *UpdateAPI) InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (map[string]interface{}, error) { + if installerPath == "" { + return errorResponse("安装文件路径不能为空"), nil + } + + result, err := service.InstallUpdateWithHash(installerPath, autoRestart, expectedHash, hashType) + if err != nil { + return errorResponse(err.Error()), nil + } + + return successResponse(result), nil +} + +// VerifyUpdateFile 验证更新文件哈希值 +func (api *UpdateAPI) VerifyUpdateFile(filePath string, expectedHash string, hashType string) (map[string]interface{}, error) { + if filePath == "" { + return errorResponse("文件路径不能为空"), nil + } + + valid, err := service.VerifyFileHash(filePath, expectedHash, hashType) + if err != nil { + return errorResponse(err.Error()), nil + } + + return successResponse(map[string]interface{}{ + "valid": valid, + }), nil +} diff --git a/internal/service/update.go b/internal/service/update.go new file mode 100644 index 0000000..8073c4d --- /dev/null +++ b/internal/service/update.go @@ -0,0 +1,429 @@ +package service + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "time" +) + +// ==================== 类型定义 ==================== + +// UpdateService 更新服务 +type UpdateService struct { + checkURL string // 版本检查接口 URL +} + +// RemoteVersionInfo 远程版本信息 +type RemoteVersionInfo struct { + Version string `json:"version"` + DownloadURL string `json:"download_url"` + Changelog string `json:"changelog"` + ForceUpdate bool `json:"force_update"` + ReleaseDate string `json:"release_date"` +} + +// UpdateCheckResult 更新检查结果 +type UpdateCheckResult struct { + HasUpdate bool `json:"has_update"` + CurrentVersion string `json:"current_version"` + LatestVersion string `json:"latest_version"` + DownloadURL string `json:"download_url"` + Changelog string `json:"changelog"` + ForceUpdate bool `json:"force_update"` + ReleaseDate string `json:"release_date"` +} + +// InstallResult 安装结果 +type InstallResult struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +// ==================== 更新服务 ==================== + +// NewUpdateService 创建更新服务 +func NewUpdateService(checkURL string) *UpdateService { + return &UpdateService{ + checkURL: checkURL, + } +} + +// CheckUpdate 检查更新 +func (s *UpdateService) CheckUpdate() (*UpdateCheckResult, error) { + log.Printf("[更新检查] 开始检查更新,检查地址: %s", s.checkURL) + + config, err := LoadUpdateConfig() + if err != nil { + return nil, fmt.Errorf("加载配置失败: %v", err) + } + + // 同步版本号 + currentVersionStr, err := s.syncConfigVersion(config) + if err != nil { + return nil, err + } + + currentVersion, err := ParseVersion(currentVersionStr) + if err != nil { + return nil, fmt.Errorf("解析当前版本失败: %v", err) + } + + // 请求远程版本信息 + remoteInfo, err := s.fetchRemoteVersionInfo() + if err != nil { + return nil, fmt.Errorf("获取远程版本信息失败: %v", err) + } + + log.Printf("[更新检查] 远程版本信息: 版本=%s, 下载地址=%s, 强制更新=%v", + remoteInfo.Version, remoteInfo.DownloadURL, remoteInfo.ForceUpdate) + + // 解析远程版本号 + remoteVersion, err := ParseVersion(remoteInfo.Version) + if err != nil { + return nil, fmt.Errorf("解析远程版本号失败: %v", err) + } + + // 比较版本 + hasUpdate := remoteVersion.IsNewerThan(currentVersion) + log.Printf("[更新检查] 版本比较: 当前=%s, 远程=%s, 有更新=%v", + currentVersion.String(), remoteVersion.String(), hasUpdate) + + // 更新最后检查时间 + config.UpdateLastCheckTime() + + result := &UpdateCheckResult{ + HasUpdate: hasUpdate, + CurrentVersion: currentVersionStr, + LatestVersion: remoteInfo.Version, + DownloadURL: remoteInfo.DownloadURL, + Changelog: remoteInfo.Changelog, + ForceUpdate: remoteInfo.ForceUpdate, + ReleaseDate: remoteInfo.ReleaseDate, + } + + log.Printf("[更新检查] 检查完成: 有更新=%v", hasUpdate) + return result, nil +} + +// syncConfigVersion 同步配置中的版本号 +func (s *UpdateService) syncConfigVersion(config *UpdateConfig) (string, error) { + currentVersionStr := GetCurrentVersion() + if currentVersionStr == "" { + currentVersionStr = config.CurrentVersion + log.Printf("[更新检查] 使用配置中的版本号: %s", currentVersionStr) + } else if config.CurrentVersion != currentVersionStr { + log.Printf("[更新检查] 配置中的版本号 (%s) 与当前版本号 (%s) 不一致,更新配置", + config.CurrentVersion, currentVersionStr) + config.CurrentVersion = currentVersionStr + if err := SaveUpdateConfig(config); err != nil { + log.Printf("[更新检查] 更新配置失败: %v", err) + } + } + return currentVersionStr, nil +} + +// fetchRemoteVersionInfo 获取远程版本信息 +func (s *UpdateService) fetchRemoteVersionInfo() (*RemoteVersionInfo, error) { + if s.checkURL == "" { + log.Printf("[远程版本] 版本检查 URL 未配置") + return nil, fmt.Errorf("版本检查 URL 未配置,请先设置检查地址") + } + + log.Printf("[远程版本] 请求远程版本信息: %s", s.checkURL) + + // 创建 HTTP 客户端,设置超时 + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // 发送请求 + resp, err := client.Get(s.checkURL) + if err != nil { + log.Printf("[远程版本] 网络请求失败: %v", err) + return nil, fmt.Errorf("网络请求失败: %v", err) + } + defer resp.Body.Close() + + log.Printf("[远程版本] HTTP 响应状态码: %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode) + } + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("[远程版本] 读取响应失败: %v", err) + return nil, fmt.Errorf("读取响应失败: %v", err) + } + + log.Printf("[远程版本] 响应内容长度: %d 字节", len(body)) + + // 解析 JSON + var remoteInfo RemoteVersionInfo + if err := json.Unmarshal(body, &remoteInfo); err != nil { + log.Printf("[远程版本] 解析 JSON 失败: %v, 响应内容: %s", err, string(body)) + return nil, fmt.Errorf("解析响应失败: %v", err) + } + + if remoteInfo.Version == "" { + log.Printf("[远程版本] 远程版本信息不完整,响应内容: %s", string(body)) + return nil, fmt.Errorf("远程版本信息不完整") + } + + log.Printf("[远程版本] 成功获取远程版本信息: %+v", remoteInfo) + return &remoteInfo, nil +} + +// ==================== 安装更新 ==================== + +// InstallUpdate 安装更新包 +func InstallUpdate(installerPath string, autoRestart bool) (*InstallResult, error) { + return InstallUpdateWithHash(installerPath, autoRestart, "", "") +} + +// InstallUpdateWithHash 安装更新包(带哈希验证) +func InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (*InstallResult, error) { + if _, err := os.Stat(installerPath); os.IsNotExist(err) { + return nil, fmt.Errorf("安装文件不存在: %s", installerPath) + } + + // 哈希验证 + if expectedHash != "" && hashType != "" { + valid, err := VerifyFileHash(installerPath, expectedHash, hashType) + if err != nil { + return &InstallResult{Success: false, Message: "文件验证失败: " + err.Error()}, nil + } + if !valid { + return &InstallResult{Success: false, Message: "文件哈希值不匹配,文件可能已损坏或被篡改"}, nil + } + } + + // 备份 + backupPath, err := BackupApplication() + if err != nil { + return &InstallResult{Success: false, Message: fmt.Sprintf("备份失败: %v", err)}, nil + } + + // 安装 + ext := filepath.Ext(installerPath) + switch ext { + case ".exe": + if runtime.GOOS != "windows" { + return &InstallResult{Success: false, Message: "当前系统不是 Windows,无法安装 .exe 文件"}, nil + } + err = installExe(installerPath) + case ".zip": + err = installZip(installerPath) + default: + return nil, fmt.Errorf("不支持的安装包格式: %s", ext) + } + + // 处理安装结果 + if err != nil { + // 安装失败,尝试回滚 + if backupPath != "" { + _ = rollbackFromBackup(backupPath) + } + return &InstallResult{Success: false, Message: fmt.Sprintf("安装失败: %v", err)}, nil + } + + // 自动重启 + if autoRestart { + go func() { + time.Sleep(2 * time.Second) + restartApplication() + }() + } + + return &InstallResult{Success: true, Message: "安装成功"}, nil +} + +// ==================== 安装相关辅助函数 ==================== + +// installExe 安装 exe 文件 +func installExe(exePath string) error { + execPath, err := os.Executable() + if err != nil { + return err + } + return replaceExecutableFile(exePath, execPath) +} + +// installZip 安装 ZIP 压缩包 +func installZip(zipPath string) error { + // 这里需要导入 archive/zip 包 + return fmt.Errorf("ZIP 安装暂未实现") +} + +// replaceExecutableFile 替换可执行文件(Windows 和 Unix 通用逻辑) +func replaceExecutableFile(newFilePath, execPath string) error { + if runtime.GOOS == "windows" { + return replaceExecutableFileWindows(newFilePath, execPath) + } + + // Unix-like: 直接替换 + if err := copyFile(newFilePath, execPath); err != nil { + return fmt.Errorf("复制文件失败: %v", err) + } + return os.Chmod(execPath, 0755) +} + +// replaceExecutableFileWindows Windows 平台替换可执行文件 +func replaceExecutableFileWindows(newFilePath, execPath string) error { + oldExecPath := execPath + ".old" + newExecPathTemp := execPath + ".new" + + // 清理旧文件 + os.Remove(oldExecPath) + os.Remove(newExecPathTemp) + + // 复制新文件到临时位置 + if err := copyFile(newFilePath, newExecPathTemp); err != nil { + return fmt.Errorf("复制新文件失败: %v", err) + } + + // 尝试重命名当前文件(如果失败,说明文件正在使用) + if err := os.Rename(execPath, oldExecPath); err != nil { + return nil + } + + // 替换文件 + if err := os.Rename(newExecPathTemp, execPath); err != nil { + os.Rename(oldExecPath, execPath) // 恢复 + return fmt.Errorf("替换文件失败: %v", err) + } + + // 延迟删除旧文件 + go func() { + time.Sleep(10 * time.Second) + os.Remove(oldExecPath) + }() + return nil +} + +// restartApplication 重启应用 +func restartApplication() { + execPath, err := os.Executable() + if err != nil { + return + } + + currentPID := os.Getpid() + if runtime.GOOS != "windows" { + return + } + + replacePendingFile(execPath) + + // 创建并执行重启脚本 + tempDir := os.TempDir() + batFile := filepath.Join(tempDir, fmt.Sprintf("restart_%d.bat", currentPID)) + execDir := filepath.Dir(execPath) + + batContent := fmt.Sprintf(`@echo off +cd /d "%s" +start "" "%s" +timeout /t 3 /nobreak >nul +taskkill /PID %d /F >nul 2>&1 +del "%%~f0" +`, execDir, execPath, currentPID) + + if err := os.WriteFile(batFile, []byte(batContent), 0644); err != nil { + fallbackRestart(execPath) + return + } + + cmd := exec.Command("cmd", "/C", batFile) + cmd.Stdout = nil + cmd.Stderr = nil + if err := cmd.Start(); err != nil { + fallbackRestart(execPath) + return + } + + os.Exit(0) +} + +// fallbackRestart 降级重启方案 +func fallbackRestart(execPath string) { + exec.Command("cmd", "/C", "start", "", execPath).Start() + time.Sleep(2 * time.Second) + os.Exit(0) +} + +// replacePendingFile 替换待替换的文件(.new -> 可执行文件) +func replacePendingFile(execPath string) error { + newExecPathTemp := execPath + ".new" + if _, err := os.Stat(newExecPathTemp); os.IsNotExist(err) { + return nil // 没有待替换文件 + } + + oldExecPath := execPath + ".old" + os.Remove(oldExecPath) + if err := os.Rename(newExecPathTemp, execPath); err != nil { + return fmt.Errorf("文件替换失败: %v", err) + } + return nil +} + +// rollbackFromBackup 从备份恢复 +func rollbackFromBackup(backupPath string) error { + execPath, err := os.Executable() + if err != nil { + return err + } + return copyFile(backupPath, execPath) +} + +// BackupApplication 备份当前应用 +func BackupApplication() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", err + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("获取用户目录失败: %v", err) + } + + backupDir := filepath.Join(homeDir, ".go-desk", "backups") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return "", fmt.Errorf("创建备份目录失败: %v", err) + } + + timestamp := time.Now().Format("20060102-150405") + backupFileName := fmt.Sprintf("go-desk-backup-%s%s", timestamp, filepath.Ext(execPath)) + backupPath := filepath.Join(backupDir, backupFileName) + + if err := copyFile(execPath, backupPath); err != nil { + return "", fmt.Errorf("复制文件失败: %v", err) + } + + return backupPath, nil +} + +// copyFile 复制文件 +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = destFile.ReadFrom(sourceFile) + return err +} diff --git a/internal/service/update_config.go b/internal/service/update_config.go new file mode 100644 index 0000000..c7ee3d7 --- /dev/null +++ b/internal/service/update_config.go @@ -0,0 +1,133 @@ +package service + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "time" +) + +// UpdateConfig 更新配置 +type UpdateConfig struct { + CurrentVersion string `json:"current_version"` + LastCheckTime time.Time `json:"last_check_time"` + AutoCheckEnabled bool `json:"auto_check_enabled"` + CheckIntervalMinutes int `json:"check_interval_minutes"` // 检查间隔(分钟) + CheckURL string `json:"check_url,omitempty"` // 版本检查接口 URL +} + +// GetUpdateConfigPath 获取更新配置文件路径 +func GetUpdateConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("获取用户目录失败: %v", err) + } + + configDir := filepath.Join(homeDir, ".go-desk") + if err := os.MkdirAll(configDir, 0755); err != nil { + return "", fmt.Errorf("创建配置目录失败: %v", err) + } + + return filepath.Join(configDir, "update_config.json"), nil +} + +// LoadUpdateConfig 加载更新配置 +func LoadUpdateConfig() (*UpdateConfig, error) { + configPath, err := GetUpdateConfigPath() + if err != nil { + return nil, err + } + + // 如果文件不存在,返回默认配置 + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return &UpdateConfig{ + CurrentVersion: GetCurrentVersion(), + LastCheckTime: time.Time{}, + AutoCheckEnabled: true, + CheckIntervalMinutes: 1, + CheckURL: "https://img.1216.top/go-desk/last-version.json", + }, nil + } + + data, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %v", err) + } + + var config UpdateConfig + if err := json.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %v", err) + } + + // 兼容旧配置:如果 check_interval_minutes 为 0,尝试从旧字段转换 + if config.CheckIntervalMinutes == 0 { + var configMap map[string]interface{} + if json.Unmarshal(data, &configMap) == nil { + if days, ok := configMap["check_interval_days"].(float64); ok && days > 0 { + config.CheckIntervalMinutes = int(days * 24 * 60) + } + } + if config.CheckIntervalMinutes == 0 { + config.CheckIntervalMinutes = 1 + } + } + + // 同步最新版本号 + latestVersion := GetCurrentVersion() + if config.CurrentVersion == "" || config.CurrentVersion != latestVersion { + if config.CurrentVersion != "" { + log.Printf("[配置] 配置中的版本号 (%s) 与最新版本号 (%s) 不一致", config.CurrentVersion, latestVersion) + } + config.CurrentVersion = latestVersion + } + + // 使用默认检查地址 + if config.CheckURL == "" { + config.CheckURL = "https://img.1216.top/go-desk/last-version.json" + } + + return &config, nil +} + +// SaveUpdateConfig 保存更新配置 +func SaveUpdateConfig(config *UpdateConfig) error { + configPath, err := GetUpdateConfigPath() + if err != nil { + return err + } + + data, err := json.MarshalIndent(config, "", " ") + if err != nil { + return fmt.Errorf("序列化配置失败: %v", err) + } + + if err := os.WriteFile(configPath, data, 0644); err != nil { + return fmt.Errorf("写入配置文件失败: %v", err) + } + + return nil +} + +// ShouldCheckUpdate 判断是否应该检查更新 +func (c *UpdateConfig) ShouldCheckUpdate() bool { + if !c.AutoCheckEnabled { + return false + } + + // 如果从未检查过,应该检查 + if c.LastCheckTime.IsZero() { + return true + } + + // 检查是否超过间隔分钟数 + minutesSinceLastCheck := time.Since(c.LastCheckTime).Minutes() + return minutesSinceLastCheck >= float64(c.CheckIntervalMinutes) +} + +// UpdateLastCheckTime 更新最后检查时间 +func (c *UpdateConfig) UpdateLastCheckTime() error { + c.LastCheckTime = time.Now() + return SaveUpdateConfig(c) +} diff --git a/internal/service/update_download.go b/internal/service/update_download.go new file mode 100644 index 0000000..7d0f223 --- /dev/null +++ b/internal/service/update_download.go @@ -0,0 +1,340 @@ +package service + +import ( + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "time" +) + +// ==================== 类型定义 ==================== + +// DownloadProgress 下载进度回调函数类型 +type DownloadProgress func(progress float64, speed float64, downloaded int64, total int64) + +// DownloadResult 下载结果 +type DownloadResult struct { + FilePath string `json:"file_path"` + FileSize int64 `json:"file_size"` + MD5Hash string `json:"md5_hash,omitempty"` + SHA256Hash string `json:"sha256_hash,omitempty"` +} + +// ==================== 下载更新 ==================== + +// DownloadUpdate 下载更新包 +func DownloadUpdate(downloadURL string, progressCallback DownloadProgress) (*DownloadResult, error) { + log.Printf("[下载] 开始下载,URL: %s", downloadURL) + + // 获取下载目录 + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("获取用户目录失败: %v", err) + } + + downloadDir := filepath.Join(homeDir, ".go-desk", "downloads") + if err := os.MkdirAll(downloadDir, 0755); err != nil { + return nil, fmt.Errorf("创建下载目录失败: %v", err) + } + + // 从 URL 提取文件名 + filename := filepath.Base(downloadURL) + if filename == "" || filename == "." { + filename = fmt.Sprintf("update-%d.exe", time.Now().Unix()) + } + + filePath := filepath.Join(downloadDir, filename) + + // 检查文件是否已存在 + var downloadedSize int64 = 0 + if fileInfo, err := os.Stat(filePath); err == nil { + downloadedSize = fileInfo.Size() + // 检查是否已完整下载 + if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && downloadedSize == remoteSize { + log.Printf("[下载] 文件已存在且完整: %s", filePath) + if progressCallback != nil { + progressCallback(100.0, 0, downloadedSize, remoteSize) + time.Sleep(50 * time.Millisecond) + } + md5Hash, sha256Hash, err := calculateFileHashes(filePath) + if err != nil { + return nil, fmt.Errorf("计算文件哈希失败: %v", err) + } + return &DownloadResult{ + FilePath: filePath, + FileSize: downloadedSize, + MD5Hash: md5Hash, + SHA256Hash: sha256Hash, + }, nil + } + } + + // 打开文件(支持断点续传) + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, fmt.Errorf("创建文件失败: %v", err) + } + defer file.Close() + + // 创建 HTTP 请求 + req, err := http.NewRequest("GET", downloadURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %v", err) + } + + // 如果已下载部分,设置 Range 头(断点续传) + if downloadedSize > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloadedSize)) + log.Printf("[下载] 启用断点续传,已下载: %d 字节", downloadedSize) + } + + // 发送请求 + client := &http.Client{Timeout: 30 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("下载请求失败: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { + file.Close() + if fileInfo, err := os.Stat(filePath); err == nil { + if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && fileInfo.Size() == remoteSize { + log.Printf("[下载] 文件已完整下载") + if progressCallback != nil { + progressCallback(100.0, 0, fileInfo.Size(), remoteSize) + } + md5Hash, sha256Hash, err := calculateFileHashes(filePath) + if err != nil { + return nil, fmt.Errorf("计算文件哈希失败: %v", err) + } + return &DownloadResult{ + FilePath: filePath, + FileSize: fileInfo.Size(), + MD5Hash: md5Hash, + SHA256Hash: sha256Hash, + }, nil + } + } + return nil, fmt.Errorf("服务器返回 416 错误,且文件可能不完整") + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode) + } + + // 获取文件总大小 + contentLength := resp.ContentLength + gotTotalFromRange := false + + // 优先从 Content-Range 头获取总大小 + if rangeHeader := resp.Header.Get("Content-Range"); rangeHeader != "" { + var start, end, total int64 + if n, _ := fmt.Sscanf(rangeHeader, "bytes %d-%d/%d", &start, &end, &total); n == 3 && total > 0 { + contentLength = total + gotTotalFromRange = true + } + } + + // 如果未获取到,尝试通过 HEAD 请求获取 + if contentLength <= 0 && !gotTotalFromRange { + if remoteSize, err := getRemoteFileSize(downloadURL); err == nil { + contentLength = remoteSize + } + } + + // 断点续传时,如果未从 Content-Range 获取,需要加上已下载部分 + if resp.StatusCode == http.StatusPartialContent && downloadedSize > 0 && contentLength > 0 && !gotTotalFromRange { + if contentLength < downloadedSize { + contentLength += downloadedSize + } + } + + log.Printf("[下载] 开始下载文件,总大小: %d 字节,已下载: %d 字节", contentLength, downloadedSize) + + // 发送初始进度事件 + if progressCallback != nil { + time.Sleep(50 * time.Millisecond) + if contentLength > 0 && downloadedSize > 0 { + progress := normalizeProgress(float64(downloadedSize) / float64(contentLength) * 100) + progressCallback(progress, 0, downloadedSize, contentLength) + } else { + progressCallback(0, 0, downloadedSize, -1) + } + time.Sleep(20 * time.Millisecond) + } + + // 下载文件 + buffer := make([]byte, 32*1024) // 32KB 缓冲区 + var totalDownloaded int64 = downloadedSize + startTime := time.Now() + lastProgressTime := startTime + lastProgressSize := totalDownloaded + + for { + n, err := resp.Body.Read(buffer) + if n > 0 { + written, writeErr := file.Write(buffer[:n]) + if writeErr != nil { + return nil, fmt.Errorf("写入文件失败: %v", writeErr) + } + totalDownloaded += int64(written) + + // 计算进度和速度 + now := time.Now() + elapsed := now.Sub(lastProgressTime).Seconds() + + // 每 0.3 秒更新一次进度 + if elapsed >= 0.3 { + progress := float64(0) + if contentLength > 0 { + progress = normalizeProgress(float64(totalDownloaded) / float64(contentLength) * 100) + } + + speed := float64(0) + if elapsed > 0 { + speed = float64(totalDownloaded-lastProgressSize) / elapsed + } + + if progressCallback != nil { + progressCallback(progress, speed, totalDownloaded, contentLength) + } + + lastProgressTime = now + lastProgressSize = totalDownloaded + } + } + + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("读取数据失败: %v", err) + } + } + + // 最后一次进度更新 + if progressCallback != nil { + if contentLength > 0 { + progressCallback(100.0, 0, totalDownloaded, contentLength) + } else { + progressCallback(100.0, 0, totalDownloaded, totalDownloaded) + } + } + + file.Close() + log.Printf("[下载] 下载完成,文件大小: %d 字节", totalDownloaded) + + md5Hash, sha256Hash, err := calculateFileHashes(filePath) + if err != nil { + return nil, fmt.Errorf("计算文件哈希失败: %v", err) + } + + log.Printf("[下载] 文件哈希计算完成,MD5: %s, SHA256: %s", md5Hash, sha256Hash) + + fileInfo, err := os.Stat(filePath) + if err != nil { + return nil, fmt.Errorf("获取文件信息失败: %v", err) + } + + return &DownloadResult{ + FilePath: filePath, + FileSize: fileInfo.Size(), + MD5Hash: md5Hash, + SHA256Hash: sha256Hash, + }, nil +} + +// ==================== 辅助函数 ==================== + +// getRemoteFileSize 通过HEAD请求获取远程文件大小 +func getRemoteFileSize(url string) (int64, error) { + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return 0, err + } + resp, err := client.Do(req) + if err != nil { + return 0, err + } + defer resp.Body.Close() + if resp.ContentLength > 0 { + return resp.ContentLength, nil + } + return 0, fmt.Errorf("无法获取文件大小") +} + +// normalizeProgress 标准化进度值到0-100之间 +func normalizeProgress(progress float64) float64 { + if progress < 0 { + return 0 + } + if progress > 100 { + return 100 + } + return progress +} + +// calculateFileHashes 计算文件的 MD5 和 SHA256 哈希值 +func calculateFileHashes(filePath string) (string, string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", "", err + } + defer file.Close() + + md5Hash := md5.New() + sha256Hash := sha256.New() + + // 使用 MultiWriter 同时计算两个哈希 + writer := io.MultiWriter(md5Hash, sha256Hash) + + if _, err := io.Copy(writer, file); err != nil { + return "", "", err + } + + md5Sum := hex.EncodeToString(md5Hash.Sum(nil)) + sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil)) + + return md5Sum, sha256Sum, nil +} + +// VerifyFileHash 验证文件哈希值 +func VerifyFileHash(filePath string, expectedHash string, hashType string) (bool, error) { + file, err := os.Open(filePath) + if err != nil { + return false, err + } + defer file.Close() + + var hash []byte + var calculatedHash string + + switch hashType { + case "md5": + md5Hash := md5.New() + if _, err := io.Copy(md5Hash, file); err != nil { + return false, err + } + hash = md5Hash.Sum(nil) + calculatedHash = hex.EncodeToString(hash) + case "sha256": + sha256Hash := sha256.New() + if _, err := io.Copy(sha256Hash, file); err != nil { + return false, err + } + hash = sha256Hash.Sum(nil) + calculatedHash = hex.EncodeToString(hash) + default: + return false, fmt.Errorf("不支持的哈希类型: %s", hashType) + } + + return calculatedHash == expectedHash, nil +} diff --git a/internal/service/version.go b/internal/service/version.go new file mode 100644 index 0000000..715091a --- /dev/null +++ b/internal/service/version.go @@ -0,0 +1,161 @@ +package service + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strconv" + "strings" +) + +// ==================== 常量定义 ==================== + +// AppVersion 应用版本号(发布时直接修改此处) +const AppVersion = "0.1.0" + +// ==================== 类型定义 ==================== + +// Version 版本号结构 +type Version struct { + Major int + Minor int + Patch int +} + +// WailsConfig Wails 配置文件结构 +type WailsConfig struct { + Version string `json:"version"` +} + +// ==================== 版本号解析和比较 ==================== + +// ParseVersion 解析版本号字符串(支持 v1.0.0 或 1.0.0 格式) +func ParseVersion(versionStr string) (*Version, error) { + versionStr = strings.TrimPrefix(versionStr, "v") + parts := strings.Split(versionStr, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("版本号格式错误,应为 x.y.z 格式") + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return nil, fmt.Errorf("主版本号解析失败: %v", err) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("次版本号解析失败: %v", err) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return nil, fmt.Errorf("修订号解析失败: %v", err) + } + + return &Version{Major: major, Minor: minor, Patch: patch}, nil +} + +// String 返回版本号字符串(格式:v1.0.0) +func (v *Version) String() string { + return fmt.Sprintf("v%d.%d.%d", v.Major, v.Minor, v.Patch) +} + +// Compare 比较版本号 +// 返回值:-1 表示当前版本小于目标版本,0 表示相等,1 表示大于 +func (v *Version) Compare(other *Version) int { + switch { + case v.Major != other.Major: + return compareInt(v.Major, other.Major) + case v.Minor != other.Minor: + return compareInt(v.Minor, other.Minor) + case v.Patch != other.Patch: + return compareInt(v.Patch, other.Patch) + default: + return 0 + } +} + +// compareInt 比较两个整数 +func compareInt(a, b int) int { + if a < b { + return -1 + } + if a > b { + return 1 + } + return 0 +} + +// IsNewerThan 判断是否比目标版本新 +func (v *Version) IsNewerThan(other *Version) bool { + return v.Compare(other) > 0 +} + +// IsOlderThan 判断是否比目标版本旧 +func (v *Version) IsOlderThan(other *Version) bool { + return v.Compare(other) < 0 +} + +// ==================== 版本号获取 ==================== + +// GetCurrentVersion 获取当前版本号 +// 优先级:硬编码版本号 > wails.json(开发模式)> 默认值 +func GetCurrentVersion() string { + if AppVersion != "" { + log.Printf("[版本] 使用硬编码版本号: %s", AppVersion) + return AppVersion + } + + version := getVersionFromWailsJSON() + if version != "" { + log.Printf("[版本] 从 wails.json 获取版本号: %s", version) + return version + } + + log.Printf("[版本] 使用默认版本号: 0.0.1") + return "0.0.1" +} + +// ==================== 配置文件读取 ==================== + +// getVersionFromWailsJSON 从 wails.json 读取版本号(仅开发模式使用) +func getVersionFromWailsJSON() string { + wd, err := os.Getwd() + if err != nil { + return "" + } + + // 尝试当前目录 + if version := readVersionFromFile(filepath.Join(wd, "wails.json")); version != "" { + return version + } + + // 尝试父目录 + if version := readVersionFromFile(filepath.Join(filepath.Dir(wd), "wails.json")); version != "" { + return version + } + + return "" +} + +// readVersionFromFile 从指定文件读取版本号 +func readVersionFromFile(filePath string) string { + data, err := os.ReadFile(filePath) + if err != nil { + log.Printf("[版本] 读取文件失败: %s, 错误: %v", filePath, err) + return "" + } + + var config WailsConfig + if err := json.Unmarshal(data, &config); err != nil { + log.Printf("[版本] 解析 JSON 失败: %s, 错误: %v", filePath, err) + return "" + } + + if config.Version != "" { + log.Printf("[版本] 从文件读取版本号: %s -> %s", filePath, config.Version) + } + return config.Version +} diff --git a/internal/storage/models/version.go b/internal/storage/models/version.go new file mode 100644 index 0000000..6ef0a94 --- /dev/null +++ b/internal/storage/models/version.go @@ -0,0 +1,20 @@ +package models + +import "time" + +// Version 版本信息 +type Version struct { + ID int `gorm:"primaryKey" json:"id"` // 主键ID + Version string `gorm:"type:varchar(20);not null;uniqueIndex" json:"version"` // 版本号(语义化版本,如1.0.0) + DownloadURL string `gorm:"type:varchar(500)" json:"download_url"` // 下载地址(更新包下载URL) + Changelog string `gorm:"type:text" json:"changelog"` // 更新日志(Markdown格式) + ForceUpdate int `gorm:"type:tinyint;not null;default:0" json:"force_update"` // 是否强制更新(1:是 0:否) + ReleaseDate *time.Time `gorm:"type:date" json:"release_date"` // 发布日期 + CreatedAt time.Time `gorm:"autoCreateTime:false" json:"created_at"` // 创建时间(由程序设置) + UpdatedAt time.Time `gorm:"autoUpdateTime:false" json:"updated_at"` // 更新时间(由程序设置) +} + +// TableName 指定表名 +func (Version) TableName() string { + return "sys_version" +} diff --git a/main.go b/main.go index 5ab21b9..f247da6 100644 --- a/main.go +++ b/main.go @@ -47,7 +47,7 @@ func main() { Assets: assets, }, BackgroundColour: &options.RGBA{R: 255, G: 255, B: 255, A: 1}, - OnStartup: app.startup, + OnStartup: app.Startup, Bind: []interface{}{ app, }, diff --git a/web/src/App.vue b/web/src/App.vue index 3e054ea..3b213d4 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -9,6 +9,13 @@
+ + + + +
@@ -83,6 +90,16 @@ + + + + + @@ -92,8 +109,10 @@ import {Message} from '@arco-design/web-vue' import DeviceTest from './components/DeviceTest.vue' import DbCli from './views/db-cli/index.vue' import ThemeToggle from './components/ThemeToggle.vue' +import UpdatePanel from './components/UpdatePanel.vue' const activeTab = ref('db-cli') +const showUpdateModal = ref(false) const loading = ref(false) const formModel = ref({ keyword: '', diff --git a/web/src/components/DeviceTest.vue b/web/src/components/DeviceTest.vue index 7e00ca2..16663c5 100644 --- a/web/src/components/DeviceTest.vue +++ b/web/src/components/DeviceTest.vue @@ -53,10 +53,12 @@ - 浏览 列出目录 @@ -124,8 +126,8 @@ diff --git a/web/src/components/UpdatePanel.vue b/web/src/components/UpdatePanel.vue new file mode 100644 index 0000000..026b909 --- /dev/null +++ b/web/src/components/UpdatePanel.vue @@ -0,0 +1,427 @@ + + + + +