优化: 代理工具代码审查修复

This commit is contained in:
2026-04-01 18:11:06 +08:00
parent f77ed2572a
commit 38b16a6efc
9 changed files with 321 additions and 91 deletions

View File

@@ -100,6 +100,9 @@ impl ErrorResponse {
// ============== 处理器 ==============
struct QueryResult { columns: Vec<String>, rows: Vec<Vec<Option<String>>> }
struct ExecResult { affected: u64, last_id: u64 }
/// 查询处理器
pub async fn query(
State(state): State<AppState>,
@@ -108,14 +111,6 @@ pub async fn query(
let (manager, logger) = &*state;
let start = Instant::now();
// 获取连接
let mut conn = manager.get_conn(&req.conn)
.map_err(|e| {
let usage = "Usage: POST /query {\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}\n"
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
error_response_with_usage(&e.to_string(), &usage)
})?;
// 判断是否是查询语句
let sql_upper = req.sql.trim().to_uppercase();
let is_query = sql_upper.starts_with("SELECT")
@@ -126,43 +121,45 @@ pub async fn query(
|| sql_upper.starts_with("WITH");
if !is_query {
let err = error_response_with_usage(
logger.log(&LogEntry::new("/query", "http")
.with_conn(&req.conn).with_sql(&req.sql).with_error("Not a SELECT query"));
return Err(error_response_with_usage(
"Not a SELECT query. Use /execute for INSERT/UPDATE/DELETE",
"Usage:\n POST /query {\"conn\": \"name\", \"sql\": \"SELECT ...\"}\n POST /execute {\"conn\": \"name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}"
);
return Err(err);
));
}
// 执行查询
let result = conn.query_iter(&req.sql)
.map_err(|e| {
let usage = format!("SQL Error: {}\n\nUsage: POST /query {{\"conn\": \"name\", \"sql\": \"SELECT ...\"}}", e);
let mgr = manager.clone();
let conn_name = req.conn.clone();
let sql = req.sql.clone();
let conn_names = list_conn_names(manager);
let result = tokio::task::spawn_blocking(move || -> Result<QueryResult, (StatusCode, Json<ErrorResponse>)> {
let mut conn = mgr.get_conn(&conn_name).map_err(|e| {
let usage = format!("Usage: POST /query {{\"conn\": \"connection_name\", \"sql\": \"SELECT ...\"}}\nAvailable connections: {}", conn_names);
error_response_with_usage(&e.to_string(), &usage)
})?;
let mut columns: Vec<String> = Vec::new();
let mut data: Vec<Vec<Option<String>>> = Vec::new();
// 获取数据并从第一行提取列名
for row_result in result {
let row = row_result.map_err(|e| error_response(&e.to_string()))?;
// 从第一行获取列名
if columns.is_empty() {
columns = row.columns()
.iter()
.map(|c| c.name_str().to_string())
.collect();
let result = conn.query_iter(&sql).map_err(|e| error_response(&e.to_string()))?;
let mut columns: Vec<String> = Vec::new();
let mut data: Vec<Vec<Option<String>>> = Vec::new();
for row_result in result {
let row = row_result.map_err(|e| error_response(&e.to_string()))?;
if columns.is_empty() {
columns = row.columns().iter().map(|c| c.name_str().to_string()).collect();
}
data.push(row_to_strings(&row, columns.len()));
}
Ok(QueryResult { columns, rows: data })
}).await.map_err(|_| error_response("Task join error"))?;
let values = row_to_strings(&row, columns.len());
data.push(values);
}
let result = result.map_err(|e| {
logger.log(&LogEntry::new("/query", "http").with_conn(&req.conn).with_sql(&req.sql).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
let row_count = data.len();
let row_count = result.rows.len();
// 记录日志
logger.log(&LogEntry::new("/query", "http")
.with_conn(&req.conn)
.with_sql(&req.sql)
@@ -170,8 +167,8 @@ pub async fn query(
.with_rows(row_count));
Ok(Json(QueryResponse {
columns,
rows: data,
columns: result.columns,
rows: result.rows,
row_count,
duration_ms: duration.as_millis() as u64,
}))
@@ -185,31 +182,35 @@ pub async fn execute(
let (manager, logger) = &*state;
let start = Instant::now();
// 获取连接
let mut conn = manager.get_conn(&req.conn)
.map_err(|e| {
let usage = "Usage: POST /execute {\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}\n"
.to_string() + &format!("Available connections: {}", list_conn_names(&manager));
let mgr = manager.clone();
let conn_name = req.conn.clone();
let sql = req.sql.clone();
let conn_names = list_conn_names(manager);
let exec_result = tokio::task::spawn_blocking(move || -> Result<ExecResult, (StatusCode, Json<ErrorResponse>)> {
let mut conn = mgr.get_conn(&conn_name).map_err(|e| {
let usage = format!("Usage: POST /execute {{\"conn\": \"connection_name\", \"sql\": \"INSERT/UPDATE/DELETE ...\"}}\nAvailable connections: {}", conn_names);
error_response_with_usage(&e.to_string(), &usage)
})?;
let result = conn.query_iter(&sql).map_err(|e| error_response(&e.to_string()))?;
Ok(ExecResult { affected: result.affected_rows(), last_id: result.last_insert_id().unwrap_or(0) })
}).await.map_err(|_| error_response("Task join error"))?;
// 执行
let result = conn.query_iter(&req.sql)
.map_err(|e| error_response(&e.to_string()))?;
let exec_result = exec_result.map_err(|e| {
logger.log(&LogEntry::new("/execute", "http").with_conn(&req.conn).with_sql(&req.sql).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
let affected = result.affected_rows();
let last_id = result.last_insert_id().unwrap_or(0);
// 记录日志
logger.log(&LogEntry::new("/execute", "http")
.with_conn(&req.conn)
.with_sql(&req.sql)
.with_duration(duration.as_millis() as u64));
Ok(Json(ExecuteResponse {
affected_rows: affected,
last_insert_id: last_id,
affected_rows: exec_result.affected,
last_insert_id: exec_result.last_id,
duration_ms: duration.as_millis() as u64,
}))
}
@@ -254,7 +255,10 @@ pub async fn add_connection(
};
manager.add_connection(cfg)
.map_err(|e| error_response(&e.to_string()))?;
.map_err(|e| {
logger.log(&LogEntry::new("/connections/add", "http").with_conn(&req.name).with_error(&e.to_string()));
error_response(&e.to_string())
})?;
// 记录日志
logger.log(&LogEntry::new("/connections/add", "http")