Private
Public Access
1
0
Files
u-desk/internal/storage/download_cache.go

188 lines
4.9 KiB
Go

package storage
import (
"crypto/sha256"
"fmt"
"os"
"path/filepath"
"time"
"u-desk/internal/storage/models"
"gorm.io/gorm"
)
const downloadCacheTTL = 24 * time.Hour
// cacheTempDir 确定性临时目录
var cacheTempDir = filepath.Join(os.TempDir(), "u-desk-cache")
// GetCachedPath 查询缓存,验证文件存在后返回本地路径
func GetCachedPath(transport, connID, remotePath string, fileSize int64, modTime string) (string, bool) {
db := GetDB()
if db == nil {
return "", false
}
var entry models.DownloadCache
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
transport, connID, remotePath, fileSize, modTime).First(&entry).Error
if err != nil {
return "", false
}
// 检查文件是否仍然存在于磁盘
if _, err := os.Stat(entry.LocalPath); err != nil {
// 文件已丢失,清理过期记录
db.Delete(&entry)
return "", false
}
// 检查是否过期
if time.Since(entry.DownloadedAt) > downloadCacheTTL {
os.Remove(entry.LocalPath)
db.Delete(&entry)
return "", false
}
return entry.LocalPath, true
}
// SaveCache 保存或更新缓存记录
func SaveCache(transport, connID, remotePath string, fileSize int64, modTime, localPath string) {
db := GetDB()
if db == nil {
return
}
var existing models.DownloadCache
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
transport, connID, remotePath, fileSize, modTime).First(&existing).Error
if err == gorm.ErrRecordNotFound {
db.Create(&models.DownloadCache{
Transport: transport,
ConnID: connID,
RemotePath: remotePath,
FileSize: fileSize,
ModTime: modTime,
LocalPath: localPath,
DownloadedAt: time.Now(),
})
} else if err == nil {
db.Model(&existing).Updates(map[string]any{
"local_path": localPath,
"downloaded_at": time.Now(),
})
}
}
// CleanupExpiredCache 清理超过 24h 的缓存记录并删除对应临时文件
func CleanupExpiredCache() {
db := GetDB()
if db == nil {
return
}
cutoff := time.Now().Add(-downloadCacheTTL)
var expired []models.DownloadCache
db.Where("downloaded_at < ?", cutoff).Find(&expired)
for _, entry := range expired {
os.Remove(entry.LocalPath)
db.Delete(&entry)
}
if len(expired) > 0 {
fmt.Printf("[下载缓存] 清理 %d 条过期记录\n", len(expired))
}
}
// DownloadToTempCached 带缓存的下载:命中返回本地路径,未命中调用 downloadFn 后缓存结果
func DownloadToTempCached(transport, connID, remotePath string, fileSize int64, modTime string, downloadFn func() (string, error)) (string, error) {
// 1. 查缓存
if localPath, hit := GetCachedPath(transport, connID, remotePath, fileSize, modTime); hit {
return localPath, nil
}
// 2. 缓存未命中,执行下载
tempPath, err := downloadFn()
if err != nil {
return "", err
}
// 3. 生成确定性路径并移动文件
deterministicPath, err := deterministicCachePath(transport, connID, remotePath, fileSize, modTime)
if err != nil {
// 降级:直接使用 downloadFn 返回的路径,仍然缓存
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
return tempPath, nil
}
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(deterministicPath), 0755); err != nil {
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
return tempPath, nil
}
// 移动文件到确定性路径
if err := os.Rename(tempPath, deterministicPath); err != nil {
// Rename 可能跨卷失败,尝试 Copy+Delete
if copyFile(tempPath, deterministicPath) != nil {
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
return tempPath, nil
}
os.Remove(tempPath)
}
SaveCache(transport, connID, remotePath, fileSize, modTime, deterministicPath)
return deterministicPath, nil
}
// deterministicCachePath 根据文件信息生成确定性的缓存路径
func deterministicCachePath(transport, connID, remotePath string, fileSize int64, modTime string) (string, error) {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%s:%s:%s:%d:%s", transport, connID, remotePath, fileSize, modTime)))
hash := fmt.Sprintf("%x", h.Sum(nil))[:16]
baseName := filepath.Base(remotePath)
if baseName == "" || baseName == "." || baseName == "/" {
baseName = "file"
}
// 截断过长的文件名
if len(baseName) > 64 {
ext := filepath.Ext(baseName)
maxName := 64 - len(ext)
if maxName <= 0 {
maxName = 1
ext = ext[:63]
}
baseName = baseName[:maxName] + ext
}
fileName := fmt.Sprintf("%s_%s", hash, baseName)
return filepath.Join(cacheTempDir, fileName), nil
}
// copyFile 复制文件内容
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := out.ReadFrom(in); err != nil {
os.Remove(dst)
return err
}
return nil
}