新增: MySQL/SSH 代理工具
- mysql-proxy: MySQL HTTP 代理,连接池复用 - ssh-proxy: SSH HTTP 代理,会话复用 - mysql-cli: 轻量级 MySQL CLI 工具 功能特性: - 延迟初始化,启动快 - CLI 和 HTTP API 双模式 - 请求日志支持 - 错误友好提示 - JSON 极简输出格式
This commit is contained in:
21
mysql-proxy/Cargo.toml
Normal file
21
mysql-proxy/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "mysql-proxy"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "MySQL HTTP proxy with connection pooling"
|
||||
|
||||
[dependencies]
|
||||
mysql = "25"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
axum = "0.7"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
ureq = "2" # CLI HTTP client (2.x has simpler API)
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
strip = true
|
||||
287
mysql-proxy/src/cli.rs
Normal file
287
mysql-proxy/src/cli.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
use anyhow::{Result, bail};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Instant;
|
||||
|
||||
const DEFAULT_SERVER: &str = "http://127.0.0.1:3307";
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct QueryRequest {
|
||||
conn: String,
|
||||
sql: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct QueryResponse {
|
||||
columns: Vec<String>,
|
||||
rows: Vec<Vec<Option<String>>>,
|
||||
#[serde(rename = "rowCount")]
|
||||
row_count: usize,
|
||||
#[serde(rename = "durationMs")]
|
||||
duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConnectionsResponse {
|
||||
connections: Vec<ConnectionInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConnectionInfo {
|
||||
name: String,
|
||||
database: String,
|
||||
host: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
/// CLI 客户端
|
||||
pub struct Cli {
|
||||
server: String,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn new(server: Option<String>) -> Self {
|
||||
Self {
|
||||
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// 执行查询并输出结果
|
||||
pub fn query(&self, conn: &str, sql: &str, format: Option<&str>) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
|
||||
let body = serde_json::to_string(&QueryRequest {
|
||||
conn: conn.to_string(),
|
||||
sql: sql.to_string(),
|
||||
})?;
|
||||
|
||||
let response = ureq::post(&format!("{}/query", self.server))
|
||||
.set("Content-Type", "application/json")
|
||||
.send_string(&body);
|
||||
|
||||
// ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中
|
||||
let body = match response {
|
||||
Ok(r) => r.into_string()?,
|
||||
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||
Err(e) => bail!("HTTP error: {}", e),
|
||||
};
|
||||
|
||||
let total_ms = start.elapsed().as_millis();
|
||||
|
||||
// 解析响应
|
||||
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||
eprintln!("Error: {}", err.error);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result: QueryResponse = serde_json::from_str(&body)?;
|
||||
|
||||
// 输出结果
|
||||
match format {
|
||||
Some("json") => self.print_json(&result),
|
||||
Some("csv") => self.print_csv(&result),
|
||||
Some("vertical") | Some("vert") => self.print_vertical(&result),
|
||||
_ => self.print_table(&result),
|
||||
}
|
||||
|
||||
println!("\n{} rows in set ({}ms db, {}ms total)",
|
||||
result.row_count, result.duration_ms, total_ms);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 执行语句 (INSERT/UPDATE/DELETE)
|
||||
pub fn execute(&self, conn: &str, sql: &str) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
|
||||
let body = serde_json::to_string(&QueryRequest {
|
||||
conn: conn.to_string(),
|
||||
sql: sql.to_string(),
|
||||
})?;
|
||||
|
||||
let response = ureq::post(&format!("{}/execute", self.server))
|
||||
.set("Content-Type", "application/json")
|
||||
.send_string(&body);
|
||||
|
||||
// ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中
|
||||
let body = match response {
|
||||
Ok(r) => r.into_string()?,
|
||||
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||
Err(e) => bail!("HTTP error: {}", e),
|
||||
};
|
||||
|
||||
let total_ms = start.elapsed().as_millis();
|
||||
|
||||
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||
eprintln!("Error: {}", err.error);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecuteResponse {
|
||||
#[serde(rename = "affectedRows")]
|
||||
affected_rows: u64,
|
||||
#[serde(rename = "lastInsertId")]
|
||||
last_insert_id: u64,
|
||||
}
|
||||
|
||||
let result: ExecuteResponse = serde_json::from_str(&body)?;
|
||||
println!("Query OK, {} rows affected ({}ms total)", result.affected_rows, total_ms);
|
||||
|
||||
if result.last_insert_id > 0 {
|
||||
println!("Last insert ID: {}", result.last_insert_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 列出连接
|
||||
pub fn list_connections(&self) -> Result<()> {
|
||||
let response = ureq::get(&format!("{}/connections", self.server))
|
||||
.call()?;
|
||||
|
||||
let body = response.into_string()?;
|
||||
|
||||
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||
bail!("{}", err.error);
|
||||
}
|
||||
|
||||
let result: ConnectionsResponse = serde_json::from_str(&body)?;
|
||||
|
||||
println!("Connections:");
|
||||
println!("{:<15} {:<20} {:<40} {:<10}", "Name", "Database", "Host", "Status");
|
||||
println!("{}", "-".repeat(85));
|
||||
|
||||
for conn in result.connections {
|
||||
println!("{:<15} {:<20} {:<40} {:<10}",
|
||||
conn.name, conn.database, conn.host, conn.status);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查代理是否运行
|
||||
pub fn check_server(&self) -> Result<bool> {
|
||||
match ureq::get(&format!("{}/health", self.server)).call() {
|
||||
Ok(_) => Ok(true),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
// 输出格式化方法
|
||||
fn print_table(&self, result: &QueryResponse) {
|
||||
if result.columns.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// 计算列宽
|
||||
let mut widths: Vec<usize> = result.columns.iter().map(|c| c.len()).collect();
|
||||
|
||||
for row in &result.rows {
|
||||
for (i, val) in row.iter().enumerate() {
|
||||
let len = val.as_ref().map(|s| s.len()).unwrap_or(4); // NULL
|
||||
widths[i] = widths[i].max(len);
|
||||
}
|
||||
}
|
||||
|
||||
// 打印表头
|
||||
let header: String = result.columns.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| format!(" {:width$} ", c, width = widths[i]))
|
||||
.collect::<Vec<_>>()
|
||||
.join("|");
|
||||
|
||||
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||
println!("|{}|", header);
|
||||
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||
|
||||
// 打印数据行
|
||||
for row in &result.rows {
|
||||
let line: String = row.iter()
|
||||
.enumerate()
|
||||
.map(|(i, val)| {
|
||||
match val {
|
||||
Some(s) => format!(" {:width$} ", s, width = widths[i]),
|
||||
None => format!(" {:width$} ", "NULL", width = widths[i]),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("|");
|
||||
println!("|{}|", line);
|
||||
}
|
||||
|
||||
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||
}
|
||||
|
||||
fn print_json(&self, result: &QueryResponse) {
|
||||
// 极简格式:对象数组
|
||||
let rows: Vec<serde_json::Value> = result.rows.iter().map(|row| {
|
||||
let mut obj = serde_json::Map::new();
|
||||
for (i, col) in result.columns.iter().enumerate() {
|
||||
let val = row.get(i)
|
||||
.and_then(|v| v.as_ref())
|
||||
.map(|s| {
|
||||
// 尝试解析为数字
|
||||
if let Ok(n) = s.parse::<i64>() {
|
||||
serde_json::Value::Number(n.into())
|
||||
} else if let Ok(n) = s.parse::<f64>() {
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)))
|
||||
} else {
|
||||
serde_json::Value::String(s.clone())
|
||||
}
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
obj.insert(col.clone(), val);
|
||||
}
|
||||
serde_json::Value::Object(obj)
|
||||
}).collect();
|
||||
|
||||
println!("{}", serde_json::to_string(&rows).unwrap_or_default());
|
||||
}
|
||||
|
||||
fn print_csv(&self, result: &QueryResponse) {
|
||||
// 打印表头
|
||||
println!("{}", result.columns.join(","));
|
||||
|
||||
// 打印数据
|
||||
for row in &result.rows {
|
||||
let line: String = row.iter()
|
||||
.map(|val| {
|
||||
match val {
|
||||
Some(s) => {
|
||||
if s.contains(',') || s.contains('"') || s.contains('\n') {
|
||||
format!("\"{}\"", s.replace('"', "\"\""))
|
||||
} else {
|
||||
s.clone()
|
||||
}
|
||||
}
|
||||
None => "".to_string(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
println!("{}", line);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_vertical(&self, result: &QueryResponse) {
|
||||
for (row_idx, row) in result.rows.iter().enumerate() {
|
||||
if row_idx > 0 {
|
||||
println!();
|
||||
}
|
||||
println!("*************************** {}. row ***************************", row_idx + 1);
|
||||
|
||||
for (col_idx, col) in result.columns.iter().enumerate() {
|
||||
let val = row.get(col_idx)
|
||||
.and_then(|v| v.as_ref())
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("NULL");
|
||||
println!("{:>20}: {}", col, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
156
mysql-proxy/src/config.rs
Normal file
156
mysql-proxy/src/config.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use anyhow::{Result, bail};
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
pub port: u16,
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
/// 连接池配置
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct PoolConfig {
|
||||
/// 默认最大连接数
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub default_max_connections: usize,
|
||||
/// 空闲超时 (秒)
|
||||
#[serde(default = "default_idle_timeout")]
|
||||
pub idle_timeout_secs: u64,
|
||||
/// 检查间隔 (秒)
|
||||
#[serde(default = "default_check_interval")]
|
||||
pub check_interval_secs: u64,
|
||||
}
|
||||
|
||||
fn default_max_connections() -> usize { 5 }
|
||||
fn default_idle_timeout() -> u64 { 300 }
|
||||
fn default_check_interval() -> u64 { 60 }
|
||||
|
||||
impl Default for PoolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_max_connections: default_max_connections(),
|
||||
idle_timeout_secs: default_idle_timeout(),
|
||||
check_interval_secs: default_check_interval(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ConnectionConfig {
|
||||
pub name: String,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub user: String,
|
||||
pub password: String,
|
||||
pub database: String,
|
||||
/// 最大连接数 (可选,覆盖默认值)
|
||||
#[serde(default)]
|
||||
pub max_connections: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_server")]
|
||||
pub server: ServerConfig,
|
||||
#[serde(default)]
|
||||
pub pool: PoolConfig,
|
||||
pub connections: Vec<ConnectionConfig>,
|
||||
}
|
||||
|
||||
fn default_server() -> ServerConfig {
|
||||
ServerConfig {
|
||||
port: 3307,
|
||||
host: "127.0.0.1".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// 从文件加载配置
|
||||
pub fn from_file(path: &str) -> Result<Self> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
let config: Config = toml::from_str(&content)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// 从默认路径加载配置
|
||||
pub fn load() -> Result<Self> {
|
||||
// 尝试多个配置文件路径
|
||||
let paths = [
|
||||
"mysql-proxy.toml",
|
||||
"./config/mysql-proxy.toml",
|
||||
&format!("{}/.mysql-proxy.toml", std::env::var("HOME").unwrap_or_default()),
|
||||
];
|
||||
|
||||
for path in &paths {
|
||||
if std::path::Path::new(path).exists() {
|
||||
return Self::from_file(path);
|
||||
}
|
||||
}
|
||||
|
||||
bail!("Config file not found. Create mysql-proxy.toml in current directory or use --config flag")
|
||||
}
|
||||
|
||||
/// 验证配置
|
||||
fn validate(&self) -> Result<()> {
|
||||
let mut names = std::collections::HashSet::new();
|
||||
for conn in &self.connections {
|
||||
if conn.name.is_empty() {
|
||||
bail!("Connection name cannot be empty");
|
||||
}
|
||||
if names.contains(&conn.name) {
|
||||
bail!("Duplicate connection name: {}", conn.name);
|
||||
}
|
||||
names.insert(conn.name.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionConfig {
|
||||
/// 构建 DSN
|
||||
pub fn build_dsn(&self) -> String {
|
||||
let password = urlencoding(&self.password);
|
||||
format!(
|
||||
"mysql://{}:{}@{}:{}/{}",
|
||||
self.user, password, self.host, self.port, self.database
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// URL 编码密码中的特殊字符
|
||||
fn urlencoding(s: &str) -> String {
|
||||
let mut result = String::new();
|
||||
for c in s.chars() {
|
||||
match c {
|
||||
'@' => result.push_str("%40"),
|
||||
':' => result.push_str("%3A"),
|
||||
'/' => result.push_str("%2F"),
|
||||
'?' => result.push_str("%3F"),
|
||||
'#' => result.push_str("%23"),
|
||||
'[' => result.push_str("%5B"),
|
||||
']' => result.push_str("%5D"),
|
||||
'!' => result.push_str("%21"),
|
||||
'$' => result.push_str("%24"),
|
||||
'&' => result.push_str("%26"),
|
||||
'\'' => result.push_str("%27"),
|
||||
'(' => result.push_str("%28"),
|
||||
')' => result.push_str("%29"),
|
||||
'*' => result.push_str("%2A"),
|
||||
'+' => result.push_str("%2B"),
|
||||
',' => result.push_str("%2C"),
|
||||
';' => result.push_str("%3B"),
|
||||
'=' => result.push_str("%3D"),
|
||||
'%' => result.push_str("%25"),
|
||||
' ' => result.push_str("%20"),
|
||||
_ => result.push(c),
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
173
mysql-proxy/src/db.rs
Normal file
173
mysql-proxy/src/db.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use anyhow::{Result, bail};
|
||||
use mysql::{Pool, PooledConn, Opts,OptsBuilder, prelude::*};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::config::{ConnectionConfig, PoolConfig};
|
||||
|
||||
/// 连接池状态
|
||||
struct PoolState {
|
||||
pool: Arc<Pool>,
|
||||
last_used: Instant,
|
||||
}
|
||||
|
||||
/// 连接池管理器 (延迟初始化)
|
||||
pub struct ConnectionManager {
|
||||
pools: Mutex<HashMap<String, PoolState>>,
|
||||
configs: Mutex<HashMap<String, ConnectionConfig>>,
|
||||
pool_config: PoolConfig,
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
/// 创建连接管理器 (不初始化连接)
|
||||
pub fn new(configs: &[ConnectionConfig], pool_config: PoolConfig) -> Result<Self> {
|
||||
let mut config_map = HashMap::new();
|
||||
|
||||
for cfg in configs {
|
||||
config_map.insert(cfg.name.clone(), cfg.clone());
|
||||
println!(" Registered: {} ({}/{})", cfg.name, cfg.host, cfg.database);
|
||||
}
|
||||
|
||||
println!("\n{} connection(s) configured (lazy init)", config_map.len());
|
||||
println!("Pool config: max_connections={}, idle_timeout={}s",
|
||||
pool_config.default_max_connections, pool_config.idle_timeout_secs);
|
||||
|
||||
Ok(Self {
|
||||
pools: Mutex::new(HashMap::new()),
|
||||
configs: Mutex::new(config_map),
|
||||
pool_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取或创建连接
|
||||
pub fn get_conn(&self, name: &str) -> Result<PooledConn> {
|
||||
// 1. 先尝试从已有池中获取
|
||||
{
|
||||
let mut pools = self.pools.lock().unwrap();
|
||||
if let Some(state) = pools.get_mut(name) {
|
||||
state.last_used = Instant::now();
|
||||
return Ok(state.pool.get_conn()?);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 没有则创建新池
|
||||
let cfg = self.configs.lock().unwrap().get(name)
|
||||
.ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?
|
||||
.clone();
|
||||
|
||||
println!("[LazyInit] Creating connection pool for: {}", name);
|
||||
let pool = self.create_pool(&cfg)?;
|
||||
let arc_pool = Arc::new(pool);
|
||||
|
||||
// 3. 保存并返回连接
|
||||
{
|
||||
let mut pools = self.pools.lock().unwrap();
|
||||
pools.insert(name.to_string(), PoolState {
|
||||
pool: arc_pool.clone(),
|
||||
last_used: Instant::now(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(arc_pool.get_conn()?)
|
||||
}
|
||||
|
||||
/// 创建连接池
|
||||
fn create_pool(&self, cfg: &ConnectionConfig) -> Result<Pool> {
|
||||
let dsn = cfg.build_dsn();
|
||||
let opts: Opts = Opts::from_url(&dsn)?;
|
||||
|
||||
// 获取最大连接数 (单个配置 > 全局默认)
|
||||
let max_conn = cfg.max_connections.unwrap_or(self.pool_config.default_max_connections);
|
||||
|
||||
// 构建带连接池参数的选项
|
||||
let pool_opts = OptsBuilder::from_opts(opts)
|
||||
.pool_opts(mysql::PoolOpts::new()
|
||||
.with_constraints(mysql::PoolConstraints::new(1, max_conn).unwrap()));
|
||||
|
||||
let pool = Pool::new(pool_opts)?;
|
||||
|
||||
// 测试连接
|
||||
let mut conn = pool.get_conn()?;
|
||||
conn.query::<String, _>("SELECT 1")?;
|
||||
|
||||
println!("[LazyInit] ✓ Connected: {} (max_conn={})", cfg.name, max_conn);
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// 获取所有连接信息
|
||||
pub fn list_connections(&self) -> Vec<ConnectionInfo> {
|
||||
let pools = self.pools.lock().unwrap();
|
||||
let configs = self.configs.lock().unwrap();
|
||||
|
||||
configs.iter().map(|(name, cfg)| {
|
||||
let status = if pools.contains_key(name) { "connected" } else { "pending" };
|
||||
ConnectionInfo {
|
||||
name: name.clone(),
|
||||
database: cfg.database.clone(),
|
||||
host: cfg.host.clone(),
|
||||
status: status.to_string(),
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// 清理空闲连接池
|
||||
pub fn cleanup_idle(&self) {
|
||||
let mut pools = self.pools.lock().unwrap();
|
||||
let now = Instant::now();
|
||||
let idle_timeout = Duration::from_secs(self.pool_config.idle_timeout_secs);
|
||||
|
||||
pools.retain(|name, state| {
|
||||
let elapsed = now.duration_since(state.last_used);
|
||||
if elapsed > idle_timeout {
|
||||
println!("[Cleanup] Removing idle pool: {} (idle {}s)", name, elapsed.as_secs());
|
||||
return false;
|
||||
}
|
||||
true
|
||||
});
|
||||
}
|
||||
|
||||
/// 检查连接是否健康
|
||||
pub fn health_check(&self, name: &str) -> bool {
|
||||
let pools = self.pools.lock().unwrap();
|
||||
if let Some(state) = pools.get(name) {
|
||||
return state.pool.get_conn().ok()
|
||||
.and_then(|mut c| c.query::<String, _>("SELECT 1").ok())
|
||||
.is_some();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// 动态添加连接配置 (临时,重启后消失)
|
||||
pub fn add_connection(&self, cfg: ConnectionConfig) -> Result<()> {
|
||||
let name = cfg.name.clone();
|
||||
let mut configs = self.configs.lock().unwrap();
|
||||
if configs.contains_key(&name) {
|
||||
bail!("Connection '{}' already exists", name);
|
||||
}
|
||||
|
||||
println!("[Dynamic] Adding: {} ({}/{})", name, cfg.host, cfg.database);
|
||||
|
||||
// 测试连接
|
||||
let pool = self.create_pool(&cfg)?;
|
||||
let arc_pool = Arc::new(pool);
|
||||
|
||||
// 保存
|
||||
self.pools.lock().unwrap().insert(name.clone(), PoolState {
|
||||
pool: arc_pool,
|
||||
last_used: Instant::now(),
|
||||
});
|
||||
configs.insert(name.clone(), cfg);
|
||||
|
||||
println!("[Dynamic] ✓ Added: {}", name);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct ConnectionInfo {
|
||||
pub name: String,
|
||||
pub database: String,
|
||||
pub host: String,
|
||||
pub status: String,
|
||||
}
|
||||
320
mysql-proxy/src/handler.rs
Normal file
320
mysql-proxy/src/handler.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use mysql::{prelude::*, Value, Row as MysqlRow};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::db::ConnectionManager;
|
||||
use crate::logger::{LogEntry, RequestLogger};
|
||||
|
||||
/// 应用状态
|
||||
pub type AppState = Arc<(Arc<ConnectionManager>, Arc<RequestLogger>)>;
|
||||
|
||||
// ============== 请求结构 ==============
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct QueryRequest {
|
||||
/// 连接名称
|
||||
pub conn: String,
|
||||
/// SQL 语句
|
||||
pub sql: String,
|
||||
/// 输出格式: json, table, csv, vertical
|
||||
#[serde(default = "default_format")]
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
fn default_format() -> String {
|
||||
"json".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ExecuteRequest {
|
||||
pub conn: String,
|
||||
pub sql: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AddConnectionRequest {
|
||||
pub name: String,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub user: String,
|
||||
pub password: String,
|
||||
pub database: String,
|
||||
}
|
||||
|
||||
// ============== 响应结构 ==============
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct QueryResponse {
|
||||
pub columns: Vec<String>,
|
||||
pub rows: Vec<Vec<Option<String>>>,
|
||||
#[serde(rename = "rowCount")]
|
||||
pub row_count: usize,
|
||||
#[serde(rename = "durationMs")]
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecuteResponse {
|
||||
#[serde(rename = "affectedRows")]
|
||||
pub affected_rows: u64,
|
||||
#[serde(rename = "lastInsertId")]
|
||||
pub last_insert_id: u64,
|
||||
#[serde(rename = "durationMs")]
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ConnectionsResponse {
|
||||
pub connections: Vec<crate::db::ConnectionInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct HealthResponse {
|
||||
pub status: String,
|
||||
pub connections: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: String,
|
||||
#[serde(rename = "usage")]
|
||||
pub usage: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn new(msg: &str) -> Self {
|
||||
Self { error: msg.to_string(), usage: None }
|
||||
}
|
||||
|
||||
pub fn with_usage(mut self, usage: &str) -> Self {
|
||||
self.usage = Some(usage.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============== 处理器 ==============
|
||||
|
||||
/// 查询处理器
|
||||
pub async fn query(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<QueryRequest>,
|
||||
) -> Result<Json<QueryResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let (manager, logger) = &*state;
|
||||
let start = Instant::now();
|
||||
|
||||
// 获取连接
|
||||
let mut conn = manager.get_conn(&req.conn)
|
||||
.map_err(|e| {
|
||||
let usage = "Usage: POST /query {\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}\n"
|
||||
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
|
||||
error_response_with_usage(&e.to_string(), &usage)
|
||||
})?;
|
||||
|
||||
// 判断是否是查询语句
|
||||
let sql_upper = req.sql.trim().to_uppercase();
|
||||
let is_query = sql_upper.starts_with("SELECT")
|
||||
|| sql_upper.starts_with("SHOW")
|
||||
|| sql_upper.starts_with("DESCRIBE")
|
||||
|| sql_upper.starts_with("DESC ")
|
||||
|| sql_upper.starts_with("EXPLAIN")
|
||||
|| sql_upper.starts_with("WITH");
|
||||
|
||||
if !is_query {
|
||||
let err = error_response_with_usage(
|
||||
"Not a SELECT query. Use /execute for INSERT/UPDATE/DELETE",
|
||||
"Usage:\n POST /query {\"conn\": \"name\", \"sql\": \"SELECT ...\"}\n POST /execute {\"conn\": \"name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}"
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
let result = conn.query_iter(&req.sql)
|
||||
.map_err(|e| {
|
||||
let usage = format!("SQL Error: {}\n\nUsage: POST /query {{\"conn\": \"name\", \"sql\": \"SELECT ...\"}}", e);
|
||||
error_response_with_usage(&e.to_string(), &usage)
|
||||
})?;
|
||||
|
||||
let mut columns: Vec<String> = Vec::new();
|
||||
let mut data: Vec<Vec<Option<String>>> = Vec::new();
|
||||
|
||||
// 获取数据并从第一行提取列名
|
||||
for row_result in result {
|
||||
let row = row_result.map_err(|e| error_response(&e.to_string()))?;
|
||||
|
||||
// 从第一行获取列名
|
||||
if columns.is_empty() {
|
||||
columns = row.columns()
|
||||
.iter()
|
||||
.map(|c| c.name_str().to_string())
|
||||
.collect();
|
||||
}
|
||||
|
||||
let values = row_to_strings(&row, columns.len());
|
||||
data.push(values);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let row_count = data.len();
|
||||
|
||||
// 记录日志
|
||||
logger.log(&LogEntry::new("/query", "http")
|
||||
.with_conn(&req.conn)
|
||||
.with_sql(&req.sql)
|
||||
.with_duration(duration.as_millis() as u64)
|
||||
.with_rows(row_count));
|
||||
|
||||
Ok(Json(QueryResponse {
|
||||
columns,
|
||||
rows: data,
|
||||
row_count,
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 执行处理器 (INSERT/UPDATE/DELETE)
|
||||
pub async fn execute(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<ExecuteRequest>,
|
||||
) -> Result<Json<ExecuteResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let (manager, logger) = &*state;
|
||||
let start = Instant::now();
|
||||
|
||||
// 获取连接
|
||||
let mut conn = manager.get_conn(&req.conn)
|
||||
.map_err(|e| {
|
||||
let usage = "Usage: POST /execute {\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}\n"
|
||||
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
|
||||
error_response_with_usage(&e.to_string(), &usage)
|
||||
})?;
|
||||
|
||||
// 执行
|
||||
let result = conn.query_iter(&req.sql)
|
||||
.map_err(|e| error_response(&e.to_string()))?;
|
||||
|
||||
let duration = start.elapsed();
|
||||
let affected = result.affected_rows();
|
||||
let last_id = result.last_insert_id().unwrap_or(0);
|
||||
|
||||
// 记录日志
|
||||
logger.log(&LogEntry::new("/execute", "http")
|
||||
.with_conn(&req.conn)
|
||||
.with_sql(&req.sql)
|
||||
.with_duration(duration.as_millis() as u64));
|
||||
|
||||
Ok(Json(ExecuteResponse {
|
||||
affected_rows: affected,
|
||||
last_insert_id: last_id,
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 获取连接列表
|
||||
pub async fn connections(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ConnectionsResponse> {
|
||||
let (manager, _) = &*state;
|
||||
Json(ConnectionsResponse {
|
||||
connections: manager.list_connections(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 健康检查
|
||||
pub async fn health(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<HealthResponse> {
|
||||
let (manager, _) = &*state;
|
||||
Json(HealthResponse {
|
||||
status: "ok".to_string(),
|
||||
connections: manager.list_connections().len(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 动态添加连接
|
||||
pub async fn add_connection(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<AddConnectionRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
|
||||
let (manager, logger) = &*state;
|
||||
use crate::config::ConnectionConfig;
|
||||
|
||||
let cfg = ConnectionConfig {
|
||||
name: req.name.clone(),
|
||||
host: req.host,
|
||||
port: req.port,
|
||||
user: req.user,
|
||||
password: req.password,
|
||||
database: req.database,
|
||||
max_connections: None,
|
||||
};
|
||||
|
||||
manager.add_connection(cfg)
|
||||
.map_err(|e| error_response(&e.to_string()))?;
|
||||
|
||||
// 记录日志
|
||||
logger.log(&LogEntry::new("/connections/add", "http")
|
||||
.with_conn(&req.name)
|
||||
.with_duration(0));
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"success": true,
|
||||
"message": format!("Connection '{}' added (temporary, will be lost on restart)", req.name)
|
||||
})))
|
||||
}
|
||||
|
||||
// ============== 辅助函数 ==============
|
||||
|
||||
fn error_response(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg)))
|
||||
}
|
||||
|
||||
fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage)))
|
||||
}
|
||||
|
||||
fn list_conn_names(manager: &ConnectionManager) -> String {
|
||||
manager.list_connections()
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
/// 将 MySQL Row 转换为字符串向量
|
||||
fn row_to_strings(row: &MysqlRow, col_count: usize) -> Vec<Option<String>> {
|
||||
(0..col_count)
|
||||
.map(|i| {
|
||||
match row.get::<Value, usize>(i) {
|
||||
Some(value) => value_to_string(&value),
|
||||
None => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 将 Value 转换为字符串
|
||||
fn value_to_string(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
Value::NULL => None,
|
||||
Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(),
|
||||
Value::Int(i) => Some(i.to_string()),
|
||||
Value::UInt(u) => Some(u.to_string()),
|
||||
Value::Float(f) => Some(f.to_string()),
|
||||
Value::Double(d) => Some(d.to_string()),
|
||||
Value::Date(year, month, day, hour, min, sec, micro) => {
|
||||
match (hour, min, sec, micro) {
|
||||
(0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)),
|
||||
_ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)),
|
||||
}
|
||||
}
|
||||
Value::Time(neg, days, hours, minutes, seconds, microseconds) => {
|
||||
let sign = if *neg { "-" } else { "" };
|
||||
Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds))
|
||||
}
|
||||
}
|
||||
}
|
||||
228
mysql-proxy/src/logger.rs
Normal file
228
mysql-proxy/src/logger.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// 日志记录器
|
||||
pub struct RequestLogger {
|
||||
log_file: Mutex<Option<File>>,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl RequestLogger {
|
||||
pub fn new(log_path: Option<&str>) -> Self {
|
||||
let (log_file, enabled) = if let Some(path) = log_path {
|
||||
let path = expand_path(path);
|
||||
if let Some(parent) = path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
match OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
{
|
||||
Ok(file) => (Some(file), true),
|
||||
Err(e) => {
|
||||
eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e);
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(None, false)
|
||||
};
|
||||
|
||||
Self {
|
||||
log_file: Mutex::new(log_file),
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// 记录请求日志
|
||||
pub fn log(&self, entry: &LogEntry) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let json = match serde_json::to_string(entry) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
eprintln!("[Logger] Failed to serialize log: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Ok(mut file) = self.log_file.lock() {
|
||||
if let Some(ref mut f) = *file {
|
||||
let _ = writeln!(f, "{}", json);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 日志条目
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LogEntry {
|
||||
pub timestamp: String,
|
||||
pub level: String,
|
||||
#[serde(rename = "type")]
|
||||
pub log_type: String,
|
||||
pub client: String,
|
||||
pub endpoint: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub conn: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub server: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sql: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub command: Option<String>,
|
||||
#[serde(rename = "durationMs")]
|
||||
pub duration_ms: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rows: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "exitCode")]
|
||||
pub exit_code: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl LogEntry {
|
||||
pub fn new(endpoint: &str, client: &str) -> Self {
|
||||
Self {
|
||||
timestamp: current_timestamp(),
|
||||
level: "INFO".to_string(),
|
||||
log_type: "request".to_string(),
|
||||
client: client.to_string(),
|
||||
endpoint: endpoint.to_string(),
|
||||
conn: None,
|
||||
server: None,
|
||||
sql: None,
|
||||
command: None,
|
||||
duration_ms: 0,
|
||||
rows: None,
|
||||
exit_code: None,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_conn(mut self, conn: &str) -> Self {
|
||||
self.conn = Some(conn.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_server(mut self, server: &str) -> Self {
|
||||
self.server = Some(server.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_sql(mut self, sql: &str) -> Self {
|
||||
// 截断过长的 SQL
|
||||
self.sql = Some(truncate_string(sql, 1000));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_command(mut self, command: &str) -> Self {
|
||||
// 截断过长的命令
|
||||
self.command = Some(truncate_string(command, 500));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_duration(mut self, ms: u64) -> Self {
|
||||
self.duration_ms = ms;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_rows(mut self, rows: usize) -> Self {
|
||||
self.rows = Some(rows);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_exit_code(mut self, code: i32) -> Self {
|
||||
self.exit_code = Some(code);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_error(mut self, error: &str) -> Self {
|
||||
self.level = "ERROR".to_string();
|
||||
self.error = Some(error.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
let secs = now.as_secs();
|
||||
let datetime = chrono_timestamp(secs);
|
||||
format!("{}Z", datetime)
|
||||
}
|
||||
|
||||
fn chrono_timestamp(secs: u64) -> String {
|
||||
let days = secs / 86400;
|
||||
let remaining = secs % 86400;
|
||||
let hours = remaining / 3600;
|
||||
let minutes = (remaining % 3600) / 60;
|
||||
let seconds = remaining % 60;
|
||||
|
||||
// 从 1970-01-01 开始计算日期
|
||||
let mut year = 1970;
|
||||
let mut days_left = days;
|
||||
|
||||
loop {
|
||||
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
|
||||
if days_left < days_in_year {
|
||||
break;
|
||||
}
|
||||
days_left -= days_in_year;
|
||||
year += 1;
|
||||
}
|
||||
|
||||
let month_days = if is_leap_year(year) {
|
||||
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||
} else {
|
||||
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||
};
|
||||
|
||||
let mut month = 1;
|
||||
for &days_in_month in &month_days {
|
||||
if days_left < days_in_month {
|
||||
break;
|
||||
}
|
||||
days_left -= days_in_month;
|
||||
month += 1;
|
||||
}
|
||||
let day = days_left + 1;
|
||||
|
||||
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds)
|
||||
}
|
||||
|
||||
fn is_leap_year(year: u64) -> bool {
|
||||
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
|
||||
}
|
||||
|
||||
fn truncate_string(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}... (truncated)", &s[..max_len])
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_path(path: &str) -> PathBuf {
|
||||
if path.starts_with('~') {
|
||||
let home = std::env::var("HOME")
|
||||
.or_else(|_| std::env::var("USERPROFILE"))
|
||||
.unwrap_or_default();
|
||||
PathBuf::from(path.replacen('~', &home, 1))
|
||||
} else {
|
||||
PathBuf::from(path)
|
||||
}
|
||||
}
|
||||
212
mysql-proxy/src/main.rs
Normal file
212
mysql-proxy/src/main.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
mod cli;
|
||||
mod config;
|
||||
mod db;
|
||||
mod handler;
|
||||
mod logger;
|
||||
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use clap::{Parser, Subcommand};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "mysql-proxy")]
|
||||
#[command(about = "MySQL HTTP proxy with connection pooling")]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
/// Config file path
|
||||
#[arg(short, long, default_value = "mysql-proxy.toml", global = true)]
|
||||
config: String,
|
||||
|
||||
/// Server URL (for CLI mode)
|
||||
#[arg(short = 'S', long, default_value = "http://127.0.0.1:3307", global = true)]
|
||||
server: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
/// Start HTTP server (default)
|
||||
Server {
|
||||
/// Port to listen on
|
||||
#[arg(short = 'P', long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Host to bind
|
||||
#[arg(short = 'H', long)]
|
||||
host: Option<String>,
|
||||
},
|
||||
|
||||
/// Execute SQL via proxy (for AI clients)
|
||||
Cli {
|
||||
/// Connection name
|
||||
#[arg(short, long, default_value = "flux_dev")]
|
||||
conn: String,
|
||||
|
||||
/// SQL to execute (与 mysql 官方一致)
|
||||
#[arg(short = 'e', long)]
|
||||
sql: Option<String>,
|
||||
|
||||
/// Output format: table, json, csv, vertical
|
||||
#[arg(short = 'F', long, default_value = "table")]
|
||||
format: String,
|
||||
|
||||
/// Execute INSERT/UPDATE/DELETE
|
||||
#[arg(short = 'x', long)]
|
||||
execute: bool,
|
||||
},
|
||||
|
||||
/// List available connections
|
||||
Connections,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
match args.command {
|
||||
Some(Commands::Server { port, host }) => {
|
||||
run_server(&args.config, port, host).await
|
||||
}
|
||||
Some(Commands::Cli { conn, sql, format, execute }) => {
|
||||
run_cli(&args.server, &conn, sql.as_deref(), &format, execute)
|
||||
}
|
||||
Some(Commands::Connections) => {
|
||||
list_connections(&args.server)
|
||||
}
|
||||
None => {
|
||||
run_server(&args.config, None, None).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 启动 HTTP 服务器
|
||||
async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>) -> anyhow::Result<()> {
|
||||
println!("MySQL HTTP Proxy v0.1.0\n");
|
||||
|
||||
// 加载配置
|
||||
let mut config = config::Config::from_file(config_path)?;
|
||||
|
||||
// 命令行参数覆盖配置文件
|
||||
if let Some(port) = port {
|
||||
config.server.port = port;
|
||||
}
|
||||
if let Some(host) = host {
|
||||
config.server.host = host;
|
||||
}
|
||||
|
||||
// 初始化日志
|
||||
let log_path = std::env::var("MYSQL_PROXY_LOG").ok();
|
||||
let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref()));
|
||||
if logger.is_enabled() {
|
||||
println!("Request logging: enabled");
|
||||
}
|
||||
|
||||
// 初始化连接池
|
||||
println!("Initializing connection pools...\n");
|
||||
let manager = Arc::new(db::ConnectionManager::new(&config.connections, config.pool.clone())?);
|
||||
|
||||
// 启动后台清理任务
|
||||
let manager_clone = manager.clone();
|
||||
let check_interval = config.pool.check_interval_secs;
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(
|
||||
std::time::Duration::from_secs(check_interval)
|
||||
);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
manager_clone.cleanup_idle();
|
||||
}
|
||||
});
|
||||
|
||||
// 构建路由
|
||||
let app = Router::new()
|
||||
.route("/query", post(handler::query))
|
||||
.route("/execute", post(handler::execute))
|
||||
.route("/connections", get(handler::connections))
|
||||
.route("/connections/add", post(handler::add_connection))
|
||||
.route("/health", get(handler::health))
|
||||
.with_state(Arc::new((manager, logger)));
|
||||
|
||||
// 启动服务器
|
||||
let addr = format!("{}:{}", config.server.host, config.server.port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
|
||||
println!("\nServer started at http://{}", addr);
|
||||
println!("\nAPI Endpoints:");
|
||||
println!(" POST /query - Execute SELECT queries");
|
||||
println!(" POST /execute - Execute INSERT/UPDATE/DELETE");
|
||||
println!(" GET /connections - List all connections");
|
||||
println!(" POST /connections/add - Add connection (temporary)");
|
||||
println!(" GET /health - Health check");
|
||||
println!("\nCLI Usage:");
|
||||
println!(" mysql-proxy cli -c flux_dev -e \"SELECT 1\"");
|
||||
println!(" mysql-proxy connections");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// CLI 模式
|
||||
fn run_cli(server: &str, conn: &str, sql: Option<&str>, format: &str, execute: bool) -> anyhow::Result<()> {
|
||||
let cli = cli::Cli::new(Some(server.to_string()));
|
||||
|
||||
// 检查服务器是否运行
|
||||
if !cli.check_server()? {
|
||||
eprintln!();
|
||||
eprintln!("=== mysql-proxy 未运行,请使用降级方案 ===");
|
||||
eprintln!();
|
||||
eprintln!("启动代理: mysql-proxy");
|
||||
eprintln!();
|
||||
eprintln!("或直接使用 mysql 命令:");
|
||||
eprintln!(" mysql -h<host> -u<user> -p<password> -D<database> -e \"<SQL>\"");
|
||||
eprintln!();
|
||||
eprintln!("配置文件中的连接信息见: mysql-proxy.toml");
|
||||
anyhow::bail!("Proxy server not running at {}", server);
|
||||
}
|
||||
|
||||
// 如果没有提供 SQL,从 stdin 读取
|
||||
let sql = match sql {
|
||||
Some(s) => s.to_string(),
|
||||
None => {
|
||||
use std::io::{self, BufRead};
|
||||
let stdin = io::stdin();
|
||||
let mut lines = Vec::new();
|
||||
for line in stdin.lock().lines() {
|
||||
lines.push(line?);
|
||||
}
|
||||
lines.join(" ")
|
||||
}
|
||||
};
|
||||
|
||||
if sql.trim().is_empty() {
|
||||
anyhow::bail!("No SQL provided. Use -e \"SQL\" or pipe SQL via stdin");
|
||||
}
|
||||
|
||||
if execute {
|
||||
cli.execute(conn, &sql)?;
|
||||
} else {
|
||||
cli.query(conn, &sql, Some(format))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 列出连接
|
||||
fn list_connections(server: &str) -> anyhow::Result<()> {
|
||||
let cli = cli::Cli::new(Some(server.to_string()));
|
||||
|
||||
if !cli.check_server()? {
|
||||
eprintln!();
|
||||
eprintln!("=== mysql-proxy 未运行 ===");
|
||||
eprintln!("启动代理: mysql-proxy");
|
||||
eprintln!("配置文件: mysql-proxy.toml");
|
||||
anyhow::bail!("Proxy server not running at {}", server);
|
||||
}
|
||||
|
||||
cli.list_connections()
|
||||
}
|
||||
Reference in New Issue
Block a user