新增: SFTP直连+网站预览+OSS区域嗅探+热键+BGM播放
This commit is contained in:
@@ -45,9 +45,8 @@ var (
|
||||
es6DynamicImport = regexp.MustCompile(`import\s*\(\s*["']([^"']+)["']\s*\)`)
|
||||
es6BareImport = regexp.MustCompile(`(?m)^\s*import\s+["']([^"']+)["']`)
|
||||
|
||||
// HTML 预览路径修复
|
||||
locationPathRegex = regexp.MustCompile(`\blocation\.pathname\b`)
|
||||
winDriveRegex = regexp.MustCompile(`^[A-Za-z]:`)
|
||||
// Windows 盘符检测
|
||||
winDriveRegex = regexp.MustCompile(`^[A-Za-z]:`)
|
||||
)
|
||||
|
||||
// HTML 属性正则缓存(避免 replaceHtmlTagAttribute 中重复编译)
|
||||
@@ -78,6 +77,8 @@ func validateFilePath(rawPath string, logPrefix string) (string, error) {
|
||||
clean = strings.TrimPrefix(clean, "/localfs/")
|
||||
clean = strings.TrimPrefix(clean, "localfs/")
|
||||
}
|
||||
// 清理残留的前导斜杠(避免 /u-res/... 类路径在 Windows 上异常)
|
||||
clean = strings.TrimLeft(clean, "/")
|
||||
|
||||
// 平台适配:Windows 用反斜杠,Linux/macOS 保持正斜杠
|
||||
filePath := filepath.FromSlash(clean)
|
||||
@@ -304,7 +305,12 @@ func handleLocalFileRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置响应头
|
||||
contentType := getContentType(ext)
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Cache-Control", "public, max-age=3600")
|
||||
// 媒体文件禁用缓存(避免 Chromium ERR_CACHE_OPERATION_NOT_SUPPORTED)
|
||||
if isMediaExt(ext) {
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
} else {
|
||||
w.Header().Set("Cache-Control", "public, max-age=3600")
|
||||
}
|
||||
// 支持 Range 请求
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
|
||||
@@ -365,6 +371,16 @@ func isAllowedFileType(ext string) bool {
|
||||
return defaultFileTypeManager.IsAllowed(ext)
|
||||
}
|
||||
|
||||
// isMediaExt 判断是否为音频/视频扩展名
|
||||
func isMediaExt(ext string) bool {
|
||||
switch ext {
|
||||
case ".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma",
|
||||
".mp4", ".webm", ".mkv", ".avi", ".mov", ".wmv":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭文件服务器
|
||||
func (lfs *LocalFileServer) Shutdown() error {
|
||||
if lfs == nil || lfs.server == nil {
|
||||
@@ -556,11 +572,8 @@ func handleHtmlPreviewRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// 转换资源路径(将相对路径和绝对路径都转换为完整的本地文件服务器 URL)
|
||||
processedContent := transformHtmlResourcePaths(string(content), baseDir)
|
||||
|
||||
// 修复 JS 中基于 location.pathname 的相对路径计算
|
||||
// 预览模式下 location.pathname = "/localfs/html-preview",与实际文件路径不一致
|
||||
// ⚠️ 会替换所有出现位置(含JS字符串内),HTML预览场景下可接受
|
||||
correctPathname := `"/localfs/` + strings.ReplaceAll(baseDir, "\\", "/") + `/`
|
||||
processedContent = locationPathRegex.ReplaceAllString(processedContent, correctPathname)
|
||||
// 注入路径拦截脚本(处理 webpack 等动态加载的绝对路径资源)
|
||||
processedContent = injectPathInterceptor(processedContent, baseDir)
|
||||
|
||||
// 注入链接点击拦截脚本
|
||||
finalContent := injectLinkInterceptor(processedContent)
|
||||
@@ -870,3 +883,38 @@ func injectLinkInterceptor(htmlContent string) string {
|
||||
// 没有 body 标签,在末尾插入
|
||||
return htmlContent + script
|
||||
}
|
||||
|
||||
// injectPathInterceptor 注入路径拦截脚本(处理 webpack 等动态加载的绝对路径资源)
|
||||
// 重写动态创建的 <script src="/..."> 和 <link href="/..."> 为 /localfs/ 前缀路径
|
||||
func injectPathInterceptor(htmlContent string, baseDir string) string {
|
||||
// 直接使用 baseDir(HTML 所在目录)作为 base,与 transformHtmlResourcePaths 的路径解析一致
|
||||
base := toLocalServerUrl(strings.ReplaceAll(baseDir, "\\", "/"))
|
||||
|
||||
script := `<script data-udesk-intercept="true">
|
||||
(function(){
|
||||
var base = "` + base + `/";
|
||||
function rw(v){if(typeof v!=='string')return v;if(v[0]==='/'&&!v.startsWith('/localfs/')&&!v.startsWith('//')&&!v.startsWith('http'))return base+v.substring(1);return v;}
|
||||
var sa=Element.prototype.setAttribute;
|
||||
Element.prototype.setAttribute=function(n,v){if((n==='src'||n==='href'||n==='data'||n==='poster')&&typeof v==='string')v=rw(v);return sa.call(this,n,v);};
|
||||
try{var d=Object.getOwnPropertyDescriptor(HTMLScriptElement.prototype,'src');Object.defineProperty(HTMLScriptElement.prototype,'src',{set:function(v){d.set.call(this,rw(v))},get:d.get,configurable:true});}catch(e){}
|
||||
try{var d2=Object.getOwnPropertyDescriptor(HTMLLinkElement.prototype,'href');Object.defineProperty(HTMLLinkElement.prototype,'href',{set:function(v){d2.set.call(this,rw(v))},get:d2.get,configurable:true});}catch(e){}
|
||||
})();
|
||||
</script>`
|
||||
|
||||
// 在 <head> 后立即插入(确保在任何其他脚本之前执行)
|
||||
if idx := strings.Index(htmlContent, "<head>"); idx >= 0 {
|
||||
return htmlContent[:idx+6] + script + htmlContent[idx+6:]
|
||||
}
|
||||
if idx := strings.Index(htmlContent, "<HEAD>"); idx >= 0 {
|
||||
return htmlContent[:idx+6] + script + htmlContent[idx+6:]
|
||||
}
|
||||
// 没有 head 标签,在 <!DOCTYPE> 和 <html> 后插入
|
||||
if idx := strings.Index(htmlContent, "<html"); idx >= 0 {
|
||||
end := strings.Index(htmlContent[idx:], ">")
|
||||
if end >= 0 {
|
||||
pos := idx + end + 1
|
||||
return htmlContent[:pos] + script + htmlContent[pos:]
|
||||
}
|
||||
}
|
||||
return script + htmlContent
|
||||
}
|
||||
|
||||
41
internal/hotkey/hotkey.go
Normal file
41
internal/hotkey/hotkey.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package hotkey
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
moduser32 = syscall.NewLazyDLL("user32.dll")
|
||||
procRegisterHotKey = moduser32.NewProc("RegisterHotKey")
|
||||
procUnregisterHotKey = moduser32.NewProc("UnregisterHotKey")
|
||||
procPostMessage = moduser32.NewProc("PostMessageW")
|
||||
)
|
||||
|
||||
const (
|
||||
ModAlt = 0x0001
|
||||
ModControl = 0x0002
|
||||
ModShift = 0x0004
|
||||
ModWin = 0x0008
|
||||
|
||||
WM_HOTKEY = 0x0312
|
||||
WM_APP_HOTKEY = 0x8001 // 自定义消息:在主线程触发热键注册
|
||||
)
|
||||
|
||||
func Register(hwnd uintptr, id int32, modifiers uint32, vk uint32) error {
|
||||
ret, _, _ := procRegisterHotKey.Call(hwnd, uintptr(id), uintptr(modifiers), uintptr(vk))
|
||||
if ret == 0 {
|
||||
return fmt.Errorf("RegisterHotKey failed for id=%d", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Unregister(hwnd uintptr, id int32) bool {
|
||||
ret, _, _ := procUnregisterHotKey.Call(hwnd, uintptr(id))
|
||||
return ret != 0
|
||||
}
|
||||
|
||||
// PostMessage 向窗口投递异步消息(用于跨线程调度到主线程)
|
||||
func PostMessage(hwnd uintptr, msg uint32, wParam, lParam uintptr) {
|
||||
procPostMessage.Call(hwnd, uintptr(msg), wParam, lParam)
|
||||
}
|
||||
@@ -99,6 +99,55 @@ func (c *Client) GetBucketDomains(ctx context.Context) ([]string, error) {
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
// GetBucketRegion 查询桶的真实区域
|
||||
// API: POST https://uc.qbox.me/v2/buckets → 遍历匹配桶名获取 region
|
||||
func (c *Client) GetBucketRegion(ctx context.Context) (string, error) {
|
||||
// 使用 UC API 获取所有桶列表(含 region)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://uc.qbox.me/v2/buckets", nil)
|
||||
if err != nil {
|
||||
return "", oss.NewError("BUCKET_ERROR", "failed to create request", err)
|
||||
}
|
||||
|
||||
path := "/v2/buckets"
|
||||
host := "uc.qbox.me"
|
||||
authToken := c.generateAuthTokenWithQuery("POST", path, "", host, "application/x-www-form-urlencoded", nil)
|
||||
req.Header.Set("Host", host)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", authToken)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", oss.NewError("BUCKET_ERROR", "failed to query bucket region", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", oss.NewError("BUCKET_ERROR", "failed to read response", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", oss.NewError("BUCKET_ERROR",
|
||||
fmt.Sprintf("query bucket region failed with status %d: %s", resp.StatusCode, string(body)), nil)
|
||||
}
|
||||
|
||||
var buckets []struct {
|
||||
ID string `json:"id"`
|
||||
Region string `json:"region"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &buckets); err != nil {
|
||||
return "", oss.NewError("BUCKET_ERROR", "failed to parse response", err)
|
||||
}
|
||||
|
||||
for _, b := range buckets {
|
||||
if b.ID == c.config.Bucket {
|
||||
return b.Region, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", oss.NewError("BUCKET_ERROR", fmt.Sprintf("bucket %s not found in account", c.config.Bucket), nil)
|
||||
}
|
||||
|
||||
// SetBucketAccess 设置空间访问权限(公开/私有)
|
||||
// 根据: https://developer.qiniu.com/kodo/api/3946/set-bucket-private
|
||||
//
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,12 +19,12 @@ import (
|
||||
|
||||
// Config 七牛云配置
|
||||
type Config struct {
|
||||
AccessKey string // 访问密钥
|
||||
SecretKey string // 秘钥
|
||||
Bucket string // 存储空间名称
|
||||
Region string // 区域 z0=华东, as0=亚太0区
|
||||
UseHTTPS bool // 是否使用 HTTPS
|
||||
UploadDomain string // 上传域名(可选,默认根据 Region 自动选择)
|
||||
AccessKey string // 访问密钥
|
||||
SecretKey string // 秘钥
|
||||
Bucket string // 存储空间名称
|
||||
Region string // 区域 z0=华东, z2=华南, as0=亚太0区
|
||||
UseHTTPS bool // 是否使用 HTTPS
|
||||
DownloadDomain string // 缓存的下载域名(由 resolveDownloadDomain 自动设置)
|
||||
}
|
||||
|
||||
// Client 七牛云客户端
|
||||
@@ -61,84 +62,31 @@ func NewClient(config *Config) (*Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// generateSignature 生成七牛云管理凭证签名
|
||||
// 根据官方文档:https://developer.qiniu.com/kodo/1201/access-token
|
||||
func (c *Client) generateSignature(method, path, host, contentType string, body []byte) string {
|
||||
// 七牛云管理凭证签名格式:
|
||||
// signingStr = Method + " " + Path + "\nHost: " + Host + "\n" + [Content-Type] + "\n\n" + [body]
|
||||
var signingStr string
|
||||
|
||||
// 1. Method + " " + Path
|
||||
signingStr = method + " " + path
|
||||
|
||||
// 2. Host header
|
||||
signingStr += "\nHost: " + host
|
||||
|
||||
// 3. Content-Type header (如果设置了)
|
||||
if contentType != "" {
|
||||
signingStr += "\nContent-Type: " + contentType
|
||||
}
|
||||
|
||||
// 4. 两个连续换行符
|
||||
signingStr += "\n\n"
|
||||
|
||||
// 5. Body (如果设置了 Content-Type 且不是 application/octet-stream)
|
||||
if contentType != "" && contentType != "application/octet-stream" && len(body) > 0 {
|
||||
signingStr += string(body)
|
||||
}
|
||||
|
||||
// 使用 HMAC-SHA1 签名
|
||||
h := hmac.New(sha1.New, []byte(c.config.SecretKey))
|
||||
h.Write([]byte(signingStr))
|
||||
|
||||
// Base64 URL 安全编码
|
||||
signature := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
return signature
|
||||
}
|
||||
|
||||
// generateAuthToken 生成管理认证 Token
|
||||
func (c *Client) generateAuthToken(method, path, host, contentType string, body []byte) string {
|
||||
signature := c.generateSignature(method, path, host, contentType, body)
|
||||
return "Qiniu " + c.config.AccessKey + ":" + signature
|
||||
return c.generateAuthTokenWithQuery(method, path, "", host, contentType, body)
|
||||
}
|
||||
|
||||
// generateAuthTokenWithQuery 生成管理认证 Token(支持 query string)
|
||||
// https://developer.qiniu.com/kodo/1201/access-token
|
||||
func (c *Client) generateAuthTokenWithQuery(method, path, query, host, contentType string, body []byte) string {
|
||||
// 七牛云管理凭证签名格式:
|
||||
// 如果 query 为非空字符串: signingStr = Method + " " + Path + "?" + query + "\nHost: " + Host + ...
|
||||
// 如果 query 为空: signingStr = Method + " " + Path + "\nHost: " + Host + ...
|
||||
var signingStr string
|
||||
|
||||
// 1. Method + " " + Path
|
||||
signingStr = method + " " + path
|
||||
|
||||
// 2. Query string (如果有)
|
||||
if query != "" {
|
||||
signingStr += "?" + query
|
||||
}
|
||||
|
||||
// 3. Host header
|
||||
signingStr += "\nHost: " + host
|
||||
|
||||
// 4. Content-Type header (如果设置了)
|
||||
if contentType != "" {
|
||||
signingStr += "\nContent-Type: " + contentType
|
||||
}
|
||||
|
||||
// 5. 两个连续换行符
|
||||
signingStr += "\n\n"
|
||||
|
||||
// 6. Body (如果设置了 Content-Type 且不是 application/octet-stream)
|
||||
if contentType != "" && contentType != "application/octet-stream" && len(body) > 0 {
|
||||
signingStr += string(body)
|
||||
}
|
||||
|
||||
// 使用 HMAC-SHA1 签名
|
||||
h := hmac.New(sha1.New, []byte(c.config.SecretKey))
|
||||
h.Write([]byte(signingStr))
|
||||
|
||||
// Base64 URL 安全编码
|
||||
signature := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
return "Qiniu " + c.config.AccessKey + ":" + signature
|
||||
@@ -152,12 +100,11 @@ func (c *Client) encodeEntry(key string) string {
|
||||
|
||||
// getUploadDomain 获取上传域名
|
||||
func (c *Client) getUploadDomain() string {
|
||||
// 如果配置了自定义上传域名,使用自定义的
|
||||
if c.config.UploadDomain != "" {
|
||||
if c.config.DownloadDomain != "" {
|
||||
if c.config.UseHTTPS {
|
||||
return "https://" + c.config.UploadDomain
|
||||
return "https://" + c.config.DownloadDomain
|
||||
}
|
||||
return "http://" + c.config.UploadDomain
|
||||
return "http://" + c.config.DownloadDomain
|
||||
}
|
||||
|
||||
// 根据区域选择默认上传域名
|
||||
@@ -264,85 +211,169 @@ func (c *Client) Upload(ctx context.Context, key string, reader io.Reader, optio
|
||||
return uploadClient.Upload(ctx, key, reader)
|
||||
}
|
||||
|
||||
// generateUploadToken 生成上传凭证
|
||||
func (c *Client) generateUploadToken(key string) string {
|
||||
// 七牛云上传凭证的生成
|
||||
// 1. 创建 putPolicy
|
||||
putPolicy := fmt.Sprintf(`{"scope":"%s:%s","deadline":%d}`,
|
||||
c.config.Bucket, key, time.Now().Add(1*time.Hour).Unix())
|
||||
|
||||
// 2. 对 putPolicy 进行 base64 URL 编码
|
||||
encodedPutPolicy := base64.URLEncoding.EncodeToString([]byte(putPolicy))
|
||||
|
||||
// 3. 对 encodedPutPolicy 进行 HMAC-SHA1 签名
|
||||
// generateToken 生成上传凭证
|
||||
func (c *Client) generateToken(scope string) string {
|
||||
putPolicy := fmt.Sprintf(`{"scope":"%s","deadline":%d}`, scope, time.Now().Add(1*time.Hour).Unix())
|
||||
encoded := base64.URLEncoding.EncodeToString([]byte(putPolicy))
|
||||
h := hmac.New(sha1.New, []byte(c.config.SecretKey))
|
||||
h.Write([]byte(encodedPutPolicy))
|
||||
encodedSign := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
// 4. 组合 token
|
||||
return c.config.AccessKey + ":" + encodedSign + ":" + encodedPutPolicy
|
||||
h.Write([]byte(encoded))
|
||||
sign := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
return c.config.AccessKey + ":" + sign + ":" + encoded
|
||||
}
|
||||
|
||||
func (c *Client) generateUploadToken(key string) string {
|
||||
return c.generateToken(c.config.Bucket + ":" + key)
|
||||
}
|
||||
|
||||
// generateBucketToken 生成 bucket 级别的上传凭证(用于分片上传 v2)
|
||||
func (c *Client) generateBucketToken() string {
|
||||
// 分片上传 v2 需要 bucket 级别的 token
|
||||
// 1. 创建 putPolicy
|
||||
putPolicy := fmt.Sprintf(`{"scope":"%s","deadline":%d}`,
|
||||
c.config.Bucket, time.Now().Add(1*time.Hour).Unix())
|
||||
return c.generateToken(c.config.Bucket)
|
||||
}
|
||||
|
||||
// 2. 对 putPolicy 进行 base64 URL 编码
|
||||
encodedPutPolicy := base64.URLEncoding.EncodeToString([]byte(putPolicy))
|
||||
// 七牛云临时域名后缀(平台分配的 CDN 域名,稳定性高)
|
||||
var qiniuTempSuffixes = []string{
|
||||
".qiniudns.com", ".clouddn.com", ".qbox.me",
|
||||
".qnssl.com", ".qnybgz.cn", ".qiniudns.com.cn",
|
||||
}
|
||||
|
||||
// 3. 对 encodedPutPolicy 进行 HMAC-SHA1 签名
|
||||
h := hmac.New(sha1.New, []byte(c.config.SecretKey))
|
||||
h.Write([]byte(encodedPutPolicy))
|
||||
encodedSign := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
// extractHost 从 URL 提取主机名(去掉 scheme、path、port)
|
||||
func extractHost(domainURL string) string {
|
||||
host := strings.TrimPrefix(domainURL, "http://")
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
if idx := strings.Index(host, "/"); idx >= 0 {
|
||||
host = host[:idx]
|
||||
}
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
return h
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// 4. 组合 token
|
||||
return c.config.AccessKey + ":" + encodedSign + ":" + encodedPutPolicy
|
||||
// isTempDomain 判断是否为七牛平台分配的临时域名(后缀匹配)
|
||||
func (c *Client) isTempDomain(domain string) bool {
|
||||
host := strings.ToLower(extractHost(domain))
|
||||
for _, s := range qiniuTempSuffixes {
|
||||
if strings.HasSuffix(host, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// classifyDomains 将域名列表分为临时域名和自定义域名
|
||||
func (c *Client) classifyDomains(domains []string) (tempDomains, customDomains []string) {
|
||||
for _, d := range domains {
|
||||
if !strings.HasPrefix(d, "http://") && !strings.HasPrefix(d, "https://") {
|
||||
d = "http://" + d
|
||||
}
|
||||
if c.isTempDomain(d) {
|
||||
tempDomains = append(tempDomains, d)
|
||||
} else {
|
||||
customDomains = append(customDomains, d)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// resolveDownloadDomain 解析并缓存下载域名
|
||||
// 策略:API 域名列表(临时优先→自定义)→ 兜底默认 CDN
|
||||
// 不做 HTTP 探测:Download 使用签名 URL,即使有防盗链也能通过
|
||||
func (c *Client) resolveDownloadDomain() (string, error) {
|
||||
if c.config.UploadDomain != "" {
|
||||
return c.config.UploadDomain, nil
|
||||
if c.config.DownloadDomain != "" {
|
||||
return c.config.DownloadDomain, nil
|
||||
}
|
||||
|
||||
domains, err := c.GetBucketDomains(context.Background())
|
||||
if err != nil || len(domains) == 0 {
|
||||
return "", fmt.Errorf("无法获取桶 %s 的下载域名: %v", c.config.Bucket, err)
|
||||
|
||||
if err == nil && len(domains) > 0 {
|
||||
tempDomains, customDomains := c.classifyDomains(domains)
|
||||
|
||||
// 精准获取桶的真实区域
|
||||
c.resolveRegion(tempDomains)
|
||||
|
||||
// 优先使用临时域名(平台分配,稳定性高)
|
||||
if len(tempDomains) > 0 {
|
||||
d := tempDomains[0]
|
||||
c.config.DownloadDomain = d
|
||||
return d, nil
|
||||
}
|
||||
// 降级到自定义域名
|
||||
if len(customDomains) > 0 {
|
||||
d := customDomains[0]
|
||||
c.config.DownloadDomain = d
|
||||
return d, nil
|
||||
}
|
||||
}
|
||||
domain := domains[0]
|
||||
if !strings.HasPrefix(domain, "http://") && !strings.HasPrefix(domain, "https://") {
|
||||
domain = "http://" + domain
|
||||
}
|
||||
c.config.UploadDomain = domain
|
||||
return domain, nil
|
||||
|
||||
// 无域名 → 兜底默认 CDN(可能不存在,但给一个机会)
|
||||
fallback := c.defaultCDNDomain()
|
||||
c.config.DownloadDomain = fallback
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
// Download 下载文件
|
||||
// defaultCDNDomain 构造七牛默认 CDN 域名
|
||||
func (c *Client) defaultCDNDomain() string {
|
||||
return fmt.Sprintf("http://%s-%s.qiniudns.com", c.config.Bucket, c.config.Region)
|
||||
}
|
||||
|
||||
// ClearDownloadDomain 清除缓存的下载域名(下载失败时调用,下次重新解析)
|
||||
func (c *Client) ClearDownloadDomain() {
|
||||
c.config.DownloadDomain = ""
|
||||
}
|
||||
|
||||
// resolveRegion 精准获取桶的真实区域
|
||||
// 优先从临时域名提取 → 查询 API → 使用配置值兜底
|
||||
func (c *Client) resolveRegion(tempDomains []string) {
|
||||
// 1. 从临时域名提取
|
||||
bucketLower := strings.ToLower(c.config.Bucket)
|
||||
for _, d := range tempDomains {
|
||||
host := extractHost(d)
|
||||
host = strings.ToLower(host)
|
||||
if !strings.HasPrefix(host, bucketLower+"-") {
|
||||
continue
|
||||
}
|
||||
rest := host[len(bucketLower)+1:]
|
||||
if idx := strings.Index(rest, "."); idx > 0 {
|
||||
c.config.Region = rest[:idx]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 查询七牛 API
|
||||
if region, err := c.GetBucketRegion(context.Background()); err == nil && region != "" {
|
||||
c.config.Region = region
|
||||
}
|
||||
}
|
||||
|
||||
// Download 下载文件(使用签名 URL,绕过防盗链)
|
||||
func (c *Client) Download(ctx context.Context, key string, writer io.Writer) error {
|
||||
baseURL, err := c.resolveDownloadDomain()
|
||||
signedURL, err := c.GetSignedURL(ctx, key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return oss.NewError("DOWNLOAD_ERROR", err.Error(), err)
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, key)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", signedURL, nil)
|
||||
if err != nil {
|
||||
return oss.NewError("DOWNLOAD_ERROR", "failed to create request", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
c.ClearDownloadDomain()
|
||||
return oss.NewError("DOWNLOAD_ERROR", "failed to download file", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return oss.NewError("DOWNLOAD_ERROR", fmt.Sprintf("download failed with status %d", resp.StatusCode), nil)
|
||||
c.ClearDownloadDomain()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return oss.NewError("DOWNLOAD_ERROR",
|
||||
fmt.Sprintf("download failed with status %d: %s", resp.StatusCode, string(body[:min(len(body), 200)])), nil)
|
||||
}
|
||||
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
c.ClearDownloadDomain()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -407,11 +438,27 @@ func (c *Client) GetFileInfo(ctx context.Context, key string) (*oss.FileInfo, er
|
||||
return nil, oss.NewError("STAT_ERROR", fmt.Sprintf("stat failed with status %d: %s", resp.StatusCode, string(body)), nil)
|
||||
}
|
||||
|
||||
// 解析响应 (简化实现)
|
||||
// 实际响应格式: {"hash":"xxx","fsize":123,"mimeType":"xxx","putTime":123}
|
||||
// 这里返回一个简化的 FileInfo
|
||||
var statResp struct {
|
||||
Hash string `json:"hash"`
|
||||
Fsize int64 `json:"fsize"`
|
||||
MimeType string `json:"mimeType"`
|
||||
PutTime int64 `json:"putTime"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &statResp); err != nil {
|
||||
return nil, oss.NewError("STAT_ERROR", "failed to parse response", err)
|
||||
}
|
||||
|
||||
var modTime time.Time
|
||||
if statResp.PutTime > 0 {
|
||||
modTime = time.Unix(0, statResp.PutTime)
|
||||
}
|
||||
|
||||
return &oss.FileInfo{
|
||||
Key: key,
|
||||
Key: key,
|
||||
Size: statResp.Fsize,
|
||||
ETag: statResp.Hash,
|
||||
ContentType: statResp.MimeType,
|
||||
LastModified: modTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -471,11 +518,16 @@ func (c *Client) ListFiles(ctx context.Context, options *oss.ListOptions) (*oss.
|
||||
// 转换为统一格式
|
||||
files := make([]oss.FileInfo, 0, len(listResp.Items))
|
||||
for _, item := range listResp.Items {
|
||||
var modTime time.Time
|
||||
if item.PutTime > 0 {
|
||||
modTime = time.Unix(0, item.PutTime)
|
||||
}
|
||||
files = append(files, oss.FileInfo{
|
||||
Key: item.Key,
|
||||
Size: item.Fsize,
|
||||
ETag: item.Hash,
|
||||
ContentType: item.MimeType,
|
||||
Key: item.Key,
|
||||
Size: item.Fsize,
|
||||
ETag: item.Hash,
|
||||
ContentType: item.MimeType,
|
||||
LastModified: modTime,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -488,27 +540,22 @@ func (c *Client) ListFiles(ctx context.Context, options *oss.ListOptions) (*oss.
|
||||
}
|
||||
|
||||
// GetSignedURL 获取预签名URL
|
||||
// 签名格式: hmac_sha1(SecretKey, "<downloadURL>?e=<deadline>")
|
||||
func (c *Client) GetSignedURL(ctx context.Context, key string, expiresIn time.Duration) (string, error) {
|
||||
// 七牛云私有空间下载需要生成私有下载 URL
|
||||
deadline := time.Now().Add(expiresIn).Unix()
|
||||
|
||||
// 构建 download URL
|
||||
baseURL, err := c.resolveDownloadDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
downloadURL := fmt.Sprintf("%s/%s", baseURL, key)
|
||||
|
||||
// 生成签名
|
||||
// 签名字符串 = 完整 URL + ?e=deadline
|
||||
urlToSign := fmt.Sprintf("%s/%s?e=%d", baseURL, key, deadline)
|
||||
h := hmac.New(sha1.New, []byte(c.config.SecretKey))
|
||||
signStr := fmt.Sprintf("%s\n%d", downloadURL, deadline)
|
||||
h.Write([]byte(signStr))
|
||||
h.Write([]byte(urlToSign))
|
||||
sign := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
// 构建最终 URL
|
||||
signedURL := fmt.Sprintf("%s?e=%d&token=%s:%s", downloadURL, deadline, c.config.AccessKey, sign)
|
||||
|
||||
return signedURL, nil
|
||||
return fmt.Sprintf("%s&token=%s:%s", urlToSign, c.config.AccessKey, sign), nil
|
||||
}
|
||||
|
||||
// Copy 复制文件
|
||||
|
||||
190
internal/oss/qiniu/client_test.go
Normal file
190
internal/oss/qiniu/client_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package qiniu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/oss"
|
||||
)
|
||||
|
||||
// 临时测试配置 — 提交前删除此文件
|
||||
func testConfig() *Config {
|
||||
return &Config{
|
||||
AccessKey: "eUjiDJGy9CkRb3-Ad3jCubPrm49xeBTesHYckIwc",
|
||||
SecretKey: "LE8XL-LmoMkpy0jNK-kDhgL_w7A6MRXD1Msqd1Y4",
|
||||
Bucket: "u-res",
|
||||
Region: "as0",
|
||||
UseHTTPS: true,
|
||||
}
|
||||
}
|
||||
|
||||
const testKey = "music/03.一人一首成名曲【特调音源】/001.雨一直下-张宇.mp3"
|
||||
|
||||
// TestListBuckets 列举桶
|
||||
func TestListBuckets(t *testing.T) {
|
||||
buckets, err := ListBuckets("eUjiDJGy9CkRb3-Ad3jCubPrm49xeBTesHYckIwc", "LE8XL-LmoMkpy0jNK-kDhgL_w7A6MRXD1Msqd1Y4")
|
||||
if err != nil {
|
||||
t.Fatalf("ListBuckets 失败: %v", err)
|
||||
}
|
||||
for _, b := range buckets {
|
||||
t.Logf("桶: %s 区域: %s", b.Name, b.Region)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetBucketDomains 获取桶域名
|
||||
func TestGetBucketDomains(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
domains, err := c.GetBucketDomains(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("获取域名失败: %v", err)
|
||||
}
|
||||
t.Logf("桶域名: %v", domains)
|
||||
}
|
||||
|
||||
// TestDownloadDirect 裸 URL 下载(测试桶公开/私有)
|
||||
func TestDownloadDirect(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
domain, err := c.resolveDownloadDomain()
|
||||
if err != nil {
|
||||
t.Fatalf("获取下载域名失败: %v", err)
|
||||
}
|
||||
t.Logf("下载域名: %s", domain)
|
||||
|
||||
rawURL := fmt.Sprintf("%s/%s", domain, testKey)
|
||||
t.Logf("裸 URL: %s", rawURL)
|
||||
|
||||
httpResp, err := http.Get(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("请求失败: %v", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
t.Logf("裸 URL 状态码: %d", httpResp.StatusCode)
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.ReadFrom(httpResp.Body)
|
||||
t.Logf("响应大小: %d bytes", buf.Len())
|
||||
}
|
||||
|
||||
// TestDownloadSigned 签名 URL 下载
|
||||
func TestDownloadSigned(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
signedURL, err := c.GetSignedURL(context.Background(), testKey, 1*time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("生成签名 URL 失败: %v", err)
|
||||
}
|
||||
t.Logf("签名 URL: %s...", signedURL[:min(120, len(signedURL))])
|
||||
|
||||
httpResp, err := http.Get(signedURL)
|
||||
if err != nil {
|
||||
t.Fatalf("请求失败: %v", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
t.Logf("签名 URL 状态码: %d", httpResp.StatusCode)
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.ReadFrom(httpResp.Body)
|
||||
t.Logf("下载大小: %d bytes", buf.Len())
|
||||
|
||||
if httpResp.StatusCode != 200 {
|
||||
t.Errorf("下载失败: %d, body: %s", httpResp.StatusCode, buf.String()[:min(200, buf.Len())])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDownloadViaClient 通过 Client.Download 方法下载
|
||||
func TestDownloadViaClient(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = c.Download(context.Background(), testKey, &buf)
|
||||
if err != nil {
|
||||
t.Errorf("Client.Download 失败: %v", err)
|
||||
} else {
|
||||
t.Logf("Client.Download 成功,大小: %d bytes (预期 ~7MB)", buf.Len())
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetFileInfo 获取文件信息
|
||||
func TestGetFileInfo(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
info, err := c.GetFileInfo(context.Background(), testKey)
|
||||
if err != nil {
|
||||
t.Errorf("GetFileInfo 失败: %v", err)
|
||||
} else {
|
||||
t.Logf("GetFileInfo: key=%s size=%d", info.Key, info.Size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListFiles 列举文件
|
||||
func TestListFiles(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
result, err := c.ListFiles(context.Background(), &oss.ListOptions{Prefix: "music/", MaxKeys: 10})
|
||||
if err != nil {
|
||||
t.Fatalf("ListFiles 失败: %v", err)
|
||||
}
|
||||
for _, f := range result.Files {
|
||||
t.Logf("文件: %-80s size: %d", f.Key, f.Size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListFilesRaw 原始 RSF 请求查看响应结构
|
||||
func TestListFilesRaw(t *testing.T) {
|
||||
c, err := NewClient(testConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
resp, err := c.doRSFRequest("GET", fmt.Sprintf("/list?bucket=%s&limit=3&prefix=music/", testConfig().Bucket))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.ReadFrom(resp.Body)
|
||||
|
||||
var pretty bytes.Buffer
|
||||
json.Indent(&pretty, buf.Bytes(), "", " ")
|
||||
t.Logf("原始响应:\n%s", pretty.String())
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -156,8 +156,8 @@ func (uc *UploadClient) Upload(ctx context.Context, key string, reader io.Reader
|
||||
}
|
||||
|
||||
var uploadURL string
|
||||
if uc.config.UploadDomain != "" {
|
||||
uploadURL = scheme + uc.config.UploadDomain
|
||||
if uc.config.DownloadDomain != "" {
|
||||
uploadURL = scheme + uc.config.DownloadDomain
|
||||
} else {
|
||||
// 根据区域选择
|
||||
switch uc.config.Region {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -15,6 +17,7 @@ import (
|
||||
"u-desk/internal/oss"
|
||||
"u-desk/internal/oss/aliyun"
|
||||
"u-desk/internal/oss/qiniu"
|
||||
"u-desk/internal/storage"
|
||||
)
|
||||
|
||||
// accountCredentials 账户级凭据
|
||||
@@ -36,17 +39,19 @@ var globalManager = &Manager{}
|
||||
|
||||
func GetManager() *Manager { return globalManager }
|
||||
|
||||
// Connect 建立账户级连接(验证凭据通过 ListBuckets)
|
||||
// Connect 建立账户级连接(验证凭据通过 ListBuckets,同时缓存桶区域)
|
||||
func (m *Manager) Connect(provider, accessKey, secretKey, endpoint string) error {
|
||||
// 验证凭据
|
||||
var entries []oss.BucketEntry
|
||||
var err error
|
||||
|
||||
switch provider {
|
||||
case "qiniu":
|
||||
_, err := qiniu.ListBuckets(accessKey, secretKey)
|
||||
entries, err = qiniu.ListBuckets(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("七牛云连接失败: %w", err)
|
||||
}
|
||||
case "aliyun":
|
||||
_, err := aliyun.ListBuckets(accessKey, secretKey, endpoint)
|
||||
entries, err = aliyun.ListBuckets(accessKey, secretKey, endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("阿里云连接失败: %w", err)
|
||||
}
|
||||
@@ -54,6 +59,13 @@ func (m *Manager) Connect(provider, accessKey, secretKey, endpoint string) error
|
||||
return fmt.Errorf("不支持的 OSS 提供商: %s", provider)
|
||||
}
|
||||
|
||||
// 连接时立即缓存桶区域,避免后续操作因缺少 region 使用默认区域
|
||||
for _, e := range entries {
|
||||
if e.Region != "" {
|
||||
m.bucketRegions.Store(provider+":"+e.Name, e.Region)
|
||||
}
|
||||
}
|
||||
|
||||
m.accounts.Store(provider, &accountCredentials{
|
||||
Provider: provider,
|
||||
AccessKey: accessKey,
|
||||
@@ -76,10 +88,15 @@ func (m *Manager) getOrCreateBucketClient(provider, bucket, region string) (oss.
|
||||
}
|
||||
c := cred.(*accountCredentials)
|
||||
|
||||
// 如果未传 region,从缓存取
|
||||
// 如果未传 region,从缓存取;仍为空则主动探测
|
||||
if region == "" {
|
||||
if v, ok := m.bucketRegions.Load(key); ok {
|
||||
region = v.(string)
|
||||
} else {
|
||||
region = m.detectBucketRegion(provider, bucket, c)
|
||||
if region != "" {
|
||||
m.bucketRegions.Store(key, region)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,12 +113,17 @@ func (m *Manager) getOrCreateBucketClient(provider, bucket, region string) (oss.
|
||||
UseHTTPS: true,
|
||||
})
|
||||
case "aliyun":
|
||||
// 有桶级 region 时不传账户 Endpoint,让 NewClient 从 region 派生正确的 endpoint
|
||||
ep := c.Endpoint
|
||||
if region != "" {
|
||||
ep = ""
|
||||
}
|
||||
client, err = aliyun.NewClient(&aliyun.Config{
|
||||
AccessKeyID: c.AccessKey,
|
||||
AccessKeySecret: c.SecretKey,
|
||||
Bucket: bucket,
|
||||
Region: region,
|
||||
Endpoint: c.Endpoint,
|
||||
Endpoint: ep,
|
||||
UseHTTPS: true,
|
||||
})
|
||||
default:
|
||||
@@ -116,6 +138,41 @@ func (m *Manager) getOrCreateBucketClient(provider, bucket, region string) (oss.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// detectBucketRegion 主动探测桶区域(缓存未命中时调用)
|
||||
func (m *Manager) detectBucketRegion(provider, bucket string, c *accountCredentials) string {
|
||||
switch provider {
|
||||
case "aliyun":
|
||||
entries, err := aliyun.ListBuckets(c.AccessKey, c.SecretKey, c.Endpoint)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
key := provider + ":" + e.Name
|
||||
if e.Region != "" {
|
||||
m.bucketRegions.Store(key, e.Region)
|
||||
}
|
||||
if e.Name == bucket {
|
||||
return e.Region
|
||||
}
|
||||
}
|
||||
case "qiniu":
|
||||
entries, err := qiniu.ListBuckets(c.AccessKey, c.SecretKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
key := provider + ":" + e.Name
|
||||
if e.Region != "" {
|
||||
m.bucketRegions.Store(key, e.Region)
|
||||
}
|
||||
if e.Name == bucket {
|
||||
return e.Region
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetClient 获取已有的桶级客户端
|
||||
func (m *Manager) GetClient(provider, bucket string) oss.OSSProvider {
|
||||
if v, ok := m.clients.Load(provider + ":" + bucket); ok {
|
||||
@@ -583,8 +640,255 @@ func (s *Service) RenamePath(connID string, oldPath string, newPath string) (*fi
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DownloadToTemp 下载文件到本地临时目录
|
||||
// htmlResourceRegex 提取 HTML 资源引用的正则
|
||||
var htmlResourceRegex = regexp.MustCompile(`(?:src|href|data|poster)=["']([^"']+)["']`)
|
||||
var htmlCssUrlRegex = regexp.MustCompile(`url\(\s*["']?([^"')]+)["']?\s*\)`)
|
||||
|
||||
// DownloadSiteForPreview 下载 HTML 及其引用的资源到临时目录
|
||||
// 对绝对路径(/开头)从 HTML 目录逐级向上嗅探网站根目录
|
||||
func (s *Service) DownloadSiteForPreview(connID string, rawPath string) (string, error) {
|
||||
bucket, key := parseBucketPath(rawPath)
|
||||
if bucket == "" {
|
||||
return "", fmt.Errorf("路径中缺少桶名")
|
||||
}
|
||||
|
||||
c, err := s.manager.getOrCreateBucketClient(connID, bucket, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. 创建临时目录,保留 OSS 目录结构
|
||||
tmpDir, err := os.MkdirTemp("", "udesk-site-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建临时目录失败: %w", err)
|
||||
}
|
||||
|
||||
keyDir := path.Dir(key)
|
||||
var htmlLocalPath string
|
||||
if keyDir != "" && keyDir != "." {
|
||||
htmlLocalPath = filepath.Join(tmpDir, filepath.FromSlash(keyDir), path.Base(key))
|
||||
if err := os.MkdirAll(filepath.Dir(htmlLocalPath), 0755); err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
htmlLocalPath = filepath.Join(tmpDir, path.Base(key))
|
||||
}
|
||||
|
||||
// 2. 下载 HTML
|
||||
f, err := os.Create(htmlLocalPath)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", fmt.Errorf("创建临时文件失败: %w", err)
|
||||
}
|
||||
if err := c.Download(ctx, key, f); err != nil {
|
||||
f.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", fmt.Errorf("下载 HTML 失败: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// 3. 解析 HTML 提取资源路径
|
||||
htmlContent, err := os.ReadFile(htmlLocalPath)
|
||||
if err != nil {
|
||||
return htmlLocalPath, nil // HTML 已下载,资源解析失败不影响
|
||||
}
|
||||
resources := extractHtmlResources(string(htmlContent))
|
||||
|
||||
// 4. 下载资源
|
||||
htmlOssDir := keyDir
|
||||
if htmlOssDir == "." {
|
||||
htmlOssDir = ""
|
||||
}
|
||||
htmlLocalDir := filepath.Dir(htmlLocalPath)
|
||||
|
||||
var siteRoot string
|
||||
var discoveredDirs []string
|
||||
seenDir := make(map[string]bool)
|
||||
recordDir := func(ossKey string) {
|
||||
dir := path.Dir(ossKey)
|
||||
if !seenDir[dir] {
|
||||
seenDir[dir] = true
|
||||
discoveredDirs = append(discoveredDirs, dir)
|
||||
}
|
||||
}
|
||||
|
||||
for _, resPath := range resources {
|
||||
if shouldSkipResource(resPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
isAbsolute := strings.HasPrefix(resPath, "/")
|
||||
cleanPath := strings.TrimPrefix(resPath, "/")
|
||||
cleanPath = strings.TrimPrefix(cleanPath, "./")
|
||||
if cleanPath == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
localPath := filepath.Join(htmlLocalDir, filepath.FromSlash(cleanPath))
|
||||
|
||||
if isAbsolute {
|
||||
resolvedKey := resolveAndDownload(c, ctx, htmlOssDir, cleanPath, localPath, &siteRoot)
|
||||
if resolvedKey != "" {
|
||||
recordDir(resolvedKey)
|
||||
}
|
||||
} else {
|
||||
var ossKey string
|
||||
if htmlOssDir != "" {
|
||||
ossKey = htmlOssDir + "/" + cleanPath
|
||||
} else {
|
||||
ossKey = cleanPath
|
||||
}
|
||||
if downloadResource(c, ctx, ossKey, localPath) {
|
||||
recordDir(ossKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 5. 补充下载已发现目录中的剩余文件(覆盖 webpack 动态 chunk 等)
|
||||
for _, dir := range discoveredDirs {
|
||||
supplementDir(c, ctx, dir, tmpDir, siteRoot)
|
||||
}
|
||||
|
||||
return htmlLocalPath, nil
|
||||
}
|
||||
|
||||
// resolveAbsoluteResourcePath 解析绝对路径资源,首次嗅探网站根,后续直接使用
|
||||
// resolveAndDownload 解析绝对路径并下载:首次嗅探网站根,后续直接使用
|
||||
func resolveAndDownload(c oss.OSSProvider, ctx context.Context, htmlOssDir string, cleanPath string, localPath string, siteRoot *string) string {
|
||||
if *siteRoot != "" {
|
||||
downloadResource(c, ctx, *siteRoot+cleanPath, localPath)
|
||||
return *siteRoot + cleanPath
|
||||
}
|
||||
|
||||
// 从 HTML 目录向上逐级探测(同时下载,成功即完成嗅探)
|
||||
dir := htmlOssDir
|
||||
for {
|
||||
var candidate string
|
||||
if dir == "" {
|
||||
candidate = cleanPath
|
||||
} else {
|
||||
candidate = dir + "/" + cleanPath
|
||||
}
|
||||
|
||||
if downloadResource(c, ctx, candidate, localPath) {
|
||||
if dir == "" {
|
||||
*siteRoot = ""
|
||||
} else {
|
||||
*siteRoot = dir + "/"
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
if dir == "" {
|
||||
break
|
||||
}
|
||||
parent := path.Dir(dir)
|
||||
if parent == dir || parent == "." {
|
||||
dir = ""
|
||||
} else {
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// downloadResource 下载资源到本地(失败静默,返回是否成功)
|
||||
func downloadResource(c oss.OSSProvider, ctx context.Context, ossKey string, localPath string) bool {
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
|
||||
return false
|
||||
}
|
||||
f, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := c.Download(ctx, ossKey, f); err != nil {
|
||||
f.Close()
|
||||
os.Remove(localPath)
|
||||
return false
|
||||
}
|
||||
f.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// supplementDir 补充下载远程目录中尚未下载的文件(只处理已知资源所在目录)
|
||||
func supplementDir(c oss.OSSProvider, ctx context.Context, remoteDir string, tmpDir string, siteRoot string) {
|
||||
prefix := remoteDir + "/"
|
||||
result, err := c.ListFiles(ctx, &oss.ListOptions{Prefix: prefix, MaxKeys: 200})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, f := range result.Files {
|
||||
if strings.HasSuffix(f.Key, "/") || f.Size == 0 {
|
||||
continue
|
||||
}
|
||||
relPath := strings.TrimPrefix(f.Key, siteRoot)
|
||||
localPath := filepath.Join(tmpDir, filepath.FromSlash(relPath))
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
continue
|
||||
}
|
||||
downloadResource(c, ctx, f.Key, localPath)
|
||||
}
|
||||
}
|
||||
func extractHtmlResources(html string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var resources []string
|
||||
|
||||
add := func(v string) {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" && !seen[v] {
|
||||
seen[v] = true
|
||||
resources = append(resources, v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range htmlResourceRegex.FindAllStringSubmatch(html, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1])
|
||||
}
|
||||
}
|
||||
for _, m := range htmlCssUrlRegex.FindAllStringSubmatch(html, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return resources
|
||||
}
|
||||
|
||||
// shouldSkipResource 判断资源路径是否应跳过
|
||||
func shouldSkipResource(p string) bool {
|
||||
return strings.HasPrefix(p, "data:") ||
|
||||
strings.HasPrefix(p, "http://") ||
|
||||
strings.HasPrefix(p, "https://") ||
|
||||
strings.HasPrefix(p, "//") ||
|
||||
strings.HasPrefix(p, "#") ||
|
||||
strings.HasPrefix(p, "javascript:") ||
|
||||
strings.HasPrefix(p, "mailto:") ||
|
||||
strings.HasPrefix(p, "blob:")
|
||||
}
|
||||
|
||||
// DownloadToTemp 下载文件到本地临时目录(带 SQLite 缓存)
|
||||
func (s *Service) DownloadToTemp(connID string, rawPath string) (string, error) {
|
||||
// 先获取文件元信息用于缓存键,确保远程文件变更时能淘汰旧缓存
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if info, err := s.GetFileInfo(connID, rawPath); err == nil {
|
||||
if sz, ok := info["size"].(int64); ok {
|
||||
fileSize = sz
|
||||
}
|
||||
if mt, ok := info["mod_time"].(string); ok {
|
||||
modTime = mt
|
||||
}
|
||||
}
|
||||
return storage.DownloadToTempCached("oss", connID, rawPath, fileSize, modTime, func() (string, error) {
|
||||
return s.downloadToTempDirect(connID, rawPath)
|
||||
})
|
||||
}
|
||||
|
||||
// downloadToTempDirect 实际执行下载(无缓存)
|
||||
func (s *Service) downloadToTempDirect(connID string, rawPath string) (string, error) {
|
||||
bucket, key := parseBucketPath(rawPath)
|
||||
if bucket == "" {
|
||||
return "", fmt.Errorf("路径中缺少桶名")
|
||||
@@ -609,6 +913,13 @@ func (s *Service) DownloadToTemp(connID string, rawPath string) (string, error)
|
||||
return localPath, nil
|
||||
}
|
||||
|
||||
// DownloadToTempCached 带缓存的 OSS 下载(命中缓存直接返回本地路径,支持传入文件元信息)
|
||||
func (s *Service) DownloadToTempCached(connID, rawPath string, fileSize int64, modTime string) (string, error) {
|
||||
return storage.DownloadToTempCached("oss", connID, rawPath, fileSize, modTime, func() (string, error) {
|
||||
return s.downloadToTempDirect(connID, rawPath)
|
||||
})
|
||||
}
|
||||
|
||||
// GetCommonPaths 返回常用路径
|
||||
func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
|
||||
@@ -8,15 +8,22 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/filesystem"
|
||||
"u-desk/internal/storage"
|
||||
|
||||
sftpclient "github.com/pkg/sftp"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpResRegex = regexp.MustCompile(`(?:src|href|data|poster)=["']([^"']+)["']`)
|
||||
sftpCssUrlRe = regexp.MustCompile(`url\(\s*["']?([^"')]+)["']?\s*\)`)
|
||||
)
|
||||
|
||||
// Service SFTP 文件操作服务
|
||||
type Service struct {
|
||||
manager *Manager
|
||||
@@ -257,8 +264,26 @@ func (s *Service) RenamePath(connID string, oldPath, newPath string) (*filesyste
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DownloadToTemp 下载远程文件到本地临时目录(用于预览)
|
||||
// DownloadToTemp 下载远程文件到本地临时目录(带 SQLite 缓存)
|
||||
func (s *Service) DownloadToTemp(connID string, remotePath string) (string, error) {
|
||||
// 先获取文件元信息用于缓存键,确保远程文件变更时能淘汰旧缓存
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if info, err := s.GetFileInfo(connID, remotePath); err == nil {
|
||||
if sz, ok := info["size"].(int64); ok {
|
||||
fileSize = sz
|
||||
}
|
||||
if mt, ok := info["mod_time"].(string); ok {
|
||||
modTime = mt
|
||||
}
|
||||
}
|
||||
return storage.DownloadToTempCached("sftp", connID, remotePath, fileSize, modTime, func() (string, error) {
|
||||
return s.downloadToTempDirect(connID, remotePath)
|
||||
})
|
||||
}
|
||||
|
||||
// downloadToTempDirect 实际执行下载(无缓存)
|
||||
func (s *Service) downloadToTempDirect(connID string, remotePath string) (string, error) {
|
||||
c, err := s.getClient(connID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -303,12 +328,201 @@ func (s *Service) DownloadToTemp(connID string, remotePath string) (string, erro
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
os.Remove(localPath)
|
||||
return "", fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
|
||||
return localPath, nil
|
||||
}
|
||||
|
||||
// DownloadToTempCached 带缓存的 SFTP 下载(支持传入文件元信息)
|
||||
func (s *Service) DownloadToTempCached(connID, remotePath string, fileSize int64, modTime string) (string, error) {
|
||||
return storage.DownloadToTempCached("sftp", connID, remotePath, fileSize, modTime, func() (string, error) {
|
||||
return s.downloadToTempDirect(connID, remotePath)
|
||||
})
|
||||
}
|
||||
|
||||
// DownloadSiteForPreview 下载 HTML 及其网站资源到本地临时目录
|
||||
func (s *Service) DownloadSiteForPreview(connID string, remotePath string) (string, error) {
|
||||
c, err := s.getClient(connID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 1. 创建临时目录
|
||||
tmpDir, err := os.MkdirTemp("", "udesk-sftp-site-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建临时目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 确定远程网站根目录(从 HTML 路径推断)
|
||||
keyDir := path.Dir(remotePath)
|
||||
if keyDir == "." {
|
||||
keyDir = ""
|
||||
}
|
||||
|
||||
var htmlLocalPath string
|
||||
if keyDir != "" {
|
||||
htmlLocalPath = filepath.Join(tmpDir, filepath.FromSlash(keyDir), path.Base(remotePath))
|
||||
if err := os.MkdirAll(filepath.Dir(htmlLocalPath), 0755); err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
htmlLocalPath = filepath.Join(tmpDir, path.Base(remotePath))
|
||||
}
|
||||
|
||||
// 3. 下载 HTML
|
||||
if err := s.sftpDownloadFile(c, remotePath, htmlLocalPath); err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
return "", fmt.Errorf("下载 HTML 失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 解析 HTML 提取资源路径
|
||||
htmlContent, err := os.ReadFile(htmlLocalPath)
|
||||
if err != nil {
|
||||
return htmlLocalPath, nil
|
||||
}
|
||||
resources := sftpExtractResources(string(htmlContent))
|
||||
|
||||
// 5. 下载静态引用资源(嗅探网站根)
|
||||
htmlRemoteDir := keyDir
|
||||
if htmlRemoteDir == "/" {
|
||||
htmlRemoteDir = ""
|
||||
}
|
||||
htmlLocalDir := filepath.Dir(htmlLocalPath)
|
||||
|
||||
var siteRoot string
|
||||
var discoveredDirs []string
|
||||
seenDir := make(map[string]bool)
|
||||
recordDir := func(remoteKey string) {
|
||||
dir := path.Dir(remoteKey)
|
||||
if !seenDir[dir] {
|
||||
seenDir[dir] = true
|
||||
discoveredDirs = append(discoveredDirs, dir)
|
||||
}
|
||||
}
|
||||
|
||||
for _, resPath := range resources {
|
||||
if sftpShouldSkip(resPath) {
|
||||
continue
|
||||
}
|
||||
isAbsolute := strings.HasPrefix(resPath, "/")
|
||||
cleanPath := strings.TrimPrefix(resPath, "/")
|
||||
cleanPath = strings.TrimPrefix(cleanPath, "./")
|
||||
if cleanPath == "" {
|
||||
continue
|
||||
}
|
||||
localPath := filepath.Join(htmlLocalDir, filepath.FromSlash(cleanPath))
|
||||
|
||||
if isAbsolute {
|
||||
resolvedKey := sftpResolveAndDownload(s, c, htmlRemoteDir, cleanPath, localPath, &siteRoot)
|
||||
if resolvedKey != "" {
|
||||
recordDir(resolvedKey)
|
||||
}
|
||||
} else {
|
||||
remoteKey := path.Join(htmlRemoteDir, cleanPath)
|
||||
if s.sftpTryDownload(c, remoteKey, localPath) {
|
||||
recordDir(remoteKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 补充下载已发现目录中的剩余文件(覆盖动态 chunk 等)
|
||||
for _, dir := range discoveredDirs {
|
||||
sftpSupplementDir(s, c, dir, tmpDir, siteRoot)
|
||||
}
|
||||
|
||||
return htmlLocalPath, nil
|
||||
}
|
||||
|
||||
// sftpDownloadFile 下载单个远程文件到本地路径
|
||||
func (s *Service) sftpDownloadFile(c *Client, remotePath, localPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WithRetry(func(sc *sftpclient.Client) error {
|
||||
src, err := sc.Open(remotePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
dst, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dst.Close()
|
||||
_, err = io.Copy(dst, src)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// sftpTryDownload 尝试下载(失败静默,返回是否成功)
|
||||
func (s *Service) sftpTryDownload(c *Client, remotePath, localPath string) bool {
|
||||
err := s.sftpDownloadFile(c, remotePath, localPath)
|
||||
if err != nil {
|
||||
os.Remove(localPath)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sftpResolveAndDownload 嗅探网站根并下载绝对路径资源
|
||||
func sftpResolveAndDownload(s *Service, c *Client, htmlDir string, cleanPath string, localPath string, siteRoot *string) string {
|
||||
if *siteRoot != "" {
|
||||
s.sftpTryDownload(c, *siteRoot+cleanPath, localPath)
|
||||
return *siteRoot + cleanPath
|
||||
}
|
||||
dir := htmlDir
|
||||
for {
|
||||
candidate := path.Join(dir, cleanPath)
|
||||
if s.sftpTryDownload(c, candidate, localPath) {
|
||||
if dir == "" {
|
||||
*siteRoot = ""
|
||||
} else {
|
||||
*siteRoot = dir + "/"
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if dir == "" {
|
||||
break
|
||||
}
|
||||
parent := path.Dir(dir)
|
||||
if parent == dir || parent == "." {
|
||||
dir = ""
|
||||
} else {
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sftpSupplementDir 补充下载远程目录中尚未下载的文件(只处理已知资源所在目录)
|
||||
func sftpSupplementDir(s *Service, c *Client, remoteDir string, tmpDir string, siteRoot string) {
|
||||
var entries []fs.FileInfo
|
||||
err := c.WithRetry(func(sc *sftpclient.Client) error {
|
||||
var e error
|
||||
entries, e = sc.ReadDir(remoteDir)
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || entry.Size() == 0 {
|
||||
continue
|
||||
}
|
||||
fullPath := path.Join(remoteDir, entry.Name())
|
||||
relPath := strings.TrimPrefix(fullPath, siteRoot)
|
||||
relPath = strings.TrimPrefix(relPath, "/")
|
||||
localPath := filepath.Join(tmpDir, filepath.FromSlash(relPath))
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
continue
|
||||
}
|
||||
s.sftpTryDownload(c, fullPath, localPath)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCommonPaths 返回 SFTP 远程主机常用路径
|
||||
func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
|
||||
c := s.manager.GetClient(connID)
|
||||
@@ -331,18 +545,9 @@ func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CleanupTempFiles 清理遗留的临时预览文件
|
||||
// CleanupTempFiles 清理遗留的临时预览文件(已由 SQLite 缓存接管)
|
||||
func CleanupTempFiles() {
|
||||
tmpDir := os.TempDir()
|
||||
entries, err := os.ReadDir(tmpDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if strings.HasPrefix(entry.Name(), "udesk-sftp-preview-") {
|
||||
os.Remove(filepath.Join(tmpDir, entry.Name()))
|
||||
}
|
||||
}
|
||||
storage.CleanupExpiredCache()
|
||||
}
|
||||
|
||||
// GetSystemInfo 通过 SSH 命令采集远程系统信息(磁盘/CPU/内存)
|
||||
@@ -501,3 +706,39 @@ func toFileOperationResult(m map[string]interface{}, isDir bool) *filesystem.Fil
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// sftpExtractResources 从 HTML 内容提取资源路径
|
||||
func sftpExtractResources(html string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var resources []string
|
||||
add := func(v string) {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" && !seen[v] {
|
||||
seen[v] = true
|
||||
resources = append(resources, v)
|
||||
}
|
||||
}
|
||||
for _, m := range sftpResRegex.FindAllStringSubmatch(html, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1])
|
||||
}
|
||||
}
|
||||
for _, m := range sftpCssUrlRe.FindAllStringSubmatch(html, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1])
|
||||
}
|
||||
}
|
||||
return resources
|
||||
}
|
||||
|
||||
// sftpShouldSkip 判断资源路径是否应跳过
|
||||
func sftpShouldSkip(p string) bool {
|
||||
return strings.HasPrefix(p, "data:") ||
|
||||
strings.HasPrefix(p, "http://") ||
|
||||
strings.HasPrefix(p, "https://") ||
|
||||
strings.HasPrefix(p, "//") ||
|
||||
strings.HasPrefix(p, "#") ||
|
||||
strings.HasPrefix(p, "javascript:") ||
|
||||
strings.HasPrefix(p, "mailto:") ||
|
||||
strings.HasPrefix(p, "blob:")
|
||||
}
|
||||
|
||||
187
internal/storage/download_cache.go
Normal file
187
internal/storage/download_cache.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/storage/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const downloadCacheTTL = 24 * time.Hour
|
||||
|
||||
// cacheTempDir 确定性临时目录
|
||||
var cacheTempDir = filepath.Join(os.TempDir(), "u-desk-cache")
|
||||
|
||||
// GetCachedPath 查询缓存,验证文件存在后返回本地路径
|
||||
func GetCachedPath(transport, connID, remotePath string, fileSize int64, modTime string) (string, bool) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var entry models.DownloadCache
|
||||
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
|
||||
transport, connID, remotePath, fileSize, modTime).First(&entry).Error
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 检查文件是否仍然存在于磁盘
|
||||
if _, err := os.Stat(entry.LocalPath); err != nil {
|
||||
// 文件已丢失,清理过期记录
|
||||
db.Delete(&entry)
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Since(entry.DownloadedAt) > downloadCacheTTL {
|
||||
os.Remove(entry.LocalPath)
|
||||
db.Delete(&entry)
|
||||
return "", false
|
||||
}
|
||||
|
||||
return entry.LocalPath, true
|
||||
}
|
||||
|
||||
// SaveCache 保存或更新缓存记录
|
||||
func SaveCache(transport, connID, remotePath string, fileSize int64, modTime, localPath string) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var existing models.DownloadCache
|
||||
err := db.Where("transport = ? AND conn_id = ? AND remote_path = ? AND file_size = ? AND mod_time = ?",
|
||||
transport, connID, remotePath, fileSize, modTime).First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
db.Create(&models.DownloadCache{
|
||||
Transport: transport,
|
||||
ConnID: connID,
|
||||
RemotePath: remotePath,
|
||||
FileSize: fileSize,
|
||||
ModTime: modTime,
|
||||
LocalPath: localPath,
|
||||
DownloadedAt: time.Now(),
|
||||
})
|
||||
} else if err == nil {
|
||||
db.Model(&existing).Updates(map[string]any{
|
||||
"local_path": localPath,
|
||||
"downloaded_at": time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupExpiredCache 清理超过 24h 的缓存记录并删除对应临时文件
|
||||
func CleanupExpiredCache() {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := time.Now().Add(-downloadCacheTTL)
|
||||
var expired []models.DownloadCache
|
||||
db.Where("downloaded_at < ?", cutoff).Find(&expired)
|
||||
|
||||
for _, entry := range expired {
|
||||
os.Remove(entry.LocalPath)
|
||||
db.Delete(&entry)
|
||||
}
|
||||
|
||||
if len(expired) > 0 {
|
||||
fmt.Printf("[下载缓存] 清理 %d 条过期记录\n", len(expired))
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadToTempCached 带缓存的下载:命中返回本地路径,未命中调用 downloadFn 后缓存结果
|
||||
func DownloadToTempCached(transport, connID, remotePath string, fileSize int64, modTime string, downloadFn func() (string, error)) (string, error) {
|
||||
// 1. 查缓存
|
||||
if localPath, hit := GetCachedPath(transport, connID, remotePath, fileSize, modTime); hit {
|
||||
return localPath, nil
|
||||
}
|
||||
|
||||
// 2. 缓存未命中,执行下载
|
||||
tempPath, err := downloadFn()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 3. 生成确定性路径并移动文件
|
||||
deterministicPath, err := deterministicCachePath(transport, connID, remotePath, fileSize, modTime)
|
||||
if err != nil {
|
||||
// 降级:直接使用 downloadFn 返回的路径,仍然缓存
|
||||
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
||||
return tempPath, nil
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(deterministicPath), 0755); err != nil {
|
||||
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
||||
return tempPath, nil
|
||||
}
|
||||
|
||||
// 移动文件到确定性路径
|
||||
if err := os.Rename(tempPath, deterministicPath); err != nil {
|
||||
// Rename 可能跨卷失败,尝试 Copy+Delete
|
||||
if copyFile(tempPath, deterministicPath) != nil {
|
||||
SaveCache(transport, connID, remotePath, fileSize, modTime, tempPath)
|
||||
return tempPath, nil
|
||||
}
|
||||
os.Remove(tempPath)
|
||||
}
|
||||
|
||||
SaveCache(transport, connID, remotePath, fileSize, modTime, deterministicPath)
|
||||
return deterministicPath, nil
|
||||
}
|
||||
|
||||
// deterministicCachePath 根据文件信息生成确定性的缓存路径
|
||||
func deterministicCachePath(transport, connID, remotePath string, fileSize int64, modTime string) (string, error) {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(fmt.Sprintf("%s:%s:%s:%d:%s", transport, connID, remotePath, fileSize, modTime)))
|
||||
hash := fmt.Sprintf("%x", h.Sum(nil))[:16]
|
||||
|
||||
baseName := filepath.Base(remotePath)
|
||||
if baseName == "" || baseName == "." || baseName == "/" {
|
||||
baseName = "file"
|
||||
}
|
||||
|
||||
// 截断过长的文件名
|
||||
if len(baseName) > 64 {
|
||||
ext := filepath.Ext(baseName)
|
||||
maxName := 64 - len(ext)
|
||||
if maxName <= 0 {
|
||||
maxName = 1
|
||||
ext = ext[:63]
|
||||
}
|
||||
baseName = baseName[:maxName] + ext
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("%s_%s", hash, baseName)
|
||||
return filepath.Join(cacheTempDir, fileName), nil
|
||||
}
|
||||
|
||||
// copyFile 复制文件内容
|
||||
func copyFile(src, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := out.ReadFrom(in); err != nil {
|
||||
os.Remove(dst)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
12
internal/storage/models/bgm_playlist.go
Normal file
12
internal/storage/models/bgm_playlist.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package models
|
||||
|
||||
// BgmPlaylist BGM 播放列表持久化
|
||||
type BgmPlaylist struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Name string `gorm:"not null;size:255"`
|
||||
Path string `gorm:"not null;size:500;uniqueIndex"`
|
||||
ProfileID string `gorm:"type:varchar(50)" json:"profile_id"`
|
||||
Sort uint `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (BgmPlaylist) TableName() string { return "bgm_playlist" }
|
||||
@@ -11,7 +11,8 @@ type ConnectionProfile struct {
|
||||
Username string `gorm:"type:varchar(100);default:root" json:"username"`
|
||||
Password string `gorm:"type:text" json:"password"`
|
||||
KeyPath string `gorm:"type:text" json:"key_path"`
|
||||
Type string `gorm:"type:varchar(20);not null;index" json:"type"` // local|remote|sftp|qiniu|aliyun
|
||||
Type string `gorm:"type:varchar(20);not null;index" json:"type"` // local|remote|sftp|oss
|
||||
Provider string `gorm:"type:varchar(20)" json:"provider"` // qiniu|aliyun (仅 type=oss)
|
||||
Token string `gorm:"type:text" json:"token"`
|
||||
AccessKey string `gorm:"type:text" json:"access_key"`
|
||||
SecretKey string `gorm:"type:text" json:"secret_key"`
|
||||
|
||||
17
internal/storage/models/download_cache.go
Normal file
17
internal/storage/models/download_cache.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// DownloadCache 下载缓存模型(SQLite 持久化)
|
||||
type DownloadCache struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Transport string `gorm:"not null;size:10;index:idx_cache_lookup"`
|
||||
ConnID string `gorm:"not null;index:idx_cache_lookup"`
|
||||
RemotePath string `gorm:"not null;index:idx_cache_lookup"`
|
||||
FileSize int64 `gorm:"not null;index:idx_cache_lookup"`
|
||||
ModTime string `gorm:"not null;index:idx_cache_lookup"`
|
||||
LocalPath string `gorm:"not null"`
|
||||
DownloadedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (DownloadCache) TableName() string { return "download_cache" }
|
||||
@@ -53,10 +53,14 @@ func InitFast() (*gorm.DB, error) {
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
if e := db.AutoMigrate(&models.AppConfig{}, &models.ConnectionProfile{}); e != nil {
|
||||
if e := db.AutoMigrate(&models.AppConfig{}, &models.ConnectionProfile{}, &models.DownloadCache{}, &models.BgmPlaylist{}); e != nil {
|
||||
initErr = e
|
||||
return
|
||||
}
|
||||
// 数据迁移:qiniu/aliyun → oss + provider
|
||||
db.Exec("UPDATE connection_profiles SET provider = type, type = 'oss' WHERE type IN ('qiniu', 'aliyun')")
|
||||
// 为旧 BGM 播放列表补充 profile_id(找第一个 OSS profile)
|
||||
db.Exec("UPDATE bgm_playlist SET profile_id = (SELECT CAST(id AS VARCHAR) FROM connection_profiles WHERE type = 'oss' LIMIT 1) WHERE (profile_id = '' OR profile_id IS NULL) AND path NOT LIKE '%:'")
|
||||
globalDB = db
|
||||
})
|
||||
if initErr != nil {
|
||||
|
||||
Reference in New Issue
Block a user