新增: MySQL/SSH 代理工具

- mysql-proxy: MySQL HTTP 代理,连接池复用
- ssh-proxy: SSH HTTP 代理,会话复用
- mysql-cli: 轻量级 MySQL CLI 工具

功能特性:
- 延迟初始化,启动快
- CLI 和 HTTP API 双模式
- 请求日志支持
- 错误友好提示
- JSON 极简输出格式
This commit is contained in:
2026-03-19 14:03:12 +08:00
commit 11203f036f
24 changed files with 3794 additions and 0 deletions

21
mysql-proxy/Cargo.toml Normal file
View File

@@ -0,0 +1,21 @@
[package]
name = "mysql-proxy"
version = "0.1.0"
edition = "2021"
description = "MySQL HTTP proxy with connection pooling"
[dependencies]
mysql = "25"
tokio = { version = "1", features = ["full"] }
axum = "0.7"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
anyhow = "1"
clap = { version = "4", features = ["derive"] }
ureq = "2" # CLI HTTP client (2.x has simpler API)
[profile.release]
opt-level = "z"
lto = true
strip = true

287
mysql-proxy/src/cli.rs Normal file
View File

@@ -0,0 +1,287 @@
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
use std::time::Instant;
const DEFAULT_SERVER: &str = "http://127.0.0.1:3307";
#[derive(Debug, Serialize)]
struct QueryRequest {
conn: String,
sql: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct QueryResponse {
columns: Vec<String>,
rows: Vec<Vec<Option<String>>>,
#[serde(rename = "rowCount")]
row_count: usize,
#[serde(rename = "durationMs")]
duration_ms: u64,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: String,
}
#[derive(Debug, Deserialize)]
struct ConnectionsResponse {
connections: Vec<ConnectionInfo>,
}
#[derive(Debug, Deserialize)]
struct ConnectionInfo {
name: String,
database: String,
host: String,
status: String,
}
/// CLI 客户端
pub struct Cli {
server: String,
}
impl Cli {
pub fn new(server: Option<String>) -> Self {
Self {
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
}
}
/// 执行查询并输出结果
pub fn query(&self, conn: &str, sql: &str, format: Option<&str>) -> Result<()> {
let start = Instant::now();
let body = serde_json::to_string(&QueryRequest {
conn: conn.to_string(),
sql: sql.to_string(),
})?;
let response = ureq::post(&format!("{}/query", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
// ureq 2.x: 非 2xx 会返回 Err但响应体在 error 中
let body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
let total_ms = start.elapsed().as_millis();
// 解析响应
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
eprintln!("Error: {}", err.error);
return Ok(());
}
let result: QueryResponse = serde_json::from_str(&body)?;
// 输出结果
match format {
Some("json") => self.print_json(&result),
Some("csv") => self.print_csv(&result),
Some("vertical") | Some("vert") => self.print_vertical(&result),
_ => self.print_table(&result),
}
println!("\n{} rows in set ({}ms db, {}ms total)",
result.row_count, result.duration_ms, total_ms);
Ok(())
}
/// 执行语句 (INSERT/UPDATE/DELETE)
pub fn execute(&self, conn: &str, sql: &str) -> Result<()> {
let start = Instant::now();
let body = serde_json::to_string(&QueryRequest {
conn: conn.to_string(),
sql: sql.to_string(),
})?;
let response = ureq::post(&format!("{}/execute", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
// ureq 2.x: 非 2xx 会返回 Err但响应体在 error 中
let body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
let total_ms = start.elapsed().as_millis();
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
eprintln!("Error: {}", err.error);
return Ok(());
}
#[derive(Deserialize)]
struct ExecuteResponse {
#[serde(rename = "affectedRows")]
affected_rows: u64,
#[serde(rename = "lastInsertId")]
last_insert_id: u64,
}
let result: ExecuteResponse = serde_json::from_str(&body)?;
println!("Query OK, {} rows affected ({}ms total)", result.affected_rows, total_ms);
if result.last_insert_id > 0 {
println!("Last insert ID: {}", result.last_insert_id);
}
Ok(())
}
/// 列出连接
pub fn list_connections(&self) -> Result<()> {
let response = ureq::get(&format!("{}/connections", self.server))
.call()?;
let body = response.into_string()?;
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
bail!("{}", err.error);
}
let result: ConnectionsResponse = serde_json::from_str(&body)?;
println!("Connections:");
println!("{:<15} {:<20} {:<40} {:<10}", "Name", "Database", "Host", "Status");
println!("{}", "-".repeat(85));
for conn in result.connections {
println!("{:<15} {:<20} {:<40} {:<10}",
conn.name, conn.database, conn.host, conn.status);
}
Ok(())
}
/// 检查代理是否运行
pub fn check_server(&self) -> Result<bool> {
match ureq::get(&format!("{}/health", self.server)).call() {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
// 输出格式化方法
fn print_table(&self, result: &QueryResponse) {
if result.columns.is_empty() {
return;
}
// 计算列宽
let mut widths: Vec<usize> = result.columns.iter().map(|c| c.len()).collect();
for row in &result.rows {
for (i, val) in row.iter().enumerate() {
let len = val.as_ref().map(|s| s.len()).unwrap_or(4); // NULL
widths[i] = widths[i].max(len);
}
}
// 打印表头
let header: String = result.columns.iter()
.enumerate()
.map(|(i, c)| format!(" {:width$} ", c, width = widths[i]))
.collect::<Vec<_>>()
.join("|");
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
println!("|{}|", header);
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
// 打印数据行
for row in &result.rows {
let line: String = row.iter()
.enumerate()
.map(|(i, val)| {
match val {
Some(s) => format!(" {:width$} ", s, width = widths[i]),
None => format!(" {:width$} ", "NULL", width = widths[i]),
}
})
.collect::<Vec<_>>()
.join("|");
println!("|{}|", line);
}
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
}
fn print_json(&self, result: &QueryResponse) {
// 极简格式:对象数组
let rows: Vec<serde_json::Value> = result.rows.iter().map(|row| {
let mut obj = serde_json::Map::new();
for (i, col) in result.columns.iter().enumerate() {
let val = row.get(i)
.and_then(|v| v.as_ref())
.map(|s| {
// 尝试解析为数字
if let Ok(n) = s.parse::<i64>() {
serde_json::Value::Number(n.into())
} else if let Ok(n) = s.parse::<f64>() {
serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)))
} else {
serde_json::Value::String(s.clone())
}
})
.unwrap_or(serde_json::Value::Null);
obj.insert(col.clone(), val);
}
serde_json::Value::Object(obj)
}).collect();
println!("{}", serde_json::to_string(&rows).unwrap_or_default());
}
fn print_csv(&self, result: &QueryResponse) {
// 打印表头
println!("{}", result.columns.join(","));
// 打印数据
for row in &result.rows {
let line: String = row.iter()
.map(|val| {
match val {
Some(s) => {
if s.contains(',') || s.contains('"') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\"\""))
} else {
s.clone()
}
}
None => "".to_string(),
}
})
.collect::<Vec<_>>()
.join(",");
println!("{}", line);
}
}
fn print_vertical(&self, result: &QueryResponse) {
for (row_idx, row) in result.rows.iter().enumerate() {
if row_idx > 0 {
println!();
}
println!("*************************** {}. row ***************************", row_idx + 1);
for (col_idx, col) in result.columns.iter().enumerate() {
let val = row.get(col_idx)
.and_then(|v| v.as_ref())
.map(|s| s.as_str())
.unwrap_or("NULL");
println!("{:>20}: {}", col, val);
}
}
}
}

