Private
Public Access
1
0
Files
u-desk/internal/sftp/service.go
绝尘 ee4b1f5ac1 修复:审查发现的高优先问题(竞态/初始化/碰撞)
- app.go: profileSvc移入App struct,用a.mu保护
- sqlite.go: InitFast加sync.Once防并发双重初始化
- client.go: Manager.Connect加sync.Mutex防竞态泄漏SSH
- service.go: 临时文件用os.CreateTemp防时间戳碰撞
- connection-manager: 密码缺失时不再塞入假WailsTransport
2026-05-04 15:40:04 +08:00

504 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sftp
import (
"encoding/base64"
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"u-desk/internal/filesystem"
sftpclient "github.com/pkg/sftp"
)
// Service SFTP 文件操作服务
type Service struct {
manager *Manager
}
// NewService 创建 SFTP 服务实例
func NewService() *Service {
return &Service{manager: GetManager()}
}
// GetManager 获取底层连接管理器(供 App 层调用)
func (s *Service) GetManager() *Manager {
return s.manager
}
// ConnID 从配置生成连接标识符
func ConnID(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}
// --- 核心文件操作 ---
func (s *Service) ListDir(connID string, dirPath string) ([]map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var entries []fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
entries, e = sc.ReadDir(dirPath)
return e
})
if err != nil {
return nil, fmt.Errorf("读取目录失败: %w", err)
}
result := make([]map[string]interface{}, 0, len(entries))
for _, info := range entries {
fullPath := path.Join(dirPath, info.Name())
result = append(result, map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(fullPath),
"is_dir": info.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
})
}
return result, nil
}
func (s *Service) ReadFile(connID string, filePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 大小限制(与本地模式 ReadFile 的 10MB 上限对齐)
const maxSize int64 = 10 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil { return e }
if fi.Size() > maxSize {
return fmt.Errorf("文件过大 (%s),超过 %d 限制", filesystem.FormatBytes(fi.Size()), maxSize)
}
return nil
})
if err != nil {
return "", err
}
var data []byte
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Open(filePath)
if e != nil {
return e
}
defer f.Close()
data, e = io.ReadAll(f)
return e
})
if err != nil {
return "", fmt.Errorf("读取文件失败: %w", err)
}
return string(data), nil
}
func (s *Service) WriteFile(connID string, filePath string, content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write([]byte(content))
return e
})
}
// WriteBase64File 将 base64 编码的二进制内容写入远程文件(用于粘贴图片等场景)
func (s *Service) WriteBase64File(connID string, filePath string, base64Content string) error {
c, err := s.getClient(connID)
if err != nil {
return err
}
data, err := base64.StdEncoding.DecodeString(base64Content)
if err != nil {
return fmt.Errorf("base64 解码失败: %w", err)
}
return c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return fmt.Errorf("创建文件失败: %w", e)
}
defer f.Close()
_, e = f.Write(data)
return e
})
}
func (s *Service) GetFileInfo(connID string, filePath string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
var info fs.FileInfo
err = c.WithRetry(func(sc *sftpclient.Client) error {
var e error
info, e = sc.Stat(filePath)
return e
})
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
return map[string]interface{}{
"name": info.Name(),
"path": toUnixPath(filePath),
"size": info.Size(),
"size_str": filesystem.FormatBytes(info.Size()),
"is_dir": info.IsDir(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
"mode": info.Mode().String(),
}, nil
}
func (s *Service) CreateDir(connID string, dirPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.MkdirAll(dirPath)
})
if err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, dirPath)
return toFileOperationResult(infoMap, true), nil
}
func (s *Service) CreateFile(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
f, e := sc.Create(filePath)
if e != nil {
return e
}
return f.Close()
})
if err != nil {
return nil, fmt.Errorf("创建文件失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, filePath)
return toFileOperationResult(infoMap, false), nil
}
func (s *Service) DeletePath(connID string, filePath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
infoMap, _ := s.GetFileInfo(connID, filePath)
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(filePath)
if e != nil {
return e
}
if fi.IsDir() {
// 递归删除目录
return sc.RemoveAll(filePath)
}
return sc.Remove(filePath)
})
if err != nil {
return nil, fmt.Errorf("删除失败: %w", err)
}
result := toFileOperationResult(infoMap, false)
result.Deleted = true
return result, nil
}
func (s *Service) RenamePath(connID string, oldPath, newPath string) (*filesystem.FileOperationResult, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
err = c.WithRetry(func(sc *sftpclient.Client) error {
return sc.Rename(oldPath, newPath)
})
if err != nil {
return nil, fmt.Errorf("重命名失败: %w", err)
}
infoMap, _ := s.GetFileInfo(connID, newPath)
result := toFileOperationResult(infoMap, false)
result.OldPath = oldPath
return result, nil
}
// DownloadToTemp 下载远程文件到本地临时目录(用于预览)
func (s *Service) DownloadToTemp(connID string, remotePath string) (string, error) {
c, err := s.getClient(connID)
if err != nil {
return "", err
}
// 预览文件大小上限 50MB比编辑模式宽松
const maxPreviewSize int64 = 50 << 20
err = c.WithRetry(func(sc *sftpclient.Client) error {
fi, e := sc.Stat(remotePath)
if e != nil { return e }
if fi.Size() > maxPreviewSize {
return fmt.Errorf("预览文件过大: %s", filesystem.FormatBytes(fi.Size()))
}
return nil
})
if err != nil {
return "", err
}
tmpDir := os.TempDir()
tmpFile, e := os.CreateTemp(tmpDir, "udesk-sftp-*-"+filepath.Base(remotePath))
if e != nil {
return "", fmt.Errorf("创建临时文件失败: %w", e)
}
localPath := tmpFile.Name()
tmpFile.Close()
err = c.WithRetry(func(sc *sftpclient.Client) error {
src, e := sc.Open(remotePath)
if e != nil {
return e
}
defer src.Close()
dst, e := os.Create(localPath)
if e != nil {
return e
}
defer dst.Close()
_, e = io.Copy(dst, src)
return e
})
if err != nil {
return "", fmt.Errorf("下载文件失败: %w", err)
}
return localPath, nil
}
// GetCommonPaths 返回 SFTP 远程主机常用路径
func (s *Service) GetCommonPaths(connID string) (map[string]string, error) {
c := s.manager.GetClient(connID)
username := "root"
if c != nil {
c.mu.Lock()
username = c.config.Username
c.mu.Unlock()
}
home := "/root"
if username != "root" && username != "" {
home = fmt.Sprintf("/home/%s", username)
}
return map[string]string{
"home": home,
"tmp": "/tmp",
"root": "/",
}, nil
}
// CleanupTempFiles 清理遗留的临时预览文件
func CleanupTempFiles() {
tmpDir := os.TempDir()
entries, err := os.ReadDir(tmpDir)
if err != nil {
return
}
for _, entry := range entries {
if strings.HasPrefix(entry.Name(), "udesk-sftp-preview-") {
os.Remove(filepath.Join(tmpDir, entry.Name()))
}
}
}
// GetSystemInfo 通过 SSH 命令采集远程系统信息(磁盘/CPU/内存)
func (s *Service) GetSystemInfo(connID string) (map[string]interface{}, error) {
c, err := s.getClient(connID)
if err != nil {
return nil, err
}
type cmdResult struct {
key string // "df" | "mem" | "cpu"
output string
err error
}
// 并发执行三条命令(每条独立超时)
results := make(chan cmdResult, 3)
cmdTimeout := 6 * time.Second
runCmd := func(key, cmd string) {
out, e := c.RunCommand(cmd)
results <- cmdResult{key, out, e}
}
go runCmd("df", "df -B1 / 2>/dev/null | tail -1")
go runCmd("mem", "free -b 2>/dev/null | grep -E '^Mem:' || cat /proc/meminfo 2>/dev/null | head -2")
go runCmd("cpu", "top -bn1 | grep 'Cpu(s)' | awk '{print $2}' | cut -d'%' -f1")
var dfOut, memOut, cpuOut string
hasErr := false
for i := 0; i < 3; i++ {
select {
case r := <-results:
switch r.key {
case "df":
dfOut = r.output
case "mem":
memOut = r.output
case "cpu":
cpuOut = r.output
}
if r.err != nil {
hasErr = true
}
case <-time.After(cmdTimeout):
hasErr = true
}
}
info := make(map[string]interface{})
// 解析 CPU 使用率
cpuOut = strings.TrimSpace(cpuOut)
if usage, err := strconv.ParseFloat(cpuOut, 64); err == nil && usage >= 0 {
info["cpu_usage"] = fmt.Sprintf("%.0f%%", usage)
}
// 解析磁盘信息: df -B1 / → Filesystem 1M-blocks Used Available Use% Mounted on
dfOut = strings.TrimSpace(dfOut)
if dfFields := strings.Fields(dfOut); len(dfFields) >= 5 {
var diskTotal, diskUsed uint64
if v, err := strconv.ParseUint(dfFields[1], 10, 64); err == nil {
diskTotal = v
info["disk_total"] = v
}
if v, err := strconv.ParseUint(dfFields[2], 10, 64); err == nil {
diskUsed = v
info["disk_used"] = v
}
if diskTotal > 0 {
info["disk_usage"] = fmt.Sprintf("%.0f%%", float64(diskUsed)/float64(diskTotal)*100)
}
}
// 解析内存信息: free -b | grep Mem: 或 /proc/meminfo
memOut = strings.TrimSpace(memOut)
if strings.Contains(memOut, "MemTotal:") {
parseProcMeminfo(memOut, info)
} else if fields := strings.Fields(memOut); len(fields) >= 3 {
var memTotal, memUsed uint64
if v, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memTotal = v
info["mem_total"] = v
}
if v, err := strconv.ParseUint(fields[2], 10, 64); err == nil {
memUsed = v
info["mem_used"] = v
}
if memTotal > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(memUsed)/float64(memTotal)*100)
}
}
if hasErr && len(info) == 0 {
return nil, fmt.Errorf("采集远程系统信息失败")
}
return info, nil
}
func parseProcMeminfo(output string, info map[string]interface{}) {
lines := strings.Split(output, "\n")
memMap := make(map[string]uint64)
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) >= 2 {
key := strings.TrimSuffix(fields[0], ":")
if val, err := strconv.ParseUint(fields[1], 10, 64); err == nil {
memMap[key] = val
}
}
}
total := memMap["MemTotal"] * 1024 // kB → bytes
// 可用内存 ≈ MemAvailable (较新内核) 或 MemFree + Buffers + Cached
free := memMap["MemFree"]
if avail, ok := memMap["MemAvailable"]; ok {
free = avail
} else {
free += memMap["Buffers"] + memMap["Cached"]
}
used := total - free*1024
info["mem_total"] = total
info["mem_used"] = used
if total > 0 {
info["mem_usage"] = fmt.Sprintf("%.0f%%", float64(used)/float64(total)*100)
}
}
// --- 内部辅助 ---
func (s *Service) getClient(connID string) (*Client, error) {
c := s.manager.GetClient(connID)
if c == nil {
return nil, fmt.Errorf("SFTP 连接不存在: %s", connID)
}
return c, nil
}
func toUnixPath(p string) string {
return strings.ReplaceAll(p, "\\", "/")
}
func toFileOperationResult(m map[string]interface{}, isDir bool) *filesystem.FileOperationResult {
name, _ := m["name"].(string)
p, _ := m["path"].(string)
size, _ := m["size"].(int64)
modTime, _ := m["mod_time"].(string)
mode, _ := m["mode"].(string)
return &filesystem.FileOperationResult{
Path: p,
Name: name,
Size: size,
SizeStr: filesystem.FormatBytes(size),
IsDir: isDir,
ModTime: modTime,
Mode: mode,
}
}