新增: MySQL/SSH 代理工具

- mysql-proxy: MySQL HTTP 代理,连接池复用
- ssh-proxy: SSH HTTP 代理,会话复用
- mysql-cli: 轻量级 MySQL CLI 工具

功能特性:
- 延迟初始化,启动快
- CLI 和 HTTP API 双模式
- 请求日志支持
- 错误友好提示
- JSON 极简输出格式
This commit is contained in:
2026-03-19 14:03:12 +08:00
commit 11203f036f
24 changed files with 3794 additions and 0 deletions

21
ssh-proxy/Cargo.toml Normal file
View File

@@ -0,0 +1,21 @@
[package]
name = "ssh-proxy"
version = "0.1.0"
edition = "2021"
description = "SSH HTTP proxy with session pooling"
[dependencies]
ssh2 = "0.9"
tokio = { version = "1", features = ["full"] }
axum = "0.7"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
anyhow = "1"
clap = { version = "4", features = ["derive"] }
ureq = "2" # CLI HTTP client (2.x has simpler API)
[profile.release]
opt-level = "z"
lto = true
strip = true

218
ssh-proxy/src/cli.rs Normal file
View File

@@ -0,0 +1,218 @@
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
use std::time::Instant;
const DEFAULT_SERVER: &str = "http://127.0.0.1:3308";
#[derive(Debug, Serialize)]
struct ExecRequest {
server: String,
command: String,
}
#[derive(Debug, Deserialize)]
struct ExecResponse {
stdout: String,
stderr: String,
#[serde(rename = "exitCode")]
exit_code: i32,
#[serde(rename = "durationMs")]
duration_ms: u64,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: String,
#[serde(default)]
usage: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ServersResponse {
servers: Vec<ServerInfo>,
}
#[derive(Debug, Deserialize)]
struct ServerInfo {
name: String,
host: String,
port: u16,
user: String,
status: String,
}
/// CLI 客户端
pub struct Cli {
server: String,
}
impl Cli {
pub fn new(server: Option<String>) -> Self {
Self {
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
}
}
/// 执行远程命令
pub fn exec(&self, server: &str, command: &str, format: &str) -> Result<()> {
let start = Instant::now();
let body = serde_json::to_string(&ExecRequest {
server: server.to_string(),
command: command.to_string(),
})?;
let response = ureq::post(&format!("{}/exec", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
// ureq 2.x: 非 2xx 会返回 Err但响应体在 error 中
let body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
eprintln!("Error: {}", err.error);
if let Some(usage) = err.usage {
eprintln!("\n{}", usage);
}
return Ok(());
}
let result: ExecResponse = serde_json::from_str(&body)?;
// 根据格式输出
if format == "json" {
self.print_json(&result, start.elapsed().as_millis() as u64);
} else {
self.print_text(&result, start.elapsed().as_millis() as u64);
}
Ok(())
}
fn print_text(&self, result: &ExecResponse, total_ms: u64) {
if !result.stdout.is_empty() {
print!("{}", result.stdout);
}
if !result.stderr.is_empty() {
eprint!("{}", result.stderr);
}
eprintln!("\n--- exit: {}, {}ms db, {}ms total ---",
result.exit_code, result.duration_ms, total_ms);
}
fn print_json(&self, result: &ExecResponse, total_ms: u64) {
#[derive(Serialize)]
struct Output {
#[serde(rename = "exitCode")]
exit_code: i32,
stdout: String,
stderr: String,
#[serde(rename = "durationMs")]
duration_ms: u64,
#[serde(rename = "totalMs")]
total_ms: u64,
}
let output = Output {
exit_code: result.exit_code,
stdout: result.stdout.clone(),
stderr: result.stderr.clone(),
duration_ms: result.duration_ms,
total_ms,
};
match serde_json::to_string(&output) {
Ok(json) => println!("{}", json),
Err(e) => eprintln!("JSON error: {}", e),
}
}
/// 列出服务器
pub fn list_servers(&self) -> Result<()> {
let response = ureq::get(&format!("{}/servers", self.server))
.call()?;
let body = response.into_string()?;
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
anyhow::bail!("{}", err.error);
}
let result: ServersResponse = serde_json::from_str(&body)?;
println!("Servers:");
println!("{:<15} {:<30} {:<8} {:<10} {:<10}", "Name", "Host", "Port", "User", "Status");
println!("{}", "-".repeat(73));
for srv in result.servers {
println!("{:<15} {:<30} {:<8} {:<10} {:<10}",
srv.name, srv.host, srv.port, srv.user, srv.status);
}
Ok(())
}
/// 动态添加服务器
pub fn add_server(
&self,
name: String,
host: String,
port: u16,
user: String,
password: Option<String>,
private_key: Option<String>,
) -> Result<()> {
#[derive(Serialize)]
struct AddRequest {
name: String,
host: String,
port: u16,
user: String,
password: Option<String>,
private_key: Option<String>,
}
let body = serde_json::to_string(&AddRequest {
name,
host,
port,
user,
password,
private_key,
})?;
let response = ureq::post(&format!("{}/servers/add", self.server))
.set("Content-Type", "application/json")
.send_string(&body)?;
let body = response.into_string()?;
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
anyhow::bail!("{}", err.error);
}
#[derive(Deserialize)]
struct AddResponse {
success: bool,
message: String,
}
let result: AddResponse = serde_json::from_str(&body)?;
println!("{}", result.message);
Ok(())
}
/// 检查代理是否运行
pub fn check_server(&self) -> Result<bool> {
match ureq::get(&format!("{}/health", self.server)).call() {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
}

112
ssh-proxy/src/config.rs Normal file
View File

@@ -0,0 +1,112 @@
use anyhow::{Result, bail};
use serde::Deserialize;
use std::collections::HashSet;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Deserialize, Clone)]
pub struct ServerConfig {
pub port: u16,
#[serde(default = "default_host")]
pub host: String,
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
#[derive(Debug, Deserialize, Clone)]
pub struct PoolConfig {
#[serde(default = "default_idle_timeout")]
pub idle_timeout_secs: u64,
#[serde(default = "default_check_interval")]
pub check_interval_secs: u64,
}
fn default_idle_timeout() -> u64 { 300 }
fn default_check_interval() -> u64 { 60 }
impl Default for PoolConfig {
fn default() -> Self {
Self {
idle_timeout_secs: default_idle_timeout(),
check_interval_secs: default_check_interval(),
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct SshServerConfig {
pub name: String,
pub host: String,
#[serde(default = "default_ssh_port")]
pub port: u16,
pub user: String,
/// 私钥路径 (优先)
pub private_key: Option<String>,
/// 密码 (备选)
pub password: Option<String>,
}
fn default_ssh_port() -> u16 { 22 }
impl SshServerConfig {
/// 获取私钥路径
pub fn get_private_key_path(&self) -> Option<PathBuf> {
self.private_key.as_ref().map(|p| {
// 支持 ~ 路径展开 (Unix 和 Windows)
let path = if p.starts_with('~') {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.unwrap_or_default();
p.replacen('~', &home, 1)
} else {
p.clone()
};
PathBuf::from(path)
})
}
}
#[derive(Debug, Deserialize)]
pub struct Config {
#[serde(default = "default_server")]
pub server: ServerConfig,
#[serde(default)]
pub pool: PoolConfig,
pub servers: Vec<SshServerConfig>,
}
fn default_server() -> ServerConfig {
ServerConfig {
port: 3308,
host: "127.0.0.1".to_string(),
}
}
impl Config {
pub fn from_file(path: &str) -> Result<Self> {
let content = fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<()> {
let mut names = HashSet::new();
for srv in &self.servers {
if srv.name.is_empty() {
bail!("Server name cannot be empty");
}
if names.contains(&srv.name) {
bail!("Duplicate server name: {}", srv.name);
}
names.insert(srv.name.clone());
if srv.private_key.is_none() && srv.password.is_none() {
bail!("Server '{}' needs either private_key or password", srv.name);
}
}
Ok(())
}
}

183
ssh-proxy/src/handler.rs Normal file
View File

@@ -0,0 +1,183 @@
use axum::{extract::State, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use crate::session::SessionManager;
use crate::logger::{LogEntry, RequestLogger};
pub type AppState = Arc<(Arc<SessionManager>, Arc<RequestLogger>)>;
#[derive(Debug, Deserialize)]
pub struct ExecRequest {
pub server: String,
pub command: String,
}
#[derive(Debug, Deserialize)]
pub struct AddServerRequest {
pub name: String,
pub host: String,
#[serde(default = "default_ssh_port")]
pub port: u16,
#[serde(default = "default_user")]
pub user: String,
pub password: Option<String>,
pub private_key: Option<String>,
}
fn default_ssh_port() -> u16 { 22 }
fn default_user() -> String { "root".to_string() }
#[derive(Debug, Serialize)]
pub struct ExecResponse {
pub stdout: String,
pub stderr: String,
#[serde(rename = "exitCode")]
pub exit_code: i32,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
}
#[derive(Debug, Serialize)]
pub struct ServersResponse {
pub servers: Vec<crate::session::ServerInfo>,
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub servers: usize,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
#[serde(rename = "usage")]
pub usage: Option<String>,
}
impl ErrorResponse {
pub fn new(msg: &str) -> Self {
Self { error: msg.to_string(), usage: None }
}
pub fn with_usage(mut self, usage: &str) -> Self {
self.usage = Some(usage.to_string());
self
}
}
pub async fn exec(
State(state): State<AppState>,
Json(req): Json<ExecRequest>,
) -> Result<Json<ExecResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = Instant::now();
let (manager, logger) = &*state;
let manager_clone = manager.clone();
let server = req.server.clone();
let command = req.command.clone();
let result = tokio::task::spawn_blocking(move || {
manager_clone.exec(&server, &command)
})
.await
.map_err(|e| {
error_response_with_usage(&e.to_string(), "Internal error occurred")
})?
.map_err(|e| {
let usage = format!(
"Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}",
list_server_names(&manager)
);
error_response_with_usage(&e.to_string(), &usage)
})?;
let duration_ms = start.elapsed().as_millis() as u64;
// 记录日志
logger.log(&LogEntry::new("/exec", "http")
.with_server(&req.server)
.with_command(&req.command)
.with_duration(duration_ms)
.with_exit_code(result.exit_code));
Ok(Json(ExecResponse {
stdout: result.stdout,
stderr: result.stderr,
exit_code: result.exit_code,
duration_ms,
}))
}
pub async fn servers(
State(state): State<AppState>,
) -> Json<ServersResponse> {
let (manager, _) = &*state;
Json(ServersResponse { servers: manager.list_servers() })
}
pub async fn add_server(
State(state): State<AppState>,
Json(req): Json<AddServerRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
use crate::config::SshServerConfig;
// 验证必须有密码或私钥
if req.password.is_none() && req.private_key.is_none() {
return Err(error_response_with_usage(
"Server requires either password or private_key",
"Usage: POST /servers/add {\"name\": \"myserver\", \"host\": \"192.168.1.100\", \"user\": \"root\", \"password\": \"secret\"}\n or: {\"name\": \"myserver\", \"host\": \"...\", \"private_key\": \"~/.ssh/id_rsa\"}"
));
}
let cfg = SshServerConfig {
name: req.name.clone(),
host: req.host,
port: req.port,
user: req.user,
password: req.password,
private_key: req.private_key,
};
manager.add_server(cfg)
.map_err(|e| error_response(&e.to_string()))?;
// 记录日志
logger.log(&LogEntry::new("/servers/add", "http")
.with_server(&req.name)
.with_duration(0));
Ok(Json(serde_json::json!({
"success": true,
"message": format!("Server '{}' added (temporary, will be lost on restart)", req.name)
})))
}
pub async fn health(
State(state): State<AppState>,
) -> Json<HealthResponse> {
let (manager, _) = &*state;
Json(HealthResponse {
status: "ok".to_string(),
servers: manager.list_servers().len(),
})
}
fn error_response(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg)))
}
fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json<ErrorResponse>) {
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage)))
}
fn list_server_names(manager: &SessionManager) -> String {
manager.list_servers()
.iter()
.map(|s| s.name.clone())
.collect::<Vec<_>>()
.join(", ")
}

