优化: 代理工具代码审查修复

This commit is contained in:
2026-04-01 18:11:06 +08:00
parent f77ed2572a
commit 38b16a6efc
9 changed files with 321 additions and 91 deletions

View File

@@ -77,8 +77,38 @@ impl Config {
impl ConnectionConfig {
pub fn build_url(&self) -> String {
match &self.password {
Some(pass) => format!("redis://:{}@{}:{}/{}", pass, self.host, self.port, self.db),
Some(pass) => format!("redis://:{}@{}:{}/{}", url_encode_password(pass), self.host, self.port, self.db),
None => format!("redis://{}:{}/{}", self.host, self.port, self.db),
}
}
}
fn url_encode_password(s: &str) -> String {
let mut result = String::new();
for c in s.chars() {
match c {
'@' => result.push_str("%40"),
':' => result.push_str("%3A"),
'/' => result.push_str("%2F"),
'?' => result.push_str("%3F"),
'#' => result.push_str("%23"),
'[' => result.push_str("%5B"),
']' => result.push_str("%5D"),
'!' => result.push_str("%21"),
'$' => result.push_str("%24"),
'&' => result.push_str("%26"),
'\'' => result.push_str("%27"),
'(' => result.push_str("%28"),
')' => result.push_str("%29"),
'*' => result.push_str("%2A"),
'+' => result.push_str("%2B"),
',' => result.push_str("%2C"),
';' => result.push_str("%3B"),
'=' => result.push_str("%3D"),
'%' => result.push_str("%25"),
' ' => result.push_str("%20"),
_ => result.push(c),
}
}
result
}

View File

@@ -38,8 +38,16 @@ impl ConnectionManager {
} else { None }
};
if let Some(client) = client {
let cfg = self.configs.lock().unwrap().get(name).unwrap().clone();
return Ok((client, cfg));
// Validate connection is still alive
if let Ok(mut conn) = client.get_connection() {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
let cfg = self.configs.lock().unwrap().get(name).unwrap().clone();
return Ok((client, cfg));
}
}
// Connection dead, remove stale client and recreate
println!("[Reconnect] Stale Redis client detected: {}, removing", name);
self.clients.lock().unwrap().remove(name);
}
let cfg = self.configs.lock().unwrap().get(name)
.ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?.clone();

View File

