package service import ( "crypto/md5" "crypto/sha256" "encoding/hex" "fmt" "io" "net/http" "os" "path/filepath" "time" ) // getRemoteFileSize 通过HEAD请求获取远程文件大小 func getRemoteFileSize(url string) (int64, error) { client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest("HEAD", url, nil) if err != nil { return 0, err } resp, err := client.Do(req) if err != nil { return 0, err } defer resp.Body.Close() if resp.ContentLength > 0 { return resp.ContentLength, nil } return 0, fmt.Errorf("无法获取文件大小") } // normalizeProgress 标准化进度值到0-100之间 func normalizeProgress(progress float64) float64 { if progress < 0 { return 0 } if progress > 100 { return 100 } return progress } // DownloadProgress 下载进度回调函数类型 type DownloadProgress func(progress float64, speed float64, downloaded int64, total int64) // DownloadProgressInfo 下载进度信息 type DownloadProgressInfo struct { Progress float64 `json:"progress"` // 进度百分比 Speed float64 `json:"speed"` // 下载速度(字节/秒) Downloaded int64 `json:"downloaded"` // 已下载字节数 Total int64 `json:"total"` // 总字节数 Error string `json:"error,omitempty"` // 错误信息 Result *DownloadResult `json:"result,omitempty"` // 下载结果 } // DownloadResult 下载结果 type DownloadResult struct { FilePath string `json:"file_path"` FileSize int64 `json:"file_size"` MD5Hash string `json:"md5_hash,omitempty"` SHA256Hash string `json:"sha256_hash,omitempty"` } // DownloadUpdate 下载更新包 func DownloadUpdate(downloadURL string, progressCallback DownloadProgress) (*DownloadResult, error) { // 获取下载目录 homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("获取用户目录失败: %v", err) } downloadDir := filepath.Join(homeDir, ".ssq-desk", "downloads") if err := os.MkdirAll(downloadDir, 0755); err != nil { return nil, fmt.Errorf("创建下载目录失败: %v", err) } // 从 URL 提取文件名 filename := filepath.Base(downloadURL) if filename == "" || filename == "." { filename = fmt.Sprintf("update-%d.exe", time.Now().Unix()) } filePath := filepath.Join(downloadDir, filename) // 检查文件是否已存在 var downloadedSize int64 = 0 if fileInfo, err := os.Stat(filePath); err == nil { downloadedSize = fileInfo.Size() // 检查是否已完整下载 if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && downloadedSize == remoteSize { if progressCallback != nil { progressCallback(100.0, 0, downloadedSize, remoteSize) time.Sleep(50 * time.Millisecond) } md5Hash, sha256Hash, err := calculateFileHashes(filePath) if err != nil { return nil, fmt.Errorf("计算文件哈希失败: %v", err) } return &DownloadResult{ FilePath: filePath, FileSize: downloadedSize, MD5Hash: md5Hash, SHA256Hash: sha256Hash, }, nil } } // 打开文件(支持断点续传) file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { return nil, fmt.Errorf("创建文件失败: %v", err) } defer file.Close() // 创建 HTTP 请求 req, err := http.NewRequest("GET", downloadURL, nil) if err != nil { return nil, fmt.Errorf("创建请求失败: %v", err) } // 如果已下载部分,设置 Range 头(断点续传) if downloadedSize > 0 { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloadedSize)) } // 发送请求 client := &http.Client{Timeout: 30 * time.Minute} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("下载请求失败: %v", err) } defer resp.Body.Close() // 检查响应状态 if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { file.Close() if fileInfo, err := os.Stat(filePath); err == nil { if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && fileInfo.Size() == remoteSize { if progressCallback != nil { progressCallback(100.0, 0, fileInfo.Size(), remoteSize) } md5Hash, sha256Hash, err := calculateFileHashes(filePath) if err != nil { return nil, fmt.Errorf("计算文件哈希失败: %v", err) } return &DownloadResult{ FilePath: filePath, FileSize: fileInfo.Size(), MD5Hash: md5Hash, SHA256Hash: sha256Hash, }, nil } } return nil, fmt.Errorf("服务器返回 416 错误,且文件可能不完整") } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode) } // 获取文件总大小 contentLength := resp.ContentLength gotTotalFromRange := false // 优先从 Content-Range 头获取总大小 if rangeHeader := resp.Header.Get("Content-Range"); rangeHeader != "" { var start, end, total int64 if n, _ := fmt.Sscanf(rangeHeader, "bytes %d-%d/%d", &start, &end, &total); n == 3 && total > 0 { contentLength = total gotTotalFromRange = true } } // 如果未获取到,尝试通过 HEAD 请求获取 if contentLength <= 0 && !gotTotalFromRange { if remoteSize, err := getRemoteFileSize(downloadURL); err == nil { contentLength = remoteSize } } // 断点续传时,如果未从 Content-Range 获取,需要加上已下载部分 if resp.StatusCode == http.StatusPartialContent && downloadedSize > 0 && contentLength > 0 && !gotTotalFromRange { if contentLength < downloadedSize { contentLength += downloadedSize } } // 发送初始进度事件 if progressCallback != nil { time.Sleep(50 * time.Millisecond) if contentLength > 0 && downloadedSize > 0 { progress := normalizeProgress(float64(downloadedSize) / float64(contentLength) * 100) progressCallback(progress, 0, downloadedSize, contentLength) } else { progressCallback(0, 0, downloadedSize, -1) } time.Sleep(20 * time.Millisecond) } // 下载文件 buffer := make([]byte, 32*1024) // 32KB 缓冲区 var totalDownloaded int64 = downloadedSize startTime := time.Now() lastProgressTime := startTime lastProgressSize := totalDownloaded for { n, err := resp.Body.Read(buffer) if n > 0 { written, writeErr := file.Write(buffer[:n]) if writeErr != nil { return nil, fmt.Errorf("写入文件失败: %v", writeErr) } totalDownloaded += int64(written) // 计算进度和速度 now := time.Now() elapsed := now.Sub(lastProgressTime).Seconds() // 每 0.3 秒更新一次进度 if elapsed >= 0.3 { progress := float64(0) if contentLength > 0 { progress = normalizeProgress(float64(totalDownloaded) / float64(contentLength) * 100) } speed := float64(0) if elapsed > 0 { speed = float64(totalDownloaded-lastProgressSize) / elapsed } if progressCallback != nil { progressCallback(progress, speed, totalDownloaded, contentLength) } lastProgressTime = now lastProgressSize = totalDownloaded } } if err == io.EOF { break } if err != nil { return nil, fmt.Errorf("读取数据失败: %v", err) } } // 最后一次进度更新 if progressCallback != nil { if contentLength > 0 { progressCallback(100.0, 0, totalDownloaded, contentLength) } else { progressCallback(100.0, 0, totalDownloaded, totalDownloaded) } } file.Close() md5Hash, sha256Hash, err := calculateFileHashes(filePath) if err != nil { return nil, fmt.Errorf("计算文件哈希失败: %v", err) } fileInfo, err := os.Stat(filePath) if err != nil { return nil, fmt.Errorf("获取文件信息失败: %v", err) } return &DownloadResult{ FilePath: filePath, FileSize: fileInfo.Size(), MD5Hash: md5Hash, SHA256Hash: sha256Hash, }, nil } // calculateFileHashes 计算文件的 MD5 和 SHA256 哈希值 func calculateFileHashes(filePath string) (string, string, error) { file, err := os.Open(filePath) if err != nil { return "", "", err } defer file.Close() md5Hash := md5.New() sha256Hash := sha256.New() // 使用 MultiWriter 同时计算两个哈希 writer := io.MultiWriter(md5Hash, sha256Hash) if _, err := io.Copy(writer, file); err != nil { return "", "", err } md5Sum := hex.EncodeToString(md5Hash.Sum(nil)) sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil)) return md5Sum, sha256Sum, nil } // VerifyFileHash 验证文件哈希值 func VerifyFileHash(filePath string, expectedHash string, hashType string) (bool, error) { file, err := os.Open(filePath) if err != nil { return false, err } defer file.Close() var hash []byte var calculatedHash string switch hashType { case "md5": md5Hash := md5.New() if _, err := io.Copy(md5Hash, file); err != nil { return false, err } hash = md5Hash.Sum(nil) calculatedHash = hex.EncodeToString(hash) case "sha256": sha256Hash := sha256.New() if _, err := io.Copy(sha256Hash, file); err != nil { return false, err } hash = sha256Hash.Sum(nil) calculatedHash = hex.EncodeToString(hash) default: return false, fmt.Errorf("不支持的哈希类型: %s", hashType) } return calculatedHash == expectedHash, nil }