228
ssh-proxy/src/logger.rs Normal file
View File

@@ -0,0 +1,228 @@
use serde::{Deserialize, Serialize};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
/// 日志记录器
pub struct RequestLogger {
log_file: Mutex<Option<File>>,
enabled: bool,
}
impl RequestLogger {
pub fn new(log_path: Option<&str>) -> Self {
let (log_file, enabled) = if let Some(path) = log_path {
let path = expand_path(path);
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match OpenOptions::new()
.create(true)
.append(true)
.open(&path)
{
Ok(file) => (Some(file), true),
Err(e) => {
eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e);
(None, false)
}
}
} else {
(None, false)
};
Self {
log_file: Mutex::new(log_file),
enabled,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// 记录请求日志
pub fn log(&self, entry: &LogEntry) {
if !self.enabled {
return;
}
let json = match serde_json::to_string(entry) {
Ok(j) => j,
Err(e) => {
eprintln!("[Logger] Failed to serialize log: {}", e);
return;
}
};
if let Ok(mut file) = self.log_file.lock() {
if let Some(ref mut f) = *file {
let _ = writeln!(f, "{}", json);
}
}
}
}
/// 日志条目
#[derive(Debug, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: String,
pub level: String,
#[serde(rename = "type")]
pub log_type: String,
pub client: String,
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub conn: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sql: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub rows: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "exitCode")]
pub exit_code: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl LogEntry {
pub fn new(endpoint: &str, client: &str) -> Self {
Self {
timestamp: current_timestamp(),
level: "INFO".to_string(),
log_type: "request".to_string(),
client: client.to_string(),
endpoint: endpoint.to_string(),
conn: None,
server: None,
sql: None,
command: None,
duration_ms: 0,
rows: None,
exit_code: None,
error: None,
}
}
pub fn with_conn(mut self, conn: &str) -> Self {
self.conn = Some(conn.to_string());
self
}
pub fn with_server(mut self, server: &str) -> Self {
self.server = Some(server.to_string());
self
}
pub fn with_sql(mut self, sql: &str) -> Self {
// 截断过长的 SQL
self.sql = Some(truncate_string(sql, 1000));
self
}
pub fn with_command(mut self, command: &str) -> Self {
// 截断过长的命令
self.command = Some(truncate_string(command, 500));
self
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = ms;
self
}
pub fn with_rows(mut self, rows: usize) -> Self {
self.rows = Some(rows);
self
}
pub fn with_exit_code(mut self, code: i32) -> Self {
self.exit_code = Some(code);
self
}
pub fn with_error(mut self, error: &str) -> Self {
self.level = "ERROR".to_string();
self.error = Some(error.to_string());
self
}
}
fn current_timestamp() -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let secs = now.as_secs();
let datetime = chrono_timestamp(secs);
format!("{}Z", datetime)
}
fn chrono_timestamp(secs: u64) -> String {
let days = secs / 86400;
let remaining = secs % 86400;
let hours = remaining / 3600;
let minutes = (remaining % 3600) / 60;
let seconds = remaining % 60;
// 从 1970-01-01 开始计算日期
let mut year = 1970;
let mut days_left = days;
loop {
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
if days_left < days_in_year {
break;
}
days_left -= days_in_year;
year += 1;
}
let month_days = if is_leap_year(year) {
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
} else {
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
};
let mut month = 1;
for &days_in_month in &month_days {
if days_left < days_in_month {
break;
}
days_left -= days_in_month;
month += 1;
}
let day = days_left + 1;
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds)
}
fn is_leap_year(year: u64) -> bool {
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
}
fn truncate_string(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}... (truncated)", &s[..max_len])
}
}
fn expand_path(path: &str) -> PathBuf {
if path.starts_with('~') {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.unwrap_or_default();
PathBuf::from(path.replacen('~', &home, 1))
} else {
PathBuf::from(path)
}
}

