新增: ssh-proxy 文件传输 (并行 SFTP)
基于 russh-sftp 实现多通道并行文件传输,支持 upload/download, 默认 4 通道并行,可通过 -c 参数动态调整并发数。 传输失败时自动清理不完整的远程/本地文件。
This commit is contained in:
@@ -16,6 +16,7 @@ anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
ureq = "2" # CLI HTTP client (2.x has simpler API)
|
||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||
russh-sftp = "2.0.8"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
|
||||
@@ -32,6 +32,28 @@ struct ServersResponse {
|
||||
servers: Vec<ServerInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UploadRequest {
|
||||
server: String,
|
||||
#[serde(rename = "localPath")]
|
||||
local_path: String,
|
||||
#[serde(rename = "remotePath")]
|
||||
remote_path: String,
|
||||
#[serde(rename = "concurrency")]
|
||||
concurrency: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TransferResponse {
|
||||
success: bool,
|
||||
size: u64,
|
||||
#[serde(rename = "sizeFormatted")]
|
||||
size_formatted: String,
|
||||
#[serde(rename = "durationMs")]
|
||||
duration_ms: u64,
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ServerInfo {
|
||||
name: String,
|
||||
@@ -209,4 +231,70 @@ impl Cli {
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// 上传文件到远程服务器
|
||||
pub fn upload(&self, server: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result<()> {
|
||||
let body = serde_json::to_string(&UploadRequest {
|
||||
server: server.to_string(),
|
||||
local_path: local_path.to_string(),
|
||||
remote_path: remote_path.to_string(),
|
||||
concurrency,
|
||||
})?;
|
||||
|
||||
let response = ureq::post(&format!("{}/upload", self.server))
|
||||
.set("Content-Type", "application/json")
|
||||
.send_string(&body);
|
||||
|
||||
let resp_body = match response {
|
||||
Ok(r) => r.into_string()?,
|
||||
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||
Err(e) => bail!("HTTP error: {}", e),
|
||||
};
|
||||
|
||||
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&resp_body) {
|
||||
eprintln!("Error: {}", err.error);
|
||||
if let Some(usage) = err.usage {
|
||||
eprintln!("\n{}", usage);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result: TransferResponse = serde_json::from_str(&resp_body)?;
|
||||
println!("Uploaded: {} -> {}", local_path, remote_path);
|
||||
println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从远程服务器下载文件
|
||||
pub fn download(&self, server: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result<()> {
|
||||
let body = serde_json::to_string(&serde_json::json!({
|
||||
"server": server,
|
||||
"remotePath": remote_path,
|
||||
"localPath": local_path,
|
||||
"concurrency": concurrency,
|
||||
}))?;
|
||||
|
||||
let response = ureq::post(&format!("{}/download", self.server))
|
||||
.set("Content-Type", "application/json")
|
||||
.send_string(&body);
|
||||
|
||||
let resp_body = match response {
|
||||
Ok(r) => r.into_string()?,
|
||||
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||
Err(e) => bail!("HTTP error: {}", e),
|
||||
};
|
||||
|
||||
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&resp_body) {
|
||||
eprintln!("Error: {}", err.error);
|
||||
if let Some(usage) = err.usage {
|
||||
eprintln!("\n{}", usage);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result: TransferResponse = serde_json::from_str(&resp_body)?;
|
||||
println!("Downloaded: {} -> {}", remote_path, local_path);
|
||||
println!(" Size: {}, Time: {}ms", result.size_formatted, result.duration_ms);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ pub async fn exec(
|
||||
|
||||
let result = manager.exec(&req.server, &req.command).await
|
||||
.map_err(|e| {
|
||||
logger.log(&LogEntry::new("/exec", "http").with_server(&req.server).with_command(&req.command).with_error(&e.to_string()));
|
||||
let usage = format!(
|
||||
"Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}",
|
||||
server_names
|
||||
@@ -136,7 +137,10 @@ pub async fn add_server(
|
||||
};
|
||||
|
||||
manager.add_server(cfg).await
|
||||
.map_err(|e| error_response(&e.to_string()))?;
|
||||
.map_err(|e| {
|
||||
logger.log(&LogEntry::new("/servers/add", "http").with_server(&req.name).with_error(&e.to_string()));
|
||||
error_response(&e.to_string())
|
||||
})?;
|
||||
|
||||
// 记录日志
|
||||
logger.log(&LogEntry::new("/servers/add", "http")
|
||||
@@ -149,6 +153,125 @@ pub async fn add_server(
|
||||
})))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TransferResponse {
|
||||
pub success: bool,
|
||||
pub size: u64,
|
||||
#[serde(rename = "sizeFormatted")]
|
||||
pub size_formatted: String,
|
||||
#[serde(rename = "durationMs")]
|
||||
pub duration_ms: u64,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UploadRequest {
|
||||
pub server: String,
|
||||
#[serde(rename = "localPath")]
|
||||
pub local_path: String,
|
||||
#[serde(rename = "remotePath")]
|
||||
pub remote_path: String,
|
||||
#[serde(default = "default_concurrency")]
|
||||
pub concurrency: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DownloadRequest {
|
||||
pub server: String,
|
||||
#[serde(rename = "remotePath")]
|
||||
pub remote_path: String,
|
||||
#[serde(rename = "localPath")]
|
||||
pub local_path: String,
|
||||
#[serde(default = "default_concurrency")]
|
||||
pub concurrency: usize,
|
||||
}
|
||||
|
||||
fn default_concurrency() -> usize { 4 }
|
||||
|
||||
pub async fn upload(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<UploadRequest>,
|
||||
) -> Result<Json<TransferResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let (manager, logger) = &*state;
|
||||
let server_names = list_server_names(manager);
|
||||
|
||||
let result = manager.upload(&req.server, &req.local_path, &req.remote_path, req.concurrency).await
|
||||
.map_err(|e| {
|
||||
logger.log(&LogEntry::new("/upload", "http").with_server(&req.server).with_error(&e.to_string()));
|
||||
let usage = format!(
|
||||
"Usage: POST /upload {{\"server\": \"name\", \"localPath\": \"./file.txt\", \"remotePath\": \"/remote/file.txt\"}}\n\nAvailable servers: {}",
|
||||
server_names
|
||||
);
|
||||
error_response_with_usage(&e.to_string(), &usage)
|
||||
})?;
|
||||
|
||||
let duration_ms = result.duration_ms;
|
||||
let size = result.size;
|
||||
let size_formatted = format_size(size);
|
||||
|
||||
logger.log(&LogEntry::new("/upload", "http")
|
||||
.with_server(&req.server)
|
||||
.with_duration(duration_ms));
|
||||
|
||||
Ok(Json(TransferResponse {
|
||||
success: true,
|
||||
size,
|
||||
size_formatted,
|
||||
duration_ms,
|
||||
path: result.path,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn download(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<DownloadRequest>,
|
||||
) -> Result<Json<TransferResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let (manager, logger) = &*state;
|
||||
let server_names = list_server_names(manager);
|
||||
|
||||
let result = manager.download(&req.server, &req.remote_path, &req.local_path, req.concurrency).await
|
||||
.map_err(|e| {
|
||||
logger.log(&LogEntry::new("/download", "http").with_server(&req.server).with_error(&e.to_string()));
|
||||
let usage = format!(
|
||||
"Usage: POST /download {{\"server\": \"name\", \"remotePath\": \"/remote/file.txt\", \"localPath\": \"./file.txt\"}}\n\nAvailable servers: {}",
|
||||
server_names
|
||||
);
|
||||
error_response_with_usage(&e.to_string(), &usage)
|
||||
})?;
|
||||
|
||||
let duration_ms = result.duration_ms;
|
||||
let size = result.size;
|
||||
let size_formatted = format_size(size);
|
||||
|
||||
logger.log(&LogEntry::new("/download", "http")
|
||||
.with_server(&req.server)
|
||||
.with_duration(duration_ms));
|
||||
|
||||
Ok(Json(TransferResponse {
|
||||
success: true,
|
||||
size,
|
||||
size_formatted,
|
||||
duration_ms,
|
||||
path: result.path,
|
||||
}))
|
||||
}
|
||||
|
||||
fn format_size(bytes: u64) -> String {
|
||||
const KB: u64 = 1024;
|
||||
const MB: u64 = 1024 * KB;
|
||||
const GB: u64 = 1024 * MB;
|
||||
|
||||
if bytes >= GB {
|
||||
format!("{:.2} GB", bytes as f64 / GB as f64)
|
||||
} else if bytes >= MB {
|
||||
format!("{:.2} MB", bytes as f64 / MB as f64)
|
||||
} else if bytes >= KB {
|
||||
format!("{:.2} KB", bytes as f64 / KB as f64)
|
||||
} else {
|
||||
format!("{} B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn health(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<HealthResponse> {
|
||||
|
||||
@@ -38,6 +38,34 @@ enum Commands {
|
||||
format: String,
|
||||
},
|
||||
Servers,
|
||||
/// Upload file to remote server (SFTP)
|
||||
Upload {
|
||||
#[arg(short = 'n', long)]
|
||||
name: String,
|
||||
/// Local file path (source)
|
||||
#[arg(short = 's', long)]
|
||||
src: String,
|
||||
/// Remote file path (destination)
|
||||
#[arg(short = 'd', long)]
|
||||
dst: String,
|
||||
/// Parallel channels (default: 4)
|
||||
#[arg(short = 'c', long, default_value = "4")]
|
||||
concurrency: usize,
|
||||
},
|
||||
/// Download file from remote server (SFTP)
|
||||
Download {
|
||||
#[arg(short = 'n', long)]
|
||||
name: String,
|
||||
/// Remote file path (source)
|
||||
#[arg(short = 's', long)]
|
||||
src: String,
|
||||
/// Local file path (destination)
|
||||
#[arg(short = 'd', long)]
|
||||
dst: String,
|
||||
/// Parallel channels (default: 4)
|
||||
#[arg(short = 'c', long, default_value = "4")]
|
||||
concurrency: usize,
|
||||
},
|
||||
/// Add server dynamically (temporary)
|
||||
AddServer {
|
||||
#[arg(short = 'n', long)]
|
||||
@@ -62,6 +90,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await,
|
||||
Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format),
|
||||
Some(Commands::Servers) => run_cli_servers(&args.server),
|
||||
Some(Commands::Upload { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "upload", &name, &src, &dst, concurrency),
|
||||
Some(Commands::Download { name, src, dst, concurrency }) => run_cli_transfer(&args.server, "download", &name, &src, &dst, concurrency),
|
||||
Some(Commands::AddServer { name, host, port, user, password, private_key }) => {
|
||||
run_cli_add_server(&args.server, name, host, port, user, password, private_key)
|
||||
}
|
||||
@@ -100,6 +130,8 @@ async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>)
|
||||
.route("/exec", post(handler::exec))
|
||||
.route("/servers", get(handler::servers))
|
||||
.route("/servers/add", post(handler::add_server))
|
||||
.route("/upload", post(handler::upload))
|
||||
.route("/download", post(handler::download))
|
||||
.route("/health", get(handler::health))
|
||||
.with_state(Arc::new((manager, logger)));
|
||||
|
||||
@@ -109,12 +141,15 @@ async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>)
|
||||
println!("\nServer started at http://{}", addr);
|
||||
println!("\nAPI Endpoints:");
|
||||
println!(" POST /exec - Execute remote command");
|
||||
println!(" POST /upload - Upload file (SFTP)");
|
||||
println!(" POST /download - Download file (SFTP)");
|
||||
println!(" POST /servers/add - Add server (temporary)");
|
||||
println!(" GET /servers - List all servers");
|
||||
println!(" GET /health - Health check");
|
||||
println!("\nCLI Usage:");
|
||||
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\"");
|
||||
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json");
|
||||
println!(" ssh-proxy upload -n flux_dev -s ./local.txt -d /remote/path.txt");
|
||||
println!(" ssh-proxy download -n flux_dev -s /remote/path.txt -d ./local.txt");
|
||||
println!(" ssh-proxy servers");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
@@ -150,6 +185,21 @@ fn run_cli_servers(server_url: &str) -> anyhow::Result<()> {
|
||||
cli.list_servers()
|
||||
}
|
||||
|
||||
fn run_cli_transfer(server_url: &str, action: &str, server: &str, src: &str, dst: &str, concurrency: usize) -> anyhow::Result<()> {
|
||||
let cli = cli::Cli::new(Some(server_url.to_string()));
|
||||
if !cli.check_server()? {
|
||||
eprintln!();
|
||||
eprintln!("=== ssh-proxy 未运行 ===");
|
||||
eprintln!("启动代理: ssh-proxy");
|
||||
anyhow::bail!("Proxy server not running at {}", server_url);
|
||||
}
|
||||
match action {
|
||||
"upload" => cli.upload(server, src, dst, concurrency),
|
||||
"download" => cli.download(server, src, dst, concurrency),
|
||||
_ => anyhow::bail!("Unknown action: {}", action),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_cli_add_server(
|
||||
server_url: &str,
|
||||
name: String,
|
||||
|
||||
@@ -4,8 +4,11 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::path::Path;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
|
||||
use russh::client::{self, Config};
|
||||
use russh::keys::{load_secret_key, PrivateKeyWithHashAlg, ssh_key};
|
||||
use russh_sftp::client::SftpSession;
|
||||
use russh_sftp::protocol::OpenFlags;
|
||||
|
||||
use crate::config::SshServerConfig;
|
||||
|
||||
@@ -177,6 +180,208 @@ impl SessionManager {
|
||||
})
|
||||
}
|
||||
|
||||
/// 上传文件到远程服务器 (并行多通道)
|
||||
pub async fn upload(&self, name: &str, local_path: &str, remote_path: &str, concurrency: usize) -> Result<TransferResult> {
|
||||
let start = Instant::now();
|
||||
let session = self.get_or_create_session(name).await?;
|
||||
|
||||
// 读取本地文件大小
|
||||
let metadata = tokio::fs::metadata(local_path).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot stat local file '{}': {}", local_path, e))?;
|
||||
let file_size = metadata.len();
|
||||
let concurrency = concurrency.min(file_size as usize).max(1);
|
||||
let chunk_size = (file_size as usize + concurrency - 1) / concurrency;
|
||||
|
||||
// 先用第一个 channel 创建远程文件 (CREATE | TRUNCATE),获取大小后并行写入
|
||||
// 实际上每个 channel 都可以 CREATE | WRITE,SFTP 保证并发安全
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for i in 0..concurrency {
|
||||
let offset = (i * chunk_size) as u64;
|
||||
let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize;
|
||||
if len == 0 { break; }
|
||||
|
||||
let session = session.clone();
|
||||
let local = local_path.to_string();
|
||||
let remote = remote_path.to_string();
|
||||
|
||||
handles.push(tokio::spawn(async move {
|
||||
// 每个任务独立开 SFTP channel
|
||||
let channel = session.channel_open_session().await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
|
||||
channel.request_subsystem(true, "sftp").await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
|
||||
|
||||
let sftp = SftpSession::new(channel.into_stream()).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
|
||||
|
||||
// 打开远程文件 (WRITE | CREATE,不 TRUNCATE,第一个 chunk 创建后其他追加)
|
||||
let flags = if i == 0 {
|
||||
OpenFlags::WRITE | OpenFlags::CREATE | OpenFlags::TRUNCATE
|
||||
} else {
|
||||
OpenFlags::WRITE | OpenFlags::CREATE
|
||||
};
|
||||
let mut remote_file = sftp.open_with_flags(&remote, flags).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?;
|
||||
|
||||
remote_file.seek(std::io::SeekFrom::Start(offset)).await
|
||||
.map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?;
|
||||
|
||||
// 读取本地文件 chunk 并写入
|
||||
let mut local_file = tokio::fs::File::open(&local).await?;
|
||||
local_file.seek(std::io::SeekFrom::Start(offset)).await?;
|
||||
|
||||
let mut buf = vec![0u8; len];
|
||||
local_file.read_exact(&mut buf).await?;
|
||||
|
||||
remote_file.write_all(&buf).await
|
||||
.map_err(|e| anyhow::anyhow!("Write failed: {}", e))?;
|
||||
|
||||
remote_file.shutdown().await.ok();
|
||||
Ok::<_, anyhow::Error>(len as u64)
|
||||
}));
|
||||
}
|
||||
|
||||
// 等待所有任务完成
|
||||
let mut total_written = 0u64;
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(Ok(written)) => total_written += written,
|
||||
Ok(Err(e)) => {
|
||||
// 上传失败,尝试清理远程文件
|
||||
self.remove_remote_file(name, remote_path).await.ok();
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
self.remove_remote_file(name, remote_path).await.ok();
|
||||
return Err(anyhow::anyhow!("Task failed: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TransferResult {
|
||||
size: total_written,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
path: remote_path.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 从远程服务器下载文件 (并行多通道)
|
||||
pub async fn download(&self, name: &str, remote_path: &str, local_path: &str, concurrency: usize) -> Result<TransferResult> {
|
||||
let start = Instant::now();
|
||||
let session = self.get_or_create_session(name).await?;
|
||||
|
||||
// 先获取远程文件大小
|
||||
let channel = session.channel_open_session().await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
|
||||
channel.request_subsystem(true, "sftp").await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
|
||||
|
||||
let sftp = SftpSession::new(channel.into_stream()).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
|
||||
|
||||
let file_size = sftp.metadata(remote_path).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot stat remote file '{}': {}", remote_path, e))?
|
||||
.len();
|
||||
|
||||
sftp.close().await.ok();
|
||||
|
||||
let concurrency = concurrency.min(file_size as usize).max(1);
|
||||
let chunk_size = (file_size as usize + concurrency - 1) / concurrency;
|
||||
|
||||
// 创建本地文件并预分配大小
|
||||
let local_file = tokio::fs::File::create(local_path).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot create local file '{}': {}", local_path, e))?;
|
||||
local_file.set_len(file_size).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot set file size: {}", e))?;
|
||||
drop(local_file); // 关闭以便并行写入
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for i in 0..concurrency {
|
||||
let offset = (i * chunk_size) as u64;
|
||||
let len = std::cmp::min(chunk_size as u64, file_size - offset) as usize;
|
||||
if len == 0 { break; }
|
||||
|
||||
let session = session.clone();
|
||||
let remote = remote_path.to_string();
|
||||
let local = local_path.to_string();
|
||||
|
||||
handles.push(tokio::spawn(async move {
|
||||
let channel = session.channel_open_session().await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
|
||||
channel.request_subsystem(true, "sftp").await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
|
||||
|
||||
let sftp = SftpSession::new(channel.into_stream()).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
|
||||
|
||||
let mut remote_file = sftp.open(&remote).await
|
||||
.map_err(|e| anyhow::anyhow!("Cannot open remote file '{}': {}", remote, e))?;
|
||||
|
||||
remote_file.seek(std::io::SeekFrom::Start(offset)).await
|
||||
.map_err(|e| anyhow::anyhow!("Seek failed: {}", e))?;
|
||||
|
||||
let mut buf = vec![0u8; len];
|
||||
remote_file.read_exact(&mut buf).await
|
||||
.map_err(|e| anyhow::anyhow!("Read failed: {}", e))?;
|
||||
|
||||
// 写入本地文件对应 offset
|
||||
let mut local_file = tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.open(&local).await?;
|
||||
local_file.seek(std::io::SeekFrom::Start(offset)).await?;
|
||||
local_file.write_all(&buf).await
|
||||
.map_err(|e| anyhow::anyhow!("Local write failed: {}", e))?;
|
||||
local_file.flush().await?;
|
||||
|
||||
Ok::<_, anyhow::Error>(len as u64)
|
||||
}));
|
||||
}
|
||||
|
||||
// 等待所有任务完成
|
||||
let mut total_read = 0u64;
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(Ok(read)) => total_read += read,
|
||||
Ok(Err(e)) => {
|
||||
// 下载失败,删除不完整的本地文件
|
||||
tokio::fs::remove_file(local_path).await.ok();
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
tokio::fs::remove_file(local_path).await.ok();
|
||||
return Err(anyhow::anyhow!("Task failed: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TransferResult {
|
||||
size: total_read,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
path: local_path.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 删除远程文件 (传输失败时清理)
|
||||
async fn remove_remote_file(&self, name: &str, remote_path: &str) -> Result<()> {
|
||||
let session = self.get_or_create_session(name).await?;
|
||||
|
||||
let channel = session.channel_open_session().await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to open channel: {}", e))?;
|
||||
channel.request_subsystem(true, "sftp").await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to request SFTP subsystem: {}", e))?;
|
||||
|
||||
let sftp = SftpSession::new(channel.into_stream()).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to init SFTP session: {}", e))?;
|
||||
|
||||
sftp.remove_file(remote_path).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to remove remote file: {}", e))?;
|
||||
|
||||
sftp.close().await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn list_servers(&self) -> Vec<ServerInfo> {
|
||||
// 使用 try_lock 避免阻塞,如果锁不可用则返回 pending
|
||||
let sessions = self.sessions.try_lock();
|
||||
@@ -268,3 +473,10 @@ pub struct ServerInfo {
|
||||
pub user: String,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct TransferResult {
|
||||
pub size: u64,
|
||||
pub duration_ms: u64,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user