- app.go: profileSvc移入App struct,用a.mu保护 - sqlite.go: InitFast加sync.Once防并发双重初始化 - client.go: Manager.Connect加sync.Mutex防竞态泄漏SSH - service.go: 临时文件用os.CreateTemp防时间戳碰撞 - connection-manager: 密码缺失时不再塞入假WailsTransport
265 lines
5.9 KiB
Go
265 lines
5.9 KiB
Go
package sftp
|
||
|
||
import (
|
||
"fmt"
|
||
"net"
|
||
"os"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/pkg/sftp"
|
||
"golang.org/x/crypto/ssh"
|
||
)
|
||
|
||
// Client SFTP 客户端封装(单连接)
|
||
type Client struct {
|
||
config *Config
|
||
client *sftp.Client
|
||
sshClient *ssh.Client
|
||
mu sync.Mutex
|
||
}
|
||
|
||
// Manager 全局 SFTP 连接管理器(以 host:port 为 key 的连接池)
|
||
type Manager struct {
|
||
clients sync.Map // map[string]*Client
|
||
mu sync.Mutex
|
||
}
|
||
|
||
var globalManager = &Manager{}
|
||
|
||
// GetManager 获取全局连接管理器
|
||
func GetManager() *Manager {
|
||
return globalManager
|
||
}
|
||
|
||
// Connect 创建或复用 SFTP 连接
|
||
func (m *Manager) Connect(config *Config) (*Client, error) {
|
||
key := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
if existing, ok := m.clients.Load(key); ok {
|
||
c := existing.(*Client)
|
||
if c.IsHealthy() {
|
||
return c, nil
|
||
}
|
||
c.Close()
|
||
m.clients.Delete(key)
|
||
}
|
||
|
||
c, err := newClient(config)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
m.clients.Store(key, c)
|
||
return c, nil
|
||
}
|
||
|
||
// GetClient 获取已有连接(不复用也不新建)
|
||
func (m *Manager) GetClient(connID string) *Client {
|
||
if val, ok := m.clients.Load(connID); ok {
|
||
return val.(*Client)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Disconnect 关闭并移除指定连接
|
||
func (m *Manager) Disconnect(host string, port int) {
|
||
key := fmt.Sprintf("%s:%d", host, port)
|
||
if val, ok := m.clients.LoadAndDelete(key); ok {
|
||
val.(*Client).Close()
|
||
}
|
||
}
|
||
|
||
// Shutdown 关闭所有连接
|
||
func (m *Manager) Shutdown() {
|
||
m.clients.Range(func(key, value any) bool {
|
||
value.(*Client).Close()
|
||
m.clients.Delete(key)
|
||
return true
|
||
})
|
||
}
|
||
|
||
// --- 内部 ---
|
||
|
||
func newClient(config *Config) (*Client, error) {
|
||
sshConfig := &ssh.ClientConfig{
|
||
Config: ssh.Config{
|
||
KeyExchanges: []string{
|
||
"curve25519-sha256", "curve25519-sha256@libssh.org",
|
||
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384",
|
||
"diffie-hellman-group14-sha256", "diffie-hellman-group14-sha1",
|
||
},
|
||
},
|
||
User: config.Username,
|
||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||
Timeout: config.Timeout,
|
||
}
|
||
|
||
// 认证方式选择
|
||
if config.KeyPath != "" {
|
||
key, err := os.ReadFile(config.KeyPath)
|
||
if err != nil {
|
||
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("读取密钥文件失败: %w", err)}
|
||
}
|
||
var signer ssh.Signer
|
||
if config.KeyPassphrase != "" {
|
||
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.KeyPassphrase))
|
||
} else {
|
||
signer, err = ssh.ParsePrivateKey(key)
|
||
}
|
||
if err != nil {
|
||
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("解析密钥失败: %w", err)}
|
||
}
|
||
sshConfig.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
|
||
} else if config.Password != "" {
|
||
pw := config.Password
|
||
sshConfig.Auth = []ssh.AuthMethod{
|
||
ssh.Password(pw),
|
||
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||
answers := make([]string, len(questions))
|
||
for i := range questions {
|
||
answers[i] = pw
|
||
}
|
||
return answers, nil
|
||
}),
|
||
}
|
||
} else {
|
||
return nil, &ConnectionError{Op: "auth", Err: fmt.Errorf("必须提供密码或密钥文件")}
|
||
}
|
||
|
||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||
sshConn, err := net.DialTimeout("tcp", addr, config.Timeout)
|
||
if err != nil {
|
||
return nil, &ConnectionError{Op: "dial", Err: err}
|
||
}
|
||
|
||
sshConnConn, chans, reqs, err := ssh.NewClientConn(sshConn, addr, sshConfig)
|
||
if err != nil {
|
||
sshConn.Close()
|
||
return nil, &ConnectionError{Op: "handshake", Err: err}
|
||
}
|
||
|
||
sshClient := ssh.NewClient(sshConnConn, chans, reqs)
|
||
sftpClient, err := sftp.NewClient(sshClient)
|
||
if err != nil {
|
||
sshClient.Close()
|
||
return nil, &ConnectionError{Op: "sftp_init", Err: err}
|
||
}
|
||
|
||
return &Client{
|
||
config: config,
|
||
client: sftpClient,
|
||
sshClient: sshClient,
|
||
}, nil
|
||
}
|
||
|
||
// IsHealthy 检查连接是否健康(先取引用再解锁,避免持锁做 I/O)
|
||
func (c *Client) IsHealthy() bool {
|
||
c.mu.Lock()
|
||
client := c.client
|
||
c.mu.Unlock()
|
||
if client == nil {
|
||
return false
|
||
}
|
||
_, err := client.Stat("/")
|
||
return err == nil
|
||
}
|
||
|
||
// WithRetry 带重试的操作执行(自动处理断线重连)
|
||
func (c *Client) WithRetry(fn func(*sftp.Client) error) error {
|
||
const maxRetries = 3
|
||
var lastErr error
|
||
|
||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||
if attempt > 0 {
|
||
time.Sleep(time.Duration((attempt+1)*2) * time.Second)
|
||
if reconnectErr := c.reconnect(); reconnectErr != nil {
|
||
lastErr = reconnectErr
|
||
continue
|
||
}
|
||
}
|
||
|
||
c.mu.Lock()
|
||
client := c.client
|
||
c.mu.Unlock()
|
||
|
||
if client == nil {
|
||
lastErr = fmt.Errorf("SFTP 客户端未初始化")
|
||
continue
|
||
}
|
||
|
||
if err := fn(client); err != nil {
|
||
if isNetworkError(err) {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
return fmt.Errorf("操作失败(已重试 %d 次): %w", maxRetries, lastErr)
|
||
}
|
||
|
||
func (c *Client) reconnect() error {
|
||
nc, err := newClient(c.config)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
c.closeLocked()
|
||
c.client = nc.client
|
||
c.sshClient = nc.sshClient
|
||
return nil
|
||
}
|
||
|
||
func (c *Client) Close() {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
c.closeLocked()
|
||
}
|
||
|
||
func (c *Client) closeLocked() {
|
||
if c.client != nil {
|
||
c.client.Close()
|
||
c.client = nil
|
||
}
|
||
if c.sshClient != nil {
|
||
c.sshClient.Close()
|
||
c.sshClient = nil
|
||
}
|
||
}
|
||
|
||
// RunCommand 通过 SSH Session 执行远程命令,返回 stdout
|
||
func (c *Client) RunCommand(cmd string) (string, error) {
|
||
c.mu.Lock()
|
||
sshClient := c.sshClient
|
||
c.mu.Unlock()
|
||
|
||
if sshClient == nil {
|
||
return "", fmt.Errorf("SSH 客户端未初始化")
|
||
}
|
||
|
||
session, err := sshClient.NewSession()
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建 SSH 会话失败: %w", err)
|
||
}
|
||
defer session.Close()
|
||
|
||
out, err := session.CombinedOutput(cmd)
|
||
if err != nil {
|
||
return "", fmt.Errorf("执行命令失败 [%s]: %w", cmd, err)
|
||
}
|
||
return string(out), nil
|
||
}
|
||
|
||
// RawClient 获取底层 sftp.Client(高级用法)
|
||
func (c *Client) RawClient() *sftp.Client {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
return c.client
|
||
}
|