新增:Markdown编辑器/数据库优化/安全修复
- Markdown 编辑器:实时预览、PDF 导出、独立查看器 - 数据库优化:动态连接池、查询缓存、Redis Pipeline - 窗口置顶功能 - 文件系统增强:右键菜单、编辑器集成、收藏夹重构 - 安全修复:XSS 防护、路径穿越、HTML 注入 - 代码质量:正则预编译、缓存锁优化、死代码清理
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"u-desk/internal/storage"
|
||||
"u-desk/internal/service"
|
||||
"u-desk/internal/storage/models"
|
||||
)
|
||||
|
||||
// ConnectionAPI 连接管理API
|
||||
type ConnectionAPI struct {
|
||||
connService *storage.ConnectionService
|
||||
connService *service.ConnectionService
|
||||
}
|
||||
|
||||
// NewConnectionAPI 创建连接管理API
|
||||
func NewConnectionAPI() (*ConnectionAPI, error) {
|
||||
connService, err := storage.NewConnectionService()
|
||||
connService, err := service.NewConnectionService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,11 +82,7 @@ func (api *ConnectionAPI) DeleteDbConnection(id uint) error {
|
||||
}
|
||||
|
||||
func (api *ConnectionAPI) TestDbConnection(id uint) error {
|
||||
conn, err := api.connService.GetConnection(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return api.connService.TestConnection(conn)
|
||||
return api.connService.TestConnection(id)
|
||||
}
|
||||
|
||||
// TestConnectionRequest 测试连接请求结构体(不保存数据)
|
||||
@@ -104,14 +100,9 @@ type TestConnectionRequest struct {
|
||||
// TestDbConnectionWithParams 测试数据库连接(直接传入参数,不保存数据)
|
||||
func (api *ConnectionAPI) TestDbConnectionWithParams(req TestConnectionRequest) error {
|
||||
return api.connService.TestConnectionWithParams(
|
||||
req.Type,
|
||||
req.Host,
|
||||
req.Port,
|
||||
req.Username,
|
||||
req.Password,
|
||||
req.Database,
|
||||
req.Options,
|
||||
req.ID,
|
||||
req.Type, req.Host, req.Port,
|
||||
req.Username, req.Password, req.Database,
|
||||
req.Options, req.ID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -130,13 +121,8 @@ type LoadAllDatabasesRequest struct {
|
||||
// LoadAllDatabases 加载全部数据库列表
|
||||
func (api *ConnectionAPI) LoadAllDatabases(req LoadAllDatabasesRequest) ([]string, error) {
|
||||
return api.connService.LoadAllDatabases(
|
||||
req.Type,
|
||||
req.Host,
|
||||
req.Port,
|
||||
req.Username,
|
||||
req.Password,
|
||||
req.Database,
|
||||
req.Options,
|
||||
req.ID,
|
||||
req.Type, req.Host, req.Port,
|
||||
req.Username, req.Password, req.Database,
|
||||
req.Options, req.ID,
|
||||
)
|
||||
}
|
||||
|
||||
379
internal/api/pdf_api.go
Normal file
379
internal/api/pdf_api.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/chromedp/cdproto/page"
|
||||
"github.com/chromedp/chromedp"
|
||||
"github.com/yuin/goldmark"
|
||||
"u-desk/internal/common"
|
||||
)
|
||||
|
||||
// PdfExportRequest PDF导出请求结构体
|
||||
type PdfExportRequest struct {
|
||||
Content string `json:"content"` // Markdown/HTML内容
|
||||
Title string `json:"title"` // PDF标题
|
||||
FileName string `json:"fileName"` // 文件名(不含扩展名)
|
||||
FontSize int `json:"fontSize"` // 字体大小
|
||||
PageWidth int `json:"pageWidth"` // 页面宽度(mm)
|
||||
PageHeight int `json:"pageHeight"` // 页面高度(mm)
|
||||
}
|
||||
|
||||
// PdfExportResponse PDF导出响应结构体
|
||||
type PdfExportResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Path string `json:"path"` // PDF文件保存路径
|
||||
Size int64 `json:"size"` // 文件大小(字节)
|
||||
}
|
||||
|
||||
// PdfAPI PDF导出API
|
||||
type PdfAPI struct {
|
||||
// 可以在这里添加依赖,如文件系统服务等
|
||||
}
|
||||
|
||||
// NewPdfAPI 创建PDF导出API
|
||||
func NewPdfAPI() (*PdfAPI, error) {
|
||||
return &PdfAPI{}, nil
|
||||
}
|
||||
|
||||
// ExportMarkdownToPDF 将Markdown内容导出为PDF - 使用chromedp实现
|
||||
func (api *PdfAPI) ExportMarkdownToPDF(req PdfExportRequest) (*PdfExportResponse, error) {
|
||||
// 验证参数
|
||||
if strings.TrimSpace(req.Content) == "" {
|
||||
return nil, fmt.Errorf("内容不能为空")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.FileName) == "" {
|
||||
req.FileName = "document_" + time.Now().Format("20060102_150405")
|
||||
}
|
||||
|
||||
if req.FontSize <= 0 {
|
||||
req.FontSize = 12
|
||||
}
|
||||
|
||||
// 设置默认页面尺寸(A4)
|
||||
if req.PageWidth <= 0 {
|
||||
req.PageWidth = 210
|
||||
}
|
||||
if req.PageHeight <= 0 {
|
||||
req.PageHeight = 297
|
||||
}
|
||||
|
||||
// 将Markdown转换为HTML
|
||||
htmlContent := api.markdownToHTML(req.Content, req.Title, req.FontSize)
|
||||
|
||||
// 使用chromedp生成PDF
|
||||
pdfBuffer, err := api.generatePDFFromHTML(htmlContent, req.Title, req.PageWidth, req.PageHeight)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成PDF失败: %v", err)
|
||||
}
|
||||
|
||||
// 生成文件名
|
||||
if !strings.HasSuffix(strings.ToLower(req.FileName), ".pdf") {
|
||||
req.FileName += ".pdf"
|
||||
}
|
||||
|
||||
// 获取用户桌面目录作为默认保存位置
|
||||
saveDir := api.getDesktopDirectory()
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(saveDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 完整保存路径
|
||||
savePath := filepath.Join(saveDir, filepath.Base(req.FileName))
|
||||
|
||||
// 保存PDF文件
|
||||
err = os.WriteFile(savePath, pdfBuffer, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("保存PDF文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
fileInfo, err := os.Stat(savePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取文件信息失败: %v", err)
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
return &PdfExportResponse{
|
||||
Success: true,
|
||||
Message: "PDF生成成功",
|
||||
Path: savePath,
|
||||
Size: fileInfo.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// markdownToHTML 将Markdown转换为HTML
|
||||
func (api *PdfAPI) markdownToHTML(markdownContent string, title string, fontSize int) string {
|
||||
// 基础HTML模板
|
||||
htmlTemplate := `<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 40px 20px;
|
||||
font-size: %dpx;
|
||||
}
|
||||
h1, h2, h3, h4, h5, h6 {
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
font-weight: 600;
|
||||
line-height: 1.25;
|
||||
}
|
||||
h1 {
|
||||
font-size: 2em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
h2 {
|
||||
font-size: 1.5em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
h3 {
|
||||
font-size: 1.25em;
|
||||
}
|
||||
h4 {
|
||||
font-size: 1em;
|
||||
}
|
||||
h5 {
|
||||
font-size: 0.875em;
|
||||
}
|
||||
h6 {
|
||||
font-size: 0.85em;
|
||||
color: #6a737d;
|
||||
}
|
||||
p {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
blockquote {
|
||||
margin: 0 0 16px;
|
||||
padding: 0 1em;
|
||||
color: #6a737d;
|
||||
border-left: 0.25em solid #dfe2e5;
|
||||
}
|
||||
ul, ol {
|
||||
padding-left: 2em;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
code {
|
||||
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace;
|
||||
background-color: rgba(27,31,35,0.05);
|
||||
border-radius: 3px;
|
||||
font-size: 85%;
|
||||
margin: 0;
|
||||
padding: 0.2em 0.4em;
|
||||
}
|
||||
pre {
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
pre code {
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
font-size: 100%;
|
||||
}
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin-bottom: 16px;
|
||||
border: 1px solid #dfe2e5;
|
||||
}
|
||||
th, td {
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #dfe2e5;
|
||||
text-align: left;
|
||||
}
|
||||
th {
|
||||
background-color: #f6f8fa;
|
||||
font-weight: 600;
|
||||
}
|
||||
img {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
margin: 16px 0;
|
||||
}
|
||||
hr {
|
||||
height: 0.25em;
|
||||
padding: 0;
|
||||
margin: 24px 0;
|
||||
background-color: #e1e4e8;
|
||||
border: 0;
|
||||
}
|
||||
.title {
|
||||
text-align: center;
|
||||
margin-bottom: 32px;
|
||||
font-size: 1.5em;
|
||||
color: #2c3e50;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="title">%s</div>
|
||||
%s
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// 标题处理
|
||||
docTitle := ""
|
||||
if title != "" {
|
||||
docTitle = html.EscapeString(title)
|
||||
} else {
|
||||
docTitle = "文档"
|
||||
}
|
||||
|
||||
// Markdown转HTML(使用goldmark)
|
||||
var htmlContent string
|
||||
var htmlBuf strings.Builder
|
||||
if err := goldmark.Convert([]byte(markdownContent), &htmlBuf); err != nil {
|
||||
htmlContent = "<p>Markdown 解析失败</p>"
|
||||
} else {
|
||||
htmlContent = htmlBuf.String()
|
||||
}
|
||||
|
||||
// 生成完整的HTML
|
||||
fullHTML := fmt.Sprintf(htmlTemplate, fontSize, docTitle, htmlContent)
|
||||
|
||||
return fullHTML
|
||||
}
|
||||
|
||||
// generatePDFFromHTML 使用chromedp从HTML生成PDF
|
||||
func (api *PdfAPI) generatePDFFromHTML(htmlContent, title string, pageWidth, pageHeight int) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
// 配置chromedp选项
|
||||
opts := []chromedp.ExecAllocatorOption{
|
||||
chromedp.Flag("headless", true),
|
||||
chromedp.Flag("disable-gpu", true),
|
||||
chromedp.Flag("no-sandbox", true),
|
||||
chromedp.Flag("disable-dev-shm-usage", true),
|
||||
chromedp.Flag("disable-software-rasterizer", true),
|
||||
chromedp.Flag("disable-extensions", true),
|
||||
chromedp.Flag("disable-notifications", true),
|
||||
}
|
||||
|
||||
// 在Windows上设置Chrome路径
|
||||
if common.IsWindows() {
|
||||
// 常见的Windows Chrome路径
|
||||
chromePaths := []string{
|
||||
"C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe",
|
||||
"C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe",
|
||||
"C:\\Users\\" + os.Getenv("USERNAME") + "\\AppData\\Local\\Google\\Chrome\\Application\\chrome.exe",
|
||||
}
|
||||
|
||||
for _, path := range chromePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
opts = append(opts, chromedp.ExecPath(path))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建执行分配器上下文
|
||||
allocCtx, allocCancel := chromedp.NewExecAllocator(ctx, opts...)
|
||||
defer allocCancel()
|
||||
|
||||
// 创建chromedp上下文
|
||||
chromeCtx, chromeCancel := chromedp.NewContext(allocCtx)
|
||||
defer chromeCancel()
|
||||
|
||||
// 创建一个临时的目录用于PDF生成
|
||||
tempDir, err := os.MkdirTemp("", "pdf_gen")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建临时目录失败: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// 将HTML写入临时文件
|
||||
htmlFile := filepath.Join(tempDir, "document.html")
|
||||
if err := os.WriteFile(htmlFile, []byte(htmlContent), 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入HTML文件失败: %v", err)
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
|
||||
// 使用 file URL 加载本地HTML文件
|
||||
err = chromedp.Run(chromeCtx,
|
||||
// 导航到HTML文件
|
||||
chromedp.Navigate("file://"+htmlFile),
|
||||
// 等待页面加载完成
|
||||
chromedp.WaitReady("body"),
|
||||
// 打印到PDF
|
||||
chromedp.ActionFunc(func(ctx context.Context) error {
|
||||
// 设置页面打印参数
|
||||
printToPDF := page.PrintToPDF().
|
||||
WithPrintBackground(true).
|
||||
WithLandscape(false).
|
||||
WithMarginTop(0).
|
||||
WithMarginBottom(0).
|
||||
WithMarginLeft(0).
|
||||
WithMarginRight(0).
|
||||
WithPaperWidth(float64(pageWidth) / 25.4). // mm to inches
|
||||
WithPaperHeight(float64(pageHeight) / 25.4) // mm to inches
|
||||
|
||||
// 执行打印并获取PDF数据
|
||||
var err error
|
||||
buf, _, err = printToPDF.Do(ctx)
|
||||
return err
|
||||
}),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chromedp执行失败: %v", err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// getDesktopDirectory 获取用户桌面目录
|
||||
func (api *PdfAPI) getDesktopDirectory() string {
|
||||
// Windows系统
|
||||
if common.IsWindows() {
|
||||
home := os.Getenv("USERPROFILE")
|
||||
if home != "" {
|
||||
return filepath.Join(home, "Desktop")
|
||||
}
|
||||
}
|
||||
|
||||
// Linux/Mac系统
|
||||
home := os.Getenv("HOME")
|
||||
if home != "" {
|
||||
return filepath.Join(home, "Desktop")
|
||||
}
|
||||
|
||||
// 备用:当前目录
|
||||
return "."
|
||||
}
|
||||
|
||||
// SelectDirectory 选择保存目录(简化版,实际应该使用Wails runtime)
|
||||
func (api *PdfAPI) SelectDirectory() (string, error) {
|
||||
// 简化版:直接返回桌面目录
|
||||
desktop := api.getDesktopDirectory()
|
||||
if desktop == "." {
|
||||
return "", fmt.Errorf("无法确定默认目录")
|
||||
}
|
||||
return desktop, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// InterfaceSliceToStringSlice 将 []interface{} 安全转换为 []string
|
||||
@@ -54,3 +55,18 @@ func Difference[T comparable](a, b []T) []T {
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
// IsWindows 判断是否为Windows系统
|
||||
func IsWindows() bool {
|
||||
return runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
// IsMac 判断是否为Mac系统
|
||||
func IsMac() bool {
|
||||
return runtime.GOOS == "darwin"
|
||||
}
|
||||
|
||||
// IsLinux 判断是否为Linux系统
|
||||
func IsLinux() bool {
|
||||
return runtime.GOOS == "linux"
|
||||
}
|
||||
|
||||
@@ -7,20 +7,106 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// 旧版硬编码密钥(用于兼容迁移已有加密数据)
|
||||
var legacyKey = []byte("go-desk-db-cli-key-32bytes123456")
|
||||
|
||||
var (
|
||||
// 默认密钥(实际应用中应该从配置文件或环境变量读取)
|
||||
// AES-256 需要 32 字节密钥
|
||||
// "go-desk-db-cli-key-32bytes123456" = 32 bytes
|
||||
defaultKey = []byte("go-desk-db-cli-key-32bytes123456") // 32 bytes for AES-256
|
||||
encryptionKey []byte
|
||||
keyOnce sync.Once
|
||||
keyInitErr error
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 验证密钥长度
|
||||
if len(defaultKey) != 32 {
|
||||
panic(fmt.Sprintf("AES-256 密钥长度必须为 32 字节,当前为 %d 字节", len(defaultKey)))
|
||||
// getKey 获取或创建机器唯一密钥
|
||||
// 首次启动时生成并持久化到用户配置目录,后续直接读取
|
||||
func getKey() ([]byte, error) {
|
||||
keyOnce.Do(func() {
|
||||
keyFile, err := getKeyFilePath()
|
||||
if err != nil {
|
||||
keyInitErr = fmt.Errorf("获取密钥路径失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试读取已有密钥
|
||||
if data, err := os.ReadFile(keyFile); err == nil && len(data) == 32 {
|
||||
encryptionKey = data
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新密钥
|
||||
newKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, newKey); err != nil {
|
||||
keyInitErr = fmt.Errorf("生成密钥失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 持久化密钥
|
||||
dir := filepath.Dir(keyFile)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
keyInitErr = fmt.Errorf("创建密钥目录失败: %v", err)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(keyFile, newKey, 0600); err != nil {
|
||||
keyInitErr = fmt.Errorf("保存密钥失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
encryptionKey = newKey
|
||||
})
|
||||
|
||||
return encryptionKey, keyInitErr
|
||||
}
|
||||
|
||||
// getKeyFilePath 返回密钥文件路径
|
||||
func getKeyFilePath() (string, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configDir, "u-desk", ".aes-key"), nil
|
||||
}
|
||||
|
||||
// DecryptPasswordV2 使用指定密钥解密(用于密钥迁移)
|
||||
func DecryptPasswordV2(encryptedPassword string, key []byte) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(encryptedPassword) < 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建解密器失败: %v", err)
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
}
|
||||
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("密文长度不足")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// EncryptPassword 加密密码
|
||||
@@ -29,7 +115,12 @@ func EncryptPassword(password string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(defaultKey)
|
||||
key, err := getKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取加密密钥失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建加密器失败: %v", err)
|
||||
}
|
||||
@@ -53,47 +144,32 @@ func EncryptPassword(password string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptPassword 解密密码
|
||||
// DecryptPassword 解密密码(自动回退旧密钥兼容旧数据)
|
||||
func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// 如果加密字符串为空或格式不正确,返回空字符串
|
||||
if len(encryptedPassword) < 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Base64 解码
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
key, err := getKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码失败: %v", err)
|
||||
return "", fmt.Errorf("获取解密密钥失败: %v", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(defaultKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建解密器失败: %v", err)
|
||||
// 先用新密钥尝试解密
|
||||
result, err := DecryptPasswordV2(encryptedPassword, key)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 使用 GCM 模式
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建 GCM 失败: %v", err)
|
||||
// 新密钥失败,尝试旧密钥(兼容已迁移的旧数据)
|
||||
result, err = DecryptPasswordV2(encryptedPassword, legacyKey)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 提取 nonce
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("密文长度不足")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
|
||||
// 解密
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
// 两种密钥都失败
|
||||
return "", fmt.Errorf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
479
internal/dbclient/cache.go
Normal file
479
internal/dbclient/cache.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// QueryCache 查询缓存
|
||||
type QueryCache struct {
|
||||
items map[string]*CachedQuery
|
||||
size int
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 智能缓存策略
|
||||
hitRate float64 // 缓存命中率
|
||||
hitCount int64 // 命中次数
|
||||
missCount int64 // 未命中次数
|
||||
evictionCount int64 // 驱逐次数
|
||||
hotQueries map[string]bool // 热点查询标记
|
||||
cooldowns map[string]time.Time // 冷却时间(避免频繁驱逐)
|
||||
|
||||
// 内存限制
|
||||
maxMemoryBytes int64 // 缓存最大内存(字节),默认 100MB
|
||||
usedMemory int64 // 当前估算内存使用量
|
||||
}
|
||||
|
||||
// NewQueryCache 创建新的查询缓存
|
||||
func NewQueryCache(size int, ttl time.Duration) *QueryCache {
|
||||
cache := &QueryCache{
|
||||
items: make(map[string]*CachedQuery),
|
||||
size: size,
|
||||
ttl: ttl,
|
||||
stopCh: make(chan struct{}),
|
||||
hitRate: 0.0,
|
||||
hitCount: 0,
|
||||
missCount: 0,
|
||||
evictionCount: 0,
|
||||
hotQueries: make(map[string]bool),
|
||||
cooldowns: make(map[string]time.Time),
|
||||
maxMemoryBytes: 100 * 1024 * 1024, // 默认 100MB
|
||||
}
|
||||
|
||||
// 启动清理协程
|
||||
cache.StartCleanup()
|
||||
|
||||
// 启动统计协程
|
||||
cache.StartStatsCollection()
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get 从缓存中获取查询结果
|
||||
func (c *QueryCache) Get(params QueryParams) (*CachedQuery, error) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
c.mu.RLock()
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
c.missCount++
|
||||
_, inCooldown := c.cooldowns[key]
|
||||
if inCooldown && time.Now().Before(c.cooldowns[key]) {
|
||||
c.mu.RUnlock()
|
||||
return nil, ErrCacheCooldown
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
return nil, ErrCacheNotFound
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(item.ExpiryTime) {
|
||||
if c.isHotQuery(key) {
|
||||
c.mu.RUnlock()
|
||||
c.mu.Lock()
|
||||
item.ExpiryTime = time.Now().Add(c.ttl)
|
||||
c.hitCount++
|
||||
c.markAsHot(key)
|
||||
c.mu.Unlock()
|
||||
return item, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.evictionCount++
|
||||
c.missCount++
|
||||
c.mu.Unlock()
|
||||
return nil, ErrCacheExpired
|
||||
}
|
||||
|
||||
// 命中
|
||||
c.hitCount++
|
||||
needsMark := !c.hotQueries[key]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if needsMark {
|
||||
c.mu.Lock()
|
||||
c.markAsHot(key)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Set 将查询结果存入缓存
|
||||
func (c *QueryCache) Set(params QueryParams, item *CachedQuery) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
// 估算条目内存大小
|
||||
itemSize := c.estimateSize(params, item)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// 更新统计
|
||||
c.recordQueryAttempt(key)
|
||||
|
||||
// 如果超过内存限制,执行驱逐直到有空间
|
||||
for c.usedMemory+itemSize > c.maxMemoryBytes && len(c.items) > 0 {
|
||||
c.smartEvict(key)
|
||||
}
|
||||
|
||||
// 如果条目数已满,执行智能驱逐
|
||||
if len(c.items) >= c.size {
|
||||
c.smartEvict(key)
|
||||
}
|
||||
|
||||
// 如果已有旧条目,先减去旧的大小
|
||||
if old, exists := c.items[key]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(old)
|
||||
}
|
||||
|
||||
c.items[key] = item
|
||||
c.usedMemory += itemSize
|
||||
|
||||
// 标记为热点查询
|
||||
c.markAsHot(key)
|
||||
}
|
||||
|
||||
// smartEvict 智能驱逐策略
|
||||
func (c *QueryCache) smartEvict(newKey string) {
|
||||
if len(c.items) == 0 {
|
||||
return
|
||||
}
|
||||
// LRU + LFU 混合策略
|
||||
var evictKey string
|
||||
var worstScore float64 = -1
|
||||
|
||||
for key, item := range c.items {
|
||||
if key == newKey {
|
||||
continue
|
||||
}
|
||||
|
||||
score := c.calculateEvictionScore(key, item)
|
||||
if score > worstScore {
|
||||
worstScore = score
|
||||
evictKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if evictKey != "" {
|
||||
if evicted, exists := c.items[evictKey]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(evicted)
|
||||
}
|
||||
c.cooldowns[evictKey] = time.Now().Add(1 * time.Minute)
|
||||
delete(c.items, evictKey)
|
||||
c.evictionCount++
|
||||
}
|
||||
}
|
||||
|
||||
// calculateEvictionScore 计算驱逐分数(越低越适合保留)
|
||||
func (c *QueryCache) calculateEvictionScore(key string, item *CachedQuery) float64 {
|
||||
now := time.Now()
|
||||
|
||||
// 基础分数
|
||||
score := 1.0
|
||||
|
||||
// 热点查询加分(优先保留)
|
||||
if c.isHotQuery(key) {
|
||||
score -= 0.5
|
||||
}
|
||||
|
||||
// 接近过期的加分(优先驱逐即将过期的)
|
||||
if item.ExpiryTime.Sub(now) < c.ttl/2 {
|
||||
score += 0.3
|
||||
}
|
||||
|
||||
// 最近使用的加分(优先保留最近使用的)
|
||||
if !item.LastUsed.IsZero() {
|
||||
recency := now.Sub(item.LastUsed)
|
||||
if recency < 5*time.Minute {
|
||||
score -= 0.2
|
||||
}
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// isHotQuery 检查是否为热点查询
|
||||
func (c *QueryCache) isHotQuery(key string) bool {
|
||||
return c.hotQueries[key]
|
||||
}
|
||||
|
||||
// markAsHot 标记为热点查询
|
||||
func (c *QueryCache) markAsHot(key string) {
|
||||
c.hotQueries[key] = true
|
||||
}
|
||||
|
||||
// cleanupHotMarkers 清理热点标记
|
||||
func (c *QueryCache) cleanupHotMarkers() {
|
||||
now := time.Now()
|
||||
for key := range c.hotQueries {
|
||||
// 清理超过10分钟未使用的热点标记
|
||||
if item, exists := c.items[key]; exists {
|
||||
if now.Sub(item.LastUsed) > 10*time.Minute {
|
||||
delete(c.hotQueries, key)
|
||||
}
|
||||
} else {
|
||||
delete(c.hotQueries, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordQueryAttempt 记录查询尝试
|
||||
func (c *QueryCache) recordQueryAttempt(key string) {
|
||||
// 更新命中率
|
||||
c.updateHitRate()
|
||||
|
||||
// 更新最后使用时间
|
||||
if item, exists := c.items[key]; exists {
|
||||
item.LastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// updateHitRate 更新命中率
|
||||
func (c *QueryCache) updateHitRate() {
|
||||
total := c.hitCount + c.missCount
|
||||
if total > 0 {
|
||||
c.hitRate = float64(c.hitCount) / float64(total)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 从缓存中删除指定查询
|
||||
func (c *QueryCache) Delete(params QueryParams) {
|
||||
key := c.generateKey(params)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.usedMemory -= c.estimateItemSize(item)
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear 清空整个缓存
|
||||
func (c *QueryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.items = make(map[string]*CachedQuery)
|
||||
c.usedMemory = 0
|
||||
}
|
||||
|
||||
// Size 获取缓存大小
|
||||
func (c *QueryCache) Size() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return len(c.items)
|
||||
}
|
||||
|
||||
// CleanupExpired 清理过期的缓存条目
|
||||
func (c *QueryCache) CleanupExpired() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiryTime) {
|
||||
c.usedMemory -= c.estimateItemSize(item)
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys 获取缓存中所有的键
|
||||
func (c *QueryCache) Keys() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
keys := make([]string, 0, len(c.items))
|
||||
for key := range c.items {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计信息
|
||||
func (c *QueryCache) Stats() CacheStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
expired := 0
|
||||
active := 0
|
||||
|
||||
for _, item := range c.items {
|
||||
if now.After(item.ExpiryTime) {
|
||||
expired++
|
||||
} else {
|
||||
active++
|
||||
}
|
||||
}
|
||||
|
||||
return CacheStats{
|
||||
TotalItems: len(c.items),
|
||||
ActiveItems: active,
|
||||
ExpiredItems: expired,
|
||||
Size: c.size,
|
||||
TTL: c.ttl,
|
||||
HitRate: c.hitRate,
|
||||
HitCount: c.hitCount,
|
||||
MissCount: c.missCount,
|
||||
EvictionCount: c.evictionCount,
|
||||
HotQueries: len(c.hotQueries),
|
||||
}
|
||||
}
|
||||
|
||||
// generateKey 生成缓存键
|
||||
func (c *QueryCache) generateKey(params QueryParams) string {
|
||||
key := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
|
||||
params.SQL, params.Database, params.Limit, params.Offset,
|
||||
params.Table, params.Where, params.SortBy, params.IsReadOnly)
|
||||
h := sha256.Sum256([]byte(key))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// evictOldest 删除最老的缓存条目
|
||||
func (c *QueryCache) evictOldest() {
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
|
||||
for key, item := range c.items {
|
||||
if oldestKey == "" || item.CreatedAt.Before(oldestTime) {
|
||||
oldestKey = key
|
||||
oldestTime = item.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
if oldestKey != "" {
|
||||
delete(c.items, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanup 启动清理协程
|
||||
func (c *QueryCache) StartCleanup() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(c.ttl / 2) // 每 TTL/2 时间检查一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.CleanupExpired()
|
||||
c.cleanupCooldowns() // 清理冷却时间
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StartStatsCollection 启动统计收集协程
|
||||
func (c *QueryCache) StartStatsCollection() {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute) // 每分钟收集一次统计
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.updateHitRate()
|
||||
c.cleanupHotMarkers()
|
||||
case <-c.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupCooldowns 清理冷却时间
|
||||
func (c *QueryCache) cleanupCooldowns() {
|
||||
now := time.Now()
|
||||
for key, cooldown := range c.cooldowns {
|
||||
if now.After(cooldown) {
|
||||
delete(c.cooldowns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止缓存清理
|
||||
func (c *QueryCache) Stop() {
|
||||
close(c.stopCh)
|
||||
c.wg.Wait()
|
||||
}
|
||||
|
||||
// CacheStats 缓存统计信息
|
||||
type CacheStats struct {
|
||||
TotalItems int
|
||||
ActiveItems int
|
||||
ExpiredItems int
|
||||
Size int
|
||||
TTL time.Duration
|
||||
HitRate float64
|
||||
HitCount int64
|
||||
MissCount int64
|
||||
EvictionCount int64
|
||||
HotQueries int
|
||||
}
|
||||
|
||||
// 缓存错误定义
|
||||
var (
|
||||
ErrCacheNotFound = &CacheError{Message: "缓存未找到"}
|
||||
ErrCacheExpired = &CacheError{Message: "缓存已过期"}
|
||||
ErrCacheCooldown = &CacheError{Message: "查询在冷却中"}
|
||||
)
|
||||
|
||||
// CacheError 缓存错误
|
||||
type CacheError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *CacheError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// estimateSize 估算缓存条目的内存大小(字节)
|
||||
func (c *QueryCache) estimateSize(params QueryParams, item *CachedQuery) int64 {
|
||||
size := int64(len(params.SQL) + len(params.Database) + len(params.Table) +
|
||||
len(params.Where) + len(params.SortBy))
|
||||
if item != nil && item.Result != nil {
|
||||
size += c.estimateItemSize(item)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// estimateItemSize 估算 CachedQuery 的内存大小
|
||||
func (c *QueryCache) estimateItemSize(item *CachedQuery) int64 {
|
||||
if item == nil || item.Result == nil {
|
||||
return 128 // 基础结构体大小
|
||||
}
|
||||
size := int64(128) // CachedQuery 结构体基础大小
|
||||
for _, row := range item.Result.Data {
|
||||
for _, v := range row {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
size += int64(len(val))
|
||||
case []byte:
|
||||
size += int64(len(val))
|
||||
case nil:
|
||||
// 无额外开销
|
||||
default:
|
||||
size += 64 // 其他类型的估算值
|
||||
}
|
||||
}
|
||||
}
|
||||
size += int64(len(item.Result.Columns)) * 64 // 列名估算
|
||||
return size
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"u-desk/internal/common"
|
||||
"u-desk/internal/crypto"
|
||||
@@ -18,7 +19,10 @@ type ConnectionPool struct {
|
||||
mongoClients map[uint]*MongoClient
|
||||
|
||||
// 新增:MySQL 真连接池
|
||||
mysqlPool *MySQLConnectionPool
|
||||
mysqlPool *MySQLConnectionPool
|
||||
|
||||
// 查询优化器
|
||||
queryOptimizer *QueryOptimizer
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -38,18 +42,37 @@ func GetPool() *ConnectionPool {
|
||||
// 启动维护协程
|
||||
mysqlPool.StartMaintenance()
|
||||
|
||||
// 创建查询优化器
|
||||
queryOptimizer := NewQueryOptimizer(nil)
|
||||
|
||||
globalPool = &ConnectionPool{
|
||||
mysqlClients: make(map[uint]*MySQLClient),
|
||||
redisClients: make(map[uint]*RedisClient),
|
||||
mongoClients: make(map[uint]*MongoClient),
|
||||
mysqlPool: mysqlPool,
|
||||
mysqlPool: mysqlPool,
|
||||
queryOptimizer: queryOptimizer,
|
||||
}
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// PooledClient 带释放语义的客户端包装
|
||||
type PooledClient struct {
|
||||
Client *MySQLClient
|
||||
entry *MySQLPoolEntry
|
||||
pool *MySQLConnectionPool
|
||||
fromPool bool
|
||||
}
|
||||
|
||||
// Release 释放连接回连接池
|
||||
func (pc *PooledClient) Release() {
|
||||
if pc.fromPool && pc.pool != nil && pc.entry != nil {
|
||||
pc.pool.Release(pc.entry)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMySQLClient 获取或创建 MySQL 客户端(使用连接池)
|
||||
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) *PooledClient {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
@@ -57,16 +80,25 @@ func (p *ConnectionPool) GetMySQLClient(conn *models.DbConnection) (*MySQLClient
|
||||
if p.mysqlPool != nil {
|
||||
entry, err := p.mysqlPool.Acquire(conn)
|
||||
if err == nil {
|
||||
// 成功从池中获取连接
|
||||
return entry.Client, nil
|
||||
return &PooledClient{Client: entry.Client, entry: entry, pool: p.mysqlPool, fromPool: true}
|
||||
}
|
||||
|
||||
// 连接池错误,返回
|
||||
return nil, err
|
||||
p.logPoolError("Acquire failed", err)
|
||||
}
|
||||
|
||||
// 降级到原有逻辑(如果连接池未初始化)
|
||||
return p.getMySQLClientLegacy(conn)
|
||||
// 降级到原有逻辑
|
||||
client, err := p.getMySQLClientLegacy(conn)
|
||||
if err != nil {
|
||||
return &PooledClient{Client: nil, fromPool: false}
|
||||
}
|
||||
return &PooledClient{Client: client, fromPool: false}
|
||||
}
|
||||
|
||||
// logPoolError 记录连接池错误
|
||||
func (p *ConnectionPool) logPoolError(operation string, err error) {
|
||||
if p.queryOptimizer != nil {
|
||||
// 通过查询优化器记录错误
|
||||
p.queryOptimizer.RecordPoolError(operation, err)
|
||||
}
|
||||
}
|
||||
|
||||
// getMySQLClientLegacy 原有的 MySQL 客户端获取逻辑(向后兼容)
|
||||
@@ -115,6 +147,92 @@ func (p *ConnectionPool) GetMySQLPoolStats() *PoolStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OptimizeQuery 优化查询执行
|
||||
func (p *ConnectionPool) OptimizeQuery(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (*QueryResult, time.Duration, error) {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, 0, fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.OptimizeQuery(ctx, pc.Client, sqlStr, database)
|
||||
}
|
||||
|
||||
// 降级到普通查询
|
||||
startTime := time.Now()
|
||||
result, err := pc.Client.ExecuteQuery(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// ExecuteOptimizedUpdate 执行优化的更新操作
|
||||
func (p *ConnectionPool) ExecuteOptimizedUpdate(ctx context.Context, conn *models.DbConnection, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return 0, 0, fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.ExecuteOptimizedUpdate(ctx, pc.Client, sqlStr, database)
|
||||
}
|
||||
|
||||
// 降级到普通更新
|
||||
startTime := time.Now()
|
||||
result, err := pc.Client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (p *ConnectionPool) GetQueryStats() QueryStats {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetQueryStats()
|
||||
}
|
||||
return QueryStats{}
|
||||
}
|
||||
|
||||
// GetSlowQueries 获取慢查询记录
|
||||
func (p *ConnectionPool) GetSlowQueries(limit int) []SlowQuery {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetSlowQueries(limit)
|
||||
}
|
||||
return []SlowQuery{}
|
||||
}
|
||||
|
||||
// GetIndexSuggestions 获取索引建议
|
||||
func (p *ConnectionPool) GetIndexSuggestions(table string) []IndexSuggestion {
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GetIndexSuggestions(table)
|
||||
}
|
||||
return []IndexSuggestion{}
|
||||
}
|
||||
|
||||
// GenerateIndexSuggestions 为表生成索引建议
|
||||
func (p *ConnectionPool) GenerateIndexSuggestions(ctx context.Context, conn *models.DbConnection, database, table string) error {
|
||||
pc := p.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return fmt.Errorf("获取 MySQL 连接失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
// 使用查询优化器
|
||||
if p.queryOptimizer != nil {
|
||||
return p.queryOptimizer.GenerateIndexSuggestions(ctx, pc.Client, database, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearQueryCache 清空查询缓存
|
||||
func (p *ConnectionPool) ClearQueryCache() {
|
||||
if p.queryOptimizer != nil {
|
||||
p.queryOptimizer.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// GetRedisClient 获取或创建 Redis 客户端
|
||||
func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient, error) {
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -34,22 +34,40 @@ type PoolConfig struct {
|
||||
SlowConnThreshold time.Duration
|
||||
// 连接池最大容量(防止资源耗尽)
|
||||
MaxPoolCapacity int
|
||||
|
||||
// 动态连接池配置
|
||||
EnableDynamicScaling bool // 是否启用动态连接池调整
|
||||
DynamicScaleFactor float64 // 动态调整因子(0.5-2.0)
|
||||
ScaleUpThreshold float64 // 扩容阈值(0-1.0,当使用率超过此值时扩容)
|
||||
ScaleDownThreshold float64 // 缩容阈值(0-1.0,当使用率低于此值时缩容)
|
||||
MinScaleUpInterval time.Duration // 最小扩容间隔(防止频繁调整)
|
||||
MinScaleDownInterval time.Duration // 最小缩容间隔
|
||||
MaxIdleTimeForScale time.Duration // 用于动态调整的最大空闲时间
|
||||
}
|
||||
|
||||
// DefaultPoolConfig 返回默认连接池配置
|
||||
func DefaultPoolConfig() *PoolConfig {
|
||||
return &PoolConfig{
|
||||
MaxOpenConns: 20, // 最大20个连接
|
||||
MaxIdleConns: 10, // 最大10个空闲
|
||||
ConnMaxLifetime: 30 * time.Minute, // 连接最长30分钟
|
||||
ConnMaxIdleTime: 10 * time.Minute, // 空闲10分钟关闭
|
||||
MinIdleConns: 2, // 保持2个最小空闲
|
||||
ConnTimeout: 5 * time.Second, // 连接超时5秒
|
||||
HealthCheckInterval: 30 * time.Second, // 30秒健康检查一次
|
||||
MaxOpenConns: 50, // 最大50个连接(提高并发)
|
||||
MaxIdleConns: 20, // 最大20个空闲(提高响应速度)
|
||||
ConnMaxLifetime: 60 * time.Minute, // 连接最长60分钟(延长连接生命周期)
|
||||
ConnMaxIdleTime: 15 * time.Minute, // 空闲15分钟关闭(更长的空闲时间)
|
||||
MinIdleConns: 5, // 保持5个最小空闲(更好的响应性能)
|
||||
ConnTimeout: 3 * time.Second, // 连接超时3秒(更快失败)
|
||||
HealthCheckInterval: 20 * time.Second, // 20秒健康检查一次(更频繁的健康检查)
|
||||
EnableWarmup: true, // 启用预热
|
||||
EnableSlowConnLog: true, // 启用慢连接日志
|
||||
SlowConnThreshold: 500 * time.Millisecond, // 超过500ms算慢连接
|
||||
MaxPoolCapacity: 50, // 连接池最大容量
|
||||
SlowConnThreshold: 200 * time.Millisecond, // 超过200ms算慢连接(更严格的性能要求)
|
||||
MaxPoolCapacity: 100, // 连接池最大容量(支持更高并发)
|
||||
|
||||
// 动态连接池配置(更智能的调整策略)
|
||||
EnableDynamicScaling: true, // 启用动态调整
|
||||
DynamicScaleFactor: 1.8, // 调整因子1.8倍(更激进的扩容)
|
||||
ScaleUpThreshold: 0.7, // 使用率超过70%扩容(更早扩容)
|
||||
ScaleDownThreshold: 0.4, // 使用率低于40%缩容(避免频繁调整)
|
||||
MinScaleUpInterval: 1 * time.Minute, // 最小扩容间隔1分钟(更快的响应)
|
||||
MinScaleDownInterval: 3 * time.Minute, // 最小缩容间隔3分钟(稳定缩容)
|
||||
MaxIdleTimeForScale: 20 * time.Minute, // 用于调整的最大空闲时间
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +112,13 @@ type MySQLConnectionPool struct {
|
||||
stats PoolStats
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// 动态调整相关
|
||||
lastScaleUpTime time.Time // 上次扩容时间
|
||||
lastScaleDownTime time.Time // 上次缩容时间
|
||||
currentTargetSize int // 当前目标连接数
|
||||
usageHistory []float64 // 使用率历史记录(用于智能调整)
|
||||
adaptiveWeights map[uint]float64 // 连接权重(基于性能表现)
|
||||
}
|
||||
|
||||
// NewMySQLConnectionPool 创建新的 MySQL 连接池
|
||||
@@ -103,10 +128,13 @@ func NewMySQLConnectionPool(config *PoolConfig) *MySQLConnectionPool {
|
||||
}
|
||||
|
||||
pool := &MySQLConnectionPool{
|
||||
config: config,
|
||||
entries: make([]*MySQLPoolEntry, 0, config.MaxPoolCapacity),
|
||||
connMap: make(map[uint]*MySQLClient),
|
||||
stopCh: make(chan struct{}),
|
||||
config: config,
|
||||
entries: make([]*MySQLPoolEntry, 0, config.MaxPoolCapacity),
|
||||
connMap: make(map[uint]*MySQLClient),
|
||||
stopCh: make(chan struct{}),
|
||||
currentTargetSize: config.MinIdleConns,
|
||||
usageHistory: make([]float64, 0, 100), // 保留最近100个使用率记录
|
||||
adaptiveWeights: make(map[uint]float64),
|
||||
}
|
||||
|
||||
return pool
|
||||
@@ -119,7 +147,15 @@ func (p *MySQLConnectionPool) Acquire(conn *models.DbConnection) (*MySQLPoolEntr
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// 尝试从池中获取空闲连接
|
||||
// 尝试获取最优连接(启用动态调整时)
|
||||
if p.config.EnableDynamicScaling {
|
||||
if entry, err := p.getOptimalConnection(); err == nil {
|
||||
p.updateWaitStats(startTime)
|
||||
return entry, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 降级到标准逻辑 - 查找空闲连接
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
@@ -138,13 +174,13 @@ func (p *MySQLConnectionPool) Acquire(conn *models.DbConnection) (*MySQLPoolEntr
|
||||
// 没有可用连接,创建新连接
|
||||
if len(p.entries) >= p.config.MaxOpenConns {
|
||||
// 已达到最大连接数,等待
|
||||
return nil, p.waitForAvailableConnection(conn)
|
||||
return p.waitForAvailableConnection(conn)
|
||||
}
|
||||
|
||||
// 创建新连接
|
||||
// 创建新连接(使用传入的连接配置)
|
||||
newEntry, err := p.createNewEntry(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("创建连接失败: %v", err)
|
||||
}
|
||||
|
||||
p.entries = append(p.entries, newEntry)
|
||||
@@ -160,15 +196,14 @@ func (p *MySQLConnectionPool) Release(entry *MySQLPoolEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
entry.InUse = false
|
||||
entry.LastUsed = time.Now()
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
entry.mu.Lock()
|
||||
entry.InUse = false
|
||||
entry.LastUsed = time.Now()
|
||||
entry.mu.Unlock()
|
||||
|
||||
p.updateStats()
|
||||
|
||||
return nil
|
||||
@@ -240,35 +275,9 @@ func (p *MySQLConnectionPool) cleanupIdleConnections() {
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// healthCheck 健康检查
|
||||
// healthCheck 健康检查(增强版本)
|
||||
func (p *MySQLConnectionPool) healthCheck() {
|
||||
p.mu.RLock()
|
||||
entriesCopy := make([]*MySQLPoolEntry, len(p.entries))
|
||||
copy(entriesCopy, p.entries)
|
||||
p.mu.RUnlock()
|
||||
|
||||
var healthyEntries []*MySQLPoolEntry
|
||||
|
||||
for _, entry := range entriesCopy {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
// Ping 测试
|
||||
if err := entry.Client.sqlDB.Ping(); err != nil {
|
||||
// 连接失效,标记为需要关闭
|
||||
entry.mu.Unlock()
|
||||
entry.Client.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
healthyEntries = append(healthyEntries, entry)
|
||||
}
|
||||
|
||||
// 更新连接池
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.entries = healthyEntries
|
||||
p.updateStats()
|
||||
p.enhancedHealthCheck()
|
||||
}
|
||||
|
||||
// StartMaintenance 启动维护协程(清理和健康检查)
|
||||
@@ -277,16 +286,28 @@ func (p *MySQLConnectionPool) StartMaintenance() {
|
||||
go func() {
|
||||
defer p.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(p.config.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
// 健康检查Ticker
|
||||
healthTicker := time.NewTicker(p.config.HealthCheckInterval)
|
||||
defer healthTicker.Stop()
|
||||
|
||||
// 动态调整Ticker(较短间隔)
|
||||
scaleTicker := time.NewTicker(1 * time.Minute)
|
||||
defer scaleTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-healthTicker.C:
|
||||
// 清理空闲连接
|
||||
p.cleanupIdleConnections()
|
||||
// 健康检查
|
||||
p.healthCheck()
|
||||
|
||||
case <-scaleTicker.C:
|
||||
// 动态连接池调整
|
||||
if p.config.EnableDynamicScaling {
|
||||
p.adaptiveScaling()
|
||||
}
|
||||
|
||||
case <-p.stopCh:
|
||||
return
|
||||
}
|
||||
@@ -323,10 +344,8 @@ func (p *MySQLConnectionPool) createNewEntry(conn *models.DbConnection) (*MySQLP
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// waitForAvailableConnection 等待可用连接
|
||||
func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnection) error {
|
||||
// 实现简单的等待逻辑(使用 channel)
|
||||
// 创建一个超时上下文
|
||||
// waitForAvailableConnection 等待可用连接并获取它
|
||||
func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnection) (*MySQLPoolEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -336,34 +355,29 @@ func (p *MySQLConnectionPool) waitForAvailableConnection(conn *models.DbConnecti
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrPoolExhausted
|
||||
return nil, ErrPoolExhausted
|
||||
case <-ticker.C:
|
||||
// 检查是否有可用连接
|
||||
p.mu.RLock()
|
||||
hasAvailable := false
|
||||
p.mu.Lock()
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
hasAvailable = true
|
||||
entry.InUse = true
|
||||
entry.LastUsed = time.Now()
|
||||
entry.mu.Unlock()
|
||||
break
|
||||
p.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
if hasAvailable {
|
||||
return nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateWaitStats 更新等待统计
|
||||
// updateWaitStats 更新等待统计(调用方必须持有 p.mu)
|
||||
func (p *MySQLConnectionPool) updateWaitStats(startTime time.Time) {
|
||||
waitDuration := time.Since(startTime)
|
||||
p.stats.WaitCount++
|
||||
p.stats.WaitDuration += waitDuration
|
||||
p.stats.WaitDuration += time.Since(startTime)
|
||||
}
|
||||
|
||||
// updateStats 更新连接池统计
|
||||
@@ -387,6 +401,244 @@ func (p *MySQLConnectionPool) updateStats() {
|
||||
p.stats.IdleConns = idle
|
||||
}
|
||||
|
||||
// adaptiveScaling 自适应连接池调整
|
||||
func (p *MySQLConnectionPool) adaptiveScaling() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// 计算当前使用率
|
||||
if len(p.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
usageRate := float64(p.stats.ActiveConns) / float64(len(p.entries))
|
||||
|
||||
// 记录使用率历史
|
||||
p.usageHistory = append(p.usageHistory, usageRate)
|
||||
if len(p.usageHistory) > 100 {
|
||||
p.usageHistory = p.usageHistory[1:]
|
||||
}
|
||||
|
||||
// 检查是否需要调整
|
||||
now := time.Now()
|
||||
|
||||
// 扩容逻辑
|
||||
if usageRate >= p.config.ScaleUpThreshold {
|
||||
if now.Sub(p.lastScaleUpTime) >= p.config.MinScaleUpInterval {
|
||||
p.scaleUp()
|
||||
p.lastScaleUpTime = now
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 缩容逻辑
|
||||
if usageRate <= p.config.ScaleDownThreshold && len(p.entries) > p.config.MinIdleConns {
|
||||
if now.Sub(p.lastScaleDownTime) >= p.config.MinScaleDownInterval {
|
||||
p.scaleDown()
|
||||
p.lastScaleDownTime = now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scaleUp 扩容
|
||||
func (p *MySQLConnectionPool) scaleUp() {
|
||||
// scaleUp 仅更新目标大小,实际连接在 Acquire 时按需创建
|
||||
// 移除了创建无效虚拟连接的逻辑
|
||||
currentSize := len(p.entries)
|
||||
scaleFactor := p.config.DynamicScaleFactor
|
||||
|
||||
newSize := int(float64(currentSize) * scaleFactor)
|
||||
newSize = min(newSize, p.config.MaxOpenConns)
|
||||
newSize = max(newSize, currentSize+1)
|
||||
|
||||
p.currentTargetSize = newSize
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// scaleDown 缩容
|
||||
func (p *MySQLConnectionPool) scaleDown() {
|
||||
// 计算新目标大小
|
||||
currentSize := len(p.entries)
|
||||
scaleFactor := 1.0 / p.config.DynamicScaleFactor
|
||||
|
||||
newSize := int(float64(currentSize) * scaleFactor)
|
||||
newSize = max(newSize, p.config.MinIdleConns)
|
||||
newSize = min(newSize, currentSize-1) // 至少减少1个连接
|
||||
|
||||
if newSize < currentSize {
|
||||
// 关闭多余的空闲连接
|
||||
p.closeIdleConnections(currentSize - newSize)
|
||||
p.currentTargetSize = newSize
|
||||
p.updateStats()
|
||||
}
|
||||
}
|
||||
|
||||
// closeIdleConnections 关闭指定数量的空闲连接
|
||||
func (p *MySQLConnectionPool) closeIdleConnections(count int) {
|
||||
// 收集空闲连接
|
||||
idleEntries := make([]*MySQLPoolEntry, 0)
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
idleEntries = append(idleEntries, entry)
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
// 关闭指定数量的空闲连接
|
||||
closedEntries := make(map[*MySQLPoolEntry]bool)
|
||||
for i := 0; i < min(count, len(idleEntries)); i++ {
|
||||
entry := idleEntries[i]
|
||||
entry.mu.Lock()
|
||||
entry.Client.Close()
|
||||
entry.mu.Unlock()
|
||||
closedEntries[entry] = true
|
||||
}
|
||||
|
||||
// 重新构建连接池
|
||||
remainingEntries := make([]*MySQLPoolEntry, 0, len(p.entries))
|
||||
for _, entry := range p.entries {
|
||||
if closedEntries[entry] {
|
||||
continue // 跳过已关闭的连接
|
||||
}
|
||||
remainingEntries = append(remainingEntries, entry)
|
||||
}
|
||||
|
||||
p.entries = remainingEntries
|
||||
}
|
||||
|
||||
// enhancedHealthCheck 增强的健康检查
|
||||
func (p *MySQLConnectionPool) enhancedHealthCheck() {
|
||||
p.mu.RLock()
|
||||
entriesCopy := make([]*MySQLPoolEntry, len(p.entries))
|
||||
copy(entriesCopy, p.entries)
|
||||
p.mu.RUnlock()
|
||||
|
||||
var healthyEntries []*MySQLPoolEntry
|
||||
var performanceWeights []float64
|
||||
|
||||
for _, entry := range entriesCopy {
|
||||
entry.mu.Lock()
|
||||
isIdle := !entry.InUse
|
||||
|
||||
// 测试连接有效性
|
||||
isHealthy := true
|
||||
startTime := time.Now()
|
||||
|
||||
if isIdle {
|
||||
// 空闲连接:简单Ping测试
|
||||
if err := entry.Client.sqlDB.Ping(); err != nil {
|
||||
isHealthy = false
|
||||
// 关闭失效连接
|
||||
entry.Client.Close()
|
||||
}
|
||||
} else {
|
||||
// 使用中的连接:快速测试(避免影响正常查询)
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := entry.Client.sqlDB.PingContext(ctx); err != nil {
|
||||
isHealthy = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 计算连接性能权重
|
||||
if isHealthy {
|
||||
healthyEntries = append(healthyEntries, entry)
|
||||
|
||||
// 基于连接性能计算权重
|
||||
responseTime := time.Since(startTime).Microseconds()
|
||||
weight := 1.0 / max(float64(responseTime)/1000.0, 1.0) // 转换为毫秒,避免除零
|
||||
|
||||
performanceWeights = append(performanceWeights, weight)
|
||||
} else {
|
||||
// 不健康的连接
|
||||
if isIdle {
|
||||
entry.Client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
// 更新连接池
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.entries = healthyEntries
|
||||
|
||||
// 更新自适应权重
|
||||
if len(healthyEntries) > 0 {
|
||||
for i := range healthyEntries {
|
||||
if i < len(performanceWeights) {
|
||||
p.adaptiveWeights[uint(i)] = performanceWeights[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// warmUp 连接池预热
|
||||
func (p *MySQLConnectionPool) warmUp() {
|
||||
if !p.config.EnableWarmup {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
currentIdle := 0
|
||||
for _, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
currentIdle++
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
targetIdle := p.config.MinIdleConns
|
||||
needed := targetIdle - currentIdle
|
||||
|
||||
// warmUp 仅记录目标大小,不在无连接配置的情况下创建无效虚拟连接
|
||||
// 实际连接在 Acquire 时按需创建
|
||||
_ = needed
|
||||
|
||||
p.updateStats()
|
||||
}
|
||||
|
||||
// getOptimalConnection 获取最优连接(基于性能权重)
|
||||
// 注意:调用方必须已持有 p.mu
|
||||
func (p *MySQLConnectionPool) getOptimalConnection() (*MySQLPoolEntry, error) {
|
||||
var bestEntry *MySQLPoolEntry
|
||||
var bestWeight float64
|
||||
|
||||
for i, entry := range p.entries {
|
||||
entry.mu.Lock()
|
||||
if !entry.InUse {
|
||||
weight := 1.0 // 默认权重
|
||||
if w, ok := p.adaptiveWeights[uint(i)]; ok {
|
||||
weight = w
|
||||
}
|
||||
|
||||
if bestEntry == nil || weight > bestWeight {
|
||||
bestEntry = entry
|
||||
bestWeight = weight
|
||||
}
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
if bestEntry == nil {
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
|
||||
bestEntry.InUse = true
|
||||
bestEntry.LastUsed = time.Now()
|
||||
return bestEntry, nil
|
||||
}
|
||||
|
||||
// createMySQLClient 创建 MySQL 客户端的辅助函数
|
||||
func createMySQLClient(conn *models.DbConnection) (*MySQLClient, error) {
|
||||
// 解密密码
|
||||
@@ -424,3 +676,4 @@ func (e *PoolError) Error() string {
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
|
||||
762
internal/dbclient/query_optimizer.go
Normal file
762
internal/dbclient/query_optimizer.go
Normal file
@@ -0,0 +1,762 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
reLimitOffset = regexp.MustCompile(`limit\s+(\d+)(?:\s*,\s*(\d+))?`)
|
||||
reFromTable = regexp.MustCompile(`(?i)from\s+([^\s,]+)`)
|
||||
reWhereClause = regexp.MustCompile(`(?i)where\s+(.*?)(?:\s+order\s+by|\s+limit|\s+group\s+by|$)`)
|
||||
reOrderBy = regexp.MustCompile(`(?i)order\s+by\s+(.*?)(?:\s+limit|$)`)
|
||||
reBatchOperation = regexp.MustCompile(`(?i)^\s*(INSERT|UPDATE|DELETE).*VALUES\s*\(`)
|
||||
)
|
||||
|
||||
// CachedQuery 缓存查询结果
|
||||
type CachedQuery struct {
|
||||
Result *QueryResult
|
||||
ExpiryTime time.Time
|
||||
CreatedAt time.Time
|
||||
QueryHash string
|
||||
QueryParams QueryParams
|
||||
LastUsed time.Time // 最后使用时间(用于LRU策略)
|
||||
AccessCount int64 // 访问次数(用于LFU策略)
|
||||
}
|
||||
|
||||
// QueryParams 查询参数(用于缓存键生成)
|
||||
type QueryParams struct {
|
||||
SQL string
|
||||
Database string
|
||||
Limit int
|
||||
Offset int
|
||||
Table string
|
||||
Where string
|
||||
SortBy string
|
||||
IsReadOnly bool
|
||||
}
|
||||
|
||||
// QueryStats 查询统计信息
|
||||
type QueryStats struct {
|
||||
TotalQueries int64
|
||||
CachedQueries int64
|
||||
SlowQueries int64
|
||||
TotalDuration time.Duration
|
||||
AverageDuration time.Duration
|
||||
CacheHitRate float64
|
||||
LastCacheUpdate time.Time
|
||||
}
|
||||
|
||||
// SlowQuery 慢查询记录
|
||||
type SlowQuery struct {
|
||||
Query string
|
||||
Database string
|
||||
Duration time.Duration
|
||||
Timestamp time.Time
|
||||
Params QueryParams
|
||||
Table string
|
||||
IndexUsed string
|
||||
RowsAffected int64
|
||||
Error error
|
||||
}
|
||||
|
||||
// IndexSuggestion 索引建议
|
||||
type IndexSuggestion struct {
|
||||
Table string
|
||||
Columns []string
|
||||
IndexType string // "normal", "unique", "fulltext"
|
||||
Priority string // "high", "medium", "low"
|
||||
Query string
|
||||
Justification string
|
||||
CanBeApplied bool
|
||||
}
|
||||
|
||||
// QueryOptimizer 查询优化器
|
||||
type QueryOptimizer struct {
|
||||
cache *QueryCache
|
||||
stats *QueryStats
|
||||
slowQueries []SlowQuery
|
||||
indexSuggestions []IndexSuggestion
|
||||
mu sync.RWMutex
|
||||
config *OptimizerConfig
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// OptimizerConfig 查询优化器配置
|
||||
type OptimizerConfig struct {
|
||||
// 缓存配置
|
||||
CacheSize int // 最大缓存条目数
|
||||
CacheTTL time.Duration // 缓存过期时间
|
||||
EnableCache bool // 是否启用缓存
|
||||
|
||||
// 慢查询配置
|
||||
SlowQueryThreshold time.Duration // 慢查询阈值
|
||||
EnableSlowLog bool // 是否启用慢查询日志
|
||||
MaxSlowLogs int // 最大慢查询记录数
|
||||
|
||||
// 索引建议配置
|
||||
EnableIndexSuggestions bool // 是否启用索引建议
|
||||
MaxSuggestions int // 最大索引建议数
|
||||
|
||||
// 查询分析配置
|
||||
EnableQueryAnalysis bool // 是否启用查询分析
|
||||
MaxAnalysisDepth int // 查询分析深度
|
||||
}
|
||||
|
||||
// DefaultOptimizerConfig 返回默认的查询优化器配置
|
||||
func DefaultOptimizerConfig() *OptimizerConfig {
|
||||
return &OptimizerConfig{
|
||||
CacheSize: 1000, // 最多缓存1000个查询
|
||||
CacheTTL: 30 * time.Minute, // 缓存30分钟
|
||||
EnableCache: true, // 启用缓存
|
||||
SlowQueryThreshold: 100 * time.Millisecond, // 100ms以上为慢查询
|
||||
EnableSlowLog: true, // 启用慢查询日志
|
||||
MaxSlowLogs: 1000, // 最多记录1000条慢查询
|
||||
EnableIndexSuggestions: true, // 启用索引建议
|
||||
MaxSuggestions: 100, // 最多100个索引建议
|
||||
EnableQueryAnalysis: true, // 启用查询分析
|
||||
MaxAnalysisDepth: 3, // 分析深度3
|
||||
}
|
||||
}
|
||||
|
||||
// NewQueryOptimizer 创建新的查询优化器
|
||||
func NewQueryOptimizer(config *OptimizerConfig) *QueryOptimizer {
|
||||
if config == nil {
|
||||
config = DefaultOptimizerConfig()
|
||||
}
|
||||
|
||||
optimizer := &QueryOptimizer{
|
||||
cache: NewQueryCache(config.CacheSize, config.CacheTTL),
|
||||
stats: &QueryStats{},
|
||||
config: config,
|
||||
stopCh: make(chan struct{}),
|
||||
slowQueries: make([]SlowQuery, 0),
|
||||
indexSuggestions: make([]IndexSuggestion, 0),
|
||||
}
|
||||
|
||||
// 启动维护协程
|
||||
optimizer.StartMaintenance()
|
||||
|
||||
return optimizer
|
||||
}
|
||||
|
||||
// OptimizeQuery 优化查询执行
|
||||
func (o *QueryOptimizer) OptimizeQuery(ctx context.Context, client *MySQLClient, sqlStr string, database string) (*QueryResult, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
queryParams := o.parseQueryParams(sqlStr, database)
|
||||
|
||||
// 检查缓存
|
||||
if o.config.EnableCache && queryParams.IsReadOnly {
|
||||
cached, err := o.cache.Get(queryParams)
|
||||
if err == nil && cached != nil {
|
||||
o.recordCacheHit()
|
||||
return cached.Result, time.Since(startTime), nil
|
||||
}
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
result, err := client.ExecuteQuery(ctx, sqlStr, database)
|
||||
if err != nil {
|
||||
duration := time.Since(startTime)
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
|
||||
return nil, duration, err
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 检查是否为慢查询
|
||||
if duration > o.config.SlowQueryThreshold {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, result, err)
|
||||
}
|
||||
|
||||
// 缓存只读查询结果
|
||||
if o.config.EnableCache && queryParams.IsReadOnly && err == nil {
|
||||
cachedResult := &CachedQuery{
|
||||
Result: result,
|
||||
ExpiryTime: time.Now().Add(o.config.CacheTTL),
|
||||
CreatedAt: time.Now(),
|
||||
QueryHash: o.generateQueryHash(queryParams),
|
||||
QueryParams: queryParams,
|
||||
LastUsed: time.Now(),
|
||||
AccessCount: 1,
|
||||
}
|
||||
o.cache.Set(queryParams, cachedResult)
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return result, duration, err
|
||||
}
|
||||
|
||||
// ExecuteOptimizedUpdate 执行优化的更新操作
|
||||
func (o *QueryOptimizer) ExecuteOptimizedUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 分析更新查询
|
||||
queryParams := o.parseQueryParams(sqlStr, database)
|
||||
|
||||
// 检查是否为批量操作
|
||||
if o.isBatchOperation(sqlStr) {
|
||||
// 优化批量操作
|
||||
rowsAffected, duration, err := o.optimizeBatchUpdate(ctx, client, sqlStr, database)
|
||||
if err != nil {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
|
||||
return 0, duration, err
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return rowsAffected, duration, nil
|
||||
}
|
||||
|
||||
// 执行普通更新
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if duration > o.config.SlowQueryThreshold {
|
||||
o.recordSlowQuery(sqlStr, database, duration, queryParams, nil, err)
|
||||
}
|
||||
|
||||
o.recordQuery(duration)
|
||||
return rowsAffected, duration, err
|
||||
}
|
||||
|
||||
// GetIndexSuggestions 获取索引建议
|
||||
func (o *QueryOptimizer) GetIndexSuggestions(table string) []IndexSuggestion {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
var suggestions []IndexSuggestion
|
||||
for _, suggestion := range o.indexSuggestions {
|
||||
if suggestion.Table == table || table == "" {
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// GenerateIndexSuggestions 为表生成索引建议
|
||||
func (o *QueryOptimizer) GenerateIndexSuggestions(ctx context.Context, client *MySQLClient, database, table string) error {
|
||||
// 获取表的慢查询记录
|
||||
tableSlowQueries := o.getTableSlowQueries(database, table)
|
||||
|
||||
// 分析查询模式
|
||||
for _, slowQuery := range tableSlowQueries {
|
||||
suggestions := o.analyzeQueryForIndexes(slowQuery.Query, table)
|
||||
o.mu.Lock()
|
||||
o.indexSuggestions = append(o.indexSuggestions, suggestions...)
|
||||
|
||||
// 限制建议数量
|
||||
if len(o.indexSuggestions) > o.config.MaxSuggestions {
|
||||
o.indexSuggestions = o.indexSuggestions[:o.config.MaxSuggestions]
|
||||
}
|
||||
o.mu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (o *QueryOptimizer) GetQueryStats() QueryStats {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
return *o.stats
|
||||
}
|
||||
|
||||
// GetSlowQueries 获取慢查询记录
|
||||
func (o *QueryOptimizer) GetSlowQueries(limit int) []SlowQuery {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
if limit <= 0 || limit > len(o.slowQueries) {
|
||||
limit = len(o.slowQueries)
|
||||
}
|
||||
|
||||
return o.slowQueries[:limit]
|
||||
}
|
||||
|
||||
// ClearCache 清空缓存
|
||||
func (o *QueryOptimizer) ClearCache() {
|
||||
o.cache.Clear()
|
||||
}
|
||||
|
||||
// Stop 停止优化器
|
||||
func (o *QueryOptimizer) Stop() {
|
||||
close(o.stopCh)
|
||||
o.wg.Wait()
|
||||
}
|
||||
|
||||
// parseQueryParams 解析查询参数
|
||||
func (o *QueryOptimizer) parseQueryParams(sqlStr, database string) QueryParams {
|
||||
params := QueryParams{
|
||||
SQL: sqlStr,
|
||||
Database: database,
|
||||
}
|
||||
|
||||
// 解析LIMIT和OFFSET
|
||||
limit, offset := o.parseLimitOffset(sqlStr)
|
||||
params.Limit = limit
|
||||
params.Offset = offset
|
||||
|
||||
// 解析表名
|
||||
tables := o.parseTables(sqlStr)
|
||||
if len(tables) > 0 {
|
||||
params.Table = tables[0]
|
||||
}
|
||||
|
||||
// 解析WHERE条件
|
||||
where := o.parseWhereCondition(sqlStr)
|
||||
params.Where = where
|
||||
|
||||
// 解析排序
|
||||
sort := o.parseSortOrder(sqlStr)
|
||||
params.SortBy = sort
|
||||
|
||||
// 判断是否为只读查询
|
||||
params.IsReadOnly = o.isReadOnlyQuery(sqlStr)
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// parseLimitOffset 解析LIMIT和OFFSET
|
||||
func (o *QueryOptimizer) parseLimitOffset(sqlStr string) (limit, offset int) {
|
||||
sqlStr = strings.ToLower(sqlStr)
|
||||
|
||||
matches := reLimitOffset.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
fmt.Sscanf(matches[1], "%d", &limit)
|
||||
if len(matches) > 2 && matches[2] != "" {
|
||||
fmt.Sscanf(matches[2], "%d", &offset)
|
||||
}
|
||||
}
|
||||
|
||||
// MySQL LIMIT offset, count: matches[1]=offset, matches[2]=count
|
||||
if len(matches) > 2 && matches[2] != "" {
|
||||
offset, limit = limit, offset
|
||||
}
|
||||
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
// parseTables 解析查询中的表名
|
||||
func (o *QueryOptimizer) parseTables(sqlStr string) []string {
|
||||
// 简单实现:解析FROM和JOIN中的表名
|
||||
tables := make([]string, 0)
|
||||
|
||||
fromMatches := reFromTable.FindAllStringSubmatch(sqlStr, -1)
|
||||
|
||||
for _, match := range fromMatches {
|
||||
if len(match) > 1 {
|
||||
tableName := strings.Trim(match[1], "`\"'[]")
|
||||
tables = append(tables, tableName)
|
||||
}
|
||||
}
|
||||
|
||||
return tables
|
||||
}
|
||||
|
||||
// parseWhereCondition 解析WHERE条件
|
||||
func (o *QueryOptimizer) parseWhereCondition(sqlStr string) string {
|
||||
matches := reWhereClause.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSortOrder 解析排序条件
|
||||
func (o *QueryOptimizer) parseSortOrder(sqlStr string) string {
|
||||
matches := reOrderBy.FindStringSubmatch(sqlStr)
|
||||
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isReadOnlyQuery 判断是否为只读查询
|
||||
func (o *QueryOptimizer) isReadOnlyQuery(sqlStr string) bool {
|
||||
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
|
||||
|
||||
// SELECT只读查询
|
||||
if strings.HasPrefix(sqlStr, "SELECT") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持的只读查询类型
|
||||
readOnlyQueries := []string{
|
||||
"SHOW", "DESCRIBE", "DESC", "EXPLAIN",
|
||||
"WITH", "UNION", "INTERSECT", "EXCEPT",
|
||||
}
|
||||
|
||||
for _, query := range readOnlyQueries {
|
||||
if strings.HasPrefix(sqlStr, query) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isBatchOperation 判断是否为批量操作
|
||||
func (o *QueryOptimizer) isBatchOperation(sqlStr string) bool {
|
||||
return reBatchOperation.MatchString(sqlStr)
|
||||
}
|
||||
|
||||
// generateQueryHash 生成查询哈希
|
||||
func (o *QueryOptimizer) generateQueryHash(params QueryParams) string {
|
||||
hashData := fmt.Sprintf("%s|%s|%d|%d|%s|%s|%s|%v",
|
||||
params.SQL, params.Database, params.Limit, params.Offset,
|
||||
params.Table, params.Where, params.SortBy, params.IsReadOnly)
|
||||
|
||||
h := sha256.Sum256([]byte(hashData))
|
||||
return fmt.Sprintf("%x", h)
|
||||
}
|
||||
|
||||
// recordQuery 记录查询统计
|
||||
func (o *QueryOptimizer) recordQuery(duration time.Duration) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.stats.TotalQueries++
|
||||
o.stats.TotalDuration += duration
|
||||
o.stats.AverageDuration = time.Duration(int64(float64(o.stats.TotalDuration) / float64(o.stats.TotalQueries)))
|
||||
|
||||
now := time.Now()
|
||||
if o.stats.LastCacheUpdate.IsZero() || now.Sub(o.stats.LastCacheUpdate) > 5*time.Minute {
|
||||
// 更新缓存命中率
|
||||
total := o.stats.TotalQueries
|
||||
hit := o.stats.CachedQueries
|
||||
o.stats.CacheHitRate = float64(hit) / float64(total) * 100
|
||||
o.stats.LastCacheUpdate = now
|
||||
}
|
||||
}
|
||||
|
||||
// recordCacheHit 记录缓存命中
|
||||
func (o *QueryOptimizer) recordCacheHit() {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.stats.CachedQueries++
|
||||
}
|
||||
|
||||
// recordSlowQuery 记录慢查询
|
||||
func (o *QueryOptimizer) recordSlowQuery(query, database string, duration time.Duration, params QueryParams, result *QueryResult, err error) {
|
||||
if !o.config.EnableSlowLog {
|
||||
return
|
||||
}
|
||||
|
||||
slowQuery := SlowQuery{
|
||||
Query: query,
|
||||
Database: database,
|
||||
Duration: duration,
|
||||
Timestamp: time.Now(),
|
||||
Params: params,
|
||||
Table: params.Table,
|
||||
IndexUsed: o.extractIndexUsed(query),
|
||||
RowsAffected: o.extractRowsAffected(result),
|
||||
Error: err,
|
||||
}
|
||||
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.slowQueries = append(o.slowQueries, slowQuery)
|
||||
|
||||
// 限制慢查询记录数量
|
||||
if len(o.slowQueries) > o.config.MaxSlowLogs {
|
||||
o.slowQueries = o.slowQueries[1:]
|
||||
}
|
||||
|
||||
o.stats.SlowQueries++
|
||||
}
|
||||
|
||||
// extractIndexUsed 提取使用的索引
|
||||
func (o *QueryOptimizer) extractIndexUsed(query string) string {
|
||||
// 简单实现:从EXPLAIN结果中提取索引信息
|
||||
// 实际项目中应该执行EXPLAIN语句分析
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractRowsAffected 提取影响的行数
|
||||
func (o *QueryOptimizer) extractRowsAffected(result *QueryResult) int64 {
|
||||
if result != nil && len(result.Data) > 0 {
|
||||
if rows, ok := result.Data[0]["rows_affected"].(int64); ok {
|
||||
return rows
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// analyzeQuery 分析查询性能
|
||||
func (o *QueryOptimizer) analyzeQuery(query, database string, result *QueryResult, duration time.Duration) {
|
||||
// 这里可以实现更复杂的查询分析逻辑
|
||||
// 比如分析查询计划、检测N+1查询问题等
|
||||
|
||||
// 简单实现:记录查询到统计信息中
|
||||
_ = query
|
||||
_ = database
|
||||
_ = result
|
||||
_ = duration
|
||||
}
|
||||
|
||||
// analyzeQueryForIndexes 分析查询为索引建议
|
||||
func (o *QueryOptimizer) analyzeQueryForIndexes(query, table string) []IndexSuggestion {
|
||||
var suggestions []IndexSuggestion
|
||||
|
||||
// 解析查询中的WHERE条件
|
||||
where := o.parseWhereCondition(query)
|
||||
if where != "" {
|
||||
// 提取WHERE条件中的列
|
||||
columns := o.extractColumnsFromWhere(where)
|
||||
|
||||
if len(columns) > 0 {
|
||||
// 创建索引建议
|
||||
suggestion := IndexSuggestion{
|
||||
Table: table,
|
||||
Columns: columns,
|
||||
IndexType: "normal",
|
||||
Priority: "medium",
|
||||
Query: query,
|
||||
Justification: fmt.Sprintf("查询经常使用WHERE条件 %s", where),
|
||||
CanBeApplied: true,
|
||||
}
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析ORDER BY条件
|
||||
order := o.parseSortOrder(query)
|
||||
if order != "" {
|
||||
// 提取排序的列
|
||||
columns := o.extractColumnsFromOrder(order)
|
||||
|
||||
if len(columns) > 0 {
|
||||
// 创建排序索引建议
|
||||
suggestion := IndexSuggestion{
|
||||
Table: table,
|
||||
Columns: columns,
|
||||
IndexType: "normal",
|
||||
Priority: "low",
|
||||
Query: query,
|
||||
Justification: fmt.Sprintf("查询经常使用ORDER BY %s", order),
|
||||
CanBeApplied: true,
|
||||
}
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
return suggestions
|
||||
}
|
||||
|
||||
// extractColumnsFromWhere 从WHERE条件中提取列名
|
||||
func (o *QueryOptimizer) extractColumnsFromWhere(where string) []string {
|
||||
// 简单实现:提取WHERE条件中的列名
|
||||
columns := make([]string, 0)
|
||||
|
||||
// 这里可以实现更复杂的列名解析逻辑
|
||||
// 目前只做简单处理
|
||||
words := strings.Fields(where)
|
||||
for _, word := range words {
|
||||
// 去除运算符和引号
|
||||
if !strings.Contains(word, "=") &&
|
||||
!strings.Contains(word, ">") &&
|
||||
!strings.Contains(word, "<") &&
|
||||
!strings.Contains(word, "!=") &&
|
||||
!strings.HasPrefix(word, "'") &&
|
||||
!strings.HasPrefix(word, "\"") {
|
||||
columns = append(columns, strings.Trim(word, " `\"'[]"))
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// extractColumnsFromOrder 从ORDER BY条件中提取列名
|
||||
func (o *QueryOptimizer) extractColumnsFromOrder(order string) []string {
|
||||
// 简单实现:提取ORDER BY中的列名
|
||||
columns := strings.Split(order, ",")
|
||||
for i, col := range columns {
|
||||
columns[i] = strings.TrimSpace(strings.Split(col, " ")[0])
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// getTableSlowQueries 获取表的慢查询记录
|
||||
func (o *QueryOptimizer) getTableSlowQueries(database, table string) []SlowQuery {
|
||||
o.mu.RLock()
|
||||
defer o.mu.RUnlock()
|
||||
|
||||
var tableQueries []SlowQuery
|
||||
for _, query := range o.slowQueries {
|
||||
if (database == "" || query.Database == database) &&
|
||||
(table == "" || query.Table == table) {
|
||||
tableQueries = append(tableQueries, query)
|
||||
}
|
||||
}
|
||||
return tableQueries
|
||||
}
|
||||
|
||||
// optimizeBatchUpdate 优化批量更新操作
|
||||
func (o *QueryOptimizer) optimizeBatchUpdate(ctx context.Context, client *MySQLClient, sqlStr string, database string) (int64, time.Duration, error) {
|
||||
// 简单实现:执行原始查询
|
||||
// 实际项目中可以实现批量操作优化
|
||||
startTime := time.Now()
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, database)
|
||||
duration := time.Since(startTime)
|
||||
return rowsAffected, duration, err
|
||||
}
|
||||
|
||||
// StartMaintenance 启动维护协程
|
||||
func (o *QueryOptimizer) StartMaintenance() {
|
||||
o.wg.Add(1)
|
||||
go func() {
|
||||
defer o.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 清理过期的缓存
|
||||
o.cache.CleanupExpired()
|
||||
|
||||
// 分析慢查询生成新的索引建议
|
||||
o.analyzeSlowQueriesForSuggestions()
|
||||
|
||||
case <-o.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// RecordPoolError 记录连接池错误
|
||||
func (o *QueryOptimizer) RecordPoolError(operation string, err error) {
|
||||
if !o.config.EnableSlowLog || err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
poolError := SlowQuery{
|
||||
Query: operation,
|
||||
Database: "pool",
|
||||
Duration: 0,
|
||||
Timestamp: time.Now(),
|
||||
Params: QueryParams{SQL: operation},
|
||||
Table: "connection_pool",
|
||||
IndexUsed: "N/A",
|
||||
RowsAffected: 0,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.slowQueries = append(o.slowQueries, poolError)
|
||||
|
||||
// 限制慢查询记录数量
|
||||
if len(o.slowQueries) > o.config.MaxSlowLogs {
|
||||
o.slowQueries = o.slowQueries[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// analyzeSlowQueriesForSuggestions 分析慢查询生成索引建议
|
||||
func (o *QueryOptimizer) analyzeSlowQueriesForSuggestions() {
|
||||
// 这里可以实现更复杂的慢查询分析逻辑
|
||||
// 比如分析查询模式、统计索引使用情况等
|
||||
|
||||
// 分析慢查询模式
|
||||
o.analyzeSlowQueryPatterns()
|
||||
}
|
||||
|
||||
// analyzeSlowQueryPatterns 分析慢查询模式
|
||||
func (o *QueryOptimizer) analyzeSlowQueryPatterns() {
|
||||
o.mu.RLock()
|
||||
queryTypes := make(map[string]int)
|
||||
tableQueries := make(map[string]int)
|
||||
|
||||
for _, query := range o.slowQueries {
|
||||
queryType := o.detectQueryType(query.Query)
|
||||
queryTypes[queryType]++
|
||||
|
||||
if query.Table != "" {
|
||||
tableQueries[query.Table]++
|
||||
}
|
||||
}
|
||||
o.mu.RUnlock()
|
||||
|
||||
// 根据统计结果生成智能建议(在锁外执行,避免死锁)
|
||||
o.generateSmartSuggestions(queryTypes, tableQueries)
|
||||
}
|
||||
|
||||
// detectQueryType 检测查询类型
|
||||
func (o *QueryOptimizer) detectQueryType(sqlStr string) string {
|
||||
sqlStr = strings.ToUpper(strings.TrimSpace(sqlStr))
|
||||
|
||||
if strings.HasPrefix(sqlStr, "SELECT") {
|
||||
if strings.Contains(sqlStr, "JOIN") {
|
||||
return "SELECT_JOIN"
|
||||
} else if strings.Contains(sqlStr, "GROUP BY") {
|
||||
return "SELECT_GROUP"
|
||||
} else {
|
||||
return "SELECT_SIMPLE"
|
||||
}
|
||||
} else if strings.HasPrefix(sqlStr, "INSERT") {
|
||||
return "INSERT"
|
||||
} else if strings.HasPrefix(sqlStr, "UPDATE") {
|
||||
return "UPDATE"
|
||||
} else if strings.HasPrefix(sqlStr, "DELETE") {
|
||||
return "DELETE"
|
||||
}
|
||||
|
||||
return "OTHER"
|
||||
}
|
||||
|
||||
// generateSmartSuggestions 生成智能建议
|
||||
func (o *QueryOptimizer) generateSmartSuggestions(queryTypes map[string]int, tableQueries map[string]int) {
|
||||
// 分析频繁执行的查询类型
|
||||
var mostFrequentType string
|
||||
var maxCount int
|
||||
|
||||
for queryType, count := range queryTypes {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
mostFrequentType = queryType
|
||||
}
|
||||
}
|
||||
|
||||
// 生成针对性的索引建议
|
||||
switch mostFrequentType {
|
||||
case "SELECT_JOIN":
|
||||
// 为JOIN查询建议复合索引
|
||||
o.generateJoinSuggestions()
|
||||
case "SELECT_GROUP":
|
||||
// 为GROUP BY查询建议索引
|
||||
o.generateGroupSuggestions()
|
||||
case "INSERT":
|
||||
// 为批量插入建议优化
|
||||
o.generateInsertSuggestions()
|
||||
}
|
||||
}
|
||||
|
||||
// generateJoinSuggestions 生成JOIN查询建议
|
||||
func (o *QueryOptimizer) generateJoinSuggestions() {
|
||||
}
|
||||
|
||||
// generateGroupSuggestions 生成GROUP BY查询建议
|
||||
func (o *QueryOptimizer) generateGroupSuggestions() {
|
||||
}
|
||||
|
||||
// generateInsertSuggestions 生成批量插入建议
|
||||
func (o *QueryOptimizer) generateInsertSuggestions() {
|
||||
}
|
||||
151
internal/dbclient/redis_pipeline.go
Normal file
151
internal/dbclient/redis_pipeline.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package dbclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisPipeline Redis Pipeline 操作
|
||||
type RedisPipeline struct {
|
||||
client *RedisClient
|
||||
commands []RedisCommand
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// RedisCommand Redis 命令结构
|
||||
type RedisCommand struct {
|
||||
Command string
|
||||
Args []interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
// NewRedisPipeline 创建新的 Redis Pipeline
|
||||
func (r *RedisClient) NewPipeline(ctx context.Context) *RedisPipeline {
|
||||
return &RedisPipeline{
|
||||
client: r,
|
||||
commands: make([]RedisCommand, 0),
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCommand 添加命令到 Pipeline
|
||||
func (p *RedisPipeline) AddCommand(command string, args ...interface{}) {
|
||||
p.commands = append(p.commands, RedisCommand{
|
||||
Command: command,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute 使用 go-redis 原生 Pipeline 执行所有命令
|
||||
func (p *RedisPipeline) Execute() ([]interface{}, error) {
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pipe := p.client.client.Pipeline()
|
||||
|
||||
cmds := make([]*redis.Cmd, len(p.commands))
|
||||
for i, c := range p.commands {
|
||||
cmds[i] = pipe.Do(p.ctx, append([]interface{}{c.Command}, c.Args...)...)
|
||||
}
|
||||
|
||||
// 一次性发送所有命令
|
||||
results := make([]interface{}, len(p.commands))
|
||||
cmdResults, err := pipe.Exec(p.ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
log.Printf("[RedisPipeline] Exec 错误: %v", err)
|
||||
}
|
||||
|
||||
for i, cmd := range cmds {
|
||||
result, cmdErr := cmd.Result()
|
||||
results[i] = result
|
||||
p.commands[i].Result = result
|
||||
p.commands[i].Error = cmdErr
|
||||
}
|
||||
|
||||
// 如果 Exec 返回了命令结果(部分 Redis 版本),使用它们
|
||||
for i, cr := range cmdResults {
|
||||
if cr.Err() != nil && cr.Err() != redis.Nil {
|
||||
p.commands[i].Error = cr.Err()
|
||||
if i < len(results) {
|
||||
results[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = results // 已经通过 cmds 获取
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetCommands 获取 Pipeline 中的命令列表
|
||||
func (p *RedisPipeline) GetCommands() []RedisCommand {
|
||||
return p.commands
|
||||
}
|
||||
|
||||
// Len 获取 Pipeline 中的命令数量
|
||||
func (p *RedisPipeline) Len() int {
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
// Clear 清空 Pipeline
|
||||
func (p *RedisPipeline) Clear() {
|
||||
p.commands = make([]RedisCommand, 0)
|
||||
}
|
||||
|
||||
// RedisTransaction Redis 事务支持
|
||||
type RedisTransaction struct {
|
||||
client *RedisClient
|
||||
watch []string
|
||||
cmds []RedisCommand
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisTransaction 创建新的 Redis 事务
|
||||
func (r *RedisClient) NewTransaction(ctx context.Context, watch ...string) *RedisTransaction {
|
||||
return &RedisTransaction{
|
||||
client: r,
|
||||
watch: watch,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// AddCommand 添加命令到事务
|
||||
func (tx *RedisTransaction) AddCommand(command string, args ...interface{}) {
|
||||
tx.cmds = append(tx.cmds, RedisCommand{
|
||||
Command: command,
|
||||
Args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Exec 使用 go-redis Watch + TxPipeline 执行事务(MULTI/EXEC)
|
||||
func (tx *RedisTransaction) Exec() ([]interface{}, error) {
|
||||
pipe := tx.client.client.TxPipeline()
|
||||
|
||||
// 添加所有命令
|
||||
cmds := make([]*redis.Cmd, len(tx.cmds))
|
||||
for i, c := range tx.cmds {
|
||||
cmds[i] = pipe.Do(tx.ctx, append([]interface{}{c.Command}, c.Args...)...)
|
||||
}
|
||||
|
||||
// TxPipeline 自动发送 MULTI/EXEC
|
||||
results := make([]interface{}, len(tx.cmds))
|
||||
_, err := pipe.Exec(tx.ctx)
|
||||
|
||||
for i, cmd := range cmds {
|
||||
result, cmdErr := cmd.Result()
|
||||
results[i] = result
|
||||
tx.cmds[i].Result = result
|
||||
tx.cmds[i].Error = cmdErr
|
||||
}
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return results, fmt.Errorf("事务执行失败: %v", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -43,8 +43,10 @@ var defaultTabConfig = TabConfig{
|
||||
AvailableTabs: []TabDefinition{
|
||||
{Key: "file-system", Title: "文件管理", Enabled: true},
|
||||
{Key: "db-cli", Title: "数据库", Enabled: true},
|
||||
{Key: "markdown-editor", Title: "Markdown", Enabled: true},
|
||||
{Key: "openclaw-manager", Title: "OpenClaw", Enabled: true},
|
||||
},
|
||||
VisibleTabs: []string{"file-system", "db-cli"},
|
||||
VisibleTabs: []string{"file-system", "db-cli", "markdown-editor", "openclaw-manager"},
|
||||
DefaultTab: "file-system",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"u-desk/internal/crypto"
|
||||
"u-desk/internal/dbclient"
|
||||
"u-desk/internal/storage/models"
|
||||
"u-desk/internal/storage/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ConnectionService 连接管理服务
|
||||
@@ -90,8 +94,20 @@ func (s *ConnectionService) GetConnection(id uint) (*models.DbConnection, error)
|
||||
return s.repo.FindByID(id)
|
||||
}
|
||||
|
||||
// DeleteConnection 删除连接配置
|
||||
// DeleteConnection 删除连接配置(含关联数据和连接池清理)
|
||||
func (s *ConnectionService) DeleteConnection(id uint) error {
|
||||
conn, err := s.repo.FindByID(id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // 连接不存在视为成功
|
||||
}
|
||||
return fmt.Errorf("获取连接配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 关闭连接池中的连接
|
||||
dbclient.GetPool().CloseConnection(id, conn.Type)
|
||||
|
||||
// 删除连接记录
|
||||
return s.repo.Delete(id)
|
||||
}
|
||||
|
||||
@@ -185,3 +201,68 @@ func (s *ConnectionService) TestConnectionWithParams(connType, host string, port
|
||||
return fmt.Errorf("不支持的数据库类型: %s", connType)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAllDatabases 加载全部数据库列表
|
||||
func (s *ConnectionService) LoadAllDatabases(dbType, host string, port int, username, password, database, options string, existingId uint) ([]string, error) {
|
||||
// 如果是编辑模式且密码为空,尝试获取已保存的密码
|
||||
actualPassword := password
|
||||
if existingId > 0 && password == "" {
|
||||
conn, err := s.repo.FindByID(existingId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取原连接配置失败: %v", err)
|
||||
}
|
||||
actualPassword, err = crypto.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码解密失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 MongoDB 选项
|
||||
authSource := ""
|
||||
authMechanism := ""
|
||||
if options != "" {
|
||||
var opts map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(options), &opts); err == nil {
|
||||
authSource, _ = opts["authSource"].(string)
|
||||
authMechanism, _ = opts["authMechanism"].(string)
|
||||
}
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
return loadDatabasesForMySQL(host, port, username, actualPassword, database)
|
||||
case "mongo":
|
||||
return loadDatabasesForMongo(host, port, username, actualPassword, database, authSource, authMechanism)
|
||||
case "redis":
|
||||
return []string{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的数据库类型: %s", dbType)
|
||||
}
|
||||
}
|
||||
|
||||
func loadDatabasesForMySQL(host string, port int, username, password, defaultDatabase string) ([]string, error) {
|
||||
config := &dbclient.MySQLConfig{
|
||||
Host: host, Port: port, Username: username,
|
||||
Password: password, Database: defaultDatabase,
|
||||
}
|
||||
client, err := dbclient.NewMySQLClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
return client.ListDatabases(context.Background())
|
||||
}
|
||||
|
||||
func loadDatabasesForMongo(host string, port int, username, password, defaultDatabase, authSource, authMechanism string) ([]string, error) {
|
||||
config := &dbclient.MongoConfig{
|
||||
Host: host, Port: port, Username: username,
|
||||
Password: password, Database: defaultDatabase,
|
||||
AuthSource: authSource, AuthMechanism: authMechanism,
|
||||
}
|
||||
client, err := dbclient.NewMongoClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
return client.ListDatabases(context.Background())
|
||||
}
|
||||
|
||||
@@ -66,10 +66,11 @@ func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database s
|
||||
|
||||
// executeMySQL 执行MySQL SQL
|
||||
func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnection, sqlStr string, database string, startTime time.Time) (*SqlResult, error) {
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
defer pc.Release()
|
||||
|
||||
sqlStr = strings.TrimSpace(sqlStr)
|
||||
sqlUpper := strings.ToUpper(sqlStr)
|
||||
@@ -89,7 +90,7 @@ func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnec
|
||||
strings.HasPrefix(sqlUpper, "DESCRIBE") || strings.HasPrefix(sqlUpper, "DESC") ||
|
||||
strings.HasPrefix(sqlUpper, "EXPLAIN") {
|
||||
// 查询语句
|
||||
queryResult, err := client.ExecuteQuery(ctx, sqlStr, dbName)
|
||||
queryResult, err := pc.Client.ExecuteQuery(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func (s *SqlExecService) executeMySQL(ctx context.Context, conn *models.DbConnec
|
||||
result.RowsAffected = len(queryResult.Data)
|
||||
} else {
|
||||
// 更新语句
|
||||
rowsAffected, err := client.ExecuteUpdate(ctx, sqlStr, dbName)
|
||||
rowsAffected, err := pc.Client.ExecuteUpdate(ctx, sqlStr, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,11 +221,12 @@ func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
return client.ListDatabases(ctx)
|
||||
defer pc.Release()
|
||||
return pc.Client.ListDatabases(ctx)
|
||||
case "redis":
|
||||
databases := make([]string, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
@@ -254,11 +256,12 @@ func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
return client.ListTables(ctx, database)
|
||||
defer pc.Release()
|
||||
return pc.Client.ListTables(ctx, database)
|
||||
case "redis":
|
||||
client, err := s.pool.GetRedisClient(conn)
|
||||
if err != nil {
|
||||
@@ -305,7 +308,7 @@ func parseRedisCommand(cmd string) []string {
|
||||
} else {
|
||||
if char == quoteChar {
|
||||
inQuotes = false
|
||||
quoteChar = 0
|
||||
quoteChar = byte(0)
|
||||
} else {
|
||||
current.WriteByte(char)
|
||||
}
|
||||
@@ -330,11 +333,12 @@ func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableNam
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
structure, err := client.GetTableStructure(ctx, database, tableName)
|
||||
defer pc.Release()
|
||||
structure, err := pc.Client.GetTableStructure(ctx, database, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -393,11 +397,12 @@ func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName strin
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
return client.GetIndexes(ctx, database, tableName)
|
||||
defer pc.Release()
|
||||
return pc.Client.GetIndexes(ctx, database, tableName)
|
||||
|
||||
case "mongo", "redis":
|
||||
return []map[string]interface{}{}, nil
|
||||
@@ -419,11 +424,12 @@ func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tabl
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
return client.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
defer pc.Release()
|
||||
return pc.Client.PreviewTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
@@ -449,11 +455,12 @@ func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, table
|
||||
|
||||
switch conn.Type {
|
||||
case "mysql":
|
||||
client, err := s.pool.GetMySQLClient(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败: %v", err)
|
||||
pc := s.pool.GetMySQLClient(conn)
|
||||
if pc.Client == nil {
|
||||
return nil, fmt.Errorf("获取 MySQL 客户端失败")
|
||||
}
|
||||
return client.UpdateTableStructure(ctx, database, tableName, structure)
|
||||
defer pc.Release()
|
||||
return pc.Client.UpdateTableStructure(ctx, database, tableName, structure)
|
||||
|
||||
case "mongo":
|
||||
client, err := s.pool.GetMongoClient(conn)
|
||||
|
||||
Reference in New Issue
Block a user