170
ssh-proxy/src/main.rs Normal file
View File

@@ -0,0 +1,170 @@
mod cli;
mod config;
mod handler;
mod logger;
mod session;
use axum::{routing::{get, post}, Router};
use clap::{Parser, Subcommand};
use std::sync::Arc;
#[derive(Parser, Debug)]
#[command(name = "ssh-proxy")]
#[command(about = "SSH HTTP proxy with session pooling")]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[arg(long, default_value = "ssh-proxy.toml", global = true)]
config: String,
#[arg(short = 'S', long, default_value = "http://127.0.0.1:3308", global = true)]
server: String,
}
#[derive(Debug, Subcommand)]
enum Commands {
Server {
#[arg(short = 'P', long)]
port: Option<u16>,
#[arg(short = 'H', long)]
host: Option<String>,
},
Exec {
#[arg(short = 'n', long)]
name: String,
#[arg(short, long)]
command: String,
/// Output format: text, json
#[arg(short = 'F', long, default_value = "text")]
format: String,
},
Servers,
/// Add server dynamically (temporary)
AddServer {
#[arg(short = 'n', long)]
name: String,
#[arg(short = 'H', long)]
host: String,
#[arg(short = 'P', long, default_value = "22")]
port: u16,
#[arg(short = 'u', long, default_value = "root")]
user: String,
#[arg(short = 'p', long)]
password: Option<String>,
#[arg(short = 'k', long)]
private_key: Option<String>,
},
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await,
Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format),
Some(Commands::Servers) => run_cli_servers(&args.server),
Some(Commands::AddServer { name, host, port, user, password, private_key }) => {
run_cli_add_server(&args.server, name, host, port, user, password, private_key)
}
None => run_server(&args.config, None, None).await,
}
}
async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>) -> anyhow::Result<()> {
println!("SSH HTTP Proxy v0.1.0\n");
let mut config = config::Config::from_file(config_path)?;
if let Some(p) = port { config.server.port = p; }
if let Some(h) = host { config.server.host = h; }
// 初始化日志
let log_path = std::env::var("SSH_PROXY_LOG").ok();
let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref()));
if logger.is_enabled() {
println!("Request logging: enabled");
}
println!("Initializing SSH sessions...\n");
let manager = Arc::new(session::SessionManager::new(&config.servers)?);
let manager_clone = manager.clone();
let idle_timeout = config.pool.idle_timeout_secs;
let check_interval = config.pool.check_interval_secs;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(check_interval));
loop {
interval.tick().await;
manager_clone.cleanup_idle(idle_timeout);
}
});
let app = Router::new()
.route("/exec", post(handler::exec))
.route("/servers", get(handler::servers))
.route("/servers/add", post(handler::add_server))
.route("/health", get(handler::health))
.with_state(Arc::new((manager, logger)));
let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
println!("\nServer started at http://{}", addr);
println!("\nAPI Endpoints:");
println!(" POST /exec - Execute remote command");
println!(" POST /servers/add - Add server (temporary)");
println!(" GET /servers - List all servers");
println!(" GET /health - Health check");
println!("\nCLI Usage:");
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\"");
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json");
println!(" ssh-proxy servers");
axum::serve(listener, app).await?;
Ok(())
}
fn run_cli_exec(server_url: &str, server: &str, command: &str, format: &str) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server_url.to_string()));
if !cli.check_server()? {
eprintln!();
eprintln!("=== ssh-proxy 未运行,请使用降级方案 ===");
eprintln!();
eprintln!("启动代理: ssh-proxy");
eprintln!();
eprintln!("或直接使用 ssh 命令:");
eprintln!(" ssh <user>@<host> \"<command>\"");
eprintln!();
eprintln!("配置文件中的服务器信息见: ssh-proxy.toml");
anyhow::bail!("Proxy server not running at {}", server_url);
}
cli.exec(server, command, format)
}
fn run_cli_servers(server_url: &str) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server_url.to_string()));
if !cli.check_server()? {
eprintln!();
eprintln!("=== ssh-proxy 未运行 ===");
eprintln!("启动代理: ssh-proxy");
eprintln!("配置文件: ssh-proxy.toml");
anyhow::bail!("Proxy server not running at {}", server_url);
}
cli.list_servers()
}
fn run_cli_add_server(
server_url: &str,
name: String,
host: String,
port: u16,
user: String,
password: Option<String>,
private_key: Option<String>,
) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server_url.to_string()));
if !cli.check_server()? {
eprintln!();
eprintln!("=== ssh-proxy 未运行 ===");
eprintln!("启动代理: ssh-proxy");
anyhow::bail!("Proxy server not running at {}", server_url);
}
cli.add_server(name, host, port, user, password, private_key)
}

