package storage import ( "crypto/sha256" "fmt" "os" "path/filepath" "time" "u-desk/internal/storage/models" "gorm.io/gorm" ) const downloadCacheTTL = 24 * time.Hour // cacheTempDir 确定性临时目录 var cacheTempDir = filepath.Join(os.TempDir(), "u-desk-cache") // GetCachedPath 查询缓存,验证文件存在后返回本地路径 func GetCachedPath(transport, connID, remotePath string, fileSize int64, modTime string) (string, bool) { db := GetDB() if db == nil { return "", false } var entry models.DownloadCache err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?", transport, connID, remotePath, fileSize, modTime).First(&entry).Error if err != nil { return "", false } // 检查文件是否仍然存在于磁盘 if _, err := os.Stat(entry.LocalPath); err != nil { // 文件已丢失,清理过期记录 db.Delete(&entry) return "", false } // 检查是否过期 if time.Since(entry.DownloadedAt) > downloadCacheTTL { os.Remove(entry.LocalPath) db.Delete(&entry) return "", false } return entry.LocalPath, true } // SaveCache 保存或更新缓存记录 func SaveCache(transport, connID, remotePath string, fileSize int64, modTime, localPath string) { db := GetDB() if db == nil { return } var existing models.DownloadCache err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?", transport, connID, remotePath, fileSize, modTime).First(&existing).Error if err == gorm.ErrRecordNotFound { db.Create(&models.DownloadCache{ Transport: transport, ConnID: connID, RemotePath: remotePath, FileSize: fileSize, ModTime: modTime, LocalPath: localPath, DownloadedAt: time.Now(), }) } else if err == nil { db.Model(&existing).Updates(map[string]any{ "local_path": localPath, "downloaded_at": time.Now(), }) } } // CleanupExpiredCache 清理超过 24h 的缓存记录并删除对应临时文件 func CleanupExpiredCache() { db := GetDB() if db == nil { return } cutoff := time.Now().Add(-downloadCacheTTL) var expired []models.DownloadCache db.Where("downloaded_at < ?", cutoff).Find(&expired) for _, entry := range expired { os.Remove(entry.LocalPath) db.Delete(&entry) } if len(expired) > 0 { fmt.Printf("[下载缓存] 清理 %d 条过期记录\n", len(expired)) } } // DownloadToTempCached 带缓存的下载:命中返回本地路径,未命中调用 downloadFn 后缓存结果 func DownloadToTempCached(transport, connID, remotePath string, fileSize int64, modTime string, downloadFn func() (string, error)) (string, error) { // 1. 查缓存 if localPath, hit := GetCachedPath(transport, connID, remotePath, fileSize, modTime); hit { return localPath, nil } // 2. 缓存未命中,执行下载 tempPath, err := downloadFn() if err != nil { return "", err } // 3. 生成确定性路径并移动文件 deterministicPath, err := deterministicCachePath(transport, connID, remotePath, fileSize, modTime) if err != nil { // 降级:直接使用 downloadFn 返回的路径,仍然缓存 SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath) return tempPath, nil } // 确保目录存在 if err := os.MkdirAll(filepath.Dir(deterministicPath), 0755); err != nil { SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath) return tempPath, nil } // 移动文件到确定性路径 if err := os.Rename(tempPath, deterministicPath); err != nil { // Rename 可能跨卷失败,尝试 Copy+Delete if copyFile(tempPath, deterministicPath) != nil { SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath) return tempPath, nil } os.Remove(tempPath) } SaveCache(transport, connID, remotePath, fileSize, modTime, deterministicPath) return deterministicPath, nil } // deterministicCachePath 根据文件信息生成确定性的缓存路径 func deterministicCachePath(transport, connID, remotePath string, fileSize int64, modTime string) (string, error) { h := sha256.New() h.Write([]byte(fmt.Sprintf("%s:%s:%s:%d:%s", transport, connID, remotePath, fileSize, modTime))) hash := fmt.Sprintf("%x", h.Sum(nil))[:16] baseName := filepath.Base(remotePath) if baseName == "" || baseName == "." || baseName == "/" { baseName = "file" } // 截断过长的文件名 if len(baseName) > 64 { ext := filepath.Ext(baseName) maxName := 64 - len(ext) if maxName <= 0 { maxName = 1 ext = ext[:63] } baseName = baseName[:maxName] + ext } fileName := fmt.Sprintf("%s_%s", hash, baseName) return filepath.Join(cacheTempDir, fileName), nil } // copyFile 复制文件内容 func copyFile(src, dst string) error { in, err := os.Open(src) if err != nil { return err } defer in.Close() out, err := os.Create(dst) if err != nil { return err } defer out.Close() if _, err := out.ReadFrom(in); err != nil { os.Remove(dst) return err } return nil }