新增: MySQL/SSH 代理工具
- mysql-proxy: MySQL HTTP 代理,连接池复用 - ssh-proxy: SSH HTTP 代理,会话复用 - mysql-cli: 轻量级 MySQL CLI 工具 功能特性: - 延迟初始化,启动快 - CLI 和 HTTP API 双模式 - 请求日志支持 - 错误友好提示 - JSON 极简输出格式
This commit is contained in:
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# 编译输出
|
||||||
|
target/
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# 配置文件(含敏感信息)
|
||||||
|
mysql-proxy/mysql-proxy.toml
|
||||||
|
ssh-proxy/ssh-proxy.toml
|
||||||
100
INCUBATOR.md
Normal file
100
INCUBATOR.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# 代理工具孵化记录
|
||||||
|
|
||||||
|
> 状态:孵化中 | 更新:2026-03-19
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 工具列表
|
||||||
|
|
||||||
|
| 工具 | 端口 | 状态 | 项目目录 |
|
||||||
|
|------|------|------|----------|
|
||||||
|
| mysql-proxy | 3307 | 可用 | `mysql-proxy` |
|
||||||
|
| ssh-proxy | 3308 | 可用 | `ssh-proxy` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 新增功能
|
||||||
|
|
||||||
|
### 2026-03-19
|
||||||
|
|
||||||
|
1. **请求日志** (P0)
|
||||||
|
- 设置环境变量启用:`MYSQL_PROXY_LOG` / `SSH_PROXY_LOG`
|
||||||
|
- 日志格式:JSON,每行一条记录
|
||||||
|
|
||||||
|
2. **ssh-proxy CLI JSON 输出** (P1)
|
||||||
|
```bash
|
||||||
|
ssh-proxy exec -n flux_dev -c "docker ps" -F json
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **ssh-proxy 动态添加服务器** (P2)
|
||||||
|
```bash
|
||||||
|
ssh-proxy add-server -n myserver -H 192.168.1.100 -u root -p password
|
||||||
|
# 或 API
|
||||||
|
curl -X POST http://127.0.0.1:3308/servers/add \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"name":"myserver","host":"192.168.1.100","user":"root","password":"secret"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **API 错误友好提示**
|
||||||
|
- 错误时返回 `usage` 字段,引导正确使用
|
||||||
|
|
||||||
|
### 2026-03-18
|
||||||
|
|
||||||
|
- mysql-proxy 动态添加连接 API
|
||||||
|
- ssh-proxy Windows 路径修复
|
||||||
|
- ssh-proxy 读取错误日志修复
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实测性能
|
||||||
|
|
||||||
|
| 操作 | 代理 | 直连 | 提升 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| mysql 首次查询 | ~150ms | ~500ms | 3x |
|
||||||
|
| mysql 复用会话 | ~50ms | ~500ms | 10x |
|
||||||
|
| ssh 首次执行 | ~900ms | ~1500ms | 1.7x |
|
||||||
|
| ssh 复用会话 | ~100-200ms | ~1500ms | 7-15x |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AI 推荐使用方式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# MySQL 查询 (AI 友好)
|
||||||
|
curl -X POST http://127.0.0.1:3307/query \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"conn":"flux_dev","sql":"SELECT VERSION()"}'
|
||||||
|
|
||||||
|
# SSH 执行 (AI 友好)
|
||||||
|
curl -X POST http://127.0.0.1:3308/exec \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"server":"flux_dev","command":"docker ps"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势**:
|
||||||
|
- JSON 格式,易于解析
|
||||||
|
- 会话复用,响应快
|
||||||
|
- 错误提示友好
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已知问题
|
||||||
|
|
||||||
|
### ssh-proxy
|
||||||
|
|
||||||
|
1. **ssh2 不支持 ed25519 密钥**
|
||||||
|
- 错误: `[Session(-19)] Callback returned error`
|
||||||
|
- 解决: 使用 PEM 格式 RSA 密钥
|
||||||
|
```bash
|
||||||
|
ssh-keygen -t rsa -b 2048 -f ~/.ssh/id_rsa_pem -m PEM -N ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### mysql-proxy
|
||||||
|
|
||||||
|
- 暂无
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 待优化
|
||||||
|
|
||||||
|
*由 AI 在实际使用过程中根据遇到的问题填写*
|
||||||
25
INDEX.md
Normal file
25
INDEX.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Rust Work 索引
|
||||||
|
|
||||||
|
> 更新:2026-03-18
|
||||||
|
|
||||||
|
## 项目列表
|
||||||
|
|
||||||
|
| 项目 | 端口 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| [mysql-proxy](mysql-proxy) | 3307 | MySQL 会话复用代理 |
|
||||||
|
| [ssh-proxy](ssh-proxy) | 3308 | SSH 会话复用代理 |
|
||||||
|
| [mysql-cli](mysql-cli) | - | MySQL CLI 工具 (开发中) |
|
||||||
|
|
||||||
|
## 孵化记录
|
||||||
|
|
||||||
|
- [INCUBATOR.md](INCUBATOR.md) - 问题/性能/建议
|
||||||
|
|
||||||
|
## 快速启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# mysql-proxy
|
||||||
|
cd mysql-proxy && ./target/release/mysql-proxy.exe
|
||||||
|
|
||||||
|
# ssh-proxy
|
||||||
|
cd ssh-proxy && ./target/release/ssh-proxy.exe
|
||||||
|
```
|
||||||
34
mysql-cli/Cargo.toml
Normal file
34
mysql-cli/Cargo.toml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
[package]
|
||||||
|
name = "mysql-cli"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
description = "A lightweight MySQL command-line tool written in Rust"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# MySQL driver
|
||||||
|
mysql = "25"
|
||||||
|
|
||||||
|
# Command line argument parsing
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
|
||||||
|
# REPL with history and completion
|
||||||
|
rustyline = { version = "14", features = ["derive"] }
|
||||||
|
|
||||||
|
# Serialization
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
|
||||||
|
# CSV output
|
||||||
|
csv = "1"
|
||||||
|
|
||||||
|
# Error handling
|
||||||
|
anyhow = "1"
|
||||||
|
thiserror = "1"
|
||||||
|
|
||||||
|
# Check if stdin is a tty
|
||||||
|
atty = "0.2"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
strip = true
|
||||||
148
mysql-cli/README.md
Normal file
148
mysql-cli/README.md
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
# MySQL CLI (Rust)
|
||||||
|
|
||||||
|
> 💡 **快速开始**:运行 `mysql --help` 查看完整使用文档
|
||||||
|
|
||||||
|
一个轻量级的 Rust 语言 MySQL 命令行工具,兼容 MySQL 官方客户端用法,可以在未安装 MySQL 客户端的系统中执行 SQL 语句。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- ✅ **交互式模式**:支持完整的 REPL 交互式 Shell
|
||||||
|
- ✅ **批量执行**:支持从 SQL 文件执行多条语句
|
||||||
|
- ✅ **多种输出格式**:支持表格、CSV、JSON、垂直格式
|
||||||
|
- ✅ **完整 SQL 支持**:支持 SELECT、INSERT、UPDATE、DELETE 等所有 SQL 语句
|
||||||
|
- ✅ **特殊命令**:支持 USE、SHOW、DESCRIBE、EXPLAIN 等 MySQL 命令
|
||||||
|
- ✅ **紧凑体积**:编译后仅 2.5MB,无需安装 MySQL 客户端
|
||||||
|
|
||||||
|
## 编译
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 基本语法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./mysql.exe -h <host> -P <port> -u <username> -p <password> -D <database> -e "<SQL>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 参数说明
|
||||||
|
|
||||||
|
| 参数 | 说明 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| -h | MySQL 主机地址 | 127.0.0.1 |
|
||||||
|
| -P | MySQL 端口 | 3306 |
|
||||||
|
| -u | 用户名 | root |
|
||||||
|
| -p | 密码 | (空) |
|
||||||
|
| -D | 数据库名 | (空) |
|
||||||
|
| -e | 要执行的 SQL 语句 | (可选) |
|
||||||
|
| -f | 从 SQL 文件执行 | (可选) |
|
||||||
|
| -F | 输出格式:table, csv, json, vertical | table |
|
||||||
|
| -o | 输出到文件 | stdout |
|
||||||
|
| -i | 强制进入交互模式 | (可选) |
|
||||||
|
|
||||||
|
**注意**:不指定 `-e` 或 `-f` 时,需要使用 `-i` 参数进入交互模式
|
||||||
|
|
||||||
|
### 使用示例
|
||||||
|
|
||||||
|
#### 1. 交互式模式
|
||||||
|
|
||||||
|
进入交互式 Shell:
|
||||||
|
```bash
|
||||||
|
# 基本连接并进入交互模式
|
||||||
|
mysql -p123456 -i
|
||||||
|
|
||||||
|
# 指定主机和用户
|
||||||
|
mysql -h127.0.0.1 -uroot -p123456 -i
|
||||||
|
```
|
||||||
|
|
||||||
|
交互式命令:
|
||||||
|
- `use <database>` - 切换数据库
|
||||||
|
- `status` - 显示连接状态
|
||||||
|
- `source <file>` - 执行 SQL 文件
|
||||||
|
- `format <type>` - 设置输出格式(table, csv, json, vertical)
|
||||||
|
- `output <file>` - 设置输出文件
|
||||||
|
- `exit / quit` - 退出程序
|
||||||
|
|
||||||
|
#### 2. 直接执行 SQL
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查询数据
|
||||||
|
mysql -h127.0.0.1 -uroot -p123456 -Dmydb -e "SELECT * FROM users LIMIT 10"
|
||||||
|
|
||||||
|
# 显示表结构
|
||||||
|
mysql -p123456 -Dmydb -e "DESCRIBE users"
|
||||||
|
|
||||||
|
# 显示所有数据库
|
||||||
|
mysql -p123456 -e "SHOW DATABASES"
|
||||||
|
|
||||||
|
# 插入数据
|
||||||
|
mysql -p123456 -Dmydb -e "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')"
|
||||||
|
|
||||||
|
# 更新数据
|
||||||
|
mysql -p123456 -Dmydb -e "UPDATE users SET email='new@email.com' WHERE id=1"
|
||||||
|
|
||||||
|
# 删除数据
|
||||||
|
mysql -p123456 -Dmydb -e "DELETE FROM users WHERE id=1"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 从 SQL 文件执行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 执行 SQL 文件
|
||||||
|
mysql -p123456 -f script.sql
|
||||||
|
|
||||||
|
# 执行并输出到 CSV
|
||||||
|
mysql -p123456 -f script.sql -F csv -o result.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. 不同输出格式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CSV 格式
|
||||||
|
mysql -p123456 -e "SELECT * FROM users" -F csv
|
||||||
|
|
||||||
|
# JSON 格式
|
||||||
|
mysql -p123456 -e "SELECT * FROM users" -F json
|
||||||
|
|
||||||
|
# 垂直格式(适合字段较多的记录)
|
||||||
|
mysql -p123456 -e "SELECT * FROM users" -F vertical
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
mysql-cli/
|
||||||
|
├── Cargo.toml # Rust 项目配置
|
||||||
|
├── README.md # 项目文档
|
||||||
|
└── src/
|
||||||
|
├── main.rs # 主程序入口
|
||||||
|
├── config.rs # 配置管理
|
||||||
|
├── db.rs # 数据库连接
|
||||||
|
├── executor.rs # SQL 执行引擎
|
||||||
|
└── repl.rs # 交互式 REPL
|
||||||
|
```
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
|
||||||
|
- Rust 1.70+
|
||||||
|
- mysql (25.x)
|
||||||
|
- clap (4.x) - 命令行参数解析
|
||||||
|
- rustyline (14.x) - 交互式 REPL
|
||||||
|
- serde / serde_json - JSON 序列化
|
||||||
|
- csv - CSV 输出
|
||||||
|
|
||||||
|
## 与 Go 版本对比
|
||||||
|
|
||||||
|
| 功能 | Rust 版本 | Go 版本 |
|
||||||
|
|------|-----------|---------|
|
||||||
|
| 交互式模式 | ✅ | ✅ |
|
||||||
|
| SQL 文件执行 | ✅ | ✅ |
|
||||||
|
| 输出格式 | 表格/CSV/JSON/垂直 | 表格/CSV/JSON/垂直 |
|
||||||
|
| 二进制大小 | ~2.5MB | ~7.3MB |
|
||||||
|
| 内存占用 | 更低 | 较高 |
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
115
mysql-cli/src/config.rs
Normal file
115
mysql-cli/src/config.rs
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// 数据库连接配置
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
pub database: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 3306,
|
||||||
|
username: "root".to_string(),
|
||||||
|
password: String::new(),
|
||||||
|
database: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
/// 构建 MySQL DSN
|
||||||
|
pub fn build_dsn(&self) -> String {
|
||||||
|
format!(
|
||||||
|
"mysql://{}:{}@{}:{}/{}",
|
||||||
|
self.username,
|
||||||
|
urlencoding(&self.password),
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
self.database
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构建不带数据库的 DSN
|
||||||
|
pub fn build_dsn_without_db(&self) -> String {
|
||||||
|
format!(
|
||||||
|
"mysql://{}:{}@{}:{}/",
|
||||||
|
self.username,
|
||||||
|
urlencoding(&self.password),
|
||||||
|
self.host,
|
||||||
|
self.port
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Config {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{}@{}:{}",
|
||||||
|
self.username, self.host, self.port
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// URL 编码密码中的特殊字符
|
||||||
|
fn urlencoding(s: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
for c in s.chars() {
|
||||||
|
match c {
|
||||||
|
'@' => result.push_str("%40"),
|
||||||
|
':' => result.push_str("%3A"),
|
||||||
|
'/' => result.push_str("%2F"),
|
||||||
|
'?' => result.push_str("%3F"),
|
||||||
|
'#' => result.push_str("%23"),
|
||||||
|
'[' => result.push_str("%5B"),
|
||||||
|
']' => result.push_str("%5D"),
|
||||||
|
'!' => result.push_str("%21"),
|
||||||
|
'$' => result.push_str("%24"),
|
||||||
|
'&' => result.push_str("%26"),
|
||||||
|
'\'' => result.push_str("%27"),
|
||||||
|
'(' => result.push_str("%28"),
|
||||||
|
')' => result.push_str("%29"),
|
||||||
|
'*' => result.push_str("%2A"),
|
||||||
|
'+' => result.push_str("%2B"),
|
||||||
|
',' => result.push_str("%2C"),
|
||||||
|
';' => result.push_str("%3B"),
|
||||||
|
'=' => result.push_str("%3D"),
|
||||||
|
'%' => result.push_str("%25"),
|
||||||
|
' ' => result.push_str("%20"),
|
||||||
|
_ => result.push(c),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_dsn() {
|
||||||
|
let config = Config {
|
||||||
|
host: "localhost".to_string(),
|
||||||
|
port: 3306,
|
||||||
|
username: "root".to_string(),
|
||||||
|
password: "123456".to_string(),
|
||||||
|
database: "test".to_string(),
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
config.build_dsn(),
|
||||||
|
"mysql://root:123456@localhost:3306/test"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_urlencoding() {
|
||||||
|
assert_eq!(urlencoding("a@b:c"), "a%40b%3Ac");
|
||||||
|
assert_eq!(urlencoding("p@ss!word"), "p%40ss%21word");
|
||||||
|
}
|
||||||
|
}
|
||||||
60
mysql-cli/src/db.rs
Normal file
60
mysql-cli/src/db.rs
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use mysql::{Pool, PooledConn, Opts, prelude::*};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
|
||||||
|
/// 连接数据库
|
||||||
|
pub fn connect(cfg: &Config) -> Result<Pool> {
|
||||||
|
let dsn = if cfg.database.is_empty() {
|
||||||
|
cfg.build_dsn_without_db()
|
||||||
|
} else {
|
||||||
|
cfg.build_dsn()
|
||||||
|
};
|
||||||
|
|
||||||
|
let opts: Opts = Opts::from_url(&dsn)?;
|
||||||
|
let pool = Pool::new(opts)?;
|
||||||
|
|
||||||
|
// 测试连接
|
||||||
|
let mut conn = pool.get_conn()?;
|
||||||
|
conn.query::<String, _>("SELECT 1")?;
|
||||||
|
|
||||||
|
Ok(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取连接
|
||||||
|
pub fn get_conn(pool: &Pool) -> Result<PooledConn> {
|
||||||
|
Ok(pool.get_conn()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取服务器版本
|
||||||
|
pub fn get_server_version(conn: &mut PooledConn) -> String {
|
||||||
|
conn.query_first::<String, _>("SELECT VERSION()")
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or_else(|| "Unknown".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取当前数据库
|
||||||
|
pub fn get_current_database(conn: &mut PooledConn) -> Option<String> {
|
||||||
|
conn.query_first::<String, _>("SELECT DATABASE()").ok().flatten()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有数据库
|
||||||
|
pub fn get_databases(conn: &mut PooledConn) -> Vec<String> {
|
||||||
|
conn.query::<String, _>("SHOW DATABASES")
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有表名
|
||||||
|
pub fn get_tables(conn: &mut PooledConn) -> Vec<String> {
|
||||||
|
conn.query::<String, _>("SHOW TABLES")
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取表的列信息
|
||||||
|
pub fn get_columns(conn: &mut PooledConn, table: &str) -> Vec<String> {
|
||||||
|
let sql = format!("SHOW COLUMNS FROM `{}`", table);
|
||||||
|
conn.query_map(sql, |row: mysql::Row| {
|
||||||
|
row.get::<String, _>(0).unwrap_or_default()
|
||||||
|
}).unwrap_or_default()
|
||||||
|
}
|
||||||
348
mysql-cli/src/executor.rs
Normal file
348
mysql-cli/src/executor.rs
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use mysql::{Pool, PooledConn, prelude::*, Row, Value};
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{self, Write, BufWriter};
|
||||||
|
|
||||||
|
/// 输出格式
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub enum OutputFormat {
|
||||||
|
#[default]
|
||||||
|
Table,
|
||||||
|
Csv,
|
||||||
|
Json,
|
||||||
|
Vertical,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for OutputFormat {
|
||||||
|
type Err = anyhow::Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"table" => Ok(Self::Table),
|
||||||
|
"csv" => Ok(Self::Csv),
|
||||||
|
"json" => Ok(Self::Json),
|
||||||
|
"vertical" => Ok(Self::Vertical),
|
||||||
|
_ => bail!("Invalid format: {}. Valid formats: table, csv, json, vertical", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SQL 执行器
|
||||||
|
pub struct Executor {
|
||||||
|
pool: Pool,
|
||||||
|
format: OutputFormat,
|
||||||
|
output: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Executor {
|
||||||
|
pub fn new(pool: Pool, format: OutputFormat, output: Option<String>) -> Self {
|
||||||
|
Self { pool, format, output }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行 SQL(支持多条语句)
|
||||||
|
pub fn execute(&self, query: &str) -> Result<()> {
|
||||||
|
let queries = split_queries(query);
|
||||||
|
let mut conn = self.pool.get_conn()?;
|
||||||
|
|
||||||
|
for (i, q) in queries.iter().enumerate() {
|
||||||
|
if q.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.execute_one(&mut conn, q)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行单条 SQL
|
||||||
|
fn execute_one(&self, conn: &mut PooledConn, query: &str) -> Result<()> {
|
||||||
|
let query = query.trim();
|
||||||
|
if query.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let upper_query = query.to_uppercase();
|
||||||
|
let is_query = upper_query.starts_with("SELECT")
|
||||||
|
|| upper_query.starts_with("SHOW")
|
||||||
|
|| upper_query.starts_with("DESCRIBE")
|
||||||
|
|| upper_query.starts_with("DESC ")
|
||||||
|
|| upper_query.starts_with("EXPLAIN")
|
||||||
|
|| upper_query.starts_with("WITH");
|
||||||
|
|
||||||
|
if is_query {
|
||||||
|
self.execute_query(conn, query)
|
||||||
|
} else {
|
||||||
|
self.execute_exec(conn, query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行查询语句
|
||||||
|
fn execute_query(&self, conn: &mut PooledConn, query: &str) -> Result<()> {
|
||||||
|
let result = conn.query_iter(query)?;
|
||||||
|
|
||||||
|
let mut column_names: Vec<String> = Vec::new();
|
||||||
|
let mut data: Vec<Vec<Option<String>>> = Vec::new();
|
||||||
|
|
||||||
|
// 获取数据
|
||||||
|
for row_result in result {
|
||||||
|
let row = row_result?;
|
||||||
|
// 从第一行获取列名
|
||||||
|
if column_names.is_empty() {
|
||||||
|
column_names = row.columns()
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.name_str().to_string())
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
let values = row_to_strings(&row);
|
||||||
|
data.push(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出结果
|
||||||
|
self.output_result(&column_names, &data)?;
|
||||||
|
|
||||||
|
// 显示行数
|
||||||
|
if !data.is_empty() {
|
||||||
|
println!();
|
||||||
|
println!("{} rows in set", data.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行非查询语句
|
||||||
|
fn execute_exec(&self, conn: &mut PooledConn, query: &str) -> Result<()> {
|
||||||
|
let result = conn.query_iter(query)?;
|
||||||
|
let upper_query = query.to_uppercase();
|
||||||
|
|
||||||
|
if upper_query.starts_with("INSERT")
|
||||||
|
|| upper_query.starts_with("UPDATE")
|
||||||
|
|| upper_query.starts_with("DELETE")
|
||||||
|
|| upper_query.starts_with("REPLACE")
|
||||||
|
{
|
||||||
|
let affected = result.affected_rows();
|
||||||
|
if upper_query.starts_with("INSERT") {
|
||||||
|
let last_id = result.last_insert_id().unwrap_or(0);
|
||||||
|
println!("Query OK, {} row affected, last insert ID: {}", affected, last_id);
|
||||||
|
} else {
|
||||||
|
println!("Query OK, {} row affected", affected);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!("Query OK");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 根据格式输出结果
|
||||||
|
fn output_result(&self, columns: &[String], data: &[Vec<Option<String>>]) -> Result<()> {
|
||||||
|
let mut output: Box<dyn Write> = match &self.output {
|
||||||
|
Some(path) => Box::new(BufWriter::new(File::create(path)?)),
|
||||||
|
None => Box::new(BufWriter::new(io::stdout())),
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.format {
|
||||||
|
OutputFormat::Table => self.output_table(&mut output, columns, data),
|
||||||
|
OutputFormat::Csv => self.output_csv(&mut output, columns, data),
|
||||||
|
OutputFormat::Json => self.output_json(&mut output, columns, data),
|
||||||
|
OutputFormat::Vertical => self.output_vertical(&mut output, columns, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 表格格式输出
|
||||||
|
fn output_table(&self, output: &mut dyn Write, columns: &[String], data: &[Vec<Option<String>>]) -> Result<()> {
|
||||||
|
// 计算列宽
|
||||||
|
let mut widths: Vec<usize> = columns.iter().map(|c| c.len()).collect();
|
||||||
|
for row in data {
|
||||||
|
for (i, val) in row.iter().enumerate() {
|
||||||
|
let len = val.as_ref().map_or(4, |v| v.len());
|
||||||
|
if len > widths[i] {
|
||||||
|
widths[i] = len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打印表头
|
||||||
|
print_table_row(output, columns, &widths)?;
|
||||||
|
print_table_separator(output, &widths)?;
|
||||||
|
|
||||||
|
// 打印数据行
|
||||||
|
for row in data {
|
||||||
|
let str_row: Vec<String> = row.iter().map(|v| v.clone().unwrap_or_else(|| "NULL".to_string())).collect();
|
||||||
|
print_table_row(output, &str_row, &widths)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CSV 格式输出
|
||||||
|
fn output_csv(&self, output: &mut dyn Write, columns: &[String], data: &[Vec<Option<String>>]) -> Result<()> {
|
||||||
|
let mut writer = csv::Writer::from_writer(output);
|
||||||
|
|
||||||
|
writer.write_record(columns)?;
|
||||||
|
for row in data {
|
||||||
|
let str_row: Vec<String> = row.iter().map(|v| v.clone().unwrap_or_default()).collect();
|
||||||
|
writer.write_record(&str_row)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.flush()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON 格式输出
|
||||||
|
fn output_json(&self, output: &mut dyn Write, columns: &[String], data: &[Vec<Option<String>>]) -> Result<()> {
|
||||||
|
let result: Vec<serde_json::Map<String, serde_json::Value>> = data
|
||||||
|
.iter()
|
||||||
|
.map(|row| {
|
||||||
|
columns
|
||||||
|
.iter()
|
||||||
|
.zip(row.iter())
|
||||||
|
.map(|(col, val)| {
|
||||||
|
let v = match val {
|
||||||
|
Some(s) => serde_json::Value::String(s.clone()),
|
||||||
|
None => serde_json::Value::Null,
|
||||||
|
};
|
||||||
|
(col.clone(), v)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let json = serde_json::to_string_pretty(&result)?;
|
||||||
|
writeln!(output, "{}", json)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 垂直格式输出
|
||||||
|
fn output_vertical(&self, output: &mut dyn Write, columns: &[String], data: &[Vec<Option<String>>]) -> Result<()> {
|
||||||
|
let max_len = columns.iter().map(|c| c.len()).max().unwrap_or(0);
|
||||||
|
|
||||||
|
for (row_idx, row) in data.iter().enumerate() {
|
||||||
|
if row_idx > 0 {
|
||||||
|
writeln!(output)?;
|
||||||
|
}
|
||||||
|
for (i, val) in row.iter().enumerate() {
|
||||||
|
let v = val.as_ref().map(|s| s.as_str()).unwrap_or("NULL");
|
||||||
|
writeln!(output, "{:width$}: {}", columns[i], v, width = max_len)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从文件执行 SQL
|
||||||
|
pub fn execute_from_file(&self, filename: &str) -> Result<()> {
|
||||||
|
let content = std::fs::read_to_string(filename)?;
|
||||||
|
self.execute(&content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 Row 转换为字符串向量
|
||||||
|
fn row_to_strings(row: &Row) -> Vec<Option<String>> {
|
||||||
|
row.columns()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, _col)| {
|
||||||
|
match row.get::<Value, usize>(i) {
|
||||||
|
Some(value) => value_to_string(&value),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 Value 转换为字符串
|
||||||
|
fn value_to_string(value: &Value) -> Option<String> {
|
||||||
|
match value {
|
||||||
|
Value::NULL => None,
|
||||||
|
Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(),
|
||||||
|
Value::Int(i) => Some(i.to_string()),
|
||||||
|
Value::UInt(u) => Some(u.to_string()),
|
||||||
|
Value::Float(f) => Some(f.to_string()),
|
||||||
|
Value::Double(d) => Some(d.to_string()),
|
||||||
|
Value::Date(year, month, day, hour, min, sec, micro) => {
|
||||||
|
match (hour, min, sec, micro) {
|
||||||
|
(0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)),
|
||||||
|
_ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Time(neg, days, hours, minutes, seconds, microseconds) => {
|
||||||
|
let sign = if *neg { "-" } else { "" };
|
||||||
|
Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 分割多个 SQL 语句
|
||||||
|
fn split_queries(query: &str) -> Vec<String> {
|
||||||
|
let mut queries = Vec::new();
|
||||||
|
let mut current = String::new();
|
||||||
|
let mut in_quotes = false;
|
||||||
|
let mut quote_char = ' ';
|
||||||
|
let chars: Vec<char> = query.chars().collect();
|
||||||
|
|
||||||
|
for i in 0..chars.len() {
|
||||||
|
let c = chars[i];
|
||||||
|
|
||||||
|
match c {
|
||||||
|
'\'' | '"' | '`' if !in_quotes => {
|
||||||
|
in_quotes = true;
|
||||||
|
quote_char = c;
|
||||||
|
current.push(c);
|
||||||
|
}
|
||||||
|
'\'' | '"' | '`' if in_quotes && c == quote_char => {
|
||||||
|
// 检查是否是转义
|
||||||
|
let is_escaped = i > 0 && chars[i - 1] == '\\';
|
||||||
|
if !is_escaped {
|
||||||
|
in_quotes = false;
|
||||||
|
}
|
||||||
|
current.push(c);
|
||||||
|
}
|
||||||
|
';' if !in_quotes => {
|
||||||
|
queries.push(current.trim().to_string());
|
||||||
|
current.clear();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
current.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.trim().is_empty() {
|
||||||
|
queries.push(current.trim().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
queries
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 打印表格行
|
||||||
|
fn print_table_row(output: &mut dyn Write, data: &[String], widths: &[usize]) -> Result<()> {
|
||||||
|
for (i, val) in data.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(output, " | ")?;
|
||||||
|
}
|
||||||
|
write!(output, "{:<width$}", val, width = widths[i])?;
|
||||||
|
}
|
||||||
|
writeln!(output)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 打印表格分隔线
|
||||||
|
fn print_table_separator(output: &mut dyn Write, widths: &[usize]) -> Result<()> {
|
||||||
|
for (i, w) in widths.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(output, "-+-")?;
|
||||||
|
} else {
|
||||||
|
write!(output, "+")?;
|
||||||
|
}
|
||||||
|
for _ in 0..*w {
|
||||||
|
write!(output, "-")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(output, "+")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
132
mysql-cli/src/main.rs
Normal file
132
mysql-cli/src/main.rs
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use clap::Parser;
|
||||||
|
use std::io::{self, Read};
|
||||||
|
|
||||||
|
mod config;
|
||||||
|
mod db;
|
||||||
|
mod executor;
|
||||||
|
mod repl;
|
||||||
|
|
||||||
|
use config::Config;
|
||||||
|
use executor::{Executor, OutputFormat};
|
||||||
|
|
||||||
|
/// MySQL CLI - A lightweight MySQL command-line tool
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "mysql")]
|
||||||
|
#[command(about = "A lightweight MySQL command-line tool", long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// MySQL host
|
||||||
|
#[arg(short = 'h', long, default_value = "127.0.0.1")]
|
||||||
|
host: String,
|
||||||
|
|
||||||
|
/// MySQL port
|
||||||
|
#[arg(short = 'P', long, default_value_t = 3306)]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// MySQL username
|
||||||
|
#[arg(short = 'u', long, default_value = "root")]
|
||||||
|
username: String,
|
||||||
|
|
||||||
|
/// MySQL password
|
||||||
|
#[arg(short = 'p', long, default_value = "")]
|
||||||
|
password: String,
|
||||||
|
|
||||||
|
/// Database name
|
||||||
|
#[arg(short = 'D', long, default_value = "")]
|
||||||
|
database: String,
|
||||||
|
|
||||||
|
/// Execute SQL statement
|
||||||
|
#[arg(short = 'e', long)]
|
||||||
|
execute: Option<String>,
|
||||||
|
|
||||||
|
/// Execute SQL from file
|
||||||
|
#[arg(short = 'f', long)]
|
||||||
|
file: Option<String>,
|
||||||
|
|
||||||
|
/// Output format: table, csv, json, vertical
|
||||||
|
#[arg(short = 'F', long, default_value = "table")]
|
||||||
|
format: String,
|
||||||
|
|
||||||
|
/// Output file (default: stdout)
|
||||||
|
#[arg(short = 'o', long)]
|
||||||
|
output: Option<String>,
|
||||||
|
|
||||||
|
/// Interactive mode (REPL)
|
||||||
|
#[arg(short = 'i', long)]
|
||||||
|
interactive: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let config = Config {
|
||||||
|
host: args.host,
|
||||||
|
port: args.port,
|
||||||
|
username: args.username,
|
||||||
|
password: args.password,
|
||||||
|
database: args.database,
|
||||||
|
};
|
||||||
|
|
||||||
|
// 连接数据库
|
||||||
|
let pool = match db::connect(&config) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Error connecting to database: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let format: OutputFormat = args.format.parse().unwrap_or_default();
|
||||||
|
let executor = Executor::new(pool.clone(), format, args.output.clone());
|
||||||
|
|
||||||
|
// 检测 stdin 是否被重定向
|
||||||
|
let is_piped = !atty::is(atty::Stream::Stdin);
|
||||||
|
|
||||||
|
// 确定运行模式
|
||||||
|
if let Some(file) = &args.file {
|
||||||
|
// 从文件执行
|
||||||
|
if let Err(e) = executor.execute_from_file(file) {
|
||||||
|
eprintln!("Error executing file: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
} else if let Some(query) = &args.execute {
|
||||||
|
// 执行 SQL
|
||||||
|
if let Err(e) = executor.execute(query) {
|
||||||
|
eprintln!("Error executing SQL: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
} else if is_piped {
|
||||||
|
// stdin 被重定向,读取所有输入并执行
|
||||||
|
let mut content = String::new();
|
||||||
|
if let Err(e) = io::stdin().read_to_string(&mut content) {
|
||||||
|
eprintln!("Error reading stdin: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
if !content.trim().is_empty() {
|
||||||
|
if let Err(e) = executor.execute(&content) {
|
||||||
|
eprintln!("Error executing SQL: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if args.interactive {
|
||||||
|
// 交互模式
|
||||||
|
if let Err(e) = repl::run_repl(pool, config) {
|
||||||
|
eprintln!("Error: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 没有指定任何操作,显示使用提示
|
||||||
|
eprintln!("ERROR: Missing required argument.");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Usage:");
|
||||||
|
eprintln!(" mysql -h <host> -P <port> -u <user> -p <password> -D <database> -e \"<SQL>\"");
|
||||||
|
eprintln!(" mysql -i # Interactive mode");
|
||||||
|
eprintln!(" mysql -f <file> # Execute SQL file");
|
||||||
|
eprintln!(" echo \"SELECT 1\" | mysql # Execute from stdin");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("Run 'mysql --help' for detailed documentation.");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
297
mysql-cli/src/repl.rs
Normal file
297
mysql-cli/src/repl.rs
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use mysql::{Pool, prelude::*};
|
||||||
|
use rustyline::error::ReadlineError;
|
||||||
|
use rustyline::highlight::Highlighter;
|
||||||
|
use rustyline::hint::Hinter;
|
||||||
|
use rustyline::history::DefaultHistory;
|
||||||
|
use rustyline::{Completer, CompletionType, Config, Context, EditMode, Editor, Helper, Validator};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::config::Config as DbConfig;
|
||||||
|
use crate::db;
|
||||||
|
use crate::executor::{Executor, OutputFormat};
|
||||||
|
|
||||||
|
/// REPL 状态
|
||||||
|
pub struct ReplState {
|
||||||
|
config: DbConfig,
|
||||||
|
format: OutputFormat,
|
||||||
|
output: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReplState {
|
||||||
|
pub fn new(config: DbConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
format: OutputFormat::Table,
|
||||||
|
output: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt(&self) -> String {
|
||||||
|
let db = if self.config.database.is_empty() {
|
||||||
|
"(none)"
|
||||||
|
} else {
|
||||||
|
&self.config.database
|
||||||
|
};
|
||||||
|
format!("mysql [{}]> ", db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 智能补全器
|
||||||
|
#[derive(Completer, Helper, Validator)]
|
||||||
|
pub struct SqlCompleter {
|
||||||
|
sql_keywords: Vec<&'static str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SqlCompleter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sql_keywords: vec![
|
||||||
|
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP",
|
||||||
|
"CREATE", "ALTER", "SHOW", "DESCRIBE", "DESC", "USE",
|
||||||
|
"FROM", "WHERE", "ORDER", "GROUP", "LIMIT", "OFFSET",
|
||||||
|
"JOIN", "LEFT", "RIGHT", "INNER", "ON", "AS",
|
||||||
|
"AND", "OR", "NOT", "IN", "LIKE", "BETWEEN",
|
||||||
|
"BY", "TABLE", "DATABASE", "INDEX", "VIEW",
|
||||||
|
"SET", "VALUES", "INTO", "DISTINCT", "COUNT",
|
||||||
|
"SUM", "AVG", "MAX", "MIN", "NULL", "IS",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hinter for SqlCompleter {
|
||||||
|
type Hint = String;
|
||||||
|
|
||||||
|
fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option<String> {
|
||||||
|
if line.is_empty() || pos < line.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let line_upper = line.to_uppercase();
|
||||||
|
for keyword in &self.sql_keywords {
|
||||||
|
if keyword.starts_with(&line_upper) && keyword.len() > line.len() {
|
||||||
|
return Some(keyword[line.len()..].to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Highlighter for SqlCompleter {
|
||||||
|
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
|
||||||
|
Cow::Borrowed(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn highlight_char(&self, _line: &str, _pos: usize, _forced: bool) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 运行交互式 REPL
|
||||||
|
pub fn run_repl(pool: Pool, config: DbConfig) -> Result<()> {
|
||||||
|
let mut state = ReplState::new(config);
|
||||||
|
let executor = Executor::new(pool.clone(), state.format, state.output.clone());
|
||||||
|
|
||||||
|
// 打印欢迎信息
|
||||||
|
print_welcome(&pool, &state.config)?;
|
||||||
|
|
||||||
|
// 初始化 rustyline
|
||||||
|
let rl_config = Config::builder()
|
||||||
|
.history_ignore_space(true)
|
||||||
|
.completion_type(CompletionType::List)
|
||||||
|
.edit_mode(EditMode::Emacs)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let mut rl: Editor<SqlCompleter, DefaultHistory> = Editor::with_config(rl_config)?;
|
||||||
|
rl.set_helper(Some(SqlCompleter::new()));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let prompt = state.prompt();
|
||||||
|
let readline = rl.readline(&prompt);
|
||||||
|
|
||||||
|
match readline {
|
||||||
|
Ok(line) => {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到历史记录
|
||||||
|
let _ = rl.add_history_entry(line);
|
||||||
|
|
||||||
|
// 处理命令
|
||||||
|
if let Err(e) = handle_command(&pool, &mut state, &executor, line) {
|
||||||
|
eprintln!("Error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ReadlineError::Interrupted) => {
|
||||||
|
println!("^C");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(ReadlineError::Eof) => {
|
||||||
|
println!("Bye");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("Error: {}", err);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 打印欢迎信息
|
||||||
|
fn print_welcome(pool: &Pool, config: &DbConfig) -> Result<()> {
|
||||||
|
println!();
|
||||||
|
println!("Welcome to MySQL CLI (Rust). Commands end with ;");
|
||||||
|
|
||||||
|
let mut conn = pool.get_conn()?;
|
||||||
|
let version = db::get_server_version(&mut conn);
|
||||||
|
println!("Server version: {}", version);
|
||||||
|
println!("Connected to: {}", config);
|
||||||
|
println!();
|
||||||
|
println!("Commands:");
|
||||||
|
println!(" use <database> - Switch database");
|
||||||
|
println!(" status - Show connection status");
|
||||||
|
println!(" source <file> - Execute SQL from file");
|
||||||
|
println!(" format <type> - Set output format (table, csv, json, vertical)");
|
||||||
|
println!(" output <file> - Set output file");
|
||||||
|
println!(" exit/quit - Exit the program");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 处理命令
|
||||||
|
fn handle_command(pool: &Pool, state: &mut ReplState, executor: &Executor, line: &str) -> Result<()> {
|
||||||
|
let lower_line = line.to_lowercase();
|
||||||
|
|
||||||
|
// 处理特殊命令
|
||||||
|
match lower_line.as_str() {
|
||||||
|
"exit" | "quit" => {
|
||||||
|
println!("Bye");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
"status" => {
|
||||||
|
return show_status(pool, state);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower_line.starts_with("use ") {
|
||||||
|
return use_database(pool, state, &line[4..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower_line.starts_with("source ") {
|
||||||
|
return source_file(executor, &line[7..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower_line.starts_with("format ") {
|
||||||
|
return set_format(state, &line[7..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower_line.starts_with("output ") {
|
||||||
|
return set_output(state, &line[7..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行 SQL
|
||||||
|
execute_sql(pool, state, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 显示状态
|
||||||
|
fn show_status(pool: &Pool, state: &ReplState) -> Result<()> {
|
||||||
|
println!();
|
||||||
|
println!("Connection Status:");
|
||||||
|
println!(" Host: {}:{}", state.config.host, state.config.port);
|
||||||
|
println!(" User: {}", state.config.username);
|
||||||
|
println!(" Database: {}", state.config.database);
|
||||||
|
|
||||||
|
let mut conn = pool.get_conn()?;
|
||||||
|
let version = db::get_server_version(&mut conn);
|
||||||
|
println!(" Server Version: {}", version);
|
||||||
|
|
||||||
|
if let Some(current_db) = db::get_current_database(&mut conn) {
|
||||||
|
println!(" Current Database: {}", current_db);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 切换数据库
|
||||||
|
fn use_database(pool: &Pool, state: &mut ReplState, db_name: &str) -> Result<()> {
|
||||||
|
let db_name = db_name.trim().trim_end_matches(';').trim();
|
||||||
|
|
||||||
|
if db_name.is_empty() {
|
||||||
|
anyhow::bail!("database name is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut conn = pool.get_conn()?;
|
||||||
|
conn.query_iter(format!("USE `{}`", db_name))?;
|
||||||
|
|
||||||
|
state.config.database = db_name.to_string();
|
||||||
|
println!("Database changed");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从文件执行 SQL
|
||||||
|
fn source_file(executor: &Executor, filename: &str) -> Result<()> {
|
||||||
|
let filename = filename.trim().trim_end_matches(';').trim();
|
||||||
|
|
||||||
|
if filename.is_empty() {
|
||||||
|
anyhow::bail!("filename is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
executor.execute_from_file(filename)?;
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置输出格式
|
||||||
|
fn set_format(state: &mut ReplState, format: &str) -> Result<()> {
|
||||||
|
let format = format.trim().trim_end_matches(';').trim();
|
||||||
|
state.format = format.parse()?;
|
||||||
|
println!("Output format set to {}", format);
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置输出文件
|
||||||
|
fn set_output(state: &mut ReplState, output: &str) -> Result<()> {
|
||||||
|
let output = output.trim().trim_end_matches(';').trim();
|
||||||
|
|
||||||
|
if output.is_empty() {
|
||||||
|
state.output = None;
|
||||||
|
println!("Output reset to stdout");
|
||||||
|
} else {
|
||||||
|
state.output = Some(output.to_string());
|
||||||
|
println!("Output set to {}", output);
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行 SQL(支持多行)
|
||||||
|
fn execute_sql(pool: &Pool, state: &ReplState, line: &str) -> Result<()> {
|
||||||
|
let full_query = line.to_string();
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let executor = Executor::new(pool.clone(), state.format, state.output.clone());
|
||||||
|
executor.execute(&full_query)?;
|
||||||
|
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
if elapsed.as_millis() > 100 {
|
||||||
|
println!();
|
||||||
|
println!("({:.2} sec)", elapsed.as_secs_f64());
|
||||||
|
} else {
|
||||||
|
println!();
|
||||||
|
println!("({:.3} sec)", elapsed.as_secs_f64());
|
||||||
|
}
|
||||||
|
|
||||||
|
println!();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
21
mysql-proxy/Cargo.toml
Normal file
21
mysql-proxy/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "mysql-proxy"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
description = "MySQL HTTP proxy with connection pooling"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
mysql = "25"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
axum = "0.7"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
toml = "0.8"
|
||||||
|
anyhow = "1"
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
ureq = "2" # CLI HTTP client (2.x has simpler API)
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
strip = true
|
||||||
287
mysql-proxy/src/cli.rs
Normal file
287
mysql-proxy/src/cli.rs
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
const DEFAULT_SERVER: &str = "http://127.0.0.1:3307";
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct QueryRequest {
|
||||||
|
conn: String,
|
||||||
|
sql: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct QueryResponse {
|
||||||
|
columns: Vec<String>,
|
||||||
|
rows: Vec<Vec<Option<String>>>,
|
||||||
|
#[serde(rename = "rowCount")]
|
||||||
|
row_count: usize,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ErrorResponse {
|
||||||
|
error: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ConnectionsResponse {
|
||||||
|
connections: Vec<ConnectionInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ConnectionInfo {
|
||||||
|
name: String,
|
||||||
|
database: String,
|
||||||
|
host: String,
|
||||||
|
status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CLI 客户端
|
||||||
|
pub struct Cli {
|
||||||
|
server: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cli {
|
||||||
|
pub fn new(server: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行查询并输出结果
|
||||||
|
pub fn query(&self, conn: &str, sql: &str, format: Option<&str>) -> Result<()> {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let body = serde_json::to_string(&QueryRequest {
|
||||||
|
conn: conn.to_string(),
|
||||||
|
sql: sql.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response = ureq::post(&format!("{}/query", self.server))
|
||||||
|
.set("Content-Type", "application/json")
|
||||||
|
.send_string(&body);
|
||||||
|
|
||||||
|
// ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中
|
||||||
|
let body = match response {
|
||||||
|
Ok(r) => r.into_string()?,
|
||||||
|
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||||
|
Err(e) => bail!("HTTP error: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
let total_ms = start.elapsed().as_millis();
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
eprintln!("Error: {}", err.error);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: QueryResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
// 输出结果
|
||||||
|
match format {
|
||||||
|
Some("json") => self.print_json(&result),
|
||||||
|
Some("csv") => self.print_csv(&result),
|
||||||
|
Some("vertical") | Some("vert") => self.print_vertical(&result),
|
||||||
|
_ => self.print_table(&result),
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\n{} rows in set ({}ms db, {}ms total)",
|
||||||
|
result.row_count, result.duration_ms, total_ms);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行语句 (INSERT/UPDATE/DELETE)
|
||||||
|
pub fn execute(&self, conn: &str, sql: &str) -> Result<()> {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let body = serde_json::to_string(&QueryRequest {
|
||||||
|
conn: conn.to_string(),
|
||||||
|
sql: sql.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response = ureq::post(&format!("{}/execute", self.server))
|
||||||
|
.set("Content-Type", "application/json")
|
||||||
|
.send_string(&body);
|
||||||
|
|
||||||
|
// ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中
|
||||||
|
let body = match response {
|
||||||
|
Ok(r) => r.into_string()?,
|
||||||
|
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||||
|
Err(e) => bail!("HTTP error: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
let total_ms = start.elapsed().as_millis();
|
||||||
|
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
eprintln!("Error: {}", err.error);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ExecuteResponse {
|
||||||
|
#[serde(rename = "affectedRows")]
|
||||||
|
affected_rows: u64,
|
||||||
|
#[serde(rename = "lastInsertId")]
|
||||||
|
last_insert_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ExecuteResponse = serde_json::from_str(&body)?;
|
||||||
|
println!("Query OK, {} rows affected ({}ms total)", result.affected_rows, total_ms);
|
||||||
|
|
||||||
|
if result.last_insert_id > 0 {
|
||||||
|
println!("Last insert ID: {}", result.last_insert_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出连接
|
||||||
|
pub fn list_connections(&self) -> Result<()> {
|
||||||
|
let response = ureq::get(&format!("{}/connections", self.server))
|
||||||
|
.call()?;
|
||||||
|
|
||||||
|
let body = response.into_string()?;
|
||||||
|
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
bail!("{}", err.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ConnectionsResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
println!("Connections:");
|
||||||
|
println!("{:<15} {:<20} {:<40} {:<10}", "Name", "Database", "Host", "Status");
|
||||||
|
println!("{}", "-".repeat(85));
|
||||||
|
|
||||||
|
for conn in result.connections {
|
||||||
|
println!("{:<15} {:<20} {:<40} {:<10}",
|
||||||
|
conn.name, conn.database, conn.host, conn.status);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查代理是否运行
|
||||||
|
pub fn check_server(&self) -> Result<bool> {
|
||||||
|
match ureq::get(&format!("{}/health", self.server)).call() {
|
||||||
|
Ok(_) => Ok(true),
|
||||||
|
Err(_) => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出格式化方法
|
||||||
|
fn print_table(&self, result: &QueryResponse) {
|
||||||
|
if result.columns.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算列宽
|
||||||
|
let mut widths: Vec<usize> = result.columns.iter().map(|c| c.len()).collect();
|
||||||
|
|
||||||
|
for row in &result.rows {
|
||||||
|
for (i, val) in row.iter().enumerate() {
|
||||||
|
let len = val.as_ref().map(|s| s.len()).unwrap_or(4); // NULL
|
||||||
|
widths[i] = widths[i].max(len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 打印表头
|
||||||
|
let header: String = result.columns.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, c)| format!(" {:width$} ", c, width = widths[i]))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("|");
|
||||||
|
|
||||||
|
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||||
|
println!("|{}|", header);
|
||||||
|
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||||
|
|
||||||
|
// 打印数据行
|
||||||
|
for row in &result.rows {
|
||||||
|
let line: String = row.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, val)| {
|
||||||
|
match val {
|
||||||
|
Some(s) => format!(" {:width$} ", s, width = widths[i]),
|
||||||
|
None => format!(" {:width$} ", "NULL", width = widths[i]),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("|");
|
||||||
|
println!("|{}|", line);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("+{}+", widths.iter().map(|w| "-".repeat(*w + 2)).collect::<Vec<_>>().join("+"));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_json(&self, result: &QueryResponse) {
|
||||||
|
// 极简格式:对象数组
|
||||||
|
let rows: Vec<serde_json::Value> = result.rows.iter().map(|row| {
|
||||||
|
let mut obj = serde_json::Map::new();
|
||||||
|
for (i, col) in result.columns.iter().enumerate() {
|
||||||
|
let val = row.get(i)
|
||||||
|
.and_then(|v| v.as_ref())
|
||||||
|
.map(|s| {
|
||||||
|
// 尝试解析为数字
|
||||||
|
if let Ok(n) = s.parse::<i64>() {
|
||||||
|
serde_json::Value::Number(n.into())
|
||||||
|
} else if let Ok(n) = s.parse::<f64>() {
|
||||||
|
serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap_or_else(|| serde_json::Number::from(0)))
|
||||||
|
} else {
|
||||||
|
serde_json::Value::String(s.clone())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(serde_json::Value::Null);
|
||||||
|
obj.insert(col.clone(), val);
|
||||||
|
}
|
||||||
|
serde_json::Value::Object(obj)
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
println!("{}", serde_json::to_string(&rows).unwrap_or_default());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_csv(&self, result: &QueryResponse) {
|
||||||
|
// 打印表头
|
||||||
|
println!("{}", result.columns.join(","));
|
||||||
|
|
||||||
|
// 打印数据
|
||||||
|
for row in &result.rows {
|
||||||
|
let line: String = row.iter()
|
||||||
|
.map(|val| {
|
||||||
|
match val {
|
||||||
|
Some(s) => {
|
||||||
|
if s.contains(',') || s.contains('"') || s.contains('\n') {
|
||||||
|
format!("\"{}\"", s.replace('"', "\"\""))
|
||||||
|
} else {
|
||||||
|
s.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => "".to_string(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",");
|
||||||
|
println!("{}", line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_vertical(&self, result: &QueryResponse) {
|
||||||
|
for (row_idx, row) in result.rows.iter().enumerate() {
|
||||||
|
if row_idx > 0 {
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
println!("*************************** {}. row ***************************", row_idx + 1);
|
||||||
|
|
||||||
|
for (col_idx, col) in result.columns.iter().enumerate() {
|
||||||
|
let val = row.get(col_idx)
|
||||||
|
.and_then(|v| v.as_ref())
|
||||||
|
.map(|s| s.as_str())
|
||||||
|
.unwrap_or("NULL");
|
||||||
|
println!("{:>20}: {}", col, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
156
mysql-proxy/src/config.rs
Normal file
156
mysql-proxy/src/config.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default = "default_host")]
|
||||||
|
pub host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_host() -> String {
|
||||||
|
"127.0.0.1".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 连接池配置
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct PoolConfig {
|
||||||
|
/// 默认最大连接数
|
||||||
|
#[serde(default = "default_max_connections")]
|
||||||
|
pub default_max_connections: usize,
|
||||||
|
/// 空闲超时 (秒)
|
||||||
|
#[serde(default = "default_idle_timeout")]
|
||||||
|
pub idle_timeout_secs: u64,
|
||||||
|
/// 检查间隔 (秒)
|
||||||
|
#[serde(default = "default_check_interval")]
|
||||||
|
pub check_interval_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_connections() -> usize { 5 }
|
||||||
|
fn default_idle_timeout() -> u64 { 300 }
|
||||||
|
fn default_check_interval() -> u64 { 60 }
|
||||||
|
|
||||||
|
impl Default for PoolConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
default_max_connections: default_max_connections(),
|
||||||
|
idle_timeout_secs: default_idle_timeout(),
|
||||||
|
check_interval_secs: default_check_interval(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct ConnectionConfig {
|
||||||
|
pub name: String,
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub user: String,
|
||||||
|
pub password: String,
|
||||||
|
pub database: String,
|
||||||
|
/// 最大连接数 (可选,覆盖默认值)
|
||||||
|
#[serde(default)]
|
||||||
|
pub max_connections: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
#[serde(default = "default_server")]
|
||||||
|
pub server: ServerConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pool: PoolConfig,
|
||||||
|
pub connections: Vec<ConnectionConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_server() -> ServerConfig {
|
||||||
|
ServerConfig {
|
||||||
|
port: 3307,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
/// 从文件加载配置
|
||||||
|
pub fn from_file(path: &str) -> Result<Self> {
|
||||||
|
let content = fs::read_to_string(path)?;
|
||||||
|
let config: Config = toml::from_str(&content)?;
|
||||||
|
config.validate()?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从默认路径加载配置
|
||||||
|
pub fn load() -> Result<Self> {
|
||||||
|
// 尝试多个配置文件路径
|
||||||
|
let paths = [
|
||||||
|
"mysql-proxy.toml",
|
||||||
|
"./config/mysql-proxy.toml",
|
||||||
|
&format!("{}/.mysql-proxy.toml", std::env::var("HOME").unwrap_or_default()),
|
||||||
|
];
|
||||||
|
|
||||||
|
for path in &paths {
|
||||||
|
if std::path::Path::new(path).exists() {
|
||||||
|
return Self::from_file(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Config file not found. Create mysql-proxy.toml in current directory or use --config flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 验证配置
|
||||||
|
fn validate(&self) -> Result<()> {
|
||||||
|
let mut names = std::collections::HashSet::new();
|
||||||
|
for conn in &self.connections {
|
||||||
|
if conn.name.is_empty() {
|
||||||
|
bail!("Connection name cannot be empty");
|
||||||
|
}
|
||||||
|
if names.contains(&conn.name) {
|
||||||
|
bail!("Duplicate connection name: {}", conn.name);
|
||||||
|
}
|
||||||
|
names.insert(conn.name.clone());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionConfig {
|
||||||
|
/// 构建 DSN
|
||||||
|
pub fn build_dsn(&self) -> String {
|
||||||
|
let password = urlencoding(&self.password);
|
||||||
|
format!(
|
||||||
|
"mysql://{}:{}@{}:{}/{}",
|
||||||
|
self.user, password, self.host, self.port, self.database
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// URL 编码密码中的特殊字符
|
||||||
|
fn urlencoding(s: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
for c in s.chars() {
|
||||||
|
match c {
|
||||||
|
'@' => result.push_str("%40"),
|
||||||
|
':' => result.push_str("%3A"),
|
||||||
|
'/' => result.push_str("%2F"),
|
||||||
|
'?' => result.push_str("%3F"),
|
||||||
|
'#' => result.push_str("%23"),
|
||||||
|
'[' => result.push_str("%5B"),
|
||||||
|
']' => result.push_str("%5D"),
|
||||||
|
'!' => result.push_str("%21"),
|
||||||
|
'$' => result.push_str("%24"),
|
||||||
|
'&' => result.push_str("%26"),
|
||||||
|
'\'' => result.push_str("%27"),
|
||||||
|
'(' => result.push_str("%28"),
|
||||||
|
')' => result.push_str("%29"),
|
||||||
|
'*' => result.push_str("%2A"),
|
||||||
|
'+' => result.push_str("%2B"),
|
||||||
|
',' => result.push_str("%2C"),
|
||||||
|
';' => result.push_str("%3B"),
|
||||||
|
'=' => result.push_str("%3D"),
|
||||||
|
'%' => result.push_str("%25"),
|
||||||
|
' ' => result.push_str("%20"),
|
||||||
|
_ => result.push(c),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
173
mysql-proxy/src/db.rs
Normal file
173
mysql-proxy/src/db.rs
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use mysql::{Pool, PooledConn, Opts,OptsBuilder, prelude::*};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use crate::config::{ConnectionConfig, PoolConfig};
|
||||||
|
|
||||||
|
/// 连接池状态
|
||||||
|
struct PoolState {
|
||||||
|
pool: Arc<Pool>,
|
||||||
|
last_used: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 连接池管理器 (延迟初始化)
|
||||||
|
pub struct ConnectionManager {
|
||||||
|
pools: Mutex<HashMap<String, PoolState>>,
|
||||||
|
configs: Mutex<HashMap<String, ConnectionConfig>>,
|
||||||
|
pool_config: PoolConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionManager {
|
||||||
|
/// 创建连接管理器 (不初始化连接)
|
||||||
|
pub fn new(configs: &[ConnectionConfig], pool_config: PoolConfig) -> Result<Self> {
|
||||||
|
let mut config_map = HashMap::new();
|
||||||
|
|
||||||
|
for cfg in configs {
|
||||||
|
config_map.insert(cfg.name.clone(), cfg.clone());
|
||||||
|
println!(" Registered: {} ({}/{})", cfg.name, cfg.host, cfg.database);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\n{} connection(s) configured (lazy init)", config_map.len());
|
||||||
|
println!("Pool config: max_connections={}, idle_timeout={}s",
|
||||||
|
pool_config.default_max_connections, pool_config.idle_timeout_secs);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
pools: Mutex::new(HashMap::new()),
|
||||||
|
configs: Mutex::new(config_map),
|
||||||
|
pool_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取或创建连接
|
||||||
|
pub fn get_conn(&self, name: &str) -> Result<PooledConn> {
|
||||||
|
// 1. 先尝试从已有池中获取
|
||||||
|
{
|
||||||
|
let mut pools = self.pools.lock().unwrap();
|
||||||
|
if let Some(state) = pools.get_mut(name) {
|
||||||
|
state.last_used = Instant::now();
|
||||||
|
return Ok(state.pool.get_conn()?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 没有则创建新池
|
||||||
|
let cfg = self.configs.lock().unwrap().get(name)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
println!("[LazyInit] Creating connection pool for: {}", name);
|
||||||
|
let pool = self.create_pool(&cfg)?;
|
||||||
|
let arc_pool = Arc::new(pool);
|
||||||
|
|
||||||
|
// 3. 保存并返回连接
|
||||||
|
{
|
||||||
|
let mut pools = self.pools.lock().unwrap();
|
||||||
|
pools.insert(name.to_string(), PoolState {
|
||||||
|
pool: arc_pool.clone(),
|
||||||
|
last_used: Instant::now(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(arc_pool.get_conn()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建连接池
|
||||||
|
fn create_pool(&self, cfg: &ConnectionConfig) -> Result<Pool> {
|
||||||
|
let dsn = cfg.build_dsn();
|
||||||
|
let opts: Opts = Opts::from_url(&dsn)?;
|
||||||
|
|
||||||
|
// 获取最大连接数 (单个配置 > 全局默认)
|
||||||
|
let max_conn = cfg.max_connections.unwrap_or(self.pool_config.default_max_connections);
|
||||||
|
|
||||||
|
// 构建带连接池参数的选项
|
||||||
|
let pool_opts = OptsBuilder::from_opts(opts)
|
||||||
|
.pool_opts(mysql::PoolOpts::new()
|
||||||
|
.with_constraints(mysql::PoolConstraints::new(1, max_conn).unwrap()));
|
||||||
|
|
||||||
|
let pool = Pool::new(pool_opts)?;
|
||||||
|
|
||||||
|
// 测试连接
|
||||||
|
let mut conn = pool.get_conn()?;
|
||||||
|
conn.query::<String, _>("SELECT 1")?;
|
||||||
|
|
||||||
|
println!("[LazyInit] ✓ Connected: {} (max_conn={})", cfg.name, max_conn);
|
||||||
|
Ok(pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有连接信息
|
||||||
|
pub fn list_connections(&self) -> Vec<ConnectionInfo> {
|
||||||
|
let pools = self.pools.lock().unwrap();
|
||||||
|
let configs = self.configs.lock().unwrap();
|
||||||
|
|
||||||
|
configs.iter().map(|(name, cfg)| {
|
||||||
|
let status = if pools.contains_key(name) { "connected" } else { "pending" };
|
||||||
|
ConnectionInfo {
|
||||||
|
name: name.clone(),
|
||||||
|
database: cfg.database.clone(),
|
||||||
|
host: cfg.host.clone(),
|
||||||
|
status: status.to_string(),
|
||||||
|
}
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清理空闲连接池
|
||||||
|
pub fn cleanup_idle(&self) {
|
||||||
|
let mut pools = self.pools.lock().unwrap();
|
||||||
|
let now = Instant::now();
|
||||||
|
let idle_timeout = Duration::from_secs(self.pool_config.idle_timeout_secs);
|
||||||
|
|
||||||
|
pools.retain(|name, state| {
|
||||||
|
let elapsed = now.duration_since(state.last_used);
|
||||||
|
if elapsed > idle_timeout {
|
||||||
|
println!("[Cleanup] Removing idle pool: {} (idle {}s)", name, elapsed.as_secs());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
true
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查连接是否健康
|
||||||
|
pub fn health_check(&self, name: &str) -> bool {
|
||||||
|
let pools = self.pools.lock().unwrap();
|
||||||
|
if let Some(state) = pools.get(name) {
|
||||||
|
return state.pool.get_conn().ok()
|
||||||
|
.and_then(|mut c| c.query::<String, _>("SELECT 1").ok())
|
||||||
|
.is_some();
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 动态添加连接配置 (临时,重启后消失)
|
||||||
|
pub fn add_connection(&self, cfg: ConnectionConfig) -> Result<()> {
|
||||||
|
let name = cfg.name.clone();
|
||||||
|
let mut configs = self.configs.lock().unwrap();
|
||||||
|
if configs.contains_key(&name) {
|
||||||
|
bail!("Connection '{}' already exists", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("[Dynamic] Adding: {} ({}/{})", name, cfg.host, cfg.database);
|
||||||
|
|
||||||
|
// 测试连接
|
||||||
|
let pool = self.create_pool(&cfg)?;
|
||||||
|
let arc_pool = Arc::new(pool);
|
||||||
|
|
||||||
|
// 保存
|
||||||
|
self.pools.lock().unwrap().insert(name.clone(), PoolState {
|
||||||
|
pool: arc_pool,
|
||||||
|
last_used: Instant::now(),
|
||||||
|
});
|
||||||
|
configs.insert(name.clone(), cfg);
|
||||||
|
|
||||||
|
println!("[Dynamic] ✓ Added: {}", name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct ConnectionInfo {
|
||||||
|
pub name: String,
|
||||||
|
pub database: String,
|
||||||
|
pub host: String,
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
320
mysql-proxy/src/handler.rs
Normal file
320
mysql-proxy/src/handler.rs
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use mysql::{prelude::*, Value, Row as MysqlRow};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::db::ConnectionManager;
|
||||||
|
use crate::logger::{LogEntry, RequestLogger};
|
||||||
|
|
||||||
|
/// 应用状态
|
||||||
|
pub type AppState = Arc<(Arc<ConnectionManager>, Arc<RequestLogger>)>;
|
||||||
|
|
||||||
|
// ============== 请求结构 ==============
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct QueryRequest {
|
||||||
|
/// 连接名称
|
||||||
|
pub conn: String,
|
||||||
|
/// SQL 语句
|
||||||
|
pub sql: String,
|
||||||
|
/// 输出格式: json, table, csv, vertical
|
||||||
|
#[serde(default = "default_format")]
|
||||||
|
pub format: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_format() -> String {
|
||||||
|
"json".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ExecuteRequest {
|
||||||
|
pub conn: String,
|
||||||
|
pub sql: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AddConnectionRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub user: String,
|
||||||
|
pub password: String,
|
||||||
|
pub database: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== 响应结构 ==============
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct QueryResponse {
|
||||||
|
pub columns: Vec<String>,
|
||||||
|
pub rows: Vec<Vec<Option<String>>>,
|
||||||
|
#[serde(rename = "rowCount")]
|
||||||
|
pub row_count: usize,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ExecuteResponse {
|
||||||
|
#[serde(rename = "affectedRows")]
|
||||||
|
pub affected_rows: u64,
|
||||||
|
#[serde(rename = "lastInsertId")]
|
||||||
|
pub last_insert_id: u64,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ConnectionsResponse {
|
||||||
|
pub connections: Vec<crate::db::ConnectionInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct HealthResponse {
|
||||||
|
pub status: String,
|
||||||
|
pub connections: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub error: String,
|
||||||
|
#[serde(rename = "usage")]
|
||||||
|
pub usage: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorResponse {
|
||||||
|
pub fn new(msg: &str) -> Self {
|
||||||
|
Self { error: msg.to_string(), usage: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_usage(mut self, usage: &str) -> Self {
|
||||||
|
self.usage = Some(usage.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== 处理器 ==============
|
||||||
|
|
||||||
|
/// 查询处理器
|
||||||
|
pub async fn query(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(req): Json<QueryRequest>,
|
||||||
|
) -> Result<Json<QueryResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let (manager, logger) = &*state;
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
// 获取连接
|
||||||
|
let mut conn = manager.get_conn(&req.conn)
|
||||||
|
.map_err(|e| {
|
||||||
|
let usage = "Usage: POST /query {\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}\n"
|
||||||
|
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
|
||||||
|
error_response_with_usage(&e.to_string(), &usage)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// 判断是否是查询语句
|
||||||
|
let sql_upper = req.sql.trim().to_uppercase();
|
||||||
|
let is_query = sql_upper.starts_with("SELECT")
|
||||||
|
|| sql_upper.starts_with("SHOW")
|
||||||
|
|| sql_upper.starts_with("DESCRIBE")
|
||||||
|
|| sql_upper.starts_with("DESC ")
|
||||||
|
|| sql_upper.starts_with("EXPLAIN")
|
||||||
|
|| sql_upper.starts_with("WITH");
|
||||||
|
|
||||||
|
if !is_query {
|
||||||
|
let err = error_response_with_usage(
|
||||||
|
"Not a SELECT query. Use /execute for INSERT/UPDATE/DELETE",
|
||||||
|
"Usage:\n POST /query {\"conn\": \"name\", \"sql\": \"SELECT ...\"}\n POST /execute {\"conn\": \"name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}"
|
||||||
|
);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
let result = conn.query_iter(&req.sql)
|
||||||
|
.map_err(|e| {
|
||||||
|
let usage = format!("SQL Error: {}\n\nUsage: POST /query {{\"conn\": \"name\", \"sql\": \"SELECT ...\"}}", e);
|
||||||
|
error_response_with_usage(&e.to_string(), &usage)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut columns: Vec<String> = Vec::new();
|
||||||
|
let mut data: Vec<Vec<Option<String>>> = Vec::new();
|
||||||
|
|
||||||
|
// 获取数据并从第一行提取列名
|
||||||
|
for row_result in result {
|
||||||
|
let row = row_result.map_err(|e| error_response(&e.to_string()))?;
|
||||||
|
|
||||||
|
// 从第一行获取列名
|
||||||
|
if columns.is_empty() {
|
||||||
|
columns = row.columns()
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.name_str().to_string())
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
let values = row_to_strings(&row, columns.len());
|
||||||
|
data.push(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
let duration = start.elapsed();
|
||||||
|
let row_count = data.len();
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
logger.log(&LogEntry::new("/query", "http")
|
||||||
|
.with_conn(&req.conn)
|
||||||
|
.with_sql(&req.sql)
|
||||||
|
.with_duration(duration.as_millis() as u64)
|
||||||
|
.with_rows(row_count));
|
||||||
|
|
||||||
|
Ok(Json(QueryResponse {
|
||||||
|
columns,
|
||||||
|
rows: data,
|
||||||
|
row_count,
|
||||||
|
duration_ms: duration.as_millis() as u64,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行处理器 (INSERT/UPDATE/DELETE)
|
||||||
|
pub async fn execute(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(req): Json<ExecuteRequest>,
|
||||||
|
) -> Result<Json<ExecuteResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let (manager, logger) = &*state;
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
// 获取连接
|
||||||
|
let mut conn = manager.get_conn(&req.conn)
|
||||||
|
.map_err(|e| {
|
||||||
|
let usage = "Usage: POST /execute {\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}\n"
|
||||||
|
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
|
||||||
|
error_response_with_usage(&e.to_string(), &usage)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// 执行
|
||||||
|
let result = conn.query_iter(&req.sql)
|
||||||
|
.map_err(|e| error_response(&e.to_string()))?;
|
||||||
|
|
||||||
|
let duration = start.elapsed();
|
||||||
|
let affected = result.affected_rows();
|
||||||
|
let last_id = result.last_insert_id().unwrap_or(0);
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
logger.log(&LogEntry::new("/execute", "http")
|
||||||
|
.with_conn(&req.conn)
|
||||||
|
.with_sql(&req.sql)
|
||||||
|
.with_duration(duration.as_millis() as u64));
|
||||||
|
|
||||||
|
Ok(Json(ExecuteResponse {
|
||||||
|
affected_rows: affected,
|
||||||
|
last_insert_id: last_id,
|
||||||
|
duration_ms: duration.as_millis() as u64,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取连接列表
|
||||||
|
pub async fn connections(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Json<ConnectionsResponse> {
|
||||||
|
let (manager, _) = &*state;
|
||||||
|
Json(ConnectionsResponse {
|
||||||
|
connections: manager.list_connections(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 健康检查
|
||||||
|
pub async fn health(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Json<HealthResponse> {
|
||||||
|
let (manager, _) = &*state;
|
||||||
|
Json(HealthResponse {
|
||||||
|
status: "ok".to_string(),
|
||||||
|
connections: manager.list_connections().len(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 动态添加连接
|
||||||
|
pub async fn add_connection(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(req): Json<AddConnectionRequest>,
|
||||||
|
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let (manager, logger) = &*state;
|
||||||
|
use crate::config::ConnectionConfig;
|
||||||
|
|
||||||
|
let cfg = ConnectionConfig {
|
||||||
|
name: req.name.clone(),
|
||||||
|
host: req.host,
|
||||||
|
port: req.port,
|
||||||
|
user: req.user,
|
||||||
|
password: req.password,
|
||||||
|
database: req.database,
|
||||||
|
max_connections: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.add_connection(cfg)
|
||||||
|
.map_err(|e| error_response(&e.to_string()))?;
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
logger.log(&LogEntry::new("/connections/add", "http")
|
||||||
|
.with_conn(&req.name)
|
||||||
|
.with_duration(0));
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"message": format!("Connection '{}' added (temporary, will be lost on restart)", req.name)
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== 辅助函数 ==============
|
||||||
|
|
||||||
|
fn error_response(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||||
|
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||||
|
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_conn_names(manager: &ConnectionManager) -> String {
|
||||||
|
manager.list_connections()
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.name.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 MySQL Row 转换为字符串向量
|
||||||
|
fn row_to_strings(row: &MysqlRow, col_count: usize) -> Vec<Option<String>> {
|
||||||
|
(0..col_count)
|
||||||
|
.map(|i| {
|
||||||
|
match row.get::<Value, usize>(i) {
|
||||||
|
Some(value) => value_to_string(&value),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 Value 转换为字符串
|
||||||
|
fn value_to_string(value: &Value) -> Option<String> {
|
||||||
|
match value {
|
||||||
|
Value::NULL => None,
|
||||||
|
Value::Bytes(bytes) => String::from_utf8(bytes.clone()).ok(),
|
||||||
|
Value::Int(i) => Some(i.to_string()),
|
||||||
|
Value::UInt(u) => Some(u.to_string()),
|
||||||
|
Value::Float(f) => Some(f.to_string()),
|
||||||
|
Value::Double(d) => Some(d.to_string()),
|
||||||
|
Value::Date(year, month, day, hour, min, sec, micro) => {
|
||||||
|
match (hour, min, sec, micro) {
|
||||||
|
(0, 0, 0, 0) => Some(format!("{:04}-{:02}-{:02}", year, month, day)),
|
||||||
|
_ => Some(format!("{:04}-{:02}-{:02} {:02}:{:02}:{:02}", year, month, day, hour, min, sec)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Time(neg, days, hours, minutes, seconds, microseconds) => {
|
||||||
|
let sign = if *neg { "-" } else { "" };
|
||||||
|
Some(format!("{}{} {}:{:02}:{:02}.{:06}", sign, days, hours, minutes, seconds, microseconds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
228
mysql-proxy/src/logger.rs
Normal file
228
mysql-proxy/src/logger.rs
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs::{File, OpenOptions};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
/// 日志记录器
|
||||||
|
pub struct RequestLogger {
|
||||||
|
log_file: Mutex<Option<File>>,
|
||||||
|
enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestLogger {
|
||||||
|
pub fn new(log_path: Option<&str>) -> Self {
|
||||||
|
let (log_file, enabled) = if let Some(path) = log_path {
|
||||||
|
let path = expand_path(path);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
let _ = std::fs::create_dir_all(parent);
|
||||||
|
}
|
||||||
|
match OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&path)
|
||||||
|
{
|
||||||
|
Ok(file) => (Some(file), true),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e);
|
||||||
|
(None, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(None, false)
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
log_file: Mutex::new(log_file),
|
||||||
|
enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_enabled(&self) -> bool {
|
||||||
|
self.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录请求日志
|
||||||
|
pub fn log(&self, entry: &LogEntry) {
|
||||||
|
if !self.enabled {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let json = match serde_json::to_string(entry) {
|
||||||
|
Ok(j) => j,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[Logger] Failed to serialize log: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(mut file) = self.log_file.lock() {
|
||||||
|
if let Some(ref mut f) = *file {
|
||||||
|
let _ = writeln!(f, "{}", json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 日志条目
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct LogEntry {
|
||||||
|
pub timestamp: String,
|
||||||
|
pub level: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub log_type: String,
|
||||||
|
pub client: String,
|
||||||
|
pub endpoint: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub conn: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub server: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sql: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub command: Option<String>,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rows: Option<usize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(rename = "exitCode")]
|
||||||
|
pub exit_code: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogEntry {
|
||||||
|
pub fn new(endpoint: &str, client: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
level: "INFO".to_string(),
|
||||||
|
log_type: "request".to_string(),
|
||||||
|
client: client.to_string(),
|
||||||
|
endpoint: endpoint.to_string(),
|
||||||
|
conn: None,
|
||||||
|
server: None,
|
||||||
|
sql: None,
|
||||||
|
command: None,
|
||||||
|
duration_ms: 0,
|
||||||
|
rows: None,
|
||||||
|
exit_code: None,
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_conn(mut self, conn: &str) -> Self {
|
||||||
|
self.conn = Some(conn.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_server(mut self, server: &str) -> Self {
|
||||||
|
self.server = Some(server.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_sql(mut self, sql: &str) -> Self {
|
||||||
|
// 截断过长的 SQL
|
||||||
|
self.sql = Some(truncate_string(sql, 1000));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_command(mut self, command: &str) -> Self {
|
||||||
|
// 截断过长的命令
|
||||||
|
self.command = Some(truncate_string(command, 500));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_duration(mut self, ms: u64) -> Self {
|
||||||
|
self.duration_ms = ms;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_rows(mut self, rows: usize) -> Self {
|
||||||
|
self.rows = Some(rows);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_exit_code(mut self, code: i32) -> Self {
|
||||||
|
self.exit_code = Some(code);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_error(mut self, error: &str) -> Self {
|
||||||
|
self.level = "ERROR".to_string();
|
||||||
|
self.error = Some(error.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_timestamp() -> String {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default();
|
||||||
|
let secs = now.as_secs();
|
||||||
|
let datetime = chrono_timestamp(secs);
|
||||||
|
format!("{}Z", datetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chrono_timestamp(secs: u64) -> String {
|
||||||
|
let days = secs / 86400;
|
||||||
|
let remaining = secs % 86400;
|
||||||
|
let hours = remaining / 3600;
|
||||||
|
let minutes = (remaining % 3600) / 60;
|
||||||
|
let seconds = remaining % 60;
|
||||||
|
|
||||||
|
// 从 1970-01-01 开始计算日期
|
||||||
|
let mut year = 1970;
|
||||||
|
let mut days_left = days;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
|
||||||
|
if days_left < days_in_year {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
days_left -= days_in_year;
|
||||||
|
year += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let month_days = if is_leap_year(year) {
|
||||||
|
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||||
|
} else {
|
||||||
|
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut month = 1;
|
||||||
|
for &days_in_month in &month_days {
|
||||||
|
if days_left < days_in_month {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
days_left -= days_in_month;
|
||||||
|
month += 1;
|
||||||
|
}
|
||||||
|
let day = days_left + 1;
|
||||||
|
|
||||||
|
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_leap_year(year: u64) -> bool {
|
||||||
|
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_string(s: &str, max_len: usize) -> String {
|
||||||
|
if s.len() <= max_len {
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}... (truncated)", &s[..max_len])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand_path(path: &str) -> PathBuf {
|
||||||
|
if path.starts_with('~') {
|
||||||
|
let home = std::env::var("HOME")
|
||||||
|
.or_else(|_| std::env::var("USERPROFILE"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
PathBuf::from(path.replacen('~', &home, 1))
|
||||||
|
} else {
|
||||||
|
PathBuf::from(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
212
mysql-proxy/src/main.rs
Normal file
212
mysql-proxy/src/main.rs
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
mod cli;
|
||||||
|
mod config;
|
||||||
|
mod db;
|
||||||
|
mod handler;
|
||||||
|
mod logger;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
routing::{get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "mysql-proxy")]
|
||||||
|
#[command(about = "MySQL HTTP proxy with connection pooling")]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
/// Config file path
|
||||||
|
#[arg(short, long, default_value = "mysql-proxy.toml", global = true)]
|
||||||
|
config: String,
|
||||||
|
|
||||||
|
/// Server URL (for CLI mode)
|
||||||
|
#[arg(short = 'S', long, default_value = "http://127.0.0.1:3307", global = true)]
|
||||||
|
server: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
/// Start HTTP server (default)
|
||||||
|
Server {
|
||||||
|
/// Port to listen on
|
||||||
|
#[arg(short = 'P', long)]
|
||||||
|
port: Option<u16>,
|
||||||
|
|
||||||
|
/// Host to bind
|
||||||
|
#[arg(short = 'H', long)]
|
||||||
|
host: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Execute SQL via proxy (for AI clients)
|
||||||
|
Cli {
|
||||||
|
/// Connection name
|
||||||
|
#[arg(short, long, default_value = "flux_dev")]
|
||||||
|
conn: String,
|
||||||
|
|
||||||
|
/// SQL to execute (与 mysql 官方一致)
|
||||||
|
#[arg(short = 'e', long)]
|
||||||
|
sql: Option<String>,
|
||||||
|
|
||||||
|
/// Output format: table, json, csv, vertical
|
||||||
|
#[arg(short = 'F', long, default_value = "table")]
|
||||||
|
format: String,
|
||||||
|
|
||||||
|
/// Execute INSERT/UPDATE/DELETE
|
||||||
|
#[arg(short = 'x', long)]
|
||||||
|
execute: bool,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// List available connections
|
||||||
|
Connections,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
match args.command {
|
||||||
|
Some(Commands::Server { port, host }) => {
|
||||||
|
run_server(&args.config, port, host).await
|
||||||
|
}
|
||||||
|
Some(Commands::Cli { conn, sql, format, execute }) => {
|
||||||
|
run_cli(&args.server, &conn, sql.as_deref(), &format, execute)
|
||||||
|
}
|
||||||
|
Some(Commands::Connections) => {
|
||||||
|
list_connections(&args.server)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
run_server(&args.config, None, None).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 启动 HTTP 服务器
|
||||||
|
async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>) -> anyhow::Result<()> {
|
||||||
|
println!("MySQL HTTP Proxy v0.1.0\n");
|
||||||
|
|
||||||
|
// 加载配置
|
||||||
|
let mut config = config::Config::from_file(config_path)?;
|
||||||
|
|
||||||
|
// 命令行参数覆盖配置文件
|
||||||
|
if let Some(port) = port {
|
||||||
|
config.server.port = port;
|
||||||
|
}
|
||||||
|
if let Some(host) = host {
|
||||||
|
config.server.host = host;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化日志
|
||||||
|
let log_path = std::env::var("MYSQL_PROXY_LOG").ok();
|
||||||
|
let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref()));
|
||||||
|
if logger.is_enabled() {
|
||||||
|
println!("Request logging: enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化连接池
|
||||||
|
println!("Initializing connection pools...\n");
|
||||||
|
let manager = Arc::new(db::ConnectionManager::new(&config.connections, config.pool.clone())?);
|
||||||
|
|
||||||
|
// 启动后台清理任务
|
||||||
|
let manager_clone = manager.clone();
|
||||||
|
let check_interval = config.pool.check_interval_secs;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(
|
||||||
|
std::time::Duration::from_secs(check_interval)
|
||||||
|
);
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
manager_clone.cleanup_idle();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 构建路由
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/query", post(handler::query))
|
||||||
|
.route("/execute", post(handler::execute))
|
||||||
|
.route("/connections", get(handler::connections))
|
||||||
|
.route("/connections/add", post(handler::add_connection))
|
||||||
|
.route("/health", get(handler::health))
|
||||||
|
.with_state(Arc::new((manager, logger)));
|
||||||
|
|
||||||
|
// 启动服务器
|
||||||
|
let addr = format!("{}:{}", config.server.host, config.server.port);
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
|
|
||||||
|
println!("\nServer started at http://{}", addr);
|
||||||
|
println!("\nAPI Endpoints:");
|
||||||
|
println!(" POST /query - Execute SELECT queries");
|
||||||
|
println!(" POST /execute - Execute INSERT/UPDATE/DELETE");
|
||||||
|
println!(" GET /connections - List all connections");
|
||||||
|
println!(" POST /connections/add - Add connection (temporary)");
|
||||||
|
println!(" GET /health - Health check");
|
||||||
|
println!("\nCLI Usage:");
|
||||||
|
println!(" mysql-proxy cli -c flux_dev -e \"SELECT 1\"");
|
||||||
|
println!(" mysql-proxy connections");
|
||||||
|
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CLI 模式
|
||||||
|
fn run_cli(server: &str, conn: &str, sql: Option<&str>, format: &str, execute: bool) -> anyhow::Result<()> {
|
||||||
|
let cli = cli::Cli::new(Some(server.to_string()));
|
||||||
|
|
||||||
|
// 检查服务器是否运行
|
||||||
|
if !cli.check_server()? {
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("=== mysql-proxy 未运行,请使用降级方案 ===");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("启动代理: mysql-proxy");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("或直接使用 mysql 命令:");
|
||||||
|
eprintln!(" mysql -h<host> -u<user> -p<password> -D<database> -e \"<SQL>\"");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("配置文件中的连接信息见: mysql-proxy.toml");
|
||||||
|
anyhow::bail!("Proxy server not running at {}", server);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有提供 SQL,从 stdin 读取
|
||||||
|
let sql = match sql {
|
||||||
|
Some(s) => s.to_string(),
|
||||||
|
None => {
|
||||||
|
use std::io::{self, BufRead};
|
||||||
|
let stdin = io::stdin();
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
for line in stdin.lock().lines() {
|
||||||
|
lines.push(line?);
|
||||||
|
}
|
||||||
|
lines.join(" ")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if sql.trim().is_empty() {
|
||||||
|
anyhow::bail!("No SQL provided. Use -e \"SQL\" or pipe SQL via stdin");
|
||||||
|
}
|
||||||
|
|
||||||
|
if execute {
|
||||||
|
cli.execute(conn, &sql)?;
|
||||||
|
} else {
|
||||||
|
cli.query(conn, &sql, Some(format))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出连接
|
||||||
|
fn list_connections(server: &str) -> anyhow::Result<()> {
|
||||||
|
let cli = cli::Cli::new(Some(server.to_string()));
|
||||||
|
|
||||||
|
if !cli.check_server()? {
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("=== mysql-proxy 未运行 ===");
|
||||||
|
eprintln!("启动代理: mysql-proxy");
|
||||||
|
eprintln!("配置文件: mysql-proxy.toml");
|
||||||
|
anyhow::bail!("Proxy server not running at {}", server);
|
||||||
|
}
|
||||||
|
|
||||||
|
cli.list_connections()
|
||||||
|
}
|
||||||
21
ssh-proxy/Cargo.toml
Normal file
21
ssh-proxy/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "ssh-proxy"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
description = "SSH HTTP proxy with session pooling"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ssh2 = "0.9"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
axum = "0.7"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
toml = "0.8"
|
||||||
|
anyhow = "1"
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
ureq = "2" # CLI HTTP client (2.x has simpler API)
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
strip = true
|
||||||
218
ssh-proxy/src/cli.rs
Normal file
218
ssh-proxy/src/cli.rs
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
const DEFAULT_SERVER: &str = "http://127.0.0.1:3308";
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ExecRequest {
|
||||||
|
server: String,
|
||||||
|
command: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ExecResponse {
|
||||||
|
stdout: String,
|
||||||
|
stderr: String,
|
||||||
|
#[serde(rename = "exitCode")]
|
||||||
|
exit_code: i32,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ErrorResponse {
|
||||||
|
error: String,
|
||||||
|
#[serde(default)]
|
||||||
|
usage: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ServersResponse {
|
||||||
|
servers: Vec<ServerInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ServerInfo {
|
||||||
|
name: String,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
user: String,
|
||||||
|
status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CLI 客户端
|
||||||
|
pub struct Cli {
|
||||||
|
server: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cli {
|
||||||
|
pub fn new(server: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
server: server.unwrap_or_else(|| DEFAULT_SERVER.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行远程命令
|
||||||
|
pub fn exec(&self, server: &str, command: &str, format: &str) -> Result<()> {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let body = serde_json::to_string(&ExecRequest {
|
||||||
|
server: server.to_string(),
|
||||||
|
command: command.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response = ureq::post(&format!("{}/exec", self.server))
|
||||||
|
.set("Content-Type", "application/json")
|
||||||
|
.send_string(&body);
|
||||||
|
|
||||||
|
// ureq 2.x: 非 2xx 会返回 Err,但响应体在 error 中
|
||||||
|
let body = match response {
|
||||||
|
Ok(r) => r.into_string()?,
|
||||||
|
Err(ureq::Error::Status(_, resp)) => resp.into_string()?,
|
||||||
|
Err(e) => bail!("HTTP error: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
eprintln!("Error: {}", err.error);
|
||||||
|
if let Some(usage) = err.usage {
|
||||||
|
eprintln!("\n{}", usage);
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ExecResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
// 根据格式输出
|
||||||
|
if format == "json" {
|
||||||
|
self.print_json(&result, start.elapsed().as_millis() as u64);
|
||||||
|
} else {
|
||||||
|
self.print_text(&result, start.elapsed().as_millis() as u64);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_text(&self, result: &ExecResponse, total_ms: u64) {
|
||||||
|
if !result.stdout.is_empty() {
|
||||||
|
print!("{}", result.stdout);
|
||||||
|
}
|
||||||
|
if !result.stderr.is_empty() {
|
||||||
|
eprint!("{}", result.stderr);
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!("\n--- exit: {}, {}ms db, {}ms total ---",
|
||||||
|
result.exit_code, result.duration_ms, total_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_json(&self, result: &ExecResponse, total_ms: u64) {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Output {
|
||||||
|
#[serde(rename = "exitCode")]
|
||||||
|
exit_code: i32,
|
||||||
|
stdout: String,
|
||||||
|
stderr: String,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
duration_ms: u64,
|
||||||
|
#[serde(rename = "totalMs")]
|
||||||
|
total_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = Output {
|
||||||
|
exit_code: result.exit_code,
|
||||||
|
stdout: result.stdout.clone(),
|
||||||
|
stderr: result.stderr.clone(),
|
||||||
|
duration_ms: result.duration_ms,
|
||||||
|
total_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
match serde_json::to_string(&output) {
|
||||||
|
Ok(json) => println!("{}", json),
|
||||||
|
Err(e) => eprintln!("JSON error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出服务器
|
||||||
|
pub fn list_servers(&self) -> Result<()> {
|
||||||
|
let response = ureq::get(&format!("{}/servers", self.server))
|
||||||
|
.call()?;
|
||||||
|
|
||||||
|
let body = response.into_string()?;
|
||||||
|
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
anyhow::bail!("{}", err.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ServersResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
println!("Servers:");
|
||||||
|
println!("{:<15} {:<30} {:<8} {:<10} {:<10}", "Name", "Host", "Port", "User", "Status");
|
||||||
|
println!("{}", "-".repeat(73));
|
||||||
|
|
||||||
|
for srv in result.servers {
|
||||||
|
println!("{:<15} {:<30} {:<8} {:<10} {:<10}",
|
||||||
|
srv.name, srv.host, srv.port, srv.user, srv.status);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 动态添加服务器
|
||||||
|
pub fn add_server(
|
||||||
|
&self,
|
||||||
|
name: String,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
user: String,
|
||||||
|
password: Option<String>,
|
||||||
|
private_key: Option<String>,
|
||||||
|
) -> Result<()> {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct AddRequest {
|
||||||
|
name: String,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
user: String,
|
||||||
|
password: Option<String>,
|
||||||
|
private_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = serde_json::to_string(&AddRequest {
|
||||||
|
name,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
user,
|
||||||
|
password,
|
||||||
|
private_key,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response = ureq::post(&format!("{}/servers/add", self.server))
|
||||||
|
.set("Content-Type", "application/json")
|
||||||
|
.send_string(&body)?;
|
||||||
|
|
||||||
|
let body = response.into_string()?;
|
||||||
|
|
||||||
|
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
|
||||||
|
anyhow::bail!("{}", err.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AddResponse {
|
||||||
|
success: bool,
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: AddResponse = serde_json::from_str(&body)?;
|
||||||
|
println!("{}", result.message);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查代理是否运行
|
||||||
|
pub fn check_server(&self) -> Result<bool> {
|
||||||
|
match ureq::get(&format!("{}/health", self.server)).call() {
|
||||||
|
Ok(_) => Ok(true),
|
||||||
|
Err(_) => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
112
ssh-proxy/src/config.rs
Normal file
112
ssh-proxy/src/config.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default = "default_host")]
|
||||||
|
pub host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_host() -> String {
|
||||||
|
"127.0.0.1".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct PoolConfig {
|
||||||
|
#[serde(default = "default_idle_timeout")]
|
||||||
|
pub idle_timeout_secs: u64,
|
||||||
|
#[serde(default = "default_check_interval")]
|
||||||
|
pub check_interval_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_idle_timeout() -> u64 { 300 }
|
||||||
|
fn default_check_interval() -> u64 { 60 }
|
||||||
|
|
||||||
|
impl Default for PoolConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
idle_timeout_secs: default_idle_timeout(),
|
||||||
|
check_interval_secs: default_check_interval(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
|
pub struct SshServerConfig {
|
||||||
|
pub name: String,
|
||||||
|
pub host: String,
|
||||||
|
#[serde(default = "default_ssh_port")]
|
||||||
|
pub port: u16,
|
||||||
|
pub user: String,
|
||||||
|
/// 私钥路径 (优先)
|
||||||
|
pub private_key: Option<String>,
|
||||||
|
/// 密码 (备选)
|
||||||
|
pub password: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ssh_port() -> u16 { 22 }
|
||||||
|
|
||||||
|
impl SshServerConfig {
|
||||||
|
/// 获取私钥路径
|
||||||
|
pub fn get_private_key_path(&self) -> Option<PathBuf> {
|
||||||
|
self.private_key.as_ref().map(|p| {
|
||||||
|
// 支持 ~ 路径展开 (Unix 和 Windows)
|
||||||
|
let path = if p.starts_with('~') {
|
||||||
|
let home = std::env::var("HOME")
|
||||||
|
.or_else(|_| std::env::var("USERPROFILE"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
p.replacen('~', &home, 1)
|
||||||
|
} else {
|
||||||
|
p.clone()
|
||||||
|
};
|
||||||
|
PathBuf::from(path)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
#[serde(default = "default_server")]
|
||||||
|
pub server: ServerConfig,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pool: PoolConfig,
|
||||||
|
pub servers: Vec<SshServerConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_server() -> ServerConfig {
|
||||||
|
ServerConfig {
|
||||||
|
port: 3308,
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn from_file(path: &str) -> Result<Self> {
|
||||||
|
let content = fs::read_to_string(path)?;
|
||||||
|
let config: Config = toml::from_str(&content)?;
|
||||||
|
config.validate()?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(&self) -> Result<()> {
|
||||||
|
let mut names = HashSet::new();
|
||||||
|
for srv in &self.servers {
|
||||||
|
if srv.name.is_empty() {
|
||||||
|
bail!("Server name cannot be empty");
|
||||||
|
}
|
||||||
|
if names.contains(&srv.name) {
|
||||||
|
bail!("Duplicate server name: {}", srv.name);
|
||||||
|
}
|
||||||
|
names.insert(srv.name.clone());
|
||||||
|
|
||||||
|
if srv.private_key.is_none() && srv.password.is_none() {
|
||||||
|
bail!("Server '{}' needs either private_key or password", srv.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
183
ssh-proxy/src/handler.rs
Normal file
183
ssh-proxy/src/handler.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
use axum::{extract::State, http::StatusCode, Json};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use crate::session::SessionManager;
|
||||||
|
use crate::logger::{LogEntry, RequestLogger};
|
||||||
|
|
||||||
|
pub type AppState = Arc<(Arc<SessionManager>, Arc<RequestLogger>)>;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ExecRequest {
|
||||||
|
pub server: String,
|
||||||
|
pub command: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AddServerRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub host: String,
|
||||||
|
#[serde(default = "default_ssh_port")]
|
||||||
|
pub port: u16,
|
||||||
|
#[serde(default = "default_user")]
|
||||||
|
pub user: String,
|
||||||
|
pub password: Option<String>,
|
||||||
|
pub private_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ssh_port() -> u16 { 22 }
|
||||||
|
fn default_user() -> String { "root".to_string() }
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ExecResponse {
|
||||||
|
pub stdout: String,
|
||||||
|
pub stderr: String,
|
||||||
|
#[serde(rename = "exitCode")]
|
||||||
|
pub exit_code: i32,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ServersResponse {
|
||||||
|
pub servers: Vec<crate::session::ServerInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct HealthResponse {
|
||||||
|
pub status: String,
|
||||||
|
pub servers: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub error: String,
|
||||||
|
#[serde(rename = "usage")]
|
||||||
|
pub usage: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorResponse {
|
||||||
|
pub fn new(msg: &str) -> Self {
|
||||||
|
Self { error: msg.to_string(), usage: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_usage(mut self, usage: &str) -> Self {
|
||||||
|
self.usage = Some(usage.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn exec(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(req): Json<ExecRequest>,
|
||||||
|
) -> Result<Json<ExecResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let (manager, logger) = &*state;
|
||||||
|
|
||||||
|
let manager_clone = manager.clone();
|
||||||
|
let server = req.server.clone();
|
||||||
|
let command = req.command.clone();
|
||||||
|
|
||||||
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
|
manager_clone.exec(&server, &command)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error_response_with_usage(&e.to_string(), "Internal error occurred")
|
||||||
|
})?
|
||||||
|
.map_err(|e| {
|
||||||
|
let usage = format!(
|
||||||
|
"Usage: POST /exec {{\"server\": \"server_name\", \"command\": \"your command\"}}\n\nAvailable servers: {}",
|
||||||
|
list_server_names(&manager)
|
||||||
|
);
|
||||||
|
error_response_with_usage(&e.to_string(), &usage)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let duration_ms = start.elapsed().as_millis() as u64;
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
logger.log(&LogEntry::new("/exec", "http")
|
||||||
|
.with_server(&req.server)
|
||||||
|
.with_command(&req.command)
|
||||||
|
.with_duration(duration_ms)
|
||||||
|
.with_exit_code(result.exit_code));
|
||||||
|
|
||||||
|
Ok(Json(ExecResponse {
|
||||||
|
stdout: result.stdout,
|
||||||
|
stderr: result.stderr,
|
||||||
|
exit_code: result.exit_code,
|
||||||
|
duration_ms,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn servers(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Json<ServersResponse> {
|
||||||
|
let (manager, _) = &*state;
|
||||||
|
Json(ServersResponse { servers: manager.list_servers() })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_server(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(req): Json<AddServerRequest>,
|
||||||
|
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let (manager, logger) = &*state;
|
||||||
|
use crate::config::SshServerConfig;
|
||||||
|
|
||||||
|
// 验证必须有密码或私钥
|
||||||
|
if req.password.is_none() && req.private_key.is_none() {
|
||||||
|
return Err(error_response_with_usage(
|
||||||
|
"Server requires either password or private_key",
|
||||||
|
"Usage: POST /servers/add {\"name\": \"myserver\", \"host\": \"192.168.1.100\", \"user\": \"root\", \"password\": \"secret\"}\n or: {\"name\": \"myserver\", \"host\": \"...\", \"private_key\": \"~/.ssh/id_rsa\"}"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg = SshServerConfig {
|
||||||
|
name: req.name.clone(),
|
||||||
|
host: req.host,
|
||||||
|
port: req.port,
|
||||||
|
user: req.user,
|
||||||
|
password: req.password,
|
||||||
|
private_key: req.private_key,
|
||||||
|
};
|
||||||
|
|
||||||
|
manager.add_server(cfg)
|
||||||
|
.map_err(|e| error_response(&e.to_string()))?;
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
logger.log(&LogEntry::new("/servers/add", "http")
|
||||||
|
.with_server(&req.name)
|
||||||
|
.with_duration(0));
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"message": format!("Server '{}' added (temporary, will be lost on restart)", req.name)
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn health(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Json<HealthResponse> {
|
||||||
|
let (manager, _) = &*state;
|
||||||
|
Json(HealthResponse {
|
||||||
|
status: "ok".to_string(),
|
||||||
|
servers: manager.list_servers().len(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||||
|
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response_with_usage(msg: &str, usage: &str) -> (StatusCode, Json<ErrorResponse>) {
|
||||||
|
(StatusCode::BAD_REQUEST, Json(ErrorResponse::new(msg).with_usage(usage)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_server_names(manager: &SessionManager) -> String {
|
||||||
|
manager.list_servers()
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.name.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
}
|
||||||
228
ssh-proxy/src/logger.rs
Normal file
228
ssh-proxy/src/logger.rs
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs::{File, OpenOptions};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
/// 日志记录器
|
||||||
|
pub struct RequestLogger {
|
||||||
|
log_file: Mutex<Option<File>>,
|
||||||
|
enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestLogger {
|
||||||
|
pub fn new(log_path: Option<&str>) -> Self {
|
||||||
|
let (log_file, enabled) = if let Some(path) = log_path {
|
||||||
|
let path = expand_path(path);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
let _ = std::fs::create_dir_all(parent);
|
||||||
|
}
|
||||||
|
match OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&path)
|
||||||
|
{
|
||||||
|
Ok(file) => (Some(file), true),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[Logger] Failed to open log file {}: {}", path.display(), e);
|
||||||
|
(None, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(None, false)
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
log_file: Mutex::new(log_file),
|
||||||
|
enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_enabled(&self) -> bool {
|
||||||
|
self.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录请求日志
|
||||||
|
pub fn log(&self, entry: &LogEntry) {
|
||||||
|
if !self.enabled {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let json = match serde_json::to_string(entry) {
|
||||||
|
Ok(j) => j,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[Logger] Failed to serialize log: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(mut file) = self.log_file.lock() {
|
||||||
|
if let Some(ref mut f) = *file {
|
||||||
|
let _ = writeln!(f, "{}", json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 日志条目
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct LogEntry {
|
||||||
|
pub timestamp: String,
|
||||||
|
pub level: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub log_type: String,
|
||||||
|
pub client: String,
|
||||||
|
pub endpoint: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub conn: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub server: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sql: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub command: Option<String>,
|
||||||
|
#[serde(rename = "durationMs")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rows: Option<usize>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(rename = "exitCode")]
|
||||||
|
pub exit_code: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogEntry {
|
||||||
|
pub fn new(endpoint: &str, client: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
level: "INFO".to_string(),
|
||||||
|
log_type: "request".to_string(),
|
||||||
|
client: client.to_string(),
|
||||||
|
endpoint: endpoint.to_string(),
|
||||||
|
conn: None,
|
||||||
|
server: None,
|
||||||
|
sql: None,
|
||||||
|
command: None,
|
||||||
|
duration_ms: 0,
|
||||||
|
rows: None,
|
||||||
|
exit_code: None,
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_conn(mut self, conn: &str) -> Self {
|
||||||
|
self.conn = Some(conn.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_server(mut self, server: &str) -> Self {
|
||||||
|
self.server = Some(server.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_sql(mut self, sql: &str) -> Self {
|
||||||
|
// 截断过长的 SQL
|
||||||
|
self.sql = Some(truncate_string(sql, 1000));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_command(mut self, command: &str) -> Self {
|
||||||
|
// 截断过长的命令
|
||||||
|
self.command = Some(truncate_string(command, 500));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_duration(mut self, ms: u64) -> Self {
|
||||||
|
self.duration_ms = ms;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_rows(mut self, rows: usize) -> Self {
|
||||||
|
self.rows = Some(rows);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_exit_code(mut self, code: i32) -> Self {
|
||||||
|
self.exit_code = Some(code);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_error(mut self, error: &str) -> Self {
|
||||||
|
self.level = "ERROR".to_string();
|
||||||
|
self.error = Some(error.to_string());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_timestamp() -> String {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default();
|
||||||
|
let secs = now.as_secs();
|
||||||
|
let datetime = chrono_timestamp(secs);
|
||||||
|
format!("{}Z", datetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chrono_timestamp(secs: u64) -> String {
|
||||||
|
let days = secs / 86400;
|
||||||
|
let remaining = secs % 86400;
|
||||||
|
let hours = remaining / 3600;
|
||||||
|
let minutes = (remaining % 3600) / 60;
|
||||||
|
let seconds = remaining % 60;
|
||||||
|
|
||||||
|
// 从 1970-01-01 开始计算日期
|
||||||
|
let mut year = 1970;
|
||||||
|
let mut days_left = days;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let days_in_year = if is_leap_year(year) { 366 } else { 365 };
|
||||||
|
if days_left < days_in_year {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
days_left -= days_in_year;
|
||||||
|
year += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let month_days = if is_leap_year(year) {
|
||||||
|
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||||
|
} else {
|
||||||
|
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut month = 1;
|
||||||
|
for &days_in_month in &month_days {
|
||||||
|
if days_left < days_in_month {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
days_left -= days_in_month;
|
||||||
|
month += 1;
|
||||||
|
}
|
||||||
|
let day = days_left + 1;
|
||||||
|
|
||||||
|
format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}", year, month, day, hours, minutes, seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_leap_year(year: u64) -> bool {
|
||||||
|
(year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_string(s: &str, max_len: usize) -> String {
|
||||||
|
if s.len() <= max_len {
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}... (truncated)", &s[..max_len])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand_path(path: &str) -> PathBuf {
|
||||||
|
if path.starts_with('~') {
|
||||||
|
let home = std::env::var("HOME")
|
||||||
|
.or_else(|_| std::env::var("USERPROFILE"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
PathBuf::from(path.replacen('~', &home, 1))
|
||||||
|
} else {
|
||||||
|
PathBuf::from(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
170
ssh-proxy/src/main.rs
Normal file
170
ssh-proxy/src/main.rs
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
mod cli;
|
||||||
|
mod config;
|
||||||
|
mod handler;
|
||||||
|
mod logger;
|
||||||
|
mod session;
|
||||||
|
|
||||||
|
use axum::{routing::{get, post}, Router};
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "ssh-proxy")]
|
||||||
|
#[command(about = "SSH HTTP proxy with session pooling")]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
#[arg(long, default_value = "ssh-proxy.toml", global = true)]
|
||||||
|
config: String,
|
||||||
|
#[arg(short = 'S', long, default_value = "http://127.0.0.1:3308", global = true)]
|
||||||
|
server: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
Server {
|
||||||
|
#[arg(short = 'P', long)]
|
||||||
|
port: Option<u16>,
|
||||||
|
#[arg(short = 'H', long)]
|
||||||
|
host: Option<String>,
|
||||||
|
},
|
||||||
|
Exec {
|
||||||
|
#[arg(short = 'n', long)]
|
||||||
|
name: String,
|
||||||
|
#[arg(short, long)]
|
||||||
|
command: String,
|
||||||
|
/// Output format: text, json
|
||||||
|
#[arg(short = 'F', long, default_value = "text")]
|
||||||
|
format: String,
|
||||||
|
},
|
||||||
|
Servers,
|
||||||
|
/// Add server dynamically (temporary)
|
||||||
|
AddServer {
|
||||||
|
#[arg(short = 'n', long)]
|
||||||
|
name: String,
|
||||||
|
#[arg(short = 'H', long)]
|
||||||
|
host: String,
|
||||||
|
#[arg(short = 'P', long, default_value = "22")]
|
||||||
|
port: u16,
|
||||||
|
#[arg(short = 'u', long, default_value = "root")]
|
||||||
|
user: String,
|
||||||
|
#[arg(short = 'p', long)]
|
||||||
|
password: Option<String>,
|
||||||
|
#[arg(short = 'k', long)]
|
||||||
|
private_key: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
|
Some(Commands::Server { port, host }) => run_server(&args.config, port, host).await,
|
||||||
|
Some(Commands::Exec { name, command, format }) => run_cli_exec(&args.server, &name, &command, &format),
|
||||||
|
Some(Commands::Servers) => run_cli_servers(&args.server),
|
||||||
|
Some(Commands::AddServer { name, host, port, user, password, private_key }) => {
|
||||||
|
run_cli_add_server(&args.server, name, host, port, user, password, private_key)
|
||||||
|
}
|
||||||
|
None => run_server(&args.config, None, None).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_server(config_path: &str, port: Option<u16>, host: Option<String>) -> anyhow::Result<()> {
|
||||||
|
println!("SSH HTTP Proxy v0.1.0\n");
|
||||||
|
let mut config = config::Config::from_file(config_path)?;
|
||||||
|
if let Some(p) = port { config.server.port = p; }
|
||||||
|
if let Some(h) = host { config.server.host = h; }
|
||||||
|
|
||||||
|
// 初始化日志
|
||||||
|
let log_path = std::env::var("SSH_PROXY_LOG").ok();
|
||||||
|
let logger = Arc::new(logger::RequestLogger::new(log_path.as_deref()));
|
||||||
|
if logger.is_enabled() {
|
||||||
|
println!("Request logging: enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Initializing SSH sessions...\n");
|
||||||
|
let manager = Arc::new(session::SessionManager::new(&config.servers)?);
|
||||||
|
|
||||||
|
let manager_clone = manager.clone();
|
||||||
|
let idle_timeout = config.pool.idle_timeout_secs;
|
||||||
|
let check_interval = config.pool.check_interval_secs;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(check_interval));
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
manager_clone.cleanup_idle(idle_timeout);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/exec", post(handler::exec))
|
||||||
|
.route("/servers", get(handler::servers))
|
||||||
|
.route("/servers/add", post(handler::add_server))
|
||||||
|
.route("/health", get(handler::health))
|
||||||
|
.with_state(Arc::new((manager, logger)));
|
||||||
|
|
||||||
|
let addr = format!("{}:{}", config.server.host, config.server.port);
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
|
|
||||||
|
println!("\nServer started at http://{}", addr);
|
||||||
|
println!("\nAPI Endpoints:");
|
||||||
|
println!(" POST /exec - Execute remote command");
|
||||||
|
println!(" POST /servers/add - Add server (temporary)");
|
||||||
|
println!(" GET /servers - List all servers");
|
||||||
|
println!(" GET /health - Health check");
|
||||||
|
println!("\nCLI Usage:");
|
||||||
|
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\"");
|
||||||
|
println!(" ssh-proxy exec -n flux_dev -c \"docker ps\" -F json");
|
||||||
|
println!(" ssh-proxy servers");
|
||||||
|
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cli_exec(server_url: &str, server: &str, command: &str, format: &str) -> anyhow::Result<()> {
|
||||||
|
let cli = cli::Cli::new(Some(server_url.to_string()));
|
||||||
|
if !cli.check_server()? {
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("=== ssh-proxy 未运行,请使用降级方案 ===");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("启动代理: ssh-proxy");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("或直接使用 ssh 命令:");
|
||||||
|
eprintln!(" ssh <user>@<host> \"<command>\"");
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("配置文件中的服务器信息见: ssh-proxy.toml");
|
||||||
|
anyhow::bail!("Proxy server not running at {}", server_url);
|
||||||
|
}
|
||||||
|
cli.exec(server, command, format)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cli_servers(server_url: &str) -> anyhow::Result<()> {
|
||||||
|
let cli = cli::Cli::new(Some(server_url.to_string()));
|
||||||
|
if !cli.check_server()? {
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("=== ssh-proxy 未运行 ===");
|
||||||
|
eprintln!("启动代理: ssh-proxy");
|
||||||
|
eprintln!("配置文件: ssh-proxy.toml");
|
||||||
|
anyhow::bail!("Proxy server not running at {}", server_url);
|
||||||
|
}
|
||||||
|
cli.list_servers()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_cli_add_server(
|
||||||
|
server_url: &str,
|
||||||
|
name: String,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
user: String,
|
||||||
|
password: Option<String>,
|
||||||
|
private_key: Option<String>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let cli = cli::Cli::new(Some(server_url.to_string()));
|
||||||
|
if !cli.check_server()? {
|
||||||
|
eprintln!();
|
||||||
|
eprintln!("=== ssh-proxy 未运行 ===");
|
||||||
|
eprintln!("启动代理: ssh-proxy");
|
||||||
|
anyhow::bail!("Proxy server not running at {}", server_url);
|
||||||
|
}
|
||||||
|
cli.add_server(name, host, port, user, password, private_key)
|
||||||
|
}
|
||||||
192
ssh-proxy/src/session.rs
Normal file
192
ssh-proxy/src/session.rs
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
use anyhow::{Result, bail};
|
||||||
|
use ssh2::Session;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::TcpStream;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use std::io::Read;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use crate::config::SshServerConfig;
|
||||||
|
|
||||||
|
struct SessionState {
|
||||||
|
session: Session,
|
||||||
|
last_used: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SessionManager {
|
||||||
|
sessions: Mutex<HashMap<String, SessionState>>,
|
||||||
|
configs: Mutex<HashMap<String, SshServerConfig>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionManager {
|
||||||
|
pub fn new(configs: &[SshServerConfig]) -> Result<Self> {
|
||||||
|
let mut config_map = HashMap::new();
|
||||||
|
for cfg in configs {
|
||||||
|
config_map.insert(cfg.name.clone(), cfg.clone());
|
||||||
|
println!(" Registered: {} ({}:{})", cfg.name, cfg.host, cfg.port);
|
||||||
|
}
|
||||||
|
println!("\n{} server(s) configured (lazy init)", config_map.len());
|
||||||
|
Ok(Self {
|
||||||
|
sessions: Mutex::new(HashMap::new()),
|
||||||
|
configs: Mutex::new(config_map),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_create_session(&self, name: &str) -> Result<Session> {
|
||||||
|
{
|
||||||
|
let mut sessions = self.sessions.lock().unwrap();
|
||||||
|
if let Some(state) = sessions.get_mut(name) {
|
||||||
|
if state.session.authenticated() {
|
||||||
|
state.last_used = Instant::now();
|
||||||
|
return Ok(state.session.clone());
|
||||||
|
} else {
|
||||||
|
println!("[Session] Session expired, reconnecting: {}", name);
|
||||||
|
sessions.remove(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg = self.configs.lock().unwrap()
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Server '{}' not found", name))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
println!("[LazyInit] Connecting to: {}", name);
|
||||||
|
let session = self.create_session(&cfg)?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut sessions = self.sessions.lock().unwrap();
|
||||||
|
sessions.insert(name.to_string(), SessionState {
|
||||||
|
session: session.clone(),
|
||||||
|
last_used: Instant::now(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("[LazyInit] Connected: {}", name);
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_session(&self, cfg: &SshServerConfig) -> Result<Session> {
|
||||||
|
let addr = format!("{}:{}", cfg.host, cfg.port);
|
||||||
|
let tcp = TcpStream::connect(&addr)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to connect to {}: {}", addr, e))?;
|
||||||
|
|
||||||
|
let mut session = Session::new()?;
|
||||||
|
session.set_tcp_stream(tcp);
|
||||||
|
session.handshake()?;
|
||||||
|
|
||||||
|
if let Some(key_path) = cfg.get_private_key_path() {
|
||||||
|
let pubkey_path: PathBuf = key_path.with_extension("pub");
|
||||||
|
session.userauth_pubkey_file(&cfg.user, Some(&pubkey_path), &key_path, None)?;
|
||||||
|
} else if let Some(ref password) = cfg.password {
|
||||||
|
session.userauth_password(&cfg.user, password)?;
|
||||||
|
} else {
|
||||||
|
bail!("No authentication method configured");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.authenticated() {
|
||||||
|
bail!("SSH authentication failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn exec(&self, name: &str, command: &str) -> Result<ExecResult> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let session = self.get_or_create_session(name)?;
|
||||||
|
|
||||||
|
let mut channel = session.channel_session()?;
|
||||||
|
channel.exec(command)?;
|
||||||
|
|
||||||
|
let mut stdout = String::new();
|
||||||
|
let mut stderr = String::new();
|
||||||
|
if let Err(e) = channel.read_to_string(&mut stdout) {
|
||||||
|
eprintln!("[SSH] stdout read error: {}", e);
|
||||||
|
}
|
||||||
|
if let Err(e) = channel.stderr().read_to_string(&mut stderr) {
|
||||||
|
eprintln!("[SSH] stderr read error: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.wait_close()?;
|
||||||
|
let exit_code = channel.exit_status()?;
|
||||||
|
|
||||||
|
Ok(ExecResult {
|
||||||
|
stdout,
|
||||||
|
stderr,
|
||||||
|
exit_code,
|
||||||
|
duration_ms: start.elapsed().as_millis() as u64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_servers(&self) -> Vec<ServerInfo> {
|
||||||
|
let sessions = self.sessions.lock().unwrap();
|
||||||
|
let configs = self.configs.lock().unwrap();
|
||||||
|
configs.iter().map(|(name, cfg)| {
|
||||||
|
let status = if sessions.contains_key(name) { "connected" } else { "pending" };
|
||||||
|
ServerInfo {
|
||||||
|
name: name.clone(),
|
||||||
|
host: cfg.host.clone(),
|
||||||
|
port: cfg.port,
|
||||||
|
user: cfg.user.clone(),
|
||||||
|
status: status.to_string(),
|
||||||
|
}
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 动态添加服务器 (临时,重启后消失)
|
||||||
|
pub fn add_server(&self, cfg: SshServerConfig) -> Result<()> {
|
||||||
|
let name = cfg.name.clone();
|
||||||
|
let mut configs = self.configs.lock().unwrap();
|
||||||
|
if configs.contains_key(&name) {
|
||||||
|
anyhow::bail!("Server '{}' already exists", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("[Dynamic] Adding: {} ({}:{})", name, cfg.host, cfg.port);
|
||||||
|
|
||||||
|
// 测试连接
|
||||||
|
let session = self.create_session(&cfg)?;
|
||||||
|
{
|
||||||
|
let mut sessions = self.sessions.lock().unwrap();
|
||||||
|
sessions.insert(name.clone(), SessionState {
|
||||||
|
session,
|
||||||
|
last_used: Instant::now(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
configs.insert(name.clone(), cfg);
|
||||||
|
println!("[Dynamic] ✓ Added: {}", name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cleanup_idle(&self, timeout_secs: u64) {
|
||||||
|
let mut sessions = self.sessions.lock().unwrap();
|
||||||
|
let now = Instant::now();
|
||||||
|
sessions.retain(|name, state| {
|
||||||
|
let elapsed = now.duration_since(state.last_used);
|
||||||
|
if elapsed > Duration::from_secs(timeout_secs) {
|
||||||
|
println!("[Cleanup] Removing idle session: {} (idle {}s)", name, elapsed.as_secs());
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct ExecResult {
|
||||||
|
pub stdout: String,
|
||||||
|
pub stderr: String,
|
||||||
|
pub exit_code: i32,
|
||||||
|
pub duration_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct ServerInfo {
|
||||||
|
pub name: String,
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub user: String,
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user