156
mysql-proxy/src/config.rs Normal file
View File

@@ -0,0 +1,156 @@
use anyhow::{Result, bail};
use serde::Deserialize;
use std::fs;
#[derive(Debug, Deserialize, Clone)]
pub struct ServerConfig {
pub port: u16,
#[serde(default = "default_host")]
pub host: String,
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
/// 连接池配置
#[derive(Debug, Deserialize, Clone)]
pub struct PoolConfig {
/// 默认最大连接数
#[serde(default = "default_max_connections")]
pub default_max_connections: usize,
/// 空闲超时 (秒)
#[serde(default = "default_idle_timeout")]
pub idle_timeout_secs: u64,
/// 检查间隔 (秒)
#[serde(default = "default_check_interval")]
pub check_interval_secs: u64,
}
fn default_max_connections() -> usize { 5 }
fn default_idle_timeout() -> u64 { 300 }
fn default_check_interval() -> u64 { 60 }
impl Default for PoolConfig {
fn default() -> Self {
Self {
default_max_connections: default_max_connections(),
idle_timeout_secs: default_idle_timeout(),
check_interval_secs: default_check_interval(),
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct ConnectionConfig {
pub name: String,
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
/// 最大连接数 (可选,覆盖默认值)
#[serde(default)]
pub max_connections: Option<usize>,
}
#[derive(Debug, Deserialize)]
pub struct Config {
#[serde(default = "default_server")]
pub server: ServerConfig,
#[serde(default)]
pub pool: PoolConfig,
pub connections: Vec<ConnectionConfig>,
}
fn default_server() -> ServerConfig {
ServerConfig {
port: 3307,
host: "127.0.0.1".to_string(),
}
}
impl Config {
/// 从文件加载配置
pub fn from_file(path: &str) -> Result<Self> {
let content = fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
config.validate()?;
Ok(config)
}
/// 从默认路径加载配置
pub fn load() -> Result<Self> {
// 尝试多个配置文件路径
let paths = [
"mysql-proxy.toml",
"./config/mysql-proxy.toml",
&format!("{}/.mysql-proxy.toml", std::env::var("HOME").unwrap_or_default()),
];
for path in &paths {
if std::path::Path::new(path).exists() {
return Self::from_file(path);
}
}
bail!("Config file not found. Create mysql-proxy.toml in current directory or use --config flag")
}
/// 验证配置
fn validate(&self) -> Result<()> {
let mut names = std::collections::HashSet::new();
for conn in &self.connections {
if conn.name.is_empty() {
bail!("Connection name cannot be empty");
}
if names.contains(&conn.name) {
bail!("Duplicate connection name: {}", conn.name);
}
names.insert(conn.name.clone());
}
Ok(())
}
}
impl ConnectionConfig {
/// 构建 DSN
pub fn build_dsn(&self) -> String {
let password = urlencoding(&self.password);
format!(
"mysql://{}:{}@{}:{}/{}",
self.user, password, self.host, self.port, self.database
)
}
}
/// URL 编码密码中的特殊字符
fn urlencoding(s: &str) -> String {
let mut result = String::new();
for c in s.chars() {
match c {
'@' => result.push_str("%40"),
':' => result.push_str("%3A"),
'/' => result.push_str("%2F"),
'?' => result.push_str("%3F"),
'#' => result.push_str("%23"),
'[' => result.push_str("%5B"),
']' => result.push_str("%5D"),
'!' => result.push_str("%21"),
'$' => result.push_str("%24"),
'&' => result.push_str("%26"),
'\'' => result.push_str("%27"),
'(' => result.push_str("%28"),
')' => result.push_str("%29"),
'*' => result.push_str("%2A"),
'+' => result.push_str("%2B"),
',' => result.push_str("%2C"),
';' => result.push_str("%3B"),
'=' => result.push_str("%3D"),
'%' => result.push_str("%25"),
' ' => result.push_str("%20"),
_ => result.push(c),
}
}
result
}

173
mysql-proxy/src/db.rs Normal file
View File

@@ -0,0 +1,173 @@
use anyhow::{Result, bail};
use mysql::{Pool, PooledConn, Opts,OptsBuilder, prelude::*};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::config::{ConnectionConfig, PoolConfig};
/// 连接池状态
struct PoolState {
pool: Arc<Pool>,
last_used: Instant,
}
/// 连接池管理器 (延迟初始化)
pub struct ConnectionManager {
pools: Mutex<HashMap<String, PoolState>>,
configs: Mutex<HashMap<String, ConnectionConfig>>,
pool_config: PoolConfig,
}
impl ConnectionManager {
/// 创建连接管理器 (不初始化连接)
pub fn new(configs: &[ConnectionConfig], pool_config: PoolConfig) -> Result<Self> {
let mut config_map = HashMap::new();
for cfg in configs {
config_map.insert(cfg.name.clone(), cfg.clone());
println!(" Registered: {} ({}/{})", cfg.name, cfg.host, cfg.database);
}
println!("\n{} connection(s) configured (lazy init)", config_map.len());
println!("Pool config: max_connections={}, idle_timeout={}s",
pool_config.default_max_connections, pool_config.idle_timeout_secs);
Ok(Self {
pools: Mutex::new(HashMap::new()),
configs: Mutex::new(config_map),
pool_config,
})
}
/// 获取或创建连接
pub fn get_conn(&self, name: &str) -> Result<PooledConn> {
// 1. 先尝试从已有池中获取
{
let mut pools = self.pools.lock().unwrap();
if let Some(state) = pools.get_mut(name) {
state.last_used = Instant::now();
return Ok(state.pool.get_conn()?);
}
}
// 2. 没有则创建新池
let cfg = self.configs.lock().unwrap().get(name)
.ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?
.clone();
println!("[LazyInit] Creating connection pool for: {}", name);
let pool = self.create_pool(&cfg)?;
let arc_pool = Arc::new(pool);
// 3. 保存并返回连接
{
let mut pools = self.pools.lock().unwrap();
pools.insert(name.to_string(), PoolState {
pool: arc_pool.clone(),
last_used: Instant::now(),
});
}
Ok(arc_pool.get_conn()?)
}
/// 创建连接池
fn create_pool(&self, cfg: &ConnectionConfig) -> Result<Pool> {
let dsn = cfg.build_dsn();
let opts: Opts = Opts::from_url(&dsn)?;
// 获取最大连接数 (单个配置 > 全局默认)
let max_conn = cfg.max_connections.unwrap_or(self.pool_config.default_max_connections);
// 构建带连接池参数的选项
let pool_opts = OptsBuilder::from_opts(opts)
.pool_opts(mysql::PoolOpts::new()
.with_constraints(mysql::PoolConstraints::new(1, max_conn).unwrap()));
let pool = Pool::new(pool_opts)?;
// 测试连接
let mut conn = pool.get_conn()?;
conn.query::<String, _>("SELECT 1")?;
println!("[LazyInit] ✓ Connected: {} (max_conn={})", cfg.name, max_conn);
Ok(pool)
}
/// 获取所有连接信息
pub fn list_connections(&self) -> Vec<ConnectionInfo> {
let pools = self.pools.lock().unwrap();
let configs = self.configs.lock().unwrap();
configs.iter().map(|(name, cfg)| {
let status = if pools.contains_key(name) { "connected" } else { "pending" };
ConnectionInfo {
name: name.clone(),
database: cfg.database.clone(),
host: cfg.host.clone(),
status: status.to_string(),
}
}).collect()
}
/// 清理空闲连接池
pub fn cleanup_idle(&self) {
let mut pools = self.pools.lock().unwrap();
let now = Instant::now();
let idle_timeout = Duration::from_secs(self.pool_config.idle_timeout_secs);
pools.retain(|name, state| {
let elapsed = now.duration_since(state.last_used);
if elapsed > idle_timeout {
println!("[Cleanup] Removing idle pool: {} (idle {}s)", name, elapsed.as_secs());
return false;
}
true
});
}
/// 检查连接是否健康
pub fn health_check(&self, name: &str) -> bool {
let pools = self.pools.lock().unwrap();
if let Some(state) = pools.get(name) {
return state.pool.get_conn().ok()
.and_then(|mut c| c.query::<String, _>("SELECT 1").ok())
.is_some();
}
false
}
/// 动态添加连接配置 (临时,重启后消失)
pub fn add_connection(&self, cfg: ConnectionConfig) -> Result<()> {
let name = cfg.name.clone();
let mut configs = self.configs.lock().unwrap();
if configs.contains_key(&name) {
bail!("Connection '{}' already exists", name);
}
println!("[Dynamic] Adding: {} ({}/{})", name, cfg.host, cfg.database);
// 测试连接
let pool = self.create_pool(&cfg)?;
let arc_pool = Arc::new(pool);
// 保存
self.pools.lock().unwrap().insert(name.clone(), PoolState {
pool: arc_pool,
last_used: Instant::now(),
});
configs.insert(name.clone(), cfg);
println!("[Dynamic] ✓ Added: {}", name);
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ConnectionInfo {
pub name: String,
pub database: String,
pub host: String,
pub status: String,
}

320
mysql-proxy/src/handler.rs Normal file
View File

@@ -0,0 +1,320 @@
use axum::{
extract::State,
http::StatusCode,
Json,
};
use mysql::{prelude::*, Value, Row as MysqlRow};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use crate::db::ConnectionManager;
use crate::logger::{LogEntry, RequestLogger};
/// 应用状态
pub type AppState = Arc<(Arc<ConnectionManager>, Arc<RequestLogger>)>;
// ============== 请求结构 ==============
#[derive(Debug, Deserialize)]
pub struct QueryRequest {
/// 连接名称
pub conn: String,
/// SQL 语句
pub sql: String,
/// 输出格式: json, table, csv, vertical
#[serde(default = "default_format")]
pub format: String,
}
fn default_format() -> String {
"json".to_string()
}
#[derive(Debug, Deserialize)]
pub struct ExecuteRequest {
pub conn: String,
pub sql: String,
}
#[derive(Debug, Deserialize)]
pub struct AddConnectionRequest {
pub name: String,
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
}
// ============== 响应结构 ==============
#[derive(Debug, Serialize)]
pub struct QueryResponse {
pub columns: Vec<String>,
pub rows: Vec<Vec<Option<String>>>,
#[serde(rename = "rowCount")]
pub row_count: usize,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
}
#[derive(Debug, Serialize)]
pub struct ExecuteResponse {
#[serde(rename = "affectedRows")]
pub affected_rows: u64,
#[serde(rename = "lastInsertId")]
pub last_insert_id: u64,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
}
#[derive(Debug, Serialize)]
pub struct ConnectionsResponse {
pub connections: Vec<crate::db::ConnectionInfo>,
}
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub connections: usize,
}
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
#[serde(rename = "usage")]
pub usage: Option<String>,
}
impl ErrorResponse {
pub fn new(msg: &str) -> Self {
Self { error: msg.to_string(), usage: None }
}
pub fn with_usage(mut self, usage: &str) -> Self {
self.usage = Some(usage.to_string());
self
}
}
// ============== 处理器 ==============
/// 查询处理器
pub async fn query(
State(state): State<AppState>,
Json(req): Json<QueryRequest>,
) -> Result<Json<QueryResponse>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
let start = Instant::now();
// 获取连接
let mut conn = manager.get_conn(&req.conn)
.map_err(|e| {
let usage = "Usage: POST /query {\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}\n"
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
error_response_with_usage(&e.to_string(), &usage)
})?;
// 判断是否是查询语句
let sql_upper = req.sql.trim().to_uppercase();
let is_query = sql_upper.starts_with("SELECT")
|| sql_upper.starts_with("SHOW")
|| sql_upper.starts_with("DESCRIBE")
|| sql_upper.starts_with("DESC ")
|| sql_upper.starts_with("EXPLAIN")
|| sql_upper.starts_with("WITH");
if !is_query {
let err = error_response_with_usage(
"Not a SELECT query. Use /execute for INSERT/UPDATE/DELETE",
"Usage:\n POST /query {\"conn\": \"name\", \"sql\": \"SELECT ...\"}\n POST /execute {\"conn\": \"name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}"
);
return Err(err);
}
// 执行查询
let result = conn.query_iter(&req.sql)
.map_err(|e| {
let usage = format!("SQL Error: {}\n\nUsage: POST /query {{\"conn\": \"name\", \"sql\": \"SELECT ...\"}}", e);
error_response_with_usage(&e.to_string(), &usage)
})?;
let mut columns: Vec<String> = Vec::new();
let mut data: Vec<Vec<Option<String>>> = Vec::new();
// 获取数据并从第一行提取列名
for row_result in result {
let row = row_result.map_err(|e| error_response(&e.to_string()))?;
// 从第一行获取列名
if columns.is_empty() {
columns = row.columns()
.iter()
.map(|c| c.name_str().to_string())
.collect();
}
let values = row_to_strings(&row, columns.len());
data.push(values);
}
let duration = start.elapsed();
let row_count = data.len();
// 记录日志
logger.log(&LogEntry::new("/query", "http")
.with_conn(&req.conn)
.with_sql(&req.sql)
.with_duration(duration.as_millis() as u64)
.with_rows(row_count));
Ok(Json(QueryResponse {
columns,
rows: data,
row_count,
duration_ms: duration.as_millis() as u64,
}))
}
/// 执行处理器 (INSERT/UPDATE/DELETE)
pub async fn execute(
State(state): State<AppState>,
Json(req): Json<ExecuteRequest>,
) -> Result<Json<ExecuteResponse>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
let start = Instant::now();
// 获取连接
let mut conn = manager.get_conn(&req.conn)
.map_err(|e| {
let usage = "Usage: POST /execute {\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}\n"
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
error_response_with_usage(&e.to_string(), &usage)
})?;
// 执行
let result = conn.query_iter(&req.sql)
.map_err(|e| error_response(&e.to_string()))?;
let duration = start.elapsed();
let affected = result.affected_rows();
let last_id = result.last_insert_id().unwrap_or(0);
// 记录日志
logger.log(&LogEntry::new("/execute", "http")
.with_conn(&req.conn)
.with_sql(&req.sql)
.with_duration(duration.as_millis() as u64));
Ok(Json(ExecuteResponse {
affected_rows: affected,
last_insert_id: last_id,
duration_ms: duration.as_millis() as u64,
}))
}
/// 获取连接列表
pub async fn connections(
State(state): State<AppState>,
) -> Json<ConnectionsResponse> {
let (manager, _) = &*state;
Json(ConnectionsResponse {
connections: manager.list_connections(),
})
}
/// 健康检查
pub async fn health(
State(state): State<AppState>,
) -> Json<HealthResponse> {
let (manager, _) = &*state;
Json(HealthResponse {
status: "ok".to_string(),
connections: manager.list_connections().len(),
})
}
/// 动态添加连接
pub async fn add_connection(
State(state): State<AppState>,
Json(req): Json<AddConnectionRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let (manager, logger) = &*state;
use crate::config::ConnectionConfig;
let cfg = ConnectionConfig {
name: req.name.clone(),
host: req.host,
port: req.port,
user: req.user,
password: req.password,
database: req.database,
max_connections: None,
};
manager.add_connection(cfg)
.map_err(|e| error_response(&e.to_string()))?;
// 记录日志
logger.log(&LogEntry::new("/connections/add", "http")
.with_conn(&req.name)
.with_duration(0));
Ok(Json(serde_json::json!({
"success": true,
"message": format!("Connection '{}' added (temporary, will be lost on restart)", req.name)
})))
}
// ============== 辅助函数 ==============
fn error_response(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg)))
}
fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json<ErrorResponse>) {
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage)))
}
fn list_conn_names(manager: &ConnectionManager) -> String {
manager.list_connections()
.iter()
.map(|c| c.name.clone())
.collect::<Vec<_>>()
.join(", ")
}
/// 将 MySQL Row 转换为字符串向量
fn row_to_strings(row: &MysqlRow, col_count: usize) -> Vec<Option<String>> {
(0..col_count)
.map(|i| {
match row.get::<Value, usize>(i) {
Some(value) => value_to_string(&value),
None => None,
}
})
.collect()
}
/// 将 Value 转换为字符串
fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::NULL => None,
Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(),
Value::Int(i) => Some(i.to_string()),
Value::UInt(u) => Some(u.to_string()),
Value::Float(f) => Some(f.to_string()),
Value::Double(d) => Some(d.to_string()),
Value::Date(year, month, day, hour, min, sec, micro) => {
match (hour, min, sec, micro) {
(0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)),
_ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)),
}
}
Value::Time(neg, days, hours, minutes, seconds, microseconds) => {
let sign = if *neg { "-" } else { "" };
Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds))
}
}
}

