Private
Public Access
1
0

新增:Markdown编辑器/数据库优化/安全修复

- Markdown 编辑器:实时预览、PDF 导出、独立查看器
- 数据库优化:动态连接池、查询缓存、Redis Pipeline
- 窗口置顶功能
- 文件系统增强:右键菜单、编辑器集成、收藏夹重构
- 安全修复:XSS 防护、路径穿越、HTML 注入
- 代码质量:正则预编译、缓存锁优化、死代码清理
This commit is contained in:
2026-03-31 09:18:06 +08:00
parent 5f94ccf13b
commit e5dbe89a6f
59 changed files with 5289 additions and 1316 deletions

View File

@@ -1,18 +1,18 @@
package api
import (
"u-desk/internal/storage"
"u-desk/internal/service"
"u-desk/internal/storage/models"
)
// ConnectionAPI 连接管理API
type ConnectionAPI struct {
connService *storage.ConnectionService
connService *service.ConnectionService
}
// NewConnectionAPI 创建连接管理API
func NewConnectionAPI() (*ConnectionAPI, error) {
connService, err := storage.NewConnectionService()
connService, err := service.NewConnectionService()
if err != nil {
return nil, err
}
@@ -82,11 +82,7 @@ func (api *ConnectionAPI) DeleteDbConnection(id uint) error {
}
func (api *ConnectionAPI) TestDbConnection(id uint) error {
conn, err := api.connService.GetConnection(id)
if err != nil {
return err
}
return api.connService.TestConnection(conn)
return api.connService.TestConnection(id)
}
// TestConnectionRequest 测试连接请求结构体(不保存数据)
@@ -104,14 +100,9 @@ type TestConnectionRequest struct {
// TestDbConnectionWithParams 测试数据库连接(直接传入参数,不保存数据)
func (api *ConnectionAPI) TestDbConnectionWithParams(req TestConnectionRequest) error {
return api.connService.TestConnectionWithParams(
req.Type,
req.Host,
req.Port,
req.Username,
req.Password,
req.Database,
req.Options,
req.ID,
req.Type, req.Host, req.Port,
req.Username, req.Password, req.Database,
req.Options, req.ID,
)
}
@@ -130,13 +121,8 @@ type LoadAllDatabasesRequest struct {
// LoadAllDatabases 加载全部数据库列表
func (api *ConnectionAPI) LoadAllDatabases(req LoadAllDatabasesRequest) ([]string, error) {
return api.connService.LoadAllDatabases(
req.Type,
req.Host,
req.Port,
req.Username,
req.Password,
req.Database,
req.Options,
req.ID,
req.Type, req.Host, req.Port,
req.Username, req.Password, req.Database,
req.Options, req.ID,
)
}

379
internal/api/pdf_api.go Normal file
View File

@@ -0,0 +1,379 @@
package api
import (
"context"
"fmt"
"html"
"os"
"path/filepath"
"strings"
"time"
"github.com/chromedp/cdproto/page"
"github.com/chromedp/chromedp"
"github.com/yuin/goldmark"
"u-desk/internal/common"
)
// PdfExportRequest PDF导出请求结构体
type PdfExportRequest struct {
Content string `json:"content"` // Markdown/HTML内容
Title string `json:"title"` // PDF标题
FileName string `json:"fileName"` // 文件名(不含扩展名)
FontSize int `json:"fontSize"` // 字体大小
PageWidth int `json:"pageWidth"` // 页面宽度mm
PageHeight int `json:"pageHeight"` // 页面高度mm
}
// PdfExportResponse PDF导出响应结构体
type PdfExportResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Path string `json:"path"` // PDF文件保存路径
Size int64 `json:"size"` // 文件大小(字节)
}
// PdfAPI PDF导出API
type PdfAPI struct {
// 可以在这里添加依赖,如文件系统服务等
}
// NewPdfAPI 创建PDF导出API
func NewPdfAPI() (*PdfAPI, error) {
return &PdfAPI{}, nil
}
// ExportMarkdownToPDF 将Markdown内容导出为PDF - 使用chromedp实现
func (api *PdfAPI) ExportMarkdownToPDF(req PdfExportRequest) (*PdfExportResponse, error) {
// 验证参数
if strings.TrimSpace(req.Content) == "" {
return nil, fmt.Errorf("内容不能为空")
}
if strings.TrimSpace(req.FileName) == "" {
req.FileName = "document_" + time.Now().Format("20060102_150405")
}
if req.FontSize <= 0 {
req.FontSize = 12
}
// 设置默认页面尺寸A4
if req.PageWidth <= 0 {
req.PageWidth = 210
}
if req.PageHeight <= 0 {
req.PageHeight = 297
}
// 将Markdown转换为HTML
htmlContent := api.markdownToHTML(req.Content, req.Title, req.FontSize)
// 使用chromedp生成PDF
pdfBuffer, err := api.generatePDFFromHTML(htmlContent, req.Title, req.PageWidth, req.PageHeight)
if err != nil {
return nil, fmt.Errorf("生成PDF失败: %v", err)
}
// 生成文件名
if !strings.HasSuffix(strings.ToLower(req.FileName), ".pdf") {
req.FileName += ".pdf"
}
// 获取用户桌面目录作为默认保存位置
saveDir := api.getDesktopDirectory()
// 确保目录存在
if err := os.MkdirAll(saveDir, 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %v", err)
}
// 完整保存路径
savePath := filepath.Join(saveDir, filepath.Base(req.FileName))
// 保存PDF文件
err = os.WriteFile(savePath, pdfBuffer, 0644)
if err != nil {
return nil, fmt.Errorf("保存PDF文件失败: %v", err)
}
// 获取文件信息
fileInfo, err := os.Stat(savePath)
if err != nil {
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
// 返回成功响应
return &PdfExportResponse{
Success: true,
Message: "PDF生成成功",
Path: savePath,
Size: fileInfo.Size(),
}, nil
}
// markdownToHTML 将Markdown转换为HTML
func (api *PdfAPI) markdownToHTML(markdownContent string, title string, fontSize int) string {
// 基础HTML模板
htmlTemplate := `<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
line-height: 1.6;
color: #333;
max-width: 800px;
margin: 0 auto;
padding: 40px 20px;
font-size: %dpx;
}
h1, h2, h3, h4, h5, h6 {
margin-top: 24px;
margin-bottom: 16px;
font-weight: 600;
line-height: 1.25;
}
h1 {
font-size: 2em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
h2 {
font-size: 1.5em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
h3 {
font-size: 1.25em;
}
h4 {
font-size: 1em;
}
h5 {
font-size: 0.875em;
}
h6 {
font-size: 0.85em;
color: #6a737d;
}
p {
margin-bottom: 16px;
}
blockquote {
margin: 0 0 16px;
padding: 0 1em;
color: #6a737d;
border-left: 0.25em solid #dfe2e5;
}
ul, ol {
padding-left: 2em;
margin-bottom: 16px;
}
li {
margin-bottom: 4px;
}
code {
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace;
background-color: rgba(27,31,35,0.05);
border-radius: 3px;
font-size: 85%;
margin: 0;
padding: 0.2em 0.4em;
}
pre {
background-color: #f6f8fa;
border-radius: 3px;
padding: 16px;
overflow: auto;
margin-bottom: 16px;
}
pre code {
background-color: transparent;
padding: 0;
margin: 0;
font-size: 100%;
}
table {
border-collapse: collapse;
width: 100%;
margin-bottom: 16px;
border: 1px solid #dfe2e5;
}
th, td {
padding: 8px 12px;
border: 1px solid #dfe2e5;
text-align: left;
}
th {
background-color: #f6f8fa;
font-weight: 600;
}
img {
max-width: 100%;
height: auto;
margin: 16px 0;
}
hr {
height: 0.25em;
padding: 0;
margin: 24px 0;
background-color: #e1e4e8;
border: 0;
}
.title {
text-align: center;
margin-bottom: 32px;
font-size: 1.5em;
color: #2c3e50;
}
</style>
</head>
<body>
<div class="title">%s</div>
%s
</body>
</html>`
// 标题处理
docTitle := ""
if title != "" {
docTitle = html.EscapeString(title)
} else {
docTitle = "文档"
}
// Markdown转HTML使用goldmark
var htmlContent string
var htmlBuf strings.Builder
if err := goldmark.Convert([]byte(markdownContent), &htmlBuf); err != nil {
htmlContent = "<p>Markdown 解析失败</p>"
} else {
htmlContent = htmlBuf.String()
}
// 生成完整的HTML
fullHTML := fmt.Sprintf(htmlTemplate, fontSize, docTitle, htmlContent)
return fullHTML
}
// generatePDFFromHTML 使用chromedp从HTML生成PDF
func (api *PdfAPI) generatePDFFromHTML(htmlContent, title string, pageWidth, pageHeight int) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 配置chromedp选项
opts := []chromedp.ExecAllocatorOption{
chromedp.Flag("headless", true),
chromedp.Flag("disable-gpu", true),
chromedp.Flag("no-sandbox", true),
chromedp.Flag("disable-dev-shm-usage", true),
chromedp.Flag("disable-software-rasterizer", true),
chromedp.Flag("disable-extensions", true),
chromedp.Flag("disable-notifications", true),
}
// 在Windows上设置Chrome路径
if common.IsWindows() {
// 常见的Windows Chrome路径
chromePaths := []string{
"C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe",
"C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe",
"C:\\Users\\" + os.Getenv("USERNAME") + "\\AppData\\Local\\Google\\Chrome\\Application\\chrome.exe",
}
for _, path := range chromePaths {
if _, err := os.Stat(path); err == nil {
opts = append(opts, chromedp.ExecPath(path))
break
}
}
}
// 创建执行分配器上下文
allocCtx, allocCancel := chromedp.NewExecAllocator(ctx, opts...)
defer allocCancel()
// 创建chromedp上下文
chromeCtx, chromeCancel := chromedp.NewContext(allocCtx)
defer chromeCancel()
// 创建一个临时的目录用于PDF生成
tempDir, err := os.MkdirTemp("", "pdf_gen")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %v", err)
}
defer os.RemoveAll(tempDir)
// 将HTML写入临时文件
htmlFile := filepath.Join(tempDir, "document.html")
if err := os.WriteFile(htmlFile, []byte(htmlContent), 0644); err != nil {
return nil, fmt.Errorf("写入HTML文件失败: %v", err)
}
var buf []byte
// 使用 file URL 加载本地HTML文件
err = chromedp.Run(chromeCtx,
// 导航到HTML文件
chromedp.Navigate("file://"+htmlFile),
// 等待页面加载完成
chromedp.WaitReady("body"),
// 打印到PDF
chromedp.ActionFunc(func(ctx context.Context) error {
// 设置页面打印参数
printToPDF := page.PrintToPDF().
WithPrintBackground(true).
WithLandscape(false).
WithMarginTop(0).
WithMarginBottom(0).
WithMarginLeft(0).
WithMarginRight(0).
WithPaperWidth(float64(pageWidth) / 25.4). // mm to inches
WithPaperHeight(float64(pageHeight) / 25.4) // mm to inches
// 执行打印并获取PDF数据
var err error
buf, _, err = printToPDF.Do(ctx)
return err
}),
)
if err != nil {
return nil, fmt.Errorf("chromedp执行失败: %v", err)
}
return buf, nil
}
// getDesktopDirectory 获取用户桌面目录
func (api *PdfAPI) getDesktopDirectory() string {
// Windows系统
if common.IsWindows() {
home := os.Getenv("USERPROFILE")
if home != "" {
return filepath.Join(home, "Desktop")
}
}
// Linux/Mac系统
home := os.Getenv("HOME")
if home != "" {
return filepath.Join(home, "Desktop")
}
// 备用:当前目录
return "."
}
// SelectDirectory 选择保存目录简化版实际应该使用Wails runtime
func (api *PdfAPI) SelectDirectory() (string, error) {
// 简化版:直接返回桌面目录
desktop := api.getDesktopDirectory()
if desktop == "." {
return "", fmt.Errorf("无法确定默认目录")
}
return desktop, nil
}

View File

@@ -2,6 +2,7 @@ package common
import (
"fmt"
"runtime"
)
// InterfaceSliceToStringSlice 将 []interface{} 安全转换为 []string
@@ -54,3 +55,18 @@ func Difference[T comparable](a, b []T) []T {
}
return diff
}
// IsWindows 判断是否为Windows系统
func IsWindows() bool {
return runtime.GOOS == "windows"
}
// IsMac 判断是否为Mac系统
func IsMac() bool {
return runtime.GOOS == "darwin"
}
// IsLinux 判断是否为Linux系统
func IsLinux() bool {
return runtime.GOOS == "linux"
}

