新增: ssh-proxy 文件传输 (并行 SFTP)

基于 russh-sftp 实现多通道并行文件传输,支持 upload/download,
默认 4 通道并行,可通过 -c 参数动态调整并发数。
传输失败时自动清理不完整的远程/本地文件。
This commit is contained in:
2026-03-30 09:52:32 +08:00
parent 780f683706
commit f77ed2572a
5 changed files with 476 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ anyhow = "1"
clap = { version = "4", features = ["derive"] }
ureq = "2" # CLI HTTP client (2.x has simpler API)
time = { version = "0.3", features = ["formatting", "macros"] }
russh-sftp = "2.0.8"
[profile.release]
opt-level = "z"

View File

@@ -32,6 +32,28 @@ struct ServersResponse {
servers: Vec<ServerInfo>,
}
#[derive(Debug, Serialize)]
struct UploadRequest {
server: String,
#[serde(rename = "localPath")]
local_path: String,
#[serde(rename = "remotePath")]
remote_path: String,
#[serde(rename = "concurrency")]
concurrency: usize,
}
#[derive(Debug, Deserialize)]
struct TransferResponse {
success: bool,
size: u64,
#[serde(rename = "sizeFormatted")]
size_formatted: String,
#[serde(rename = "durationMs")]
duration_ms: u64,
path: String,
}
#[derive(Debug, Deserialize)]
struct ServerInfo {
name: String,
@@ -209,4 +231,70 @@ impl Cli {
Err(_) => Ok(false),
}
}
/// 上传文件到远程服务器
pub fn upload(&self, server: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result<()> {
let body = serde_json::to_string(&UploadRequest {
server: server.to_string(),
local_path: local_path.to_string(),
remote_path: remote_path.to_string(),
concurrency,
})?;
let response = ureq::post(&format!("{}/upload", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
let resp_body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&resp_body) {
eprintln!("Error: {}", err.error);
if let Some(usage) = err.usage {
eprintln!("\n{}", usage);
}
return Ok(());
}
let result: TransferResponse = serde_json::from_str(&resp_body)?;
println!("Uploaded: {} -> {}", local_path, remote_path);
println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms);
Ok(())
}
/// 从远程服务器下载文件
pub fn download(&self, server: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result<()> {
let body = serde_json::to_string(&serde_json::json!({
"server": server,
"remotePath": remote_path,
"localPath": local_path,
"concurrency": concurrency,
}))?;
let response = ureq::post(&format!("{}/download", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
let resp_body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&resp_body) {
eprintln!("Error: {}", err.error);
if let Some(usage) = err.usage {
eprintln!("\n{}", usage);
}
return Ok(());
}
let result: TransferResponse = serde_json::from_str(&resp_body)?;
println!("Downloaded: {} -> {}", remote_path, local_path);
println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms);
Ok(())
}
}

View File

@@ -81,6 +81,7 @@ pub async fn exec(
let result = manager.exec(&req.server, &req.command).await
.map_err(|e| {
logger.log(&LogEntry::new("/exec", "http").with_server(&req.server).with_command(&req.command).with_error(&e.to_string()));
let usage = format!(
"Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}",
server_names
@@ -136,7 +137,10 @@ pub async fn add_server(
};
manager.add_server(cfg).await
.map_err(|e| error_response(&e.to_string()))?;
.map_err(|e| {
logger.log(&LogEntry::new("/servers/add", "http").with_server(&req.name).with_error(&e.to_string()));
error_response(&e.to_string())
})?;
// 记录日志
logger.log(&LogEntry::new("/servers/add", "http")
@@ -149,6 +153,125 @@ pub async fn add_server(
})))
}
#[derive(Debug, Serialize)]
pub struct TransferResponse {
pub success: bool,
pub size: u64,
#[serde(rename = "sizeFormatted")]
pub size_formatted: String,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
pub path: String,
}
#[derive(Debug, Deserialize)]
pub struct UploadRequest {
pub server: String,
#[serde(rename = "localPath")]
pub local_path: String,
#[serde(rename = "remotePath")]
pub remote_path: String,
#[serde(default = "default_concurrency")]
pub concurrency: usize,
}
#[derive(Debug, Deserialize)]
pub struct DownloadRequest {
pub server: String,
#[serde(rename = "remotePath")]
pub remote_path: String,
#[serde(rename = "localPath")]
pub local_path: String,
#[serde(default = "default_concurrency")]
pub concurrency: usize,
}
fn default_concurrency() -> usize { 4 }
pub async fn upload(
State(state): State<AppState>,
Json(req): Json<UploadRequest>,
) -> Result<Json<TransferResponse>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
let server_names = list_server_names(manager);
let result = manager.upload(&req.server, &req.local_path, &req.remote_path, req.concurrency).await
.map_err(|e| {
logger.log(&LogEntry::new("/upload", "http").with_server(&req.server).with_error(&e.to_string()));
let usage = format!(
"Usage: POST /upload {{\"server\": \"name\", \"localPath\": \"./file.txt\", \"remotePath\": \"/remote/file.txt\"}}\n\nAvailable servers: {}",
server_names
);
error_response_with_usage(&e.to_string(), &usage)
})?;
let duration_ms = result.duration_ms;
let size = result.size;
let size_formatted = format_size(size);
logger.log(&LogEntry::new("/upload", "http")
.with_server(&req.server)
.with_duration(duration_ms));
Ok(Json(TransferResponse {
success: true,
size,
size_formatted,
duration_ms,
path: result.path,
}))
}
pub async fn download(
State(state): State<AppState>,
Json(req): Json<DownloadRequest>,
) -> Result<Json<TransferResponse>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
let server_names = list_server_names(manager);
let result = manager.download(&req.server, &req.remote_path, &req.local_path, req.concurrency).await
.map_err(|e| {
logger.log(&LogEntry::new("/download", "http").with_server(&req.server).with_error(&e.to_string()));
let usage = format!(
"Usage: POST /download {{\"server\": \"name\", \"remotePath\": \"/remote/file.txt\", \"localPath\": \"./file.txt\"}}\n\nAvailable servers: {}",
server_names
);
error_response_with_usage(&e.to_string(), &usage)
})?;
let duration_ms = result.duration_ms;
let size = result.size;
let size_formatted = format_size(size);
logger.log(&LogEntry::new("/download", "http")
.with_server(&req.server)
.with_duration(duration_ms));
Ok(Json(TransferResponse {
success: true,
size,
size_formatted,
duration_ms,
path: result.path,
}))
}
fn format_size(bytes: u64) -> String {
const KB: u64 = 1024;
const MB: u64 = 1024 * KB;
const GB: u64 = 1024 * MB;
if bytes >= GB {
format!("{:.2} GB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.2} MB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.2} KB", bytes as f64 / KB as f64)
} else {
format!("{} B", bytes)
}
}
pub async fn health(
State(state): State<AppState>,
) -> Json<HealthResponse> {

View File

@@ -38,6 +38,34 @@ enum Commands {
format: String,
},
Servers,
/// Upload file to remote server (SFTP)
Upload {
#[arg(short = 'n', long)]
name: String,
/// Local file path (source)
#[arg(short = 's', long)]
src: String,
/// Remote file path (destination)
#[arg(short = 'd', long)]
dst: String,
/// Parallel channels (default: 4)
#[arg(short = 'c', long, default_value = "4")]
concurrency: usize,
},
/// Download file from remote server (SFTP)
Download {
#[arg(short = 'n', long)]
name: String,
/// Remote file path (source)
#[arg(short = 's', long)]
src: String,
/// Local file path (destination)
#[arg(short = 'd', long)]
dst: String,
/// Parallel channels (default: 4)
#[arg(short = 'c', long, default_value = "4")]
concurrency: usize,
},
/// Add server dynamically (temporary)
AddServer {
#[arg(short = 'n', long)]
@@ -62,6 +90,8 @@ async fn main() -> anyhow::Result<()> {
Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await,
Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format),
Some(Commands::Servers) => run_cli_servers(&args.server),
Some(Commands::Upload { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "upload", &name, &src, &dst, concurrency),
Some(Commands::Download { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "download", &name, &src, &dst, concurrency),
Some(Commands::AddServer { name, host, port, user, password, private_key }) => {
run_cli_add_server(&args.server, name, host, port, user, password, private_key)
}
@@ -100,6 +130,8 @@ async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>)
.route("/exec", post(handler::exec))
.route("/servers", get(handler::servers))
.route("/servers/add", post(handler::add_server))
.route("/upload", post(handler::upload))
.route("/download", post(handler::download))
.route("/health", get(handler::health))
.with_state(Arc::new((manager, logger)));
@@ -109,12 +141,15 @@ async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>)
println!("\nServer started at http://{}", addr);
println!("\nAPI Endpoints:");
println!(" POST /exec - Execute remote command");
println!(" POST /upload - Upload file (SFTP)");
println!(" POST /download - Download file (SFTP)");
println!(" POST /servers/add - Add server (temporary)");
println!(" GET /servers - List all servers");
println!(" GET /health - Health check");
println!("\nCLI Usage:");
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\"");
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json");
println!(" ssh-proxy upload -n flux_dev -s ./local.txt -d /remote/path.txt");
println!(" ssh-proxy download -n flux_dev -s /remote/path.txt -d ./local.txt");
println!(" ssh-proxy servers");
axum::serve(listener, app).await?;
@@ -150,6 +185,21 @@ fn run_cli_servers(server_url: &str) -> anyhow::Result<()> {
cli.list_servers()
}
fn run_cli_transfer(server_url: &str, action: &str, server: &str, src: &str, dst: &str, concurrency: usize) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server_url.to_string()));
if !cli.check_server()? {
eprintln!();
eprintln!("=== ssh-proxy 未运行 ===");
eprintln!("启动代理: ssh-proxy");
anyhow::bail!("Proxy server not running at {}", server_url);
}
match action {
"upload" => cli.upload(server, src, dst, concurrency),
"download" => cli.download(server, src, dst, concurrency),
_ => anyhow::bail!("Unknown action: {}", action),
}
}
fn run_cli_add_server(
server_url: &str,
name: String,

View File

@@ -4,8 +4,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use std::path::Path;
use tokio::sync::Mutex;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
use russh::client::{self, Config};
use russh::keys::{load_secret_key, PrivateKeyWithHashAlg, ssh_key};
use russh_sftp::client::SftpSession;
use russh_sftp::protocol::OpenFlags;
use crate::config::SshServerConfig;
@@ -177,6 +180,208 @@ impl SessionManager {
})
}
/// 上传文件到远程服务器 (并行多通道)
pub async fn upload(&self, name: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result<TransferResult> {
let start = Instant::now();
let session = self.get_or_create_session(name).await?;
// 读取本地文件大小
let metadata = tokio::fs::metadata(local_path).await
.map_err(|e| anyhow::anyhow!("Cannot stat local file '{}': {}", local_path, e))?;
let file_size = metadata.len();
let concurrency = concurrency.min(file_size as usize).max(1);
let chunk_size = (file_size as usize + concurrency - 1) / concurrency;
// 先用第一个 channel 创建远程文件 (CREATE | TRUNCATE),获取大小后并行写入
// 实际上每个 channel 都可以 CREATE | WRITESFTP 保证并发安全
let mut handles = Vec::new();
for i in 0..concurrency {
let offset = (i * chunk_size) as u64;
let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize;
if len == 0 { break; }
let session = session.clone();
let local = local_path.to_string();
let remote = remote_path.to_string();
handles.push(tokio::spawn(async move {
// 每个任务独立开 SFTP channel
let channel = session.channel_open_session().await
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
channel.request_subsystem(true, "sftp").await
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
let sftp = SftpSession::new(channel.into_stream()).await
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
// 打开远程文件 (WRITE | CREATE不 TRUNCATE第一个 chunk 创建后其他追加)
let flags = if i == 0 {
OpenFlags::WRITE | OpenFlags::CREATE | OpenFlags::TRUNCATE
} else {
OpenFlags::WRITE | OpenFlags::CREATE
};
let mut remote_file = sftp.open_with_flags(&remote, flags).await
.map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?;
remote_file.seek(std::io::SeekFrom::Start(offset)).await
.map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?;
// 读取本地文件 chunk 并写入
let mut local_file = tokio::fs::File::open(&local).await?;
local_file.seek(std::io::SeekFrom::Start(offset)).await?;
let mut buf = vec![0u8; len];
local_file.read_exact(&mut buf).await?;
remote_file.write_all(&buf).await
.map_err(|e| anyhow::anyhow!("Write failed: {}", e))?;
remote_file.shutdown().await.ok();
Ok::<_, anyhow::Error>(len as u64)
}));
}
// 等待所有任务完成
let mut total_written = 0u64;
for handle in handles {
match handle.await {
Ok(Ok(written)) => total_written += written,
Ok(Err(e)) => {
// 上传失败,尝试清理远程文件
self.remove_remote_file(name, remote_path).await.ok();
return Err(e);
}
Err(e) => {
self.remove_remote_file(name, remote_path).await.ok();
return Err(anyhow::anyhow!("Task failed: {}", e));
}
}
}
Ok(TransferResult {
size: total_written,
duration_ms: start.elapsed().as_millis() as u64,
path: remote_path.to_string(),
})
}
/// 从远程服务器下载文件 (并行多通道)
pub async fn download(&self, name: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result<TransferResult> {
let start = Instant::now();
let session = self.get_or_create_session(name).await?;
// 先获取远程文件大小
let channel = session.channel_open_session().await
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
channel.request_subsystem(true, "sftp").await
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
let sftp = SftpSession::new(channel.into_stream()).await
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
let file_size = sftp.metadata(remote_path).await
.map_err(|e| anyhow::anyhow!("Cannot stat remote file '{}': {}", remote_path, e))?
.len();
sftp.close().await.ok();
let concurrency = concurrency.min(file_size as usize).max(1);
let chunk_size = (file_size as usize + concurrency - 1) / concurrency;
// 创建本地文件并预分配大小
let local_file = tokio::fs::File::create(local_path).await
.map_err(|e| anyhow::anyhow!("Cannot create local file '{}': {}", local_path, e))?;
local_file.set_len(file_size).await
.map_err(|e| anyhow::anyhow!("Cannot set file size: {}", e))?;
drop(local_file); // 关闭以便并行写入
let mut handles = Vec::new();
for i in 0..concurrency {
let offset = (i * chunk_size) as u64;
let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize;
if len == 0 { break; }
let session = session.clone();
let remote = remote_path.to_string();
let local = local_path.to_string();
handles.push(tokio::spawn(async move {
let channel = session.channel_open_session().await
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
channel.request_subsystem(true, "sftp").await
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
let sftp = SftpSession::new(channel.into_stream()).await
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
let mut remote_file = sftp.open(&remote).await
.map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?;
remote_file.seek(std::io::SeekFrom::Start(offset)).await
.map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?;
let mut buf = vec![0u8; len];
remote_file.read_exact(&mut buf).await
.map_err(|e| anyhow::anyhow!("Read failed: {}", e))?;
// 写入本地文件对应 offset
let mut local_file = tokio::fs::OpenOptions::new()
.write(true)
.open(&local).await?;
local_file.seek(std::io::SeekFrom::Start(offset)).await?;
local_file.write_all(&buf).await
.map_err(|e| anyhow::anyhow!("Local write failed: {}", e))?;
local_file.flush().await?;
Ok::<_, anyhow::Error>(len as u64)
}));
}
// 等待所有任务完成
let mut total_read = 0u64;
for handle in handles {
match handle.await {
Ok(Ok(read)) => total_read += read,
Ok(Err(e)) => {
// 下载失败,删除不完整的本地文件
tokio::fs::remove_file(local_path).await.ok();
return Err(e);
}
Err(e) => {
tokio::fs::remove_file(local_path).await.ok();
return Err(anyhow::anyhow!("Task failed: {}", e));
}
}
}
Ok(TransferResult {
size: total_read,
duration_ms: start.elapsed().as_millis() as u64,
path: local_path.to_string(),
})
}
/// 删除远程文件 (传输失败时清理)
async fn remove_remote_file(&self, name: &str, remote_path: &str) -> Result<()> {
let session = self.get_or_create_session(name).await?;
let channel = session.channel_open_session().await
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
channel.request_subsystem(true, "sftp").await
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
let sftp = SftpSession::new(channel.into_stream()).await
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
sftp.remove_file(remote_path).await
.map_err(|e| anyhow::anyhow!("Failed to remove remote file: {}", e))?;
sftp.close().await.ok();
Ok(())
}
pub fn list_servers(&self) -> Vec<ServerInfo> {
// 使用 try_lock 避免阻塞,如果锁不可用则返回 pending
let sessions = self.sessions.try_lock();
@@ -268,3 +473,10 @@ pub struct ServerInfo {
pub user: String,
pub status: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct TransferResult {
pub size: u64,
pub duration_ms: u64,
pub path: String,
}