188 lines
4.9 KiB
Go
188 lines
4.9 KiB
Go
package storage
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"u-desk/internal/storage/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const downloadCacheTTL = 24 * time.Hour
|
|
|
|
// cacheTempDir 确定性临时目录
|
|
var cacheTempDir = filepath.Join(os.TempDir(), "u-desk-cache")
|
|
|
|
// GetCachedPath 查询缓存,验证文件存在后返回本地路径
|
|
func GetCachedPath(transport, connID, remotePath string, fileSize int64, modTime string) (string, bool) {
|
|
db := GetDB()
|
|
if db == nil {
|
|
return "", false
|
|
}
|
|
|
|
var entry models.DownloadCache
|
|
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
|
|
transport, connID, remotePath, fileSize, modTime).First(&entry).Error
|
|
if err != nil {
|
|
return "", false
|
|
}
|
|
|
|
// 检查文件是否仍然存在于磁盘
|
|
if _, err := os.Stat(entry.LocalPath); err != nil {
|
|
// 文件已丢失,清理过期记录
|
|
db.Delete(&entry)
|
|
return "", false
|
|
}
|
|
|
|
// 检查是否过期
|
|
if time.Since(entry.DownloadedAt) > downloadCacheTTL {
|
|
os.Remove(entry.LocalPath)
|
|
db.Delete(&entry)
|
|
return "", false
|
|
}
|
|
|
|
return entry.LocalPath, true
|
|
}
|
|
|
|
// SaveCache 保存或更新缓存记录
|
|
func SaveCache(transport, connID, remotePath string, fileSize int64, modTime, localPath string) {
|
|
db := GetDB()
|
|
if db == nil {
|
|
return
|
|
}
|
|
|
|
var existing models.DownloadCache
|
|
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
|
|
transport, connID, remotePath, fileSize, modTime).First(&existing).Error
|
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
db.Create(&models.DownloadCache{
|
|
Transport: transport,
|
|
ConnID: connID,
|
|
RemotePath: remotePath,
|
|
FileSize: fileSize,
|
|
ModTime: modTime,
|
|
LocalPath: localPath,
|
|
DownloadedAt: time.Now(),
|
|
})
|
|
} else if err == nil {
|
|
db.Model(&existing).Updates(map[string]any{
|
|
"local_path": localPath,
|
|
"downloaded_at": time.Now(),
|
|
})
|
|
}
|
|
}
|
|
|
|
// CleanupExpiredCache 清理超过 24h 的缓存记录并删除对应临时文件
|
|
func CleanupExpiredCache() {
|
|
db := GetDB()
|
|
if db == nil {
|
|
return
|
|
}
|
|
|
|
cutoff := time.Now().Add(-downloadCacheTTL)
|
|
var expired []models.DownloadCache
|
|
db.Where("downloaded_at < ?", cutoff).Find(&expired)
|
|
|
|
for _, entry := range expired {
|
|
os.Remove(entry.LocalPath)
|
|
db.Delete(&entry)
|
|
}
|
|
|
|
if len(expired) > 0 {
|
|
fmt.Printf("[下载缓存] 清理 %d 条过期记录\n", len(expired))
|
|
}
|
|
}
|
|
|
|
// DownloadToTempCached 带缓存的下载:命中返回本地路径,未命中调用 downloadFn 后缓存结果
|
|
func DownloadToTempCached(transport, connID, remotePath string, fileSize int64, modTime string, downloadFn func() (string, error)) (string, error) {
|
|
// 1. 查缓存
|
|
if localPath, hit := GetCachedPath(transport, connID, remotePath, fileSize, modTime); hit {
|
|
return localPath, nil
|
|
}
|
|
|
|
// 2. 缓存未命中,执行下载
|
|
tempPath, err := downloadFn()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// 3. 生成确定性路径并移动文件
|
|
deterministicPath, err := deterministicCachePath(transport, connID, remotePath, fileSize, modTime)
|
|
if err != nil {
|
|
// 降级:直接使用 downloadFn 返回的路径,仍然缓存
|
|
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
|
return tempPath, nil
|
|
}
|
|
|
|
// 确保目录存在
|
|
if err := os.MkdirAll(filepath.Dir(deterministicPath), 0755); err != nil {
|
|
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
|
return tempPath, nil
|
|
}
|
|
|
|
// 移动文件到确定性路径
|
|
if err := os.Rename(tempPath, deterministicPath); err != nil {
|
|
// Rename 可能跨卷失败,尝试 Copy+Delete
|
|
if copyFile(tempPath, deterministicPath) != nil {
|
|
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
|
return tempPath, nil
|
|
}
|
|
os.Remove(tempPath)
|
|
}
|
|
|
|
SaveCache(transport, connID, remotePath, fileSize, modTime, deterministicPath)
|
|
return deterministicPath, nil
|
|
}
|
|
|
|
// deterministicCachePath 根据文件信息生成确定性的缓存路径
|
|
func deterministicCachePath(transport, connID, remotePath string, fileSize int64, modTime string) (string, error) {
|
|
h := sha256.New()
|
|
h.Write([]byte(fmt.Sprintf("%s:%s:%s:%d:%s", transport, connID, remotePath, fileSize, modTime)))
|
|
hash := fmt.Sprintf("%x", h.Sum(nil))[:16]
|
|
|
|
baseName := filepath.Base(remotePath)
|
|
if baseName == "" || baseName == "." || baseName == "/" {
|
|
baseName = "file"
|
|
}
|
|
|
|
// 截断过长的文件名
|
|
if len(baseName) > 64 {
|
|
ext := filepath.Ext(baseName)
|
|
maxName := 64 - len(ext)
|
|
if maxName <= 0 {
|
|
maxName = 1
|
|
ext = ext[:63]
|
|
}
|
|
baseName = baseName[:maxName] + ext
|
|
}
|
|
|
|
fileName := fmt.Sprintf("%s_%s", hash, baseName)
|
|
return filepath.Join(cacheTempDir, fileName), nil
|
|
}
|
|
|
|
// copyFile 复制文件内容
|
|
func copyFile(src, dst string) error {
|
|
in, err := os.Open(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer in.Close()
|
|
|
|
out, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
|
|
if _, err := out.ReadFrom(in); err != nil {
|
|
os.Remove(dst)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|