228
mysql-proxy/src/logger.rs Normal file
View File

@@ -0,0 +1,228 @@
use serde::{Deserialize, Serialize};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
/// 日志记录器
pub struct RequestLogger {
log_file: Mutex<Option<File>>,
enabled: bool,
}
impl RequestLogger {
pub fn new(log_path: Option<&str>) -> Self {
let (log_file, enabled) = if let Some(path) = log_path {
let path = expand_path(path);
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match OpenOptions::new()
.create(true)
.append(true)
.open(&path)
{
Ok(file) => (Some(file), true),
Err(e) => {
eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e);
(None, false)
}
}
} else {
(None, false)
};
Self {
log_file: Mutex::new(log_file),
enabled,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// 记录请求日志
pub fn log(&self, entry: &LogEntry) {
if !self.enabled {
return;
}
let json = match serde_json::to_string(entry) {
Ok(j) => j,
Err(e) => {
eprintln!("[Logger] Failed to serialize log: {}", e);
return;
}
};
if let Ok(mut file) = self.log_file.lock() {
if let Some(ref mut f) = *file {
let _ = writeln!(f, "{}", json);
}
}
}
}
/// 日志条目
#[derive(Debug, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: String,
pub level: String,
#[serde(rename = "type")]
pub log_type: String,
pub client: String,
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub conn: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub server: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sql: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(rename = "durationMs")]
pub duration_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub rows: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "exitCode")]
pub exit_code: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl LogEntry {
pub fn new(endpoint: &str, client: &str) -> Self {
Self {
timestamp: current_timestamp(),
level: "INFO".to_string(),
log_type: "request".to_string(),
client: client.to_string(),
endpoint: endpoint.to_string(),
conn: None,
server: None,
sql: None,
command: None,
duration_ms: 0,
rows: None,
exit_code: None,
error: None,
}
}
pub fn with_conn(mut self, conn: &str) -> Self {
self.conn = Some(conn.to_string());
self
}
pub fn with_server(mut self, server: &str) -> Self {
self.server = Some(server.to_string());
self
}
pub fn with_sql(mut self, sql: &str) -> Self {
// 截断过长的 SQL
self.sql = Some(truncate_string(sql, 1000));
self
}
pub fn with_command(mut self, command: &str) -> Self {
// 截断过长的命令
self.command = Some(truncate_string(command, 500));
self
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = ms;
self
}
pub fn with_rows(mut self, rows: usize) -> Self {
self.rows = Some(rows);
self
}
pub fn with_exit_code(mut self, code: i32) -> Self {
self.exit_code = Some(code);
self
}
pub fn with_error(mut self, error: &str) -> Self {
self.level = "ERROR".to_string();
self.error = Some(error.to_string());
self
}
}
fn current_timestamp() -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let secs = now.as_secs();
let datetime = chrono_timestamp(secs);
format!("{}Z", datetime)
}
fn chrono_timestamp(secs: u64) -> String {
let days = secs / 86400;
let remaining = secs % 86400;
let hours = remaining / 3600;
let minutes = (remaining % 3600) / 60;
let seconds = remaining % 60;
// 从 1970-01-01 开始计算日期
let mut year = 1970;
let mut days_left = days;
loop {
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
if days_left < days_in_year {
break;
}
days_left -= days_in_year;
year += 1;
}
let month_days = if is_leap_year(year) {
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
} else {
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
};
let mut month = 1;
for &days_in_month in &month_days {
if days_left < days_in_month {
break;
}
days_left -= days_in_month;
month += 1;
}
let day = days_left + 1;
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds)
}
fn is_leap_year(year: u64) -> bool {
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
}
fn truncate_string(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}... (truncated)", &s[..max_len])
}
}
fn expand_path(path: &str) -> PathBuf {
if path.starts_with('~') {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.unwrap_or_default();
PathBuf::from(path.replacen('~', &home, 1))
} else {
PathBuf::from(path)
}
}

