commit 11203f036fc8a1daf29ad124d0a7780305b8be02 Author: 绝尘 <237809796@qq.com> Date: Thu Mar 19 14:03:12 2026 +0800 新增: MySQL/SSH 代理工具 - mysql-proxy: MySQL HTTP 代理,连接池复用 - ssh-proxy: SSH HTTP 代理,会话复用 - mysql-cli: 轻量级 MySQL CLI 工具 功能特性: - 延迟初始化,启动快 - CLI 和 HTTP API 双模式 - 请求日志支持 - 错误友好提示 - JSON 极简输出格式 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..34b47d7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +# 编译输出 +target/ + +# 日志 +logs/ +*.log + +# IDE +.idea/ +.vscode/ + +# 配置文件(含敏感信息) +mysql-proxy/mysql-proxy.toml +ssh-proxy/ssh-proxy.toml diff --git a/INCUBATOR.md b/INCUBATOR.md new file mode 100644 index 0000000..f4b9f18 --- /dev/null +++ b/INCUBATOR.md @@ -0,0 +1,100 @@ +# 代理工具孵化记录 + +> 状态:孵化中 | 更新:2026-03-19 + +--- + +## 工具列表 + +| 工具 | 端口 | 状态 | 项目目录 | +|------|------|------|----------| +| mysql-proxy | 3307 | 可用 | `mysql-proxy` | +| ssh-proxy | 3308 | 可用 | `ssh-proxy` | + +--- + +## 新增功能 + +### 2026-03-19 + +1. **请求日志** (P0) + - 设置环境变量启用:`MYSQL_PROXY_LOG` / `SSH_PROXY_LOG` + - 日志格式:JSON,每行一条记录 + +2. **ssh-proxy CLI JSON 输出** (P1) + ```bash + ssh-proxy exec -n flux_dev -c "docker ps" -F json + ``` + +3. **ssh-proxy 动态添加服务器** (P2) + ```bash + ssh-proxy add-server -n myserver -H 192.168.1.100 -u root -p password + # 或 API + curl -X POST http://127.0.0.1:3308/servers/add \ + -H "Content-Type: application/json" \ + -d '{"name":"myserver","host":"192.168.1.100","user":"root","password":"secret"}' + ``` + +4. **API 错误友好提示** + - 错误时返回 `usage` 字段,引导正确使用 + +### 2026-03-18 + +- mysql-proxy 动态添加连接 API +- ssh-proxy Windows 路径修复 +- ssh-proxy 读取错误日志修复 + +--- + +## 实测性能 + +| 操作 | 代理 | 直连 | 提升 | +|------|------|------|------| +| mysql 首次查询 | ~150ms | ~500ms | 3x | +| mysql 复用会话 | ~50ms | ~500ms | 10x | +| ssh 首次执行 | ~900ms | ~1500ms | 1.7x | +| ssh 复用会话 | ~100-200ms | ~1500ms | 7-15x | + +--- + +## AI 推荐使用方式 + +```bash +# MySQL 查询 (AI 友好) +curl -X POST http://127.0.0.1:3307/query \ + -H "Content-Type: application/json" \ + -d '{"conn":"flux_dev","sql":"SELECT VERSION()"}' + +# SSH 执行 (AI 友好) +curl -X POST http://127.0.0.1:3308/exec \ + -H "Content-Type: application/json" \ + -d '{"server":"flux_dev","command":"docker ps"}' +``` + +**优势**: +- JSON 格式,易于解析 +- 会话复用,响应快 +- 错误提示友好 + +--- + +## 已知问题 + +### ssh-proxy + +1. **ssh2 不支持 ed25519 密钥** + - 错误: `[Session(-19)] Callback returned error` + - 解决: 使用 PEM 格式 RSA 密钥 + ```bash + ssh-keygen -t rsa -b 2048 -f ~/.ssh/id_rsa_pem -m PEM -N "" + ``` + +### mysql-proxy + +- 暂无 + +--- + +## 待优化 + +*由 AI 在实际使用过程中根据遇到的问题填写* diff --git a/INDEX.md b/INDEX.md new file mode 100644 index 0000000..00b85a6 --- /dev/null +++ b/INDEX.md @@ -0,0 +1,25 @@ +# Rust Work 索引 + +> 更新:2026-03-18 + +## 项目列表 + +| 项目 | 端口 | 说明 | +|------|------|------| +| [mysql-proxy](mysql-proxy) | 3307 | MySQL 会话复用代理 | +| [ssh-proxy](ssh-proxy) | 3308 | SSH 会话复用代理 | +| [mysql-cli](mysql-cli) | - | MySQL CLI 工具 (开发中) | + +## 孵化记录 + +- [INCUBATOR.md](INCUBATOR.md) - 问题/性能/建议 + +## 快速启动 + +```bash +# mysql-proxy +cd mysql-proxy && ./target/release/mysql-proxy.exe + +# ssh-proxy +cd ssh-proxy && ./target/release/ssh-proxy.exe +``` diff --git a/mysql-cli/Cargo.toml b/mysql-cli/Cargo.toml new file mode 100644 index 0000000..89c51ab --- /dev/null +++ b/mysql-cli/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "mysql-cli" +version = "0.1.0" +edition = "2021" +description = "A lightweight MySQL command-line tool written in Rust" + +[dependencies] +# MySQL driver +mysql = "25" + +# Command line argument parsing +clap = { version = "4", features = ["derive"] } + +# REPL with history and completion +rustyline = { version = "14", features = ["derive"] } + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# CSV output +csv = "1" + +# Error handling +anyhow = "1" +thiserror = "1" + +# Check if stdin is a tty +atty = "0.2" + +[profile.release] +opt-level = "z" +lto = true +strip = true diff --git a/mysql-cli/README.md b/mysql-cli/README.md new file mode 100644 index 0000000..05ad9cc --- /dev/null +++ b/mysql-cli/README.md @@ -0,0 +1,148 @@ +# MySQL CLI (Rust) + +> 💡 **快速开始**:运行 `mysql --help` 查看完整使用文档 + +一个轻量级的 Rust 语言 MySQL 命令行工具,兼容 MySQL 官方客户端用法,可以在未安装 MySQL 客户端的系统中执行 SQL 语句。 + +## 功能特性 + +- ✅ **交互式模式**:支持完整的 REPL 交互式 Shell +- ✅ **批量执行**:支持从 SQL 文件执行多条语句 +- ✅ **多种输出格式**:支持表格、CSV、JSON、垂直格式 +- ✅ **完整 SQL 支持**:支持 SELECT、INSERT、UPDATE、DELETE 等所有 SQL 语句 +- ✅ **特殊命令**:支持 USE、SHOW、DESCRIBE、EXPLAIN 等 MySQL 命令 +- ✅ **紧凑体积**:编译后仅 2.5MB,无需安装 MySQL 客户端 + +## 编译 + +```bash +cargo build --release +``` + +## 使用方法 + +### 基本语法 + +```bash +./mysql.exe -h -P -u -p -D -e "" +``` + +### 参数说明 + +| 参数 | 说明 | 默认值 | +|------|------|--------| +| -h | MySQL 主机地址 | 127.0.0.1 | +| -P | MySQL 端口 | 3306 | +| -u | 用户名 | root | +| -p | 密码 | (空) | +| -D | 数据库名 | (空) | +| -e | 要执行的 SQL 语句 | (可选) | +| -f | 从 SQL 文件执行 | (可选) | +| -F | 输出格式:table, csv, json, vertical | table | +| -o | 输出到文件 | stdout | +| -i | 强制进入交互模式 | (可选) | + +**注意**:不指定 `-e` 或 `-f` 时,需要使用 `-i` 参数进入交互模式 + +### 使用示例 + +#### 1. 交互式模式 + +进入交互式 Shell: +```bash +# 基本连接并进入交互模式 +mysql -p123456 -i + +# 指定主机和用户 +mysql -h127.0.0.1 -uroot -p123456 -i +``` + +交互式命令: +- `use ` - 切换数据库 +- `status` - 显示连接状态 +- `source ` - 执行 SQL 文件 +- `format ` - 设置输出格式(table, csv, json, vertical) +- `output ` - 设置输出文件 +- `exit / quit` - 退出程序 + +#### 2. 直接执行 SQL + +```bash +# 查询数据 +mysql -h127.0.0.1 -uroot -p123456 -Dmydb -e "SELECT * FROM users LIMIT 10" + +# 显示表结构 +mysql -p123456 -Dmydb -e "DESCRIBE users" + +# 显示所有数据库 +mysql -p123456 -e "SHOW DATABASES" + +# 插入数据 +mysql -p123456 -Dmydb -e "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')" + +# 更新数据 +mysql -p123456 -Dmydb -e "UPDATE users SET email='new@email.com' WHERE id=1" + +# 删除数据 +mysql -p123456 -Dmydb -e "DELETE FROM users WHERE id=1" +``` + +#### 3. 从 SQL 文件执行 + +```bash +# 执行 SQL 文件 +mysql -p123456 -f script.sql + +# 执行并输出到 CSV +mysql -p123456 -f script.sql -F csv -o result.csv +``` + +#### 4. 不同输出格式 + +```bash +# CSV 格式 +mysql -p123456 -e "SELECT * FROM users" -F csv + +# JSON 格式 +mysql -p123456 -e "SELECT * FROM users" -F json + +# 垂直格式(适合字段较多的记录) +mysql -p123456 -e "SELECT * FROM users" -F vertical +``` + +## 项目结构 + +``` +mysql-cli/ +├── Cargo.toml # Rust 项目配置 +├── README.md # 项目文档 +└── src/ + ├── main.rs # 主程序入口 + ├── config.rs # 配置管理 + ├── db.rs # 数据库连接 + ├── executor.rs # SQL 执行引擎 + └── repl.rs # 交互式 REPL +``` + +## 依赖 + +- Rust 1.70+ +- mysql (25.x) +- clap (4.x) - 命令行参数解析 +- rustyline (14.x) - 交互式 REPL +- serde / serde_json - JSON 序列化 +- csv - CSV 输出 + +## 与 Go 版本对比 + +| 功能 | Rust 版本 | Go 版本 | +|------|-----------|---------| +| 交互式模式 | ✅ | ✅ | +| SQL 文件执行 | ✅ | ✅ | +| 输出格式 | 表格/CSV/JSON/垂直 | 表格/CSV/JSON/垂直 | +| 二进制大小 | ~2.5MB | ~7.3MB | +| 内存占用 | 更低 | 较高 | + +## License + +MIT diff --git a/mysql-cli/src/config.rs b/mysql-cli/src/config.rs new file mode 100644 index 0000000..9e1dd36 --- /dev/null +++ b/mysql-cli/src/config.rs @@ -0,0 +1,115 @@ +use std::fmt; + +/// 数据库连接配置 +#[derive(Debug, Clone)] +pub struct Config { + pub host: String, + pub port: u16, + pub username: String, + pub password: String, + pub database: String, +} + +impl Default for Config { + fn default() -> Self { + Self { + host: "127.0.0.1".to_string(), + port: 3306, + username: "root".to_string(), + password: String::new(), + database: String::new(), + } + } +} + +impl Config { + /// 构建 MySQL DSN + pub fn build_dsn(&self) -> String { + format!( + "mysql://{}:{}@{}:{}/{}", + self.username, + urlencoding(&self.password), + self.host, + self.port, + self.database + ) + } + + /// 构建不带数据库的 DSN + pub fn build_dsn_without_db(&self) -> String { + format!( + "mysql://{}:{}@{}:{}/", + self.username, + urlencoding(&self.password), + self.host, + self.port + ) + } +} + +impl fmt::Display for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}@{}:{}", + self.username, self.host, self.port + ) + } +} + +/// URL 编码密码中的特殊字符 +fn urlencoding(s: &str) -> String { + let mut result = String::new(); + for c in s.chars() { + match c { + '@' => result.push_str("%40"), + ':' => result.push_str("%3A"), + '/' => result.push_str("%2F"), + '?' => result.push_str("%3F"), + '#' => result.push_str("%23"), + '[' => result.push_str("%5B"), + ']' => result.push_str("%5D"), + '!' => result.push_str("%21"), + '$' => result.push_str("%24"), + '&' => result.push_str("%26"), + '\'' => result.push_str("%27"), + '(' => result.push_str("%28"), + ')' => result.push_str("%29"), + '*' => result.push_str("%2A"), + '+' => result.push_str("%2B"), + ',' => result.push_str("%2C"), + ';' => result.push_str("%3B"), + '=' => result.push_str("%3D"), + '%' => result.push_str("%25"), + ' ' => result.push_str("%20"), + _ => result.push(c), + } + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_dsn() { + let config = Config { + host: "localhost".to_string(), + port: 3306, + username: "root".to_string(), + password: "123456".to_string(), + database: "test".to_string(), + }; + assert_eq!( + config.build_dsn(), + "mysql://root:123456@localhost:3306/test" + ); + } + + #[test] + fn test_urlencoding() { + assert_eq!(urlencoding("a@b:c"), "a%40b%3Ac"); + assert_eq!(urlencoding("p@ss!word"), "p%40ss%21word"); + } +} diff --git a/mysql-cli/src/db.rs b/mysql-cli/src/db.rs new file mode 100644 index 0000000..211f340 --- /dev/null +++ b/mysql-cli/src/db.rs @@ -0,0 +1,60 @@ +use anyhow::Result; +use mysql::{Pool, PooledConn, Opts, prelude::*}; + +use crate::config::Config; + +/// 连接数据库 +pub fn connect(cfg: &Config) -> Result { + let dsn = if cfg.database.is_empty() { + cfg.build_dsn_without_db() + } else { + cfg.build_dsn() + }; + + let opts: Opts = Opts::from_url(&dsn)?; + let pool = Pool::new(opts)?; + + // 测试连接 + let mut conn = pool.get_conn()?; + conn.query::("SELECT 1")?; + + Ok(pool) +} + +/// 获取连接 +pub fn get_conn(pool: &Pool) -> Result { + Ok(pool.get_conn()?) +} + +/// 获取服务器版本 +pub fn get_server_version(conn: &mut PooledConn) -> String { + conn.query_first::("SELECT VERSION()") + .ok() + .flatten() + .unwrap_or_else(|| "Unknown".to_string()) +} + +/// 获取当前数据库 +pub fn get_current_database(conn: &mut PooledConn) -> Option { + conn.query_first::("SELECT DATABASE()").ok().flatten() +} + +/// 获取所有数据库 +pub fn get_databases(conn: &mut PooledConn) -> Vec { + conn.query::("SHOW DATABASES") + .unwrap_or_default() +} + +/// 获取所有表名 +pub fn get_tables(conn: &mut PooledConn) -> Vec { + conn.query::("SHOW TABLES") + .unwrap_or_default() +} + +/// 获取表的列信息 +pub fn get_columns(conn: &mut PooledConn, table: &str) -> Vec { + let sql = format!("SHOW COLUMNS FROM `{}`", table); + conn.query_map(sql, |row: mysql::Row| { + row.get::(0).unwrap_or_default() + }).unwrap_or_default() +} diff --git a/mysql-cli/src/executor.rs b/mysql-cli/src/executor.rs new file mode 100644 index 0000000..1cc0139 --- /dev/null +++ b/mysql-cli/src/executor.rs @@ -0,0 +1,348 @@ +use anyhow::{Result, bail}; +use mysql::{Pool, PooledConn, prelude::*, Row, Value}; +use std::fs::File; +use std::io::{self, Write, BufWriter}; + +/// 输出格式 +#[derive(Debug, Clone, Copy, Default)] +pub enum OutputFormat { + #[default] + Table, + Csv, + Json, + Vertical, +} + +impl std::str::FromStr for OutputFormat { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "table" => Ok(Self::Table), + "csv" => Ok(Self::Csv), + "json" => Ok(Self::Json), + "vertical" => Ok(Self::Vertical), + _ => bail!("Invalid format: {}. Valid formats: table, csv, json, vertical", s), + } + } +} + +/// SQL 执行器 +pub struct Executor { + pool: Pool, + format: OutputFormat, + output: Option, +} + +impl Executor { + pub fn new(pool: Pool, format: OutputFormat, output: Option) -> Self { + Self { pool, format, output } + } + + /// 执行 SQL(支持多条语句) + pub fn execute(&self, query: &str) -> Result<()> { + let queries = split_queries(query); + let mut conn = self.pool.get_conn()?; + + for (i, q) in queries.iter().enumerate() { + if q.trim().is_empty() { + continue; + } + + if i > 0 { + println!(); + } + + self.execute_one(&mut conn, q)?; + } + + Ok(()) + } + + /// 执行单条 SQL + fn execute_one(&self, conn: &mut PooledConn, query: &str) -> Result<()> { + let query = query.trim(); + if query.is_empty() { + return Ok(()); + } + + let upper_query = query.to_uppercase(); + let is_query = upper_query.starts_with("SELECT") + || upper_query.starts_with("SHOW") + || upper_query.starts_with("DESCRIBE") + || upper_query.starts_with("DESC ") + || upper_query.starts_with("EXPLAIN") + || upper_query.starts_with("WITH"); + + if is_query { + self.execute_query(conn, query) + } else { + self.execute_exec(conn, query) + } + } + + /// 执行查询语句 + fn execute_query(&self, conn: &mut PooledConn, query: &str) -> Result<()> { + let result = conn.query_iter(query)?; + + let mut column_names: Vec = Vec::new(); + let mut data: Vec>> = Vec::new(); + + // 获取数据 + for row_result in result { + let row = row_result?; + // 从第一行获取列名 + if column_names.is_empty() { + column_names = row.columns() + .iter() + .map(|c| c.name_str().to_string()) + .collect(); + } + let values = row_to_strings(&row); + data.push(values); + } + + // 输出结果 + self.output_result(&column_names, &data)?; + + // 显示行数 + if !data.is_empty() { + println!(); + println!("{} rows in set", data.len()); + } + + Ok(()) + } + + /// 执行非查询语句 + fn execute_exec(&self, conn: &mut PooledConn, query: &str) -> Result<()> { + let result = conn.query_iter(query)?; + let upper_query = query.to_uppercase(); + + if upper_query.starts_with("INSERT") + || upper_query.starts_with("UPDATE") + || upper_query.starts_with("DELETE") + || upper_query.starts_with("REPLACE") + { + let affected = result.affected_rows(); + if upper_query.starts_with("INSERT") { + let last_id = result.last_insert_id().unwrap_or(0); + println!("Query OK, {} row affected, last insert ID: {}", affected, last_id); + } else { + println!("Query OK, {} row affected", affected); + } + } else { + println!("Query OK"); + } + + Ok(()) + } + + /// 根据格式输出结果 + fn output_result(&self, columns: &[String], data: &[Vec>]) -> Result<()> { + let mut output: Box = match &self.output { + Some(path) => Box::new(BufWriter::new(File::create(path)?)), + None => Box::new(BufWriter::new(io::stdout())), + }; + + match self.format { + OutputFormat::Table => self.output_table(&mut output, columns, data), + OutputFormat::Csv => self.output_csv(&mut output, columns, data), + OutputFormat::Json => self.output_json(&mut output, columns, data), + OutputFormat::Vertical => self.output_vertical(&mut output, columns, data), + } + } + + /// 表格格式输出 + fn output_table(&self, output: &mut dyn Write, columns: &[String], data: &[Vec>]) -> Result<()> { + // 计算列宽 + let mut widths: Vec = columns.iter().map(|c| c.len()).collect(); + for row in data { + for (i, val) in row.iter().enumerate() { + let len = val.as_ref().map_or(4, |v| v.len()); + if len > widths[i] { + widths[i] = len; + } + } + } + + // 打印表头 + print_table_row(output, columns, &widths)?; + print_table_separator(output, &widths)?; + + // 打印数据行 + for row in data { + let str_row: Vec = row.iter().map(|v| v.clone().unwrap_or_else(|| "NULL".to_string())).collect(); + print_table_row(output, &str_row, &widths)?; + } + + Ok(()) + } + + /// CSV 格式输出 + fn output_csv(&self, output: &mut dyn Write, columns: &[String], data: &[Vec>]) -> Result<()> { + let mut writer = csv::Writer::from_writer(output); + + writer.write_record(columns)?; + for row in data { + let str_row: Vec = row.iter().map(|v| v.clone().unwrap_or_default()).collect(); + writer.write_record(&str_row)?; + } + + writer.flush()?; + Ok(()) + } + + /// JSON 格式输出 + fn output_json(&self, output: &mut dyn Write, columns: &[String], data: &[Vec>]) -> Result<()> { + let result: Vec> = data + .iter() + .map(|row| { + columns + .iter() + .zip(row.iter()) + .map(|(col, val)| { + let v = match val { + Some(s) => serde_json::Value::String(s.clone()), + None => serde_json::Value::Null, + }; + (col.clone(), v) + }) + .collect() + }) + .collect(); + + let json = serde_json::to_string_pretty(&result)?; + writeln!(output, "{}", json)?; + Ok(()) + } + + /// 垂直格式输出 + fn output_vertical(&self, output: &mut dyn Write, columns: &[String], data: &[Vec>]) -> Result<()> { + let max_len = columns.iter().map(|c| c.len()).max().unwrap_or(0); + + for (row_idx, row) in data.iter().enumerate() { + if row_idx > 0 { + writeln!(output)?; + } + for (i, val) in row.iter().enumerate() { + let v = val.as_ref().map(|s| s.as_str()).unwrap_or("NULL"); + writeln!(output, "{:width$}: {}", columns[i], v, width = max_len)?; + } + } + + Ok(()) + } + + /// 从文件执行 SQL + pub fn execute_from_file(&self, filename: &str) -> Result<()> { + let content = std::fs::read_to_string(filename)?; + self.execute(&content) + } +} + +/// 将 Row 转换为字符串向量 +fn row_to_strings(row: &Row) -> Vec> { + row.columns() + .iter() + .enumerate() + .map(|(i, _col)| { + match row.get::(i) { + Some(value) => value_to_string(&value), + None => None, + } + }) + .collect() +} + +/// 将 Value 转换为字符串 +fn value_to_string(value: &Value) -> Option { + match value { + Value::NULL => None, + Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(), + Value::Int(i) => Some(i.to_string()), + Value::UInt(u) => Some(u.to_string()), + Value::Float(f) => Some(f.to_string()), + Value::Double(d) => Some(d.to_string()), + Value::Date(year, month, day, hour, min, sec, micro) => { + match (hour, min, sec, micro) { + (0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)), + _ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)), + } + } + Value::Time(neg, days, hours, minutes, seconds, microseconds) => { + let sign = if *neg { "-" } else { "" }; + Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds)) + } + } +} + +/// 分割多个 SQL 语句 +fn split_queries(query: &str) -> Vec { + let mut queries = Vec::new(); + let mut current = String::new(); + let mut in_quotes = false; + let mut quote_char = ' '; + let chars: Vec = query.chars().collect(); + + for i in 0..chars.len() { + let c = chars[i]; + + match c { + '\'' | '"' | '`' if !in_quotes => { + in_quotes = true; + quote_char = c; + current.push(c); + } + '\'' | '"' | '`' if in_quotes && c == quote_char => { + // 检查是否是转义 + let is_escaped = i > 0 && chars[i - 1] == '\\'; + if !is_escaped { + in_quotes = false; + } + current.push(c); + } + ';' if !in_quotes => { + queries.push(current.trim().to_string()); + current.clear(); + } + _ => { + current.push(c); + } + } + } + + if !current.trim().is_empty() { + queries.push(current.trim().to_string()); + } + + queries +} + +/// 打印表格行 +fn print_table_row(output: &mut dyn Write, data: &[String], widths: &[usize]) -> Result<()> { + for (i, val) in data.iter().enumerate() { + if i > 0 { + write!(output, " | ")?; + } + write!(output, "{: Result<()> { + for (i, w) in widths.iter().enumerate() { + if i > 0 { + write!(output, "-+-")?; + } else { + write!(output, "+")?; + } + for _ in 0..*w { + write!(output, "-")?; + } + } + writeln!(output, "+")?; + Ok(()) +} diff --git a/mysql-cli/src/main.rs b/mysql-cli/src/main.rs new file mode 100644 index 0000000..703e89b --- /dev/null +++ b/mysql-cli/src/main.rs @@ -0,0 +1,132 @@ +use anyhow::Result; +use clap::Parser; +use std::io::{self, Read}; + +mod config; +mod db; +mod executor; +mod repl; + +use config::Config; +use executor::{Executor, OutputFormat}; + +/// MySQL CLI - A lightweight MySQL command-line tool +#[derive(Parser, Debug)] +#[command(name = "mysql")] +#[command(about = "A lightweight MySQL command-line tool", long_about = None)] +struct Args { + /// MySQL host + #[arg(short = 'h', long, default_value = "127.0.0.1")] + host: String, + + /// MySQL port + #[arg(short = 'P', long, default_value_t = 3306)] + port: u16, + + /// MySQL username + #[arg(short = 'u', long, default_value = "root")] + username: String, + + /// MySQL password + #[arg(short = 'p', long, default_value = "")] + password: String, + + /// Database name + #[arg(short = 'D', long, default_value = "")] + database: String, + + /// Execute SQL statement + #[arg(short = 'e', long)] + execute: Option, + + /// Execute SQL from file + #[arg(short = 'f', long)] + file: Option, + + /// Output format: table, csv, json, vertical + #[arg(short = 'F', long, default_value = "table")] + format: String, + + /// Output file (default: stdout) + #[arg(short = 'o', long)] + output: Option, + + /// Interactive mode (REPL) + #[arg(short = 'i', long)] + interactive: bool, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + let config = Config { + host: args.host, + port: args.port, + username: args.username, + password: args.password, + database: args.database, + }; + + // 连接数据库 + let pool = match db::connect(&config) { + Ok(p) => p, + Err(e) => { + eprintln!("Error connecting to database: {}", e); + std::process::exit(1); + } + }; + + let format: OutputFormat = args.format.parse().unwrap_or_default(); + let executor = Executor::new(pool.clone(), format, args.output.clone()); + + // 检测 stdin 是否被重定向 + let is_piped = !atty::is(atty::Stream::Stdin); + + // 确定运行模式 + if let Some(file) = &args.file { + // 从文件执行 + if let Err(e) = executor.execute_from_file(file) { + eprintln!("Error executing file: {}", e); + std::process::exit(1); + } + } else if let Some(query) = &args.execute { + // 执行 SQL + if let Err(e) = executor.execute(query) { + eprintln!("Error executing SQL: {}", e); + std::process::exit(1); + } + } else if is_piped { + // stdin 被重定向,读取所有输入并执行 + let mut content = String::new(); + if let Err(e) = io::stdin().read_to_string(&mut content) { + eprintln!("Error reading stdin: {}", e); + std::process::exit(1); + } + if !content.trim().is_empty() { + if let Err(e) = executor.execute(&content) { + eprintln!("Error executing SQL: {}", e); + std::process::exit(1); + } + } + } else if args.interactive { + // 交互模式 + if let Err(e) = repl::run_repl(pool, config) { + eprintln!("Error: {}", e); + std::process::exit(1); + } + } else { + // 没有指定任何操作,显示使用提示 + eprintln!("ERROR: Missing required argument."); + eprintln!(); + eprintln!("Usage:"); + eprintln!(" mysql -h -P -u -p -D -e \"\""); + eprintln!(" mysql -i # Interactive mode"); + eprintln!(" mysql -f # Execute SQL file"); + eprintln!(" echo \"SELECT 1\" | mysql # Execute from stdin"); + eprintln!(); + eprintln!("Run 'mysql --help' for detailed documentation."); + std::process::exit(1); + } + + Ok(()) +} diff --git a/mysql-cli/src/repl.rs b/mysql-cli/src/repl.rs new file mode 100644 index 0000000..5739eec --- /dev/null +++ b/mysql-cli/src/repl.rs @@ -0,0 +1,297 @@ +use anyhow::Result; +use mysql::{Pool, prelude::*}; +use rustyline::error::ReadlineError; +use rustyline::highlight::Highlighter; +use rustyline::hint::Hinter; +use rustyline::history::DefaultHistory; +use rustyline::{Completer, CompletionType, Config, Context, EditMode, Editor, Helper, Validator}; +use std::borrow::Cow; +use std::time::Instant; + +use crate::config::Config as DbConfig; +use crate::db; +use crate::executor::{Executor, OutputFormat}; + +/// REPL 状态 +pub struct ReplState { + config: DbConfig, + format: OutputFormat, + output: Option, +} + +impl ReplState { + pub fn new(config: DbConfig) -> Self { + Self { + config, + format: OutputFormat::Table, + output: None, + } + } + + fn prompt(&self) -> String { + let db = if self.config.database.is_empty() { + "(none)" + } else { + &self.config.database + }; + format!("mysql [{}]> ", db) + } +} + +/// 智能补全器 +#[derive(Completer, Helper, Validator)] +pub struct SqlCompleter { + sql_keywords: Vec<&'static str>, +} + +impl SqlCompleter { + pub fn new() -> Self { + Self { + sql_keywords: vec![ + "SELECT", "INSERT", "UPDATE", "DELETE", "DROP", + "CREATE", "ALTER", "SHOW", "DESCRIBE", "DESC", "USE", + "FROM", "WHERE", "ORDER", "GROUP", "LIMIT", "OFFSET", + "JOIN", "LEFT", "RIGHT", "INNER", "ON", "AS", + "AND", "OR", "NOT", "IN", "LIKE", "BETWEEN", + "BY", "TABLE", "DATABASE", "INDEX", "VIEW", + "SET", "VALUES", "INTO", "DISTINCT", "COUNT", + "SUM", "AVG", "MAX", "MIN", "NULL", "IS", + ], + } + } +} + +impl Hinter for SqlCompleter { + type Hint = String; + + fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option { + if line.is_empty() || pos < line.len() { + return None; + } + + let line_upper = line.to_uppercase(); + for keyword in &self.sql_keywords { + if keyword.starts_with(&line_upper) && keyword.len() > line.len() { + return Some(keyword[line.len()..].to_string()); + } + } + None + } +} + +impl Highlighter for SqlCompleter { + fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { + Cow::Borrowed(line) + } + + fn highlight_char(&self, _line: &str, _pos: usize, _forced: bool) -> bool { + false + } +} + +/// 运行交互式 REPL +pub fn run_repl(pool: Pool, config: DbConfig) -> Result<()> { + let mut state = ReplState::new(config); + let executor = Executor::new(pool.clone(), state.format, state.output.clone()); + + // 打印欢迎信息 + print_welcome(&pool, &state.config)?; + + // 初始化 rustyline + let rl_config = Config::builder() + .history_ignore_space(true) + .completion_type(CompletionType::List) + .edit_mode(EditMode::Emacs) + .build(); + + let mut rl: Editor = Editor::with_config(rl_config)?; + rl.set_helper(Some(SqlCompleter::new())); + + loop { + let prompt = state.prompt(); + let readline = rl.readline(&prompt); + + match readline { + Ok(line) => { + let line = line.trim(); + if line.is_empty() { + continue; + } + + // 添加到历史记录 + let _ = rl.add_history_entry(line); + + // 处理命令 + if let Err(e) = handle_command(&pool, &mut state, &executor, line) { + eprintln!("Error: {}", e); + } + } + Err(ReadlineError::Interrupted) => { + println!("^C"); + continue; + } + Err(ReadlineError::Eof) => { + println!("Bye"); + break; + } + Err(err) => { + eprintln!("Error: {}", err); + break; + } + } + } + + Ok(()) +} + +/// 打印欢迎信息 +fn print_welcome(pool: &Pool, config: &DbConfig) -> Result<()> { + println!(); + println!("Welcome to MySQL CLI (Rust). Commands end with ;"); + + let mut conn = pool.get_conn()?; + let version = db::get_server_version(&mut conn); + println!("Server version: {}", version); + println!("Connected to: {}", config); + println!(); + println!("Commands:"); + println!(" use - Switch database"); + println!(" status - Show connection status"); + println!(" source - Execute SQL from file"); + println!(" format - Set output format (table, csv, json, vertical)"); + println!(" output - Set output file"); + println!(" exit/quit - Exit the program"); + println!(); + + Ok(()) +} + +/// 处理命令 +fn handle_command(pool: &Pool, state: &mut ReplState, executor: &Executor, line: &str) -> Result<()> { + let lower_line = line.to_lowercase(); + + // 处理特殊命令 + match lower_line.as_str() { + "exit" | "quit" => { + println!("Bye"); + std::process::exit(0); + } + "status" => { + return show_status(pool, state); + } + _ => {} + } + + if lower_line.starts_with("use ") { + return use_database(pool, state, &line[4..]); + } + + if lower_line.starts_with("source ") { + return source_file(executor, &line[7..]); + } + + if lower_line.starts_with("format ") { + return set_format(state, &line[7..]); + } + + if lower_line.starts_with("output ") { + return set_output(state, &line[7..]); + } + + // 执行 SQL + execute_sql(pool, state, line) +} + +/// 显示状态 +fn show_status(pool: &Pool, state: &ReplState) -> Result<()> { + println!(); + println!("Connection Status:"); + println!(" Host: {}:{}", state.config.host, state.config.port); + println!(" User: {}", state.config.username); + println!(" Database: {}", state.config.database); + + let mut conn = pool.get_conn()?; + let version = db::get_server_version(&mut conn); + println!(" Server Version: {}", version); + + if let Some(current_db) = db::get_current_database(&mut conn) { + println!(" Current Database: {}", current_db); + } + + println!(); + Ok(()) +} + +/// 切换数据库 +fn use_database(pool: &Pool, state: &mut ReplState, db_name: &str) -> Result<()> { + let db_name = db_name.trim().trim_end_matches(';').trim(); + + if db_name.is_empty() { + anyhow::bail!("database name is required"); + } + + let mut conn = pool.get_conn()?; + conn.query_iter(format!("USE `{}`", db_name))?; + + state.config.database = db_name.to_string(); + println!("Database changed"); + Ok(()) +} + +/// 从文件执行 SQL +fn source_file(executor: &Executor, filename: &str) -> Result<()> { + let filename = filename.trim().trim_end_matches(';').trim(); + + if filename.is_empty() { + anyhow::bail!("filename is required"); + } + + executor.execute_from_file(filename)?; + println!(); + Ok(()) +} + +/// 设置输出格式 +fn set_format(state: &mut ReplState, format: &str) -> Result<()> { + let format = format.trim().trim_end_matches(';').trim(); + state.format = format.parse()?; + println!("Output format set to {}", format); + println!(); + Ok(()) +} + +/// 设置输出文件 +fn set_output(state: &mut ReplState, output: &str) -> Result<()> { + let output = output.trim().trim_end_matches(';').trim(); + + if output.is_empty() { + state.output = None; + println!("Output reset to stdout"); + } else { + state.output = Some(output.to_string()); + println!("Output set to {}", output); + } + println!(); + Ok(()) +} + +/// 执行 SQL(支持多行) +fn execute_sql(pool: &Pool, state: &ReplState, line: &str) -> Result<()> { + let full_query = line.to_string(); + + let start = Instant::now(); + let executor = Executor::new(pool.clone(), state.format, state.output.clone()); + executor.execute(&full_query)?; + + let elapsed = start.elapsed(); + if elapsed.as_millis() > 100 { + println!(); + println!("({:.2} sec)", elapsed.as_secs_f64()); + } else { + println!(); + println!("({:.3} sec)", elapsed.as_secs_f64()); + } + + println!(); + Ok(()) +} diff --git a/mysql-proxy/Cargo.toml b/mysql-proxy/Cargo.toml new file mode 100644 index 0000000..a7f7fed --- /dev/null +++ b/mysql-proxy/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "mysql-proxy" +version = "0.1.0" +edition = "2021" +description = "MySQL HTTP proxy with connection pooling" + +[dependencies] +mysql = "25" +tokio = { version = "1", features = ["full"] } +axum = "0.7" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +toml = "0.8" +anyhow = "1" +clap = { version = "4", features = ["derive"] } +ureq = "2" # CLI HTTP client (2.x has simpler API) + +[profile.release] +opt-level = "z" +lto = true +strip = true diff --git a/mysql-proxy/src/cli.rs b/mysql-proxy/src/cli.rs new file mode 100644 index 0000000..9413161 --- /dev/null +++ b/mysql-proxy/src/cli.rs @@ -0,0 +1,287 @@ +use anyhow::{Result, bail}; +use serde::{Deserialize, Serialize}; +use std::time::Instant; + +const DEFAULT_SERVER: &str = "http://127.0.0.1:3307"; + +#[derive(Debug, Serialize)] +struct QueryRequest { + conn: String, + sql: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct QueryResponse { + columns: Vec, + rows: Vec>>, + #[serde(rename = "rowCount")] + row_count: usize, + #[serde(rename = "durationMs")] + duration_ms: u64, +} + +#[derive(Debug, Deserialize)] +struct ErrorResponse { + error: String, +} + +#[derive(Debug, Deserialize)] +struct ConnectionsResponse { + connections: Vec, +} + +#[derive(Debug, Deserialize)] +struct ConnectionInfo { + name: String, + database: String, + host: String, + status: String, +} + +/// CLI 客户端 +pub struct Cli { + server: String, +} + +impl Cli { + pub fn new(server: Option) -> Self { + Self { + server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()), + } + } + + /// 执行查询并输出结果 + pub fn query(&self, conn: &str, sql: &str, format: Option<&str>) -> Result<()> { + let start = Instant::now(); + + let body = serde_json::to_string(&QueryRequest { + conn: conn.to_string(), + sql: sql.to_string(), + })?; + + let response = ureq::post(&format!("{}/query", self.server)) + .set("Content-Type", "application/json") + .send_string(&body); + + // ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中 + let body = match response { + Ok(r) => r.into_string()?, + Err(ureq::Error::Status(_, resp)) => resp.into_string()?, + Err(e) => bail!("HTTP error: {}", e), + }; + + let total_ms = start.elapsed().as_millis(); + + // 解析响应 + if let Ok(err) = serde_json::from_str::(&body) { + eprintln!("Error: {}", err.error); + return Ok(()); + } + + let result: QueryResponse = serde_json::from_str(&body)?; + + // 输出结果 + match format { + Some("json") => self.print_json(&result), + Some("csv") => self.print_csv(&result), + Some("vertical") | Some("vert") => self.print_vertical(&result), + _ => self.print_table(&result), + } + + println!("\n{} rows in set ({}ms db, {}ms total)", + result.row_count, result.duration_ms, total_ms); + + Ok(()) + } + + /// 执行语句 (INSERT/UPDATE/DELETE) + pub fn execute(&self, conn: &str, sql: &str) -> Result<()> { + let start = Instant::now(); + + let body = serde_json::to_string(&QueryRequest { + conn: conn.to_string(), + sql: sql.to_string(), + })?; + + let response = ureq::post(&format!("{}/execute", self.server)) + .set("Content-Type", "application/json") + .send_string(&body); + + // ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中 + let body = match response { + Ok(r) => r.into_string()?, + Err(ureq::Error::Status(_, resp)) => resp.into_string()?, + Err(e) => bail!("HTTP error: {}", e), + }; + + let total_ms = start.elapsed().as_millis(); + + if let Ok(err) = serde_json::from_str::(&body) { + eprintln!("Error: {}", err.error); + return Ok(()); + } + + #[derive(Deserialize)] + struct ExecuteResponse { + #[serde(rename = "affectedRows")] + affected_rows: u64, + #[serde(rename = "lastInsertId")] + last_insert_id: u64, + } + + let result: ExecuteResponse = serde_json::from_str(&body)?; + println!("Query OK, {} rows affected ({}ms total)", result.affected_rows, total_ms); + + if result.last_insert_id > 0 { + println!("Last insert ID: {}", result.last_insert_id); + } + + Ok(()) + } + + /// 列出连接 + pub fn list_connections(&self) -> Result<()> { + let response = ureq::get(&format!("{}/connections", self.server)) + .call()?; + + let body = response.into_string()?; + + if let Ok(err) = serde_json::from_str::(&body) { + bail!("{}", err.error); + } + + let result: ConnectionsResponse = serde_json::from_str(&body)?; + + println!("Connections:"); + println!("{:<15} {:<20} {:<40} {:<10}", "Name", "Database", "Host", "Status"); + println!("{}", "-".repeat(85)); + + for conn in result.connections { + println!("{:<15} {:<20} {:<40} {:<10}", + conn.name, conn.database, conn.host, conn.status); + } + + Ok(()) + } + + /// 检查代理是否运行 + pub fn check_server(&self) -> Result { + match ureq::get(&format!("{}/health", self.server)).call() { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } + + // 输出格式化方法 + fn print_table(&self, result: &QueryResponse) { + if result.columns.is_empty() { + return; + } + + // 计算列宽 + let mut widths: Vec = result.columns.iter().map(|c| c.len()).collect(); + + for row in &result.rows { + for (i, val) in row.iter().enumerate() { + let len = val.as_ref().map(|s| s.len()).unwrap_or(4); // NULL + widths[i] = widths[i].max(len); + } + } + + // 打印表头 + let header: String = result.columns.iter() + .enumerate() + .map(|(i, c)| format!(" {:width$} ", c, width = widths[i])) + .collect::>() + .join("|"); + + println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::>().join("+")); + println!("|{}|", header); + println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::>().join("+")); + + // 打印数据行 + for row in &result.rows { + let line: String = row.iter() + .enumerate() + .map(|(i, val)| { + match val { + Some(s) => format!(" {:width$} ", s, width = widths[i]), + None => format!(" {:width$} ", "NULL", width = widths[i]), + } + }) + .collect::>() + .join("|"); + println!("|{}|", line); + } + + println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::>().join("+")); + } + + fn print_json(&self, result: &QueryResponse) { + // 极简格式:对象数组 + let rows: Vec = result.rows.iter().map(|row| { + let mut obj = serde_json::Map::new(); + for (i, col) in result.columns.iter().enumerate() { + let val = row.get(i) + .and_then(|v| v.as_ref()) + .map(|s| { + // 尝试解析为数字 + if let Ok(n) = s.parse::() { + serde_json::Value::Number(n.into()) + } else if let Ok(n) = s.parse::() { + serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0))) + } else { + serde_json::Value::String(s.clone()) + } + }) + .unwrap_or(serde_json::Value::Null); + obj.insert(col.clone(), val); + } + serde_json::Value::Object(obj) + }).collect(); + + println!("{}", serde_json::to_string(&rows).unwrap_or_default()); + } + + fn print_csv(&self, result: &QueryResponse) { + // 打印表头 + println!("{}", result.columns.join(",")); + + // 打印数据 + for row in &result.rows { + let line: String = row.iter() + .map(|val| { + match val { + Some(s) => { + if s.contains(',') || s.contains('"') || s.contains('\n') { + format!("\"{}\"", s.replace('"', "\"\"")) + } else { + s.clone() + } + } + None => "".to_string(), + } + }) + .collect::>() + .join(","); + println!("{}", line); + } + } + + fn print_vertical(&self, result: &QueryResponse) { + for (row_idx, row) in result.rows.iter().enumerate() { + if row_idx > 0 { + println!(); + } + println!("*************************** {}. row ***************************", row_idx + 1); + + for (col_idx, col) in result.columns.iter().enumerate() { + let val = row.get(col_idx) + .and_then(|v| v.as_ref()) + .map(|s| s.as_str()) + .unwrap_or("NULL"); + println!("{:>20}: {}", col, val); + } + } + } +} \ No newline at end of file diff --git a/mysql-proxy/src/config.rs b/mysql-proxy/src/config.rs new file mode 100644 index 0000000..d1c4fd3 --- /dev/null +++ b/mysql-proxy/src/config.rs @@ -0,0 +1,156 @@ +use anyhow::{Result, bail}; +use serde::Deserialize; +use std::fs; + +#[derive(Debug, Deserialize, Clone)] +pub struct ServerConfig { + pub port: u16, + #[serde(default = "default_host")] + pub host: String, +} + +fn default_host() -> String { + "127.0.0.1".to_string() +} + +/// 连接池配置 +#[derive(Debug, Deserialize, Clone)] +pub struct PoolConfig { + /// 默认最大连接数 + #[serde(default = "default_max_connections")] + pub default_max_connections: usize, + /// 空闲超时 (秒) + #[serde(default = "default_idle_timeout")] + pub idle_timeout_secs: u64, + /// 检查间隔 (秒) + #[serde(default = "default_check_interval")] + pub check_interval_secs: u64, +} + +fn default_max_connections() -> usize { 5 } +fn default_idle_timeout() -> u64 { 300 } +fn default_check_interval() -> u64 { 60 } + +impl Default for PoolConfig { + fn default() -> Self { + Self { + default_max_connections: default_max_connections(), + idle_timeout_secs: default_idle_timeout(), + check_interval_secs: default_check_interval(), + } + } +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ConnectionConfig { + pub name: String, + pub host: String, + pub port: u16, + pub user: String, + pub password: String, + pub database: String, + /// 最大连接数 (可选,覆盖默认值) + #[serde(default)] + pub max_connections: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Config { + #[serde(default = "default_server")] + pub server: ServerConfig, + #[serde(default)] + pub pool: PoolConfig, + pub connections: Vec, +} + +fn default_server() -> ServerConfig { + ServerConfig { + port: 3307, + host: "127.0.0.1".to_string(), + } +} + +impl Config { + /// 从文件加载配置 + pub fn from_file(path: &str) -> Result { + let content = fs::read_to_string(path)?; + let config: Config = toml::from_str(&content)?; + config.validate()?; + Ok(config) + } + + /// 从默认路径加载配置 + pub fn load() -> Result { + // 尝试多个配置文件路径 + let paths = [ + "mysql-proxy.toml", + "./config/mysql-proxy.toml", + &format!("{}/.mysql-proxy.toml", std::env::var("HOME").unwrap_or_default()), + ]; + + for path in &paths { + if std::path::Path::new(path).exists() { + return Self::from_file(path); + } + } + + bail!("Config file not found. Create mysql-proxy.toml in current directory or use --config flag") + } + + /// 验证配置 + fn validate(&self) -> Result<()> { + let mut names = std::collections::HashSet::new(); + for conn in &self.connections { + if conn.name.is_empty() { + bail!("Connection name cannot be empty"); + } + if names.contains(&conn.name) { + bail!("Duplicate connection name: {}", conn.name); + } + names.insert(conn.name.clone()); + } + Ok(()) + } +} + +impl ConnectionConfig { + /// 构建 DSN + pub fn build_dsn(&self) -> String { + let password = urlencoding(&self.password); + format!( + "mysql://{}:{}@{}:{}/{}", + self.user, password, self.host, self.port, self.database + ) + } +} + +/// URL 编码密码中的特殊字符 +fn urlencoding(s: &str) -> String { + let mut result = String::new(); + for c in s.chars() { + match c { + '@' => result.push_str("%40"), + ':' => result.push_str("%3A"), + '/' => result.push_str("%2F"), + '?' => result.push_str("%3F"), + '#' => result.push_str("%23"), + '[' => result.push_str("%5B"), + ']' => result.push_str("%5D"), + '!' => result.push_str("%21"), + '$' => result.push_str("%24"), + '&' => result.push_str("%26"), + '\'' => result.push_str("%27"), + '(' => result.push_str("%28"), + ')' => result.push_str("%29"), + '*' => result.push_str("%2A"), + '+' => result.push_str("%2B"), + ',' => result.push_str("%2C"), + ';' => result.push_str("%3B"), + '=' => result.push_str("%3D"), + '%' => result.push_str("%25"), + ' ' => result.push_str("%20"), + _ => result.push(c), + } + } + result +} diff --git a/mysql-proxy/src/db.rs b/mysql-proxy/src/db.rs new file mode 100644 index 0000000..eb62cbc --- /dev/null +++ b/mysql-proxy/src/db.rs @@ -0,0 +1,173 @@ +use anyhow::{Result, bail}; +use mysql::{Pool, PooledConn, Opts,OptsBuilder, prelude::*}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use crate::config::{ConnectionConfig, PoolConfig}; + +/// 连接池状态 +struct PoolState { + pool: Arc, + last_used: Instant, +} + +/// 连接池管理器 (延迟初始化) +pub struct ConnectionManager { + pools: Mutex>, + configs: Mutex>, + pool_config: PoolConfig, +} + +impl ConnectionManager { + /// 创建连接管理器 (不初始化连接) + pub fn new(configs: &[ConnectionConfig], pool_config: PoolConfig) -> Result { + let mut config_map = HashMap::new(); + + for cfg in configs { + config_map.insert(cfg.name.clone(), cfg.clone()); + println!(" Registered: {} ({}/{})", cfg.name, cfg.host, cfg.database); + } + + println!("\n{} connection(s) configured (lazy init)", config_map.len()); + println!("Pool config: max_connections={}, idle_timeout={}s", + pool_config.default_max_connections, pool_config.idle_timeout_secs); + + Ok(Self { + pools: Mutex::new(HashMap::new()), + configs: Mutex::new(config_map), + pool_config, + }) + } + + /// 获取或创建连接 + pub fn get_conn(&self, name: &str) -> Result { + // 1. 先尝试从已有池中获取 + { + let mut pools = self.pools.lock().unwrap(); + if let Some(state) = pools.get_mut(name) { + state.last_used = Instant::now(); + return Ok(state.pool.get_conn()?); + } + } + + // 2. 没有则创建新池 + let cfg = self.configs.lock().unwrap().get(name) + .ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))? + .clone(); + + println!("[LazyInit] Creating connection pool for: {}", name); + let pool = self.create_pool(&cfg)?; + let arc_pool = Arc::new(pool); + + // 3. 保存并返回连接 + { + let mut pools = self.pools.lock().unwrap(); + pools.insert(name.to_string(), PoolState { + pool: arc_pool.clone(), + last_used: Instant::now(), + }); + } + + Ok(arc_pool.get_conn()?) + } + + /// 创建连接池 + fn create_pool(&self, cfg: &ConnectionConfig) -> Result { + let dsn = cfg.build_dsn(); + let opts: Opts = Opts::from_url(&dsn)?; + + // 获取最大连接数 (单个配置 > 全局默认) + let max_conn = cfg.max_connections.unwrap_or(self.pool_config.default_max_connections); + + // 构建带连接池参数的选项 + let pool_opts = OptsBuilder::from_opts(opts) + .pool_opts(mysql::PoolOpts::new() + .with_constraints(mysql::PoolConstraints::new(1, max_conn).unwrap())); + + let pool = Pool::new(pool_opts)?; + + // 测试连接 + let mut conn = pool.get_conn()?; + conn.query::("SELECT 1")?; + + println!("[LazyInit] ✓ Connected: {} (max_conn={})", cfg.name, max_conn); + Ok(pool) + } + + /// 获取所有连接信息 + pub fn list_connections(&self) -> Vec { + let pools = self.pools.lock().unwrap(); + let configs = self.configs.lock().unwrap(); + + configs.iter().map(|(name, cfg)| { + let status = if pools.contains_key(name) { "connected" } else { "pending" }; + ConnectionInfo { + name: name.clone(), + database: cfg.database.clone(), + host: cfg.host.clone(), + status: status.to_string(), + } + }).collect() + } + + /// 清理空闲连接池 + pub fn cleanup_idle(&self) { + let mut pools = self.pools.lock().unwrap(); + let now = Instant::now(); + let idle_timeout = Duration::from_secs(self.pool_config.idle_timeout_secs); + + pools.retain(|name, state| { + let elapsed = now.duration_since(state.last_used); + if elapsed > idle_timeout { + println!("[Cleanup] Removing idle pool: {} (idle {}s)", name, elapsed.as_secs()); + return false; + } + true + }); + } + + /// 检查连接是否健康 + pub fn health_check(&self, name: &str) -> bool { + let pools = self.pools.lock().unwrap(); + if let Some(state) = pools.get(name) { + return state.pool.get_conn().ok() + .and_then(|mut c| c.query::("SELECT 1").ok()) + .is_some(); + } + false + } + + /// 动态添加连接配置 (临时,重启后消失) + pub fn add_connection(&self, cfg: ConnectionConfig) -> Result<()> { + let name = cfg.name.clone(); + let mut configs = self.configs.lock().unwrap(); + if configs.contains_key(&name) { + bail!("Connection '{}' already exists", name); + } + + println!("[Dynamic] Adding: {} ({}/{})", name, cfg.host, cfg.database); + + // 测试连接 + let pool = self.create_pool(&cfg)?; + let arc_pool = Arc::new(pool); + + // 保存 + self.pools.lock().unwrap().insert(name.clone(), PoolState { + pool: arc_pool, + last_used: Instant::now(), + }); + configs.insert(name.clone(), cfg); + + println!("[Dynamic] ✓ Added: {}", name); + Ok(()) + } +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct ConnectionInfo { + pub name: String, + pub database: String, + pub host: String, + pub status: String, +} diff --git a/mysql-proxy/src/handler.rs b/mysql-proxy/src/handler.rs new file mode 100644 index 0000000..964b77c --- /dev/null +++ b/mysql-proxy/src/handler.rs @@ -0,0 +1,320 @@ +use axum::{ + extract::State, + http::StatusCode, + Json, +}; +use mysql::{prelude::*, Value, Row as MysqlRow}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Instant; + +use crate::db::ConnectionManager; +use crate::logger::{LogEntry, RequestLogger}; + +/// 应用状态 +pub type AppState = Arc<(Arc, Arc)>; + +// ============== 请求结构 ============== + +#[derive(Debug, Deserialize)] +pub struct QueryRequest { + /// 连接名称 + pub conn: String, + /// SQL 语句 + pub sql: String, + /// 输出格式: json, table, csv, vertical + #[serde(default = "default_format")] + pub format: String, +} + +fn default_format() -> String { + "json".to_string() +} + +#[derive(Debug, Deserialize)] +pub struct ExecuteRequest { + pub conn: String, + pub sql: String, +} + +#[derive(Debug, Deserialize)] +pub struct AddConnectionRequest { + pub name: String, + pub host: String, + pub port: u16, + pub user: String, + pub password: String, + pub database: String, +} + +// ============== 响应结构 ============== + +#[derive(Debug, Serialize)] +pub struct QueryResponse { + pub columns: Vec, + pub rows: Vec>>, + #[serde(rename = "rowCount")] + pub row_count: usize, + #[serde(rename = "durationMs")] + pub duration_ms: u64, +} + +#[derive(Debug, Serialize)] +pub struct ExecuteResponse { + #[serde(rename = "affectedRows")] + pub affected_rows: u64, + #[serde(rename = "lastInsertId")] + pub last_insert_id: u64, + #[serde(rename = "durationMs")] + pub duration_ms: u64, +} + +#[derive(Debug, Serialize)] +pub struct ConnectionsResponse { + pub connections: Vec, +} + +#[derive(Debug, Serialize)] +pub struct HealthResponse { + pub status: String, + pub connections: usize, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, + #[serde(rename = "usage")] + pub usage: Option, +} + +impl ErrorResponse { + pub fn new(msg: &str) -> Self { + Self { error: msg.to_string(), usage: None } + } + + pub fn with_usage(mut self, usage: &str) -> Self { + self.usage = Some(usage.to_string()); + self + } +} + +// ============== 处理器 ============== + +/// 查询处理器 +pub async fn query( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + let start = Instant::now(); + + // 获取连接 + let mut conn = manager.get_conn(&req.conn) + .map_err(|e| { + let usage = "Usage: POST /query {\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}\n" + .to_string() + &format!("Available connections: {}", list_conn_names(&manager)); + error_response_with_usage(&e.to_string(), &usage) + })?; + + // 判断是否是查询语句 + let sql_upper = req.sql.trim().to_uppercase(); + let is_query = sql_upper.starts_with("SELECT") + || sql_upper.starts_with("SHOW") + || sql_upper.starts_with("DESCRIBE") + || sql_upper.starts_with("DESC ") + || sql_upper.starts_with("EXPLAIN") + || sql_upper.starts_with("WITH"); + + if !is_query { + let err = error_response_with_usage( + "Not a SELECT query. Use /execute for INSERT/UPDATE/DELETE", + "Usage:\n POST /query {\"conn\": \"name\", \"sql\": \"SELECT ...\"}\n POST /execute {\"conn\": \"name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}" + ); + return Err(err); + } + + // 执行查询 + let result = conn.query_iter(&req.sql) + .map_err(|e| { + let usage = format!("SQL Error: {}\n\nUsage: POST /query {{\"conn\": \"name\", \"sql\": \"SELECT ...\"}}", e); + error_response_with_usage(&e.to_string(), &usage) + })?; + + let mut columns: Vec = Vec::new(); + let mut data: Vec>> = Vec::new(); + + // 获取数据并从第一行提取列名 + for row_result in result { + let row = row_result.map_err(|e| error_response(&e.to_string()))?; + + // 从第一行获取列名 + if columns.is_empty() { + columns = row.columns() + .iter() + .map(|c| c.name_str().to_string()) + .collect(); + } + + let values = row_to_strings(&row, columns.len()); + data.push(values); + } + + let duration = start.elapsed(); + let row_count = data.len(); + + // 记录日志 + logger.log(&LogEntry::new("/query", "http") + .with_conn(&req.conn) + .with_sql(&req.sql) + .with_duration(duration.as_millis() as u64) + .with_rows(row_count)); + + Ok(Json(QueryResponse { + columns, + rows: data, + row_count, + duration_ms: duration.as_millis() as u64, + })) +} + +/// 执行处理器 (INSERT/UPDATE/DELETE) +pub async fn execute( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + let start = Instant::now(); + + // 获取连接 + let mut conn = manager.get_conn(&req.conn) + .map_err(|e| { + let usage = "Usage: POST /execute {\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}\n" + .to_string() + &format!("Available connections: {}", list_conn_names(&manager)); + error_response_with_usage(&e.to_string(), &usage) + })?; + + // 执行 + let result = conn.query_iter(&req.sql) + .map_err(|e| error_response(&e.to_string()))?; + + let duration = start.elapsed(); + let affected = result.affected_rows(); + let last_id = result.last_insert_id().unwrap_or(0); + + // 记录日志 + logger.log(&LogEntry::new("/execute", "http") + .with_conn(&req.conn) + .with_sql(&req.sql) + .with_duration(duration.as_millis() as u64)); + + Ok(Json(ExecuteResponse { + affected_rows: affected, + last_insert_id: last_id, + duration_ms: duration.as_millis() as u64, + })) +} + +/// 获取连接列表 +pub async fn connections( + State(state): State, +) -> Json { + let (manager, _) = &*state; + Json(ConnectionsResponse { + connections: manager.list_connections(), + }) +} + +/// 健康检查 +pub async fn health( + State(state): State, +) -> Json { + let (manager, _) = &*state; + Json(HealthResponse { + status: "ok".to_string(), + connections: manager.list_connections().len(), + }) +} + +/// 动态添加连接 +pub async fn add_connection( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + use crate::config::ConnectionConfig; + + let cfg = ConnectionConfig { + name: req.name.clone(), + host: req.host, + port: req.port, + user: req.user, + password: req.password, + database: req.database, + max_connections: None, + }; + + manager.add_connection(cfg) + .map_err(|e| error_response(&e.to_string()))?; + + // 记录日志 + logger.log(&LogEntry::new("/connections/add", "http") + .with_conn(&req.name) + .with_duration(0)); + + Ok(Json(serde_json::json!({ + "success": true, + "message": format!("Connection '{}' added (temporary, will be lost on restart)", req.name) + }))) +} + +// ============== 辅助函数 ============== + +fn error_response(msg: &str) -> (StatusCode, Json) { + (StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg))) +} + +fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json) { + (StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage))) +} + +fn list_conn_names(manager: &ConnectionManager) -> String { + manager.list_connections() + .iter() + .map(|c| c.name.clone()) + .collect::>() + .join(", ") +} + +/// 将 MySQL Row 转换为字符串向量 +fn row_to_strings(row: &MysqlRow, col_count: usize) -> Vec> { + (0..col_count) + .map(|i| { + match row.get::(i) { + Some(value) => value_to_string(&value), + None => None, + } + }) + .collect() +} + +/// 将 Value 转换为字符串 +fn value_to_string(value: &Value) -> Option { + match value { + Value::NULL => None, + Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(), + Value::Int(i) => Some(i.to_string()), + Value::UInt(u) => Some(u.to_string()), + Value::Float(f) => Some(f.to_string()), + Value::Double(d) => Some(d.to_string()), + Value::Date(year, month, day, hour, min, sec, micro) => { + match (hour, min, sec, micro) { + (0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)), + _ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)), + } + } + Value::Time(neg, days, hours, minutes, seconds, microseconds) => { + let sign = if *neg { "-" } else { "" }; + Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds)) + } + } +} diff --git a/mysql-proxy/src/logger.rs b/mysql-proxy/src/logger.rs new file mode 100644 index 0000000..f39207e --- /dev/null +++ b/mysql-proxy/src/logger.rs @@ -0,0 +1,228 @@ +use serde::{Deserialize, Serialize}; +use std::fs::{File, OpenOptions}; +use std::io::Write; +use std::path::PathBuf; +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// 日志记录器 +pub struct RequestLogger { + log_file: Mutex>, + enabled: bool, +} + +impl RequestLogger { + pub fn new(log_path: Option<&str>) -> Self { + let (log_file, enabled) = if let Some(path) = log_path { + let path = expand_path(path); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + match OpenOptions::new() + .create(true) + .append(true) + .open(&path) + { + Ok(file) => (Some(file), true), + Err(e) => { + eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e); + (None, false) + } + } + } else { + (None, false) + }; + + Self { + log_file: Mutex::new(log_file), + enabled, + } + } + + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// 记录请求日志 + pub fn log(&self, entry: &LogEntry) { + if !self.enabled { + return; + } + + let json = match serde_json::to_string(entry) { + Ok(j) => j, + Err(e) => { + eprintln!("[Logger] Failed to serialize log: {}", e); + return; + } + }; + + if let Ok(mut file) = self.log_file.lock() { + if let Some(ref mut f) = *file { + let _ = writeln!(f, "{}", json); + } + } + } +} + +/// 日志条目 +#[derive(Debug, Serialize, Deserialize)] +pub struct LogEntry { + pub timestamp: String, + pub level: String, + #[serde(rename = "type")] + pub log_type: String, + pub client: String, + pub endpoint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub conn: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sql: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub command: Option, + #[serde(rename = "durationMs")] + pub duration_ms: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub rows: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "exitCode")] + pub exit_code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl LogEntry { + pub fn new(endpoint: &str, client: &str) -> Self { + Self { + timestamp: current_timestamp(), + level: "INFO".to_string(), + log_type: "request".to_string(), + client: client.to_string(), + endpoint: endpoint.to_string(), + conn: None, + server: None, + sql: None, + command: None, + duration_ms: 0, + rows: None, + exit_code: None, + error: None, + } + } + + pub fn with_conn(mut self, conn: &str) -> Self { + self.conn = Some(conn.to_string()); + self + } + + pub fn with_server(mut self, server: &str) -> Self { + self.server = Some(server.to_string()); + self + } + + pub fn with_sql(mut self, sql: &str) -> Self { + // 截断过长的 SQL + self.sql = Some(truncate_string(sql, 1000)); + self + } + + pub fn with_command(mut self, command: &str) -> Self { + // 截断过长的命令 + self.command = Some(truncate_string(command, 500)); + self + } + + pub fn with_duration(mut self, ms: u64) -> Self { + self.duration_ms = ms; + self + } + + pub fn with_rows(mut self, rows: usize) -> Self { + self.rows = Some(rows); + self + } + + pub fn with_exit_code(mut self, code: i32) -> Self { + self.exit_code = Some(code); + self + } + + pub fn with_error(mut self, error: &str) -> Self { + self.level = "ERROR".to_string(); + self.error = Some(error.to_string()); + self + } +} + +fn current_timestamp() -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + let secs = now.as_secs(); + let datetime = chrono_timestamp(secs); + format!("{}Z", datetime) +} + +fn chrono_timestamp(secs: u64) -> String { + let days = secs / 86400; + let remaining = secs % 86400; + let hours = remaining / 3600; + let minutes = (remaining % 3600) / 60; + let seconds = remaining % 60; + + // 从 1970-01-01 开始计算日期 + let mut year = 1970; + let mut days_left = days; + + loop { + let days_in_year = if is_leap_year(year) { 366 } else { 365 }; + if days_left < days_in_year { + break; + } + days_left -= days_in_year; + year += 1; + } + + let month_days = if is_leap_year(year) { + [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + } else { + [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + }; + + let mut month = 1; + for &days_in_month in &month_days { + if days_left < days_in_month { + break; + } + days_left -= days_in_month; + month += 1; + } + let day = days_left + 1; + + format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds) +} + +fn is_leap_year(year: u64) -> bool { + (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) +} + +fn truncate_string(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + s.to_string() + } else { + format!("{}... (truncated)", &s[..max_len]) + } +} + +fn expand_path(path: &str) -> PathBuf { + if path.starts_with('~') { + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_default(); + PathBuf::from(path.replacen('~', &home, 1)) + } else { + PathBuf::from(path) + } +} diff --git a/mysql-proxy/src/main.rs b/mysql-proxy/src/main.rs new file mode 100644 index 0000000..885988a --- /dev/null +++ b/mysql-proxy/src/main.rs @@ -0,0 +1,212 @@ +mod cli; +mod config; +mod db; +mod handler; +mod logger; + +use axum::{ + routing::{get, post}, + Router, +}; +use clap::{Parser, Subcommand}; +use std::sync::Arc; + +#[derive(Parser, Debug)] +#[command(name = "mysql-proxy")] +#[command(about = "MySQL HTTP proxy with connection pooling")] +struct Args { + #[command(subcommand)] + command: Option, + + /// Config file path + #[arg(short, long, default_value = "mysql-proxy.toml", global = true)] + config: String, + + /// Server URL (for CLI mode) + #[arg(short = 'S', long, default_value = "http://127.0.0.1:3307", global = true)] + server: String, +} + +#[derive(Debug, Subcommand)] +enum Commands { + /// Start HTTP server (default) + Server { + /// Port to listen on + #[arg(short = 'P', long)] + port: Option, + + /// Host to bind + #[arg(short = 'H', long)] + host: Option, + }, + + /// Execute SQL via proxy (for AI clients) + Cli { + /// Connection name + #[arg(short, long, default_value = "flux_dev")] + conn: String, + + /// SQL to execute (与 mysql 官方一致) + #[arg(short = 'e', long)] + sql: Option, + + /// Output format: table, json, csv, vertical + #[arg(short = 'F', long, default_value = "table")] + format: String, + + /// Execute INSERT/UPDATE/DELETE + #[arg(short = 'x', long)] + execute: bool, + }, + + /// List available connections + Connections, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + match args.command { + Some(Commands::Server { port, host }) => { + run_server(&args.config, port, host).await + } + Some(Commands::Cli { conn, sql, format, execute }) => { + run_cli(&args.server, &conn, sql.as_deref(), &format, execute) + } + Some(Commands::Connections) => { + list_connections(&args.server) + } + None => { + run_server(&args.config, None, None).await + } + } +} + +/// 启动 HTTP 服务器 +async fn run_server(config_path: &str, port: Option, host: Option) -> anyhow::Result<()> { + println!("MySQL HTTP Proxy v0.1.0\n"); + + // 加载配置 + let mut config = config::Config::from_file(config_path)?; + + // 命令行参数覆盖配置文件 + if let Some(port) = port { + config.server.port = port; + } + if let Some(host) = host { + config.server.host = host; + } + + // 初始化日志 + let log_path = std::env::var("MYSQL_PROXY_LOG").ok(); + let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref())); + if logger.is_enabled() { + println!("Request logging: enabled"); + } + + // 初始化连接池 + println!("Initializing connection pools...\n"); + let manager = Arc::new(db::ConnectionManager::new(&config.connections, config.pool.clone())?); + + // 启动后台清理任务 + let manager_clone = manager.clone(); + let check_interval = config.pool.check_interval_secs; + tokio::spawn(async move { + let mut interval = tokio::time::interval( + std::time::Duration::from_secs(check_interval) + ); + loop { + interval.tick().await; + manager_clone.cleanup_idle(); + } + }); + + // 构建路由 + let app = Router::new() + .route("/query", post(handler::query)) + .route("/execute", post(handler::execute)) + .route("/connections", get(handler::connections)) + .route("/connections/add", post(handler::add_connection)) + .route("/health", get(handler::health)) + .with_state(Arc::new((manager, logger))); + + // 启动服务器 + let addr = format!("{}:{}", config.server.host, config.server.port); + let listener = tokio::net::TcpListener::bind(&addr).await?; + + println!("\nServer started at http://{}", addr); + println!("\nAPI Endpoints:"); + println!(" POST /query - Execute SELECT queries"); + println!(" POST /execute - Execute INSERT/UPDATE/DELETE"); + println!(" GET /connections - List all connections"); + println!(" POST /connections/add - Add connection (temporary)"); + println!(" GET /health - Health check"); + println!("\nCLI Usage:"); + println!(" mysql-proxy cli -c flux_dev -e \"SELECT 1\""); + println!(" mysql-proxy connections"); + + axum::serve(listener, app).await?; + + Ok(()) +} + +/// CLI 模式 +fn run_cli(server: &str, conn: &str, sql: Option<&str>, format: &str, execute: bool) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server.to_string())); + + // 检查服务器是否运行 + if !cli.check_server()? { + eprintln!(); + eprintln!("=== mysql-proxy 未运行,请使用降级方案 ==="); + eprintln!(); + eprintln!("启动代理: mysql-proxy"); + eprintln!(); + eprintln!("或直接使用 mysql 命令:"); + eprintln!(" mysql -h -u -p -D -e \"\""); + eprintln!(); + eprintln!("配置文件中的连接信息见: mysql-proxy.toml"); + anyhow::bail!("Proxy server not running at {}", server); + } + + // 如果没有提供 SQL,从 stdin 读取 + let sql = match sql { + Some(s) => s.to_string(), + None => { + use std::io::{self, BufRead}; + let stdin = io::stdin(); + let mut lines = Vec::new(); + for line in stdin.lock().lines() { + lines.push(line?); + } + lines.join(" ") + } + }; + + if sql.trim().is_empty() { + anyhow::bail!("No SQL provided. Use -e \"SQL\" or pipe SQL via stdin"); + } + + if execute { + cli.execute(conn, &sql)?; + } else { + cli.query(conn, &sql, Some(format))?; + } + + Ok(()) +} + +/// 列出连接 +fn list_connections(server: &str) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server.to_string())); + + if !cli.check_server()? { + eprintln!(); + eprintln!("=== mysql-proxy 未运行 ==="); + eprintln!("启动代理: mysql-proxy"); + eprintln!("配置文件: mysql-proxy.toml"); + anyhow::bail!("Proxy server not running at {}", server); + } + + cli.list_connections() +} diff --git a/ssh-proxy/Cargo.toml b/ssh-proxy/Cargo.toml new file mode 100644 index 0000000..d912f19 --- /dev/null +++ b/ssh-proxy/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ssh-proxy" +version = "0.1.0" +edition = "2021" +description = "SSH HTTP proxy with session pooling" + +[dependencies] +ssh2 = "0.9" +tokio = { version = "1", features = ["full"] } +axum = "0.7" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +toml = "0.8" +anyhow = "1" +clap = { version = "4", features = ["derive"] } +ureq = "2" # CLI HTTP client (2.x has simpler API) + +[profile.release] +opt-level = "z" +lto = true +strip = true \ No newline at end of file diff --git a/ssh-proxy/src/cli.rs b/ssh-proxy/src/cli.rs new file mode 100644 index 0000000..ef9f3be --- /dev/null +++ b/ssh-proxy/src/cli.rs @@ -0,0 +1,218 @@ +use anyhow::{Result, bail}; +use serde::{Deserialize, Serialize}; +use std::time::Instant; + +const DEFAULT_SERVER: &str = "http://127.0.0.1:3308"; + +#[derive(Debug, Serialize)] +struct ExecRequest { + server: String, + command: String, +} + +#[derive(Debug, Deserialize)] +struct ExecResponse { + stdout: String, + stderr: String, + #[serde(rename = "exitCode")] + exit_code: i32, + #[serde(rename = "durationMs")] + duration_ms: u64, +} + +#[derive(Debug, Deserialize)] +struct ErrorResponse { + error: String, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ServersResponse { + servers: Vec, +} + +#[derive(Debug, Deserialize)] +struct ServerInfo { + name: String, + host: String, + port: u16, + user: String, + status: String, +} + +/// CLI 客户端 +pub struct Cli { + server: String, +} + +impl Cli { + pub fn new(server: Option) -> Self { + Self { + server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()), + } + } + + /// 执行远程命令 + pub fn exec(&self, server: &str, command: &str, format: &str) -> Result<()> { + let start = Instant::now(); + + let body = serde_json::to_string(&ExecRequest { + server: server.to_string(), + command: command.to_string(), + })?; + + let response = ureq::post(&format!("{}/exec", self.server)) + .set("Content-Type", "application/json") + .send_string(&body); + + // ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中 + let body = match response { + Ok(r) => r.into_string()?, + Err(ureq::Error::Status(_, resp)) => resp.into_string()?, + Err(e) => bail!("HTTP error: {}", e), + }; + + if let Ok(err) = serde_json::from_str::(&body) { + eprintln!("Error: {}", err.error); + if let Some(usage) = err.usage { + eprintln!("\n{}", usage); + } + return Ok(()); + } + + let result: ExecResponse = serde_json::from_str(&body)?; + + // 根据格式输出 + if format == "json" { + self.print_json(&result, start.elapsed().as_millis() as u64); + } else { + self.print_text(&result, start.elapsed().as_millis() as u64); + } + + Ok(()) + } + + fn print_text(&self, result: &ExecResponse, total_ms: u64) { + if !result.stdout.is_empty() { + print!("{}", result.stdout); + } + if !result.stderr.is_empty() { + eprint!("{}", result.stderr); + } + + eprintln!("\n--- exit: {}, {}ms db, {}ms total ---", + result.exit_code, result.duration_ms, total_ms); + } + + fn print_json(&self, result: &ExecResponse, total_ms: u64) { + #[derive(Serialize)] + struct Output { + #[serde(rename = "exitCode")] + exit_code: i32, + stdout: String, + stderr: String, + #[serde(rename = "durationMs")] + duration_ms: u64, + #[serde(rename = "totalMs")] + total_ms: u64, + } + + let output = Output { + exit_code: result.exit_code, + stdout: result.stdout.clone(), + stderr: result.stderr.clone(), + duration_ms: result.duration_ms, + total_ms, + }; + + match serde_json::to_string(&output) { + Ok(json) => println!("{}", json), + Err(e) => eprintln!("JSON error: {}", e), + } + } + + /// 列出服务器 + pub fn list_servers(&self) -> Result<()> { + let response = ureq::get(&format!("{}/servers", self.server)) + .call()?; + + let body = response.into_string()?; + + if let Ok(err) = serde_json::from_str::(&body) { + anyhow::bail!("{}", err.error); + } + + let result: ServersResponse = serde_json::from_str(&body)?; + + println!("Servers:"); + println!("{:<15} {:<30} {:<8} {:<10} {:<10}", "Name", "Host", "Port", "User", "Status"); + println!("{}", "-".repeat(73)); + + for srv in result.servers { + println!("{:<15} {:<30} {:<8} {:<10} {:<10}", + srv.name, srv.host, srv.port, srv.user, srv.status); + } + + Ok(()) + } + + /// 动态添加服务器 + pub fn add_server( + &self, + name: String, + host: String, + port: u16, + user: String, + password: Option, + private_key: Option, + ) -> Result<()> { + #[derive(Serialize)] + struct AddRequest { + name: String, + host: String, + port: u16, + user: String, + password: Option, + private_key: Option, + } + + let body = serde_json::to_string(&AddRequest { + name, + host, + port, + user, + password, + private_key, + })?; + + let response = ureq::post(&format!("{}/servers/add", self.server)) + .set("Content-Type", "application/json") + .send_string(&body)?; + + let body = response.into_string()?; + + if let Ok(err) = serde_json::from_str::(&body) { + anyhow::bail!("{}", err.error); + } + + #[derive(Deserialize)] + struct AddResponse { + success: bool, + message: String, + } + + let result: AddResponse = serde_json::from_str(&body)?; + println!("{}", result.message); + + Ok(()) + } + + /// 检查代理是否运行 + pub fn check_server(&self) -> Result { + match ureq::get(&format!("{}/health", self.server)).call() { + Ok(_) => Ok(true), + Err(_) => Ok(false), + } + } +} diff --git a/ssh-proxy/src/config.rs b/ssh-proxy/src/config.rs new file mode 100644 index 0000000..d89b4dc --- /dev/null +++ b/ssh-proxy/src/config.rs @@ -0,0 +1,112 @@ +use anyhow::{Result, bail}; +use serde::Deserialize; +use std::collections::HashSet; +use std::fs; +use std::path::PathBuf; + +#[derive(Debug, Deserialize, Clone)] +pub struct ServerConfig { + pub port: u16, + #[serde(default = "default_host")] + pub host: String, +} + +fn default_host() -> String { + "127.0.0.1".to_string() +} + +#[derive(Debug, Deserialize, Clone)] +pub struct PoolConfig { + #[serde(default = "default_idle_timeout")] + pub idle_timeout_secs: u64, + #[serde(default = "default_check_interval")] + pub check_interval_secs: u64, +} + +fn default_idle_timeout() -> u64 { 300 } +fn default_check_interval() -> u64 { 60 } + +impl Default for PoolConfig { + fn default() -> Self { + Self { + idle_timeout_secs: default_idle_timeout(), + check_interval_secs: default_check_interval(), + } + } +} + +#[derive(Debug, Deserialize, Clone)] +pub struct SshServerConfig { + pub name: String, + pub host: String, + #[serde(default = "default_ssh_port")] + pub port: u16, + pub user: String, + /// 私钥路径 (优先) + pub private_key: Option, + /// 密码 (备选) + pub password: Option, +} + +fn default_ssh_port() -> u16 { 22 } + +impl SshServerConfig { + /// 获取私钥路径 + pub fn get_private_key_path(&self) -> Option { + self.private_key.as_ref().map(|p| { + // 支持 ~ 路径展开 (Unix 和 Windows) + let path = if p.starts_with('~') { + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_default(); + p.replacen('~', &home, 1) + } else { + p.clone() + }; + PathBuf::from(path) + }) + } +} + +#[derive(Debug, Deserialize)] +pub struct Config { + #[serde(default = "default_server")] + pub server: ServerConfig, + #[serde(default)] + pub pool: PoolConfig, + pub servers: Vec, +} + +fn default_server() -> ServerConfig { + ServerConfig { + port: 3308, + host: "127.0.0.1".to_string(), + } +} + +impl Config { + pub fn from_file(path: &str) -> Result { + let content = fs::read_to_string(path)?; + let config: Config = toml::from_str(&content)?; + config.validate()?; + Ok(config) + } + + fn validate(&self) -> Result<()> { + let mut names = HashSet::new(); + for srv in &self.servers { + if srv.name.is_empty() { + bail!("Server name cannot be empty"); + } + if names.contains(&srv.name) { + bail!("Duplicate server name: {}", srv.name); + } + names.insert(srv.name.clone()); + + if srv.private_key.is_none() && srv.password.is_none() { + bail!("Server '{}' needs either private_key or password", srv.name); + } + } + Ok(()) + } +} \ No newline at end of file diff --git a/ssh-proxy/src/handler.rs b/ssh-proxy/src/handler.rs new file mode 100644 index 0000000..ad09a7b --- /dev/null +++ b/ssh-proxy/src/handler.rs @@ -0,0 +1,183 @@ +use axum::{extract::State, http::StatusCode, Json}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Instant; + +use crate::session::SessionManager; +use crate::logger::{LogEntry, RequestLogger}; + +pub type AppState = Arc<(Arc, Arc)>; + +#[derive(Debug, Deserialize)] +pub struct ExecRequest { + pub server: String, + pub command: String, +} + +#[derive(Debug, Deserialize)] +pub struct AddServerRequest { + pub name: String, + pub host: String, + #[serde(default = "default_ssh_port")] + pub port: u16, + #[serde(default = "default_user")] + pub user: String, + pub password: Option, + pub private_key: Option, +} + +fn default_ssh_port() -> u16 { 22 } +fn default_user() -> String { "root".to_string() } + +#[derive(Debug, Serialize)] +pub struct ExecResponse { + pub stdout: String, + pub stderr: String, + #[serde(rename = "exitCode")] + pub exit_code: i32, + #[serde(rename = "durationMs")] + pub duration_ms: u64, +} + +#[derive(Debug, Serialize)] +pub struct ServersResponse { + pub servers: Vec, +} + +#[derive(Debug, Serialize)] +pub struct HealthResponse { + pub status: String, + pub servers: usize, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, + #[serde(rename = "usage")] + pub usage: Option, +} + +impl ErrorResponse { + pub fn new(msg: &str) -> Self { + Self { error: msg.to_string(), usage: None } + } + + pub fn with_usage(mut self, usage: &str) -> Self { + self.usage = Some(usage.to_string()); + self + } +} + +pub async fn exec( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let start = Instant::now(); + let (manager, logger) = &*state; + + let manager_clone = manager.clone(); + let server = req.server.clone(); + let command = req.command.clone(); + + let result = tokio::task::spawn_blocking(move || { + manager_clone.exec(&server, &command) + }) + .await + .map_err(|e| { + error_response_with_usage(&e.to_string(), "Internal error occurred") + })? + .map_err(|e| { + let usage = format!( + "Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}", + list_server_names(&manager) + ); + error_response_with_usage(&e.to_string(), &usage) + })?; + + let duration_ms = start.elapsed().as_millis() as u64; + + // 记录日志 + logger.log(&LogEntry::new("/exec", "http") + .with_server(&req.server) + .with_command(&req.command) + .with_duration(duration_ms) + .with_exit_code(result.exit_code)); + + Ok(Json(ExecResponse { + stdout: result.stdout, + stderr: result.stderr, + exit_code: result.exit_code, + duration_ms, + })) +} + +pub async fn servers( + State(state): State, +) -> Json { + let (manager, _) = &*state; + Json(ServersResponse { servers: manager.list_servers() }) +} + +pub async fn add_server( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (manager, logger) = &*state; + use crate::config::SshServerConfig; + + // 验证必须有密码或私钥 + if req.password.is_none() && req.private_key.is_none() { + return Err(error_response_with_usage( + "Server requires either password or private_key", + "Usage: POST /servers/add {\"name\": \"myserver\", \"host\": \"192.168.1.100\", \"user\": \"root\", \"password\": \"secret\"}\n or: {\"name\": \"myserver\", \"host\": \"...\", \"private_key\": \"~/.ssh/id_rsa\"}" + )); + } + + let cfg = SshServerConfig { + name: req.name.clone(), + host: req.host, + port: req.port, + user: req.user, + password: req.password, + private_key: req.private_key, + }; + + manager.add_server(cfg) + .map_err(|e| error_response(&e.to_string()))?; + + // 记录日志 + logger.log(&LogEntry::new("/servers/add", "http") + .with_server(&req.name) + .with_duration(0)); + + Ok(Json(serde_json::json!({ + "success": true, + "message": format!("Server '{}' added (temporary, will be lost on restart)", req.name) + }))) +} + +pub async fn health( + State(state): State, +) -> Json { + let (manager, _) = &*state; + Json(HealthResponse { + status: "ok".to_string(), + servers: manager.list_servers().len(), + }) +} + +fn error_response(msg: &str) -> (StatusCode, Json) { + (StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg))) +} + +fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json) { + (StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage))) +} + +fn list_server_names(manager: &SessionManager) -> String { + manager.list_servers() + .iter() + .map(|s| s.name.clone()) + .collect::>() + .join(", ") +} diff --git a/ssh-proxy/src/logger.rs b/ssh-proxy/src/logger.rs new file mode 100644 index 0000000..f39207e --- /dev/null +++ b/ssh-proxy/src/logger.rs @@ -0,0 +1,228 @@ +use serde::{Deserialize, Serialize}; +use std::fs::{File, OpenOptions}; +use std::io::Write; +use std::path::PathBuf; +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// 日志记录器 +pub struct RequestLogger { + log_file: Mutex>, + enabled: bool, +} + +impl RequestLogger { + pub fn new(log_path: Option<&str>) -> Self { + let (log_file, enabled) = if let Some(path) = log_path { + let path = expand_path(path); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + match OpenOptions::new() + .create(true) + .append(true) + .open(&path) + { + Ok(file) => (Some(file), true), + Err(e) => { + eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e); + (None, false) + } + } + } else { + (None, false) + }; + + Self { + log_file: Mutex::new(log_file), + enabled, + } + } + + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// 记录请求日志 + pub fn log(&self, entry: &LogEntry) { + if !self.enabled { + return; + } + + let json = match serde_json::to_string(entry) { + Ok(j) => j, + Err(e) => { + eprintln!("[Logger] Failed to serialize log: {}", e); + return; + } + }; + + if let Ok(mut file) = self.log_file.lock() { + if let Some(ref mut f) = *file { + let _ = writeln!(f, "{}", json); + } + } + } +} + +/// 日志条目 +#[derive(Debug, Serialize, Deserialize)] +pub struct LogEntry { + pub timestamp: String, + pub level: String, + #[serde(rename = "type")] + pub log_type: String, + pub client: String, + pub endpoint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub conn: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub server: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sql: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub command: Option, + #[serde(rename = "durationMs")] + pub duration_ms: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub rows: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "exitCode")] + pub exit_code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl LogEntry { + pub fn new(endpoint: &str, client: &str) -> Self { + Self { + timestamp: current_timestamp(), + level: "INFO".to_string(), + log_type: "request".to_string(), + client: client.to_string(), + endpoint: endpoint.to_string(), + conn: None, + server: None, + sql: None, + command: None, + duration_ms: 0, + rows: None, + exit_code: None, + error: None, + } + } + + pub fn with_conn(mut self, conn: &str) -> Self { + self.conn = Some(conn.to_string()); + self + } + + pub fn with_server(mut self, server: &str) -> Self { + self.server = Some(server.to_string()); + self + } + + pub fn with_sql(mut self, sql: &str) -> Self { + // 截断过长的 SQL + self.sql = Some(truncate_string(sql, 1000)); + self + } + + pub fn with_command(mut self, command: &str) -> Self { + // 截断过长的命令 + self.command = Some(truncate_string(command, 500)); + self + } + + pub fn with_duration(mut self, ms: u64) -> Self { + self.duration_ms = ms; + self + } + + pub fn with_rows(mut self, rows: usize) -> Self { + self.rows = Some(rows); + self + } + + pub fn with_exit_code(mut self, code: i32) -> Self { + self.exit_code = Some(code); + self + } + + pub fn with_error(mut self, error: &str) -> Self { + self.level = "ERROR".to_string(); + self.error = Some(error.to_string()); + self + } +} + +fn current_timestamp() -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + let secs = now.as_secs(); + let datetime = chrono_timestamp(secs); + format!("{}Z", datetime) +} + +fn chrono_timestamp(secs: u64) -> String { + let days = secs / 86400; + let remaining = secs % 86400; + let hours = remaining / 3600; + let minutes = (remaining % 3600) / 60; + let seconds = remaining % 60; + + // 从 1970-01-01 开始计算日期 + let mut year = 1970; + let mut days_left = days; + + loop { + let days_in_year = if is_leap_year(year) { 366 } else { 365 }; + if days_left < days_in_year { + break; + } + days_left -= days_in_year; + year += 1; + } + + let month_days = if is_leap_year(year) { + [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + } else { + [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] + }; + + let mut month = 1; + for &days_in_month in &month_days { + if days_left < days_in_month { + break; + } + days_left -= days_in_month; + month += 1; + } + let day = days_left + 1; + + format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds) +} + +fn is_leap_year(year: u64) -> bool { + (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) +} + +fn truncate_string(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + s.to_string() + } else { + format!("{}... (truncated)", &s[..max_len]) + } +} + +fn expand_path(path: &str) -> PathBuf { + if path.starts_with('~') { + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_default(); + PathBuf::from(path.replacen('~', &home, 1)) + } else { + PathBuf::from(path) + } +} diff --git a/ssh-proxy/src/main.rs b/ssh-proxy/src/main.rs new file mode 100644 index 0000000..34cfb0c --- /dev/null +++ b/ssh-proxy/src/main.rs @@ -0,0 +1,170 @@ +mod cli; +mod config; +mod handler; +mod logger; +mod session; + +use axum::{routing::{get, post}, Router}; +use clap::{Parser, Subcommand}; +use std::sync::Arc; + +#[derive(Parser, Debug)] +#[command(name = "ssh-proxy")] +#[command(about = "SSH HTTP proxy with session pooling")] +struct Args { + #[command(subcommand)] + command: Option, + #[arg(long, default_value = "ssh-proxy.toml", global = true)] + config: String, + #[arg(short = 'S', long, default_value = "http://127.0.0.1:3308", global = true)] + server: String, +} + +#[derive(Debug, Subcommand)] +enum Commands { + Server { + #[arg(short = 'P', long)] + port: Option, + #[arg(short = 'H', long)] + host: Option, + }, + Exec { + #[arg(short = 'n', long)] + name: String, + #[arg(short, long)] + command: String, + /// Output format: text, json + #[arg(short = 'F', long, default_value = "text")] + format: String, + }, + Servers, + /// Add server dynamically (temporary) + AddServer { + #[arg(short = 'n', long)] + name: String, + #[arg(short = 'H', long)] + host: String, + #[arg(short = 'P', long, default_value = "22")] + port: u16, + #[arg(short = 'u', long, default_value = "root")] + user: String, + #[arg(short = 'p', long)] + password: Option, + #[arg(short = 'k', long)] + private_key: Option, + }, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + match args.command { + Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await, + Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format), + Some(Commands::Servers) => run_cli_servers(&args.server), + Some(Commands::AddServer { name, host, port, user, password, private_key }) => { + run_cli_add_server(&args.server, name, host, port, user, password, private_key) + } + None => run_server(&args.config, None, None).await, + } +} + +async fn run_server(config_path: &str, port: Option, host: Option) -> anyhow::Result<()> { + println!("SSH HTTP Proxy v0.1.0\n"); + let mut config = config::Config::from_file(config_path)?; + if let Some(p) = port { config.server.port = p; } + if let Some(h) = host { config.server.host = h; } + + // 初始化日志 + let log_path = std::env::var("SSH_PROXY_LOG").ok(); + let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref())); + if logger.is_enabled() { + println!("Request logging: enabled"); + } + + println!("Initializing SSH sessions...\n"); + let manager = Arc::new(session::SessionManager::new(&config.servers)?); + + let manager_clone = manager.clone(); + let idle_timeout = config.pool.idle_timeout_secs; + let check_interval = config.pool.check_interval_secs; + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(check_interval)); + loop { + interval.tick().await; + manager_clone.cleanup_idle(idle_timeout); + } + }); + + let app = Router::new() + .route("/exec", post(handler::exec)) + .route("/servers", get(handler::servers)) + .route("/servers/add", post(handler::add_server)) + .route("/health", get(handler::health)) + .with_state(Arc::new((manager, logger))); + + let addr = format!("{}:{}", config.server.host, config.server.port); + let listener = tokio::net::TcpListener::bind(&addr).await?; + + println!("\nServer started at http://{}", addr); + println!("\nAPI Endpoints:"); + println!(" POST /exec - Execute remote command"); + println!(" POST /servers/add - Add server (temporary)"); + println!(" GET /servers - List all servers"); + println!(" GET /health - Health check"); + println!("\nCLI Usage:"); + println!(" ssh-proxy exec -n flux_dev -c \"docker ps\""); + println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json"); + println!(" ssh-proxy servers"); + + axum::serve(listener, app).await?; + Ok(()) +} + +fn run_cli_exec(server_url: &str, server: &str, command: &str, format: &str) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server_url.to_string())); + if !cli.check_server()? { + eprintln!(); + eprintln!("=== ssh-proxy 未运行,请使用降级方案 ==="); + eprintln!(); + eprintln!("启动代理: ssh-proxy"); + eprintln!(); + eprintln!("或直接使用 ssh 命令:"); + eprintln!(" ssh @ \"\""); + eprintln!(); + eprintln!("配置文件中的服务器信息见: ssh-proxy.toml"); + anyhow::bail!("Proxy server not running at {}", server_url); + } + cli.exec(server, command, format) +} + +fn run_cli_servers(server_url: &str) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server_url.to_string())); + if !cli.check_server()? { + eprintln!(); + eprintln!("=== ssh-proxy 未运行 ==="); + eprintln!("启动代理: ssh-proxy"); + eprintln!("配置文件: ssh-proxy.toml"); + anyhow::bail!("Proxy server not running at {}", server_url); + } + cli.list_servers() +} + +fn run_cli_add_server( + server_url: &str, + name: String, + host: String, + port: u16, + user: String, + password: Option, + private_key: Option, +) -> anyhow::Result<()> { + let cli = cli::Cli::new(Some(server_url.to_string())); + if !cli.check_server()? { + eprintln!(); + eprintln!("=== ssh-proxy 未运行 ==="); + eprintln!("启动代理: ssh-proxy"); + anyhow::bail!("Proxy server not running at {}", server_url); + } + cli.add_server(name, host, port, user, password, private_key) +} \ No newline at end of file diff --git a/ssh-proxy/src/session.rs b/ssh-proxy/src/session.rs new file mode 100644 index 0000000..a8e55e5 --- /dev/null +++ b/ssh-proxy/src/session.rs @@ -0,0 +1,192 @@ +use anyhow::{Result, bail}; +use ssh2::Session; +use std::collections::HashMap; +use std::net::TcpStream; +use std::sync::Mutex; +use std::time::{Duration, Instant}; +use std::io::Read; +use std::path::PathBuf; + +use crate::config::SshServerConfig; + +struct SessionState { + session: Session, + last_used: Instant, +} + +pub struct SessionManager { + sessions: Mutex>, + configs: Mutex>, +} + +impl SessionManager { + pub fn new(configs: &[SshServerConfig]) -> Result { + let mut config_map = HashMap::new(); + for cfg in configs { + config_map.insert(cfg.name.clone(), cfg.clone()); + println!(" Registered: {} ({}:{})", cfg.name, cfg.host, cfg.port); + } + println!("\n{} server(s) configured (lazy init)", config_map.len()); + Ok(Self { + sessions: Mutex::new(HashMap::new()), + configs: Mutex::new(config_map), + }) + } + + fn get_or_create_session(&self, name: &str) -> Result { + { + let mut sessions = self.sessions.lock().unwrap(); + if let Some(state) = sessions.get_mut(name) { + if state.session.authenticated() { + state.last_used = Instant::now(); + return Ok(state.session.clone()); + } else { + println!("[Session] Session expired, reconnecting: {}", name); + sessions.remove(name); + } + } + } + + let cfg = self.configs.lock().unwrap() + .get(name) + .ok_or_else(|| anyhow::anyhow!("Server '{}' not found", name))? + .clone(); + + println!("[LazyInit] Connecting to: {}", name); + let session = self.create_session(&cfg)?; + + { + let mut sessions = self.sessions.lock().unwrap(); + sessions.insert(name.to_string(), SessionState { + session: session.clone(), + last_used: Instant::now(), + }); + } + + println!("[LazyInit] Connected: {}", name); + Ok(session) + } + + fn create_session(&self, cfg: &SshServerConfig) -> Result { + let addr = format!("{}:{}", cfg.host, cfg.port); + let tcp = TcpStream::connect(&addr) + .map_err(|e| anyhow::anyhow!("Failed to connect to {}: {}", addr, e))?; + + let mut session = Session::new()?; + session.set_tcp_stream(tcp); + session.handshake()?; + + if let Some(key_path) = cfg.get_private_key_path() { + let pubkey_path: PathBuf = key_path.with_extension("pub"); + session.userauth_pubkey_file(&cfg.user, Some(&pubkey_path), &key_path, None)?; + } else if let Some(ref password) = cfg.password { + session.userauth_password(&cfg.user, password)?; + } else { + bail!("No authentication method configured"); + } + + if !session.authenticated() { + bail!("SSH authentication failed"); + } + + Ok(session) + } + + pub fn exec(&self, name: &str, command: &str) -> Result { + let start = Instant::now(); + let session = self.get_or_create_session(name)?; + + let mut channel = session.channel_session()?; + channel.exec(command)?; + + let mut stdout = String::new(); + let mut stderr = String::new(); + if let Err(e) = channel.read_to_string(&mut stdout) { + eprintln!("[SSH] stdout read error: {}", e); + } + if let Err(e) = channel.stderr().read_to_string(&mut stderr) { + eprintln!("[SSH] stderr read error: {}", e); + } + + channel.wait_close()?; + let exit_code = channel.exit_status()?; + + Ok(ExecResult { + stdout, + stderr, + exit_code, + duration_ms: start.elapsed().as_millis() as u64, + }) + } + + pub fn list_servers(&self) -> Vec { + let sessions = self.sessions.lock().unwrap(); + let configs = self.configs.lock().unwrap(); + configs.iter().map(|(name, cfg)| { + let status = if sessions.contains_key(name) { "connected" } else { "pending" }; + ServerInfo { + name: name.clone(), + host: cfg.host.clone(), + port: cfg.port, + user: cfg.user.clone(), + status: status.to_string(), + } + }).collect() + } + + /// 动态添加服务器 (临时,重启后消失) + pub fn add_server(&self, cfg: SshServerConfig) -> Result<()> { + let name = cfg.name.clone(); + let mut configs = self.configs.lock().unwrap(); + if configs.contains_key(&name) { + anyhow::bail!("Server '{}' already exists", name); + } + + println!("[Dynamic] Adding: {} ({}:{})", name, cfg.host, cfg.port); + + // 测试连接 + let session = self.create_session(&cfg)?; + { + let mut sessions = self.sessions.lock().unwrap(); + sessions.insert(name.clone(), SessionState { + session, + last_used: Instant::now(), + }); + } + + configs.insert(name.clone(), cfg); + println!("[Dynamic] ✓ Added: {}", name); + Ok(()) + } + + pub fn cleanup_idle(&self, timeout_secs: u64) { + let mut sessions = self.sessions.lock().unwrap(); + let now = Instant::now(); + sessions.retain(|name, state| { + let elapsed = now.duration_since(state.last_used); + if elapsed > Duration::from_secs(timeout_secs) { + println!("[Cleanup] Removing idle session: {} (idle {}s)", name, elapsed.as_secs()); + false + } else { + true + } + }); + } +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct ExecResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, + pub duration_ms: u64, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct ServerInfo { + pub name: String, + pub host: String, + pub port: u16, + pub user: String, + pub status: String, +} \ No newline at end of file