Private
Public Access
1
0
Files
u-desk/internal/sftp/service.go

745 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sftp
import (
"encoding/base64"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"u-desk/internal/filesystem"
"u-desk/internal/storage"
sftpclient "github.com/pkg/sftp"
)
var (
sftpResRegex = regexp.MustCompile(`(?:src|href|data|poster)=["']([^"']+)["']`)
sftpCssUrlRe = regexp.MustCompile(`url\(\s*["']?([^"')]+)["']?\s*\)`)
)
// Service SFTP 文件操作服务
type Service struct {
manager *Manager
}
// NewService 创建 SFTP 服务实例
func NewService() *Service {
return &Service{manager: GetManager()}
}
// GetManager 获取底层连接管理器(供 App 层调用)
func (s *Service) GetManager() *Manager {
return s.manager
}
// ConnID 从配置生成连接标识符
func ConnID(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}
// --- 核心文件操作 ---
func (s *Service) ListDir(connID string, dirPath string) ([]map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var entries []fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
entries, e = sc.ReadDir(dirPath)
return e
})
if err != nil {
return nil, fmt.Errorf("读取目录失败: %w", err)
}
result := make([]map[string]interface{}, 0, len(entries))
for _, info := range entries {
fullPath := path.Join(dirPath, info.Name())
result = append(result, map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(fullPath),
"is_dir": info.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
func (s *Service) ReadFile(connID string, filePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 大小限制(与本地模式 ReadFile 的 10MB 上限对齐)
const maxSize int64 = 10 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil { return e }
if fi.Size() > maxSize {
return fmt.Errorf("文件过大 (%s),超过 %d 限制", filesystem.FormatBytes(fi.Size()), maxSize)
}
return nil
})
if err != nil {
return "", err
}
var data []byte
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Open(filePath)
if e != nil {
return e
}
defer f.Close()
data, e = io.ReadAll(f)
return e
})
if err != nil {
return "", fmt.Errorf("读取文件失败: %w", err)
}
return filesystem.BytesToString(data), nil
}
func (s *Service) WriteFile(connID string, filePath string, content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write([]byte(content))
return e
})
}
// WriteBase64File 将 base64 编码的二进制内容写入远程文件(用于粘贴图片等场景)
func (s *Service) WriteBase64File(connID string, filePath string, base64Content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
data, err := base64.StdEncoding.DecodeString(base64Content)
if err != nil {
return fmt.Errorf("base64 解码失败: %w", err)
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write(data)
return e
})
}
func (s *Service) GetFileInfo(connID string, filePath string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var info fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
info, e = sc.Stat(filePath)
return e
})
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
return map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(filePath),
"size": info.Size(),
"size_str": filesystem.FormatBytes(info.Size()),
"is_dir": info.IsDir(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
"mode": info.Mode().String(),
}, nil
}
func (s *Service) CreateDir(connID string, dirPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.MkdirAll(dirPath)
})
if err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, dirPath)
return toFileOperationResult(infoMap, true), nil
}
func (s *Service) CreateFile(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return e
}
return f.Close()
})
if err != nil {
return nil, fmt.Errorf("创建文件失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, filePath)
return toFileOperationResult(infoMap, false), nil
}
func (s *Service) DeletePath(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
infoMap, _ := s.GetFileInfo(connID, filePath)
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil {
return e
}
if fi.IsDir() {
// 递归删除目录
return sc.RemoveAll(filePath)
}
return sc.Remove(filePath)
})
if err != nil {
return nil, fmt.Errorf("删除失败: %w", err)
}
result := toFileOperationResult(infoMap, false)
result.Deleted = true
return result, nil
}
func (s *Service) RenamePath(connID string, oldPath, newPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.Rename(oldPath, newPath)
})
if err != nil {
return nil, fmt.Errorf("重命名失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, newPath)
result := toFileOperationResult(infoMap, false)
result.OldPath = oldPath
return result, nil
}
// DownloadToTemp 下载远程文件到本地临时目录(带 SQLite 缓存)
func (s *Service) DownloadToTemp(connID string, remotePath string) (string, error) {
// 先获取文件元信息用于缓存键,确保远程文件变更时能淘汰旧缓存
var fileSize int64
var modTime string
if info, err := s.GetFileInfo(connID, remotePath); err == nil {
if sz, ok := info["size"].(int64); ok {
fileSize = sz
}
if mt, ok := info["mod_time"].(string); ok {
modTime = mt
}
}
return storage.DownloadToTempCached("sftp", connID, remotePath, fileSize, modTime, func() (string, error) {
return s.downloadToTempDirect(connID, remotePath)
})
}
// downloadToTempDirect 实际执行下载(无缓存)
func (s *Service) downloadToTempDirect(connID string, remotePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 预览文件大小上限 50MB比编辑模式宽松
const maxPreviewSize int64 = 50 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(remotePath)
if e != nil { return e }
if fi.Size() > maxPreviewSize {
return fmt.Errorf("预览文件过大: %s", filesystem.FormatBytes(fi.Size()))
}
return nil
})
if err != nil {
return "", err
}
tmpDir := os.TempDir()
tmpFile, e := os.CreateTemp(tmpDir, "udesk-sftp-*-"+filepath.Base(remotePath))
if e != nil {
return "", fmt.Errorf("创建临时文件失败: %w", e)
}
localPath := tmpFile.Name()
tmpFile.Close()
err = c.WithRetry(func(sc *sftpclient.Client) error {
src, e := sc.Open(remotePath)
if e != nil {
return e
}
defer src.Close()
dst, e := os.Create(localPath)
if e != nil {
return e
}
defer dst.Close()
_, e = io.Copy(dst, src)
return e
})
if err != nil {
os.Remove(localPath)
return "", fmt.Errorf("下载文件失败: %w", err)
}
return localPath, nil
}
// DownloadToTempCached 带缓存的 SFTP 下载(支持传入文件元信息)
func (s *Service) DownloadToTempCached(connID, remotePath string, fileSize int64, modTime string) (string, error) {
return storage.DownloadToTempCached("sftp", connID, remotePath, fileSize, modTime, func() (string, error) {
return s.downloadToTempDirect(connID, remotePath)
})
}
// DownloadSiteForPreview 下载 HTML 及其网站资源到本地临时目录
func (s *Service) DownloadSiteForPreview(connID string, remotePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 1. 创建临时目录
tmpDir, err := os.MkdirTemp("", "udesk-sftp-site-*")
if err != nil {
return "", fmt.Errorf("创建临时目录失败: %w", err)
}
// 2. 确定远程网站根目录(从 HTML 路径推断)
keyDir := path.Dir(remotePath)
if keyDir == "." {
keyDir = ""
}
var htmlLocalPath string
if keyDir != "" {
htmlLocalPath = filepath.Join(tmpDir, filepath.FromSlash(keyDir), path.Base(remotePath))
if err := os.MkdirAll(filepath.Dir(htmlLocalPath), 0755); err != nil {
os.RemoveAll(tmpDir)
return "", fmt.Errorf("创建目录失败: %w", err)
}
} else {
htmlLocalPath = filepath.Join(tmpDir, path.Base(remotePath))
}
// 3. 下载 HTML
if err := s.sftpDownloadFile(c, remotePath, htmlLocalPath); err != nil {
os.RemoveAll(tmpDir)
return "", fmt.Errorf("下载 HTML 失败: %w", err)
}
// 4. 解析 HTML 提取资源路径
htmlContent, err := os.ReadFile(htmlLocalPath)
if err != nil {
return htmlLocalPath, nil
}
resources := sftpExtractResources(string(htmlContent))
// 5. 下载静态引用资源(嗅探网站根)
htmlRemoteDir := keyDir
if htmlRemoteDir == "/" {
htmlRemoteDir = ""
}
htmlLocalDir := filepath.Dir(htmlLocalPath)
var siteRoot string
var discoveredDirs []string
seenDir := make(map[string]bool)
recordDir := func(remoteKey string) {
dir := path.Dir(remoteKey)
if !seenDir[dir] {
seenDir[dir] = true
discoveredDirs = append(discoveredDirs, dir)
}
}
for _, resPath := range resources {
if sftpShouldSkip(resPath) {
continue
}
isAbsolute := strings.HasPrefix(resPath, "/")
cleanPath := strings.TrimPrefix(resPath, "/")
cleanPath = strings.TrimPrefix(cleanPath, "./")
if cleanPath == "" {
continue
}
localPath := filepath.Join(htmlLocalDir, filepath.FromSlash(cleanPath))
if isAbsolute {
resolvedKey := sftpResolveAndDownload(s, c, htmlRemoteDir, cleanPath, localPath, &siteRoot)
if resolvedKey != "" {
recordDir(resolvedKey)
}
} else {
remoteKey := path.Join(htmlRemoteDir, cleanPath)
if s.sftpTryDownload(c, remoteKey, localPath) {
recordDir(remoteKey)
}
}
}
// 6. 补充下载已发现目录中的剩余文件(覆盖动态 chunk 等)
for _, dir := range discoveredDirs {
sftpSupplementDir(s, c, dir, tmpDir, siteRoot)
}
return htmlLocalPath, nil
}
// sftpDownloadFile 下载单个远程文件到本地路径
func (s *Service) sftpDownloadFile(c *Client, remotePath, localPath string) error {
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
return err
}
return c.WithRetry(func(sc *sftpclient.Client) error {
src, err := sc.Open(remotePath)
if err != nil {
return err
}
defer src.Close()
dst, err := os.Create(localPath)
if err != nil {
return err
}
defer dst.Close()
_, err = io.Copy(dst, src)
return err
})
}
// sftpTryDownload 尝试下载(失败静默,返回是否成功)
func (s *Service) sftpTryDownload(c *Client, remotePath, localPath string) bool {
err := s.sftpDownloadFile(c, remotePath, localPath)
if err != nil {
os.Remove(localPath)
return false
}
return true
}
// sftpResolveAndDownload 嗅探网站根并下载绝对路径资源
func sftpResolveAndDownload(s *Service, c *Client, htmlDir string, cleanPath string, localPath string, siteRoot *string) string {
if *siteRoot != "" {
s.sftpTryDownload(c, *siteRoot+cleanPath, localPath)
return *siteRoot + cleanPath
}
dir := htmlDir
for {
candidate := path.Join(dir, cleanPath)
if s.sftpTryDownload(c, candidate, localPath) {
if dir == "" {
*siteRoot = ""
} else {
*siteRoot = dir + "/"
}
return candidate
}
if dir == "" {
break
}
parent := path.Dir(dir)
if parent == dir || parent == "." {
dir = ""
} else {
dir = parent
}
}
return ""
}
// sftpSupplementDir 补充下载远程目录中尚未下载的文件(只处理已知资源所在目录)
func sftpSupplementDir(s *Service, c *Client, remoteDir string, tmpDir string, siteRoot string) {
var entries []fs.FileInfo
err := c.WithRetry(func(sc *sftpclient.Client) error {
var e error
entries, e = sc.ReadDir(remoteDir)
return e
})
if err != nil {
return
}
for _, entry := range entries {
if entry.IsDir() || entry.Size() == 0 {
continue
}
fullPath := path.Join(remoteDir, entry.Name())
relPath := strings.TrimPrefix(fullPath, siteRoot)
relPath = strings.TrimPrefix(relPath, "/")
localPath := filepath.Join(tmpDir, filepath.FromSlash(relPath))
if _, err := os.Stat(localPath); err == nil {
continue
}
s.sftpTryDownload(c, fullPath, localPath)
}
}
// GetCommonPaths 返回 SFTP 远程主机常用路径
func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
c := s.manager.GetClient(connID)
username := "root"
if c != nil {
c.mu.Lock()
username = c.config.Username
c.mu.Unlock()
}
home := "/root"
if username != "root" && username != "" {
home = fmt.Sprintf("/home/%s", username)
}
return map[string]string{
"home": home,
"tmp": "/tmp",
"root": "/",
}, nil
}
// CleanupTempFiles 清理遗留的临时预览文件(已由 SQLite 缓存接管)
func CleanupTempFiles() {
storage.CleanupExpiredCache()
}
// GetSystemInfo 通过 SSH 命令采集远程系统信息(磁盘/CPU/内存)
func (s *Service) GetSystemInfo(connID string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
type cmdResult struct {
key string // "df" | "mem" | "cpu"
output string
err error
}
// 并发执行三条命令(每条独立超时)
results := make(chan cmdResult, 3)
cmdTimeout := 6 * time.Second
runCmd := func(key, cmd string) {
out, e := c.RunCommand(cmd)
results <- cmdResult{key, out, e}
}
go runCmd("df", "df -B1 / 2>/dev/null | tail -1")
go runCmd("mem", "free -b 2>/dev/null | grep -E '^Mem:' || cat /proc/meminfo 2>/dev/null | head -2")
go runCmd("cpu", "top -bn1 | grep 'Cpu(s)' | awk '{print $2}' | cut -d'%' -f1")
var dfOut, memOut, cpuOut string
hasErr := false
for i := 0; i < 3; i++ {
select {
case r := <-results:
switch r.key {
case "df":
dfOut = r.output
case "mem":
memOut = r.output
case "cpu":
cpuOut = r.output
}
if r.err != nil {
hasErr = true
}
case <-time.After(cmdTimeout):
hasErr = true
}
}
info := make(map[string]interface{})
// 解析 CPU 使用率
cpuOut = strings.TrimSpace(cpuOut)
if usage, err := strconv.ParseFloat(cpuOut, 64); err == nil && usage >= 0 {
info["cpu_usage"] = fmt.Sprintf("%.0f%%", usage)
}
// 解析磁盘信息: df -B1 / → Filesystem 1M-blocks Used Available Use% Mounted on
dfOut = strings.TrimSpace(dfOut)
if dfFields := strings.Fields(dfOut); len(dfFields) >= 5 {
var diskTotal, diskUsed uint64
if v, err := strconv.ParseUint(dfFields[1], 10, 64); err == nil {
diskTotal = v
info["disk_total"] = v
}
if v, err := strconv.ParseUint(dfFields[2], 10, 64); err == nil {
diskUsed = v
info["disk_used"] = v
}
if diskTotal > 0 {
info["disk_usage"] = fmt.Sprintf("%.0f%%", float64(diskUsed)/float64(diskTotal)*100)
}
}
// 解析内存信息: free -b | grep Mem: 或 /proc/meminfo
memOut = strings.TrimSpace(memOut)
if strings.Contains(memOut, "MemTotal:") {
parseProcMeminfo(memOut, info)
} else if fields := strings.Fields(memOut); len(fields) >= 3 {
var memTotal, memUsed uint64
if v, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memTotal = v
info["mem_total"] = v
}
if v, err := strconv.ParseUint(fields[2], 10, 64); err == nil {
memUsed = v
info["mem_used"] = v
}
if memTotal > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(memUsed)/float64(memTotal)*100)
}
}
if hasErr && len(info) == 0 {
return nil, fmt.Errorf("采集远程系统信息失败")
}
return info, nil
}
func parseProcMeminfo(output string, info map[string]interface{}) {
lines := strings.Split(output, "\n")
memMap := make(map[string]uint64)
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) >= 2 {
key := strings.TrimSuffix(fields[0], ":")
if val, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memMap[key] = val
}
}
}
total := memMap["MemTotal"] * 1024 // kB → bytes
// 可用内存 ≈ MemAvailable (较新内核) 或 MemFree + Buffers + Cached
free := memMap["MemFree"]
if avail, ok := memMap["MemAvailable"]; ok {
free = avail
} else {
free += memMap["Buffers"] + memMap["Cached"]
}
used := total - free*1024
info["mem_total"] = total
info["mem_used"] = used
if total > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(used)/float64(total)*100)
}
}
// --- 内部辅助 ---
func (s *Service) getClient(connID string) (*Client, error) {
c := s.manager.GetClient(connID)
if c == nil {
return nil, fmt.Errorf("SFTP 连接不存在: %s", connID)
}
return c, nil
}
func toUnixPath(p string) string {
return strings.ReplaceAll(p, "\\", "/")
}
func toFileOperationResult(m map[string]interface{}, isDir bool) *filesystem.FileOperationResult {
name, _ := m["name"].(string)
p, _ := m["path"].(string)
size, _ := m["size"].(int64)
modTime, _ := m["mod_time"].(string)
mode, _ := m["mode"].(string)
return &filesystem.FileOperationResult{
Path: p,
Name: name,
Size: size,
SizeStr: filesystem.FormatBytes(size),
IsDir: isDir,
ModTime: modTime,
Mode: mode,
}
}
// sftpExtractResources 从 HTML 内容提取资源路径
func sftpExtractResources(html string) []string {
seen := make(map[string]bool)
var resources []string
add := func(v string) {
v = strings.TrimSpace(v)
if v != "" && !seen[v] {
seen[v] = true
resources = append(resources, v)
}
}
for _, m := range sftpResRegex.FindAllStringSubmatch(html, -1) {
if len(m) > 1 {
add(m[1])
}
}
for _, m := range sftpCssUrlRe.FindAllStringSubmatch(html, -1) {
if len(m) > 1 {
add(m[1])
}
}
return resources
}
// sftpShouldSkip 判断资源路径是否应跳过
func sftpShouldSkip(p string) bool {
return strings.HasPrefix(p, "data:") ||
strings.HasPrefix(p, "http://") ||
strings.HasPrefix(p, "https://") ||
strings.HasPrefix(p, "//") ||
strings.HasPrefix(p, "#") ||
strings.HasPrefix(p, "javascript:") ||
strings.HasPrefix(p, "mailto:") ||
strings.HasPrefix(p, "blob:")
}