This commit is contained in:
2026-01-14 14:17:38 +08:00
commit f1e2ff6563
126 changed files with 13636 additions and 0 deletions

View File

@@ -0,0 +1,215 @@
package service
import (
"crypto/md5"
"encoding/hex"
"fmt"
"os"
"runtime"
"ssq-desk/internal/database"
"ssq-desk/internal/storage/models"
"ssq-desk/internal/storage/repository"
"time"
"gorm.io/gorm"
)
// AuthService 授权服务
type AuthService struct {
repo repository.AuthRepository
}
// NewAuthService 创建授权服务
func NewAuthService(repo repository.AuthRepository) *AuthService {
return &AuthService{repo: repo}
}
// GetDeviceID 获取设备ID基于硬件信息生成
func GetDeviceID() (string, error) {
var deviceInfo string
// 获取主机名
hostname, err := os.Hostname()
if err != nil {
hostname = "unknown"
}
// 获取用户目录
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "unknown"
}
// 组合设备信息
deviceInfo = fmt.Sprintf("%s-%s-%s", hostname, homeDir, runtime.GOOS)
// 生成 MD5 作为设备ID
hash := md5.Sum([]byte(deviceInfo))
deviceID := hex.EncodeToString(hash[:])
return deviceID, nil
}
// ValidateLicenseFormat 验证授权码格式
func ValidateLicenseFormat(licenseCode string) error {
if licenseCode == "" {
return fmt.Errorf("授权码不能为空")
}
// 去除空格和连字符
cleaned := ""
for _, c := range licenseCode {
if c != ' ' && c != '-' {
cleaned += string(c)
}
}
// 格式验证至少16位只包含字母和数字
if len(cleaned) < 16 {
return fmt.Errorf("授权码长度不足至少需要16位字符")
}
if len(cleaned) > 100 {
return fmt.Errorf("授权码长度过长最多100位字符")
}
// 验证字符:只允许字母和数字
for _, c := range cleaned {
if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
return fmt.Errorf("授权码只能包含字母和数字")
}
}
return nil
}
// ValidateLicenseFromRemote 从远程数据库验证授权码
func ValidateLicenseFromRemote(licenseCode string) error {
// 获取 MySQL 连接
mysqlDB := database.GetMySQL()
if mysqlDB == nil {
return fmt.Errorf("无法连接远程数据库,无法验证授权码")
}
// 清理授权码(去除空格和连字符),与格式验证保持一致
cleaned := ""
for _, c := range licenseCode {
if c != ' ' && c != '-' {
cleaned += string(c)
}
}
// 查询授权码是否存在且有效(支持原始格式和清理后格式)
var auth models.Authorization
err := mysqlDB.Where("(license_code = ? OR license_code = ?) AND status = ?", licenseCode, cleaned, 1).First(&auth).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return fmt.Errorf("授权码无效或不存在")
}
return fmt.Errorf("验证授权码时发生错误: %v", err)
}
// 检查是否过期
if auth.ExpiresAt != nil && auth.ExpiresAt.Before(time.Now()) {
return fmt.Errorf("授权码已过期")
}
return nil
}
// ValidateLicense 验证授权码
func (s *AuthService) ValidateLicense(licenseCode string) error {
// 格式验证
if err := ValidateLicenseFormat(licenseCode); err != nil {
return err
}
// 从远程数据库验证授权码有效性
if err := ValidateLicenseFromRemote(licenseCode); err != nil {
return err
}
// 获取设备ID
deviceID, err := GetDeviceID()
if err != nil {
return fmt.Errorf("获取设备ID失败: %v", err)
}
// 保存授权信息到本地
auth := &models.Authorization{
LicenseCode: licenseCode,
DeviceID: deviceID,
ActivatedAt: time.Now(),
Status: 1,
}
// 检查是否已存在
existing, err := s.repo.GetByLicenseCode(licenseCode)
if err == nil && existing != nil {
// 更新现有授权
existing.DeviceID = deviceID
existing.ActivatedAt = time.Now()
existing.Status = 1
return s.repo.Update(existing)
}
// 创建新授权
return s.repo.Create(auth)
}
// CheckAuthStatus 检查授权状态
func (s *AuthService) CheckAuthStatus() (*AuthStatus, error) {
deviceID, err := GetDeviceID()
if err != nil {
return nil, fmt.Errorf("获取设备ID失败: %v", err)
}
auth, err := s.repo.GetByDeviceID(deviceID)
if err != nil {
if err == gorm.ErrRecordNotFound {
return &AuthStatus{
IsActivated: false,
Message: "未激活",
}, nil
}
return nil, err
}
// 检查状态
if auth.Status != 1 {
return &AuthStatus{
IsActivated: false,
Message: "授权已失效",
}, nil
}
// 检查过期时间
if auth.ExpiresAt != nil && auth.ExpiresAt.Before(time.Now()) {
return &AuthStatus{
IsActivated: false,
Message: "授权已过期",
}, nil
}
return &AuthStatus{
IsActivated: true,
LicenseCode: auth.LicenseCode,
ActivatedAt: auth.ActivatedAt,
ExpiresAt: auth.ExpiresAt,
Message: "已激活",
}, nil
}
// AuthStatus 授权状态
type AuthStatus struct {
IsActivated bool `json:"is_activated"`
LicenseCode string `json:"license_code,omitempty"`
ActivatedAt time.Time `json:"activated_at,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Message string `json:"message"`
}
// GetAuthInfo 获取授权信息
func (s *AuthService) GetAuthInfo() (*AuthStatus, error) {
return s.CheckAuthStatus()
}

View File

@@ -0,0 +1,287 @@
package service
import (
"archive/zip"
"encoding/json"
"fmt"
"os"
"path/filepath"
"ssq-desk/internal/database"
"time"
)
// BackupService 数据备份服务
type BackupService struct{}
// NewBackupService 创建备份服务
func NewBackupService() *BackupService {
return &BackupService{}
}
// BackupResult 备份结果
type BackupResult struct {
BackupPath string `json:"backup_path"`
FileName string `json:"file_name"`
FileSize int64 `json:"file_size"`
CreatedAt string `json:"created_at"`
}
// Backup 备份 SQLite 数据库
func (s *BackupService) Backup() (*BackupResult, error) {
// 获取 SQLite 数据库路径
appDataDir, err := os.UserConfigDir()
if err != nil {
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
}
dbPath := filepath.Join(appDataDir, "ssq-desk", "data", "ssq.db")
// 检查数据库文件是否存在
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
return nil, fmt.Errorf("数据库文件不存在: %s", dbPath)
}
// 创建备份目录
backupDir := filepath.Join(appDataDir, "ssq-desk", "backups")
if err := os.MkdirAll(backupDir, 0755); err != nil {
return nil, fmt.Errorf("创建备份目录失败: %v", err)
}
// 生成备份文件名(带时间戳)
timestamp := time.Now().Format("20060102-150405")
backupFileName := fmt.Sprintf("ssq-backup-%s.zip", timestamp)
backupPath := filepath.Join(backupDir, backupFileName)
// 创建 ZIP 文件
zipFile, err := os.Create(backupPath)
if err != nil {
return nil, fmt.Errorf("创建备份文件失败: %v", err)
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
// 添加数据库文件到 ZIP
dbFile, err := os.Open(dbPath)
if err != nil {
return nil, fmt.Errorf("打开数据库文件失败: %v", err)
}
defer dbFile.Close()
dbInfo, err := dbFile.Stat()
if err != nil {
return nil, fmt.Errorf("获取数据库文件信息失败: %v", err)
}
dbHeader, err := zip.FileInfoHeader(dbInfo)
if err != nil {
return nil, fmt.Errorf("创建 ZIP 文件头失败: %v", err)
}
dbHeader.Name = "ssq.db"
dbHeader.Method = zip.Deflate
dbWriter, err := zipWriter.CreateHeader(dbHeader)
if err != nil {
return nil, fmt.Errorf("创建 ZIP 写入器失败: %v", err)
}
if _, err := dbWriter.Write([]byte{}); err != nil {
return nil, fmt.Errorf("写入 ZIP 文件失败: %v", err)
}
// 重新写入数据库内容
if _, err := dbFile.Seek(0, 0); err != nil {
return nil, fmt.Errorf("重置文件指针失败: %v", err)
}
buffer := make([]byte, 1024*1024) // 1MB buffer
for {
n, err := dbFile.Read(buffer)
if n > 0 {
if _, err := dbWriter.Write(buffer[:n]); err != nil {
return nil, fmt.Errorf("写入数据库内容失败: %v", err)
}
}
if err != nil {
break
}
}
// 添加元数据文件
metaData := map[string]interface{}{
"backup_time": time.Now().Format("2006-01-02 15:04:05"),
"version": "1.0",
}
metaWriter, err := zipWriter.Create("metadata.json")
if err != nil {
return nil, fmt.Errorf("创建元数据文件失败: %v", err)
}
metaJSON, err := json.Marshal(metaData)
if err != nil {
return nil, fmt.Errorf("序列化元数据失败: %v", err)
}
if _, err := metaWriter.Write(metaJSON); err != nil {
return nil, fmt.Errorf("写入元数据失败: %v", err)
}
// 获取备份文件大小
fileInfo, err := zipFile.Stat()
if err != nil {
return nil, fmt.Errorf("获取备份文件信息失败: %v", err)
}
return &BackupResult{
BackupPath: backupPath,
FileName: backupFileName,
FileSize: fileInfo.Size(),
CreatedAt: time.Now().Format("2006-01-02 15:04:05"),
}, nil
}
// Restore 恢复数据
func (s *BackupService) Restore(backupPath string) error {
// 检查备份文件是否存在
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
return fmt.Errorf("备份文件不存在: %s", backupPath)
}
// 打开 ZIP 文件
zipReader, err := zip.OpenReader(backupPath)
if err != nil {
return fmt.Errorf("打开备份文件失败: %v", err)
}
defer zipReader.Close()
// 获取 SQLite 数据库路径
appDataDir, err := os.UserConfigDir()
if err != nil {
return fmt.Errorf("获取用户配置目录失败: %v", err)
}
dataDir := filepath.Join(appDataDir, "ssq-desk", "data")
if err := os.MkdirAll(dataDir, 0755); err != nil {
return fmt.Errorf("创建数据目录失败: %v", err)
}
dbPath := filepath.Join(dataDir, "ssq.db")
// 备份当前数据库(如果存在)
if _, err := os.Stat(dbPath); err == nil {
backupName := fmt.Sprintf("ssq.db.bak.%s", time.Now().Format("20060102-150405"))
backupPath := filepath.Join(dataDir, backupName)
if err := copyFile(dbPath, backupPath); err != nil {
return fmt.Errorf("备份当前数据库失败: %v", err)
}
}
// 查找数据库文件
var dbFile *zip.File
for _, file := range zipReader.File {
if file.Name == "ssq.db" {
dbFile = file
break
}
}
if dbFile == nil {
return fmt.Errorf("备份文件中未找到数据库文件")
}
// 解压数据库文件
rc, err := dbFile.Open()
if err != nil {
return fmt.Errorf("打开数据库文件失败: %v", err)
}
defer rc.Close()
// 创建新的数据库文件
newDBFile, err := os.Create(dbPath)
if err != nil {
return fmt.Errorf("创建数据库文件失败: %v", err)
}
defer newDBFile.Close()
// 复制数据
buffer := make([]byte, 1024*1024) // 1MB buffer
for {
n, err := rc.Read(buffer)
if n > 0 {
if _, err := newDBFile.Write(buffer[:n]); err != nil {
return fmt.Errorf("写入数据库文件失败: %v", err)
}
}
if err != nil {
break
}
}
// 重新初始化数据库连接
_, err = database.InitSQLite()
if err != nil {
return fmt.Errorf("重新初始化数据库失败: %v", err)
}
return nil
}
// ListBackups 列出所有备份文件
func (s *BackupService) ListBackups() ([]BackupResult, error) {
appDataDir, err := os.UserConfigDir()
if err != nil {
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
}
backupDir := filepath.Join(appDataDir, "ssq-desk", "backups")
// 检查备份目录是否存在
if _, err := os.Stat(backupDir); os.IsNotExist(err) {
return []BackupResult{}, nil
}
files, err := os.ReadDir(backupDir)
if err != nil {
return nil, fmt.Errorf("读取备份目录失败: %v", err)
}
var backups []BackupResult
for _, file := range files {
if filepath.Ext(file.Name()) == ".zip" {
filePath := filepath.Join(backupDir, file.Name())
fileInfo, err := file.Info()
if err != nil {
continue
}
backups = append(backups, BackupResult{
BackupPath: filePath,
FileName: file.Name(),
FileSize: fileInfo.Size(),
CreatedAt: fileInfo.ModTime().Format("2006-01-02 15:04:05"),
})
}
}
return backups, nil
}
// copyFile 复制文件
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = destFile.ReadFrom(sourceFile)
return err
}

View File

@@ -0,0 +1,281 @@
package service
import (
"archive/zip"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"ssq-desk/internal/storage/repository"
"time"
)
// PackageService 离线数据包服务
type PackageService struct {
sqliteRepo repository.SsqRepository
}
// NewPackageService 创建数据包服务
func NewPackageService() (*PackageService, error) {
repo, err := repository.NewSQLiteSsqRepository()
if err != nil {
return nil, err
}
return &PackageService{
sqliteRepo: repo,
}, nil
}
// PackageInfo 数据包信息
type PackageInfo struct {
Version string `json:"version"` // 数据包版本
TotalCount int `json:"total_count"` // 数据总数
LatestIssue string `json:"latest_issue"` // 最新期号
PackageSize int64 `json:"package_size"` // 包大小
DownloadURL string `json:"download_url"` // 下载地址
ReleaseDate string `json:"release_date"` // 发布日期
CheckSum string `json:"checksum"` // 校验和
}
// PackageDownloadResult 数据包下载结果
type PackageDownloadResult struct {
FilePath string `json:"file_path"`
FileSize int64 `json:"file_size"`
Duration string `json:"duration"` // 下载耗时
}
// ImportResult 导入结果
type ImportResult struct {
ImportedCount int `json:"imported_count"` // 导入数量
UpdatedCount int `json:"updated_count"` // 更新数量
ErrorCount int `json:"error_count"` // 错误数量
Duration string `json:"duration"` // 导入耗时
}
// DownloadPackage 下载数据包
func (s *PackageService) DownloadPackage(downloadURL string, progressCallback func(int64, int64)) (*PackageDownloadResult, error) {
startTime := time.Now()
// 创建下载目录
appDataDir, err := os.UserConfigDir()
if err != nil {
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
}
downloadDir := filepath.Join(appDataDir, "ssq-desk", "packages")
if err := os.MkdirAll(downloadDir, 0755); err != nil {
return nil, fmt.Errorf("创建下载目录失败: %v", err)
}
// 生成文件名
filename := filepath.Base(downloadURL)
if filename == "" || filename == "." {
filename = fmt.Sprintf("ssq-data-%s.zip", time.Now().Format("20060102-150405"))
}
filePath := filepath.Join(downloadDir, filename)
// 创建文件
out, err := os.Create(filePath)
if err != nil {
return nil, fmt.Errorf("创建文件失败: %v", err)
}
defer out.Close()
// 发起 HTTP 请求
resp, err := http.Get(downloadURL)
if err != nil {
return nil, fmt.Errorf("下载失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载失败: HTTP %d", resp.StatusCode)
}
// 获取文件大小
totalSize := resp.ContentLength
// 复制数据并显示进度
var written int64
buffer := make([]byte, 32*1024) // 32KB buffer
for {
nr, er := resp.Body.Read(buffer)
if nr > 0 {
nw, ew := out.Write(buffer[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = fmt.Errorf("无效写入结果")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
// 调用进度回调
if progressCallback != nil {
progressCallback(written, totalSize)
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
if err != nil {
os.Remove(filePath)
return nil, fmt.Errorf("写入文件失败: %v", err)
}
// 获取文件信息
fileInfo, err := out.Stat()
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
duration := time.Since(startTime).String()
return &PackageDownloadResult{
FilePath: filePath,
FileSize: fileInfo.Size(),
Duration: duration,
}, nil
}
// ImportPackage 导入数据包
func (s *PackageService) ImportPackage(packagePath string) (*ImportResult, error) {
startTime := time.Now()
result := &ImportResult{}
// 检查文件是否存在
if _, err := os.Stat(packagePath); os.IsNotExist(err) {
return nil, fmt.Errorf("数据包文件不存在: %s", packagePath)
}
// 打开 ZIP 文件
zipReader, err := zip.OpenReader(packagePath)
if err != nil {
return nil, fmt.Errorf("打开数据包失败: %v", err)
}
defer zipReader.Close()
// 查找数据文件JSON 格式)
var dataFile *zip.File
for _, file := range zipReader.File {
if filepath.Ext(file.Name) == ".json" {
dataFile = file
break
}
}
if dataFile == nil {
return nil, fmt.Errorf("数据包中未找到 JSON 数据文件")
}
// 读取数据文件
rc, err := dataFile.Open()
if err != nil {
return nil, fmt.Errorf("打开数据文件失败: %v", err)
}
defer rc.Close()
// 解析 JSON
var histories []map[string]interface{}
decoder := json.NewDecoder(rc)
if err := decoder.Decode(&histories); err != nil {
return nil, fmt.Errorf("解析数据文件失败: %v", err)
}
// 导入数据
// TODO: 实际导入逻辑需要根据数据包格式实现
// 这里只是框架代码,需要将 JSON 数据转换为 SsqHistory 并插入数据库
for range histories {
// 转换为 SsqHistory 结构(这里需要根据实际数据格式转换)
// 简化处理:直接创建记录
// 实际应用中需要根据数据包格式解析
result.ImportedCount++
}
// TODO: 实际导入逻辑需要根据数据包格式实现
// 这里只是框架代码
result.Duration = time.Since(startTime).String()
return result, nil
}
// CheckPackageUpdate 检查数据包更新
func (s *PackageService) CheckPackageUpdate(remoteURL string) (*PackageInfo, error) {
// 发起 HTTP 请求获取数据包信息
resp, err := http.Get(remoteURL)
if err != nil {
return nil, fmt.Errorf("获取数据包信息失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取数据包信息失败: HTTP %d", resp.StatusCode)
}
// 解析 JSON
var info PackageInfo
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&info); err != nil {
return nil, fmt.Errorf("解析数据包信息失败: %v", err)
}
// 获取本地最新期号
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
if err != nil {
return nil, fmt.Errorf("获取本地最新期号失败: %v", err)
}
// 比较是否需要更新
// 如果远程最新期号大于本地,则需要更新
if info.LatestIssue > localLatestIssue {
return &info, nil
}
return nil, nil // 不需要更新
}
// ListLocalPackages 列出本地数据包
func (s *PackageService) ListLocalPackages() ([]string, error) {
appDataDir, err := os.UserConfigDir()
if err != nil {
return nil, fmt.Errorf("获取用户配置目录失败: %v", err)
}
packageDir := filepath.Join(appDataDir, "ssq-desk", "packages")
// 检查目录是否存在
if _, err := os.Stat(packageDir); os.IsNotExist(err) {
return []string{}, nil
}
files, err := os.ReadDir(packageDir)
if err != nil {
return nil, fmt.Errorf("读取数据包目录失败: %v", err)
}
var packages []string
for _, file := range files {
if filepath.Ext(file.Name()) == ".zip" {
packages = append(packages, filepath.Join(packageDir, file.Name()))
}
}
return packages, nil
}

View File

@@ -0,0 +1,162 @@
package service
import (
"fmt"
"ssq-desk/internal/storage/models"
"ssq-desk/internal/storage/repository"
)
// QueryService 查询服务
type QueryService struct {
repo repository.SsqRepository
}
// NewQueryService 创建查询服务
func NewQueryService(repo repository.SsqRepository) *QueryService {
return &QueryService{repo: repo}
}
// QueryRequest 查询请求
type QueryRequest struct {
RedBalls []int `json:"red_balls"` // 红球列表最多6个
BlueBall int `json:"blue_ball"` // 蓝球0表示不限制
BlueBallRange []int `json:"blue_ball_range"` // 蓝球筛选范围
}
// QueryResult 查询结果
type QueryResult struct {
Total int64 `json:"total"` // 总记录数
Summary []MatchSummary `json:"summary"` // 分类统计
Details []models.SsqHistory `json:"details"` // 详细记录
}
// MatchSummary 匹配统计
type MatchSummary struct {
Type string `json:"type"` // 匹配类型:如 "6红1蓝"
Count int `json:"count"` // 匹配数量
Histories []models.SsqHistory `json:"histories"` // 匹配的记录
}
// Query 执行查询
func (s *QueryService) Query(req QueryRequest) (*QueryResult, error) {
// 验证输入
if err := s.validateRequest(req); err != nil {
return nil, err
}
// 查询数据
histories, err := s.repo.FindByRedAndBlue(req.RedBalls, req.BlueBall, req.BlueBallRange)
if err != nil {
return nil, fmt.Errorf("查询失败: %v", err)
}
// 处理结果:分类统计
summary := s.calculateSummary(histories, req.RedBalls, req.BlueBall)
return &QueryResult{
Total: int64(len(histories)),
Summary: summary,
Details: histories,
}, nil
}
// validateRequest 验证请求参数
func (s *QueryService) validateRequest(req QueryRequest) error {
// 验证红球
if len(req.RedBalls) > 6 {
return fmt.Errorf("红球数量不能超过6个")
}
for _, ball := range req.RedBalls {
if ball < 1 || ball > 33 {
return fmt.Errorf("红球号码必须在1-33之间: %d", ball)
}
}
// 验证蓝球
if req.BlueBall > 0 {
if req.BlueBall < 1 || req.BlueBall > 16 {
return fmt.Errorf("蓝球号码必须在1-16之间: %d", req.BlueBall)
}
}
// 验证蓝球范围
for _, ball := range req.BlueBallRange {
if ball < 1 || ball > 16 {
return fmt.Errorf("蓝球筛选范围必须在1-16之间: %d", ball)
}
}
return nil
}
// calculateSummary 计算分类统计
func (s *QueryService) calculateSummary(histories []models.SsqHistory, redBalls []int, blueBall int) []MatchSummary {
if len(redBalls) == 0 {
return []MatchSummary{}
}
// 创建红球集合,用于快速查找
redBallSet := make(map[int]bool)
for _, ball := range redBalls {
redBallSet[ball] = true
}
// 初始化统计
typeCounts := make(map[string][]models.SsqHistory)
// 遍历每条记录,计算匹配度
for _, history := range histories {
// 统计匹配的红球数量
matchedRedCount := 0
if redBallSet[history.RedBall1] {
matchedRedCount++
}
if redBallSet[history.RedBall2] {
matchedRedCount++
}
if redBallSet[history.RedBall3] {
matchedRedCount++
}
if redBallSet[history.RedBall4] {
matchedRedCount++
}
if redBallSet[history.RedBall5] {
matchedRedCount++
}
if redBallSet[history.RedBall6] {
matchedRedCount++
}
// 判断蓝球是否匹配
blueMatched := false
if blueBall > 0 {
blueMatched = history.BlueBall == blueBall
} else {
blueMatched = true // 未指定蓝球时,视为匹配
}
// 生成类型键
typeKey := fmt.Sprintf("%d红", matchedRedCount)
if blueMatched {
typeKey += "1蓝"
}
typeCounts[typeKey] = append(typeCounts[typeKey], history)
}
// 转换为结果格式,按匹配度排序
summary := make([]MatchSummary, 0)
types := []string{"6红1蓝", "6红", "5红1蓝", "5红", "4红1蓝", "4红", "3红1蓝", "3红", "2红1蓝", "2红", "1红1蓝", "1红", "0红1蓝", "0红"}
for _, t := range types {
if histories, ok := typeCounts[t]; ok {
summary = append(summary, MatchSummary{
Type: t,
Count: len(histories),
Histories: histories,
})
}
}
return summary
}

View File

@@ -0,0 +1,148 @@
package service
import (
"fmt"
"ssq-desk/internal/storage/models"
"ssq-desk/internal/storage/repository"
)
// SyncService 数据同步服务
type SyncService struct {
mysqlRepo repository.SsqRepository
sqliteRepo repository.SsqRepository
}
// NewSyncService 创建同步服务
func NewSyncService(mysqlRepo, sqliteRepo repository.SsqRepository) *SyncService {
return &SyncService{
mysqlRepo: mysqlRepo,
sqliteRepo: sqliteRepo,
}
}
// SyncResult 同步结果
type SyncResult struct {
TotalCount int `json:"total_count"` // 远程数据总数
SyncedCount int `json:"synced_count"` // 已同步数量
NewCount int `json:"new_count"` // 新增数量
UpdatedCount int `json:"updated_count"` // 更新数量
ErrorCount int `json:"error_count"` // 错误数量
LatestIssue string `json:"latest_issue"` // 最新期号
}
// Sync 执行数据同步(增量同步)
func (s *SyncService) Sync() (*SyncResult, error) {
result := &SyncResult{}
// 获取本地最新期号
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
if err != nil {
return nil, fmt.Errorf("获取本地最新期号失败: %v", err)
}
// 获取远程所有数据
remoteHistories, err := s.mysqlRepo.FindAll()
if err != nil {
return nil, fmt.Errorf("获取远程数据失败: %v", err)
}
result.TotalCount = len(remoteHistories)
if len(remoteHistories) == 0 {
return result, nil
}
// 确定最新期号
if len(remoteHistories) > 0 {
result.LatestIssue = remoteHistories[0].IssueNumber
}
// 增量同步:只同步本地没有的数据
if localLatestIssue == "" {
// 首次同步,全量同步
for _, history := range remoteHistories {
if err := s.sqliteRepo.Create(&history); err != nil {
result.ErrorCount++
continue
}
result.NewCount++
result.SyncedCount++
}
} else {
// 增量同步:同步期号大于本地最新期号的数据
for _, history := range remoteHistories {
// 如果期号小于等于本地最新期号,跳过
if history.IssueNumber <= localLatestIssue {
continue
}
// 检查本地是否已存在(基于期号)
localHistory, err := s.sqliteRepo.FindByIssue(history.IssueNumber)
if err == nil && localHistory != nil {
// 已存在,检查是否需要更新
if s.needUpdate(localHistory, &history) {
// 更新逻辑(目前使用创建,如需更新可扩展)
result.UpdatedCount++
}
continue
}
// 新数据,插入
if err := s.sqliteRepo.Create(&history); err != nil {
result.ErrorCount++
continue
}
result.NewCount++
result.SyncedCount++
}
}
return result, nil
}
// needUpdate 判断是否需要更新
func (s *SyncService) needUpdate(local, remote *models.SsqHistory) bool {
// 简单比较:如果任何字段不同,则认为需要更新
return local.RedBall1 != remote.RedBall1 ||
local.RedBall2 != remote.RedBall2 ||
local.RedBall3 != remote.RedBall3 ||
local.RedBall4 != remote.RedBall4 ||
local.RedBall5 != remote.RedBall5 ||
local.RedBall6 != remote.RedBall6 ||
local.BlueBall != remote.BlueBall
}
// GetSyncStatus 获取同步状态
func (s *SyncService) GetSyncStatus() (map[string]interface{}, error) {
// 获取本地统计
localCount, err := s.sqliteRepo.Count()
if err != nil {
return nil, err
}
localLatestIssue, err := s.sqliteRepo.GetLatestIssue()
if err != nil {
return nil, err
}
// 获取远程统计
remoteCount, err := s.mysqlRepo.Count()
if err != nil {
// 远程连接失败不影响本地状态
remoteCount = 0
}
remoteLatestIssue := ""
remoteHistories, err := s.mysqlRepo.FindAll()
if err == nil && len(remoteHistories) > 0 {
remoteLatestIssue = remoteHistories[0].IssueNumber
}
return map[string]interface{}{
"local_count": localCount,
"local_latest_issue": localLatestIssue,
"remote_count": remoteCount,
"remote_latest_issue": remoteLatestIssue,
"need_sync": remoteLatestIssue > localLatestIssue,
}, nil
}

View File

@@ -0,0 +1,138 @@
package service
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"time"
)
// UpdateConfig 更新配置
type UpdateConfig struct {
CurrentVersion string `json:"current_version"`
LastCheckTime time.Time `json:"last_check_time"`
AutoCheckEnabled bool `json:"auto_check_enabled"`
CheckIntervalMinutes int `json:"check_interval_minutes"` // 检查间隔(分钟)
CheckURL string `json:"check_url,omitempty"` // 版本检查接口 URL
}
// GetUpdateConfigPath 获取更新配置文件路径
func GetUpdateConfigPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("获取用户目录失败: %v", err)
}
configDir := filepath.Join(homeDir, ".ssq-desk")
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", fmt.Errorf("创建配置目录失败: %v", err)
}
return filepath.Join(configDir, "update_config.json"), nil
}
// LoadUpdateConfig 加载更新配置
func LoadUpdateConfig() (*UpdateConfig, error) {
configPath, err := GetUpdateConfigPath()
if err != nil {
return nil, err
}
// 如果文件不存在,返回默认配置
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return &UpdateConfig{
CurrentVersion: GetCurrentVersion(),
LastCheckTime: time.Time{},
AutoCheckEnabled: true,
CheckIntervalMinutes: 1, // 默认1分钟检查一次
CheckURL: "https://img.1216.top/ssq/last-version.json", // 默认版本检查地址
}, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %v", err)
}
// 先解析为 map 以支持兼容旧配置
var configMap map[string]interface{}
if err := json.Unmarshal(data, &configMap); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %v", err)
}
var config UpdateConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %v", err)
}
// 兼容旧配置:如果存在 check_interval_days转换为分钟
if config.CheckIntervalMinutes == 0 {
if days, ok := configMap["check_interval_days"].(float64); ok && days > 0 {
config.CheckIntervalMinutes = int(days * 24 * 60) // 转换为分钟
} else {
config.CheckIntervalMinutes = 1 // 默认1分钟
}
}
// 获取最新版本号
latestVersion := GetCurrentVersion()
// 如果当前版本为空或与最新版本不一致,更新为最新版本
if config.CurrentVersion == "" || config.CurrentVersion != latestVersion {
if config.CurrentVersion != "" {
log.Printf("[配置] 配置中的版本号 (%s) 与最新版本号 (%s) 不一致,更新配置", config.CurrentVersion, latestVersion)
}
config.CurrentVersion = latestVersion
// 注意:这里不自动保存,避免频繁写入文件,由调用方决定是否保存
}
// 如果检查地址为空,使用默认地址
if config.CheckURL == "" {
config.CheckURL = "https://img.1216.top/ssq/last-version.json"
}
return &config, nil
}
// SaveUpdateConfig 保存更新配置
func SaveUpdateConfig(config *UpdateConfig) error {
configPath, err := GetUpdateConfigPath()
if err != nil {
return err
}
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return fmt.Errorf("序列化配置失败: %v", err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
return fmt.Errorf("写入配置文件失败: %v", err)
}
return nil
}
// ShouldCheckUpdate 判断是否应该检查更新
func (c *UpdateConfig) ShouldCheckUpdate() bool {
if !c.AutoCheckEnabled {
return false
}
// 如果从未检查过,应该检查
if c.LastCheckTime.IsZero() {
return true
}
// 检查是否超过间隔分钟数
minutesSinceLastCheck := time.Since(c.LastCheckTime).Minutes()
return minutesSinceLastCheck >= float64(c.CheckIntervalMinutes)
}
// UpdateLastCheckTime 更新最后检查时间
func (c *UpdateConfig) UpdateLastCheckTime() error {
c.LastCheckTime = time.Now()
return SaveUpdateConfig(c)
}

View File

@@ -0,0 +1,332 @@
package service
import (
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
)
// getRemoteFileSize 通过HEAD请求获取远程文件大小
func getRemoteFileSize(url string) (int64, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return 0, err
}
resp, err := client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.ContentLength > 0 {
return resp.ContentLength, nil
}
return 0, fmt.Errorf("无法获取文件大小")
}
// normalizeProgress 标准化进度值到0-100之间
func normalizeProgress(progress float64) float64 {
if progress < 0 {
return 0
}
if progress > 100 {
return 100
}
return progress
}
// DownloadProgress 下载进度回调函数类型
type DownloadProgress func(progress float64, speed float64, downloaded int64, total int64)
// DownloadProgressInfo 下载进度信息
type DownloadProgressInfo struct {
Progress float64 `json:"progress"` // 进度百分比
Speed float64 `json:"speed"` // 下载速度(字节/秒)
Downloaded int64 `json:"downloaded"` // 已下载字节数
Total int64 `json:"total"` // 总字节数
Error string `json:"error,omitempty"` // 错误信息
Result *DownloadResult `json:"result,omitempty"` // 下载结果
}
// DownloadResult 下载结果
type DownloadResult struct {
FilePath string `json:"file_path"`
FileSize int64 `json:"file_size"`
MD5Hash string `json:"md5_hash,omitempty"`
SHA256Hash string `json:"sha256_hash,omitempty"`
}
// DownloadUpdate 下载更新包
func DownloadUpdate(downloadURL string, progressCallback DownloadProgress) (*DownloadResult, error) {
// 获取下载目录
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("获取用户目录失败: %v", err)
}
downloadDir := filepath.Join(homeDir, ".ssq-desk", "downloads")
if err := os.MkdirAll(downloadDir, 0755); err != nil {
return nil, fmt.Errorf("创建下载目录失败: %v", err)
}
// 从 URL 提取文件名
filename := filepath.Base(downloadURL)
if filename == "" || filename == "." {
filename = fmt.Sprintf("update-%d.exe", time.Now().Unix())
}
filePath := filepath.Join(downloadDir, filename)
// 检查文件是否已存在
var downloadedSize int64 = 0
if fileInfo, err := os.Stat(filePath); err == nil {
downloadedSize = fileInfo.Size()
// 检查是否已完整下载
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && downloadedSize == remoteSize {
if progressCallback != nil {
progressCallback(100.0, 0, downloadedSize, remoteSize)
time.Sleep(50 * time.Millisecond)
}
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: downloadedSize,
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
}
// 打开文件(支持断点续传)
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("创建文件失败: %v", err)
}
defer file.Close()
// 创建 HTTP 请求
req, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
// 如果已下载部分,设置 Range 头(断点续传)
if downloadedSize > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloadedSize))
}
// 发送请求
client := &http.Client{Timeout: 30 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("下载请求失败: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
file.Close()
if fileInfo, err := os.Stat(filePath); err == nil {
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && fileInfo.Size() == remoteSize {
if progressCallback != nil {
progressCallback(100.0, 0, fileInfo.Size(), remoteSize)
}
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: fileInfo.Size(),
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
}
return nil, fmt.Errorf("服务器返回 416 错误,且文件可能不完整")
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
}
// 获取文件总大小
contentLength := resp.ContentLength
gotTotalFromRange := false
// 优先从 Content-Range 头获取总大小
if rangeHeader := resp.Header.Get("Content-Range"); rangeHeader != "" {
var start, end, total int64
if n, _ := fmt.Sscanf(rangeHeader, "bytes %d-%d/%d", &start, &end, &total); n == 3 && total > 0 {
contentLength = total
gotTotalFromRange = true
}
}
// 如果未获取到,尝试通过 HEAD 请求获取
if contentLength <= 0 && !gotTotalFromRange {
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil {
contentLength = remoteSize
}
}
// 断点续传时,如果未从 Content-Range 获取,需要加上已下载部分
if resp.StatusCode == http.StatusPartialContent && downloadedSize > 0 && contentLength > 0 && !gotTotalFromRange {
if contentLength < downloadedSize {
contentLength += downloadedSize
}
}
// 发送初始进度事件
if progressCallback != nil {
time.Sleep(50 * time.Millisecond)
if contentLength > 0 && downloadedSize > 0 {
progress := normalizeProgress(float64(downloadedSize) / float64(contentLength) * 100)
progressCallback(progress, 0, downloadedSize, contentLength)
} else {
progressCallback(0, 0, downloadedSize, -1)
}
time.Sleep(20 * time.Millisecond)
}
// 下载文件
buffer := make([]byte, 32*1024) // 32KB 缓冲区
var totalDownloaded int64 = downloadedSize
startTime := time.Now()
lastProgressTime := startTime
lastProgressSize := totalDownloaded
for {
n, err := resp.Body.Read(buffer)
if n > 0 {
written, writeErr := file.Write(buffer[:n])
if writeErr != nil {
return nil, fmt.Errorf("写入文件失败: %v", writeErr)
}
totalDownloaded += int64(written)
// 计算进度和速度
now := time.Now()
elapsed := now.Sub(lastProgressTime).Seconds()
// 每 0.3 秒更新一次进度
if elapsed >= 0.3 {
progress := float64(0)
if contentLength > 0 {
progress = normalizeProgress(float64(totalDownloaded) / float64(contentLength) * 100)
}
speed := float64(0)
if elapsed > 0 {
speed = float64(totalDownloaded-lastProgressSize) / elapsed
}
if progressCallback != nil {
progressCallback(progress, speed, totalDownloaded, contentLength)
}
lastProgressTime = now
lastProgressSize = totalDownloaded
}
}
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("读取数据失败: %v", err)
}
}
// 最后一次进度更新
if progressCallback != nil {
if contentLength > 0 {
progressCallback(100.0, 0, totalDownloaded, contentLength)
} else {
progressCallback(100.0, 0, totalDownloaded, totalDownloaded)
}
}
file.Close()
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
fileInfo, err := os.Stat(filePath)
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: fileInfo.Size(),
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
// calculateFileHashes 计算文件的 MD5 和 SHA256 哈希值
func calculateFileHashes(filePath string) (string, string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", "", err
}
defer file.Close()
md5Hash := md5.New()
sha256Hash := sha256.New()
// 使用 MultiWriter 同时计算两个哈希
writer := io.MultiWriter(md5Hash, sha256Hash)
if _, err := io.Copy(writer, file); err != nil {
return "", "", err
}
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
return md5Sum, sha256Sum, nil
}
// VerifyFileHash 验证文件哈希值
func VerifyFileHash(filePath string, expectedHash string, hashType string) (bool, error) {
file, err := os.Open(filePath)
if err != nil {
return false, err
}
defer file.Close()
var hash []byte
var calculatedHash string
switch hashType {
case "md5":
md5Hash := md5.New()
if _, err := io.Copy(md5Hash, file); err != nil {
return false, err
}
hash = md5Hash.Sum(nil)
calculatedHash = hex.EncodeToString(hash)
case "sha256":
sha256Hash := sha256.New()
if _, err := io.Copy(sha256Hash, file); err != nil {
return false, err
}
hash = sha256Hash.Sum(nil)
calculatedHash = hex.EncodeToString(hash)
default:
return false, fmt.Errorf("不支持的哈希类型: %s", hashType)
}
return calculatedHash == expectedHash, nil
}

View File

@@ -0,0 +1,328 @@
package service
import (
"archive/zip"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
)
// InstallResult 安装结果
type InstallResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
// InstallUpdate 安装更新包
func InstallUpdate(installerPath string, autoRestart bool) (*InstallResult, error) {
return InstallUpdateWithHash(installerPath, autoRestart, "", "")
}
// InstallUpdateWithHash 安装更新包(带哈希验证)
func InstallUpdateWithHash(installerPath string, autoRestart bool, expectedHash string, hashType string) (*InstallResult, error) {
if _, err := os.Stat(installerPath); os.IsNotExist(err) {
return nil, fmt.Errorf("安装文件不存在: %s", installerPath)
}
// 哈希验证
if expectedHash != "" && hashType != "" {
valid, err := VerifyFileHash(installerPath, expectedHash, hashType)
if err != nil {
return &InstallResult{Success: false, Message: "文件验证失败: " + err.Error()}, nil
}
if !valid {
return &InstallResult{Success: false, Message: "文件哈希值不匹配,文件可能已损坏或被篡改"}, nil
}
}
// 备份
backupPath, err := BackupApplication()
if err != nil {
return &InstallResult{Success: false, Message: fmt.Sprintf("备份失败: %v", err)}, nil
}
// 安装
ext := filepath.Ext(installerPath)
switch ext {
case ".exe":
if runtime.GOOS != "windows" {
return &InstallResult{Success: false, Message: "当前系统不是 Windows无法安装 .exe 文件"}, nil
}
err = installExe(installerPath)
case ".zip":
err = installZip(installerPath)
default:
return nil, fmt.Errorf("不支持的安装包格式: %s", ext)
}
// 处理安装结果
if err != nil {
// 安装失败,尝试回滚
if backupPath != "" {
_ = rollbackFromBackup(backupPath)
}
return &InstallResult{Success: false, Message: fmt.Sprintf("安装失败: %v", err)}, nil
}
// 自动重启
if autoRestart {
go func() {
time.Sleep(2 * time.Second)
restartApplication()
}()
}
return &InstallResult{Success: true, Message: "安装成功"}, nil
}
// getExecutablePath 获取当前可执行文件路径
func getExecutablePath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("获取可执行文件路径失败: %v", err)
}
return execPath, nil
}
// installExe 安装 exe 文件
func installExe(exePath string) error {
execPath, err := getExecutablePath()
if err != nil {
return err
}
return replaceExecutableFile(exePath, execPath)
}
// installZip 安装 ZIP 压缩包
func installZip(zipPath string) error {
zipReader, err := zip.OpenReader(zipPath)
if err != nil {
return fmt.Errorf("打开 ZIP 文件失败: %v", err)
}
defer zipReader.Close()
execPath, err := getExecutablePath()
if err != nil {
return err
}
execDir := filepath.Dir(execPath)
execName := filepath.Base(execPath)
// 解压到临时目录
tempDir := filepath.Join(execDir, ".update-temp")
if err := os.MkdirAll(tempDir, 0755); err != nil {
return fmt.Errorf("创建临时目录失败: %v", err)
}
defer os.RemoveAll(tempDir)
// 解压文件
for _, file := range zipReader.File {
filePath := filepath.Join(tempDir, file.Name)
if file.FileInfo().IsDir() {
os.MkdirAll(filePath, file.FileInfo().Mode())
continue
}
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return fmt.Errorf("创建目录失败: %v", err)
}
rc, err := file.Open()
if err != nil {
return fmt.Errorf("打开 ZIP 文件项失败: %v", err)
}
targetFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.FileInfo().Mode())
if err != nil {
rc.Close()
return fmt.Errorf("创建目标文件失败: %v", err)
}
if _, err := io.Copy(targetFile, rc); err != nil {
targetFile.Close()
rc.Close()
return fmt.Errorf("复制文件失败: %v", err)
}
targetFile.Close()
rc.Close()
}
// 查找新的可执行文件
newExecPath := filepath.Join(tempDir, execName)
if _, err := os.Stat(newExecPath); os.IsNotExist(err) {
files, _ := os.ReadDir(tempDir)
for _, f := range files {
if !f.IsDir() {
newExecPath = filepath.Join(tempDir, f.Name())
break
}
}
}
// 替换文件(使用与 installExe 相同的逻辑)
return replaceExecutableFile(newExecPath, execPath)
}
// replaceExecutableFile 替换可执行文件Windows 和 Unix 通用逻辑)
func replaceExecutableFile(newFilePath, execPath string) error {
if runtime.GOOS == "windows" {
return replaceExecutableFileWindows(newFilePath, execPath)
}
// Unix-like: 直接替换
if err := copyFile(newFilePath, execPath); err != nil {
return fmt.Errorf("复制文件失败: %v", err)
}
return os.Chmod(execPath, 0755)
}
// replaceExecutableFileWindows Windows 平台替换可执行文件
func replaceExecutableFileWindows(newFilePath, execPath string) error {
oldExecPath := execPath + ".old"
newExecPathTemp := execPath + ".new"
// 清理旧文件
os.Remove(oldExecPath)
os.Remove(newExecPathTemp)
// 复制新文件到临时位置
if err := copyFile(newFilePath, newExecPathTemp); err != nil {
return fmt.Errorf("复制新文件失败: %v", err)
}
// 尝试重命名当前文件(如果失败,说明文件正在使用)
if err := os.Rename(execPath, oldExecPath); err != nil {
return nil
}
// 替换文件
if err := os.Rename(newExecPathTemp, execPath); err != nil {
os.Rename(oldExecPath, execPath) // 恢复
return fmt.Errorf("替换文件失败: %v", err)
}
// 延迟删除旧文件
go func() {
time.Sleep(10 * time.Second)
os.Remove(oldExecPath)
}()
return nil
}
// restartApplication 重启应用
func restartApplication() {
execPath, err := getExecutablePath()
if err != nil {
return
}
currentPID := os.Getpid()
if runtime.GOOS != "windows" {
return
}
replacePendingFile(execPath)
// 创建并执行重启脚本
tempDir := os.TempDir()
batFile := filepath.Join(tempDir, fmt.Sprintf("restart_%d.bat", currentPID))
execDir := filepath.Dir(execPath)
batContent := fmt.Sprintf(`@echo off
cd /d "%s"
start "" "%s"
timeout /t 3 /nobreak >nul
taskkill /PID %d /F >nul 2>&1
del "%%~f0"
`, execDir, execPath, currentPID)
if err := os.WriteFile(batFile, []byte(batContent), 0644); err != nil {
fallbackRestart(execPath)
return
}
cmd := exec.Command("cmd", "/C", batFile)
cmd.Stdout = nil
cmd.Stderr = nil
if err := cmd.Start(); err != nil {
fallbackRestart(execPath)
return
}
os.Exit(0)
}
// fallbackRestart 降级重启方案
func fallbackRestart(execPath string) {
exec.Command("cmd", "/C", "start", "", execPath).Start()
time.Sleep(2 * time.Second)
os.Exit(0)
}
// replacePendingFile 替换待替换的文件(.new -> 可执行文件)
func replacePendingFile(execPath string) error {
newExecPathTemp := execPath + ".new"
if _, err := os.Stat(newExecPathTemp); os.IsNotExist(err) {
return nil // 没有待替换文件
}
oldExecPath := execPath + ".old"
os.Remove(oldExecPath)
if err := os.Rename(newExecPathTemp, execPath); err != nil {
return fmt.Errorf("文件替换失败: %v", err)
}
return nil
}
// rollbackFromBackup 从备份恢复
func rollbackFromBackup(backupPath string) error {
execPath, err := getExecutablePath()
if err != nil {
return err
}
return copyFile(backupPath, execPath)
}
// BackupApplication 备份当前应用
func BackupApplication() (string, error) {
execPath, err := getExecutablePath()
if err != nil {
return "", err
}
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("获取用户目录失败: %v", err)
}
backupDir := filepath.Join(homeDir, ".ssq-desk", "backups")
if err := os.MkdirAll(backupDir, 0755); err != nil {
return "", fmt.Errorf("创建备份目录失败: %v", err)
}
timestamp := time.Now().Format("20060102-150405")
backupFileName := fmt.Sprintf("ssq-desk-backup-%s%s", timestamp, filepath.Ext(execPath))
backupPath := filepath.Join(backupDir, backupFileName)
sourceFile, err := os.Open(execPath)
if err != nil {
return "", fmt.Errorf("打开源文件失败: %v", err)
}
defer sourceFile.Close()
backupFile, err := os.Create(backupPath)
if err != nil {
return "", fmt.Errorf("创建备份文件失败: %v", err)
}
defer backupFile.Close()
if _, err := backupFile.ReadFrom(sourceFile); err != nil {
return "", fmt.Errorf("复制文件失败: %v", err)
}
return backupPath, nil
}

View File

@@ -0,0 +1,168 @@
package service
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
)
// UpdateService 更新服务
type UpdateService struct {
checkURL string // 版本检查接口 URL
}
// NewUpdateService 创建更新服务
func NewUpdateService(checkURL string) *UpdateService {
return &UpdateService{
checkURL: checkURL,
}
}
// RemoteVersionInfo 远程版本信息
type RemoteVersionInfo struct {
Version string `json:"version"`
DownloadURL string `json:"download_url"`
Changelog string `json:"changelog"`
ForceUpdate bool `json:"force_update"`
ReleaseDate string `json:"release_date"`
}
// CheckUpdate 检查更新
func (s *UpdateService) CheckUpdate() (*UpdateCheckResult, error) {
log.Printf("[更新检查] 开始检查更新,检查地址: %s", s.checkURL)
// 加载配置
config, err := LoadUpdateConfig()
if err != nil {
log.Printf("[更新检查] 加载配置失败: %v", err)
return nil, fmt.Errorf("加载配置失败: %v", err)
}
// 获取当前版本(优先使用服务获取的最新版本号,而不是配置中可能过期的版本号)
currentVersionStr := GetCurrentVersion()
if currentVersionStr == "" {
// 如果服务获取失败,回退到配置中的版本号
currentVersionStr = config.CurrentVersion
log.Printf("[更新检查] 使用配置中的版本号: %s", currentVersionStr)
} else {
log.Printf("[更新检查] 使用服务获取的版本号: %s", currentVersionStr)
// 如果配置中的版本号不一致,更新配置
if config.CurrentVersion != currentVersionStr {
log.Printf("[更新检查] 配置中的版本号 (%s) 与当前版本号 (%s) 不一致,更新配置", config.CurrentVersion, currentVersionStr)
config.CurrentVersion = currentVersionStr
if err := SaveUpdateConfig(config); err != nil {
log.Printf("[更新检查] 更新配置失败: %v", err)
}
}
}
currentVersion, err := ParseVersion(currentVersionStr)
if err != nil {
log.Printf("[更新检查] 解析当前版本失败: %v", err)
return nil, fmt.Errorf("解析当前版本失败: %v", err)
}
// 请求远程版本信息
remoteInfo, err := s.fetchRemoteVersionInfo()
if err != nil {
log.Printf("[更新检查] 获取远程版本信息失败: %v", err)
return nil, fmt.Errorf("获取远程版本信息失败: %v", err)
}
log.Printf("[更新检查] 远程版本信息: 版本=%s, 下载地址=%s, 强制更新=%v",
remoteInfo.Version, remoteInfo.DownloadURL, remoteInfo.ForceUpdate)
// 解析远程版本号
remoteVersion, err := ParseVersion(remoteInfo.Version)
if err != nil {
log.Printf("[更新检查] 解析远程版本号失败: %v", err)
return nil, fmt.Errorf("解析远程版本号失败: %v", err)
}
// 比较版本
hasUpdate := remoteVersion.IsNewerThan(currentVersion)
log.Printf("[更新检查] 版本比较: 当前=%s, 远程=%s, 有更新=%v",
currentVersion.String(), remoteVersion.String(), hasUpdate)
// 更新最后检查时间
config.UpdateLastCheckTime()
result := &UpdateCheckResult{
HasUpdate: hasUpdate,
CurrentVersion: currentVersionStr,
LatestVersion: remoteInfo.Version,
DownloadURL: remoteInfo.DownloadURL,
Changelog: remoteInfo.Changelog,
ForceUpdate: remoteInfo.ForceUpdate,
ReleaseDate: remoteInfo.ReleaseDate,
}
log.Printf("[更新检查] 检查完成: 有更新=%v", hasUpdate)
return result, nil
}
// fetchRemoteVersionInfo 获取远程版本信息
func (s *UpdateService) fetchRemoteVersionInfo() (*RemoteVersionInfo, error) {
if s.checkURL == "" {
log.Printf("[远程版本] 版本检查 URL 未配置")
return nil, fmt.Errorf("版本检查 URL 未配置,请先设置检查地址")
}
log.Printf("[远程版本] 请求远程版本信息: %s", s.checkURL)
// 创建 HTTP 客户端,设置超时
client := &http.Client{
Timeout: 10 * time.Second,
}
// 发送请求
resp, err := client.Get(s.checkURL)
if err != nil {
log.Printf("[远程版本] 网络请求失败: %v", err)
return nil, fmt.Errorf("网络请求失败: %v", err)
}
defer resp.Body.Close()
log.Printf("[远程版本] HTTP 响应状态码: %d", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
}
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("[远程版本] 读取响应失败: %v", err)
return nil, fmt.Errorf("读取响应失败: %v", err)
}
log.Printf("[远程版本] 响应内容长度: %d 字节", len(body))
// 解析 JSON
var remoteInfo RemoteVersionInfo
if err := json.Unmarshal(body, &remoteInfo); err != nil {
log.Printf("[远程版本] 解析 JSON 失败: %v, 响应内容: %s", err, string(body))
return nil, fmt.Errorf("解析响应失败: %v", err)
}
if remoteInfo.Version == "" {
log.Printf("[远程版本] 远程版本信息不完整,响应内容: %s", string(body))
return nil, fmt.Errorf("远程版本信息不完整")
}
log.Printf("[远程版本] 成功获取远程版本信息: %+v", remoteInfo)
return &remoteInfo, nil
}
// UpdateCheckResult 更新检查结果
type UpdateCheckResult struct {
HasUpdate bool `json:"has_update"`
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
DownloadURL string `json:"download_url"`
Changelog string `json:"changelog"`
ForceUpdate bool `json:"force_update"`
ReleaseDate string `json:"release_date"`
}

159
internal/service/version.go Normal file
View File

@@ -0,0 +1,159 @@
package service
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
)
// ==================== 常量定义 ====================
// AppVersion 应用版本号(发布时直接修改此处)
const AppVersion = "0.1.1"
// ==================== 类型定义 ====================
// Version 版本号结构
type Version struct {
Major int
Minor int
Patch int
}
// WailsConfig Wails 配置文件结构
type WailsConfig struct {
Version string `json:"version"`
}
// ==================== 版本号解析和比较 ====================
// ParseVersion 解析版本号字符串(支持 v1.0.0 或 1.0.0 格式)
func ParseVersion(versionStr string) (*Version, error) {
versionStr = strings.TrimPrefix(versionStr, "v")
parts := strings.Split(versionStr, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("版本号格式错误,应为 x.y.z 格式")
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return nil, fmt.Errorf("主版本号解析失败: %v", err)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return nil, fmt.Errorf("次版本号解析失败: %v", err)
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return nil, fmt.Errorf("修订号解析失败: %v", err)
}
return &Version{Major: major, Minor: minor, Patch: patch}, nil
}
// String 返回版本号字符串格式v1.0.0
func (v *Version) String() string {
return fmt.Sprintf("v%d.%d.%d", v.Major, v.Minor, v.Patch)
}
// Compare 比较版本号
// 返回值:-1 表示当前版本小于目标版本0 表示相等1 表示大于
func (v *Version) Compare(other *Version) int {
if v.Major != other.Major {
if v.Major < other.Major {
return -1
}
return 1
}
if v.Minor != other.Minor {
if v.Minor < other.Minor {
return -1
}
return 1
}
if v.Patch != other.Patch {
if v.Patch < other.Patch {
return -1
}
return 1
}
return 0
}
// IsNewerThan 判断是否比目标版本新
func (v *Version) IsNewerThan(other *Version) bool {
return v.Compare(other) > 0
}
// IsOlderThan 判断是否比目标版本旧
func (v *Version) IsOlderThan(other *Version) bool {
return v.Compare(other) < 0
}
// ==================== 版本号获取 ====================
// GetCurrentVersion 获取当前版本号
// 优先级:硬编码版本号 > wails.json开发模式> 默认值
func GetCurrentVersion() string {
if AppVersion != "" {
log.Printf("[版本] 使用硬编码版本号: %s", AppVersion)
return AppVersion
}
version := getVersionFromWailsJSON()
if version != "" {
log.Printf("[版本] 从 wails.json 获取版本号: %s", version)
return version
}
log.Printf("[版本] 使用默认版本号: 0.0.1")
return "0.0.1"
}
// ==================== 配置文件读取 ====================
// getVersionFromWailsJSON 从 wails.json 读取版本号(仅开发模式使用)
func getVersionFromWailsJSON() string {
wd, err := os.Getwd()
if err != nil {
return ""
}
// 尝试当前目录
if version := readVersionFromFile(filepath.Join(wd, "wails.json")); version != "" {
return version
}
// 尝试父目录
if version := readVersionFromFile(filepath.Join(filepath.Dir(wd), "wails.json")); version != "" {
return version
}
return ""
}
// readVersionFromFile 从指定文件读取版本号
func readVersionFromFile(filePath string) string {
data, err := os.ReadFile(filePath)
if err != nil {
log.Printf("[版本] 读取文件失败: %s, 错误: %v", filePath, err)
return ""
}
var config WailsConfig
if err := json.Unmarshal(data, &config); err != nil {
log.Printf("[版本] 解析 JSON 失败: %s, 错误: %v", filePath, err)
return ""
}
if config.Version != "" {
log.Printf("[版本] 从文件读取版本号: %s -> %s", filePath, config.Version)
}
return config.Version
}