192
ssh-proxy/src/session.rs Normal file
View File

@@ -0,0 +1,192 @@
use anyhow::{Result, bail};
use ssh2::Session;
use std::collections::HashMap;
use std::net::TcpStream;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use std::io::Read;
use std::path::PathBuf;
use crate::config::SshServerConfig;
struct SessionState {
session: Session,
last_used: Instant,
}
pub struct SessionManager {
sessions: Mutex<HashMap<String, SessionState>>,
configs: Mutex<HashMap<String, SshServerConfig>>,
}
impl SessionManager {
pub fn new(configs: &[SshServerConfig]) -> Result<Self> {
let mut config_map = HashMap::new();
for cfg in configs {
config_map.insert(cfg.name.clone(), cfg.clone());
println!(" Registered: {} ({}:{})", cfg.name, cfg.host, cfg.port);
}
println!("\n{} server(s) configured (lazy init)", config_map.len());
Ok(Self {
sessions: Mutex::new(HashMap::new()),
configs: Mutex::new(config_map),
})
}
fn get_or_create_session(&self, name: &str) -> Result<Session> {
{
let mut sessions = self.sessions.lock().unwrap();
if let Some(state) = sessions.get_mut(name) {
if state.session.authenticated() {
state.last_used = Instant::now();
return Ok(state.session.clone());
} else {
println!("[Session] Session expired, reconnecting: {}", name);
sessions.remove(name);
}
}
}
let cfg = self.configs.lock().unwrap()
.get(name)
.ok_or_else(|| anyhow::anyhow!("Server '{}' not found", name))?
.clone();
println!("[LazyInit] Connecting to: {}", name);
let session = self.create_session(&cfg)?;
{
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(name.to_string(), SessionState {
session: session.clone(),
last_used: Instant::now(),
});
}
println!("[LazyInit] Connected: {}", name);
Ok(session)
}
fn create_session(&self, cfg: &SshServerConfig) -> Result<Session> {
let addr = format!("{}:{}", cfg.host, cfg.port);
let tcp = TcpStream::connect(&addr)
.map_err(|e| anyhow::anyhow!("Failed to connect to {}: {}", addr, e))?;
let mut session = Session::new()?;
session.set_tcp_stream(tcp);
session.handshake()?;
if let Some(key_path) = cfg.get_private_key_path() {
let pubkey_path: PathBuf = key_path.with_extension("pub");
session.userauth_pubkey_file(&cfg.user, Some(&pubkey_path), &key_path, None)?;
} else if let Some(ref password) = cfg.password {
session.userauth_password(&cfg.user, password)?;
} else {
bail!("No authentication method configured");
}
if !session.authenticated() {
bail!("SSH authentication failed");
}
Ok(session)
}
pub fn exec(&self, name: &str, command: &str) -> Result<ExecResult> {
let start = Instant::now();
let session = self.get_or_create_session(name)?;
let mut channel = session.channel_session()?;
channel.exec(command)?;
let mut stdout = String::new();
let mut stderr = String::new();
if let Err(e) = channel.read_to_string(&mut stdout) {
eprintln!("[SSH] stdout read error: {}", e);
}
if let Err(e) = channel.stderr().read_to_string(&mut stderr) {
eprintln!("[SSH] stderr read error: {}", e);
}
channel.wait_close()?;
let exit_code = channel.exit_status()?;
Ok(ExecResult {
stdout,
stderr,
exit_code,
duration_ms: start.elapsed().as_millis() as u64,
})
}
pub fn list_servers(&self) -> Vec<ServerInfo> {
let sessions = self.sessions.lock().unwrap();
let configs = self.configs.lock().unwrap();
configs.iter().map(|(name, cfg)| {
let status = if sessions.contains_key(name) { "connected" } else { "pending" };
ServerInfo {
name: name.clone(),
host: cfg.host.clone(),
port: cfg.port,
user: cfg.user.clone(),
status: status.to_string(),
}
}).collect()
}
/// 动态添加服务器 (临时,重启后消失)
pub fn add_server(&self, cfg: SshServerConfig) -> Result<()> {
let name = cfg.name.clone();
let mut configs = self.configs.lock().unwrap();
if configs.contains_key(&name) {
anyhow::bail!("Server '{}' already exists", name);
}
println!("[Dynamic] Adding: {} ({}:{})", name, cfg.host, cfg.port);
// 测试连接
let session = self.create_session(&cfg)?;
{
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(name.clone(), SessionState {
session,
last_used: Instant::now(),
});
}
configs.insert(name.clone(), cfg);
println!("[Dynamic] ✓ Added: {}", name);
Ok(())
}
pub fn cleanup_idle(&self, timeout_secs: u64) {
let mut sessions = self.sessions.lock().unwrap();
let now = Instant::now();
sessions.retain(|name, state| {
let elapsed = now.duration_since(state.last_used);
if elapsed > Duration::from_secs(timeout_secs) {
println!("[Cleanup] Removing idle session: {} (idle {}s)", name, elapsed.as_secs());
false
} else {
true
}
});
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ExecResult {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
pub duration_ms: u64,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ServerInfo {
pub name: String,
pub host: String,
pub port: u16,
pub user: String,
pub status: String,
}