Private
Public Access
1
0
Files
u-desk/internal/service/update_download.go
绝尘 eb2cbad17b 优化:代码质量提升,修复重复逻辑和语法高亮支持
- 简化计算属性,删除重复代码
- 优化文件扩展名获取逻辑
- 新增文件工具函数库 fileHelpers.js
- 增强 CodeEditor 语法高亮(支持 30+ 语言)
- 修复 Office 文档文件服务器访问权限
- 添加特殊文件名支持(Dockerfile、Makefile 等)
2026-01-30 02:29:51 +08:00

342 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"io"
"log"
"net/http"
"os"
"path/filepath"
"time"
"u-desk/internal/common"
)
// ==================== 类型定义 ====================
// DownloadProgress 下载进度回调函数类型
type DownloadProgress func(progress float64, speed float64, downloaded int64, total int64)
// DownloadResult 下载结果
type DownloadResult struct {
FilePath string `json:"file_path"`
FileSize int64 `json:"file_size"`
MD5Hash string `json:"md5_hash,omitempty"`
SHA256Hash string `json:"sha256_hash,omitempty"`
}
// ==================== 下载更新 ====================
// DownloadUpdate 下载更新包
func DownloadUpdate(downloadURL string, progressCallback DownloadProgress) (*DownloadResult, error) {
log.Printf("[下载] 开始下载URL: %s", downloadURL)
// 获取下载目录
downloadDir := filepath.Join(common.GetUserDataDir(), "downloads")
if err := os.MkdirAll(downloadDir, 0755); err != nil {
return nil, fmt.Errorf("创建下载目录失败: %v", err)
}
// 从 URL 提取文件名
filename := filepath.Base(downloadURL)
if filename == "" || filename == "." {
filename = fmt.Sprintf("update-%d.exe", time.Now().Unix())
}
filePath := filepath.Join(downloadDir, filename)
// 检查文件是否已存在
var downloadedSize int64 = 0
if fileInfo, err := os.Stat(filePath); err == nil {
downloadedSize = fileInfo.Size()
// 检查是否已完整下载
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && downloadedSize == remoteSize {
log.Printf("[下载] 文件已存在且完整: %s", filePath)
if progressCallback != nil {
progressCallback(100.0, 0, downloadedSize, remoteSize)
time.Sleep(50 * time.Millisecond)
}
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: downloadedSize,
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
}
// 打开文件(支持断点续传)
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("创建文件失败: %v", err)
}
defer file.Close()
// 创建 HTTP 请求
req, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
// 如果已下载部分,设置 Range 头(断点续传)
if downloadedSize > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloadedSize))
log.Printf("[下载] 启用断点续传,已下载: %d 字节", downloadedSize)
}
// 发送请求
client := &http.Client{Timeout: 30 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("下载请求失败: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
file.Close()
if fileInfo, err := os.Stat(filePath); err == nil {
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil && fileInfo.Size() == remoteSize {
log.Printf("[下载] 文件已完整下载")
if progressCallback != nil {
progressCallback(100.0, 0, fileInfo.Size(), remoteSize)
}
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: fileInfo.Size(),
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
}
return nil, fmt.Errorf("服务器返回 416 错误,且文件可能不完整")
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return nil, fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
}
// 获取文件总大小
contentLength := resp.ContentLength
gotTotalFromRange := false
// 优先从 Content-Range 头获取总大小
if rangeHeader := resp.Header.Get("Content-Range"); rangeHeader != "" {
var start, end, total int64
if n, _ := fmt.Sscanf(rangeHeader, "bytes %d-%d/%d", &start, &end, &total); n == 3 && total > 0 {
contentLength = total
gotTotalFromRange = true
}
}
// 如果未获取到,尝试通过 HEAD 请求获取
if contentLength <= 0 && !gotTotalFromRange {
if remoteSize, err := getRemoteFileSize(downloadURL); err == nil {
contentLength = remoteSize
}
}
// 断点续传时,如果未从 Content-Range 获取,需要加上已下载部分
if resp.StatusCode == http.StatusPartialContent && downloadedSize > 0 && contentLength > 0 && !gotTotalFromRange {
if contentLength < downloadedSize {
contentLength += downloadedSize
}
}
log.Printf("[下载] 开始下载文件,总大小: %d 字节,已下载: %d 字节", contentLength, downloadedSize)
// 发送初始进度事件
if progressCallback != nil {
time.Sleep(50 * time.Millisecond)
if contentLength > 0 && downloadedSize > 0 {
progress := normalizeProgress(float64(downloadedSize) / float64(contentLength) * 100)
progressCallback(progress, 0, downloadedSize, contentLength)
} else {
progressCallback(0, 0, downloadedSize, -1)
}
time.Sleep(20 * time.Millisecond)
}
// 下载文件
buffer := make([]byte, 32*1024) // 32KB 缓冲区
var totalDownloaded int64 = downloadedSize
startTime := time.Now()
lastProgressTime := startTime
lastProgressSize := totalDownloaded
for {
n, err := resp.Body.Read(buffer)
if n > 0 {
written, writeErr := file.Write(buffer[:n])
if writeErr != nil {
return nil, fmt.Errorf("写入文件失败: %v", writeErr)
}
totalDownloaded += int64(written)
// 计算进度和速度
now := time.Now()
elapsed := now.Sub(lastProgressTime).Seconds()
// 每 0.3 秒更新一次进度
if elapsed >= 0.3 {
progress := float64(0)
if contentLength > 0 {
rawProgress := float64(totalDownloaded) / float64(contentLength) * 100
progress = normalizeProgress(rawProgress)
}
speed := float64(0)
if elapsed > 0 {
speed = float64(totalDownloaded-lastProgressSize) / elapsed
}
if progressCallback != nil {
progressCallback(progress, speed, totalDownloaded, contentLength)
}
lastProgressTime = now
lastProgressSize = totalDownloaded
}
}
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("读取数据失败: %v", err)
}
}
// 最后一次进度更新
if progressCallback != nil {
if contentLength > 0 {
progressCallback(100.0, 0, totalDownloaded, contentLength)
} else {
progressCallback(100.0, 0, totalDownloaded, totalDownloaded)
}
}
file.Close()
log.Printf("[下载] 下载完成,文件大小: %d 字节", totalDownloaded)
md5Hash, sha256Hash, err := calculateFileHashes(filePath)
if err != nil {
return nil, fmt.Errorf("计算文件哈希失败: %v", err)
}
log.Printf("[下载] 文件哈希计算完成MD5: %s, SHA256: %s", md5Hash, sha256Hash)
fileInfo, err := os.Stat(filePath)
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
return &DownloadResult{
FilePath: filePath,
FileSize: fileInfo.Size(),
MD5Hash: md5Hash,
SHA256Hash: sha256Hash,
}, nil
}
// ==================== 辅助函数 ====================
// getRemoteFileSize 通过HEAD请求获取远程文件大小
func getRemoteFileSize(url string) (int64, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return 0, err
}
resp, err := client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.ContentLength > 0 {
return resp.ContentLength, nil
}
return 0, fmt.Errorf("无法获取文件大小")
}
// normalizeProgress 标准化进度值到0-100之间
func normalizeProgress(progress float64) float64 {
if progress < 0 {
return 0
}
if progress > 100 {
return 100
}
return progress
}
// calculateHash 计算文件的哈希值(通用函数)
func calculateHash(filePath string, hashType string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
var hash hash.Hash
switch hashType {
case "md5":
hash = md5.New()
case "sha256":
hash = sha256.New()
default:
return "", fmt.Errorf("不支持的哈希类型: %s", hashType)
}
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
// calculateFileHashes 计算文件的 MD5 和 SHA256 哈希值(优化版,使用 MultiWriter
func calculateFileHashes(filePath string) (string, string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", "", err
}
defer file.Close()
md5Hash := md5.New()
sha256Hash := sha256.New()
// 使用 MultiWriter 同时计算两个哈希,只读取文件一次
writer := io.MultiWriter(md5Hash, sha256Hash)
if _, err := io.Copy(writer, file); err != nil {
return "", "", err
}
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
return md5Sum, sha256Sum, nil
}
// VerifyFileHash 验证文件哈希值
func VerifyFileHash(filePath string, expectedHash string, hashType string) (bool, error) {
calculatedHash, err := calculateHash(filePath, hashType)
if err != nil {
return false, err
}
return calculatedHash == expectedHash, nil
}