212
mysql-proxy/src/main.rs Normal file
View File

@@ -0,0 +1,212 @@
mod cli;
mod config;
mod db;
mod handler;
mod logger;
use axum::{
routing::{get, post},
Router,
};
use clap::{Parser, Subcommand};
use std::sync::Arc;
#[derive(Parser, Debug)]
#[command(name = "mysql-proxy")]
#[command(about = "MySQL HTTP proxy with connection pooling")]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
/// Config file path
#[arg(short, long, default_value = "mysql-proxy.toml", global = true)]
config: String,
/// Server URL (for CLI mode)
#[arg(short = 'S', long, default_value = "http://127.0.0.1:3307", global = true)]
server: String,
}
#[derive(Debug, Subcommand)]
enum Commands {
/// Start HTTP server (default)
Server {
/// Port to listen on
#[arg(short = 'P', long)]
port: Option<u16>,
/// Host to bind
#[arg(short = 'H', long)]
host: Option<String>,
},
/// Execute SQL via proxy (for AI clients)
Cli {
/// Connection name
#[arg(short, long, default_value = "flux_dev")]
conn: String,
/// SQL to execute (与 mysql 官方一致)
#[arg(short = 'e', long)]
sql: Option<String>,
/// Output format: table, json, csv, vertical
#[arg(short = 'F', long, default_value = "table")]
format: String,
/// Execute INSERT/UPDATE/DELETE
#[arg(short = 'x', long)]
execute: bool,
},
/// List available connections
Connections,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Some(Commands::Server { port, host }) => {
run_server(&args.config, port, host).await
}
Some(Commands::Cli { conn, sql, format, execute }) => {
run_cli(&args.server, &conn, sql.as_deref(), &format, execute)
}
Some(Commands::Connections) => {
list_connections(&args.server)
}
None => {
run_server(&args.config, None, None).await
}
}
}
/// 启动 HTTP 服务器
async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>) -> anyhow::Result<()> {
println!("MySQL HTTP Proxy v0.1.0\n");
// 加载配置
let mut config = config::Config::from_file(config_path)?;
// 命令行参数覆盖配置文件
if let Some(port) = port {
config.server.port = port;
}
if let Some(host) = host {
config.server.host = host;
}
// 初始化日志
let log_path = std::env::var("MYSQL_PROXY_LOG").ok();
let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref()));
if logger.is_enabled() {
println!("Request logging: enabled");
}
// 初始化连接池
println!("Initializing connection pools...\n");
let manager = Arc::new(db::ConnectionManager::new(&config.connections, config.pool.clone())?);
// 启动后台清理任务
let manager_clone = manager.clone();
let check_interval = config.pool.check_interval_secs;
tokio::spawn(async move {
let mut interval = tokio::time::interval(
std::time::Duration::from_secs(check_interval)
);
loop {
interval.tick().await;
manager_clone.cleanup_idle();
}
});
// 构建路由
let app = Router::new()
.route("/query", post(handler::query))
.route("/execute", post(handler::execute))
.route("/connections", get(handler::connections))
.route("/connections/add", post(handler::add_connection))
.route("/health", get(handler::health))
.with_state(Arc::new((manager, logger)));
// 启动服务器
let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
println!("\nServer started at http://{}", addr);
println!("\nAPI Endpoints:");
println!(" POST /query - Execute SELECT queries");
println!(" POST /execute - Execute INSERT/UPDATE/DELETE");
println!(" GET /connections - List all connections");
println!(" POST /connections/add - Add connection (temporary)");
println!(" GET /health - Health check");
println!("\nCLI Usage:");
println!(" mysql-proxy cli -c flux_dev -e \"SELECT 1\"");
println!(" mysql-proxy connections");
axum::serve(listener, app).await?;
Ok(())
}
/// CLI 模式
fn run_cli(server: &str, conn: &str, sql: Option<&str>, format: &str, execute: bool) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server.to_string()));
// 检查服务器是否运行
if !cli.check_server()? {
eprintln!();
eprintln!("=== mysql-proxy 未运行,请使用降级方案 ===");
eprintln!();
eprintln!("启动代理: mysql-proxy");
eprintln!();
eprintln!("或直接使用 mysql 命令:");
eprintln!(" mysql -h<host> -u<user> -p<password> -D<database> -e \"<SQL>\"");
eprintln!();
eprintln!("配置文件中的连接信息见: mysql-proxy.toml");
anyhow::bail!("Proxy server not running at {}", server);
}
// 如果没有提供 SQL从 stdin 读取
let sql = match sql {
Some(s) => s.to_string(),
None => {
use std::io::{self, BufRead};
let stdin = io::stdin();
let mut lines = Vec::new();
for line in stdin.lock().lines() {
lines.push(line?);
}
lines.join(" ")
}
};
if sql.trim().is_empty() {
anyhow::bail!("No SQL provided. Use -e \"SQL\" or pipe SQL via stdin");
}
if execute {
cli.execute(conn, &sql)?;
} else {
cli.query(conn, &sql, Some(format))?;
}
Ok(())
}
/// 列出连接
fn list_connections(server: &str) -> anyhow::Result<()> {
let cli = cli::Cli::new(Some(server.to_string()));
if !cli.check_server()? {
eprintln!();
eprintln!("=== mysql-proxy 未运行 ===");
eprintln!("启动代理: mysql-proxy");
eprintln!("配置文件: mysql-proxy.toml");
anyhow::bail!("Proxy server not running at {}", server);
}
cli.list_connections()
}