View File

@@ -7,20 +7,106 @@ import (
"encoding/base64"
"fmt"
"io"
"os"
"path/filepath"
"sync"
)
// 旧版硬编码密钥(用于兼容迁移已有加密数据)
var legacyKey = []byte("go-desk-db-cli-key-32bytes123456")
var (
// 默认密钥(实际应用中应该从配置文件或环境变量读取)
// AES-256 需要 32 字节密钥
// "go-desk-db-cli-key-32bytes123456" = 32 bytes
defaultKey = []byte("go-desk-db-cli-key-32bytes123456") // 32 bytes for AES-256
encryptionKey []byte
keyOnce sync.Once
keyInitErr error
)
func init() {
// 验证密钥长度
if len(defaultKey) != 32 {
panic(fmt.Sprintf("AES-256 密钥长度必须为 32 字节,当前为 %d 字节", len(defaultKey)))
// getKey 获取或创建机器唯一密钥
// 首次启动时生成并持久化到用户配置目录,后续直接读取
func getKey() ([]byte, error) {
keyOnce.Do(func() {
keyFile, err := getKeyFilePath()
if err != nil {
keyInitErr = fmt.Errorf("获取密钥路径失败: %v", err)
return
}
// 尝试读取已有密钥
if data, err := os.ReadFile(keyFile); err == nil && len(data) == 32 {
encryptionKey = data
return
}
// 生成新密钥
newKey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, newKey); err != nil {
keyInitErr = fmt.Errorf("生成密钥失败: %v", err)
return
}
// 持久化密钥
dir := filepath.Dir(keyFile)
if err := os.MkdirAll(dir, 0700); err != nil {
keyInitErr = fmt.Errorf("创建密钥目录失败: %v", err)
return
}
if err := os.WriteFile(keyFile, newKey, 0600); err != nil {
keyInitErr = fmt.Errorf("保存密钥失败: %v", err)
return
}
encryptionKey = newKey
})
return encryptionKey, keyInitErr
}
// getKeyFilePath 返回密钥文件路径
func getKeyFilePath() (string, error) {
configDir, err := os.UserConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "u-desk", ".aes-key"), nil
}
// DecryptPasswordV2 使用指定密钥解密(用于密钥迁移)
func DecryptPasswordV2(encryptedPassword string, key []byte) (string, error) {
if encryptedPassword == "" {
return "", nil
}
if len(encryptedPassword) < 10 {
return "", nil
}
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", fmt.Errorf("解码失败: %v", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("创建解密器失败: %v", err)
}
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("创建 GCM 失败: %v", err)
}
nonceSize := aesGCM.NonceSize()
if len(ciphertext) < nonceSize {
return "", fmt.Errorf("密文长度不足")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("解密失败: %v", err)
}
return string(plaintext), nil
}
// EncryptPassword 加密密码
@@ -29,7 +115,12 @@ func EncryptPassword(password string) (string, error) {
return "", nil
}
block, err := aes.NewCipher(defaultKey)
key, err := getKey()
if err != nil {
return "", fmt.Errorf("获取加密密钥失败: %v", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("创建加密器失败: %v", err)
}
@@ -53,47 +144,32 @@ func EncryptPassword(password string) (string, error) {
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptPassword 解密密码
// DecryptPassword 解密密码(自动回退旧密钥兼容旧数据)
func DecryptPassword(encryptedPassword string) (string, error) {
if encryptedPassword == "" {
return "", nil
}
// 如果加密字符串为空或格式不正确,返回空字符串
if len(encryptedPassword) < 10 {
return "", nil
}
// Base64 解码
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
key, err := getKey()
if err != nil {
return "", fmt.Errorf("解码失败: %v", err)
return "", fmt.Errorf("获取解密密钥失败: %v", err)
}
block, err := aes.NewCipher(defaultKey)
if err != nil {
return "", fmt.Errorf("创建解密器失败: %v", err)
// 先用新密钥尝试解密
result, err := DecryptPasswordV2(encryptedPassword, key)
if err == nil {
return result, nil
}
// 使用 GCM 模式
aesGCM, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("创建 GCM 失败: %v", err)
// 新密钥失败,尝试旧密钥(兼容已迁移的旧数据)
result, err = DecryptPasswordV2(encryptedPassword, legacyKey)
if err == nil {
return result, nil
}
// 提取 nonce
nonceSize := aesGCM.NonceSize()
if len(ciphertext) < nonceSize {
return "", fmt.Errorf("密文长度不足")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
// 解密
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("解密失败: %v", err)
}
return string(plaintext), nil
// 两种密钥都失败
return "", fmt.Errorf("解密失败: %v", err)
}

479
internal/dbclient/cache.go Normal file
View File

@@ -0,0 +1,479 @@
package dbclient
import (
"crypto/sha256"
"fmt"
"sync"
"time"
)
// QueryCache 查询缓存
type QueryCache struct {
items map[string]*CachedQuery
size int
ttl time.Duration
mu sync.RWMutex
stopCh chan struct{}
wg sync.WaitGroup
// 智能缓存策略
hitRate float64 // 缓存命中率
hitCount int64 // 命中次数
missCount int64 // 未命中次数
evictionCount int64 // 驱逐次数
hotQueries map[string]bool // 热点查询标记
cooldowns map[string]time.Time // 冷却时间(避免频繁驱逐)
// 内存限制
maxMemoryBytes int64 // 缓存最大内存(字节),默认 100MB
usedMemory int64 // 当前估算内存使用量
}
// NewQueryCache 创建新的查询缓存
func NewQueryCache(size int, ttl time.Duration) *QueryCache {
cache := &QueryCache{
items: make(map[string]*CachedQuery),
size: size,
ttl: ttl,
stopCh: make(chan struct{}),
hitRate: 0.0,
hitCount: 0,
missCount: 0,
evictionCount: 0,
hotQueries: make(map[string]bool),
cooldowns: make(map[string]time.Time),
maxMemoryBytes: 100 * 1024 * 1024, // 默认 100MB
}
// 启动清理协程
cache.StartCleanup()
// 启动统计协程
cache.StartStatsCollection()
return cache
}
// Get 从缓存中获取查询结果
func (c *QueryCache) Get(params QueryParams) (*CachedQuery, error) {
key := c.generateKey(params)
c.mu.RLock()
item, exists := c.items[key]
if !exists {
c.missCount++
_, inCooldown := c.cooldowns[key]
if inCooldown && time.Now().Before(c.cooldowns[key]) {
c.mu.RUnlock()
return nil, ErrCacheCooldown
}
c.mu.RUnlock()
return nil, ErrCacheNotFound
}
// 检查是否过期
if time.Now().After(item.ExpiryTime) {
if c.isHotQuery(key) {
c.mu.RUnlock()
c.mu.Lock()
item.ExpiryTime = time.Now().Add(c.ttl)
c.hitCount++
c.markAsHot(key)
c.mu.Unlock()
return item, nil
}
c.mu.RUnlock()
c.mu.Lock()
delete(c.items, key)
c.evictionCount++
c.missCount++
c.mu.Unlock()
return nil, ErrCacheExpired
}
// 命中
c.hitCount++
needsMark := !c.hotQueries[key]
c.mu.RUnlock()
if needsMark {
c.mu.Lock()
c.markAsHot(key)
c.mu.Unlock()
}
return item, nil
}
// Set 将查询结果存入缓存
func (c *QueryCache) Set(params QueryParams, item *CachedQuery) {
key := c.generateKey(params)
// 估算条目内存大小
itemSize := c.estimateSize(params, item)
c.mu.Lock()
defer c.mu.Unlock()
// 更新统计
c.recordQueryAttempt(key)
// 如果超过内存限制,执行驱逐直到有空间
for c.usedMemory+itemSize > c.maxMemoryBytes && len(c.items) > 0 {
c.smartEvict(key)
}
// 如果条目数已满,执行智能驱逐
if len(c.items) >= c.size {
c.smartEvict(key)
}
// 如果已有旧条目,先减去旧的大小
if old, exists := c.items[key]; exists {
c.usedMemory -= c.estimateItemSize(old)
}
c.items[key] = item
c.usedMemory += itemSize
// 标记为热点查询
c.markAsHot(key)
}
// smartEvict 智能驱逐策略
func (c *QueryCache) smartEvict(newKey string) {
if len(c.items) == 0 {
return
}
// LRU + LFU 混合策略
var evictKey string
var worstScore float64 = -1
for key, item := range c.items {
if key == newKey {
continue
}
score := c.calculateEvictionScore(key, item)
if score > worstScore {
worstScore = score
evictKey = key
}
}
if evictKey != "" {
if evicted, exists := c.items[evictKey]; exists {
c.usedMemory -= c.estimateItemSize(evicted)
}
c.cooldowns[evictKey] = time.Now().Add(1 * time.Minute)
delete(c.items, evictKey)
c.evictionCount++
}
}
// calculateEvictionScore 计算驱逐分数(越低越适合保留)
func (c *QueryCache) calculateEvictionScore(key string, item *CachedQuery) float64 {
now := time.Now()
// 基础分数
score := 1.0
// 热点查询加分(优先保留)
if c.isHotQuery(key) {
score -= 0.5
}
// 接近过期的加分(优先驱逐即将过期的)
if item.ExpiryTime.Sub(now) < c.ttl/2 {
score += 0.3
}
// 最近使用的加分(优先保留最近使用的)
if !item.LastUsed.IsZero() {
recency := now.Sub(item.LastUsed)
if recency < 5*time.Minute {
score -= 0.2
}
}
return score
}
// isHotQuery 检查是否为热点查询
func (c *QueryCache) isHotQuery(key string) bool {
return c.hotQueries[key]
}
// markAsHot 标记为热点查询
func (c *QueryCache) markAsHot(key string) {
c.hotQueries[key] = true
}
// cleanupHotMarkers 清理热点标记
func (c *QueryCache) cleanupHotMarkers() {
now := time.Now()
for key := range c.hotQueries {
// 清理超过10分钟未使用的热点标记
if item, exists := c.items[key]; exists {
if now.Sub(item.LastUsed) > 10*time.Minute {
delete(c.hotQueries, key)
}
} else {
delete(c.hotQueries, key)
}
}
}
// recordQueryAttempt 记录查询尝试
func (c *QueryCache) recordQueryAttempt(key string) {
// 更新命中率
c.updateHitRate()
// 更新最后使用时间
if item, exists := c.items[key]; exists {
item.LastUsed = time.Now()
}
}
// updateHitRate 更新命中率
func (c *QueryCache) updateHitRate() {
total := c.hitCount + c.missCount
if total > 0 {
c.hitRate = float64(c.hitCount) / float64(total)
}
}
// Delete 从缓存中删除指定查询
func (c *QueryCache) Delete(params QueryParams) {
key := c.generateKey(params)
c.mu.Lock()
defer c.mu.Unlock()
if item, exists := c.items[key]; exists {
c.usedMemory -= c.estimateItemSize(item)
delete(c.items, key)
}
}
// Clear 清空整个缓存
func (c *QueryCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[string]*CachedQuery)
c.usedMemory = 0
}
// Size 获取缓存大小
func (c *QueryCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.items)
}
// CleanupExpired 清理过期的缓存条目
func (c *QueryCache) CleanupExpired() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
for key, item := range c.items {
if now.After(item.ExpiryTime) {
c.usedMemory -= c.estimateItemSize(item)
delete(c.items, key)
}
}
}
// Keys 获取缓存中所有的键
func (c *QueryCache) Keys() []string {
c.mu.RLock()
defer c.mu.RUnlock()
keys := make([]string, 0, len(c.items))
for key := range c.items {
keys = append(keys, key)
}
return keys
}
// Stats 获取缓存统计信息
func (c *QueryCache) Stats() CacheStats {
c.mu.RLock()
defer c.mu.RUnlock()
now := time.Now()
expired := 0
active := 0
for _, item := range c.items {
if now.After(item.ExpiryTime) {
expired++
} else {
active++
}
}
return CacheStats{
TotalItems: len(c.items),
ActiveItems: active,
ExpiredItems: expired,
Size: c.size,
TTL: c.ttl,
HitRate: c.hitRate,
HitCount: c.hitCount,
MissCount: c.missCount,
EvictionCount: c.evictionCount,
HotQueries: len(c.hotQueries),
}
}
// generateKey 生成缓存键
func (c *QueryCache) generateKey(params QueryParams) string {
key := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
params.SQL, params.Database, params.Limit, params.Offset,
params.Table, params.Where, params.SortBy, params.IsReadOnly)
h := sha256.Sum256([]byte(key))
return fmt.Sprintf("%x", h)
}
// evictOldest 删除最老的缓存条目
func (c *QueryCache) evictOldest() {
var oldestKey string
var oldestTime time.Time
for key, item := range c.items {
if oldestKey == "" || item.CreatedAt.Before(oldestTime) {
oldestKey = key
oldestTime = item.CreatedAt
}
}
if oldestKey != "" {
delete(c.items, oldestKey)
}
}
// StartCleanup 启动清理协程
func (c *QueryCache) StartCleanup() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
ticker := time.NewTicker(c.ttl / 2) // 每 TTL/2 时间检查一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.CleanupExpired()
c.cleanupCooldowns() // 清理冷却时间
case <-c.stopCh:
return
}
}
}()
}
// StartStatsCollection 启动统计收集协程
func (c *QueryCache) StartStatsCollection() {
c.wg.Add(1)
go func() {
defer c.wg.Done()
ticker := time.NewTicker(1 * time.Minute) // 每分钟收集一次统计
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.updateHitRate()
c.cleanupHotMarkers()
case <-c.stopCh:
return
}
}
}()
}
// cleanupCooldowns 清理冷却时间
func (c *QueryCache) cleanupCooldowns() {
now := time.Now()
for key, cooldown := range c.cooldowns {
if now.After(cooldown) {
delete(c.cooldowns, key)
}
}
}
// Stop 停止缓存清理
func (c *QueryCache) Stop() {
close(c.stopCh)
c.wg.Wait()
}
// CacheStats 缓存统计信息
type CacheStats struct {
TotalItems int
ActiveItems int
ExpiredItems int
Size int
TTL time.Duration
HitRate float64
HitCount int64
MissCount int64
EvictionCount int64
HotQueries int
}
// 缓存错误定义
var (
ErrCacheNotFound = &CacheError{Message: "缓存未找到"}
ErrCacheExpired = &CacheError{Message: "缓存已过期"}
ErrCacheCooldown = &CacheError{Message: "查询在冷却中"}
)
// CacheError 缓存错误
type CacheError struct {
Message string
}
func (e *CacheError) Error() string {
return e.Message
}
// estimateSize 估算缓存条目的内存大小(字节)
func (c *QueryCache) estimateSize(params QueryParams, item *CachedQuery) int64 {
size := int64(len(params.SQL) + len(params.Database) + len(params.Table) +
len(params.Where) + len(params.SortBy))
if item != nil && item.Result != nil {
size += c.estimateItemSize(item)
}
return size
}
// estimateItemSize 估算 CachedQuery 的内存大小
func (c *QueryCache) estimateItemSize(item *CachedQuery) int64 {
if item == nil || item.Result == nil {
return 128 // 基础结构体大小
}
size := int64(128) // CachedQuery 结构体基础大小
for _, row := range item.Result.Data {
for _, v := range row {
switch val := v.(type) {
case string:
size += int64(len(val))
case []byte:
size += int64(len(val))
case nil:
// 无额外开销
default:
size += 64 // 其他类型的估算值
}
}
}
size += int64(len(item.Result.Columns)) * 64 // 列名估算
return size
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"sync"
"time"
"u-desk/internal/common"
"u-desk/internal/crypto"
@@ -18,7 +19,10 @@ type ConnectionPool struct {
mongoClients map[uint]*MongoClient
// 新增MySQL 真连接池
mysqlPool *MySQLConnectionPool
mysqlPool *MySQLConnectionPool
// 查询优化器
queryOptimizer *QueryOptimizer
mu sync.RWMutex
}
@@ -38,18 +42,37 @@ func GetPool() *ConnectionPool {
// 启动维护协程
mysqlPool.StartMaintenance()
// 创建查询优化器
queryOptimizer := NewQueryOptimizer(nil)
globalPool = &ConnectionPool{
mysqlClients: make(map[uint]*MySQLClient),
redisClients: make(map[uint]*RedisClient),
mongoClients: make(map[uint]*MongoClient),
mysqlPool: mysqlPool,
mysqlPool: mysqlPool,
queryOptimizer: queryOptimizer,
}
})
return globalPool
}
// PooledClient 带释放语义的客户端包装
type PooledClient struct {
Client *MySQLClient
entry *MySQLPoolEntry
pool *MySQLConnectionPool
fromPool bool
}
// Release 释放连接回连接池
func (pc *PooledClient) Release() {
if pc.fromPool && pc.pool != nil && pc.entry != nil {
pc.pool.Release(pc.entry)
}
}
// GetMySQLClient 获取或创建 MySQL 客户端(使用连接池)
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) *PooledClient {
p.mu.Lock()
defer p.mu.Unlock()
@@ -57,16 +80,25 @@ func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient
if p.mysqlPool != nil {
entry, err := p.mysqlPool.Acquire(conn)
if err == nil {
// 成功从池中获取连接
return entry.Client, nil
return &PooledClient{Client: entry.Client, entry: entry, pool: p.mysqlPool, fromPool: true}
}
// 连接池错误,返回
return nil, err
p.logPoolError("Acquire failed", err)
}
// 降级到原有逻辑(如果连接池未初始化)
return p.getMySQLClientLegacy(conn)
// 降级到原有逻辑
client, err := p.getMySQLClientLegacy(conn)
if err != nil {
return &PooledClient{Client: nil, fromPool: false}
}
return &PooledClient{Client: client, fromPool: false}
}
// logPoolError 记录连接池错误
func (p *ConnectionPool) logPoolError(operation string, err error) {
if p.queryOptimizer != nil {
// 通过查询优化器记录错误
p.queryOptimizer.RecordPoolError(operation, err)
}
}
// getMySQLClientLegacy 原有的 MySQL 客户端获取逻辑(向后兼容)
@@ -115,6 +147,92 @@ func (p *ConnectionPool) GetMySQLPoolStats() *PoolStats {
return nil
}
// OptimizeQuery 优化查询执行
func (p *ConnectionPool) OptimizeQuery(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (*QueryResult, time.Duration, error) {
pc := p.GetMySQLClient(conn)
if pc.Client == nil {
return nil, 0, fmt.Errorf("获取 MySQL 连接失败")
}
defer pc.Release()
// 使用查询优化器
if p.queryOptimizer != nil {
return p.queryOptimizer.OptimizeQuery(ctx, pc.Client, sqlStr, database)
}
// 降级到普通查询
startTime := time.Now()
result, err := pc.Client.ExecuteQuery(ctx, sqlStr, database)
duration := time.Since(startTime)
return result, duration, err
}
// ExecuteOptimizedUpdate 执行优化的更新操作
func (p *ConnectionPool) ExecuteOptimizedUpdate(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (int64, time.Duration, error) {
pc := p.GetMySQLClient(conn)
if pc.Client == nil {
return 0, 0, fmt.Errorf("获取 MySQL 连接失败")
}
defer pc.Release()
// 使用查询优化器
if p.queryOptimizer != nil {
return p.queryOptimizer.ExecuteOptimizedUpdate(ctx, pc.Client, sqlStr, database)
}
// 降级到普通更新
startTime := time.Now()
result, err := pc.Client.ExecuteUpdate(ctx, sqlStr, database)
duration := time.Since(startTime)
return result, duration, err
}
// GetQueryStats 获取查询统计信息
func (p *ConnectionPool) GetQueryStats() QueryStats {
if p.queryOptimizer != nil {
return p.queryOptimizer.GetQueryStats()
}
return QueryStats{}
}
// GetSlowQueries 获取慢查询记录
func (p *ConnectionPool) GetSlowQueries(limit int) []SlowQuery {
if p.queryOptimizer != nil {
return p.queryOptimizer.GetSlowQueries(limit)
}
return []SlowQuery{}
}
// GetIndexSuggestions 获取索引建议
func (p *ConnectionPool) GetIndexSuggestions(table string) []IndexSuggestion {
if p.queryOptimizer != nil {
return p.queryOptimizer.GetIndexSuggestions(table)
}
return []IndexSuggestion{}
}
// GenerateIndexSuggestions 为表生成索引建议
func (p *ConnectionPool) GenerateIndexSuggestions(ctx context.Context, conn *models.DbConnection, database, table string) error {
pc := p.GetMySQLClient(conn)
if pc.Client == nil {
return fmt.Errorf("获取 MySQL 连接失败")
}
defer pc.Release()
// 使用查询优化器
if p.queryOptimizer != nil {
return p.queryOptimizer.GenerateIndexSuggestions(ctx, pc.Client, database, table)
}
return nil
}
// ClearQueryCache 清空查询缓存
func (p *ConnectionPool) ClearQueryCache() {
if p.queryOptimizer != nil {
p.queryOptimizer.ClearCache()
}
}
// GetRedisClient 获取或创建 Redis 客户端
func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient, error) {
p.mu.Lock()

View File

@@ -34,22 +34,40 @@ type PoolConfig struct {
SlowConnThreshold time.Duration
// 连接池最大容量(防止资源耗尽)
MaxPoolCapacity int
// 动态连接池配置
EnableDynamicScaling bool // 是否启用动态连接池调整
DynamicScaleFactor float64 // 动态调整因子0.5-2.0
ScaleUpThreshold float64 // 扩容阈值0-1.0,当使用率超过此值时扩容)
ScaleDownThreshold float64 // 缩容阈值0-1.0,当使用率低于此值时缩容)
MinScaleUpInterval time.Duration // 最小扩容间隔(防止频繁调整)
MinScaleDownInterval time.Duration // 最小缩容间隔
MaxIdleTimeForScale time.Duration // 用于动态调整的最大空闲时间
}
// DefaultPoolConfig 返回默认连接池配置
func DefaultPoolConfig() *PoolConfig {
return &PoolConfig{
MaxOpenConns: 20, // 最大20个连接
MaxIdleConns: 10, // 最大10个空闲
ConnMaxLifetime: 30 * time.Minute, // 连接最长30分钟
ConnMaxIdleTime: 10 * time.Minute, // 空闲10分钟关闭
MinIdleConns: 2, // 保持2个最小空闲
ConnTimeout: 5 * time.Second, // 连接超时5秒
HealthCheckInterval: 30 * time.Second, // 30秒健康检查一次
MaxOpenConns: 50, // 最大50个连接(提高并发)
MaxIdleConns: 20, // 最大20个空闲(提高响应速度)
ConnMaxLifetime: 60 * time.Minute, // 连接最长60分钟(延长连接生命周期)
ConnMaxIdleTime: 15 * time.Minute, // 空闲15分钟关闭(更长的空闲时间)
MinIdleConns: 5, // 保持5个最小空闲(更好的响应性能)
ConnTimeout: 3 * time.Second, // 连接超时3秒更快失败
HealthCheckInterval: 20 * time.Second, // 20秒健康检查一次(更频繁的健康检查)
EnableWarmup: true, // 启用预热
EnableSlowConnLog: true, // 启用慢连接日志
SlowConnThreshold: 500 * time.Millisecond, // 超过500ms算慢连接
MaxPoolCapacity: 50, // 连接池最大容量
SlowConnThreshold: 200 * time.Millisecond, // 超过200ms算慢连接(更严格的性能要求)
MaxPoolCapacity: 100, // 连接池最大容量(支持更高并发)
// 动态连接池配置(更智能的调整策略)
EnableDynamicScaling: true, // 启用动态调整
DynamicScaleFactor: 1.8, // 调整因子1.8倍(更激进的扩容)
ScaleUpThreshold: 0.7, // 使用率超过70%扩容(更早扩容)
ScaleDownThreshold: 0.4, // 使用率低于40%缩容(避免频繁调整)
MinScaleUpInterval: 1 * time.Minute, // 最小扩容间隔1分钟更快的响应
MinScaleDownInterval: 3 * time.Minute, // 最小缩容间隔3分钟稳定缩容
MaxIdleTimeForScale: 20 * time.Minute, // 用于调整的最大空闲时间
}
}
@@ -94,6 +112,13 @@ type MySQLConnectionPool struct {
stats PoolStats
stopCh chan struct{}
wg sync.WaitGroup
// 动态调整相关
lastScaleUpTime time.Time // 上次扩容时间
lastScaleDownTime time.Time // 上次缩容时间
currentTargetSize int // 当前目标连接数
usageHistory []float64 // 使用率历史记录(用于智能调整)
adaptiveWeights map[uint]float64 // 连接权重(基于性能表现)
}
// NewMySQLConnectionPool 创建新的 MySQL 连接池
@@ -103,10 +128,13 @@ func NewMySQLConnectionPool(config *PoolConfig) *MySQLConnectionPool {
}
pool := &MySQLConnectionPool{
config: config,
entries: make([]*MySQLPoolEntry, 0, config.MaxPoolCapacity),
connMap: make(map[uint]*MySQLClient),
stopCh: make(chan struct{}),
config: config,
entries: make([]*MySQLPoolEntry, 0, config.MaxPoolCapacity),
connMap: make(map[uint]*MySQLClient),
stopCh: make(chan struct{}),
currentTargetSize: config.MinIdleConns,
usageHistory: make([]float64, 0, 100), // 保留最近100个使用率记录
adaptiveWeights: make(map[uint]float64),
}
return pool
@@ -119,7 +147,15 @@ func (p *MySQLConnectionPool) Acquire(conn *models.DbConnection) (*MySQLPoolEntr
startTime := time.Now()
// 尝试从池中获取空闲连接
// 尝试获取最优连接(启用动态调整时)
if p.config.EnableDynamicScaling {
if entry, err := p.getOptimalConnection(); err == nil {
p.updateWaitStats(startTime)
return entry, nil
}
}
// 降级到标准逻辑 - 查找空闲连接
for _, entry := range p.entries {
entry.mu.Lock()
if !entry.InUse {
@@ -138,13 +174,13 @@ func (p *MySQLConnectionPool) Acquire(conn *models.DbConnection) (*MySQLPoolEntr
// 没有可用连接,创建新连接
if len(p.entries) >= p.config.MaxOpenConns {
// 已达到最大连接数,等待
return nil, p.waitForAvailableConnection(conn)
return p.waitForAvailableConnection(conn)
}
// 创建新连接
// 创建新连接(使用传入的连接配置)
newEntry, err := p.createNewEntry(conn)
if err != nil {
return nil, err
return nil, fmt.Errorf("创建连接失败: %v", err)
}
p.entries = append(p.entries, newEntry)
@@ -160,15 +196,14 @@ func (p *MySQLConnectionPool) Release(entry *MySQLPoolEntry) error {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
entry.InUse = false
entry.LastUsed = time.Now()
p.mu.Lock()
defer p.mu.Unlock()
entry.mu.Lock()
entry.InUse = false
entry.LastUsed = time.Now()
entry.mu.Unlock()
p.updateStats()
return nil
@@ -240,35 +275,9 @@ func (p *MySQLConnectionPool) cleanupIdleConnections() {
p.updateStats()
}
// healthCheck 健康检查
// healthCheck 健康检查(增强版本)
func (p *MySQLConnectionPool) healthCheck() {
p.mu.RLock()
entriesCopy := make([]*MySQLPoolEntry, len(p.entries))
copy(entriesCopy, p.entries)
p.mu.RUnlock()
var healthyEntries []*MySQLPoolEntry
for _, entry := range entriesCopy {
entry.mu.Lock()
if !entry.InUse {
// Ping 测试
if err := entry.Client.sqlDB.Ping(); err != nil {
// 连接失效,标记为需要关闭
entry.mu.Unlock()
entry.Client.Close()
continue
}
}
entry.mu.Unlock()
healthyEntries = append(healthyEntries, entry)
}
// 更新连接池
p.mu.Lock()
defer p.mu.Unlock()
p.entries = healthyEntries
p.updateStats()
p.enhancedHealthCheck()
}
// StartMaintenance 启动维护协程(清理和健康检查)
@@ -277,16 +286,28 @@ func (p *MySQLConnectionPool) StartMaintenance() {
go func() {
defer p.wg.Done()
ticker := time.NewTicker(p.config.HealthCheckInterval)
defer ticker.Stop()
// 健康检查Ticker
healthTicker := time.NewTicker(p.config.HealthCheckInterval)
defer healthTicker.Stop()
// 动态调整Ticker较短间隔
scaleTicker := time.NewTicker(1 * time.Minute)
defer scaleTicker.Stop()
for {
select {
case <-ticker.C:
case <-healthTicker.C:
// 清理空闲连接
p.cleanupIdleConnections()
// 健康检查
p.healthCheck()
case <-scaleTicker.C:
// 动态连接池调整
if p.config.EnableDynamicScaling {
p.adaptiveScaling()
}
case <-p.stopCh:
return
}
@@ -323,10 +344,8 @@ func (p *MySQLConnectionPool) createNewEntry(conn *models.DbConnection) (*MySQLP
return entry, nil
}
// waitForAvailableConnection 等待可用连接
func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnection) error {
// 实现简单的等待逻辑(使用 channel
// 创建一个超时上下文
// waitForAvailableConnection 等待可用连接并获取它
func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnection) (*MySQLPoolEntry, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -336,34 +355,29 @@ func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnecti
for {
select {
case <-ctx.Done():
return ErrPoolExhausted
return nil, ErrPoolExhausted
case <-ticker.C:
// 检查是否有可用连接
p.mu.RLock()
hasAvailable := false
p.mu.Lock()
for _, entry := range p.entries {
entry.mu.Lock()
if !entry.InUse {
hasAvailable = true
entry.InUse = true
entry.LastUsed = time.Now()
entry.mu.Unlock()
break
p.mu.Unlock()
return entry, nil
}
entry.mu.Unlock()
}
p.mu.RUnlock()
if hasAvailable {
return nil
}
p.mu.Unlock()
}
}
}
// updateWaitStats 更新等待统计
// updateWaitStats 更新等待统计(调用方必须持有 p.mu
func (p *MySQLConnectionPool) updateWaitStats(startTime time.Time) {
waitDuration := time.Since(startTime)
p.stats.WaitCount++
p.stats.WaitDuration += waitDuration
p.stats.WaitDuration += time.Since(startTime)
}
// updateStats 更新连接池统计
@@ -387,6 +401,244 @@ func (p *MySQLConnectionPool) updateStats() {
p.stats.IdleConns = idle
}
// adaptiveScaling 自适应连接池调整
func (p *MySQLConnectionPool) adaptiveScaling() {
p.mu.Lock()
defer p.mu.Unlock()
// 计算当前使用率
if len(p.entries) == 0 {
return
}
usageRate := float64(p.stats.ActiveConns) / float64(len(p.entries))
// 记录使用率历史
p.usageHistory = append(p.usageHistory, usageRate)
if len(p.usageHistory) > 100 {
p.usageHistory = p.usageHistory[1:]
}
// 检查是否需要调整
now := time.Now()
// 扩容逻辑
if usageRate >= p.config.ScaleUpThreshold {
if now.Sub(p.lastScaleUpTime) >= p.config.MinScaleUpInterval {
p.scaleUp()
p.lastScaleUpTime = now
}
return
}
// 缩容逻辑
if usageRate <= p.config.ScaleDownThreshold && len(p.entries) > p.config.MinIdleConns {
if now.Sub(p.lastScaleDownTime) >= p.config.MinScaleDownInterval {
p.scaleDown()
p.lastScaleDownTime = now
}
}
}
// scaleUp 扩容
func (p *MySQLConnectionPool) scaleUp() {
// scaleUp 仅更新目标大小,实际连接在 Acquire 时按需创建
// 移除了创建无效虚拟连接的逻辑
currentSize := len(p.entries)
scaleFactor := p.config.DynamicScaleFactor
newSize := int(float64(currentSize) * scaleFactor)
newSize = min(newSize, p.config.MaxOpenConns)
newSize = max(newSize, currentSize+1)
p.currentTargetSize = newSize
p.updateStats()
}
// scaleDown 缩容
func (p *MySQLConnectionPool) scaleDown() {
// 计算新目标大小
currentSize := len(p.entries)
scaleFactor := 1.0 / p.config.DynamicScaleFactor
newSize := int(float64(currentSize) * scaleFactor)
newSize = max(newSize, p.config.MinIdleConns)
newSize = min(newSize, currentSize-1) // 至少减少1个连接
if newSize < currentSize {
// 关闭多余的空闲连接
p.closeIdleConnections(currentSize - newSize)
p.currentTargetSize = newSize
p.updateStats()
}
}
// closeIdleConnections 关闭指定数量的空闲连接
func (p *MySQLConnectionPool) closeIdleConnections(count int) {
// 收集空闲连接
idleEntries := make([]*MySQLPoolEntry, 0)
for _, entry := range p.entries {
entry.mu.Lock()
if !entry.InUse {
idleEntries = append(idleEntries, entry)
}
entry.mu.Unlock()
}
// 关闭指定数量的空闲连接
closedEntries := make(map[*MySQLPoolEntry]bool)
for i := 0; i < min(count, len(idleEntries)); i++ {
entry := idleEntries[i]
entry.mu.Lock()
entry.Client.Close()
entry.mu.Unlock()
closedEntries[entry] = true
}
// 重新构建连接池
remainingEntries := make([]*MySQLPoolEntry, 0, len(p.entries))
for _, entry := range p.entries {
if closedEntries[entry] {
continue // 跳过已关闭的连接
}
remainingEntries = append(remainingEntries, entry)
}
p.entries = remainingEntries
}
// enhancedHealthCheck 增强的健康检查
func (p *MySQLConnectionPool) enhancedHealthCheck() {
p.mu.RLock()
entriesCopy := make([]*MySQLPoolEntry, len(p.entries))
copy(entriesCopy, p.entries)
p.mu.RUnlock()
var healthyEntries []*MySQLPoolEntry
var performanceWeights []float64
for _, entry := range entriesCopy {
entry.mu.Lock()
isIdle := !entry.InUse
// 测试连接有效性
isHealthy := true
startTime := time.Now()
if isIdle {
// 空闲连接简单Ping测试
if err := entry.Client.sqlDB.Ping(); err != nil {
isHealthy = false
// 关闭失效连接
entry.Client.Close()
}
} else {
// 使用中的连接:快速测试(避免影响正常查询)
func() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := entry.Client.sqlDB.PingContext(ctx); err != nil {
isHealthy = false
}
}()
}
// 计算连接性能权重
if isHealthy {
healthyEntries = append(healthyEntries, entry)
// 基于连接性能计算权重
responseTime := time.Since(startTime).Microseconds()
weight := 1.0 / max(float64(responseTime)/1000.0, 1.0) // 转换为毫秒,避免除零
performanceWeights = append(performanceWeights, weight)
} else {
// 不健康的连接
if isIdle {
entry.Client.Close()
}
}
entry.mu.Unlock()
}
// 更新连接池
p.mu.Lock()
defer p.mu.Unlock()
p.entries = healthyEntries
// 更新自适应权重
if len(healthyEntries) > 0 {
for i := range healthyEntries {
if i < len(performanceWeights) {
p.adaptiveWeights[uint(i)] = performanceWeights[i]
}
}
}
p.updateStats()
}
// warmUp 连接池预热
func (p *MySQLConnectionPool) warmUp() {
if !p.config.EnableWarmup {
return
}
p.mu.Lock()
defer p.mu.Unlock()
currentIdle := 0
for _, entry := range p.entries {
entry.mu.Lock()
if !entry.InUse {
currentIdle++
}
entry.mu.Unlock()
}
targetIdle := p.config.MinIdleConns
needed := targetIdle - currentIdle
// warmUp 仅记录目标大小,不在无连接配置的情况下创建无效虚拟连接
// 实际连接在 Acquire 时按需创建
_ = needed
p.updateStats()
}
// getOptimalConnection 获取最优连接(基于性能权重)
// 注意:调用方必须已持有 p.mu
func (p *MySQLConnectionPool) getOptimalConnection() (*MySQLPoolEntry, error) {
var bestEntry *MySQLPoolEntry
var bestWeight float64
for i, entry := range p.entries {
entry.mu.Lock()
if !entry.InUse {
weight := 1.0 // 默认权重
if w, ok := p.adaptiveWeights[uint(i)]; ok {
weight = w
}
if bestEntry == nil || weight > bestWeight {
bestEntry = entry
bestWeight = weight
}
}
entry.mu.Unlock()
}
if bestEntry == nil {
return nil, ErrPoolExhausted
}
bestEntry.InUse = true
bestEntry.LastUsed = time.Now()
return bestEntry, nil
}
// createMySQLClient 创建 MySQL 客户端的辅助函数
func createMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
// 解密密码
@@ -424,3 +676,4 @@ func (e *PoolError) Error() string {
}
return e.Message
}

View File

@@ -0,0 +1,762 @@
package dbclient
import (
"context"
"crypto/sha256"
"fmt"
"regexp"
"strings"
"sync"
"time"
)
var (
reLimitOffset = regexp.MustCompile(`limit\s+(\d+)(?:\s*,\s*(\d+))?`)
reFromTable = regexp.MustCompile(`(?i)from\s+([^\s,]+)`)
reWhereClause = regexp.MustCompile(`(?i)where\s+(.*?)(?:\s+order\s+by|\s+limit|\s+group\s+by|$)`)
reOrderBy = regexp.MustCompile(`(?i)order\s+by\s+(.*?)(?:\s+limit|$)`)
reBatchOperation = regexp.MustCompile(`(?i)^\s*(INSERT|UPDATE|DELETE).*VALUES\s*\(`)
)
// CachedQuery 缓存查询结果
type CachedQuery struct {
Result *QueryResult
ExpiryTime time.Time
CreatedAt time.Time
QueryHash string
QueryParams QueryParams
LastUsed time.Time // 最后使用时间用于LRU策略
AccessCount int64 // 访问次数用于LFU策略
}
// QueryParams 查询参数(用于缓存键生成)
type QueryParams struct {
SQL string
Database string
Limit int
Offset int
Table string
Where string
SortBy string
IsReadOnly bool
}
// QueryStats 查询统计信息
type QueryStats struct {
TotalQueries int64
CachedQueries int64
SlowQueries int64
TotalDuration time.Duration
AverageDuration time.Duration
CacheHitRate float64
LastCacheUpdate time.Time
}
// SlowQuery 慢查询记录
type SlowQuery struct {
Query string
Database string
Duration time.Duration
Timestamp time.Time
Params QueryParams
Table string
IndexUsed string
RowsAffected int64
Error error
}
// IndexSuggestion 索引建议
type IndexSuggestion struct {
Table string
Columns []string
IndexType string // "normal", "unique", "fulltext"
Priority string // "high", "medium", "low"
Query string
Justification string
CanBeApplied bool
}
// QueryOptimizer 查询优化器
type QueryOptimizer struct {
cache *QueryCache
stats *QueryStats
slowQueries []SlowQuery
indexSuggestions []IndexSuggestion
mu sync.RWMutex
config *OptimizerConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// OptimizerConfig 查询优化器配置
type OptimizerConfig struct {
// 缓存配置
CacheSize int // 最大缓存条目数
CacheTTL time.Duration // 缓存过期时间
EnableCache bool // 是否启用缓存
// 慢查询配置
SlowQueryThreshold time.Duration // 慢查询阈值
EnableSlowLog bool // 是否启用慢查询日志
MaxSlowLogs int // 最大慢查询记录数
// 索引建议配置
EnableIndexSuggestions bool // 是否启用索引建议
MaxSuggestions int // 最大索引建议数
// 查询分析配置
EnableQueryAnalysis bool // 是否启用查询分析
MaxAnalysisDepth int // 查询分析深度
}
// DefaultOptimizerConfig 返回默认的查询优化器配置
func DefaultOptimizerConfig() *OptimizerConfig {
return &OptimizerConfig{
CacheSize: 1000, // 最多缓存1000个查询
CacheTTL: 30 * time.Minute, // 缓存30分钟
EnableCache: true, // 启用缓存
SlowQueryThreshold: 100 * time.Millisecond, // 100ms以上为慢查询
EnableSlowLog: true, // 启用慢查询日志
MaxSlowLogs: 1000, // 最多记录1000条慢查询
EnableIndexSuggestions: true, // 启用索引建议
MaxSuggestions: 100, // 最多100个索引建议
EnableQueryAnalysis: true, // 启用查询分析
MaxAnalysisDepth: 3, // 分析深度3
}
}
// NewQueryOptimizer 创建新的查询优化器
func NewQueryOptimizer(config *OptimizerConfig) *QueryOptimizer {
if config == nil {
config = DefaultOptimizerConfig()
}
optimizer := &QueryOptimizer{
cache: NewQueryCache(config.CacheSize, config.CacheTTL),
stats: &QueryStats{},
config: config,
stopCh: make(chan struct{}),
slowQueries: make([]SlowQuery, 0),
indexSuggestions: make([]IndexSuggestion, 0),
}
// 启动维护协程
optimizer.StartMaintenance()
return optimizer
}
// OptimizeQuery 优化查询执行
func (o *QueryOptimizer) OptimizeQuery(ctx context.Context, client *MySQLClient, sqlStr string, database string) (*QueryResult, time.Duration, error) {
startTime := time.Now()
queryParams := o.parseQueryParams(sqlStr, database)
// 检查缓存
if o.config.EnableCache && queryParams.IsReadOnly {
cached, err := o.cache.Get(queryParams)
if err == nil && cached != nil {
o.recordCacheHit()
return cached.Result, time.Since(startTime), nil
}
}
// 执行查询
result, err := client.ExecuteQuery(ctx, sqlStr, database)
if err != nil {
duration := time.Since(startTime)
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
return nil, duration, err
}
duration := time.Since(startTime)
// 检查是否为慢查询
if duration > o.config.SlowQueryThreshold {
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
}
// 缓存只读查询结果
if o.config.EnableCache && queryParams.IsReadOnly && err == nil {
cachedResult := &CachedQuery{
Result: result,
ExpiryTime: time.Now().Add(o.config.CacheTTL),
CreatedAt: time.Now(),
QueryHash: o.generateQueryHash(queryParams),
QueryParams: queryParams,
LastUsed: time.Now(),
AccessCount: 1,
}
o.cache.Set(queryParams, cachedResult)
}
o.recordQuery(duration)
return result, duration, err
}
// ExecuteOptimizedUpdate 执行优化的更新操作
func (o *QueryOptimizer) ExecuteOptimizedUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
startTime := time.Now()
// 分析更新查询
queryParams := o.parseQueryParams(sqlStr, database)
// 检查是否为批量操作
if o.isBatchOperation(sqlStr) {
// 优化批量操作
rowsAffected, duration, err := o.optimizeBatchUpdate(ctx, client, sqlStr, database)
if err != nil {
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
return 0, duration, err
}
o.recordQuery(duration)
return rowsAffected, duration, nil
}
// 执行普通更新
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
duration := time.Since(startTime)
if duration > o.config.SlowQueryThreshold {
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
}
o.recordQuery(duration)
return rowsAffected, duration, err
}
// GetIndexSuggestions 获取索引建议
func (o *QueryOptimizer) GetIndexSuggestions(table string) []IndexSuggestion {
o.mu.RLock()
defer o.mu.RUnlock()
var suggestions []IndexSuggestion
for _, suggestion := range o.indexSuggestions {
if suggestion.Table == table || table == "" {
suggestions = append(suggestions, suggestion)
}
}
return suggestions
}
// GenerateIndexSuggestions 为表生成索引建议
func (o *QueryOptimizer) GenerateIndexSuggestions(ctx context.Context, client *MySQLClient, database, table string) error {
// 获取表的慢查询记录
tableSlowQueries := o.getTableSlowQueries(database, table)
// 分析查询模式
for _, slowQuery := range tableSlowQueries {
suggestions := o.analyzeQueryForIndexes(slowQuery.Query, table)
o.mu.Lock()
o.indexSuggestions = append(o.indexSuggestions, suggestions...)
// 限制建议数量
if len(o.indexSuggestions) > o.config.MaxSuggestions {
o.indexSuggestions = o.indexSuggestions[:o.config.MaxSuggestions]
}
o.mu.Unlock()
}
return nil
}
// GetQueryStats 获取查询统计信息
func (o *QueryOptimizer) GetQueryStats() QueryStats {
o.mu.RLock()
defer o.mu.RUnlock()
return *o.stats
}
// GetSlowQueries 获取慢查询记录
func (o *QueryOptimizer) GetSlowQueries(limit int) []SlowQuery {
o.mu.RLock()
defer o.mu.RUnlock()
if limit <= 0 || limit > len(o.slowQueries) {
limit = len(o.slowQueries)
}
return o.slowQueries[:limit]
}
// ClearCache 清空缓存
func (o *QueryOptimizer) ClearCache() {
o.cache.Clear()
}
// Stop 停止优化器
func (o *QueryOptimizer) Stop() {
close(o.stopCh)
o.wg.Wait()
}
// parseQueryParams 解析查询参数
func (o *QueryOptimizer) parseQueryParams(sqlStr, database string) QueryParams {
params := QueryParams{
SQL: sqlStr,
Database: database,
}
// 解析LIMIT和OFFSET
limit, offset := o.parseLimitOffset(sqlStr)
params.Limit = limit
params.Offset = offset
// 解析表名
tables := o.parseTables(sqlStr)
if len(tables) > 0 {
params.Table = tables[0]
}
// 解析WHERE条件
where := o.parseWhereCondition(sqlStr)
params.Where = where
// 解析排序
sort := o.parseSortOrder(sqlStr)
params.SortBy = sort
// 判断是否为只读查询
params.IsReadOnly = o.isReadOnlyQuery(sqlStr)
return params
}
// parseLimitOffset 解析LIMIT和OFFSET
func (o *QueryOptimizer) parseLimitOffset(sqlStr string) (limit, offset int) {
sqlStr = strings.ToLower(sqlStr)
matches := reLimitOffset.FindStringSubmatch(sqlStr)
if len(matches) > 1 {
fmt.Sscanf(matches[1], "%d", &limit)
if len(matches) > 2 && matches[2] != "" {
fmt.Sscanf(matches[2], "%d", &offset)
}
}
// MySQL LIMIT offset, count: matches[1]=offset, matches[2]=count
if len(matches) > 2 && matches[2] != "" {
offset, limit = limit, offset
}
return limit, offset
}
// parseTables 解析查询中的表名
func (o *QueryOptimizer) parseTables(sqlStr string) []string {
// 简单实现解析FROM和JOIN中的表名
tables := make([]string, 0)
fromMatches := reFromTable.FindAllStringSubmatch(sqlStr, -1)
for _, match := range fromMatches {
if len(match) > 1 {
tableName := strings.Trim(match[1], "`\"'[]")
tables = append(tables, tableName)
}
}
return tables
}
// parseWhereCondition 解析WHERE条件
func (o *QueryOptimizer) parseWhereCondition(sqlStr string) string {
matches := reWhereClause.FindStringSubmatch(sqlStr)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
// parseSortOrder 解析排序条件
func (o *QueryOptimizer) parseSortOrder(sqlStr string) string {
matches := reOrderBy.FindStringSubmatch(sqlStr)
if len(matches) > 1 {
return strings.TrimSpace(matches[1])
}
return ""
}
// isReadOnlyQuery 判断是否为只读查询
func (o *QueryOptimizer) isReadOnlyQuery(sqlStr string) bool {
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
// SELECT只读查询
if strings.HasPrefix(sqlStr, "SELECT") {
return true
}
// 支持的只读查询类型
readOnlyQueries := []string{
"SHOW", "DESCRIBE", "DESC", "EXPLAIN",
"WITH", "UNION", "INTERSECT", "EXCEPT",
}
for _, query := range readOnlyQueries {
if strings.HasPrefix(sqlStr, query) {
return true
}
}
return false
}
// isBatchOperation 判断是否为批量操作
func (o *QueryOptimizer) isBatchOperation(sqlStr string) bool {
return reBatchOperation.MatchString(sqlStr)
}
// generateQueryHash 生成查询哈希
func (o *QueryOptimizer) generateQueryHash(params QueryParams) string {
hashData := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
params.SQL, params.Database, params.Limit, params.Offset,
params.Table, params.Where, params.SortBy, params.IsReadOnly)
h := sha256.Sum256([]byte(hashData))
return fmt.Sprintf("%x", h)
}
// recordQuery 记录查询统计
func (o *QueryOptimizer) recordQuery(duration time.Duration) {
o.mu.Lock()
defer o.mu.Unlock()
o.stats.TotalQueries++
o.stats.TotalDuration += duration
o.stats.AverageDuration = time.Duration(int64(float64(o.stats.TotalDuration) / float64(o.stats.TotalQueries)))
now := time.Now()
if o.stats.LastCacheUpdate.IsZero() || now.Sub(o.stats.LastCacheUpdate) > 5*time.Minute {
// 更新缓存命中率
total := o.stats.TotalQueries
hit := o.stats.CachedQueries
o.stats.CacheHitRate = float64(hit) / float64(total) * 100
o.stats.LastCacheUpdate = now
}
}
// recordCacheHit 记录缓存命中
func (o *QueryOptimizer) recordCacheHit() {
o.mu.Lock()
defer o.mu.Unlock()
o.stats.CachedQueries++
}
// recordSlowQuery 记录慢查询
func (o *QueryOptimizer) recordSlowQuery(query, database string, duration time.Duration, params QueryParams, result *QueryResult, err error) {
if !o.config.EnableSlowLog {
return
}
slowQuery := SlowQuery{
Query: query,
Database: database,
Duration: duration,
Timestamp: time.Now(),
Params: params,
Table: params.Table,
IndexUsed: o.extractIndexUsed(query),
RowsAffected: o.extractRowsAffected(result),
Error: err,
}
o.mu.Lock()
defer o.mu.Unlock()
o.slowQueries = append(o.slowQueries, slowQuery)
// 限制慢查询记录数量
if len(o.slowQueries) > o.config.MaxSlowLogs {
o.slowQueries = o.slowQueries[1:]
}
o.stats.SlowQueries++
}
// extractIndexUsed 提取使用的索引
func (o *QueryOptimizer) extractIndexUsed(query string) string {
// 简单实现从EXPLAIN结果中提取索引信息
// 实际项目中应该执行EXPLAIN语句分析
return "unknown"
}
// extractRowsAffected 提取影响的行数
func (o *QueryOptimizer) extractRowsAffected(result *QueryResult) int64 {
if result != nil && len(result.Data) > 0 {
if rows, ok := result.Data[0]["rows_affected"].(int64); ok {
return rows
}
}
return 0
}
// analyzeQuery 分析查询性能
func (o *QueryOptimizer) analyzeQuery(query, database string, result *QueryResult, duration time.Duration) {
// 这里可以实现更复杂的查询分析逻辑
// 比如分析查询计划、检测N+1查询问题等
// 简单实现:记录查询到统计信息中
_ = query
_ = database
_ = result
_ = duration
}
// analyzeQueryForIndexes 分析查询为索引建议
func (o *QueryOptimizer) analyzeQueryForIndexes(query, table string) []IndexSuggestion {
var suggestions []IndexSuggestion
// 解析查询中的WHERE条件
where := o.parseWhereCondition(query)
if where != "" {
// 提取WHERE条件中的列
columns := o.extractColumnsFromWhere(where)
if len(columns) > 0 {
// 创建索引建议
suggestion := IndexSuggestion{
Table: table,
Columns: columns,
IndexType: "normal",
Priority: "medium",
Query: query,
Justification: fmt.Sprintf("查询经常使用WHERE条件 %s", where),
CanBeApplied: true,
}
suggestions = append(suggestions, suggestion)
}
}
// 解析ORDER BY条件
order := o.parseSortOrder(query)
if order != "" {
// 提取排序的列
columns := o.extractColumnsFromOrder(order)
if len(columns) > 0 {
// 创建排序索引建议
suggestion := IndexSuggestion{
Table: table,
Columns: columns,
IndexType: "normal",
Priority: "low",
Query: query,
Justification: fmt.Sprintf("查询经常使用ORDER BY %s", order),
CanBeApplied: true,
}
suggestions = append(suggestions, suggestion)
}
}
return suggestions
}
// extractColumnsFromWhere 从WHERE条件中提取列名
func (o *QueryOptimizer) extractColumnsFromWhere(where string) []string {
// 简单实现提取WHERE条件中的列名
columns := make([]string, 0)
// 这里可以实现更复杂的列名解析逻辑
// 目前只做简单处理
words := strings.Fields(where)
for _, word := range words {
// 去除运算符和引号
if !strings.Contains(word, "=") &&
!strings.Contains(word, ">") &&
!strings.Contains(word, "<") &&
!strings.Contains(word, "!=") &&
!strings.HasPrefix(word, "'") &&
!strings.HasPrefix(word, "\"") {
columns = append(columns, strings.Trim(word, " `\"'[]"))
}
}
return columns
}
// extractColumnsFromOrder 从ORDER BY条件中提取列名
func (o *QueryOptimizer) extractColumnsFromOrder(order string) []string {
// 简单实现提取ORDER BY中的列名
columns := strings.Split(order, ",")
for i, col := range columns {
columns[i] = strings.TrimSpace(strings.Split(col, " ")[0])
}
return columns
}
// getTableSlowQueries 获取表的慢查询记录
func (o *QueryOptimizer) getTableSlowQueries(database, table string) []SlowQuery {
o.mu.RLock()
defer o.mu.RUnlock()
var tableQueries []SlowQuery
for _, query := range o.slowQueries {
if (database == "" || query.Database == database) &&
(table == "" || query.Table == table) {
tableQueries = append(tableQueries, query)
}
}
return tableQueries
}
// optimizeBatchUpdate 优化批量更新操作
func (o *QueryOptimizer) optimizeBatchUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
// 简单实现:执行原始查询
// 实际项目中可以实现批量操作优化
startTime := time.Now()
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
duration := time.Since(startTime)
return rowsAffected, duration, err
}
// StartMaintenance 启动维护协程
func (o *QueryOptimizer) StartMaintenance() {
o.wg.Add(1)
go func() {
defer o.wg.Done()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 清理过期的缓存
o.cache.CleanupExpired()
// 分析慢查询生成新的索引建议
o.analyzeSlowQueriesForSuggestions()
case <-o.stopCh:
return
}
}
}()
}
// RecordPoolError 记录连接池错误
func (o *QueryOptimizer) RecordPoolError(operation string, err error) {
if !o.config.EnableSlowLog || err == nil {
return
}
poolError := SlowQuery{
Query: operation,
Database: "pool",
Duration: 0,
Timestamp: time.Now(),
Params: QueryParams{SQL: operation},
Table: "connection_pool",
IndexUsed: "N/A",
RowsAffected: 0,
Error: err,
}
o.mu.Lock()
defer o.mu.Unlock()
o.slowQueries = append(o.slowQueries, poolError)
// 限制慢查询记录数量
if len(o.slowQueries) > o.config.MaxSlowLogs {
o.slowQueries = o.slowQueries[1:]
}
}
// analyzeSlowQueriesForSuggestions 分析慢查询生成索引建议
func (o *QueryOptimizer) analyzeSlowQueriesForSuggestions() {
// 这里可以实现更复杂的慢查询分析逻辑
// 比如分析查询模式、统计索引使用情况等
// 分析慢查询模式
o.analyzeSlowQueryPatterns()
}
// analyzeSlowQueryPatterns 分析慢查询模式
func (o *QueryOptimizer) analyzeSlowQueryPatterns() {
o.mu.RLock()
queryTypes := make(map[string]int)
tableQueries := make(map[string]int)
for _, query := range o.slowQueries {
queryType := o.detectQueryType(query.Query)
queryTypes[queryType]++
if query.Table != "" {
tableQueries[query.Table]++
}
}
o.mu.RUnlock()
// 根据统计结果生成智能建议(在锁外执行,避免死锁)
o.generateSmartSuggestions(queryTypes, tableQueries)
}
// detectQueryType 检测查询类型
func (o *QueryOptimizer) detectQueryType(sqlStr string) string {
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
if strings.HasPrefix(sqlStr, "SELECT") {
if strings.Contains(sqlStr, "JOIN") {
return "SELECT_JOIN"
} else if strings.Contains(sqlStr, "GROUP BY") {
return "SELECT_GROUP"
} else {
return "SELECT_SIMPLE"
}
} else if strings.HasPrefix(sqlStr, "INSERT") {
return "INSERT"
} else if strings.HasPrefix(sqlStr, "UPDATE") {
return "UPDATE"
} else if strings.HasPrefix(sqlStr, "DELETE") {
return "DELETE"
}
return "OTHER"
}
// generateSmartSuggestions 生成智能建议
func (o *QueryOptimizer) generateSmartSuggestions(queryTypes map[string]int, tableQueries map[string]int) {
// 分析频繁执行的查询类型
var mostFrequentType string
var maxCount int
for queryType, count := range queryTypes {
if count > maxCount {
maxCount = count
mostFrequentType = queryType
}
}
// 生成针对性的索引建议
switch mostFrequentType {
case "SELECT_JOIN":
// 为JOIN查询建议复合索引
o.generateJoinSuggestions()
case "SELECT_GROUP":
// 为GROUP BY查询建议索引
o.generateGroupSuggestions()
case "INSERT":
// 为批量插入建议优化
o.generateInsertSuggestions()
}
}
// generateJoinSuggestions 生成JOIN查询建议
func (o *QueryOptimizer) generateJoinSuggestions() {
}
// generateGroupSuggestions 生成GROUP BY查询建议
func (o *QueryOptimizer) generateGroupSuggestions() {
}
// generateInsertSuggestions 生成批量插入建议
func (o *QueryOptimizer) generateInsertSuggestions() {
}

View File

@@ -0,0 +1,151 @@
package dbclient
import (
"context"
"fmt"
"log"
"github.com/redis/go-redis/v9"
)
// RedisPipeline Redis Pipeline 操作
type RedisPipeline struct {
client *RedisClient
commands []RedisCommand
ctx context.Context
}
// RedisCommand Redis 命令结构
type RedisCommand struct {
Command string
Args []interface{}
Result interface{}
Error error
}
// NewRedisPipeline 创建新的 Redis Pipeline
func (r *RedisClient) NewPipeline(ctx context.Context) *RedisPipeline {
return &RedisPipeline{
client: r,
commands: make([]RedisCommand, 0),
ctx: ctx,
}
}
// AddCommand 添加命令到 Pipeline
func (p *RedisPipeline) AddCommand(command string, args ...interface{}) {
p.commands = append(p.commands, RedisCommand{
Command: command,
Args: args,
})
}
// Execute 使用 go-redis 原生 Pipeline 执行所有命令
func (p *RedisPipeline) Execute() ([]interface{}, error) {
if len(p.commands) == 0 {
return nil, nil
}
pipe := p.client.client.Pipeline()
cmds := make([]*redis.Cmd, len(p.commands))
for i, c := range p.commands {
cmds[i] = pipe.Do(p.ctx, append([]interface{}{c.Command}, c.Args...)...)
}
// 一次性发送所有命令
results := make([]interface{}, len(p.commands))
cmdResults, err := pipe.Exec(p.ctx)
if err != nil && err != redis.Nil {
log.Printf("[RedisPipeline] Exec 错误: %v", err)
}
for i, cmd := range cmds {
result, cmdErr := cmd.Result()
results[i] = result
p.commands[i].Result = result
p.commands[i].Error = cmdErr
}
// 如果 Exec 返回了命令结果(部分 Redis 版本),使用它们
for i, cr := range cmdResults {
if cr.Err() != nil && cr.Err() != redis.Nil {
p.commands[i].Error = cr.Err()
if i < len(results) {
results[i] = nil
}
}
}
_ = results // 已经通过 cmds 获取
return results, nil
}
// GetCommands 获取 Pipeline 中的命令列表
func (p *RedisPipeline) GetCommands() []RedisCommand {
return p.commands
}
// Len 获取 Pipeline 中的命令数量
func (p *RedisPipeline) Len() int {
return len(p.commands)
}
// Clear 清空 Pipeline
func (p *RedisPipeline) Clear() {
p.commands = make([]RedisCommand, 0)
}
// RedisTransaction Redis 事务支持
type RedisTransaction struct {
client *RedisClient
watch []string
cmds []RedisCommand
ctx context.Context
}
// NewRedisTransaction 创建新的 Redis 事务
func (r *RedisClient) NewTransaction(ctx context.Context, watch ...string) *RedisTransaction {
return &RedisTransaction{
client: r,
watch: watch,
ctx: ctx,
}
}
// AddCommand 添加命令到事务
func (tx *RedisTransaction) AddCommand(command string, args ...interface{}) {
tx.cmds = append(tx.cmds, RedisCommand{
Command: command,
Args: args,
})
}
// Exec 使用 go-redis Watch + TxPipeline 执行事务MULTI/EXEC
func (tx *RedisTransaction) Exec() ([]interface{}, error) {
pipe := tx.client.client.TxPipeline()
// 添加所有命令
cmds := make([]*redis.Cmd, len(tx.cmds))
for i, c := range tx.cmds {
cmds[i] = pipe.Do(tx.ctx, append([]interface{}{c.Command}, c.Args...)...)
}
// TxPipeline 自动发送 MULTI/EXEC
results := make([]interface{}, len(tx.cmds))
_, err := pipe.Exec(tx.ctx)
for i, cmd := range cmds {
result, cmdErr := cmd.Result()
results[i] = result
tx.cmds[i].Result = result
tx.cmds[i].Error = cmdErr
}
if err != nil && err != redis.Nil {
return results, fmt.Errorf("事务执行失败: %v", err)
}
return results, nil
}

View File

@@ -43,8 +43,10 @@ var defaultTabConfig = TabConfig{
AvailableTabs: []TabDefinition{
{Key: "file-system", Title: "文件管理", Enabled: true},
{Key: "db-cli", Title: "数据库", Enabled: true},
{Key: "markdown-editor", Title: "Markdown", Enabled: true},
{Key: "openclaw-manager", Title: "OpenClaw", Enabled: true},
},
VisibleTabs: []string{"file-system", "db-cli"},
VisibleTabs: []string{"file-system", "db-cli", "markdown-editor", "openclaw-manager"},
DefaultTab: "file-system",
}

View File

@@ -1,12 +1,16 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"u-desk/internal/crypto"
"u-desk/internal/dbclient"
"u-desk/internal/storage/models"
"u-desk/internal/storage/repository"
"gorm.io/gorm"
)
// ConnectionService 连接管理服务
@@ -90,8 +94,20 @@ func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error)
return s.repo.FindByID(id)
}
// DeleteConnection 删除连接配置
// DeleteConnection 删除连接配置(含关联数据和连接池清理)
func (s *ConnectionService) DeleteConnection(id uint) error {
conn, err := s.repo.FindByID(id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 连接不存在视为成功
}
return fmt.Errorf("获取连接配置失败: %v", err)
}
// 关闭连接池中的连接
dbclient.GetPool().CloseConnection(id, conn.Type)
// 删除连接记录
return s.repo.Delete(id)
}
@@ -185,3 +201,68 @@ func (s *ConnectionService) TestConnectionWithParams(connType, host string, port
return fmt.Errorf("不支持的数据库类型: %s", connType)
}
}
// LoadAllDatabases 加载全部数据库列表
func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, existingId uint) ([]string, error) {
// 如果是编辑模式且密码为空,尝试获取已保存的密码
actualPassword := password
if existingId > 0 && password == "" {
conn, err := s.repo.FindByID(existingId)
if err != nil {
return nil, fmt.Errorf("获取原连接配置失败: %v", err)
}
actualPassword, err = crypto.DecryptPassword(conn.Password)
if err != nil {
return nil, fmt.Errorf("密码解密失败: %v", err)
}
}
// 解析 MongoDB 选项
authSource := ""
authMechanism := ""
if options != "" {
var opts map[string]interface{}
if err := json.Unmarshal([]byte(options), &opts); err == nil {
authSource, _ = opts["authSource"].(string)
authMechanism, _ = opts["authMechanism"].(string)
}
}
switch dbType {
case "mysql":
return loadDatabasesForMySQL(host, port, username, actualPassword, database)
case "mongo":
return loadDatabasesForMongo(host, port, username, actualPassword, database, authSource, authMechanism)
case "redis":
return []string{}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", dbType)
}
}
func loadDatabasesForMySQL(host string, port int, username, password, defaultDatabase string) ([]string, error) {
config := &dbclient.MySQLConfig{
Host: host, Port: port, Username: username,
Password: password, Database: defaultDatabase,
}
client, err := dbclient.NewMySQLClient(config)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}
func loadDatabasesForMongo(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) {
config := &dbclient.MongoConfig{
Host: host, Port: port, Username: username,
Password: password, Database: defaultDatabase,
AuthSource: authSource, AuthMechanism: authMechanism,
}
client, err := dbclient.NewMongoClient(config)
if err != nil {
return nil, err
}
defer client.Close()
return client.ListDatabases(context.Background())
}

