Private
Public Access
1
0
Files
u-desk/internal/sftp/client.go
绝尘 ee4b1f5ac1 修复:审查发现的高优先问题(竞态/初始化/碰撞)
- app.go: profileSvc移入App struct,用a.mu保护
- sqlite.go: InitFast加sync.Once防并发双重初始化
- client.go: Manager.Connect加sync.Mutex防竞态泄漏SSH
- service.go: 临时文件用os.CreateTemp防时间戳碰撞
- connection-manager: 密码缺失时不再塞入假WailsTransport
2026-05-04 15:40:04 +08:00

265 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sftp
import (
"fmt"
"net"
"os"
"sync"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Client SFTP 客户端封装(单连接)
type Client struct {
config *Config
client *sftp.Client
sshClient *ssh.Client
mu sync.Mutex
}
// Manager 全局 SFTP 连接管理器(以 host:port 为 key 的连接池)
type Manager struct {
clients sync.Map // map[string]*Client
mu sync.Mutex
}
var globalManager = &Manager{}
// GetManager 获取全局连接管理器
func GetManager() *Manager {
return globalManager
}
// Connect 创建或复用 SFTP 连接
func (m *Manager) Connect(config *Config) (*Client, error) {
key := fmt.Sprintf("%s:%d", config.Host, config.Port)
m.mu.Lock()
defer m.mu.Unlock()
if existing, ok := m.clients.Load(key); ok {
c := existing.(*Client)
if c.IsHealthy() {
return c, nil
}
c.Close()
m.clients.Delete(key)
}
c, err := newClient(config)
if err != nil {
return nil, err
}
m.clients.Store(key, c)
return c, nil
}
// GetClient 获取已有连接(不复用也不新建)
func (m *Manager) GetClient(connID string) *Client {
if val, ok := m.clients.Load(connID); ok {
return val.(*Client)
}
return nil
}
// Disconnect 关闭并移除指定连接
func (m *Manager) Disconnect(host string, port int) {
key := fmt.Sprintf("%s:%d", host, port)
if val, ok := m.clients.LoadAndDelete(key); ok {
val.(*Client).Close()
}
}
// Shutdown 关闭所有连接
func (m *Manager) Shutdown() {
m.clients.Range(func(key, value any) bool {
value.(*Client).Close()
m.clients.Delete(key)
return true
})
}
// --- 内部 ---
func newClient(config *Config) (*Client, error) {
sshConfig := &ssh.ClientConfig{
Config: ssh.Config{
KeyExchanges: []string{
"curve25519-sha256", "curve25519-sha256@libssh.org",
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384",
"diffie-hellman-group14-sha256", "diffie-hellman-group14-sha1",
},
},
User: config.Username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: config.Timeout,
}
// 认证方式选择
if config.KeyPath != "" {
key, err := os.ReadFile(config.KeyPath)
if err != nil {
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("读取密钥文件失败: %w", err)}
}
var signer ssh.Signer
if config.KeyPassphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.KeyPassphrase))
} else {
signer, err = ssh.ParsePrivateKey(key)
}
if err != nil {
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("解析密钥失败: %w", err)}
}
sshConfig.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
} else if config.Password != "" {
pw := config.Password
sshConfig.Auth = []ssh.AuthMethod{
ssh.Password(pw),
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = pw
}
return answers, nil
}),
}
} else {
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("必须提供密码或密钥文件")}
}
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
sshConn, err := net.DialTimeout("tcp", addr, config.Timeout)
if err != nil {
return nil, &ConnectionError{Op: "dial", Err: err}
}
sshConnConn, chans, reqs, err := ssh.NewClientConn(sshConn, addr, sshConfig)
if err != nil {
sshConn.Close()
return nil, &ConnectionError{Op: "handshake", Err: err}
}
sshClient := ssh.NewClient(sshConnConn, chans, reqs)
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
sshClient.Close()
return nil, &ConnectionError{Op: "sftp_init", Err: err}
}
return &Client{
config: config,
client: sftpClient,
sshClient: sshClient,
}, nil
}
// IsHealthy 检查连接是否健康(先取引用再解锁,避免持锁做 I/O
func (c *Client) IsHealthy() bool {
c.mu.Lock()
client := c.client
c.mu.Unlock()
if client == nil {
return false
}
_, err := client.Stat("/")
return err == nil
}
// WithRetry 带重试的操作执行(自动处理断线重连)
func (c *Client) WithRetry(fn func(*sftp.Client) error) error {
const maxRetries = 3
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(time.Duration((attempt+1)*2) * time.Second)
if reconnectErr := c.reconnect(); reconnectErr != nil {
lastErr = reconnectErr
continue
}
}
c.mu.Lock()
client := c.client
c.mu.Unlock()
if client == nil {
lastErr = fmt.Errorf("SFTP 客户端未初始化")
continue
}
if err := fn(client); err != nil {
if isNetworkError(err) {
lastErr = err
continue
}
return err
}
return nil
}
return fmt.Errorf("操作失败(已重试 %d 次): %w", maxRetries, lastErr)
}
func (c *Client) reconnect() error {
nc, err := newClient(c.config)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
c.closeLocked()
c.client = nc.client
c.sshClient = nc.sshClient
return nil
}
func (c *Client) Close() {
c.mu.Lock()
defer c.mu.Unlock()
c.closeLocked()
}
func (c *Client) closeLocked() {
if c.client != nil {
c.client.Close()
c.client = nil
}
if c.sshClient != nil {
c.sshClient.Close()
c.sshClient = nil
}
}
// RunCommand 通过 SSH Session 执行远程命令,返回 stdout
func (c *Client) RunCommand(cmd string) (string, error) {
c.mu.Lock()
sshClient := c.sshClient
c.mu.Unlock()
if sshClient == nil {
return "", fmt.Errorf("SSH 客户端未初始化")
}
session, err := sshClient.NewSession()
if err != nil {
return "", fmt.Errorf("创建 SSH 会话失败: %w", err)
}
defer session.Close()
out, err := session.CombinedOutput(cmd)
if err != nil {
return "", fmt.Errorf("执行命令失败 [%s]: %w", cmd, err)
}
return string(out), nil
}
// RawClient 获取底层 sftp.Client高级用法
func (c *Client) RawClient() *sftp.Client {
c.mu.Lock()
defer c.mu.Unlock()
return c.client
}