Files
rust-work/mysql-proxy/src/cli.rs
绝尘 11203f036f 新增: MySQL/SSH 代理工具
- mysql-proxy: MySQL HTTP 代理,连接池复用
- ssh-proxy: SSH HTTP 代理,会话复用
- mysql-cli: 轻量级 MySQL CLI 工具

功能特性:
- 延迟初始化,启动快
- CLI 和 HTTP API 双模式
- 请求日志支持
- 错误友好提示
- JSON 极简输出格式
2026-03-19 14:03:12 +08:00

287 lines
8.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
use std::time::Instant;
const DEFAULT_SERVER: &str = "http://127.0.0.1:3307";
#[derive(Debug, Serialize)]
struct QueryRequest {
conn: String,
sql: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct QueryResponse {
columns: Vec<String>,
rows: Vec<Vec<Option<String>>>,
#[serde(rename = "rowCount")]
row_count: usize,
#[serde(rename = "durationMs")]
duration_ms: u64,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: String,
}
#[derive(Debug, Deserialize)]
struct ConnectionsResponse {
connections: Vec<ConnectionInfo>,
}
#[derive(Debug, Deserialize)]
struct ConnectionInfo {
name: String,
database: String,
host: String,
status: String,
}
/// CLI 客户端
pub struct Cli {
server: String,
}
impl Cli {
pub fn new(server: Option<String>) -> Self {
Self {
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
}
}
/// 执行查询并输出结果
pub fn query(&self, conn: &str, sql: &str, format: Option<&str>) -> Result<()> {
let start = Instant::now();
let body = serde_json::to_string(&QueryRequest {
conn: conn.to_string(),
sql: sql.to_string(),
})?;
let response = ureq::post(&format!("{}/query", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
// ureq 2.x: 非 2xx 会返回 Err但响应体在 error 中
let body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
let total_ms = start.elapsed().as_millis();
// 解析响应
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
eprintln!("Error: {}", err.error);
return Ok(());
}
let result: QueryResponse = serde_json::from_str(&body)?;
// 输出结果
match format {
Some("json") => self.print_json(&result),
Some("csv") => self.print_csv(&result),
Some("vertical") | Some("vert") => self.print_vertical(&result),
_ => self.print_table(&result),
}
println!("\n{} rows in set ({}ms db, {}ms total)",
result.row_count, result.duration_ms, total_ms);
Ok(())
}
/// 执行语句 (INSERT/UPDATE/DELETE)
pub fn execute(&self, conn: &str, sql: &str) -> Result<()> {
let start = Instant::now();
let body = serde_json::to_string(&QueryRequest {
conn: conn.to_string(),
sql: sql.to_string(),
})?;
let response = ureq::post(&format!("{}/execute", self.server))
.set("Content-Type", "application/json")
.send_string(&body);
// ureq 2.x: 非 2xx 会返回 Err但响应体在 error 中
let body = match response {
Ok(r) => r.into_string()?,
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
Err(e) => bail!("HTTP error: {}", e),
};
let total_ms = start.elapsed().as_millis();
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
eprintln!("Error: {}", err.error);
return Ok(());
}
#[derive(Deserialize)]
struct ExecuteResponse {
#[serde(rename = "affectedRows")]
affected_rows: u64,
#[serde(rename = "lastInsertId")]
last_insert_id: u64,
}
let result: ExecuteResponse = serde_json::from_str(&body)?;
println!("Query OK, {} rows affected ({}ms total)", result.affected_rows, total_ms);
if result.last_insert_id > 0 {
println!("Last insert ID: {}", result.last_insert_id);
}
Ok(())
}
/// 列出连接
pub fn list_connections(&self) -> Result<()> {
let response = ureq::get(&format!("{}/connections", self.server))
.call()?;
let body = response.into_string()?;
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
bail!("{}", err.error);
}
let result: ConnectionsResponse = serde_json::from_str(&body)?;
println!("Connections:");
println!("{:<15} {:<20} {:<40} {:<10}", "Name", "Database", "Host", "Status");
println!("{}", "-".repeat(85));
for conn in result.connections {
println!("{:<15} {:<20} {:<40} {:<10}",
conn.name, conn.database, conn.host, conn.status);
}
Ok(())
}
/// 检查代理是否运行
pub fn check_server(&self) -> Result<bool> {
match ureq::get(&format!("{}/health", self.server)).call() {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
// 输出格式化方法
fn print_table(&self, result: &QueryResponse) {
if result.columns.is_empty() {
return;
}
// 计算列宽
let mut widths: Vec<usize> = result.columns.iter().map(|c| c.len()).collect();
for row in &result.rows {
for (i, val) in row.iter().enumerate() {
let len = val.as_ref().map(|s| s.len()).unwrap_or(4); // NULL
widths[i] = widths[i].max(len);
}
}
// 打印表头
let header: String = result.columns.iter()
.enumerate()
.map(|(i, c)| format!(" {:width$} ", c, width = widths[i]))
.collect::<Vec<_>>()
.join("|");
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
println!("|{}|", header);
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
// 打印数据行
for row in &result.rows {
let line: String = row.iter()
.enumerate()
.map(|(i, val)| {
match val {
Some(s) => format!(" {:width$} ", s, width = widths[i]),
None => format!(" {:width$} ", "NULL", width = widths[i]),
}
})
.collect::<Vec<_>>()
.join("|");
println!("|{}|", line);
}
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
}
fn print_json(&self, result: &QueryResponse) {
// 极简格式:对象数组
let rows: Vec<serde_json::Value> = result.rows.iter().map(|row| {
let mut obj = serde_json::Map::new();
for (i, col) in result.columns.iter().enumerate() {
let val = row.get(i)
.and_then(|v| v.as_ref())
.map(|s| {
// 尝试解析为数字
if let Ok(n) = s.parse::<i64>() {
serde_json::Value::Number(n.into())
} else if let Ok(n) = s.parse::<f64>() {
serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)))
} else {
serde_json::Value::String(s.clone())
}
})
.unwrap_or(serde_json::Value::Null);
obj.insert(col.clone(), val);
}
serde_json::Value::Object(obj)
}).collect();
println!("{}", serde_json::to_string(&rows).unwrap_or_default());
}
fn print_csv(&self, result: &QueryResponse) {
// 打印表头
println!("{}", result.columns.join(","));
// 打印数据
for row in &result.rows {
let line: String = row.iter()
.map(|val| {
match val {
Some(s) => {
if s.contains(',') || s.contains('"') || s.contains('\n') {
format!("\"{}\"", s.replace('"', "\"\""))
} else {
s.clone()
}
}
None => "".to_string(),
}
})
.collect::<Vec<_>>()
.join(",");
println!("{}", line);
}
}
fn print_vertical(&self, result: &QueryResponse) {
for (row_idx, row) in result.rows.iter().enumerate() {
if row_idx > 0 {
println!();
}
println!("*************************** {}. row ***************************", row_idx + 1);
for (col_idx, col) in result.columns.iter().enumerate() {
let val = row.get(col_idx)
.and_then(|v| v.as_ref())
.map(|s| s.as_str())
.unwrap_or("NULL");
println!("{:>20}: {}", col, val);
}
}
}
}