@@ -78,14 +78,20 @@ pub async fn run_cmd(State(state): State<AppState>, Json(req): Json<RunRequest>)
let conn_name = req.conn.clone();
let cmd_str = format!("{} {}", req.command, req.args.join(" "));
let result: serde_json::Value = tokio::task::spawn_blocking(move || -> Result<serde_json::Value, ApiError> {
let result: Result<serde_json::Value, ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err_usage(&e.to_string(), "Usage: POST /run {\"conn\": \"name\", \"command\": \"GET\", \"args\": [\"key\"]}"))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
let mut cmd = redis::cmd(&req.command);
for arg in &req.args { cmd.arg(arg); }
let val: redis::Value = cmd.query(&mut conn).map_err(|e| err(&e.to_string()))?;
Ok(redis_value_to_json(&val))
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
let result = result.map_err(|e| {
let msg = &e.1.error;
logger.log(&LogEntry::new("/run", "http").with_conn(&req.conn).with_command(&cmd_str).with_error(msg));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/run", "http").with_conn(&req.conn).with_command(&cmd_str).with_duration(duration.as_millis() as u64));
@@ -99,12 +105,17 @@ pub async fn get(State(state): State<AppState>, Json(req): Json<GetRequest>) ->
let conn_name = req.conn.clone();
let key_name = req.key.clone();
let value: Option<String> = tokio::task::spawn_blocking(move || -> Result<Option<String>, ApiError> {
let value: Result<Option<String>, ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err(&e.to_string()))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
let val: Option<String> = conn.get(&key_name).map_err(|e| err(&e.to_string()))?;
Ok(val)
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
let value = value.map_err(|e| {
logger.log(&LogEntry::new("/get", "http").with_conn(&req.conn).with_command(&format!("GET {}", req.key)).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/get", "http").with_conn(&req.conn).with_command(&format!("GET {}", req.key)).with_duration(duration.as_millis() as u64));
@@ -120,7 +131,7 @@ pub async fn set(State(state): State<AppState>, Json(req): Json<SetRequest>) ->
let value_str = req.value.clone();
let ttl = req.ttl;
tokio::task::spawn_blocking(move || -> Result<(), ApiError> {
let result: Result<(), ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err(&e.to_string()))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
if let Some(secs) = ttl {
@@ -129,7 +140,12 @@ pub async fn set(State(state): State<AppState>, Json(req): Json<SetRequest>) ->
conn.set::<_, _, ()>(&key_name, &value_str).map_err(|e| err(&e.to_string()))?;
}
Ok(())
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
result.map_err(|e| {
logger.log(&LogEntry::new("/set", "http").with_conn(&req.conn).with_command(&format!("SET {} {}", req.key, req.value)).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/set", "http").with_conn(&req.conn).with_command(&format!("SET {} {}", req.key, req.value)).with_duration(duration.as_millis() as u64));
@@ -143,12 +159,17 @@ pub async fn del(State(state): State<AppState>, Json(req): Json<DelRequest>) ->
let conn_name = req.conn.clone();
let del_keys = req.keys.clone();
let deleted: u64 = tokio::task::spawn_blocking(move || -> Result<u64, ApiError> {
let deleted: Result<u64, ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err(&e.to_string()))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
let count: u64 = conn.del(&del_keys).map_err(|e| err(&e.to_string()))?;
Ok(count)
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
let deleted = deleted.map_err(|e| {
logger.log(&LogEntry::new("/del", "http").with_conn(&req.conn).with_command(&format!("DEL {}", req.keys.join(" "))).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/del", "http").with_conn(&req.conn).with_command(&format!("DEL {}", req.keys.join(" "))).with_duration(duration.as_millis() as u64));
@@ -162,12 +183,17 @@ pub async fn keys(State(state): State<AppState>, Json(req): Json<KeysRequest>) -
let conn_name = req.conn.clone();
let pattern = req.pattern.clone();
let result: Vec<String> = tokio::task::spawn_blocking(move || -> Result<Vec<String>, ApiError> {
let result: Result<Vec<String>, ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err(&e.to_string()))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
let keys: Vec<String> = conn.keys(&pattern).map_err(|e| err(&e.to_string()))?;
Ok(keys)
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
let result = result.map_err(|e| {
logger.log(&LogEntry::new("/keys", "http").with_conn(&req.conn).with_command(&format!("KEYS {}", req.pattern)).with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/keys", "http").with_conn(&req.conn).with_command(&format!("KEYS {}", req.pattern)).with_duration(duration.as_millis() as u64));
@@ -180,12 +206,17 @@ pub async fn info(State(state): State<AppState>, Json(req): Json<InfoRequest>) -
let mgr = manager.clone();
let conn_name = req.conn.clone();
let info: String = tokio::task::spawn_blocking(move || -> Result<String, ApiError> {
let info: Result<String, ApiError> = tokio::task::spawn_blocking(move || {
let (client, _) = mgr.get_conn(&conn_name).map_err(|e| err(&e.to_string()))?;
let mut conn = client.get_connection().map_err(|e| err(&e.to_string()))?;
let info: String = redis::cmd("INFO").query(&mut conn).map_err(|e| err(&e.to_string()))?;
Ok(info)
}).await.map_err(|_| err("Task join error"))??;
}).await.map_err(|_| err("Task join error"))?;
let info = info.map_err(|e| {
logger.log(&LogEntry::new("/info", "http").with_conn(&req.conn).with_command("INFO").with_error(&e.1.error));
e
})?;
let duration = start.elapsed();
logger.log(&LogEntry::new("/info", "http").with_conn(&req.conn).with_command("INFO").with_duration(duration.as_millis() as u64));
@@ -203,7 +234,10 @@ pub async fn health(State(state): State<AppState>) -> Json<HealthResponse> {
pub async fn add_connection(State(state): State<AppState>, Json(req): Json<AddConnectionRequest>) -> Result<Json<serde_json::Value>, ApiError> {
let (manager, logger) = state.as_ref();
let cfg = crate::config::ConnectionConfig { name: req.name.clone(), host: req.host, port: req.port, password: req.password, db: req.db };
manager.add_connection(cfg).map_err(|e| err(&e.to_string()))?;
manager.add_connection(cfg).map_err(|e| {
logger.log(&LogEntry::new("/connections/add", "http").with_conn(&req.name).with_error(&e.to_string()));
err(&e.to_string())
})?;
logger.log(&LogEntry::new("/connections/add", "http").with_conn(&req.name).with_duration(0));
Ok(Json(serde_json::json!({"success": true, "message": format!("Connection '{}' added (temporary)", req.name)})))
}