View File

@@ -66,10 +66,11 @@ func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database s
// executeMySQL 执行MySQL SQL
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
defer pc.Release()
sqlStr = strings.TrimSpace(sqlStr)
sqlUpper := strings.ToUpper(sqlStr)
@@ -89,7 +90,7 @@ func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnec
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
// 查询语句
queryResult, err := client.ExecuteQuery(ctx, sqlStr, dbName)
queryResult, err := pc.Client.ExecuteQuery(ctx, sqlStr, dbName)
if err != nil {
return nil, err
}
@@ -99,7 +100,7 @@ func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnec
result.RowsAffected = len(queryResult.Data)
} else {
// 更新语句
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, dbName)
rowsAffected, err := pc.Client.ExecuteUpdate(ctx, sqlStr, dbName)
if err != nil {
return nil, err
}
@@ -220,11 +221,12 @@ func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
return client.ListDatabases(ctx)
defer pc.Release()
return pc.Client.ListDatabases(ctx)
case "redis":
databases := make([]string, 16)
for i := 0; i < 16; i++ {
@@ -254,11 +256,12 @@ func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
return client.ListTables(ctx, database)
defer pc.Release()
return pc.Client.ListTables(ctx, database)
case "redis":
client, err := s.pool.GetRedisClient(conn)
if err != nil {
@@ -305,7 +308,7 @@ func parseRedisCommand(cmd string) []string {
} else {
if char == quoteChar {
inQuotes = false
quoteChar = 0
quoteChar = byte(0)
} else {
current.WriteByte(char)
}
@@ -330,11 +333,12 @@ func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableNam
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
structure, err := client.GetTableStructure(ctx, database, tableName)
defer pc.Release()
structure, err := pc.Client.GetTableStructure(ctx, database, tableName)
if err != nil {
return nil, err
}
@@ -393,11 +397,12 @@ func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName strin
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
return client.GetIndexes(ctx, database, tableName)
defer pc.Release()
return pc.Client.GetIndexes(ctx, database, tableName)
case "mongo", "redis":
return []map[string]interface{}{}, nil
@@ -419,11 +424,12 @@ func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tabl
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
return client.PreviewTableStructure(ctx, database, tableName, structure)
defer pc.Release()
return pc.Client.PreviewTableStructure(ctx, database, tableName, structure)
case "mongo":
client, err := s.pool.GetMongoClient(conn)
@@ -449,11 +455,12 @@ func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, table
switch conn.Type {
case "mysql":
client, err := s.pool.GetMySQLClient(conn)
if err != nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
pc := s.pool.GetMySQLClient(conn)
if pc.Client == nil {
return nil, fmt.Errorf("获取 MySQL 客户端失败")
}
return client.UpdateTableStructure(ctx, database, tableName, structure)
defer pc.Release()
return pc.Client.UpdateTableStructure(ctx, database, tableName, structure)
case "mongo":
client, err := s.pool.GetMongoClient(conn)