Private
Public Access
1
0
Files
u-desk/internal/sftp/service.go
绝尘 6bee55b96f 新增:SFTP直连+连接池+autoConnect+文件服务器端口自动回退
- SFTP模块:连接/断开/文件CRUD/系统信息采集/base64二进制写入
- 连接池:多服务器同时在线,瞬间切换profile
- autoConnect:启动时自动连接所有非本地服务器
- 端口自动回退:listenWithFallback消除TOCTOU,解决端口冲突崩溃
- 文件服务器URL集中管理:file-server.ts消除8+处硬编码端口
- Sidebar设置面板:添加服务器/自动连接/自动刷新开关
- 修复:validateFilePath越界panic、正则预编译
- 修复:注释准确性(RemoveAll/端口8073/动态端口文档)
2026-05-04 15:33:19 +08:00

500 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sftp
import (
"encoding/base64"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"u-desk/internal/filesystem"
sftpclient "github.com/pkg/sftp"
)
// Service SFTP 文件操作服务
type Service struct {
manager *Manager
}
// NewService 创建 SFTP 服务实例
func NewService() *Service {
return &Service{manager: GetManager()}
}
// GetManager 获取底层连接管理器(供 App 层调用)
func (s *Service) GetManager() *Manager {
return s.manager
}
// ConnID 从配置生成连接标识符
func ConnID(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}
// --- 核心文件操作 ---
func (s *Service) ListDir(connID string, dirPath string) ([]map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var entries []fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
entries, e = sc.ReadDir(dirPath)
return e
})
if err != nil {
return nil, fmt.Errorf("读取目录失败: %w", err)
}
result := make([]map[string]interface{}, 0, len(entries))
for _, info := range entries {
fullPath := path.Join(dirPath, info.Name())
result = append(result, map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(fullPath),
"is_dir": info.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
func (s *Service) ReadFile(connID string, filePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 大小限制(与本地模式 ReadFile 的 10MB 上限对齐)
const maxSize int64 = 10 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil { return e }
if fi.Size() > maxSize {
return fmt.Errorf("文件过大 (%s),超过 %d 限制", filesystem.FormatBytes(fi.Size()), maxSize)
}
return nil
})
if err != nil {
return "", err
}
var data []byte
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Open(filePath)
if e != nil {
return e
}
defer f.Close()
data, e = io.ReadAll(f)
return e
})
if err != nil {
return "", fmt.Errorf("读取文件失败: %w", err)
}
return string(data), nil
}
func (s *Service) WriteFile(connID string, filePath string, content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write([]byte(content))
return e
})
}
// WriteBase64File 将 base64 编码的二进制内容写入远程文件(用于粘贴图片等场景)
func (s *Service) WriteBase64File(connID string, filePath string, base64Content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
data, err := base64.StdEncoding.DecodeString(base64Content)
if err != nil {
return fmt.Errorf("base64 解码失败: %w", err)
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write(data)
return e
})
}
func (s *Service) GetFileInfo(connID string, filePath string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var info fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
info, e = sc.Stat(filePath)
return e
})
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
return map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(filePath),
"size": info.Size(),
"size_str": filesystem.FormatBytes(info.Size()),
"is_dir": info.IsDir(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
"mode": info.Mode().String(),
}, nil
}
func (s *Service) CreateDir(connID string, dirPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.MkdirAll(dirPath)
})
if err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, dirPath)
return toFileOperationResult(infoMap, true), nil
}
func (s *Service) CreateFile(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return e
}
return f.Close()
})
if err != nil {
return nil, fmt.Errorf("创建文件失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, filePath)
return toFileOperationResult(infoMap, false), nil
}
func (s *Service) DeletePath(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
infoMap, _ := s.GetFileInfo(connID, filePath)
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil {
return e
}
if fi.IsDir() {
// 递归删除目录
return sc.RemoveAll(filePath)
}
return sc.Remove(filePath)
})
if err != nil {
return nil, fmt.Errorf("删除失败: %w", err)
}
result := toFileOperationResult(infoMap, false)
result.Deleted = true
return result, nil
}
func (s *Service) RenamePath(connID string, oldPath, newPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.Rename(oldPath, newPath)
})
if err != nil {
return nil, fmt.Errorf("重命名失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, newPath)
result := toFileOperationResult(infoMap, false)
result.OldPath = oldPath
return result, nil
}
// DownloadToTemp 下载远程文件到本地临时目录(用于预览)
func (s *Service) DownloadToTemp(connID string, remotePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 预览文件大小上限 50MB比编辑模式宽松
const maxPreviewSize int64 = 50 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(remotePath)
if e != nil { return e }
if fi.Size() > maxPreviewSize {
return fmt.Errorf("预览文件过大: %s", filesystem.FormatBytes(fi.Size()))
}
return nil
})
if err != nil {
return "", err
}
tmpDir := os.TempDir()
// 用时间戳+随机数避免同名文件覆盖
localPath := filepath.Join(tmpDir, fmt.Sprintf("udesk-sftp-preview-%d-%s", time.Now().UnixNano(), filepath.Base(remotePath)))
err = c.WithRetry(func(sc *sftpclient.Client) error {
src, e := sc.Open(remotePath)
if e != nil {
return e
}
defer src.Close()
dst, e := os.Create(localPath)
if e != nil {
return e
}
defer dst.Close()
_, e = io.Copy(dst, src)
return e
})
if err != nil {
return "", fmt.Errorf("下载文件失败: %w", err)
}
return localPath, nil
}
// GetCommonPaths 返回 SFTP 远程主机常用路径
func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
c := s.manager.GetClient(connID)
username := "root"
if c != nil {
c.mu.Lock()
username = c.config.Username
c.mu.Unlock()
}
home := "/root"
if username != "root" && username != "" {
home = fmt.Sprintf("/home/%s", username)
}
return map[string]string{
"home": home,
"tmp": "/tmp",
"root": "/",
}, nil
}
// CleanupTempFiles 清理遗留的临时预览文件
func CleanupTempFiles() {
tmpDir := os.TempDir()
entries, err := os.ReadDir(tmpDir)
if err != nil {
return
}
for _, entry := range entries {
if strings.HasPrefix(entry.Name(), "udesk-sftp-preview-") {
os.Remove(filepath.Join(tmpDir, entry.Name()))
}
}
}
// GetSystemInfo 通过 SSH 命令采集远程系统信息(磁盘/CPU/内存)
func (s *Service) GetSystemInfo(connID string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
type cmdResult struct {
key string // "df" | "mem" | "cpu"
output string
err error
}
// 并发执行三条命令(每条独立超时)
results := make(chan cmdResult, 3)
cmdTimeout := 6 * time.Second
runCmd := func(key, cmd string) {
out, e := c.RunCommand(cmd)
results <- cmdResult{key, out, e}
}
go runCmd("df", "df -B1 / 2>/dev/null | tail -1")
go runCmd("mem", "free -b 2>/dev/null | grep -E '^Mem:' || cat /proc/meminfo 2>/dev/null | head -2")
go runCmd("cpu", "top -bn1 | grep 'Cpu(s)' | awk '{print $2}' | cut -d'%' -f1")
var dfOut, memOut, cpuOut string
hasErr := false
for i := 0; i < 3; i++ {
select {
case r := <-results:
switch r.key {
case "df":
dfOut = r.output
case "mem":
memOut = r.output
case "cpu":
cpuOut = r.output
}
if r.err != nil {
hasErr = true
}
case <-time.After(cmdTimeout):
hasErr = true
}
}
info := make(map[string]interface{})
// 解析 CPU 使用率
cpuOut = strings.TrimSpace(cpuOut)
if usage, err := strconv.ParseFloat(cpuOut, 64); err == nil && usage >= 0 {
info["cpu_usage"] = fmt.Sprintf("%.0f%%", usage)
}
// 解析磁盘信息: df -B1 / → Filesystem 1M-blocks Used Available Use% Mounted on
dfOut = strings.TrimSpace(dfOut)
if dfFields := strings.Fields(dfOut); len(dfFields) >= 5 {
var diskTotal, diskUsed uint64
if v, err := strconv.ParseUint(dfFields[1], 10, 64); err == nil {
diskTotal = v
info["disk_total"] = v
}
if v, err := strconv.ParseUint(dfFields[2], 10, 64); err == nil {
diskUsed = v
info["disk_used"] = v
}
if diskTotal > 0 {
info["disk_usage"] = fmt.Sprintf("%.0f%%", float64(diskUsed)/float64(diskTotal)*100)
}
}
// 解析内存信息: free -b | grep Mem: 或 /proc/meminfo
memOut = strings.TrimSpace(memOut)
if strings.Contains(memOut, "MemTotal:") {
parseProcMeminfo(memOut, info)
} else if fields := strings.Fields(memOut); len(fields) >= 3 {
var memTotal, memUsed uint64
if v, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memTotal = v
info["mem_total"] = v
}
if v, err := strconv.ParseUint(fields[2], 10, 64); err == nil {
memUsed = v
info["mem_used"] = v
}
if memTotal > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(memUsed)/float64(memTotal)*100)
}
}
if hasErr && len(info) == 0 {
return nil, fmt.Errorf("采集远程系统信息失败")
}
return info, nil
}
func parseProcMeminfo(output string, info map[string]interface{}) {
lines := strings.Split(output, "\n")
memMap := make(map[string]uint64)
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) >= 2 {
key := strings.TrimSuffix(fields[0], ":")
if val, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memMap[key] = val
}
}
}
total := memMap["MemTotal"] * 1024 // kB → bytes
// 可用内存 ≈ MemAvailable (较新内核) 或 MemFree + Buffers + Cached
free := memMap["MemFree"]
if avail, ok := memMap["MemAvailable"]; ok {
free = avail
} else {
free += memMap["Buffers"] + memMap["Cached"]
}
used := total - free*1024
info["mem_total"] = total
info["mem_used"] = used
if total > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(used)/float64(total)*100)
}
}
// --- 内部辅助 ---
func (s *Service) getClient(connID string) (*Client, error) {
c := s.manager.GetClient(connID)
if c == nil {
return nil, fmt.Errorf("SFTP 连接不存在: %s", connID)
}
return c, nil
}
func toUnixPath(p string) string {
return strings.ReplaceAll(p, "\\", "/")
}
func toFileOperationResult(m map[string]interface{}, isDir bool) *filesystem.FileOperationResult {
name, _ := m["name"].(string)
p, _ := m["path"].(string)
size, _ := m["size"].(int64)
modTime, _ := m["mod_time"].(string)
mode, _ := m["mode"].(string)
return &filesystem.FileOperationResult{
Path: p,
Name: name,
Size: size,
SizeStr: filesystem.FormatBytes(size),
IsDir: isDir,
ModTime: modTime,
Mode: mode,
}
}