From f77ed2572af792e91c9d21ffd54b8dbf6ab21e6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=9D=E5=B0=98?= <237809796@qq.com> Date: Mon, 30 Mar 2026 09:52:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E:=20ssh-proxy=20=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E4=BC=A0=E8=BE=93=20(=E5=B9=B6=E8=A1=8C=20SFTP)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 基于 russh-sftp 实现多通道并行文件传输,支持 upload/download, 默认 4 通道并行,可通过 -c 参数动态调整并发数。 传输失败时自动清理不完整的远程/本地文件。 --- ssh-proxy/Cargo.toml | 1 + ssh-proxy/src/cli.rs | 88 ++++++++++++++++ ssh-proxy/src/handler.rs | 125 ++++++++++++++++++++++- ssh-proxy/src/main.rs | 52 +++++++++- ssh-proxy/src/session.rs | 212 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 476 insertions(+), 2 deletions(-) diff --git a/ssh-proxy/Cargo.toml b/ssh-proxy/Cargo.toml index cf51fba..d2bc97f 100644 --- a/ssh-proxy/Cargo.toml +++ b/ssh-proxy/Cargo.toml @@ -16,6 +16,7 @@ anyhow = "1" clap = { version = "4", features = ["derive"] } ureq = "2" # CLI HTTP client (2.x has simpler API) time = { version = "0.3", features = ["formatting", "macros"] } +russh-sftp = "2.0.8" [profile.release] opt-level = "z" diff --git a/ssh-proxy/src/cli.rs b/ssh-proxy/src/cli.rs index f60d8b9..3530265 100644 --- a/ssh-proxy/src/cli.rs +++ b/ssh-proxy/src/cli.rs @@ -32,6 +32,28 @@ struct ServersResponse { servers: Vec, } +#[derive(Debug, Serialize)] +struct UploadRequest { + server: String, + #[serde(rename = "localPath")] + local_path: String, + #[serde(rename = "remotePath")] + remote_path: String, + #[serde(rename = "concurrency")] + concurrency: usize, +} + +#[derive(Debug, Deserialize)] +struct TransferResponse { + success: bool, + size: u64, + #[serde(rename = "sizeFormatted")] + size_formatted: String, + #[serde(rename = "durationMs")] + duration_ms: u64, + path: String, +} + #[derive(Debug, Deserialize)] struct ServerInfo { name: String, @@ -209,4 +231,70 @@ impl Cli { Err(_) => Ok(false), } } + + /// 上传文件到远程服务器 + pub fn upload(&self, server: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result<()> { + let body = serde_json::to_string(&UploadRequest { + server: server.to_string(), + local_path: local_path.to_string(), + remote_path: remote_path.to_string(), + concurrency, + })?; + + let response = ureq::post(&format!("{}/upload", self.server)) + .set("Content-Type", "application/json") + .send_string(&body); + + let resp_body = match response { + Ok(r) => r.into_string()?, + Err(ureq::Error::Status(_, resp)) => resp.into_string()?, + Err(e) => bail!("HTTP error: {}", e), + }; + + if let Ok(err) = serde_json::from_str::(&resp_body) { + eprintln!("Error: {}", err.error); + if let Some(usage) = err.usage { + eprintln!("\n{}", usage); + } + return Ok(()); + } + + let result: TransferResponse = serde_json::from_str(&resp_body)?; + println!("Uploaded: {} -> {}", local_path, remote_path); + println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms); + Ok(()) + } + + /// 从远程服务器下载文件 + pub fn download(&self, server: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result<()> { + let body = serde_json::to_string(&serde_json::json!({ + "server": server, + "remotePath": remote_path, + "localPath": local_path, + "concurrency": concurrency, + }))?; + + let response = ureq::post(&format!("{}/download", self.server)) + .set("Content-Type", "application/json") + .send_string(&body); + + let resp_body = match response { + Ok(r) => r.into_string()?, + Err(ureq::Error::Status(_, resp)) => resp.into_string()?, + Err(e) => bail!("HTTP error: {}", e), + }; + + if let Ok(err) = serde_json::from_str::(&resp_body) { + eprintln!("Error: {}", err.error); + if let Some(usage) = err.usage { + eprintln!("\n{}", usage); + } + return Ok(()); + } + + let result: TransferResponse = serde_json::from_str(&resp_body)?; + println!("Downloaded: {} -> {}", remote_path, local_path); + println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms); + Ok(()) + } } diff --git a/ssh-proxy/src/handler.rs b/ssh-proxy/src/handler.rs index e135325..dd0b77e 100644 --- a/ssh-proxy/src/handler.rs +++ b/ssh-proxy/src/handler.rs @@ -81,6 +81,7 @@ pub async fn exec( let result = manager.exec(&req.server, &req.command).await .map_err(|e| { + logger.log(&LogEntry::new("/exec", "http").with_server(&req.server).with_command(&req.command).with_error(&e.to_string())); let usage = format!( "Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}", server_names @@ -136,7 +137,10 @@ pub async fn add_server( }; manager.add_server(cfg).await - .map_err(|e| error_response(&e.to_string()))?; + .map_err(|e| { + logger.log(&LogEntry::new("/servers/add", "http").with_server(&req.name).with_error(&e.to_string())); + error_response(&e.to_string()) + })?; // 记录日志 logger.log(&LogEntry::new("/servers/add", "http") @@ -149,6 +153,125 @@ pub async fn add_server( }))) } +#[derive(Debug, Serialize)] +pub struct TransferResponse { + pub success: bool, + pub size: u64, + #[serde(rename = "sizeFormatted")] + pub size_formatted: String, + #[serde(rename = "durationMs")] + pub duration_ms: u64, + pub path: String, +} + +#[derive(Debug, Deserialize)] +pub struct UploadRequest { + pub server: String, + #[serde(rename = "localPath")] + pub local_path: String, + #[serde(rename = "remotePath")] + pub remote_path: String, + #[serde(default = "default_concurrency")] + pub concurrency: usize, +} + +#[derive(Debug, Deserialize)] +pub struct DownloadRequest { + pub server: String, + #[serde(rename = "remotePath")] + pub remote_path: String, + #[serde(rename = "localPath")] + pub local_path: String, + #[serde(default = "default_concurrency")] + pub concurrency: usize, +} + +fn default_concurrency() -> usize { 4 } + +pub async fn upload( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + let server_names = list_server_names(manager); + + let result = manager.upload(&req.server, &req.local_path, &req.remote_path, req.concurrency).await + .map_err(|e| { + logger.log(&LogEntry::new("/upload", "http").with_server(&req.server).with_error(&e.to_string())); + let usage = format!( + "Usage: POST /upload {{\"server\": \"name\", \"localPath\": \"./file.txt\", \"remotePath\": \"/remote/file.txt\"}}\n\nAvailable servers: {}", + server_names + ); + error_response_with_usage(&e.to_string(), &usage) + })?; + + let duration_ms = result.duration_ms; + let size = result.size; + let size_formatted = format_size(size); + + logger.log(&LogEntry::new("/upload", "http") + .with_server(&req.server) + .with_duration(duration_ms)); + + Ok(Json(TransferResponse { + success: true, + size, + size_formatted, + duration_ms, + path: result.path, + })) +} + +pub async fn download( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + let server_names = list_server_names(manager); + + let result = manager.download(&req.server, &req.remote_path, &req.local_path, req.concurrency).await + .map_err(|e| { + logger.log(&LogEntry::new("/download", "http").with_server(&req.server).with_error(&e.to_string())); + let usage = format!( + "Usage: POST /download {{\"server\": \"name\", \"remotePath\": \"/remote/file.txt\", \"localPath\": \"./file.txt\"}}\n\nAvailable servers: {}", + server_names + ); + error_response_with_usage(&e.to_string(), &usage) + })?; + + let duration_ms = result.duration_ms; + let size = result.size; + let size_formatted = format_size(size); + + logger.log(&LogEntry::new("/download", "http") + .with_server(&req.server) + .with_duration(duration_ms)); + + Ok(Json(TransferResponse { + success: true, + size, + size_formatted, + duration_ms, + path: result.path, + })) +} + +fn format_size(bytes: u64) -> String { + const KB: u64 = 1024; + const MB: u64 = 1024 * KB; + const GB: u64 = 1024 * MB; + + if bytes >= GB { + format!("{:.2} GB", bytes as f64 / GB as f64) + } else if bytes >= MB { + format!("{:.2} MB", bytes as f64 / MB as f64) + } else if bytes >= KB { + format!("{:.2} KB", bytes as f64 / KB as f64) + } else { + format!("{} B", bytes) + } +} + pub async fn health( State(state): State, ) -> Json { diff --git a/ssh-proxy/src/main.rs b/ssh-proxy/src/main.rs index 024148d..ca4be6c 100644 --- a/ssh-proxy/src/main.rs +++ b/ssh-proxy/src/main.rs @@ -38,6 +38,34 @@ enum Commands { format: String, }, Servers, + /// Upload file to remote server (SFTP) + Upload { + #[arg(short = 'n', long)] + name: String, + /// Local file path (source) + #[arg(short = 's', long)] + src: String, + /// Remote file path (destination) + #[arg(short = 'd', long)] + dst: String, + /// Parallel channels (default: 4) + #[arg(short = 'c', long, default_value = "4")] + concurrency: usize, + }, + /// Download file from remote server (SFTP) + Download { + #[arg(short = 'n', long)] + name: String, + /// Remote file path (source) + #[arg(short = 's', long)] + src: String, + /// Local file path (destination) + #[arg(short = 'd', long)] + dst: String, + /// Parallel channels (default: 4) + #[arg(short = 'c', long, default_value = "4")] + concurrency: usize, + }, /// Add server dynamically (temporary) AddServer { #[arg(short = 'n', long)] @@ -62,6 +90,8 @@ async fn main() -> anyhow::Result<()> { Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await, Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format), Some(Commands::Servers) => run_cli_servers(&args.server), + Some(Commands::Upload { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "upload", &name, &src, &dst, concurrency), + Some(Commands::Download { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "download", &name, &src, &dst, concurrency), Some(Commands::AddServer { name, host, port, user, password, private_key }) => { run_cli_add_server(&args.server, name, host, port, user, password, private_key) } @@ -100,6 +130,8 @@ async fn run_server(config_path: &str, port: Option, host: Option) .route("/exec", post(handler::exec)) .route("/servers", get(handler::servers)) .route("/servers/add", post(handler::add_server)) + .route("/upload", post(handler::upload)) + .route("/download", post(handler::download)) .route("/health", get(handler::health)) .with_state(Arc::new((manager, logger))); @@ -109,12 +141,15 @@ async fn run_server(config_path: &str, port: Option, host: Option) println!("\nServer started at http://{}", addr); println!("\nAPI Endpoints:"); println!(" POST /exec - Execute remote command"); + println!(" POST /upload - Upload file (SFTP)"); + println!(" POST /download - Download file (SFTP)"); println!(" POST /servers/add - Add server (temporary)"); println!(" GET /servers - List all servers"); println!(" GET /health - Health check"); println!("\nCLI Usage:"); println!(" ssh-proxy exec -n flux_dev -c \"docker ps\""); - println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json"); + println!(" ssh-proxy upload -n flux_dev -s ./local.txt -d /remote/path.txt"); + println!(" ssh-proxy download -n flux_dev -s /remote/path.txt -d ./local.txt"); println!(" ssh-proxy servers"); axum::serve(listener, app).await?; @@ -150,6 +185,21 @@ fn run_cli_servers(server_url: &str) -> anyhow::Result<()> { cli.list_servers() } +fn run_cli_transfer(server_url: &str, action: &str, server: &str, src: &str, dst: &str, concurrency: usize) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server_url.to_string())); + if !cli.check_server()? { + eprintln!(); + eprintln!("=== ssh-proxy 未运行 ==="); + eprintln!("启动代理: ssh-proxy"); + anyhow::bail!("Proxy server not running at {}", server_url); + } + match action { + "upload" => cli.upload(server, src, dst, concurrency), + "download" => cli.download(server, src, dst, concurrency), + _ => anyhow::bail!("Unknown action: {}", action), + } +} + fn run_cli_add_server( server_url: &str, name: String, diff --git a/ssh-proxy/src/session.rs b/ssh-proxy/src/session.rs index b1fe0f3..49d0591 100644 --- a/ssh-proxy/src/session.rs +++ b/ssh-proxy/src/session.rs @@ -4,8 +4,11 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use std::path::Path; use tokio::sync::Mutex; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use russh::client::{self, Config}; use russh::keys::{load_secret_key, PrivateKeyWithHashAlg, ssh_key}; +use russh_sftp::client::SftpSession; +use russh_sftp::protocol::OpenFlags; use crate::config::SshServerConfig; @@ -177,6 +180,208 @@ impl SessionManager { }) } + /// 上传文件到远程服务器 (并行多通道) + pub async fn upload(&self, name: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result { + let start = Instant::now(); + let session = self.get_or_create_session(name).await?; + + // 读取本地文件大小 + let metadata = tokio::fs::metadata(local_path).await + .map_err(|e| anyhow::anyhow!("Cannot stat local file '{}': {}", local_path, e))?; + let file_size = metadata.len(); + let concurrency = concurrency.min(file_size as usize).max(1); + let chunk_size = (file_size as usize + concurrency - 1) / concurrency; + + // 先用第一个 channel 创建远程文件 (CREATE | TRUNCATE),获取大小后并行写入 + // 实际上每个 channel 都可以 CREATE | WRITE,SFTP 保证并发安全 + let mut handles = Vec::new(); + + for i in 0..concurrency { + let offset = (i * chunk_size) as u64; + let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize; + if len == 0 { break; } + + let session = session.clone(); + let local = local_path.to_string(); + let remote = remote_path.to_string(); + + handles.push(tokio::spawn(async move { + // 每个任务独立开 SFTP channel + let channel = session.channel_open_session().await + .map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?; + channel.request_subsystem(true, "sftp").await + .map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?; + + let sftp = SftpSession::new(channel.into_stream()).await + .map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?; + + // 打开远程文件 (WRITE | CREATE,不 TRUNCATE,第一个 chunk 创建后其他追加) + let flags = if i == 0 { + OpenFlags::WRITE | OpenFlags::CREATE | OpenFlags::TRUNCATE + } else { + OpenFlags::WRITE | OpenFlags::CREATE + }; + let mut remote_file = sftp.open_with_flags(&remote, flags).await + .map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?; + + remote_file.seek(std::io::SeekFrom::Start(offset)).await + .map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?; + + // 读取本地文件 chunk 并写入 + let mut local_file = tokio::fs::File::open(&local).await?; + local_file.seek(std::io::SeekFrom::Start(offset)).await?; + + let mut buf = vec![0u8; len]; + local_file.read_exact(&mut buf).await?; + + remote_file.write_all(&buf).await + .map_err(|e| anyhow::anyhow!("Write failed: {}", e))?; + + remote_file.shutdown().await.ok(); + Ok::<_, anyhow::Error>(len as u64) + })); + } + + // 等待所有任务完成 + let mut total_written = 0u64; + for handle in handles { + match handle.await { + Ok(Ok(written)) => total_written += written, + Ok(Err(e)) => { + // 上传失败,尝试清理远程文件 + self.remove_remote_file(name, remote_path).await.ok(); + return Err(e); + } + Err(e) => { + self.remove_remote_file(name, remote_path).await.ok(); + return Err(anyhow::anyhow!("Task failed: {}", e)); + } + } + } + + Ok(TransferResult { + size: total_written, + duration_ms: start.elapsed().as_millis() as u64, + path: remote_path.to_string(), + }) + } + + /// 从远程服务器下载文件 (并行多通道) + pub async fn download(&self, name: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result { + let start = Instant::now(); + let session = self.get_or_create_session(name).await?; + + // 先获取远程文件大小 + let channel = session.channel_open_session().await + .map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?; + channel.request_subsystem(true, "sftp").await + .map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?; + + let sftp = SftpSession::new(channel.into_stream()).await + .map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?; + + let file_size = sftp.metadata(remote_path).await + .map_err(|e| anyhow::anyhow!("Cannot stat remote file '{}': {}", remote_path, e))? + .len(); + + sftp.close().await.ok(); + + let concurrency = concurrency.min(file_size as usize).max(1); + let chunk_size = (file_size as usize + concurrency - 1) / concurrency; + + // 创建本地文件并预分配大小 + let local_file = tokio::fs::File::create(local_path).await + .map_err(|e| anyhow::anyhow!("Cannot create local file '{}': {}", local_path, e))?; + local_file.set_len(file_size).await + .map_err(|e| anyhow::anyhow!("Cannot set file size: {}", e))?; + drop(local_file); // 关闭以便并行写入 + + let mut handles = Vec::new(); + + for i in 0..concurrency { + let offset = (i * chunk_size) as u64; + let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize; + if len == 0 { break; } + + let session = session.clone(); + let remote = remote_path.to_string(); + let local = local_path.to_string(); + + handles.push(tokio::spawn(async move { + let channel = session.channel_open_session().await + .map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?; + channel.request_subsystem(true, "sftp").await + .map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?; + + let sftp = SftpSession::new(channel.into_stream()).await + .map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?; + + let mut remote_file = sftp.open(&remote).await + .map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?; + + remote_file.seek(std::io::SeekFrom::Start(offset)).await + .map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?; + + let mut buf = vec![0u8; len]; + remote_file.read_exact(&mut buf).await + .map_err(|e| anyhow::anyhow!("Read failed: {}", e))?; + + // 写入本地文件对应 offset + let mut local_file = tokio::fs::OpenOptions::new() + .write(true) + .open(&local).await?; + local_file.seek(std::io::SeekFrom::Start(offset)).await?; + local_file.write_all(&buf).await + .map_err(|e| anyhow::anyhow!("Local write failed: {}", e))?; + local_file.flush().await?; + + Ok::<_, anyhow::Error>(len as u64) + })); + } + + // 等待所有任务完成 + let mut total_read = 0u64; + for handle in handles { + match handle.await { + Ok(Ok(read)) => total_read += read, + Ok(Err(e)) => { + // 下载失败,删除不完整的本地文件 + tokio::fs::remove_file(local_path).await.ok(); + return Err(e); + } + Err(e) => { + tokio::fs::remove_file(local_path).await.ok(); + return Err(anyhow::anyhow!("Task failed: {}", e)); + } + } + } + + Ok(TransferResult { + size: total_read, + duration_ms: start.elapsed().as_millis() as u64, + path: local_path.to_string(), + }) + } + + /// 删除远程文件 (传输失败时清理) + async fn remove_remote_file(&self, name: &str, remote_path: &str) -> Result<()> { + let session = self.get_or_create_session(name).await?; + + let channel = session.channel_open_session().await + .map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?; + channel.request_subsystem(true, "sftp").await + .map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?; + + let sftp = SftpSession::new(channel.into_stream()).await + .map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?; + + sftp.remove_file(remote_path).await + .map_err(|e| anyhow::anyhow!("Failed to remove remote file: {}", e))?; + + sftp.close().await.ok(); + Ok(()) + } + pub fn list_servers(&self) -> Vec { // 使用 try_lock 避免阻塞,如果锁不可用则返回 pending let sessions = self.sessions.try_lock(); @@ -268,3 +473,10 @@ pub struct ServerInfo { pub user: String, pub status: String, } + +#[derive(Debug, Clone, serde::Serialize)] +pub struct TransferResult { + pub size: u64, + pub duration_ms: u64, + pub path: String, +}