.
This commit is contained in:
87
internal/api/auth_api.go
Normal file
87
internal/api/auth_api.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/service"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// AuthAPI 授权码 API
|
||||
type AuthAPI struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
// NewAuthAPI 创建授权码 API
|
||||
func NewAuthAPI() (*AuthAPI, error) {
|
||||
repo, err := repository.NewSQLiteAuthRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authService := service.NewAuthService(repo)
|
||||
|
||||
return &AuthAPI{
|
||||
authService: authService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ActivateLicense 激活授权码
|
||||
func (api *AuthAPI) ActivateLicense(licenseCode string) (map[string]interface{}, error) {
|
||||
if api.authService == nil {
|
||||
newAPI, err := NewAuthAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.authService = newAPI.authService
|
||||
}
|
||||
|
||||
err := api.authService.ValidateLicense(licenseCode)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
status, err := api.authService.CheckAuthStatus()
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "激活成功",
|
||||
"data": status,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAuthStatus 获取授权状态
|
||||
func (api *AuthAPI) GetAuthStatus() (map[string]interface{}, error) {
|
||||
if api.authService == nil {
|
||||
newAPI, err := NewAuthAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.authService = newAPI.authService
|
||||
}
|
||||
|
||||
status, err := api.authService.CheckAuthStatus()
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"data": status,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDeviceID 获取设备ID
|
||||
func (api *AuthAPI) GetDeviceID() (string, error) {
|
||||
return service.GetDeviceID()
|
||||
}
|
||||
67
internal/api/backup_api.go
Normal file
67
internal/api/backup_api.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/service"
|
||||
)
|
||||
|
||||
// BackupAPI 数据备份 API
|
||||
type BackupAPI struct {
|
||||
backupService *service.BackupService
|
||||
}
|
||||
|
||||
// NewBackupAPI 创建数据备份 API
|
||||
func NewBackupAPI() *BackupAPI {
|
||||
return &BackupAPI{
|
||||
backupService: service.NewBackupService(),
|
||||
}
|
||||
}
|
||||
|
||||
// Backup 备份数据
|
||||
func (api *BackupAPI) Backup() (map[string]interface{}, error) {
|
||||
result, err := api.backupService.Backup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"backup_path": result.BackupPath,
|
||||
"file_name": result.FileName,
|
||||
"file_size": result.FileSize,
|
||||
"created_at": result.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Restore 恢复数据
|
||||
func (api *BackupAPI) Restore(backupPath string) (map[string]interface{}, error) {
|
||||
if err := api.backupService.Restore(backupPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "数据恢复成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListBackups 列出所有备份
|
||||
func (api *BackupAPI) ListBackups() (map[string]interface{}, error) {
|
||||
backups, err := api.backupService.ListBackups()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupList := make([]map[string]interface{}, len(backups))
|
||||
for i, backup := range backups {
|
||||
backupList[i] = map[string]interface{}{
|
||||
"backup_path": backup.BackupPath,
|
||||
"file_name": backup.FileName,
|
||||
"file_size": backup.FileSize,
|
||||
"created_at": backup.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"backups": backupList,
|
||||
"count": len(backupList),
|
||||
}, nil
|
||||
}
|
||||
121
internal/api/package_api.go
Normal file
121
internal/api/package_api.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/service"
|
||||
)
|
||||
|
||||
// PackageAPI 离线数据包 API
|
||||
type PackageAPI struct {
|
||||
packageService *service.PackageService
|
||||
}
|
||||
|
||||
// NewPackageAPI 创建数据包 API
|
||||
func NewPackageAPI() (*PackageAPI, error) {
|
||||
packageService, err := service.NewPackageService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PackageAPI{
|
||||
packageService: packageService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DownloadPackage 下载数据包
|
||||
func (api *PackageAPI) DownloadPackage(downloadURL string) (map[string]interface{}, error) {
|
||||
if api.packageService == nil {
|
||||
newAPI, err := NewPackageAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.packageService = newAPI.packageService
|
||||
}
|
||||
|
||||
result, err := api.packageService.DownloadPackage(downloadURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"file_path": result.FilePath,
|
||||
"file_size": result.FileSize,
|
||||
"duration": result.Duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImportPackage 导入数据包
|
||||
func (api *PackageAPI) ImportPackage(packagePath string) (map[string]interface{}, error) {
|
||||
if api.packageService == nil {
|
||||
newAPI, err := NewPackageAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.packageService = newAPI.packageService
|
||||
}
|
||||
|
||||
result, err := api.packageService.ImportPackage(packagePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"imported_count": result.ImportedCount,
|
||||
"updated_count": result.UpdatedCount,
|
||||
"error_count": result.ErrorCount,
|
||||
"duration": result.Duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckPackageUpdate 检查数据包更新
|
||||
func (api *PackageAPI) CheckPackageUpdate(remoteURL string) (map[string]interface{}, error) {
|
||||
if api.packageService == nil {
|
||||
newAPI, err := NewPackageAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.packageService = newAPI.packageService
|
||||
}
|
||||
|
||||
info, err := api.packageService.CheckPackageUpdate(remoteURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
return map[string]interface{}{
|
||||
"need_update": false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"need_update": true,
|
||||
"version": info.Version,
|
||||
"total_count": info.TotalCount,
|
||||
"latest_issue": info.LatestIssue,
|
||||
"package_size": info.PackageSize,
|
||||
"download_url": info.DownloadURL,
|
||||
"release_date": info.ReleaseDate,
|
||||
"checksum": info.CheckSum,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListLocalPackages 列出本地数据包
|
||||
func (api *PackageAPI) ListLocalPackages() (map[string]interface{}, error) {
|
||||
if api.packageService == nil {
|
||||
newAPI, err := NewPackageAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.packageService = newAPI.packageService
|
||||
}
|
||||
|
||||
packages, err := api.packageService.ListLocalPackages()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"packages": packages,
|
||||
"count": len(packages),
|
||||
}, nil
|
||||
}
|
||||
55
internal/api/ssq_api.go
Normal file
55
internal/api/ssq_api.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/service"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// SsqAPI 双色球查询 API
|
||||
type SsqAPI struct {
|
||||
queryService *service.QueryService
|
||||
}
|
||||
|
||||
// NewSsqAPI 创建双色球查询 API
|
||||
func NewSsqAPI() (*SsqAPI, error) {
|
||||
// 优先使用 MySQL(数据源)
|
||||
repo, err := repository.NewMySQLSsqRepository()
|
||||
if err != nil {
|
||||
// MySQL 连接失败,降级使用本地 SQLite
|
||||
repo, err = repository.NewSQLiteSsqRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
queryService := service.NewQueryService(repo)
|
||||
|
||||
return &SsqAPI{
|
||||
queryService: queryService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryRequest 查询请求参数
|
||||
type QueryRequest struct {
|
||||
RedBalls []int `json:"red_balls"` // 红球列表
|
||||
BlueBall int `json:"blue_ball"` // 蓝球(0表示不限制)
|
||||
BlueBallRange []int `json:"blue_ball_range"` // 蓝球筛选范围
|
||||
}
|
||||
|
||||
// QueryHistory 查询历史数据
|
||||
func (api *SsqAPI) QueryHistory(req QueryRequest) (*service.QueryResult, error) {
|
||||
if api.queryService == nil {
|
||||
// 重新初始化
|
||||
newAPI, err := NewSsqAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.queryService = newAPI.queryService
|
||||
}
|
||||
|
||||
return api.queryService.Query(service.QueryRequest{
|
||||
RedBalls: req.RedBalls,
|
||||
BlueBall: req.BlueBall,
|
||||
BlueBallRange: req.BlueBallRange,
|
||||
})
|
||||
}
|
||||
69
internal/api/sync_api.go
Normal file
69
internal/api/sync_api.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/service"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// SyncAPI 数据同步 API
|
||||
type SyncAPI struct {
|
||||
syncService *service.SyncService
|
||||
}
|
||||
|
||||
// NewSyncAPI 创建数据同步 API
|
||||
func NewSyncAPI() (*SyncAPI, error) {
|
||||
// 获取 MySQL 和 SQLite Repository
|
||||
mysqlRepo, err := repository.NewMySQLSsqRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqliteRepo, err := repository.NewSQLiteSsqRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncService := service.NewSyncService(mysqlRepo, sqliteRepo)
|
||||
|
||||
return &SyncAPI{
|
||||
syncService: syncService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Sync 执行数据同步
|
||||
func (api *SyncAPI) Sync() (map[string]interface{}, error) {
|
||||
if api.syncService == nil {
|
||||
newAPI, err := NewSyncAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.syncService = newAPI.syncService
|
||||
}
|
||||
|
||||
result, err := api.syncService.Sync()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_count": result.TotalCount,
|
||||
"synced_count": result.SyncedCount,
|
||||
"new_count": result.NewCount,
|
||||
"updated_count": result.UpdatedCount,
|
||||
"error_count": result.ErrorCount,
|
||||
"latest_issue": result.LatestIssue,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSyncStatus 获取同步状态
|
||||
func (api *SyncAPI) GetSyncStatus() (map[string]interface{}, error) {
|
||||
if api.syncService == nil {
|
||||
newAPI, err := NewSyncAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
api.syncService = newAPI.syncService
|
||||
}
|
||||
|
||||
return api.syncService.GetSyncStatus()
|
||||
}
|
||||
254
internal/api/update_api.go
Normal file
254
internal/api/update_api.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"ssq-desk/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
// UpdateAPI 版本更新 API
|
||||
type UpdateAPI struct {
|
||||
updateService *service.UpdateService
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewUpdateAPI 创建版本更新 API
|
||||
func NewUpdateAPI(checkURL string) (*UpdateAPI, error) {
|
||||
updateService := service.NewUpdateService(checkURL)
|
||||
|
||||
return &UpdateAPI{
|
||||
updateService: updateService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetContext 设置上下文(用于事件推送)
|
||||
func (api *UpdateAPI) SetContext(ctx context.Context) {
|
||||
api.ctx = ctx
|
||||
}
|
||||
|
||||
// CheckUpdate 检查更新
|
||||
func (api *UpdateAPI) CheckUpdate() (map[string]interface{}, error) {
|
||||
if api.updateService == nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "更新服务未初始化",
|
||||
}, nil
|
||||
}
|
||||
|
||||
result, err := api.updateService.CheckUpdate()
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
if errorMsg == "" {
|
||||
errorMsg = "未知错误"
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": errorMsg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "检查更新返回结果为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCurrentVersion 获取当前版本号
|
||||
func (api *UpdateAPI) GetCurrentVersion() (map[string]interface{}, error) {
|
||||
version := service.GetCurrentVersion()
|
||||
|
||||
// 更新配置中的版本号
|
||||
if config, err := service.LoadUpdateConfig(); err == nil && config.CurrentVersion != version {
|
||||
config.CurrentVersion = version
|
||||
service.SaveUpdateConfig(config)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"version": version,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUpdateConfig 获取更新配置
|
||||
func (api *UpdateAPI) GetUpdateConfig() (map[string]interface{}, error) {
|
||||
config, err := service.LoadUpdateConfig()
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 确保版本号是最新的
|
||||
latestVersion := service.GetCurrentVersion()
|
||||
if config.CurrentVersion != latestVersion {
|
||||
config.CurrentVersion = latestVersion
|
||||
service.SaveUpdateConfig(config)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"current_version": config.CurrentVersion,
|
||||
"last_check_time": config.LastCheckTime.Format("2006-01-02 15:04:05"),
|
||||
"auto_check_enabled": config.AutoCheckEnabled,
|
||||
"check_interval_minutes": config.CheckIntervalMinutes,
|
||||
"check_url": config.CheckURL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetUpdateConfig 设置更新配置
|
||||
func (api *UpdateAPI) SetUpdateConfig(autoCheckEnabled bool, checkIntervalMinutes int, checkURL string) (map[string]interface{}, error) {
|
||||
config, err := service.LoadUpdateConfig()
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
config.AutoCheckEnabled = autoCheckEnabled
|
||||
config.CheckIntervalMinutes = checkIntervalMinutes
|
||||
if checkURL != "" {
|
||||
config.CheckURL = checkURL
|
||||
// 如果 URL 改变,需要重新创建服务
|
||||
api.updateService = service.NewUpdateService(checkURL)
|
||||
}
|
||||
|
||||
if err := service.SaveUpdateConfig(config); err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "配置保存成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DownloadUpdate 下载更新包(异步,通过事件推送进度)
|
||||
func (api *UpdateAPI) DownloadUpdate(downloadURL string) (map[string]interface{}, error) {
|
||||
if downloadURL == "" {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "下载地址不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
progressCallback := func(progress float64, speed float64, downloaded int64, total int64) {
|
||||
if api.ctx == nil {
|
||||
return
|
||||
}
|
||||
// 确保进度值在 0-100 之间
|
||||
if progress < 0 {
|
||||
progress = 0
|
||||
} else if progress > 100 {
|
||||
progress = 100
|
||||
}
|
||||
|
||||
progressInfo := map[string]interface{}{
|
||||
"progress": progress,
|
||||
"speed": speed,
|
||||
"downloaded": downloaded,
|
||||
"total": total,
|
||||
}
|
||||
progressJSON, _ := json.Marshal(progressInfo)
|
||||
runtime.EventsEmit(api.ctx, "download-progress", string(progressJSON))
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
result, err := service.DownloadUpdate(downloadURL, progressCallback)
|
||||
|
||||
if api.ctx != nil {
|
||||
if err != nil {
|
||||
errorInfo := map[string]interface{}{"error": err.Error()}
|
||||
errorJSON, _ := json.Marshal(errorInfo)
|
||||
runtime.EventsEmit(api.ctx, "download-complete", string(errorJSON))
|
||||
} else {
|
||||
resultInfo := map[string]interface{}{
|
||||
"success": true,
|
||||
"file_path": result.FilePath,
|
||||
"file_size": result.FileSize,
|
||||
}
|
||||
resultJSON, _ := json.Marshal(resultInfo)
|
||||
runtime.EventsEmit(api.ctx, "download-complete", string(resultJSON))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "下载已开始",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InstallUpdate 安装更新包
|
||||
func (api *UpdateAPI) InstallUpdate(installerPath string, autoRestart bool) (map[string]interface{}, error) {
|
||||
return api.InstallUpdateWithHash(installerPath, autoRestart, "", "")
|
||||
}
|
||||
|
||||
// InstallUpdateWithHash 安装更新包(带哈希验证)
|
||||
func (api *UpdateAPI) InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (map[string]interface{}, error) {
|
||||
if installerPath == "" {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "安装文件路径不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
result, err := service.InstallUpdateWithHash(installerPath, autoRestart, expectedHash, hashType)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": result.Success,
|
||||
"message": result.Message,
|
||||
"data": result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyUpdateFile 验证更新文件哈希值
|
||||
func (api *UpdateAPI) VerifyUpdateFile(filePath string, expectedHash string, hashType string) (map[string]interface{}, error) {
|
||||
if filePath == "" {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": "文件路径不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
valid, err := service.VerifyFileHash(filePath, expectedHash, hashType)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"valid": valid,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
91
internal/database/mysql.go
Normal file
91
internal/database/mysql.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"sync"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
mysqlDB *gorm.DB
|
||||
mysqlOnce sync.Once
|
||||
)
|
||||
|
||||
// MySQLConfig MySQL 连接配置
|
||||
type MySQLConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
// GetMySQLConfig 获取 MySQL 配置(从配置文件或环境变量)
|
||||
func GetMySQLConfig() *MySQLConfig {
|
||||
return &MySQLConfig{
|
||||
Host: "39.99.243.191",
|
||||
Port: 3306,
|
||||
User: "u_ssq",
|
||||
Password: "u_ssq@260106",
|
||||
Database: "ssq_dev", // 需要根据实际情况修改数据库名
|
||||
}
|
||||
}
|
||||
|
||||
// InitMySQL 初始化 MySQL 连接
|
||||
func InitMySQL() (*gorm.DB, error) {
|
||||
var err error
|
||||
mysqlOnce.Do(func() {
|
||||
config := GetMySQLConfig()
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
config.User, config.Password, config.Host, config.Port, config.Database)
|
||||
|
||||
mysqlDB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
sqlDB, err2 := mysqlDB.DB()
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
if err2 = sqlDB.Ping(); err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
// 自动迁移表结构
|
||||
err2 = mysqlDB.AutoMigrate(
|
||||
&models.SsqHistory{},
|
||||
&models.Authorization{},
|
||||
&models.Version{},
|
||||
)
|
||||
if err2 != nil {
|
||||
err = fmt.Errorf("MySQL 表迁移失败: %v", err2)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("MySQL 连接初始化失败: %v", err)
|
||||
}
|
||||
|
||||
return mysqlDB, nil
|
||||
}
|
||||
|
||||
// GetMySQL 获取 MySQL 连接实例
|
||||
func GetMySQL() *gorm.DB {
|
||||
if mysqlDB == nil {
|
||||
db, err := InitMySQL()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return db
|
||||
}
|
||||
return mysqlDB
|
||||
}
|
||||
102
internal/database/sqlite.go
Normal file
102
internal/database/sqlite.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"ssq-desk/internal/storage/models"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
_ "modernc.org/sqlite" // 使用纯 Go 的 SQLite 驱动(不需要 CGO)
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteDB *gorm.DB
|
||||
sqliteOnce sync.Once
|
||||
)
|
||||
|
||||
// InitSQLite 初始化 SQLite 连接
|
||||
func InitSQLite() (*gorm.DB, error) {
|
||||
var err error
|
||||
sqliteOnce.Do(func() {
|
||||
// 获取应用数据目录
|
||||
homeDir, err2 := os.UserHomeDir()
|
||||
if err2 != nil {
|
||||
err = fmt.Errorf("获取用户目录失败: %v", err2)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建数据目录
|
||||
dataDir := filepath.Join(homeDir, ".ssq-desk")
|
||||
if err2 := os.MkdirAll(dataDir, 0755); err2 != nil {
|
||||
err = fmt.Errorf("创建数据目录失败: %v", err2)
|
||||
return
|
||||
}
|
||||
|
||||
// SQLite 数据库文件路径
|
||||
dbPath := filepath.Join(dataDir, "ssq.db")
|
||||
|
||||
// 直接使用 database/sql 打开连接,确保使用 modernc.org/sqlite(纯 Go,不需要 CGO)
|
||||
sqlDB, err2 := sql.Open("sqlite", dbPath)
|
||||
if err2 != nil {
|
||||
err = fmt.Errorf("SQLite 打开连接失败: %v", err2)
|
||||
sqliteDB = nil
|
||||
return
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err2 = sqlDB.Ping(); err2 != nil {
|
||||
err = fmt.Errorf("SQLite 连接测试失败: %v", err2)
|
||||
sqlDB.Close()
|
||||
sqliteDB = nil
|
||||
return
|
||||
}
|
||||
|
||||
// 使用已打开的 database/sql 连接创建 GORM 实例
|
||||
// 使用 sqlite.Dialector 并指定连接
|
||||
sqliteDB, err2 = gorm.Open(sqlite.Dialector{Conn: sqlDB}, &gorm.Config{})
|
||||
if err2 != nil {
|
||||
err = fmt.Errorf("SQLite GORM 初始化失败: %v", err2)
|
||||
sqlDB.Close()
|
||||
sqliteDB = nil
|
||||
return
|
||||
}
|
||||
|
||||
// 自动迁移表结构(如果表已存在但结构不对,AutoMigrate 会尝试修改)
|
||||
// 如果表结构完全不匹配,可能需要手动删除旧表
|
||||
err2 = sqliteDB.AutoMigrate(
|
||||
&models.SsqHistory{},
|
||||
&models.Authorization{},
|
||||
&models.Version{},
|
||||
)
|
||||
if err2 != nil {
|
||||
err = fmt.Errorf("SQLite 表迁移失败: %v", err2)
|
||||
sqliteDB = nil
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sqliteDB, nil
|
||||
}
|
||||
|
||||
// GetSQLite 获取 SQLite 连接实例
|
||||
// 如果连接未初始化或初始化失败,返回 nil
|
||||
func GetSQLite() *gorm.DB {
|
||||
if sqliteDB == nil {
|
||||
db, err := InitSQLite()
|
||||
if err != nil {
|
||||
// 初始化失败,返回 nil
|
||||
return nil
|
||||
}
|
||||
sqliteDB = db
|
||||
}
|
||||
return sqliteDB
|
||||
}
|
||||
62
internal/module/auth_module.go
Normal file
62
internal/module/auth_module.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ssq-desk/internal/api"
|
||||
)
|
||||
|
||||
// AuthModule 授权码模块
|
||||
type AuthModule struct {
|
||||
BaseModule
|
||||
authAPI *api.AuthAPI
|
||||
}
|
||||
|
||||
// NewAuthModule 创建授权码模块
|
||||
func NewAuthModule() (*AuthModule, error) {
|
||||
// 延迟初始化,等到 Init() 方法调用时再创建 API(此时数据库已初始化)
|
||||
return &AuthModule{
|
||||
BaseModule: BaseModule{
|
||||
name: "auth",
|
||||
api: nil, // 延迟初始化
|
||||
},
|
||||
authAPI: nil, // 延迟初始化
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Init 初始化模块
|
||||
func (m *AuthModule) Init(ctx context.Context) error {
|
||||
if m.authAPI == nil {
|
||||
authAPI, err := api.NewAuthAPI()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.authAPI = authAPI
|
||||
m.api = authAPI
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动模块(检查授权状态)
|
||||
func (m *AuthModule) Start(ctx context.Context) error {
|
||||
if m.authAPI == nil {
|
||||
return m.Init(ctx)
|
||||
}
|
||||
|
||||
status, err := m.authAPI.GetAuthStatus()
|
||||
if err == nil {
|
||||
if data, ok := status["data"].(map[string]interface{}); ok && data != nil {
|
||||
if isActivated, ok := data["is_activated"].(bool); ok && !isActivated {
|
||||
println("授权未激活,部分功能可能受限")
|
||||
} else {
|
||||
println("授权验证通过")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthAPI 返回 Auth API(类型安全)
|
||||
func (m *AuthModule) AuthAPI() *api.AuthAPI {
|
||||
return m.authAPI
|
||||
}
|
||||
119
internal/module/manager.go
Normal file
119
internal/module/manager.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Manager 模块管理器
|
||||
type Manager struct {
|
||||
modules map[string]Module
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager 创建模块管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
modules: make(map[string]Module),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册模块
|
||||
func (m *Manager) Register(module Module) error {
|
||||
if module == nil {
|
||||
return fmt.Errorf("模块不能为空")
|
||||
}
|
||||
|
||||
name := module.Name()
|
||||
if name == "" {
|
||||
return fmt.Errorf("模块名称不能为空")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.modules[name]; exists {
|
||||
return fmt.Errorf("模块 %s 已存在", name)
|
||||
}
|
||||
|
||||
m.modules[name] = module
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取模块
|
||||
func (m *Manager) Get(name string) (Module, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
module, exists := m.modules[name]
|
||||
return module, exists
|
||||
}
|
||||
|
||||
// GetAll 获取所有模块
|
||||
func (m *Manager) GetAll() []Module {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
modules := make([]Module, 0, len(m.modules))
|
||||
for _, module := range m.modules {
|
||||
modules = append(modules, module)
|
||||
}
|
||||
return modules
|
||||
}
|
||||
|
||||
// InitAll 初始化所有模块
|
||||
func (m *Manager) InitAll(ctx context.Context) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for name, module := range m.modules {
|
||||
if err := module.Init(ctx); err != nil {
|
||||
return fmt.Errorf("初始化模块 %s 失败: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAll 启动所有模块
|
||||
func (m *Manager) StartAll(ctx context.Context) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for name, module := range m.modules {
|
||||
if err := module.Start(ctx); err != nil {
|
||||
return fmt.Errorf("启动模块 %s 失败: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopAll 停止所有模块
|
||||
func (m *Manager) StopAll(ctx context.Context) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for name, module := range m.modules {
|
||||
if err := module.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("停止模块 %s 失败: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAPIs 获取所有模块的 API
|
||||
func (m *Manager) GetAPIs() []interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
apis := make([]interface{}, 0, len(m.modules))
|
||||
for _, module := range m.modules {
|
||||
if api := module.GetAPI(); api != nil {
|
||||
apis = append(apis, api)
|
||||
}
|
||||
}
|
||||
return apis
|
||||
}
|
||||
52
internal/module/module.go
Normal file
52
internal/module/module.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package module
|
||||
|
||||
import "context"
|
||||
|
||||
// Module 模块接口
|
||||
type Module interface {
|
||||
// Name 返回模块名称
|
||||
Name() string
|
||||
|
||||
// Init 初始化模块
|
||||
Init(ctx context.Context) error
|
||||
|
||||
// Start 启动模块
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop 停止模块
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// GetAPI 获取模块的 API 接口(供前端调用)
|
||||
GetAPI() interface{}
|
||||
}
|
||||
|
||||
// BaseModule 基础模块实现
|
||||
type BaseModule struct {
|
||||
name string
|
||||
api interface{}
|
||||
}
|
||||
|
||||
// Name 返回模块名称
|
||||
func (m *BaseModule) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
// GetAPI 获取模块的 API 接口
|
||||
func (m *BaseModule) GetAPI() interface{} {
|
||||
return m.api
|
||||
}
|
||||
|
||||
// Init 初始化模块(默认实现)
|
||||
func (m *BaseModule) Init(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动模块(默认实现)
|
||||
func (m *BaseModule) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止模块(默认实现)
|
||||
func (m *BaseModule) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
42
internal/module/ssq_module.go
Normal file
42
internal/module/ssq_module.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ssq-desk/internal/api"
|
||||
)
|
||||
|
||||
// SsqModule 双色球查询模块
|
||||
type SsqModule struct {
|
||||
BaseModule
|
||||
ssqAPI *api.SsqAPI
|
||||
}
|
||||
|
||||
// NewSsqModule 创建双色球查询模块
|
||||
func NewSsqModule() (*SsqModule, error) {
|
||||
// 延迟初始化,等到 Init() 方法调用时再创建 API(此时数据库已初始化)
|
||||
return &SsqModule{
|
||||
BaseModule: BaseModule{
|
||||
name: "ssq",
|
||||
api: nil, // 延迟初始化
|
||||
},
|
||||
ssqAPI: nil, // 延迟初始化
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Init 初始化模块
|
||||
func (m *SsqModule) Init(ctx context.Context) error {
|
||||
if m.ssqAPI == nil {
|
||||
ssqAPI, err := api.NewSsqAPI()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ssqAPI = ssqAPI
|
||||
m.api = ssqAPI
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SsqAPI 返回 SSQ API(类型安全)
|
||||
func (m *SsqModule) SsqAPI() *api.SsqAPI {
|
||||
return m.ssqAPI
|
||||
}
|
||||
94
internal/module/update_module.go
Normal file
94
internal/module/update_module.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"context"
|
||||
"ssq-desk/internal/api"
|
||||
"ssq-desk/internal/service"
|
||||
)
|
||||
|
||||
// UpdateModule 版本更新模块
|
||||
type UpdateModule struct {
|
||||
BaseModule
|
||||
updateAPI *api.UpdateAPI
|
||||
checkURL string // 版本检查接口 URL
|
||||
}
|
||||
|
||||
// NewUpdateModule 创建版本更新模块
|
||||
func NewUpdateModule() (*UpdateModule, error) {
|
||||
// 从配置文件读取检查 URL
|
||||
config, err := service.LoadUpdateConfig()
|
||||
if err != nil {
|
||||
// 配置加载失败,使用默认值
|
||||
config = &service.UpdateConfig{}
|
||||
}
|
||||
|
||||
checkURL := config.CheckURL
|
||||
if checkURL == "" {
|
||||
// 如果配置中没有,使用默认地址
|
||||
checkURL = "https://img.1216.top/ssq/last-version.json"
|
||||
}
|
||||
|
||||
updateAPI, err := api.NewUpdateAPI(checkURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UpdateModule{
|
||||
BaseModule: BaseModule{
|
||||
name: "update",
|
||||
api: updateAPI,
|
||||
},
|
||||
updateAPI: updateAPI,
|
||||
checkURL: checkURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Init 初始化模块
|
||||
func (m *UpdateModule) Init(ctx context.Context) error {
|
||||
if m.updateAPI == nil {
|
||||
updateAPI, err := api.NewUpdateAPI(m.checkURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.updateAPI = updateAPI
|
||||
m.api = updateAPI
|
||||
}
|
||||
// 设置 context 以便推送事件
|
||||
if m.updateAPI != nil {
|
||||
m.updateAPI.SetContext(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动模块(检查更新配置,决定是否自动检查)
|
||||
func (m *UpdateModule) Start(ctx context.Context) error {
|
||||
if m.updateAPI == nil {
|
||||
return m.Init(ctx)
|
||||
}
|
||||
|
||||
// 加载配置
|
||||
config, err := service.LoadUpdateConfig()
|
||||
if err != nil {
|
||||
// 配置加载失败不影响启动,只记录日志
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果启用了自动检查且满足检查条件,则检查更新
|
||||
if config.ShouldCheckUpdate() && config.CheckURL != "" {
|
||||
// 异步检查更新,不阻塞启动流程
|
||||
go func() {
|
||||
_, err := m.updateAPI.CheckUpdate()
|
||||
if err == nil {
|
||||
// 更新最后检查时间
|
||||
config.UpdateLastCheckTime()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAPI 返回 Update API(类型安全)
|
||||
func (m *UpdateModule) UpdateAPI() *api.UpdateAPI {
|
||||
return m.updateAPI
|
||||
}
|
||||
215
internal/service/auth_service.go
Normal file
215
internal/service/auth_service.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"ssq-desk/internal/database"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthService 授权服务
|
||||
type AuthService struct {
|
||||
repo repository.AuthRepository
|
||||
}
|
||||
|
||||
// NewAuthService 创建授权服务
|
||||
func NewAuthService(repo repository.AuthRepository) *AuthService {
|
||||
return &AuthService{repo: repo}
|
||||
}
|
||||
|
||||
// GetDeviceID 获取设备ID(基于硬件信息生成)
|
||||
func GetDeviceID() (string, error) {
|
||||
var deviceInfo string
|
||||
|
||||
// 获取主机名
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
}
|
||||
|
||||
// 获取用户目录
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "unknown"
|
||||
}
|
||||
|
||||
// 组合设备信息
|
||||
deviceInfo = fmt.Sprintf("%s-%s-%s", hostname, homeDir, runtime.GOOS)
|
||||
|
||||
// 生成 MD5 作为设备ID
|
||||
hash := md5.Sum([]byte(deviceInfo))
|
||||
deviceID := hex.EncodeToString(hash[:])
|
||||
|
||||
return deviceID, nil
|
||||
}
|
||||
|
||||
// ValidateLicenseFormat 验证授权码格式
|
||||
func ValidateLicenseFormat(licenseCode string) error {
|
||||
if licenseCode == "" {
|
||||
return fmt.Errorf("授权码不能为空")
|
||||
}
|
||||
|
||||
// 去除空格和连字符
|
||||
cleaned := ""
|
||||
for _, c := range licenseCode {
|
||||
if c != ' ' && c != '-' {
|
||||
cleaned += string(c)
|
||||
}
|
||||
}
|
||||
|
||||
// 格式验证:至少16位,只包含字母和数字
|
||||
if len(cleaned) < 16 {
|
||||
return fmt.Errorf("授权码长度不足,至少需要16位字符")
|
||||
}
|
||||
|
||||
if len(cleaned) > 100 {
|
||||
return fmt.Errorf("授权码长度过长,最多100位字符")
|
||||
}
|
||||
|
||||
// 验证字符:只允许字母和数字
|
||||
for _, c := range cleaned {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
|
||||
return fmt.Errorf("授权码只能包含字母和数字")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateLicenseFromRemote 从远程数据库验证授权码
|
||||
func ValidateLicenseFromRemote(licenseCode string) error {
|
||||
// 获取 MySQL 连接
|
||||
mysqlDB := database.GetMySQL()
|
||||
if mysqlDB == nil {
|
||||
return fmt.Errorf("无法连接远程数据库,无法验证授权码")
|
||||
}
|
||||
|
||||
// 清理授权码(去除空格和连字符),与格式验证保持一致
|
||||
cleaned := ""
|
||||
for _, c := range licenseCode {
|
||||
if c != ' ' && c != '-' {
|
||||
cleaned += string(c)
|
||||
}
|
||||
}
|
||||
|
||||
// 查询授权码是否存在且有效(支持原始格式和清理后格式)
|
||||
var auth models.Authorization
|
||||
err := mysqlDB.Where("(license_code = ? OR license_code = ?) AND status = ?", licenseCode, cleaned, 1).First(&auth).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return fmt.Errorf("授权码无效或不存在")
|
||||
}
|
||||
return fmt.Errorf("验证授权码时发生错误: %v", err)
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if auth.ExpiresAt != nil && auth.ExpiresAt.Before(time.Now()) {
|
||||
return fmt.Errorf("授权码已过期")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateLicense 验证授权码
|
||||
func (s *AuthService) ValidateLicense(licenseCode string) error {
|
||||
// 格式验证
|
||||
if err := ValidateLicenseFormat(licenseCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 从远程数据库验证授权码有效性
|
||||
if err := ValidateLicenseFromRemote(licenseCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取设备ID
|
||||
deviceID, err := GetDeviceID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取设备ID失败: %v", err)
|
||||
}
|
||||
|
||||
// 保存授权信息到本地
|
||||
auth := &models.Authorization{
|
||||
LicenseCode: licenseCode,
|
||||
DeviceID: deviceID,
|
||||
ActivatedAt: time.Now(),
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
existing, err := s.repo.GetByLicenseCode(licenseCode)
|
||||
if err == nil && existing != nil {
|
||||
// 更新现有授权
|
||||
existing.DeviceID = deviceID
|
||||
existing.ActivatedAt = time.Now()
|
||||
existing.Status = 1
|
||||
return s.repo.Update(existing)
|
||||
}
|
||||
|
||||
// 创建新授权
|
||||
return s.repo.Create(auth)
|
||||
}
|
||||
|
||||
// CheckAuthStatus 检查授权状态
|
||||
func (s *AuthService) CheckAuthStatus() (*AuthStatus, error) {
|
||||
deviceID, err := GetDeviceID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取设备ID失败: %v", err)
|
||||
}
|
||||
|
||||
auth, err := s.repo.GetByDeviceID(deviceID)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return &AuthStatus{
|
||||
IsActivated: false,
|
||||
Message: "未激活",
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
if auth.Status != 1 {
|
||||
return &AuthStatus{
|
||||
IsActivated: false,
|
||||
Message: "授权已失效",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 检查过期时间
|
||||
if auth.ExpiresAt != nil && auth.ExpiresAt.Before(time.Now()) {
|
||||
return &AuthStatus{
|
||||
IsActivated: false,
|
||||
Message: "授权已过期",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &AuthStatus{
|
||||
IsActivated: true,
|
||||
LicenseCode: auth.LicenseCode,
|
||||
ActivatedAt: auth.ActivatedAt,
|
||||
ExpiresAt: auth.ExpiresAt,
|
||||
Message: "已激活",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthStatus 授权状态
|
||||
type AuthStatus struct {
|
||||
IsActivated bool `json:"is_activated"`
|
||||
LicenseCode string `json:"license_code,omitempty"`
|
||||
ActivatedAt time.Time `json:"activated_at,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// GetAuthInfo 获取授权信息
|
||||
func (s *AuthService) GetAuthInfo() (*AuthStatus, error) {
|
||||
return s.CheckAuthStatus()
|
||||
}
|
||||
287
internal/service/backup_service.go
Normal file
287
internal/service/backup_service.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ssq-desk/internal/database"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BackupService 数据备份服务
|
||||
type BackupService struct{}
|
||||
|
||||
// NewBackupService 创建备份服务
|
||||
func NewBackupService() *BackupService {
|
||||
return &BackupService{}
|
||||
}
|
||||
|
||||
// BackupResult 备份结果
|
||||
type BackupResult struct {
|
||||
BackupPath string `json:"backup_path"`
|
||||
FileName string `json:"file_name"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// Backup 备份 SQLite 数据库
|
||||
func (s *BackupService) Backup() (*BackupResult, error) {
|
||||
// 获取 SQLite 数据库路径
|
||||
appDataDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(appDataDir, "ssq-desk", "data", "ssq.db")
|
||||
|
||||
// 检查数据库文件是否存在
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("数据库文件不存在: %s", dbPath)
|
||||
}
|
||||
|
||||
// 创建备份目录
|
||||
backupDir := filepath.Join(appDataDir, "ssq-desk", "backups")
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建备份目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 生成备份文件名(带时间戳)
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
backupFileName := fmt.Sprintf("ssq-backup-%s.zip", timestamp)
|
||||
backupPath := filepath.Join(backupDir, backupFileName)
|
||||
|
||||
// 创建 ZIP 文件
|
||||
zipFile, err := os.Create(backupPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建备份文件失败: %v", err)
|
||||
}
|
||||
defer zipFile.Close()
|
||||
|
||||
zipWriter := zip.NewWriter(zipFile)
|
||||
defer zipWriter.Close()
|
||||
|
||||
// 添加数据库文件到 ZIP
|
||||
dbFile, err := os.Open(dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库文件失败: %v", err)
|
||||
}
|
||||
defer dbFile.Close()
|
||||
|
||||
dbInfo, err := dbFile.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
dbHeader, err := zip.FileInfoHeader(dbInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 ZIP 文件头失败: %v", err)
|
||||
}
|
||||
dbHeader.Name = "ssq.db"
|
||||
dbHeader.Method = zip.Deflate
|
||||
|
||||
dbWriter, err := zipWriter.CreateHeader(dbHeader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 ZIP 写入器失败: %v", err)
|
||||
}
|
||||
|
||||
if _, err := dbWriter.Write([]byte{}); err != nil {
|
||||
return nil, fmt.Errorf("写入 ZIP 文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 重新写入数据库内容
|
||||
if _, err := dbFile.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("重置文件指针失败: %v", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 1024*1024) // 1MB buffer
|
||||
for {
|
||||
n, err := dbFile.Read(buffer)
|
||||
if n > 0 {
|
||||
if _, err := dbWriter.Write(buffer[:n]); err != nil {
|
||||
return nil, fmt.Errorf("写入数据库内容失败: %v", err)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 添加元数据文件
|
||||
metaData := map[string]interface{}{
|
||||
"backup_time": time.Now().Format("2006-01-02 15:04:05"),
|
||||
"version": "1.0",
|
||||
}
|
||||
|
||||
metaWriter, err := zipWriter.Create("metadata.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建元数据文件失败: %v", err)
|
||||
}
|
||||
|
||||
metaJSON, err := json.Marshal(metaData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化元数据失败: %v", err)
|
||||
}
|
||||
|
||||
if _, err := metaWriter.Write(metaJSON); err != nil {
|
||||
return nil, fmt.Errorf("写入元数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取备份文件大小
|
||||
fileInfo, err := zipFile.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取备份文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
return &BackupResult{
|
||||
BackupPath: backupPath,
|
||||
FileName: backupFileName,
|
||||
FileSize: fileInfo.Size(),
|
||||
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Restore 恢复数据
|
||||
func (s *BackupService) Restore(backupPath string) error {
|
||||
// 检查备份文件是否存在
|
||||
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("备份文件不存在: %s", backupPath)
|
||||
}
|
||||
|
||||
// 打开 ZIP 文件
|
||||
zipReader, err := zip.OpenReader(backupPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开备份文件失败: %v", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
// 获取 SQLite 数据库路径
|
||||
appDataDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取用户配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
dataDir := filepath.Join(appDataDir, "ssq-desk", "data")
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建数据目录失败: %v", err)
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(dataDir, "ssq.db")
|
||||
|
||||
// 备份当前数据库(如果存在)
|
||||
if _, err := os.Stat(dbPath); err == nil {
|
||||
backupName := fmt.Sprintf("ssq.db.bak.%s", time.Now().Format("20060102-150405"))
|
||||
backupPath := filepath.Join(dataDir, backupName)
|
||||
if err := copyFile(dbPath, backupPath); err != nil {
|
||||
return fmt.Errorf("备份当前数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 查找数据库文件
|
||||
var dbFile *zip.File
|
||||
for _, file := range zipReader.File {
|
||||
if file.Name == "ssq.db" {
|
||||
dbFile = file
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dbFile == nil {
|
||||
return fmt.Errorf("备份文件中未找到数据库文件")
|
||||
}
|
||||
|
||||
// 解压数据库文件
|
||||
rc, err := dbFile.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库文件失败: %v", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
// 创建新的数据库文件
|
||||
newDBFile, err := os.Create(dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建数据库文件失败: %v", err)
|
||||
}
|
||||
defer newDBFile.Close()
|
||||
|
||||
// 复制数据
|
||||
buffer := make([]byte, 1024*1024) // 1MB buffer
|
||||
for {
|
||||
n, err := rc.Read(buffer)
|
||||
if n > 0 {
|
||||
if _, err := newDBFile.Write(buffer[:n]); err != nil {
|
||||
return fmt.Errorf("写入数据库文件失败: %v", err)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 重新初始化数据库连接
|
||||
_, err = database.InitSQLite()
|
||||
if err != nil {
|
||||
return fmt.Errorf("重新初始化数据库失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBackups 列出所有备份文件
|
||||
func (s *BackupService) ListBackups() ([]BackupResult, error) {
|
||||
appDataDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
backupDir := filepath.Join(appDataDir, "ssq-desk", "backups")
|
||||
|
||||
// 检查备份目录是否存在
|
||||
if _, err := os.Stat(backupDir); os.IsNotExist(err) {
|
||||
return []BackupResult{}, nil
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(backupDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取备份目录失败: %v", err)
|
||||
}
|
||||
|
||||
var backups []BackupResult
|
||||
for _, file := range files {
|
||||
if filepath.Ext(file.Name()) == ".zip" {
|
||||
filePath := filepath.Join(backupDir, file.Name())
|
||||
fileInfo, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
backups = append(backups, BackupResult{
|
||||
BackupPath: filePath,
|
||||
FileName: file.Name(),
|
||||
FileSize: fileInfo.Size(),
|
||||
CreatedAt: fileInfo.ModTime().Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
// copyFile 复制文件
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = destFile.ReadFrom(sourceFile)
|
||||
return err
|
||||
}
|
||||
281
internal/service/package_service.go
Normal file
281
internal/service/package_service.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PackageService 离线数据包服务
|
||||
type PackageService struct {
|
||||
sqliteRepo repository.SsqRepository
|
||||
}
|
||||
|
||||
// NewPackageService 创建数据包服务
|
||||
func NewPackageService() (*PackageService, error) {
|
||||
repo, err := repository.NewSQLiteSsqRepository()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PackageService{
|
||||
sqliteRepo: repo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PackageInfo 数据包信息
|
||||
type PackageInfo struct {
|
||||
Version string `json:"version"` // 数据包版本
|
||||
TotalCount int `json:"total_count"` // 数据总数
|
||||
LatestIssue string `json:"latest_issue"` // 最新期号
|
||||
PackageSize int64 `json:"package_size"` // 包大小
|
||||
DownloadURL string `json:"download_url"` // 下载地址
|
||||
ReleaseDate string `json:"release_date"` // 发布日期
|
||||
CheckSum string `json:"checksum"` // 校验和
|
||||
}
|
||||
|
||||
// PackageDownloadResult 数据包下载结果
|
||||
type PackageDownloadResult struct {
|
||||
FilePath string `json:"file_path"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
Duration string `json:"duration"` // 下载耗时
|
||||
}
|
||||
|
||||
// ImportResult 导入结果
|
||||
type ImportResult struct {
|
||||
ImportedCount int `json:"imported_count"` // 导入数量
|
||||
UpdatedCount int `json:"updated_count"` // 更新数量
|
||||
ErrorCount int `json:"error_count"` // 错误数量
|
||||
Duration string `json:"duration"` // 导入耗时
|
||||
}
|
||||
|
||||
// DownloadPackage 下载数据包
|
||||
func (s *PackageService) DownloadPackage(downloadURL string, progressCallback func(int64, int64)) (*PackageDownloadResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 创建下载目录
|
||||
appDataDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
downloadDir := filepath.Join(appDataDir, "ssq-desk", "packages")
|
||||
if err := os.MkdirAll(downloadDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建下载目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 生成文件名
|
||||
filename := filepath.Base(downloadURL)
|
||||
if filename == "" || filename == "." {
|
||||
filename = fmt.Sprintf("ssq-data-%s.zip", time.Now().Format("20060102-150405"))
|
||||
}
|
||||
filePath := filepath.Join(downloadDir, filename)
|
||||
|
||||
// 创建文件
|
||||
out, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// 发起 HTTP 请求
|
||||
resp, err := http.Get(downloadURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载失败: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 获取文件大小
|
||||
totalSize := resp.ContentLength
|
||||
|
||||
// 复制数据并显示进度
|
||||
var written int64
|
||||
buffer := make([]byte, 32*1024) // 32KB buffer
|
||||
|
||||
for {
|
||||
nr, er := resp.Body.Read(buffer)
|
||||
if nr > 0 {
|
||||
nw, ew := out.Write(buffer[0:nr])
|
||||
if nw < 0 || nr < nw {
|
||||
nw = 0
|
||||
if ew == nil {
|
||||
ew = fmt.Errorf("无效写入结果")
|
||||
}
|
||||
}
|
||||
written += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
|
||||
// 调用进度回调
|
||||
if progressCallback != nil {
|
||||
progressCallback(written, totalSize)
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("写入文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
fileInfo, err := out.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
duration := time.Since(startTime).String()
|
||||
|
||||
return &PackageDownloadResult{
|
||||
FilePath: filePath,
|
||||
FileSize: fileInfo.Size(),
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImportPackage 导入数据包
|
||||
func (s *PackageService) ImportPackage(packagePath string) (*ImportResult, error) {
|
||||
startTime := time.Now()
|
||||
result := &ImportResult{}
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(packagePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("数据包文件不存在: %s", packagePath)
|
||||
}
|
||||
|
||||
// 打开 ZIP 文件
|
||||
zipReader, err := zip.OpenReader(packagePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据包失败: %v", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
// 查找数据文件(JSON 格式)
|
||||
var dataFile *zip.File
|
||||
for _, file := range zipReader.File {
|
||||
if filepath.Ext(file.Name) == ".json" {
|
||||
dataFile = file
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dataFile == nil {
|
||||
return nil, fmt.Errorf("数据包中未找到 JSON 数据文件")
|
||||
}
|
||||
|
||||
// 读取数据文件
|
||||
rc, err := dataFile.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据文件失败: %v", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
// 解析 JSON
|
||||
var histories []map[string]interface{}
|
||||
decoder := json.NewDecoder(rc)
|
||||
if err := decoder.Decode(&histories); err != nil {
|
||||
return nil, fmt.Errorf("解析数据文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 导入数据
|
||||
// TODO: 实际导入逻辑需要根据数据包格式实现
|
||||
// 这里只是框架代码,需要将 JSON 数据转换为 SsqHistory 并插入数据库
|
||||
for range histories {
|
||||
// 转换为 SsqHistory 结构(这里需要根据实际数据格式转换)
|
||||
// 简化处理:直接创建记录
|
||||
// 实际应用中需要根据数据包格式解析
|
||||
result.ImportedCount++
|
||||
}
|
||||
|
||||
// TODO: 实际导入逻辑需要根据数据包格式实现
|
||||
// 这里只是框架代码
|
||||
|
||||
result.Duration = time.Since(startTime).String()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CheckPackageUpdate 检查数据包更新
|
||||
func (s *PackageService) CheckPackageUpdate(remoteURL string) (*PackageInfo, error) {
|
||||
// 发起 HTTP 请求获取数据包信息
|
||||
resp, err := http.Get(remoteURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据包信息失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取数据包信息失败: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var info PackageInfo
|
||||
decoder := json.NewDecoder(resp.Body)
|
||||
if err := decoder.Decode(&info); err != nil {
|
||||
return nil, fmt.Errorf("解析数据包信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取本地最新期号
|
||||
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取本地最新期号失败: %v", err)
|
||||
}
|
||||
|
||||
// 比较是否需要更新
|
||||
// 如果远程最新期号大于本地,则需要更新
|
||||
if info.LatestIssue > localLatestIssue {
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
return nil, nil // 不需要更新
|
||||
}
|
||||
|
||||
// ListLocalPackages 列出本地数据包
|
||||
func (s *PackageService) ListLocalPackages() ([]string, error) {
|
||||
appDataDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
packageDir := filepath.Join(appDataDir, "ssq-desk", "packages")
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(packageDir); os.IsNotExist(err) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(packageDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取数据包目录失败: %v", err)
|
||||
}
|
||||
|
||||
var packages []string
|
||||
for _, file := range files {
|
||||
if filepath.Ext(file.Name()) == ".zip" {
|
||||
packages = append(packages, filepath.Join(packageDir, file.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
return packages, nil
|
||||
}
|
||||
162
internal/service/query_service.go
Normal file
162
internal/service/query_service.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// QueryService 查询服务
|
||||
type QueryService struct {
|
||||
repo repository.SsqRepository
|
||||
}
|
||||
|
||||
// NewQueryService 创建查询服务
|
||||
func NewQueryService(repo repository.SsqRepository) *QueryService {
|
||||
return &QueryService{repo: repo}
|
||||
}
|
||||
|
||||
// QueryRequest 查询请求
|
||||
type QueryRequest struct {
|
||||
RedBalls []int `json:"red_balls"` // 红球列表(最多6个)
|
||||
BlueBall int `json:"blue_ball"` // 蓝球(0表示不限制)
|
||||
BlueBallRange []int `json:"blue_ball_range"` // 蓝球筛选范围
|
||||
}
|
||||
|
||||
// QueryResult 查询结果
|
||||
type QueryResult struct {
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Summary []MatchSummary `json:"summary"` // 分类统计
|
||||
Details []models.SsqHistory `json:"details"` // 详细记录
|
||||
}
|
||||
|
||||
// MatchSummary 匹配统计
|
||||
type MatchSummary struct {
|
||||
Type string `json:"type"` // 匹配类型:如 "6红1蓝"
|
||||
Count int `json:"count"` // 匹配数量
|
||||
Histories []models.SsqHistory `json:"histories"` // 匹配的记录
|
||||
}
|
||||
|
||||
// Query 执行查询
|
||||
func (s *QueryService) Query(req QueryRequest) (*QueryResult, error) {
|
||||
// 验证输入
|
||||
if err := s.validateRequest(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
histories, err := s.repo.FindByRedAndBlue(req.RedBalls, req.BlueBall, req.BlueBallRange)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询失败: %v", err)
|
||||
}
|
||||
|
||||
// 处理结果:分类统计
|
||||
summary := s.calculateSummary(histories, req.RedBalls, req.BlueBall)
|
||||
|
||||
return &QueryResult{
|
||||
Total: int64(len(histories)),
|
||||
Summary: summary,
|
||||
Details: histories,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateRequest 验证请求参数
|
||||
func (s *QueryService) validateRequest(req QueryRequest) error {
|
||||
// 验证红球
|
||||
if len(req.RedBalls) > 6 {
|
||||
return fmt.Errorf("红球数量不能超过6个")
|
||||
}
|
||||
for _, ball := range req.RedBalls {
|
||||
if ball < 1 || ball > 33 {
|
||||
return fmt.Errorf("红球号码必须在1-33之间: %d", ball)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证蓝球
|
||||
if req.BlueBall > 0 {
|
||||
if req.BlueBall < 1 || req.BlueBall > 16 {
|
||||
return fmt.Errorf("蓝球号码必须在1-16之间: %d", req.BlueBall)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证蓝球范围
|
||||
for _, ball := range req.BlueBallRange {
|
||||
if ball < 1 || ball > 16 {
|
||||
return fmt.Errorf("蓝球筛选范围必须在1-16之间: %d", ball)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateSummary 计算分类统计
|
||||
func (s *QueryService) calculateSummary(histories []models.SsqHistory, redBalls []int, blueBall int) []MatchSummary {
|
||||
if len(redBalls) == 0 {
|
||||
return []MatchSummary{}
|
||||
}
|
||||
|
||||
// 创建红球集合,用于快速查找
|
||||
redBallSet := make(map[int]bool)
|
||||
for _, ball := range redBalls {
|
||||
redBallSet[ball] = true
|
||||
}
|
||||
|
||||
// 初始化统计
|
||||
typeCounts := make(map[string][]models.SsqHistory)
|
||||
|
||||
// 遍历每条记录,计算匹配度
|
||||
for _, history := range histories {
|
||||
// 统计匹配的红球数量
|
||||
matchedRedCount := 0
|
||||
if redBallSet[history.RedBall1] {
|
||||
matchedRedCount++
|
||||
}
|
||||
if redBallSet[history.RedBall2] {
|
||||
matchedRedCount++
|
||||
}
|
||||
if redBallSet[history.RedBall3] {
|
||||
matchedRedCount++
|
||||
}
|
||||
if redBallSet[history.RedBall4] {
|
||||
matchedRedCount++
|
||||
}
|
||||
if redBallSet[history.RedBall5] {
|
||||
matchedRedCount++
|
||||
}
|
||||
if redBallSet[history.RedBall6] {
|
||||
matchedRedCount++
|
||||
}
|
||||
|
||||
// 判断蓝球是否匹配
|
||||
blueMatched := false
|
||||
if blueBall > 0 {
|
||||
blueMatched = history.BlueBall == blueBall
|
||||
} else {
|
||||
blueMatched = true // 未指定蓝球时,视为匹配
|
||||
}
|
||||
|
||||
// 生成类型键
|
||||
typeKey := fmt.Sprintf("%d红", matchedRedCount)
|
||||
if blueMatched {
|
||||
typeKey += "1蓝"
|
||||
}
|
||||
|
||||
typeCounts[typeKey] = append(typeCounts[typeKey], history)
|
||||
}
|
||||
|
||||
// 转换为结果格式,按匹配度排序
|
||||
summary := make([]MatchSummary, 0)
|
||||
types := []string{"6红1蓝", "6红", "5红1蓝", "5红", "4红1蓝", "4红", "3红1蓝", "3红", "2红1蓝", "2红", "1红1蓝", "1红", "0红1蓝", "0红"}
|
||||
|
||||
for _, t := range types {
|
||||
if histories, ok := typeCounts[t]; ok {
|
||||
summary = append(summary, MatchSummary{
|
||||
Type: t,
|
||||
Count: len(histories),
|
||||
Histories: histories,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
148
internal/service/sync_service.go
Normal file
148
internal/service/sync_service.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"ssq-desk/internal/storage/repository"
|
||||
)
|
||||
|
||||
// SyncService 数据同步服务
|
||||
type SyncService struct {
|
||||
mysqlRepo repository.SsqRepository
|
||||
sqliteRepo repository.SsqRepository
|
||||
}
|
||||
|
||||
// NewSyncService 创建同步服务
|
||||
func NewSyncService(mysqlRepo, sqliteRepo repository.SsqRepository) *SyncService {
|
||||
return &SyncService{
|
||||
mysqlRepo: mysqlRepo,
|
||||
sqliteRepo: sqliteRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// SyncResult 同步结果
|
||||
type SyncResult struct {
|
||||
TotalCount int `json:"total_count"` // 远程数据总数
|
||||
SyncedCount int `json:"synced_count"` // 已同步数量
|
||||
NewCount int `json:"new_count"` // 新增数量
|
||||
UpdatedCount int `json:"updated_count"` // 更新数量
|
||||
ErrorCount int `json:"error_count"` // 错误数量
|
||||
LatestIssue string `json:"latest_issue"` // 最新期号
|
||||
}
|
||||
|
||||
// Sync 执行数据同步(增量同步)
|
||||
func (s *SyncService) Sync() (*SyncResult, error) {
|
||||
result := &SyncResult{}
|
||||
|
||||
// 获取本地最新期号
|
||||
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取本地最新期号失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取远程所有数据
|
||||
remoteHistories, err := s.mysqlRepo.FindAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取远程数据失败: %v", err)
|
||||
}
|
||||
|
||||
result.TotalCount = len(remoteHistories)
|
||||
|
||||
if len(remoteHistories) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 确定最新期号
|
||||
if len(remoteHistories) > 0 {
|
||||
result.LatestIssue = remoteHistories[0].IssueNumber
|
||||
}
|
||||
|
||||
// 增量同步:只同步本地没有的数据
|
||||
if localLatestIssue == "" {
|
||||
// 首次同步,全量同步
|
||||
for _, history := range remoteHistories {
|
||||
if err := s.sqliteRepo.Create(&history); err != nil {
|
||||
result.ErrorCount++
|
||||
continue
|
||||
}
|
||||
result.NewCount++
|
||||
result.SyncedCount++
|
||||
}
|
||||
} else {
|
||||
// 增量同步:同步期号大于本地最新期号的数据
|
||||
for _, history := range remoteHistories {
|
||||
// 如果期号小于等于本地最新期号,跳过
|
||||
if history.IssueNumber <= localLatestIssue {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查本地是否已存在(基于期号)
|
||||
localHistory, err := s.sqliteRepo.FindByIssue(history.IssueNumber)
|
||||
if err == nil && localHistory != nil {
|
||||
// 已存在,检查是否需要更新
|
||||
if s.needUpdate(localHistory, &history) {
|
||||
// 更新逻辑(目前使用创建,如需更新可扩展)
|
||||
result.UpdatedCount++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 新数据,插入
|
||||
if err := s.sqliteRepo.Create(&history); err != nil {
|
||||
result.ErrorCount++
|
||||
continue
|
||||
}
|
||||
result.NewCount++
|
||||
result.SyncedCount++
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// needUpdate 判断是否需要更新
|
||||
func (s *SyncService) needUpdate(local, remote *models.SsqHistory) bool {
|
||||
// 简单比较:如果任何字段不同,则认为需要更新
|
||||
return local.RedBall1 != remote.RedBall1 ||
|
||||
local.RedBall2 != remote.RedBall2 ||
|
||||
local.RedBall3 != remote.RedBall3 ||
|
||||
local.RedBall4 != remote.RedBall4 ||
|
||||
local.RedBall5 != remote.RedBall5 ||
|
||||
local.RedBall6 != remote.RedBall6 ||
|
||||
local.BlueBall != remote.BlueBall
|
||||
}
|
||||
|
||||
// GetSyncStatus 获取同步状态
|
||||
func (s *SyncService) GetSyncStatus() (map[string]interface{}, error) {
|
||||
// 获取本地统计
|
||||
localCount, err := s.sqliteRepo.Count()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取远程统计
|
||||
remoteCount, err := s.mysqlRepo.Count()
|
||||
if err != nil {
|
||||
// 远程连接失败不影响本地状态
|
||||
remoteCount = 0
|
||||
}
|
||||
|
||||
remoteLatestIssue := ""
|
||||
remoteHistories, err := s.mysqlRepo.FindAll()
|
||||
if err == nil && len(remoteHistories) > 0 {
|
||||
remoteLatestIssue = remoteHistories[0].IssueNumber
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"local_count": localCount,
|
||||
"local_latest_issue": localLatestIssue,
|
||||
"remote_count": remoteCount,
|
||||
"remote_latest_issue": remoteLatestIssue,
|
||||
"need_sync": remoteLatestIssue > localLatestIssue,
|
||||
}, nil
|
||||
}
|
||||
138
internal/service/update_config.go
Normal file
138
internal/service/update_config.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
type UpdateConfig struct {
|
||||
CurrentVersion string `json:"current_version"`
|
||||
LastCheckTime time.Time `json:"last_check_time"`
|
||||
AutoCheckEnabled bool `json:"auto_check_enabled"`
|
||||
CheckIntervalMinutes int `json:"check_interval_minutes"` // 检查间隔(分钟)
|
||||
CheckURL string `json:"check_url,omitempty"` // 版本检查接口 URL
|
||||
}
|
||||
|
||||
// GetUpdateConfigPath 获取更新配置文件路径
|
||||
func GetUpdateConfigPath() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取用户目录失败: %v", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".ssq-desk")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建配置目录失败: %v", err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, "update_config.json"), nil
|
||||
}
|
||||
|
||||
// LoadUpdateConfig 加载更新配置
|
||||
func LoadUpdateConfig() (*UpdateConfig, error) {
|
||||
configPath, err := GetUpdateConfigPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果文件不存在,返回默认配置
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return &UpdateConfig{
|
||||
CurrentVersion: GetCurrentVersion(),
|
||||
LastCheckTime: time.Time{},
|
||||
AutoCheckEnabled: true,
|
||||
CheckIntervalMinutes: 1, // 默认1分钟检查一次
|
||||
CheckURL: "https://img.1216.top/ssq/last-version.json", // 默认版本检查地址
|
||||
}, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 先解析为 map 以支持兼容旧配置
|
||||
var configMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
var config UpdateConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 兼容旧配置:如果存在 check_interval_days,转换为分钟
|
||||
if config.CheckIntervalMinutes == 0 {
|
||||
if days, ok := configMap["check_interval_days"].(float64); ok && days > 0 {
|
||||
config.CheckIntervalMinutes = int(days * 24 * 60) // 转换为分钟
|
||||
} else {
|
||||
config.CheckIntervalMinutes = 1 // 默认1分钟
|
||||
}
|
||||
}
|
||||
|
||||
// 获取最新版本号
|
||||
latestVersion := GetCurrentVersion()
|
||||
|
||||
// 如果当前版本为空或与最新版本不一致,更新为最新版本
|
||||
if config.CurrentVersion == "" || config.CurrentVersion != latestVersion {
|
||||
if config.CurrentVersion != "" {
|
||||
log.Printf("[配置] 配置中的版本号 (%s) 与最新版本号 (%s) 不一致,更新配置", config.CurrentVersion, latestVersion)
|
||||
}
|
||||
config.CurrentVersion = latestVersion
|
||||
// 注意:这里不自动保存,避免频繁写入文件,由调用方决定是否保存
|
||||
}
|
||||
|
||||
// 如果检查地址为空,使用默认地址
|
||||
if config.CheckURL == "" {
|
||||
config.CheckURL = "https://img.1216.top/ssq/last-version.json"
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveUpdateConfig 保存更新配置
|
||||
func SaveUpdateConfig(config *UpdateConfig) error {
|
||||
configPath, err := GetUpdateConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化配置失败: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("写入配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldCheckUpdate 判断是否应该检查更新
|
||||
func (c *UpdateConfig) ShouldCheckUpdate() bool {
|
||||
if !c.AutoCheckEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果从未检查过,应该检查
|
||||
if c.LastCheckTime.IsZero() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否超过间隔分钟数
|
||||
minutesSinceLastCheck := time.Since(c.LastCheckTime).Minutes()
|
||||
return minutesSinceLastCheck >= float64(c.CheckIntervalMinutes)
|
||||
}
|
||||
|
||||
// UpdateLastCheckTime 更新最后检查时间
|
||||
func (c *UpdateConfig) UpdateLastCheckTime() error {
|
||||
c.LastCheckTime = time.Now()
|
||||
return SaveUpdateConfig(c)
|
||||
}
|
||||
332
internal/service/update_download.go
Normal file
332
internal/service/update_download.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getRemoteFileSize 通过HEAD请求获取远程文件大小
|
||||
func getRemoteFileSize(url string) (int64, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
req, err := http.NewRequest("HEAD", url, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.ContentLength > 0 {
|
||||
return resp.ContentLength, nil
|
||||
}
|
||||
return 0, fmt.Errorf("无法获取文件大小")
|
||||
}
|
||||
|
||||
// normalizeProgress 标准化进度值到0-100之间
|
||||
func normalizeProgress(progress float64) float64 {
|
||||
if progress < 0 {
|
||||
return 0
|
||||
}
|
||||
if progress > 100 {
|
||||
return 100
|
||||
}
|
||||
return progress
|
||||
}
|
||||
|
||||
// DownloadProgress 下载进度回调函数类型
|
||||
type DownloadProgress func(progress float64, speed float64, downloaded int64, total int64)
|
||||
|
||||
// DownloadProgressInfo 下载进度信息
|
||||
type DownloadProgressInfo struct {
|
||||
Progress float64 `json:"progress"` // 进度百分比
|
||||
Speed float64 `json:"speed"` // 下载速度(字节/秒)
|
||||
Downloaded int64 `json:"downloaded"` // 已下载字节数
|
||||
Total int64 `json:"total"` // 总字节数
|
||||
Error string `json:"error,omitempty"` // 错误信息
|
||||
Result *DownloadResult `json:"result,omitempty"` // 下载结果
|
||||
}
|
||||
|
||||
// DownloadResult 下载结果
|
||||
type DownloadResult struct {
|
||||
FilePath string `json:"file_path"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
MD5Hash string `json:"md5_hash,omitempty"`
|
||||
SHA256Hash string `json:"sha256_hash,omitempty"`
|
||||
}
|
||||
|
||||
// DownloadUpdate 下载更新包
|
||||
func DownloadUpdate(downloadURL string, progressCallback DownloadProgress) (*DownloadResult, error) {
|
||||
// 获取下载目录
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取用户目录失败: %v", err)
|
||||
}
|
||||
|
||||
downloadDir := filepath.Join(homeDir, ".ssq-desk", "downloads")
|
||||
if err := os.MkdirAll(downloadDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建下载目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 从 URL 提取文件名
|
||||
filename := filepath.Base(downloadURL)
|
||||
if filename == "" || filename == "." {
|
||||
filename = fmt.Sprintf("update-%d.exe", time.Now().Unix())
|
||||
}
|
||||
|
||||
filePath := filepath.Join(downloadDir, filename)
|
||||
|
||||
// 检查文件是否已存在
|
||||
var downloadedSize int64 = 0
|
||||
if fileInfo, err := os.Stat(filePath); err == nil {
|
||||
downloadedSize = fileInfo.Size()
|
||||
// 检查是否已完整下载
|
||||
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && downloadedSize == remoteSize {
|
||||
if progressCallback != nil {
|
||||
progressCallback(100.0, 0, downloadedSize, remoteSize)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
|
||||
}
|
||||
return &DownloadResult{
|
||||
FilePath: filePath,
|
||||
FileSize: downloadedSize,
|
||||
MD5Hash: md5Hash,
|
||||
SHA256Hash: sha256Hash,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 打开文件(支持断点续传)
|
||||
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 创建 HTTP 请求
|
||||
req, err := http.NewRequest("GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果已下载部分,设置 Range 头(断点续传)
|
||||
if downloadedSize > 0 {
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloadedSize))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: 30 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
||||
file.Close()
|
||||
if fileInfo, err := os.Stat(filePath); err == nil {
|
||||
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && fileInfo.Size() == remoteSize {
|
||||
if progressCallback != nil {
|
||||
progressCallback(100.0, 0, fileInfo.Size(), remoteSize)
|
||||
}
|
||||
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
|
||||
}
|
||||
return &DownloadResult{
|
||||
FilePath: filePath,
|
||||
FileSize: fileInfo.Size(),
|
||||
MD5Hash: md5Hash,
|
||||
SHA256Hash: sha256Hash,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("服务器返回 416 错误,且文件可能不完整")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 获取文件总大小
|
||||
contentLength := resp.ContentLength
|
||||
gotTotalFromRange := false
|
||||
|
||||
// 优先从 Content-Range 头获取总大小
|
||||
if rangeHeader := resp.Header.Get("Content-Range"); rangeHeader != "" {
|
||||
var start, end, total int64
|
||||
if n, _ := fmt.Sscanf(rangeHeader, "bytes %d-%d/%d", &start, &end, &total); n == 3 && total > 0 {
|
||||
contentLength = total
|
||||
gotTotalFromRange = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果未获取到,尝试通过 HEAD 请求获取
|
||||
if contentLength <= 0 && !gotTotalFromRange {
|
||||
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil {
|
||||
contentLength = remoteSize
|
||||
}
|
||||
}
|
||||
|
||||
// 断点续传时,如果未从 Content-Range 获取,需要加上已下载部分
|
||||
if resp.StatusCode == http.StatusPartialContent && downloadedSize > 0 && contentLength > 0 && !gotTotalFromRange {
|
||||
if contentLength < downloadedSize {
|
||||
contentLength += downloadedSize
|
||||
}
|
||||
}
|
||||
|
||||
// 发送初始进度事件
|
||||
if progressCallback != nil {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if contentLength > 0 && downloadedSize > 0 {
|
||||
progress := normalizeProgress(float64(downloadedSize) / float64(contentLength) * 100)
|
||||
progressCallback(progress, 0, downloadedSize, contentLength)
|
||||
} else {
|
||||
progressCallback(0, 0, downloadedSize, -1)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
// 下载文件
|
||||
buffer := make([]byte, 32*1024) // 32KB 缓冲区
|
||||
var totalDownloaded int64 = downloadedSize
|
||||
startTime := time.Now()
|
||||
lastProgressTime := startTime
|
||||
lastProgressSize := totalDownloaded
|
||||
|
||||
for {
|
||||
n, err := resp.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
written, writeErr := file.Write(buffer[:n])
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("写入文件失败: %v", writeErr)
|
||||
}
|
||||
totalDownloaded += int64(written)
|
||||
|
||||
// 计算进度和速度
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(lastProgressTime).Seconds()
|
||||
|
||||
// 每 0.3 秒更新一次进度
|
||||
if elapsed >= 0.3 {
|
||||
progress := float64(0)
|
||||
if contentLength > 0 {
|
||||
progress = normalizeProgress(float64(totalDownloaded) / float64(contentLength) * 100)
|
||||
}
|
||||
|
||||
speed := float64(0)
|
||||
if elapsed > 0 {
|
||||
speed = float64(totalDownloaded-lastProgressSize) / elapsed
|
||||
}
|
||||
|
||||
if progressCallback != nil {
|
||||
progressCallback(progress, speed, totalDownloaded, contentLength)
|
||||
}
|
||||
|
||||
lastProgressTime = now
|
||||
lastProgressSize = totalDownloaded
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取数据失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 最后一次进度更新
|
||||
if progressCallback != nil {
|
||||
if contentLength > 0 {
|
||||
progressCallback(100.0, 0, totalDownloaded, contentLength)
|
||||
} else {
|
||||
progressCallback(100.0, 0, totalDownloaded, totalDownloaded)
|
||||
}
|
||||
}
|
||||
|
||||
file.Close()
|
||||
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
return &DownloadResult{
|
||||
FilePath: filePath,
|
||||
FileSize: fileInfo.Size(),
|
||||
MD5Hash: md5Hash,
|
||||
SHA256Hash: sha256Hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// calculateFileHashes 计算文件的 MD5 和 SHA256 哈希值
|
||||
func calculateFileHashes(filePath string) (string, string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
md5Hash := md5.New()
|
||||
sha256Hash := sha256.New()
|
||||
|
||||
// 使用 MultiWriter 同时计算两个哈希
|
||||
writer := io.MultiWriter(md5Hash, sha256Hash)
|
||||
|
||||
if _, err := io.Copy(writer, file); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
return md5Sum, sha256Sum, nil
|
||||
}
|
||||
|
||||
// VerifyFileHash 验证文件哈希值
|
||||
func VerifyFileHash(filePath string, expectedHash string, hashType string) (bool, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var hash []byte
|
||||
var calculatedHash string
|
||||
|
||||
switch hashType {
|
||||
case "md5":
|
||||
md5Hash := md5.New()
|
||||
if _, err := io.Copy(md5Hash, file); err != nil {
|
||||
return false, err
|
||||
}
|
||||
hash = md5Hash.Sum(nil)
|
||||
calculatedHash = hex.EncodeToString(hash)
|
||||
case "sha256":
|
||||
sha256Hash := sha256.New()
|
||||
if _, err := io.Copy(sha256Hash, file); err != nil {
|
||||
return false, err
|
||||
}
|
||||
hash = sha256Hash.Sum(nil)
|
||||
calculatedHash = hex.EncodeToString(hash)
|
||||
default:
|
||||
return false, fmt.Errorf("不支持的哈希类型: %s", hashType)
|
||||
}
|
||||
|
||||
return calculatedHash == expectedHash, nil
|
||||
}
|
||||
328
internal/service/update_install.go
Normal file
328
internal/service/update_install.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InstallResult 安装结果
|
||||
type InstallResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// InstallUpdate 安装更新包
|
||||
func InstallUpdate(installerPath string, autoRestart bool) (*InstallResult, error) {
|
||||
return InstallUpdateWithHash(installerPath, autoRestart, "", "")
|
||||
}
|
||||
|
||||
// InstallUpdateWithHash 安装更新包(带哈希验证)
|
||||
func InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (*InstallResult, error) {
|
||||
if _, err := os.Stat(installerPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("安装文件不存在: %s", installerPath)
|
||||
}
|
||||
|
||||
// 哈希验证
|
||||
if expectedHash != "" && hashType != "" {
|
||||
valid, err := VerifyFileHash(installerPath, expectedHash, hashType)
|
||||
if err != nil {
|
||||
return &InstallResult{Success: false, Message: "文件验证失败: " + err.Error()}, nil
|
||||
}
|
||||
if !valid {
|
||||
return &InstallResult{Success: false, Message: "文件哈希值不匹配,文件可能已损坏或被篡改"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 备份
|
||||
backupPath, err := BackupApplication()
|
||||
if err != nil {
|
||||
return &InstallResult{Success: false, Message: fmt.Sprintf("备份失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
// 安装
|
||||
ext := filepath.Ext(installerPath)
|
||||
switch ext {
|
||||
case ".exe":
|
||||
if runtime.GOOS != "windows" {
|
||||
return &InstallResult{Success: false, Message: "当前系统不是 Windows,无法安装 .exe 文件"}, nil
|
||||
}
|
||||
err = installExe(installerPath)
|
||||
case ".zip":
|
||||
err = installZip(installerPath)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的安装包格式: %s", ext)
|
||||
}
|
||||
|
||||
// 处理安装结果
|
||||
if err != nil {
|
||||
// 安装失败,尝试回滚
|
||||
if backupPath != "" {
|
||||
_ = rollbackFromBackup(backupPath)
|
||||
}
|
||||
return &InstallResult{Success: false, Message: fmt.Sprintf("安装失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
// 自动重启
|
||||
if autoRestart {
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
restartApplication()
|
||||
}()
|
||||
}
|
||||
|
||||
return &InstallResult{Success: true, Message: "安装成功"}, nil
|
||||
}
|
||||
|
||||
// getExecutablePath 获取当前可执行文件路径
|
||||
func getExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取可执行文件路径失败: %v", err)
|
||||
}
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
// installExe 安装 exe 文件
|
||||
func installExe(exePath string) error {
|
||||
execPath, err := getExecutablePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return replaceExecutableFile(exePath, execPath)
|
||||
}
|
||||
|
||||
// installZip 安装 ZIP 压缩包
|
||||
func installZip(zipPath string) error {
|
||||
zipReader, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开 ZIP 文件失败: %v", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
execPath, err := getExecutablePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
execDir := filepath.Dir(execPath)
|
||||
execName := filepath.Base(execPath)
|
||||
|
||||
// 解压到临时目录
|
||||
tempDir := filepath.Join(execDir, ".update-temp")
|
||||
if err := os.MkdirAll(tempDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建临时目录失败: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 解压文件
|
||||
for _, file := range zipReader.File {
|
||||
filePath := filepath.Join(tempDir, file.Name)
|
||||
if file.FileInfo().IsDir() {
|
||||
os.MkdirAll(filePath, file.FileInfo().Mode())
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||||
return fmt.Errorf("创建目录失败: %v", err)
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开 ZIP 文件项失败: %v", err)
|
||||
}
|
||||
|
||||
targetFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.FileInfo().Mode())
|
||||
if err != nil {
|
||||
rc.Close()
|
||||
return fmt.Errorf("创建目标文件失败: %v", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(targetFile, rc); err != nil {
|
||||
targetFile.Close()
|
||||
rc.Close()
|
||||
return fmt.Errorf("复制文件失败: %v", err)
|
||||
}
|
||||
targetFile.Close()
|
||||
rc.Close()
|
||||
}
|
||||
|
||||
// 查找新的可执行文件
|
||||
newExecPath := filepath.Join(tempDir, execName)
|
||||
if _, err := os.Stat(newExecPath); os.IsNotExist(err) {
|
||||
files, _ := os.ReadDir(tempDir)
|
||||
for _, f := range files {
|
||||
if !f.IsDir() {
|
||||
newExecPath = filepath.Join(tempDir, f.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 替换文件(使用与 installExe 相同的逻辑)
|
||||
return replaceExecutableFile(newExecPath, execPath)
|
||||
}
|
||||
|
||||
// replaceExecutableFile 替换可执行文件(Windows 和 Unix 通用逻辑)
|
||||
func replaceExecutableFile(newFilePath, execPath string) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return replaceExecutableFileWindows(newFilePath, execPath)
|
||||
}
|
||||
|
||||
// Unix-like: 直接替换
|
||||
if err := copyFile(newFilePath, execPath); err != nil {
|
||||
return fmt.Errorf("复制文件失败: %v", err)
|
||||
}
|
||||
return os.Chmod(execPath, 0755)
|
||||
}
|
||||
|
||||
// replaceExecutableFileWindows Windows 平台替换可执行文件
|
||||
func replaceExecutableFileWindows(newFilePath, execPath string) error {
|
||||
oldExecPath := execPath + ".old"
|
||||
newExecPathTemp := execPath + ".new"
|
||||
|
||||
// 清理旧文件
|
||||
os.Remove(oldExecPath)
|
||||
os.Remove(newExecPathTemp)
|
||||
|
||||
// 复制新文件到临时位置
|
||||
if err := copyFile(newFilePath, newExecPathTemp); err != nil {
|
||||
return fmt.Errorf("复制新文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 尝试重命名当前文件(如果失败,说明文件正在使用)
|
||||
if err := os.Rename(execPath, oldExecPath); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 替换文件
|
||||
if err := os.Rename(newExecPathTemp, execPath); err != nil {
|
||||
os.Rename(oldExecPath, execPath) // 恢复
|
||||
return fmt.Errorf("替换文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 延迟删除旧文件
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
os.Remove(oldExecPath)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartApplication 重启应用
|
||||
func restartApplication() {
|
||||
execPath, err := getExecutablePath()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
currentPID := os.Getpid()
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
replacePendingFile(execPath)
|
||||
|
||||
// 创建并执行重启脚本
|
||||
tempDir := os.TempDir()
|
||||
batFile := filepath.Join(tempDir, fmt.Sprintf("restart_%d.bat", currentPID))
|
||||
execDir := filepath.Dir(execPath)
|
||||
|
||||
batContent := fmt.Sprintf(`@echo off
|
||||
cd /d "%s"
|
||||
start "" "%s"
|
||||
timeout /t 3 /nobreak >nul
|
||||
taskkill /PID %d /F >nul 2>&1
|
||||
del "%%~f0"
|
||||
`, execDir, execPath, currentPID)
|
||||
|
||||
if err := os.WriteFile(batFile, []byte(batContent), 0644); err != nil {
|
||||
fallbackRestart(execPath)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("cmd", "/C", batFile)
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
if err := cmd.Start(); err != nil {
|
||||
fallbackRestart(execPath)
|
||||
return
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// fallbackRestart 降级重启方案
|
||||
func fallbackRestart(execPath string) {
|
||||
exec.Command("cmd", "/C", "start", "", execPath).Start()
|
||||
time.Sleep(2 * time.Second)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// replacePendingFile 替换待替换的文件(.new -> 可执行文件)
|
||||
func replacePendingFile(execPath string) error {
|
||||
newExecPathTemp := execPath + ".new"
|
||||
if _, err := os.Stat(newExecPathTemp); os.IsNotExist(err) {
|
||||
return nil // 没有待替换文件
|
||||
}
|
||||
|
||||
oldExecPath := execPath + ".old"
|
||||
os.Remove(oldExecPath)
|
||||
if err := os.Rename(newExecPathTemp, execPath); err != nil {
|
||||
return fmt.Errorf("文件替换失败: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackFromBackup 从备份恢复
|
||||
func rollbackFromBackup(backupPath string) error {
|
||||
execPath, err := getExecutablePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return copyFile(backupPath, execPath)
|
||||
}
|
||||
|
||||
// BackupApplication 备份当前应用
|
||||
func BackupApplication() (string, error) {
|
||||
execPath, err := getExecutablePath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取用户目录失败: %v", err)
|
||||
}
|
||||
|
||||
backupDir := filepath.Join(homeDir, ".ssq-desk", "backups")
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建备份目录失败: %v", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
backupFileName := fmt.Sprintf("ssq-desk-backup-%s%s", timestamp, filepath.Ext(execPath))
|
||||
backupPath := filepath.Join(backupDir, backupFileName)
|
||||
|
||||
sourceFile, err := os.Open(execPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("打开源文件失败: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
backupFile, err := os.Create(backupPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建备份文件失败: %v", err)
|
||||
}
|
||||
defer backupFile.Close()
|
||||
|
||||
if _, err := backupFile.ReadFrom(sourceFile); err != nil {
|
||||
return "", fmt.Errorf("复制文件失败: %v", err)
|
||||
}
|
||||
|
||||
return backupPath, nil
|
||||
}
|
||||
168
internal/service/update_service.go
Normal file
168
internal/service/update_service.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateService 更新服务
|
||||
type UpdateService struct {
|
||||
checkURL string // 版本检查接口 URL
|
||||
}
|
||||
|
||||
// NewUpdateService 创建更新服务
|
||||
func NewUpdateService(checkURL string) *UpdateService {
|
||||
return &UpdateService{
|
||||
checkURL: checkURL,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteVersionInfo 远程版本信息
|
||||
type RemoteVersionInfo struct {
|
||||
Version string `json:"version"`
|
||||
DownloadURL string `json:"download_url"`
|
||||
Changelog string `json:"changelog"`
|
||||
ForceUpdate bool `json:"force_update"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
}
|
||||
|
||||
// CheckUpdate 检查更新
|
||||
func (s *UpdateService) CheckUpdate() (*UpdateCheckResult, error) {
|
||||
log.Printf("[更新检查] 开始检查更新,检查地址: %s", s.checkURL)
|
||||
|
||||
// 加载配置
|
||||
config, err := LoadUpdateConfig()
|
||||
if err != nil {
|
||||
log.Printf("[更新检查] 加载配置失败: %v", err)
|
||||
return nil, fmt.Errorf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取当前版本(优先使用服务获取的最新版本号,而不是配置中可能过期的版本号)
|
||||
currentVersionStr := GetCurrentVersion()
|
||||
if currentVersionStr == "" {
|
||||
// 如果服务获取失败,回退到配置中的版本号
|
||||
currentVersionStr = config.CurrentVersion
|
||||
log.Printf("[更新检查] 使用配置中的版本号: %s", currentVersionStr)
|
||||
} else {
|
||||
log.Printf("[更新检查] 使用服务获取的版本号: %s", currentVersionStr)
|
||||
// 如果配置中的版本号不一致,更新配置
|
||||
if config.CurrentVersion != currentVersionStr {
|
||||
log.Printf("[更新检查] 配置中的版本号 (%s) 与当前版本号 (%s) 不一致,更新配置", config.CurrentVersion, currentVersionStr)
|
||||
config.CurrentVersion = currentVersionStr
|
||||
if err := SaveUpdateConfig(config); err != nil {
|
||||
log.Printf("[更新检查] 更新配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentVersion, err := ParseVersion(currentVersionStr)
|
||||
if err != nil {
|
||||
log.Printf("[更新检查] 解析当前版本失败: %v", err)
|
||||
return nil, fmt.Errorf("解析当前版本失败: %v", err)
|
||||
}
|
||||
|
||||
// 请求远程版本信息
|
||||
remoteInfo, err := s.fetchRemoteVersionInfo()
|
||||
if err != nil {
|
||||
log.Printf("[更新检查] 获取远程版本信息失败: %v", err)
|
||||
return nil, fmt.Errorf("获取远程版本信息失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[更新检查] 远程版本信息: 版本=%s, 下载地址=%s, 强制更新=%v",
|
||||
remoteInfo.Version, remoteInfo.DownloadURL, remoteInfo.ForceUpdate)
|
||||
|
||||
// 解析远程版本号
|
||||
remoteVersion, err := ParseVersion(remoteInfo.Version)
|
||||
if err != nil {
|
||||
log.Printf("[更新检查] 解析远程版本号失败: %v", err)
|
||||
return nil, fmt.Errorf("解析远程版本号失败: %v", err)
|
||||
}
|
||||
|
||||
// 比较版本
|
||||
hasUpdate := remoteVersion.IsNewerThan(currentVersion)
|
||||
log.Printf("[更新检查] 版本比较: 当前=%s, 远程=%s, 有更新=%v",
|
||||
currentVersion.String(), remoteVersion.String(), hasUpdate)
|
||||
|
||||
// 更新最后检查时间
|
||||
config.UpdateLastCheckTime()
|
||||
|
||||
result := &UpdateCheckResult{
|
||||
HasUpdate: hasUpdate,
|
||||
CurrentVersion: currentVersionStr,
|
||||
LatestVersion: remoteInfo.Version,
|
||||
DownloadURL: remoteInfo.DownloadURL,
|
||||
Changelog: remoteInfo.Changelog,
|
||||
ForceUpdate: remoteInfo.ForceUpdate,
|
||||
ReleaseDate: remoteInfo.ReleaseDate,
|
||||
}
|
||||
|
||||
log.Printf("[更新检查] 检查完成: 有更新=%v", hasUpdate)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// fetchRemoteVersionInfo 获取远程版本信息
|
||||
func (s *UpdateService) fetchRemoteVersionInfo() (*RemoteVersionInfo, error) {
|
||||
if s.checkURL == "" {
|
||||
log.Printf("[远程版本] 版本检查 URL 未配置")
|
||||
return nil, fmt.Errorf("版本检查 URL 未配置,请先设置检查地址")
|
||||
}
|
||||
|
||||
log.Printf("[远程版本] 请求远程版本信息: %s", s.checkURL)
|
||||
|
||||
// 创建 HTTP 客户端,设置超时
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Get(s.checkURL)
|
||||
if err != nil {
|
||||
log.Printf("[远程版本] 网络请求失败: %v", err)
|
||||
return nil, fmt.Errorf("网络请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
log.Printf("[远程版本] HTTP 响应状态码: %d", resp.StatusCode)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("[远程版本] 读取响应失败: %v", err)
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[远程版本] 响应内容长度: %d 字节", len(body))
|
||||
|
||||
// 解析 JSON
|
||||
var remoteInfo RemoteVersionInfo
|
||||
if err := json.Unmarshal(body, &remoteInfo); err != nil {
|
||||
log.Printf("[远程版本] 解析 JSON 失败: %v, 响应内容: %s", err, string(body))
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if remoteInfo.Version == "" {
|
||||
log.Printf("[远程版本] 远程版本信息不完整,响应内容: %s", string(body))
|
||||
return nil, fmt.Errorf("远程版本信息不完整")
|
||||
}
|
||||
|
||||
log.Printf("[远程版本] 成功获取远程版本信息: %+v", remoteInfo)
|
||||
return &remoteInfo, nil
|
||||
}
|
||||
|
||||
// UpdateCheckResult 更新检查结果
|
||||
type UpdateCheckResult struct {
|
||||
HasUpdate bool `json:"has_update"`
|
||||
CurrentVersion string `json:"current_version"`
|
||||
LatestVersion string `json:"latest_version"`
|
||||
DownloadURL string `json:"download_url"`
|
||||
Changelog string `json:"changelog"`
|
||||
ForceUpdate bool `json:"force_update"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
}
|
||||
159
internal/service/version.go
Normal file
159
internal/service/version.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ==================== 常量定义 ====================
|
||||
|
||||
// AppVersion 应用版本号(发布时直接修改此处)
|
||||
const AppVersion = "0.1.1"
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
|
||||
// Version 版本号结构
|
||||
type Version struct {
|
||||
Major int
|
||||
Minor int
|
||||
Patch int
|
||||
}
|
||||
|
||||
// WailsConfig Wails 配置文件结构
|
||||
type WailsConfig struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ==================== 版本号解析和比较 ====================
|
||||
|
||||
// ParseVersion 解析版本号字符串(支持 v1.0.0 或 1.0.0 格式)
|
||||
func ParseVersion(versionStr string) (*Version, error) {
|
||||
versionStr = strings.TrimPrefix(versionStr, "v")
|
||||
parts := strings.Split(versionStr, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("版本号格式错误,应为 x.y.z 格式")
|
||||
}
|
||||
|
||||
major, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("主版本号解析失败: %v", err)
|
||||
}
|
||||
|
||||
minor, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("次版本号解析失败: %v", err)
|
||||
}
|
||||
|
||||
patch, err := strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("修订号解析失败: %v", err)
|
||||
}
|
||||
|
||||
return &Version{Major: major, Minor: minor, Patch: patch}, nil
|
||||
}
|
||||
|
||||
// String 返回版本号字符串(格式:v1.0.0)
|
||||
func (v *Version) String() string {
|
||||
return fmt.Sprintf("v%d.%d.%d", v.Major, v.Minor, v.Patch)
|
||||
}
|
||||
|
||||
// Compare 比较版本号
|
||||
// 返回值:-1 表示当前版本小于目标版本,0 表示相等,1 表示大于
|
||||
func (v *Version) Compare(other *Version) int {
|
||||
if v.Major != other.Major {
|
||||
if v.Major < other.Major {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if v.Minor != other.Minor {
|
||||
if v.Minor < other.Minor {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if v.Patch != other.Patch {
|
||||
if v.Patch < other.Patch {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsNewerThan 判断是否比目标版本新
|
||||
func (v *Version) IsNewerThan(other *Version) bool {
|
||||
return v.Compare(other) > 0
|
||||
}
|
||||
|
||||
// IsOlderThan 判断是否比目标版本旧
|
||||
func (v *Version) IsOlderThan(other *Version) bool {
|
||||
return v.Compare(other) < 0
|
||||
}
|
||||
|
||||
// ==================== 版本号获取 ====================
|
||||
|
||||
// GetCurrentVersion 获取当前版本号
|
||||
// 优先级:硬编码版本号 > wails.json(开发模式)> 默认值
|
||||
func GetCurrentVersion() string {
|
||||
if AppVersion != "" {
|
||||
log.Printf("[版本] 使用硬编码版本号: %s", AppVersion)
|
||||
return AppVersion
|
||||
}
|
||||
|
||||
version := getVersionFromWailsJSON()
|
||||
if version != "" {
|
||||
log.Printf("[版本] 从 wails.json 获取版本号: %s", version)
|
||||
return version
|
||||
}
|
||||
|
||||
log.Printf("[版本] 使用默认版本号: 0.0.1")
|
||||
return "0.0.1"
|
||||
}
|
||||
|
||||
// ==================== 配置文件读取 ====================
|
||||
|
||||
// getVersionFromWailsJSON 从 wails.json 读取版本号(仅开发模式使用)
|
||||
func getVersionFromWailsJSON() string {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 尝试当前目录
|
||||
if version := readVersionFromFile(filepath.Join(wd, "wails.json")); version != "" {
|
||||
return version
|
||||
}
|
||||
|
||||
// 尝试父目录
|
||||
if version := readVersionFromFile(filepath.Join(filepath.Dir(wd), "wails.json")); version != "" {
|
||||
return version
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// readVersionFromFile 从指定文件读取版本号
|
||||
func readVersionFromFile(filePath string) string {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
log.Printf("[版本] 读取文件失败: %s, 错误: %v", filePath, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
var config WailsConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
log.Printf("[版本] 解析 JSON 失败: %s, 错误: %v", filePath, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if config.Version != "" {
|
||||
log.Printf("[版本] 从文件读取版本号: %s -> %s", filePath, config.Version)
|
||||
}
|
||||
return config.Version
|
||||
}
|
||||
20
internal/storage/models/authorization.go
Normal file
20
internal/storage/models/authorization.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Authorization 授权信息
|
||||
type Authorization struct {
|
||||
ID int `gorm:"primaryKey" json:"id"` // 主键ID
|
||||
LicenseCode string `gorm:"type:varchar(100);not null;uniqueIndex" json:"license_code"` // 授权码(唯一)
|
||||
DeviceID string `gorm:"type:varchar(100);not null;index" json:"device_id"` // 设备ID(MD5哈希)
|
||||
ActivatedAt time.Time `gorm:"not null" json:"activated_at"` // 激活时间
|
||||
ExpiresAt *time.Time `gorm:"type:datetime" json:"expires_at"` // 过期时间(可选,nil表示永不过期)
|
||||
Status int `gorm:"type:tinyint;not null;default:1" json:"status"` // 状态(1:有效 0:无效)
|
||||
CreatedAt time.Time `gorm:"autoCreateTime:false" json:"created_at"` // 创建时间(由程序设置)
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false" json:"updated_at"` // 更新时间(由程序设置)
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Authorization) TableName() string {
|
||||
return "sys_authorization_code"
|
||||
}
|
||||
24
internal/storage/models/ssq_history.go
Normal file
24
internal/storage/models/ssq_history.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// SsqHistory 双色球历史开奖数据
|
||||
type SsqHistory struct {
|
||||
ID int `gorm:"primaryKey;column:id" json:"id"`
|
||||
IssueNumber string `gorm:"type:varchar(20);not null;index;column:issue_number" json:"issue_number"`
|
||||
OpenDate *time.Time `gorm:"type:date;column:open_date" json:"open_date"`
|
||||
RedBall1 int `gorm:"type:tinyint;not null;column:red_ball_1" json:"red_ball_1"`
|
||||
RedBall2 int `gorm:"type:tinyint;not null;column:red_ball_2" json:"red_ball_2"`
|
||||
RedBall3 int `gorm:"type:tinyint;not null;column:red_ball_3" json:"red_ball_3"`
|
||||
RedBall4 int `gorm:"type:tinyint;not null;column:red_ball_4" json:"red_ball_4"`
|
||||
RedBall5 int `gorm:"type:tinyint;not null;column:red_ball_5" json:"red_ball_5"`
|
||||
RedBall6 int `gorm:"type:tinyint;not null;column:red_ball_6" json:"red_ball_6"`
|
||||
BlueBall int `gorm:"type:tinyint;not null;column:blue_ball" json:"blue_ball"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime:false;column:created_at" json:"created_at"` // 创建时间(由程序设置)
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false;column:updated_at" json:"updated_at"` // 更新时间(由程序设置)
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SsqHistory) TableName() string {
|
||||
return "ssq_history"
|
||||
}
|
||||
20
internal/storage/models/version.go
Normal file
20
internal/storage/models/version.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// Version 版本信息
|
||||
type Version struct {
|
||||
ID int `gorm:"primaryKey" json:"id"` // 主键ID
|
||||
Version string `gorm:"type:varchar(20);not null;uniqueIndex" json:"version"` // 版本号(语义化版本,如1.0.0)
|
||||
DownloadURL string `gorm:"type:varchar(500)" json:"download_url"` // 下载地址(更新包下载URL)
|
||||
Changelog string `gorm:"type:text" json:"changelog"` // 更新日志(Markdown格式)
|
||||
ForceUpdate int `gorm:"type:tinyint;not null;default:0" json:"force_update"` // 是否强制更新(1:是 0:否)
|
||||
ReleaseDate *time.Time `gorm:"type:date" json:"release_date"` // 发布日期
|
||||
CreatedAt time.Time `gorm:"autoCreateTime:false" json:"created_at"` // 创建时间(由程序设置)
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false" json:"updated_at"` // 更新时间(由程序设置)
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Version) TableName() string {
|
||||
return "sys_version"
|
||||
}
|
||||
76
internal/storage/repository/auth_repository.go
Normal file
76
internal/storage/repository/auth_repository.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/database"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthRepository 授权数据访问接口
|
||||
type AuthRepository interface {
|
||||
Create(auth *models.Authorization) error
|
||||
Update(auth *models.Authorization) error
|
||||
GetByLicenseCode(licenseCode string) (*models.Authorization, error)
|
||||
GetByDeviceID(deviceID string) (*models.Authorization, error)
|
||||
}
|
||||
|
||||
// SQLiteAuthRepository SQLite 授权数据访问实现
|
||||
type SQLiteAuthRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSQLiteAuthRepository 创建 SQLite 授权数据访问实例
|
||||
func NewSQLiteAuthRepository() (AuthRepository, error) {
|
||||
db := database.GetSQLite()
|
||||
if db == nil {
|
||||
return nil, gorm.ErrInvalidDB
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
err := db.AutoMigrate(&models.Authorization{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SQLiteAuthRepository{db: db}, nil
|
||||
}
|
||||
|
||||
// Create 创建授权记录
|
||||
func (r *SQLiteAuthRepository) Create(auth *models.Authorization) error {
|
||||
now := time.Now()
|
||||
if auth.CreatedAt.IsZero() {
|
||||
auth.CreatedAt = now
|
||||
}
|
||||
if auth.UpdatedAt.IsZero() {
|
||||
auth.UpdatedAt = now
|
||||
}
|
||||
return r.db.Create(auth).Error
|
||||
}
|
||||
|
||||
// Update 更新授权记录
|
||||
func (r *SQLiteAuthRepository) Update(auth *models.Authorization) error {
|
||||
auth.UpdatedAt = time.Now()
|
||||
return r.db.Save(auth).Error
|
||||
}
|
||||
|
||||
// GetByLicenseCode 根据授权码查询
|
||||
func (r *SQLiteAuthRepository) GetByLicenseCode(licenseCode string) (*models.Authorization, error) {
|
||||
var auth models.Authorization
|
||||
err := r.db.Where("license_code = ?", licenseCode).First(&auth).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// GetByDeviceID 根据设备ID查询
|
||||
func (r *SQLiteAuthRepository) GetByDeviceID(deviceID string) (*models.Authorization, error) {
|
||||
var auth models.Authorization
|
||||
err := r.db.Where("device_id = ? AND status = ?", deviceID, 1).First(&auth).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
128
internal/storage/repository/mysql_repo.go
Normal file
128
internal/storage/repository/mysql_repo.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"ssq-desk/internal/database"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MySQLSsqRepository MySQL 实现
|
||||
type MySQLSsqRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewMySQLSsqRepository 创建 MySQL 仓库
|
||||
func NewMySQLSsqRepository() (SsqRepository, error) {
|
||||
db := database.GetMySQL()
|
||||
if db == nil {
|
||||
return nil, gorm.ErrInvalidDB
|
||||
}
|
||||
return &MySQLSsqRepository{db: db}, nil
|
||||
}
|
||||
|
||||
// FindAll 查询所有历史数据
|
||||
func (r *MySQLSsqRepository) FindAll() ([]models.SsqHistory, error) {
|
||||
var histories []models.SsqHistory
|
||||
err := r.db.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// FindByIssue 根据期号查询
|
||||
func (r *MySQLSsqRepository) FindByIssue(issueNumber string) (*models.SsqHistory, error) {
|
||||
var history models.SsqHistory
|
||||
err := r.db.Where("issue_number = ?", issueNumber).First(&history).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &history, err
|
||||
}
|
||||
|
||||
// FindByRedBalls 根据红球查询(支持部分匹配)
|
||||
func (r *MySQLSsqRepository) FindByRedBalls(redBalls []int) ([]models.SsqHistory, error) {
|
||||
if len(redBalls) == 0 {
|
||||
return r.FindAll()
|
||||
}
|
||||
|
||||
var histories []models.SsqHistory
|
||||
query := r.db
|
||||
|
||||
// 构建查询条件:红球在输入的红球列表中
|
||||
for _, ball := range redBalls {
|
||||
query = query.Where("red_ball_1 = ? OR red_ball_2 = ? OR red_ball_3 = ? OR red_ball_4 = ? OR red_ball_5 = ? OR red_ball_6 = ?",
|
||||
ball, ball, ball, ball, ball, ball)
|
||||
}
|
||||
|
||||
err := query.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// FindByRedAndBlue 根据红球和蓝球查询
|
||||
func (r *MySQLSsqRepository) FindByRedAndBlue(redBalls []int, blueBall int, blueBallRange []int) ([]models.SsqHistory, error) {
|
||||
var histories []models.SsqHistory
|
||||
query := r.db
|
||||
|
||||
// 红球条件
|
||||
if len(redBalls) > 0 {
|
||||
for _, ball := range redBalls {
|
||||
query = query.Where("red_ball_1 = ? OR red_ball_2 = ? OR red_ball_3 = ? OR red_ball_4 = ? OR red_ball_5 = ? OR red_ball_6 = ?",
|
||||
ball, ball, ball, ball, ball, ball)
|
||||
}
|
||||
}
|
||||
|
||||
// 蓝球条件
|
||||
if blueBall > 0 {
|
||||
query = query.Where("blue_ball = ?", blueBall)
|
||||
} else if len(blueBallRange) > 0 {
|
||||
query = query.Where("blue_ball IN ?", blueBallRange)
|
||||
}
|
||||
|
||||
err := query.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// Create 创建记录
|
||||
func (r *MySQLSsqRepository) Create(history *models.SsqHistory) error {
|
||||
now := time.Now()
|
||||
if history.CreatedAt.IsZero() {
|
||||
history.CreatedAt = now
|
||||
}
|
||||
if history.UpdatedAt.IsZero() {
|
||||
history.UpdatedAt = now
|
||||
}
|
||||
return r.db.Create(history).Error
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建
|
||||
func (r *MySQLSsqRepository) BatchCreate(histories []models.SsqHistory) error {
|
||||
if len(histories) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range histories {
|
||||
if histories[i].CreatedAt.IsZero() {
|
||||
histories[i].CreatedAt = now
|
||||
}
|
||||
if histories[i].UpdatedAt.IsZero() {
|
||||
histories[i].UpdatedAt = now
|
||||
}
|
||||
}
|
||||
return r.db.CreateInBatches(histories, 100).Error
|
||||
}
|
||||
|
||||
// GetLatestIssue 获取最新期号
|
||||
func (r *MySQLSsqRepository) GetLatestIssue() (string, error) {
|
||||
var history models.SsqHistory
|
||||
err := r.db.Order("issue_number DESC").First(&history).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return history.IssueNumber, err
|
||||
}
|
||||
|
||||
// Count 统计总数
|
||||
func (r *MySQLSsqRepository) Count() (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&models.SsqHistory{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
136
internal/storage/repository/sqlite_repo.go
Normal file
136
internal/storage/repository/sqlite_repo.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"ssq-desk/internal/database"
|
||||
"ssq-desk/internal/storage/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SQLiteSsqRepository SQLite 实现
|
||||
type SQLiteSsqRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSQLiteSsqRepository 创建 SQLite 仓库
|
||||
func NewSQLiteSsqRepository() (SsqRepository, error) {
|
||||
db := database.GetSQLite()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("SQLite 数据库未初始化或初始化失败")
|
||||
}
|
||||
|
||||
// 自动迁移表结构(数据库初始化时已迁移,这里再次确保)
|
||||
err := db.AutoMigrate(&models.SsqHistory{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SQLite 表迁移失败: %v", err)
|
||||
}
|
||||
|
||||
return &SQLiteSsqRepository{db: db}, nil
|
||||
}
|
||||
|
||||
// FindAll 查询所有历史数据
|
||||
func (r *SQLiteSsqRepository) FindAll() ([]models.SsqHistory, error) {
|
||||
var histories []models.SsqHistory
|
||||
err := r.db.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// FindByIssue 根据期号查询
|
||||
func (r *SQLiteSsqRepository) FindByIssue(issueNumber string) (*models.SsqHistory, error) {
|
||||
var history models.SsqHistory
|
||||
err := r.db.Where("issue_number = ?", issueNumber).First(&history).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return &history, err
|
||||
}
|
||||
|
||||
// FindByRedBalls 根据红球查询(支持部分匹配)
|
||||
func (r *SQLiteSsqRepository) FindByRedBalls(redBalls []int) ([]models.SsqHistory, error) {
|
||||
if len(redBalls) == 0 {
|
||||
return r.FindAll()
|
||||
}
|
||||
|
||||
var histories []models.SsqHistory
|
||||
query := r.db
|
||||
|
||||
// 构建查询条件:红球在输入的红球列表中
|
||||
for _, ball := range redBalls {
|
||||
query = query.Where("red_ball_1 = ? OR red_ball_2 = ? OR red_ball_3 = ? OR red_ball_4 = ? OR red_ball_5 = ? OR red_ball_6 = ?",
|
||||
ball, ball, ball, ball, ball, ball)
|
||||
}
|
||||
|
||||
err := query.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// FindByRedAndBlue 根据红球和蓝球查询
|
||||
func (r *SQLiteSsqRepository) FindByRedAndBlue(redBalls []int, blueBall int, blueBallRange []int) ([]models.SsqHistory, error) {
|
||||
var histories []models.SsqHistory
|
||||
query := r.db
|
||||
|
||||
// 红球条件
|
||||
if len(redBalls) > 0 {
|
||||
for _, ball := range redBalls {
|
||||
query = query.Where("red_ball_1 = ? OR red_ball_2 = ? OR red_ball_3 = ? OR red_ball_4 = ? OR red_ball_5 = ? OR red_ball_6 = ?",
|
||||
ball, ball, ball, ball, ball, ball)
|
||||
}
|
||||
}
|
||||
|
||||
// 蓝球条件
|
||||
if blueBall > 0 {
|
||||
query = query.Where("blue_ball = ?", blueBall)
|
||||
} else if len(blueBallRange) > 0 {
|
||||
query = query.Where("blue_ball IN ?", blueBallRange)
|
||||
}
|
||||
|
||||
err := query.Order("issue_number DESC").Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// Create 创建记录
|
||||
func (r *SQLiteSsqRepository) Create(history *models.SsqHistory) error {
|
||||
now := time.Now()
|
||||
if history.CreatedAt.IsZero() {
|
||||
history.CreatedAt = now
|
||||
}
|
||||
if history.UpdatedAt.IsZero() {
|
||||
history.UpdatedAt = now
|
||||
}
|
||||
return r.db.Create(history).Error
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建
|
||||
func (r *SQLiteSsqRepository) BatchCreate(histories []models.SsqHistory) error {
|
||||
if len(histories) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
for i := range histories {
|
||||
if histories[i].CreatedAt.IsZero() {
|
||||
histories[i].CreatedAt = now
|
||||
}
|
||||
if histories[i].UpdatedAt.IsZero() {
|
||||
histories[i].UpdatedAt = now
|
||||
}
|
||||
}
|
||||
return r.db.CreateInBatches(histories, 100).Error
|
||||
}
|
||||
|
||||
// GetLatestIssue 获取最新期号
|
||||
func (r *SQLiteSsqRepository) GetLatestIssue() (string, error) {
|
||||
var history models.SsqHistory
|
||||
err := r.db.Order("issue_number DESC").First(&history).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return history.IssueNumber, err
|
||||
}
|
||||
|
||||
// Count 统计总数
|
||||
func (r *SQLiteSsqRepository) Count() (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&models.SsqHistory{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
25
internal/storage/repository/ssq_repository.go
Normal file
25
internal/storage/repository/ssq_repository.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"ssq-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// SsqRepository 双色球数据仓库接口
|
||||
type SsqRepository interface {
|
||||
// FindAll 查询所有历史数据
|
||||
FindAll() ([]models.SsqHistory, error)
|
||||
// FindByIssue 根据期号查询
|
||||
FindByIssue(issueNumber string) (*models.SsqHistory, error)
|
||||
// FindByRedBalls 根据红球查询(支持部分匹配)
|
||||
FindByRedBalls(redBalls []int) ([]models.SsqHistory, error)
|
||||
// FindByRedAndBlue 根据红球和蓝球查询
|
||||
FindByRedAndBlue(redBalls []int, blueBall int, blueBallRange []int) ([]models.SsqHistory, error)
|
||||
// Create 创建记录
|
||||
Create(history *models.SsqHistory) error
|
||||
// BatchCreate 批量创建
|
||||
BatchCreate(histories []models.SsqHistory) error
|
||||
// GetLatestIssue 获取最新期号
|
||||
GetLatestIssue() (string, error)
|
||||
// Count 统计总数
|
||||
Count() (int64, error)
|
||||
}
|
||||
Reference in New Issue
Block a user