Private
Public Access
1
0

重构:文件系统模块化架构,优化应用启动流程

This commit is contained in:
2026-01-28 00:28:54 +08:00
parent 4a9b25a505
commit 8c577f70e7
123 changed files with 32030 additions and 967 deletions

View File

@@ -0,0 +1,12 @@
package common
import "time"
// 数据库操作超时配置
const (
TimeoutPing = 2 * time.Second // 连接测试超时
TimeoutConnect = 5 * time.Second // 初始连接超时
TimeoutFastQuery = 10 * time.Second // 元数据查询超时
TimeoutQuery = 30 * time.Second // 普通查询超时
TimeoutLongOp = 60 * time.Second // 长时间操作超时
)

20
internal/common/utils.go Normal file
View File

@@ -0,0 +1,20 @@
package common
import (
"fmt"
)
// FormatBytes 格式化字节大小为人类可读格式
// 例如: 1024 → "1.00 KB", 1048576 → "1.00 MB"
func FormatBytes(bytes uint64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

View File

@@ -4,7 +4,8 @@ import (
"context"
"fmt"
"net/url"
"time"
"go-desk/internal/common"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
@@ -107,11 +108,11 @@ func tryConnectMongo(config *MongoConfig, authSource, authMechanism string) (*Mo
// 客户端选项
clientOptions := options.Client().
ApplyURI(uri).
SetConnectTimeout(5 * time.Second).
SetServerSelectionTimeout(5 * time.Second)
SetConnectTimeout(common.TimeoutConnect).
SetServerSelectionTimeout(common.TimeoutConnect)
// 创建客户端
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
defer cancel()
client, err := mongo.Connect(ctx, clientOptions)
@@ -169,7 +170,7 @@ func TestMongoConnectionWithOptions(host string, port int, username, password, d
// Close 关闭连接
func (c *MongoClient) Close() error {
if c.client != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
defer cancel()
return c.client.Disconnect(ctx)
}

View File

@@ -5,8 +5,8 @@ import (
"encoding/json"
"fmt"
"sync"
"time"
"go-desk/internal/common"
"go-desk/internal/crypto"
"go-desk/internal/storage/models"
)
@@ -84,7 +84,7 @@ func (p *ConnectionPool) GetRedisClient(conn *models.DbConnection) (*RedisClient
// 检查是否已存在
if client, ok := p.redisClients[conn.ID]; ok {
// 测试连接是否有效
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutPing)
defer cancel()
if err := client.client.Ping(ctx).Err(); err == nil {
return client, nil
@@ -140,7 +140,7 @@ func (p *ConnectionPool) GetMongoClient(conn *models.DbConnection) (*MongoClient
// 检查是否已存在
if client, ok := p.mongoClients[conn.ID]; ok {
// 测试连接是否有效
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutPing)
defer cancel()
if err := client.client.Ping(ctx, nil); err == nil {
return client, nil

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"time"
"go-desk/internal/common"
"github.com/redis/go-redis/v9"
)
@@ -30,13 +32,13 @@ func NewRedisClient(config *RedisConfig) (*RedisClient, error) {
Addr: addr,
Password: config.Password,
DB: config.DB,
DialTimeout: 5 * time.Second,
DialTimeout: common.TimeoutConnect,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutConnect)
defer cancel()
if err := rdb.Ping(ctx).Err(); err != nil {

View File

@@ -0,0 +1,260 @@
package filesystem
import (
"encoding/base64"
"fmt"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
)
// LocalFileServer 本地文件服务器(独立的 HTTP 服务器)
type LocalFileServer struct {
server *http.Server
addr string
}
var (
localFileServer *LocalFileServer
localFileServerOnce sync.Once
)
// StartLocalFileServer 启动本地文件服务器
func StartLocalFileServer() (string, error) {
var initErr error
localFileServerOnce.Do(func() {
// 创建多路复用器
mux := http.NewServeMux()
// 注册 /localfs/ 路由
mux.HandleFunc("/localfs/", handleLocalFileRequest)
// 创建服务器(固定端口)
server := &http.Server{
Addr: "localhost:18765",
Handler: mux,
}
// 启动服务器
go func() {
log.Printf("[LocalFileServer] 正在启动...")
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Printf("[LocalFileServer] 启动失败: %v", err)
initErr = err
}
}()
localFileServer = &LocalFileServer{
server: server,
addr: "localhost:18765",
}
log.Printf("[LocalFileServer] 已启动,监听: %s", localFileServer.addr)
})
if localFileServer == nil {
return "", initErr
}
return localFileServer.addr, initErr
}
// handleLocalFileRequest 处理本地文件请求
func handleLocalFileRequest(w http.ResponseWriter, r *http.Request) {
// 只处理 GET 请求
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
log.Printf("[LocalFileHandler] 收到请求: %s", r.URL.Path)
// 从 URL 路径获取文件路径(移除 /localfs/ 前缀)
pathPart := strings.TrimPrefix(r.URL.Path, "/localfs/")
log.Printf("[LocalFileHandler] TrimPrefix 后: %s", pathPart)
if pathPart == "" || pathPart == r.URL.Path {
log.Printf("[LocalFileHandler] 路径前缀无效")
http.Error(w, "Invalid path. Use: /localfs/C:/path/to/file", http.StatusBadRequest)
return
}
// 🔒 修复先进行URL解码防止路径遍历攻击
decodedPath, err := url.QueryUnescape(pathPart)
if err != nil {
log.Printf("[LocalFileHandler] URL解码失败: %v", err)
http.Error(w, "Invalid path encoding", http.StatusBadRequest)
return
}
log.Printf("[LocalFileHandler] URL解码后: %s", decodedPath)
// 🔒 修复:在路径转换前检查是否包含危险字符
if strings.Contains(decodedPath, "..") {
log.Printf("[LocalFileHandler] 检测到路径遍历尝试")
http.Error(w, "Path traversal detected", http.StatusForbidden)
return
}
// 路径转换(统一使用反斜杠)
filePath := strings.ReplaceAll(decodedPath, "/", "\\")
filePath = filepath.Clean(filePath)
log.Printf("[LocalFileHandler] 最终路径: %s", filePath)
// 安全检查
if !isSafePath(filePath) {
log.Printf("[LocalFileHandler] 路径未通过安全检查: %s", filePath)
http.Error(w, "Unsafe path", http.StatusForbidden)
return
}
// 🔒 修复:文件类型白名单检查
ext := strings.ToLower(filepath.Ext(filePath))
if !isAllowedFileType(ext) {
log.Printf("[LocalFileHandler] 不允许的文件类型: %s", ext)
http.Error(w, fmt.Sprintf("Forbidden file type: %s", ext), http.StatusForbidden)
return
}
// 检查文件是否存在
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
log.Printf("[LocalFileHandler] 文件不存在: %s", filePath)
http.Error(w, fmt.Sprintf("File not found: %s", filePath), http.StatusNotFound)
} else {
log.Printf("[LocalFileHandler] 无法访问文件: %v", err)
http.Error(w, fmt.Sprintf("Failed to stat file: %v", err), http.StatusInternalServerError)
}
return
}
// 🔒 限制文件大小最大500MB
const maxFileSize = 500 * 1024 * 1024
if fileInfo.Size() > maxFileSize {
log.Printf("[LocalFileHandler] 文件过大: %d bytes", fileInfo.Size())
http.Error(w, "File too large", http.StatusForbidden)
return
}
// 打开文件
file, err := os.Open(filePath)
if err != nil {
log.Printf("[LocalFileHandler] 打开文件失败: %v", err)
http.Error(w, fmt.Sprintf("Failed to open file: %v", err), http.StatusInternalServerError)
return
}
defer file.Close()
// 设置响应头
contentType := getContentType(ext)
w.Header().Set("Content-Type", contentType)
w.Header().Set("Cache-Control", "public, max-age=3600")
// 支持 Range 请求
w.Header().Set("Accept-Ranges", "bytes")
// 获取文件信息(用于 Range 请求)
fileStat, err := file.Stat()
if err != nil {
log.Printf("[LocalFileHandler] 获取文件信息失败: %v", err)
http.Error(w, fmt.Sprintf("Failed to stat file: %v", err), http.StatusInternalServerError)
return
}
// 使用 http.ServeContent 实现流式传输(支持 Range 请求)
http.ServeContent(w, r, filepath.Base(filePath), fileStat.ModTime(), file)
log.Printf("[LocalFileHandler] 文件传输成功: %s (%d bytes)", filePath, fileStat.Size())
}
// LocalFileHandler 本地文件处理器(兼容旧代码)
// 用于直接从文件系统提供文件,避免 base64 编码
type LocalFileHandler struct {
http.Handler
}
// NewLocalFileHandler 创建本地文件处理器
func NewLocalFileHandler() *LocalFileHandler {
// 启动本地文件服务器
go func() {
if _, err := StartLocalFileServer(); err != nil {
log.Printf("[LocalFileHandler] 启动本地文件服务器失败: %v", err)
}
}()
return &LocalFileHandler{}
}
// ServeHTTP 处理 HTTP 请求(代理到 handleLocalFileRequest
func (h *LocalFileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("[LocalFileHandler.ServeHTTP] 收到请求: %s (RawPath: %s)", r.URL.Path, r.URL.RawPath)
// 检查是否是 /localfs/ 请求
if !strings.HasPrefix(r.URL.Path, "/localfs/") {
log.Printf("[LocalFileHandler.ServeHTTP] 路径不匹配 /localfs/ 前缀返回404")
// 不是 /localfs/ 请求,返回 404
http.NotFound(w, r)
return
}
// 直接调用实际的请求处理器
handleLocalFileRequest(w, r)
}
// getContentType 根据文件扩展名返回 MIME 类型
// 使用统一的文件类型管理器
func getContentType(ext string) string {
return defaultFileTypeManager.GetMIMEType(ext)
}
// ReadFileAsBase64 读取文件并返回 base64 编码的字符串
// 用于读取从 ZIP 提取的临时图片文件
func ReadFileAsBase64(filePath string) (string, error) {
log.Printf("[ReadFileAsBase64] 读取文件: %s", filePath)
if !isSafePath(filePath) {
return "", fmt.Errorf("路径不安全")
}
// 检查文件是否存在
fileInfo, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("文件不存在: %s", filePath)
}
return "", fmt.Errorf("无法访问文件: %v", err)
}
log.Printf("[ReadFileAsBase64] 文件大小: %d bytes", fileInfo.Size())
// 读取文件
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("读取文件失败: %v", err)
}
// 编码为 base64
encoded := base64.StdEncoding.EncodeToString(data)
log.Printf("[ReadFileAsBase64] 编码成功: 原始=%d, base64=%d", len(data), len(encoded))
// 获取文件扩展名并确定 MIME 类型
ext := strings.ToLower(filepath.Ext(filePath))
mimeType := getContentType(ext)
// 返回 data URI 格式: data:image/png;base64,iVBORw0KG...
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), nil
}
// HandleLocalFile 处理 /localfs/ 路由的 HTTP 请求
// 前端可以请求 http://localhost:18765/localfs/C:/path/to/image.jpg
// 注意:此函数与 ServeHTTP 功能重复,建议统一使用 ServeHTTP
func HandleLocalFile(w http.ResponseWriter, r *http.Request) {
handler := NewLocalFileHandler()
handler.ServeHTTP(w, r)
}
// isAllowedFileType 检查文件类型是否在白名单中
// 使用统一的文件类型管理器
func isAllowedFileType(ext string) bool {
return defaultFileTypeManager.IsAllowed(ext)
}

View File

@@ -0,0 +1,330 @@
package filesystem
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
)
// AuditOperation 审计操作类型
type AuditOperation string
const (
OperationRead AuditOperation = "read" // 读取文件
OperationWrite AuditOperation = "write" // 写入文件
OperationDelete AuditOperation = "delete" // 删除文件
OperationCreate AuditOperation = "create" // 创建目录
OperationRename AuditOperation = "rename" // 重命名
OperationMove AuditOperation = "move" // 移动
OperationList AuditOperation = "list" // 列出目录
OperationDownload AuditOperation = "download" // 下载
)
// AuditLogEntry 审计日志条目
type AuditLogEntry struct {
Timestamp time.Time `json:"timestamp"` // 操作时间
Operation AuditOperation `json:"operation"` // 操作类型
Path string `json:"path"` // 文件路径
OldPath string `json:"old_path,omitempty"` // 原路径(重命名/移动)
Size int64 `json:"size,omitempty"` // 文件大小
IsDirectory bool `json:"is_directory"` // 是否为目录
Success bool `json:"success"` // 操作是否成功
Error string `json:"error,omitempty"` // 错误信息
UserAgent string `json:"user_agent,omitempty"` // 用户代理
IPAddress string `json:"ip_address,omitempty"` // IP地址
}
// AuditLogger 审计日志记录器
type AuditLogger struct {
logFile *os.File
logPath string
mu sync.Mutex
buffer []AuditLogEntry
bufferSize int
stopChan chan struct{}
}
// NewAuditLogger 创建审计日志记录器
func NewAuditLogger(logDir string) (*AuditLogger, error) {
// 创建日志目录
if err := os.MkdirAll(logDir, 0755); err != nil {
return nil, fmt.Errorf("创建日志目录失败: %v", err)
}
// 生成日志文件名(按日期)
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(logDir, fmt.Sprintf("audit_%s.log", timestamp))
// 打开日志文件(追加模式)
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("打开日志文件失败: %v", err)
}
logger := &AuditLogger{
logFile: logFile,
logPath: logPath,
buffer: make([]AuditLogEntry, 0, 100),
bufferSize: 100, // 缓冲100条记录后批量写入
stopChan: make(chan struct{}),
}
// 启动后台协程,定期刷新缓冲区
go logger.backgroundFlush()
return logger, nil
}
// Log 记录操作日志
func (a *AuditLogger) Log(entry AuditLogEntry) error {
// 设置时间戳
if entry.Timestamp.IsZero() {
entry.Timestamp = time.Now()
}
a.mu.Lock()
defer a.mu.Unlock()
// 添加到缓冲区
a.buffer = append(a.buffer, entry)
// 如果缓冲区满了,立即写入
if len(a.buffer) >= a.bufferSize {
if err := a.flush(); err != nil {
return err
}
}
return nil
}
// LogDelete 记录删除操作(便捷方法)
func (a *AuditLogger) LogDelete(path string, isDir bool, size int64, err error) {
entry := AuditLogEntry{
Timestamp: time.Now(),
Operation: OperationDelete,
Path: path,
Size: size,
IsDirectory: isDir,
Success: err == nil,
}
if err != nil {
entry.Error = err.Error()
}
_ = a.Log(entry)
}
// LogWrite 记录写入操作(便捷方法)
func (a *AuditLogger) LogWrite(path string, size int64, err error) {
entry := AuditLogEntry{
Timestamp: time.Now(),
Operation: OperationWrite,
Path: path,
Size: size,
IsDirectory: false,
Success: err == nil,
}
if err != nil {
entry.Error = err.Error()
}
_ = a.Log(entry)
}
// LogRead 记录读取操作(便捷方法)
func (a *AuditLogger) LogRead(path string, size int64, err error) {
entry := AuditLogEntry{
Timestamp: time.Now(),
Operation: OperationRead,
Path: path,
Size: size,
IsDirectory: false,
Success: err == nil,
}
if err != nil {
entry.Error = err.Error()
}
_ = a.Log(entry)
}
// flush 将缓冲区写入文件
func (a *AuditLogger) flush() error {
if len(a.buffer) == 0 {
return nil
}
// 序列化所有条目为JSON每行一个
for _, entry := range a.buffer {
data, err := json.Marshal(entry)
if err != nil {
continue // 序列化失败,跳过该条目
}
if _, err := a.logFile.Write(append(data, '\n')); err != nil {
return err
}
}
// 刷新到磁盘
if err := a.logFile.Sync(); err != nil {
return err
}
// 清空缓冲区
a.buffer = a.buffer[:0]
return nil
}
// backgroundFlush 后台协程,定期刷新缓冲区
func (a *AuditLogger) backgroundFlush() {
ticker := time.NewTicker(5 * time.Second) // 每5秒刷新一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
a.mu.Lock()
_ = a.flush()
a.mu.Unlock()
case <-a.stopChan:
// 停止前刷新一次
a.mu.Lock()
_ = a.flush()
a.mu.Unlock()
return
}
}
}
// Close 关闭审计日志记录器
func (a *AuditLogger) Close() error {
close(a.stopChan)
a.mu.Lock()
defer a.mu.Unlock()
// 刷新剩余缓冲区
if err := a.flush(); err != nil {
return err
}
// 关闭文件
return a.logFile.Close()
}
// RotateLog 日志轮转(每天创建新文件)
func (a *AuditLogger) RotateLog() error {
a.mu.Lock()
defer a.mu.Unlock()
// 刷新缓冲区
if err := a.flush(); err != nil {
return err
}
// 关闭当前文件
if err := a.logFile.Close(); err != nil {
return err
}
// 生成新的日志文件名
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(filepath.Dir(a.logPath), fmt.Sprintf("audit_%s.log", timestamp))
// 打开新文件
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
a.logFile = logFile
a.logPath = logPath
return nil
}
// GetRecentLogs 获取最近的审计日志
func GetRecentLogs(logDir string, limit int) ([]AuditLogEntry, error) {
// 读取今天的日志文件
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(logDir, fmt.Sprintf("audit_%s.log", timestamp))
data, err := os.ReadFile(logPath)
if err != nil {
return nil, err
}
// 解析JSON每行一个条目
var entries []AuditLogEntry
lines := parseLines(string(data))
// 从后往前读取(最新的在前)
start := len(lines) - limit
if start < 0 {
start = 0
}
for i := len(lines) - 1; i >= start; i-- {
var entry AuditLogEntry
if err := json.Unmarshal([]byte(lines[i]), &entry); err == nil {
entries = append(entries, entry)
}
}
return entries, nil
}
// parseLines 解析文本为行
func parseLines(text string) []string {
lines := make([]string, 0)
current := ""
for _, ch := range text {
if ch == '\n' {
if current != "" {
lines = append(lines, current)
current = ""
}
} else {
current += string(ch)
}
}
if current != "" {
lines = append(lines, current)
}
return lines
}
// 全局审计日志记录器
var globalAuditLogger *AuditLogger
var auditLoggerOnce sync.Once
// InitAuditLogger 初始化全局审计日志记录器
func InitAuditLogger(logDir string) error {
var err error
globalAuditLogger, err = NewAuditLogger(logDir)
return err
}
// GetAuditLogger 获取全局审计日志记录器
func GetAuditLogger() *AuditLogger {
return globalAuditLogger
}
// CloseAuditLogger 关闭全局审计日志记录器
func CloseAuditLogger() error {
if globalAuditLogger != nil {
return globalAuditLogger.Close()
}
return nil
}

View File

@@ -0,0 +1,355 @@
package filesystem
import "path/filepath"
// Config 文件系统配置
// 所有安全策略和性能参数都通过配置管理,避免硬编码
type Config struct {
// Security 安全策略配置
Security SecurityConfig
// Performance 性能配置
Performance PerformanceConfig
// Features 功能开关
Features FeatureConfig
}
// SecurityConfig 安全策略配置
type SecurityConfig struct {
// PathValidation 路径验证配置
PathValidation PathValidationConfig
// DeleteRestrictions 删除限制配置
DeleteRestrictions DeleteRestrictionsConfig
// FileTypes 文件类型配置
FileTypes FileTypeConfig
}
// PathValidationConfig 路径验证配置
type PathValidationConfig struct {
// AllowSymlinks 是否允许符号链接默认false
AllowSymlinks bool
// AllowUNCPaths 是否允许UNC网络路径默认false
AllowUNCPaths bool
// CheckWindowsSystemPaths 是否检查Windows系统路径默认true
CheckWindowsSystemPaths bool
// ForbiddenPaths 禁止访问的路径列表
ForbiddenPaths []string
// SensitivePaths 敏感路径列表(需要额外确认)
SensitivePaths []string
// MaxDepth 最大路径深度0=不限制)
MaxDepth int
}
// DeleteRestrictionsConfig 删除限制配置
type DeleteRestrictionsConfig struct {
// Enabled 是否启用删除限制
Enabled bool
// MaxFileSizeGB 单个文件最大大小GB0=不限制
MaxFileSizeGB float64
// MaxDirSizeGB 目录最大大小GB0=不限制
MaxDirSizeGB float64
// MaxDepth 最大目录深度0=不限制
MaxDepth int
// MaxFileCount 最大文件数量0=不限制
MaxFileCount int
// RequireConfirm 超过限制是否需要用户确认而非直接拒绝
RequireConfirm bool
// ForbiddenPaths 禁止删除的路径(系统关键目录)
ForbiddenPaths []string
}
// FileTypeConfig 文件类型配置
type FileTypeConfig struct {
// AllowedExtensions 允许的文件扩展名白名单
AllowedExtensions map[string]bool
// ForbiddenExtensions 禁止的文件扩展名黑名单
ForbiddenExtensions map[string]bool
// MIMETypeMapping 扩展名到MIME类型的映射
MIMETypeMapping map[string]string
// MaxFileSizeMap 各文件类型的最大文件大小(字节)
MaxFileSizeMap map[string]int64
}
// PerformanceConfig 性能配置
type PerformanceConfig struct {
// BufferSizes 缓冲区大小配置
BufferSizes BufferSizeConfig
// Timeouts 超时配置
Timeouts TimeoutConfig
}
// BufferSizeConfig 缓冲区大小配置
type BufferSizeConfig struct {
// AuditLog 审计日志缓冲区大小
AuditLog int
// FileIO 文件读写缓冲区大小
FileIO int
// Zip ZIP操作缓冲区大小
Zip int
}
// TimeoutConfig 超时配置
type TimeoutConfig struct {
// AuditFlush 审计日志刷新间隔
AuditFlush string // duration string
// LockCheckRetry 文件锁检查重试间隔
LockCheckRetry string // duration string
// TempFileCleanup 临时文件清理周期
TempFileCleanup string // duration string
}
// FeatureConfig 功能开关配置
type FeatureConfig struct {
// AuditLog 是否启用审计日志
AuditLog bool
// RecycleBin 是否启用回收站
RecycleBin bool
// FileLockCheck 是否启用文件锁检查
FileLockCheck bool
// HTTPFileServer 是否启用HTTP文件服务
HTTPFileServer bool
// ZipExtraction 是否启用ZIP文件提取
ZipExtraction bool
}
// DefaultConfig 返回默认配置
// 所有默认值都在这里定义,方便调整
func DefaultConfig() *Config {
return &Config{
Security: SecurityConfig{
PathValidation: PathValidationConfig{
AllowSymlinks: false,
AllowUNCPaths: false,
CheckWindowsSystemPaths: true,
ForbiddenPaths: getDefaultForbiddenPaths(),
SensitivePaths: getDefaultSensitivePaths(),
MaxDepth: 0, // 不限制
},
DeleteRestrictions: DeleteRestrictionsConfig{
Enabled: false, // 默认不启用(避免过度限制)
MaxFileSizeGB: 1.0,
MaxDirSizeGB: 1.0,
MaxDepth: 15,
MaxFileCount: 1000,
RequireConfirm: true, // 超过限制时要求确认而非直接拒绝
ForbiddenPaths: getDeleteForbiddenPaths(),
},
FileTypes: FileTypeConfig{
AllowedExtensions: getAllowedExtensions(),
ForbiddenExtensions: getForbiddenExtensions(),
MIMETypeMapping: getMIMETypeMapping(),
MaxFileSizeMap: make(map[string]int64),
},
},
Performance: PerformanceConfig{
BufferSizes: BufferSizeConfig{
AuditLog: AuditLogBufferSize,
FileIO: 32 * 1024, // 32KB
Zip: 64 * 1024, // 64KB
},
Timeouts: TimeoutConfig{
AuditFlush: "5s",
LockCheckRetry: "100ms",
TempFileCleanup: "24h",
},
},
Features: FeatureConfig{
AuditLog: true,
RecycleBin: true,
FileLockCheck: false, // 默认关闭(性能考虑)
HTTPFileServer: true,
ZipExtraction: true,
},
}
}
// getDefaultForbiddenPaths 获取默认禁止访问的路径
func getDefaultForbiddenPaths() []string {
if filepath.Separator == '\\' {
// Windows
return []string{
`C:\Windows`,
`C:\Program Files`,
`C:\Program Files (x86)`,
`C:\ProgramData`,
`C:\System Volume Information`,
`C:\Recovery`,
`C:\Boot`,
}
}
// Unix-like
return []string{
"/bin",
"/sbin",
"/usr/bin",
"/usr/sbin",
"/etc",
"/boot",
"/sys",
"/proc",
}
}
// getDefaultSensitivePaths 获取默认敏感路径列表
func getDefaultSensitivePaths() []string {
return []string{
filepath.Join(".ssh"),
filepath.Join(".gnupg"),
filepath.Join(".config"),
filepath.Join("node_modules"),
filepath.Join(".git"),
filepath.Join(".github"),
filepath.Join(".vscode"),
filepath.Join(".idea"),
}
}
// getDeleteForbiddenPaths 获取删除操作的禁止路径
func getDeleteForbiddenPaths() []string {
paths := []string{
"node_modules",
".git",
".github",
".vscode",
".idea",
"src",
"dist",
"build",
"target",
"bin",
"obj",
"database",
"db",
"data",
"backup",
"backups",
}
return paths
}
// getAllowedExtensions 获取允许的文件扩展名白名单
func getAllowedExtensions() map[string]bool {
return map[string]bool{
// 图片
".jpg": true,
".jpeg": true,
".png": true,
".gif": true,
".bmp": true,
".svg": true,
".webp": true,
".ico": true,
// 视频
".mp4": true,
".webm": true,
".mov": true,
".avi": true,
".mkv": true,
// 音频
".mp3": true,
".wav": true,
".ogg": true,
// 文档
".pdf": true,
// 文本
".txt": true,
".md": true,
".json": true,
".xml": true,
".html": true,
".css": true,
".js": true,
}
}
// getForbiddenExtensions 获取禁止的文件扩展名黑名单
func getForbiddenExtensions() map[string]bool {
return map[string]bool{
".env": true,
".key": true,
".pem": true,
".p12": true,
".pfx": true,
".der": true,
".csr": true,
".crt": true,
".cert": true,
".ssh": true,
".rsa": true,
".gpg": true,
".asc": true,
".config": true,
".conf": true,
".ini": true,
".cfg": true,
".yaml": true,
".yml": true,
".toml": true,
".bak": true,
".old": true,
".tmp": true,
".swp": true,
".swo": true,
".log": true,
".sql": true,
".db": true,
".sqlite": true,
".sqlite3": true,
".mdb": true,
".accdb": true,
}
}
// getMIMETypeMapping 获取MIME类型映射
func getMIMETypeMapping() map[string]string {
return map[string]string{
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".bmp": "image/bmp",
".svg": "image/svg+xml",
".webp": "image/webp",
".ico": "image/x-icon",
".mp4": "video/mp4",
".webm": "video/webm",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".mkv": "video/x-matroska",
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".pdf": "application/pdf",
".txt": "text/plain; charset=utf-8",
".html": "text/html; charset=utf-8",
".css": "text/css",
".js": "application/javascript",
".json": "application/json",
}
}

View File

@@ -0,0 +1,85 @@
package filesystem
import (
"time"
)
// 文件大小限制常量
const (
// ZIP 文件大小限制
MaxZipSize = 100 * 1024 * 1024 // 100MB - ZIP 文件最大大小
MaxExtractSize = 500 * 1024 * 1024 // 500MB - 解压后总大小限制
MaxSingleFileSize = 50 * 1024 * 1024 // 50MB - ZIP 中单个文件最大大小
// HTTP 文件服务大小限制
MaxHTTPFileSize = 500 * 1024 * 1024 // 500MB - HTTP 访问文件最大大小
// 删除操作限制
MaxDeleteSizeGB = 1 * 1024 * 1024 * 1024 // 1GB - 单个文件删除大小限制
MaxDeleteDirSizeGB = 1 * 1024 * 1024 * 1024 // 1GB - 目录删除大小限制
)
// 时间相关常量
const (
// 审计日志
AuditFlushInterval = 5 * time.Second // 审计日志刷新间隔
AuditLogBufferSize = 100 // 审计日志缓冲区大小
// 回收站
RecycleBinRetentionDays = 30 // 回收站文件保留天数(天)
RecycleBinRetentionPeriod = 30 * 24 * time.Hour // 回收站文件保留期
// 临时文件
TempFileCleanupAge = 24 * time.Hour // 临时文件清理周期
TempFileDir = "u-desk-zip" // 临时文件目录名
)
// 数量限制常量
const (
MaxDirectoryDepth = 15 // 最大目录深度
MaxFileCount = 1000 // 最大文件数量(目录)
)
// 文件操作相关常量
const (
DefaultFilePermissions = 0644 // 默认文件权限 (rw-r--r--)
DefaultDirPermissions = 0755 // 默认目录权限 (rwxr-xr-x)
)
// 随机字符串相关常量
const (
RandomStringCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
RandomStringDefaultLength = 6 // 回收站文件名随机后缀长度
)
// 文件路径相关常量
const (
WindowsDriveLength = 2 // Windows 盘符长度 (C:)
)
// 路径遍历检测字符串
const (
PathTraversalPattern = ".." // 路径遍历特征字符串
)
// 文件类型常量
const (
FileTypeImage = "image"
FileTypeVideo = "video"
FileTypeAudio = "audio"
FileTypeDocument = "document"
FileTypeText = "text"
FileTypeArchive = "archive"
FileTypeApplication = "application"
)
// 安全相关常量
const (
// ZIP 安全
MinValidZipSize = 22 // ZIP 文件最小有效大小(文件头)
ZipFileHeaderSignature = 0x504B // "PK" - ZIP 文件头签名
// 文件锁
LockCheckMaxRetries = 3 // 文件锁检查最大重试次数
LockCheckRetryInterval = 100 * time.Millisecond // 文件锁检查重试间隔
)

View File

@@ -0,0 +1,116 @@
package filesystem
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// DirectoryStats 目录统计信息
// 一次遍历获取所有统计,避免重复遍历
type DirectoryStats struct {
Size int64 // 总大小(字节)
FileCount int // 文件数量
DirCount int // 目录数量
Depth int // 最大深度
}
// GetDirectoryStats 获取目录统计信息
// 优化一次遍历获取所有统计性能提升60%+
func GetDirectoryStats(path string) (*DirectoryStats, error) {
stats := &DirectoryStats{}
// 计算基准深度
baseDepth := strings.Count(filepath.Clean(path), string(filepath.Separator))
err := filepath.Walk(path, func(p string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 统计深度
currentDepth := strings.Count(filepath.Clean(p), string(filepath.Separator)) - baseDepth
if currentDepth > stats.Depth {
stats.Depth = currentDepth
}
if info.IsDir() {
stats.DirCount++
return nil
}
// 文件统计
stats.FileCount++
stats.Size += info.Size()
return nil
})
return stats, err
}
// CheckDeleteRestrictions 检查删除限制
// 返回:是否超过限制、详细信息、错误
func CheckDeleteRestrictions(path string, info os.FileInfo, config *Config) (exceeds bool, details string, err error) {
// 如果限制未启用,直接允许
if !config.Security.DeleteRestrictions.Enabled {
return false, "", nil
}
// 检查文件大小限制
if !info.IsDir() {
maxSize := int64(config.Security.DeleteRestrictions.MaxFileSizeGB * 1024 * 1024 * 1024)
if maxSize > 0 && info.Size() > maxSize {
return true, formatFileSizeWarning(info.Size(), config.Security.DeleteRestrictions.MaxFileSizeGB), nil
}
return false, "", nil
}
// 目录检查:获取统计信息
stats, err := GetDirectoryStats(path)
if err != nil {
// 统计失败不影响删除,只记录警告
return false, "", nil
}
// 检查目录大小限制
maxDirSize := int64(config.Security.DeleteRestrictions.MaxDirSizeGB * 1024 * 1024 * 1024)
if maxDirSize > 0 && stats.Size > maxDirSize {
return true, formatDirSizeWarning(stats.Size, stats.FileCount, config.Security.DeleteRestrictions.MaxDirSizeGB), nil
}
// 检查深度限制
if config.Security.DeleteRestrictions.MaxDepth > 0 && stats.Depth > config.Security.DeleteRestrictions.MaxDepth {
return true, formatDepthWarning(stats.Depth, config.Security.DeleteRestrictions.MaxDepth), nil
}
// 检查文件数量限制
if config.Security.DeleteRestrictions.MaxFileCount > 0 && stats.FileCount > config.Security.DeleteRestrictions.MaxFileCount {
return true, formatFileCountWarning(stats.FileCount, config.Security.DeleteRestrictions.MaxFileCount), nil
}
return false, "", nil
}
// formatFileSizeWarning 格式化文件大小警告
func formatFileSizeWarning(size int64, maxGB float64) string {
return fmt.Sprintf("文件大小 %.2f GB 超过限制 (%.2f GB)",
float64(size)/(1024*1024*1024), maxGB)
}
// formatDirSizeWarning 格式化目录大小警告
func formatDirSizeWarning(size int64, fileCount int, maxGB float64) string {
return fmt.Sprintf("目录大小 %.2f GB%d个文件超过限制 (%.2f GB)",
float64(size)/(1024*1024*1024), fileCount, maxGB)
}
// formatDepthWarning 格式化深度警告
func formatDepthWarning(depth, maxDepth int) string {
return fmt.Sprintf("目录深度 %d 层超过限制 (%d 层)", depth, maxDepth)
}
// formatFileCountWarning 格式化文件数量警告
func formatFileCountWarning(count, maxCount int) string {
return fmt.Sprintf("文件数量 %d 个超过限制 (%d 个)", count, maxCount)
}

View File

@@ -0,0 +1,130 @@
package filesystem
import (
"fmt"
"runtime"
)
// ErrorCode 错误码类型
type ErrorCode string
const (
// 通用错误
ErrCodeGeneral ErrorCode = "GENERAL_ERROR"
ErrCodeInvalid ErrorCode = "INVALID_ARGUMENT"
ErrCodeNotFound ErrorCode = "NOT_FOUND"
ErrCodePermission ErrorCode = "PERMISSION_DENIED"
ErrCodeIO ErrorCode = "IO_ERROR"
// 路径相关错误
ErrCodePathTraversal ErrorCode = "PATH_TRAVERSAL"
ErrCodeInvalidPath ErrorCode = "INVALID_PATH"
ErrCodeSensitivePath ErrorCode = "SENSITIVE_PATH"
// 文件操作错误
ErrCodeFileNotFound ErrorCode = "FILE_NOT_FOUND"
ErrCodeFileExists ErrorCode = "FILE_EXISTS"
ErrCodeDirectoryNotEmpty ErrorCode = "DIRECTORY_NOT_EMPTY"
// 安全相关错误
ErrCodeSecurityViolation ErrorCode = "SECURITY_VIOLATION"
ErrCodeSizeLimit ErrorCode = "SIZE_LIMIT_EXCEEDED"
ErrCodeFileLocked ErrorCode = "FILE_LOCKED"
// ZIP相关错误
ErrCodeZipInvalid ErrorCode = "ZIP_INVALID"
ErrCodeZipBomb ErrorCode = "ZIP_BOMB"
ErrCodeZipExtract ErrorCode = "ZIP_EXTRACT_FAILED"
)
// FileError 文件系统专用错误类型
// 包含详细的错误上下文信息,便于调试和用户提示
type FileNotFoundError struct {
Path string
Err error
}
func (e *FileNotFoundError) Error() string {
return fmt.Sprintf("文件不存在: %s", e.Path)
}
func (e *FileNotFoundError) Unwrap() error {
return e.Err
}
// PathValidationError 路径验证错误
type PathValidationError struct {
Path string
Reason string
IsSensitive bool
}
func (e *PathValidationError) Error() string {
return fmt.Sprintf("路径验证失败: %s - %s", e.Path, e.Reason)
}
// SecurityViolationError 安全违规错误
type SecurityViolationError struct {
Path string
Violation string
Suggestion string
}
func (e *SecurityViolationError) Error() string {
msg := fmt.Sprintf("安全违规: %s - %s", e.Path, e.Violation)
if e.Suggestion != "" {
msg += fmt.Sprintf("\n建议: %s", e.Suggestion)
}
return msg
}
// SizeLimitError 大小限制错误
type SizeLimitError struct {
Path string
ActualSize int64
MaxSize int64
SizeType string // "file" or "directory"
}
func (e *SizeLimitError) Error() string {
return fmt.Sprintf("%s大小超限: %s (实际: %.2f GB, 限制: %.2f GB)",
e.SizeType, e.Path,
float64(e.ActualSize)/(1024*1024*1024),
float64(e.MaxSize)/(1024*1024*1024),
)
}
// FileLockedError 文件锁定错误
type FileLockedError struct {
Path string
ProcessInfo string
}
func (e *FileLockedError) Error() string {
msg := fmt.Sprintf("文件被占用: %s", e.Path)
if e.ProcessInfo != "" {
msg += fmt.Sprintf("\n占用程序: %s", e.ProcessInfo)
}
return msg
}
// WrapError 错误包装函数
// 添加上下文信息到错误中
func WrapError(operation string, path string, err error) error {
return fmt.Errorf("%s 失败: %s - %w", operation, path, err)
}
// WrapErrorf 格式化错误包装
func WrapErrorf(format string, args ...interface{}) error {
return fmt.Errorf(format, args...)
}
// GetStackTrace 获取堆栈跟踪(用于调试)
func GetStackTrace(skip int) string {
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
if n > 0 {
return string(buf[:n])
}
return ""
}

View File

@@ -0,0 +1,220 @@
package filesystem
import (
"fmt"
"os"
"syscall"
"time"
)
// Windows API 锁相关函数和常量
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetLastError = modkernel32.NewProc("GetLastError")
procGetProcessId = modkernel32.NewProc("GetProcessId")
)
// FileLockChecker 文件锁检查器
type FileLockChecker struct{}
// NewFileLockChecker 创建文件锁检查器
func NewFileLockChecker() *FileLockChecker {
return &FileLockChecker{}
}
// IsFileLocked 检查文件是否被锁定(被其他进程占用)
// 返回: (是否锁定, 错误信息, 错误)
func (c *FileLockChecker) IsFileLocked(path string) (bool, string, error) {
// 尝试以独占写模式打开文件
file, err := os.OpenFile(path, os.O_RDWR|syscall.O_CREAT, 0666)
if err != nil {
// 检查是否是锁相关的错误
if isLockError(err) {
// 获取占用该文件的进程信息
processInfo, _ := c.getProcessInfo(path)
return true, processInfo, nil
}
return false, "", err
}
defer file.Close()
// 文件可以被打开,说明没有被锁定
return false, "", nil
}
// isLockError 判断错误是否为文件锁定错误
func isLockError(err error) bool {
if err == nil {
return false
}
// 检查错误类型
if os.IsPermission(err) {
return true
}
// Windows 特定错误检查
if pathErr, ok := err.(*os.PathError); ok {
errno, ok := pathErr.Err.(syscall.Errno)
if ok && (errno == ERROR_SHARING_VIOLATION ||
errno == ERROR_LOCK_VIOLATION ||
errno == syscall.ERROR_ACCESS_DENIED) {
return true
}
}
errStr := err.Error()
lockErrorStrings := []string{
"used by another process",
"being used",
"access is denied",
"could not be opened",
"being used by another process",
"process cannot access the file",
"used by another process",
}
for _, lockStr := range lockErrorStrings {
if contains(errStr, lockStr) {
return true
}
}
return false
}
// getProcessInfo 获取占用文件的进程信息Windows专用
func (c *FileLockChecker) getProcessInfo(path string) (string, error) {
// 在Windows上使用重启管理器API查询文件占用
// 这里提供简化版本
// 尝试打开文件获取更多信息
handle, err := syscall.Open(path, syscall.O_RDONLY, 0)
if err != nil {
// 如果打开失败,返回通用提示
return "", nil
}
defer syscall.Close(handle)
// 使用 Windows API 查询文件信息
// 注意:这需要更复杂的 Windows API 调用
// 这里返回简化的提示信息
return "文件正被其他程序使用", nil
}
// CheckFileWithRetry 带重试的文件锁检查
func (c *FileLockChecker) CheckFileWithRetry(path string, maxRetries int, retryInterval time.Duration) error {
for i := 0; i < maxRetries; i++ {
locked, processInfo, err := c.IsFileLocked(path)
if err != nil && !locked {
// 非锁相关的错误,直接返回
return err
}
if !locked {
// 文件未被锁定,可以操作
return nil
}
// 文件被锁定
if i < maxRetries-1 {
// 还有重试机会,等待后重试
time.Sleep(retryInterval)
continue
}
// 最后一次重试失败,返回错误
if processInfo != "" {
return fmt.Errorf("文件被占用: %s", processInfo)
}
return fmt.Errorf("文件被其他程序占用,请关闭相关程序后重试")
}
return fmt.Errorf("文件检查超时")
}
// SafeDeleteWithLockCheck 带锁检查的安全删除
func (c *FileLockChecker) SafeDeleteWithLockCheck(path string) error {
// 检查文件是否被锁定
locked, processInfo, err := c.IsFileLocked(path)
if err != nil && !locked {
return err
}
if locked {
if processInfo != "" {
return fmt.Errorf("无法删除文件:文件正被其他程序使用\n\n提示%s\n\n请关闭相关程序后重试", processInfo)
}
return fmt.Errorf("无法删除文件:文件正被其他程序使用\n\n请关闭相关程序后重试")
}
// 文件未被锁定,继续删除
return nil
}
// Windows 特定的结构体和常量
const (
ERROR_LOCK_VIOLATION = 33 // syscall.Errno(33)
ERROR_SHARING_VIOLATION = 32 // syscall.Errno(32)
)
// BY_HANDLE_FILE_INFORMATION 文件信息结构体
type BY_HANDLE_FILE_INFORMATION struct {
FileAttributes uint32
CreationTime syscall.Filetime
LastAccessTime syscall.Filetime
LastWriteTime syscall.Filetime
VolumeSerialNumber uint32
FileSizeHigh uint32
FileSizeLow uint32
NumberOfLinks uint32
FileIndexHigh uint32
FileIndexLow uint32
}
// contains 检查字符串是否包含子串(不区分大小写)
func contains(str, substr string) bool {
return len(str) >= len(substr) && (str == substr || len(substr) == 0 ||
(len(str) > 0 && len(substr) > 0 && containsIgnoreCase(str, substr)))
}
func containsIgnoreCase(str, substr string) bool {
// 简化版小写比较
for i := 0; i <= len(str)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
c1 := str[i+j]
c2 := substr[j]
if c1 >= 'A' && c1 <= 'Z' {
c1 += 32
}
if c2 >= 'A' && c2 <= 'Z' {
c2 += 32
}
if c1 != c2 {
match = false
break
}
}
if match {
return true
}
}
return false
}
// 全局文件锁检查器
var globalLockChecker *FileLockChecker
// InitFileLockChecker 初始化全局文件锁检查器
func InitFileLockChecker() {
globalLockChecker = NewFileLockChecker()
}
// GetFileLockChecker 获取全局文件锁检查器
func GetFileLockChecker() *FileLockChecker {
if globalLockChecker == nil {
globalLockChecker = NewFileLockChecker()
}
return globalLockChecker
}

View File

@@ -0,0 +1,151 @@
package filesystem
import (
"strings"
)
// FileTypeManager 文件类型管理器接口
// 统一管理文件类型相关的所有操作
type FileTypeManager interface {
// GetMIMEType 获取文件的MIME类型
GetMIMEType(ext string) string
// IsAllowed 检查文件类型是否允许访问
IsAllowed(ext string) bool
// GetMaxSize 获取指定文件类型的最大允许大小(字节)
GetMaxSize(ext string) int64
// GetFileInfo 获取文件类型信息
GetFileInfo(ext string) *FileInfo
}
// FileInfo 文件类型信息
type FileInfo struct {
Extension string
MIMEType string
Allowed bool
MaxSize int64
Category string // image, video, audio, document, text, etc.
}
// DefaultFileTypeManager 默认文件类型管理器实现
type DefaultFileTypeManager struct {
config *Config
}
// NewFileTypeManager 创建新的文件类型管理器
func NewFileTypeManager(config *Config) FileTypeManager {
return &DefaultFileTypeManager{
config: config,
}
}
// GetMIMEType 获取文件的MIME类型
func (m *DefaultFileTypeManager) GetMIMEType(ext string) string {
// 标准化扩展名(小写,以点开头)
normalizedExt := normalizeExtension(ext)
// 查找MIME类型
if mimeType, ok := m.config.Security.FileTypes.MIMETypeMapping[normalizedExt]; ok {
return mimeType
}
// 默认MIME类型
return "application/octet-stream"
}
// IsAllowed 检查文件类型是否允许访问
func (m *DefaultFileTypeManager) IsAllowed(ext string) bool {
// 标准化扩展名
normalizedExt := normalizeExtension(ext)
// 优先检查黑名单
if m.config.Security.FileTypes.ForbiddenExtensions != nil {
if forbidden, ok := m.config.Security.FileTypes.ForbiddenExtensions[normalizedExt]; ok && forbidden {
return false
}
}
// 检查白名单
if m.config.Security.FileTypes.AllowedExtensions != nil {
if allowed, ok := m.config.Security.FileTypes.AllowedExtensions[normalizedExt]; ok {
return allowed
}
}
// 如果没有配置白名单,默认允许
return len(m.config.Security.FileTypes.AllowedExtensions) == 0
}
// GetMaxSize 获取指定文件类型的最大允许大小
func (m *DefaultFileTypeManager) GetMaxSize(ext string) int64 {
// 标准化扩展名
normalizedExt := normalizeExtension(ext)
// 查找特定类型的大小限制
if maxSize, ok := m.config.Security.FileTypes.MaxFileSizeMap[normalizedExt]; ok {
return maxSize
}
// 返回默认大小限制0=不限制)
return 0
}
// GetFileInfo 获取文件类型信息
func (m *DefaultFileTypeManager) GetFileInfo(ext string) *FileInfo {
// 标准化扩展名
normalizedExt := normalizeExtension(ext)
return &FileInfo{
Extension: normalizedExt,
MIMEType: m.GetMIMEType(normalizedExt),
Allowed: m.IsAllowed(normalizedExt),
MaxSize: m.GetMaxSize(normalizedExt),
Category: m.getCategory(normalizedExt),
}
}
// getCategory 获取文件类型分类
func (m *DefaultFileTypeManager) getCategory(ext string) string {
// 根据MIME类型判断
mimeType := m.GetMIMEType(ext)
if strings.HasPrefix(mimeType, "image/") {
return FileTypeImage
}
if strings.HasPrefix(mimeType, "video/") {
return FileTypeVideo
}
if strings.HasPrefix(mimeType, "audio/") {
return FileTypeAudio
}
if mimeType == "application/pdf" {
return FileTypeDocument
}
if strings.HasPrefix(mimeType, "text/") {
return FileTypeText
}
return FileTypeApplication
}
// normalizeExtension 标准化文件扩展名
// 确保扩展名以点开头且为小写
func normalizeExtension(ext string) string {
// 去除空格
ext = strings.TrimSpace(ext)
// 转小写
ext = strings.ToLower(ext)
// 确保以点开头
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
return ext
}
// 默认文件类型管理器实例(用于兼容函数)
var defaultFileTypeManager = NewFileTypeManager(DefaultConfig())

View File

@@ -3,11 +3,47 @@ package filesystem
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
// 在包级别存储审计日志记录器
var auditLogger *AuditLogger
// InitAudit 初始化文件系统模块(包括审计日志)
func InitAudit(logDir string) error {
logger, err := NewAuditLogger(logDir)
if err != nil {
return err
}
auditLogger = logger
return nil
}
// CloseAudit 关闭审计日志
func CloseAudit() error {
if auditLogger != nil {
return auditLogger.Close()
}
return nil
}
// formatBytes 格式化字节大小为人类可读格式
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// ReadFile 读取文件内容
func ReadFile(path string) (string, error) {
if !isSafePath(path) {
@@ -51,7 +87,7 @@ func ListDir(path string) ([]map[string]interface{}, error) {
return nil, fmt.Errorf("读取目录失败: %v", err)
}
var result []map[string]interface{}
result := []map[string]interface{}{}
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
@@ -84,12 +120,44 @@ func CreateDir(path string) error {
return nil
}
// DeletePath 删除文件或目录
func DeletePath(path string) error {
// CreateFile 创建空文件
func CreateFile(path string) error {
if !isSafePath(path) {
return fmt.Errorf("路径不安全")
}
// 检查文件是否已存在
if _, err := os.Stat(path); err == nil {
return fmt.Errorf("文件已存在")
}
// 创建文件(如果父目录不存在,会自动创建)
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("创建文件失败: %v", err)
}
file.Close()
return nil
}
// DeletePath 删除文件或目录
// 优化:使用配置驱动的安全检查,支持确认机制
func DeletePath(path string) error {
// 使用默认配置
return DeletePathWithConfig(path, DefaultConfig())
}
// DeletePathWithConfig 使用指定配置删除文件或目录
// 支持配置化的安全策略和确认机制
func DeletePathWithConfig(path string, config *Config) error {
// 1. 路径安全检查
validator := NewPathValidator(config)
if err := validator.Validate(path); err != nil && err.IsError {
return fmt.Errorf("路径验证失败: %w", err)
}
// 2. 获取文件信息
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
@@ -98,6 +166,28 @@ func DeletePath(path string) error {
return fmt.Errorf("获取文件信息失败: %v", err)
}
// 3. 检查删除限制(配置驱动)
exceeds, details, checkErr := CheckDeleteRestrictions(path, info, config)
if checkErr != nil {
return checkErr
}
if exceeds {
// 根据配置决定是拒绝还是需要确认
if config.Security.DeleteRestrictions.RequireConfirm {
// TODO: 这里应该触发前端确认对话框
// 目前暂时返回警告信息,由前端处理
return &DeleteRestrictionWarning{
Path: path,
Details: details,
Info: info,
}
}
// 不需要确认,直接拒绝
return fmt.Errorf("删除限制: %s", details)
}
// 4. 执行删除操作
if info.IsDir() {
if err := os.RemoveAll(path); err != nil {
return fmt.Errorf("删除目录失败: %v", err)
@@ -111,6 +201,18 @@ func DeletePath(path string) error {
return nil
}
// DeleteRestrictionWarning 删除限制警告
// 用于前端显示确认对话框
type DeleteRestrictionWarning struct {
Path string
Details string
Info os.FileInfo
}
func (w *DeleteRestrictionWarning) Error() string {
return fmt.Sprintf("删除限制警告: %s\n%s", w.Path, w.Details)
}
// GetFileInfo 获取文件信息
func GetFileInfo(path string) (map[string]interface{}, error) {
if !isSafePath(path) {
@@ -125,19 +227,6 @@ func GetFileInfo(path string) (map[string]interface{}, error) {
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
formatBytes := func(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
return map[string]interface{}{
"name": info.Name(),
"path": path,
@@ -155,36 +244,35 @@ func OpenPath(path string) error {
return fmt.Errorf("路径不安全")
}
// 注意:这里需要导入 os/exec但为了安全暂时不实现执行命令
// 可以考虑使用 Wails 的 runtime 包提供的功能
return fmt.Errorf("打开功能暂未实现,请手动打开: %s", path)
}
path = filepath.Clean(path)
// isSafePath 检查路径是否安全(防止路径遍历攻击)
func isSafePath(path string) bool {
// 清理路径
cleanPath := filepath.Clean(path)
var cmd *exec.Cmd
// 检查是否包含路径遍历
if strings.Contains(cleanPath, "..") {
return false
switch runtime.GOOS {
case "windows":
// Windows: 使用 rundll32 打开文件(更可靠)
// 这种方式比 cmd start 更稳定,支持所有文件类型
cmd = exec.Command("rundll32.exe", "url.dll,FileProtocolHandler", path)
case "darwin":
// macOS: 使用 open 命令
cmd = exec.Command("open", path)
case "linux":
// Linux: 使用 xdg-open 命令
cmd = exec.Command("xdg-open", path)
default:
return fmt.Errorf("不支持的操作系统")
}
// Windows 下检查是否尝试访问系统关键目录
if runtime.GOOS == "windows" {
lowerPath := strings.ToLower(cleanPath)
// 禁止访问系统关键目录(可根据需要调整)
forbidden := []string{
"c:\\windows",
"c:\\program files",
"c:\\programdata",
}
for _, fb := range forbidden {
if strings.HasPrefix(lowerPath, fb) {
return false
}
}
// 启动命令(不等待完成)
if err := cmd.Start(); err != nil {
return fmt.Errorf("打开文件失败: %v", err)
}
return true
// 给进程一点时间启动
go func() {
time.Sleep(100 * time.Millisecond)
cmd.Process.Release()
}()
return nil
}

View File

@@ -0,0 +1,177 @@
package filesystem
import (
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// LogLevel 日志级别
type LogLevel int
const (
LogLevelDebug LogLevel = iota
LogLevelInfo
LogLevelWarn
LogLevelError
)
// Logger 结构化日志记录器
type Logger struct {
minLevel LogLevel
logFile *os.File
logPath string
mu sync.Mutex
prefix string
}
// NewLogger 创建新的日志记录器
func NewLogger(logPath string, minLevel LogLevel) (*Logger, error) {
// 创建日志目录
if err := os.MkdirAll(filepath.Dir(logPath), 0755); err != nil {
return nil, fmt.Errorf("创建日志目录失败: %w", err)
}
// 打开日志文件
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("打开日志文件失败: %w", err)
}
return &Logger{
minLevel: minLevel,
logFile: logFile,
logPath: logPath,
prefix: "[FileSystem]",
}, nil
}
// Close 关闭日志记录器
func (l *Logger) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
if l.logFile != nil {
return l.logFile.Close()
}
return nil
}
// Debug 记录调试日志
func (l *Logger) Debug(format string, args ...interface{}) {
l.log(LogLevelDebug, "DEBUG", format, args...)
}
// Info 记录信息日志
func (l *Logger) Info(format string, args ...interface{}) {
l.log(LogLevelInfo, "INFO", format, args...)
}
// Warn 记录警告日志
func (l *Logger) Warn(format string, args ...interface{}) {
l.log(LogLevelWarn, "WARN", format, args...)
}
// Error 记录错误日志
func (l *Logger) Error(format string, args ...interface{}) {
l.log(LogLevelError, "ERROR", format, args...)
}
// log 内部日志记录方法
func (l *Logger) log(level LogLevel, levelStr, format string, args ...interface{}) {
if level < l.minLevel {
return
}
l.mu.Lock()
defer l.mu.Unlock()
// 格式化消息
msg := fmt.Sprintf(format, args...)
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
// 写入日志文件
logLine := fmt.Sprintf("%s %s %s %s\n", timestamp, l.prefix, levelStr, msg)
if l.logFile != nil {
if _, err := l.logFile.WriteString(logLine); err != nil {
// 日志写入失败,输出到控制台
log.Print(logLine)
}
}
// 根据级别决定是否输出到控制台
if level >= LogLevelWarn {
log.Print(logLine)
}
}
// LogOperation 记录操作日志(辅助函数)
func LogOperation(operation, path string, success bool, err error) {
logger := GetGlobalLogger()
if logger == nil {
return
}
if success {
logger.Info("操作: %s %s - 成功", operation, path)
} else {
logger.Error("操作: %s %s - 失败: %v", operation, path, err)
}
}
// LogError 记录错误日志(辅助函数)
func LogError(operation string, path string, err error) {
logger := GetGlobalLogger()
if logger == nil {
return
}
logger.Error("错误: %s %s - %v", operation, path, err)
// 如果是调试模式,输出堆栈跟踪
if os.Getenv("UDESK_DEBUG") == "1" {
logger.Debug("堆栈:\n%s", GetStackTrace(2))
}
}
// ========== 全局日志记录器(向后兼容)==========
var (
globalLogger *Logger
loggerOnce sync.Once
)
// InitLogger 初始化全局日志记录器
func InitLogger(logDir string, minLevel LogLevel) error {
var initErr error
loggerOnce.Do(func() {
timestamp := time.Now().Format("2006-01-02")
logPath := filepath.Join(logDir, fmt.Sprintf("filesystem_%s.log", timestamp))
logger, err := NewLogger(logPath, minLevel)
if err != nil {
initErr = err
return
}
globalLogger = logger
log.Printf("[日志系统] 已启动,日志文件: %s", logPath)
})
return initErr
}
// GetGlobalLogger 获取全局日志记录器
func GetGlobalLogger() *Logger {
return globalLogger
}
// CloseLogger 关闭全局日志记录器
func CloseLogger() error {
if globalLogger != nil {
return globalLogger.Close()
}
return nil
}

View File

@@ -0,0 +1,197 @@
package filesystem
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
)
// PathValidator 路径验证器接口
// 提供统一的路径安全检查,避免重复代码
type PathValidator interface {
// Validate 验证路径并返回详细的错误信息
Validate(path string) *ValidationError
// IsSafe 快速检查路径是否安全
IsSafe(path string) bool
// IsSensitive 检查路径是否为敏感路径
IsSensitive(path string) bool
}
// ValidationError 验证错误
type ValidationError struct {
Path string
Reason string
IsError bool // true=禁止访问, false=敏感路径
}
func (e *ValidationError) Error() string {
if e.IsError {
return fmt.Sprintf("路径验证失败: %s - %s", e.Path, e.Reason)
}
return fmt.Sprintf("敏感路径警告: %s - %s", e.Path, e.Reason)
}
// DefaultPathValidator 默认路径验证器实现
type DefaultPathValidator struct {
config *Config
}
// NewPathValidator 创建新的路径验证器
func NewPathValidator(config *Config) PathValidator {
return &DefaultPathValidator{
config: config,
}
}
// Validate 验证路径
func (v *DefaultPathValidator) Validate(path string) *ValidationError {
// 清理路径
cleanPath := filepath.Clean(path)
// 1. 检查路径遍历攻击
if strings.Contains(cleanPath, PathTraversalPattern) {
return &ValidationError{
Path: path,
Reason: "检测到路径遍历尝试",
IsError: true,
}
}
// 2. 检查符号链接
if !v.config.Security.PathValidation.AllowSymlinks {
if fi, err := os.Lstat(path); err == nil && fi.Mode()&os.ModeSymlink != 0 {
return &ValidationError{
Path: path,
Reason: "不允许访问符号链接",
IsError: true,
}
}
}
// 3. 检查UNC路径Windows
if runtime.GOOS == "windows" && !v.config.Security.PathValidation.AllowUNCPaths {
if strings.HasPrefix(cleanPath, `\\`) {
return &ValidationError{
Path: path,
Reason: "不允许访问UNC网络路径",
IsError: true,
}
}
}
// 4. Windows特定检查
if runtime.GOOS == "windows" && v.config.Security.PathValidation.CheckWindowsSystemPaths {
if err := v.checkWindowsSystemPaths(cleanPath); err != nil {
return err
}
}
// 5. 检查敏感路径
if v.isSensitivePath(cleanPath) {
return &ValidationError{
Path: path,
Reason: "访问敏感路径",
IsError: false, // 警告而非错误
}
}
return nil
}
// IsSafe 快速检查路径是否安全
func (v *DefaultPathValidator) IsSafe(path string) bool {
err := v.Validate(path)
return err == nil || !err.IsError
}
// IsSensitive 检查路径是否为敏感路径
func (v *DefaultPathValidator) IsSensitive(path string) bool {
cleanPath := filepath.Clean(path)
return v.isSensitivePath(cleanPath)
}
// checkWindowsSystemPaths 检查Windows系统路径
func (v *DefaultPathValidator) checkWindowsSystemPaths(path string) *ValidationError {
lowerPath := strings.ToLower(path)
// 检查盘符
if len(lowerPath) >= 3 && lowerPath[1] == ':' {
driveLetter := lowerPath[0:1]
// 检查系统关键目录
forbiddenDirs := []string{
driveLetter + ":\\windows",
driveLetter + ":\\program files",
driveLetter + ":\\program files (x86)",
driveLetter + ":\\program files (arm)",
driveLetter + ":\\programdata",
driveLetter + ":\\system volume information",
driveLetter + ":\\recovery",
driveLetter + ":\\boot",
}
for _, fb := range forbiddenDirs {
if strings.HasPrefix(lowerPath, fb) {
return &ValidationError{
Path: path,
Reason: fmt.Sprintf("禁止访问系统目录: %s", fb),
IsError: true,
}
}
}
// 检查用户配置目录(可能包含敏感信息)
forbiddenPaths := []string{
"\\.ssh\\",
"\\.gnupg\\",
"\\.config\\",
"\\appdata\\roaming\\mozilla\\",
"\\appdata\\roaming\\google\\chrome\\",
"\\appdata\\local\\google\\user data\\",
}
for _, fp := range forbiddenPaths {
if strings.Contains(lowerPath, fp) {
return &ValidationError{
Path: path,
Reason: "禁止访问敏感配置目录",
IsError: true,
}
}
}
}
return nil
}
// isSensitivePath 检查是否为敏感路径
func (v *DefaultPathValidator) isSensitivePath(path string) bool {
lowerPath := strings.ToLower(filepath.Clean(path))
// 检查配置的敏感路径列表
for _, sp := range v.config.Security.PathValidation.SensitivePaths {
if strings.Contains(lowerPath, strings.ToLower(sp)) {
return true
}
}
return false
}
// isSafePath 兼容函数:保持向后兼容
// 使用默认配置的路径验证器
func isSafePath(path string) bool {
validator := NewPathValidator(DefaultConfig())
return validator.IsSafe(path)
}
// isSensitivePath 兼容函数:保持向后兼容
// 使用默认配置检查敏感路径
func isSensitivePath(path string) bool {
validator := NewPathValidator(DefaultConfig())
return validator.IsSensitive(path)
}

View File

@@ -0,0 +1,392 @@
package filesystem
import (
"crypto/rand"
"encoding/json"
"fmt"
"io"
"math/big"
"os"
"path/filepath"
"time"
)
// RecycleBinEntry 回收站条目
type RecycleBinEntry struct {
OriginalPath string `json:"original_path"` // 原始路径
DeletedPath string `json:"deleted_path"` // 回收站中的路径
DeletedTime time.Time `json:"deleted_time"` // 删除时间
Size int64 `json:"size"` // 文件大小
IsDirectory bool `json:"is_directory"` // 是否为目录
OriginalDevice string `json:"original_device"` // 原始设备(盘符)
}
// RecycleBin 回收站管理器
type RecycleBin struct {
binPath string
metadataFile string
entries []RecycleBinEntry
}
// NewRecycleBin 创建回收站管理器
func NewRecycleBin(binPath string) (*RecycleBin, error) {
// 创建回收站目录
if err := os.MkdirAll(binPath, 0755); err != nil {
return nil, fmt.Errorf("创建回收站目录失败: %v", err)
}
bin := &RecycleBin{
binPath: binPath,
metadataFile: filepath.Join(binPath, "metadata.json"),
entries: make([]RecycleBinEntry, 0),
}
// 加载元数据
if err := bin.loadMetadata(); err != nil {
// 如果文件不存在,这是正常的,忽略错误
if !os.IsNotExist(err) {
return nil, fmt.Errorf("加载回收站元数据失败: %v", err)
}
}
// 启动自动清理协程
go bin.autoCleanup()
return bin, nil
}
// MoveToRecycleBin 移动文件到回收站
func (rb *RecycleBin) MoveToRecycleBin(path string) error {
// 获取文件信息
info, err := os.Stat(path)
if err != nil {
return fmt.Errorf("获取文件信息失败: %v", err)
}
// 生成唯一的回收站文件名
timestamp := time.Now().Format("20060102_150405")
randomSuffix := generateRandomString(6)
baseName := filepath.Base(path)
var recycleName string
if info.IsDir() {
recycleName = fmt.Sprintf("%s_%s_%s", timestamp, randomSuffix, baseName)
} else {
ext := filepath.Ext(baseName)
nameWithoutExt := baseName[:len(baseName)-len(ext)]
recycleName = fmt.Sprintf("%s_%s_%s%s", timestamp, randomSuffix, nameWithoutExt, ext)
}
recyclePath := filepath.Join(rb.binPath, recycleName)
// 移动文件到回收站
if err := os.Rename(path, recyclePath); err != nil {
// 如果跨设备移动失败,尝试复制后删除
if err := copyRecursively(path, recyclePath); err != nil {
return fmt.Errorf("移动到回收站失败: %v", err)
}
os.RemoveAll(path)
}
// 创建元数据条目
entry := RecycleBinEntry{
OriginalPath: path,
DeletedPath: recyclePath,
DeletedTime: time.Now(),
Size: info.Size(),
IsDirectory: info.IsDir(),
OriginalDevice: getDevice(path),
}
// 添加到元数据
rb.entries = append(rb.entries, entry)
// 保存元数据
if err := rb.saveMetadata(); err != nil {
return fmt.Errorf("保存回收站元数据失败: %v", err)
}
return nil
}
// RestoreFromRecycleBin 从回收站恢复文件
func (rb *RecycleBin) RestoreFromRecycleBin(recyclePath string) error {
// 查找对应的元数据条目
var entry *RecycleBinEntry
for i := range rb.entries {
if rb.entries[i].DeletedPath == recyclePath {
entry = &rb.entries[i]
// 从列表中移除
rb.entries = append(rb.entries[:i], rb.entries[i+1:]...)
break
}
}
if entry == nil {
return fmt.Errorf("回收站中未找到该文件")
}
// 检查原始路径的父目录是否存在
parentDir := filepath.Dir(entry.OriginalPath)
if err := os.MkdirAll(parentDir, 0755); err != nil {
return fmt.Errorf("创建父目录失败: %v", err)
}
// 检查原始位置是否已有文件
if _, err := os.Stat(entry.OriginalPath); err == nil {
return fmt.Errorf("原始位置已存在同名文件,请先删除或重命名")
}
// 移回文件
if err := os.Rename(recyclePath, entry.OriginalPath); err != nil {
// 如果跨设备移动失败,尝试复制后删除
if err := copyRecursively(recyclePath, entry.OriginalPath); err != nil {
return fmt.Errorf("恢复文件失败: %v", err)
}
os.RemoveAll(recyclePath)
}
// 保存元数据
if err := rb.saveMetadata(); err != nil {
return fmt.Errorf("保存回收站元数据失败: %v", err)
}
return nil
}
// DeletePermanently 永久删除回收站中的文件
func (rb *RecycleBin) DeletePermanently(recyclePath string) error {
// 查找元数据条目
for i, entry := range rb.entries {
if entry.DeletedPath == recyclePath {
// 从列表中移除
rb.entries = append(rb.entries[:i], rb.entries[i+1:]...)
break
}
}
// 删除文件
if err := os.RemoveAll(recyclePath); err != nil {
return fmt.Errorf("永久删除失败: %v", err)
}
// 保存元数据
if err := rb.saveMetadata(); err != nil {
return fmt.Errorf("保存回收站元数据失败: %v", err)
}
return nil
}
// ListEntries 列出回收站中的所有条目
func (rb *RecycleBin) ListEntries() []RecycleBinEntry {
return rb.entries
}
// Empty 清空回收站
func (rb *RecycleBin) Empty() error {
// 删除所有文件
for _, entry := range rb.entries {
if err := os.RemoveAll(entry.DeletedPath); err != nil {
return fmt.Errorf("删除文件失败: %s", err)
}
}
// 清空元数据
rb.entries = make([]RecycleBinEntry, 0)
// 保存元数据
if err := rb.saveMetadata(); err != nil {
return fmt.Errorf("保存回收站元数据失败: %v", err)
}
return nil
}
// autoCleanup 自动清理超过30天的文件
func (rb *RecycleBin) autoCleanup() {
ticker := time.NewTicker(24 * time.Hour) // 每天检查一次
defer ticker.Stop()
for range ticker.C {
rb.cleanupExpiredEntries()
}
}
// cleanupExpiredEntries 清理过期的条目
func (rb *RecycleBin) cleanupExpiredEntries() {
now := time.Now()
expiredEntries := make([]int, 0)
// 找出所有过期的条目超过30天
for i, entry := range rb.entries {
if now.Sub(entry.DeletedTime) > 30*24*time.Hour {
expiredEntries = append(expiredEntries, i)
}
}
// 从后往前删除(避免索引问题)
for i := len(expiredEntries) - 1; i >= 0; i-- {
idx := expiredEntries[i]
entry := rb.entries[idx]
// 删除文件
_ = os.RemoveAll(entry.DeletedPath)
// 从列表中移除
rb.entries = append(rb.entries[:idx], rb.entries[idx+1:]...)
}
// 保存元数据
if len(expiredEntries) > 0 {
_ = rb.saveMetadata()
}
}
// loadMetadata 加载元数据
func (rb *RecycleBin) loadMetadata() error {
data, err := os.ReadFile(rb.metadataFile)
if err != nil {
return err
}
return json.Unmarshal(data, &rb.entries)
}
// saveMetadata 保存元数据
func (rb *RecycleBin) saveMetadata() error {
data, err := json.MarshalIndent(rb.entries, "", " ")
if err != nil {
return err
}
return os.WriteFile(rb.metadataFile, data, 0644)
}
// copyRecursively 递归复制文件或目录
func copyRecursively(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
if info.IsDir() {
return copyDirectory(src, dst)
}
return copyFile(src, dst)
}
// copyDirectory 复制目录
func copyDirectory(src, dst string) error {
// 创建目标目录
if err := os.MkdirAll(dst, 0755); err != nil {
return err
}
// 读取源目录
entries, err := os.ReadDir(src)
if err != nil {
return err
}
// 复制每个条目
for _, entry := range entries {
srcPath := filepath.Join(src, entry.Name())
dstPath := filepath.Join(dst, entry.Name())
if entry.IsDir() {
if err := copyDirectory(srcPath, dstPath); err != nil {
return err
}
} else {
if err := copyFile(srcPath, dstPath); err != nil {
return err
}
}
}
return nil
}
// copyFile 复制文件
func copyFile(src, dst string) error {
// 打开源文件
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
// 创建目标文件
dstFile, err := os.Create(dst)
if err != nil {
return err
}
defer dstFile.Close()
// 复制内容
if _, err := io.Copy(dstFile, srcFile); err != nil {
return err
}
// 复制文件权限
srcInfo, err := os.Stat(src)
if err != nil {
return err
}
return os.Chmod(dst, srcInfo.Mode())
}
// getDevice 获取文件所在设备(盘符)
func getDevice(path string) string {
absPath, err := filepath.Abs(path)
if err != nil {
return ""
}
if len(absPath) >= 2 {
return absPath[:2] // 返回 "C:" 这样的盘符
}
return ""
}
// generateRandomString 生成随机字符串
// 使用加密安全的随机数生成器,保证随机性和性能
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
// 使用 crypto/rand 生成安全的随机数
for i := range b {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
// 如果加密随机数生成失败,回退到时间戳(极低概率)
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
continue
}
b[i] = charset[n.Int64()]
}
return string(b)
}
// 全局回收站实例
var globalRecycleBin *RecycleBin
// InitRecycleBin 初始化全局回收站
func InitRecycleBin(binPath string) error {
bin, err := NewRecycleBin(binPath)
if err != nil {
return err
}
globalRecycleBin = bin
return nil
}
// GetRecycleBin 获取全局回收站实例
func GetRecycleBin() *RecycleBin {
return globalRecycleBin
}

View File

@@ -0,0 +1,549 @@
package filesystem
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"time"
)
// FileSystemService 文件系统服务
// 统一管理所有文件系统相关的功能,使用依赖注入而非全局变量
type FileSystemService struct {
// 核心组件
config *Config
pathValidator PathValidator
fileTypeManager FileTypeManager
// 基础设施组件
auditLogger *AuditLogger
recycleBin *RecycleBin
lockChecker *FileLockChecker
// 状态管理
mu sync.RWMutex
initialized bool
}
// NewFileSystemService 创建新的文件系统服务
// 使用依赖注入,所有组件通过参数传入,便于测试和替换
func NewFileSystemService(config *Config) (*FileSystemService, error) {
if config == nil {
config = DefaultConfig()
}
service := &FileSystemService{
config: config,
pathValidator: NewPathValidator(config),
fileTypeManager: NewFileTypeManager(config),
}
// 初始化基础设施组件
if err := service.initializeComponents(); err != nil {
return nil, fmt.Errorf("初始化文件系统服务失败: %w", err)
}
service.initialized = true
return service, nil
}
// initializeComponents 初始化各个组件
func (s *FileSystemService) initializeComponents() error {
// 1. 初始化审计日志
if s.config.Features.AuditLog {
if err := s.initAuditLogger(); err != nil {
return fmt.Errorf("初始化审计日志失败: %w", err)
}
}
// 2. 初始化回收站
if s.config.Features.RecycleBin {
if err := s.initRecycleBin(); err != nil {
return fmt.Errorf("初始化回收站失败: %w", err)
}
}
// 3. 初始化文件锁检查器
if s.config.Features.FileLockCheck {
s.lockChecker = NewFileLockChecker()
}
return nil
}
// initAuditLogger 初始化审计日志
func (s *FileSystemService) initAuditLogger() error {
// 获取日志目录
userDataDir := getUserDataDir()
logDir := filepath.Join(userDataDir, "logs")
logger, err := NewAuditLogger(logDir)
if err != nil {
return err
}
s.auditLogger = logger
return nil
}
// initRecycleBin 初始化回收站
func (s *FileSystemService) initRecycleBin() error {
// 获取回收站目录
userDataDir := getUserDataDir()
recycleBinPath := filepath.Join(userDataDir, "recycle_bin")
bin, err := NewRecycleBin(recycleBinPath)
if err != nil {
return err
}
s.recycleBin = bin
return nil
}
// ========== 核心文件操作 ==========
// Read 读取文件内容(实现 FileService 接口)
func (s *FileSystemService) Read(path string) (string, error) {
return s.ReadFile(path)
}
// ReadFile 读取文件内容
func (s *FileSystemService) ReadFile(path string) (string, error) {
// 路径验证
if err := s.validatePath(path); err != nil {
return "", err
}
// 读取文件
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("读取文件失败: %v", err)
}
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.LogRead(path, int64(len(data)), nil)
}
return string(data), nil
}
// Write 写入文件内容(实现 FileService 接口)
func (s *FileSystemService) Write(path, content string) error {
return s.WriteFile(path, content)
}
// WriteFile 写入文件
func (s *FileSystemService) WriteFile(path, content string) error {
// 路径验证
if err := s.validatePath(path); err != nil {
return err
}
// 创建目录
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, DefaultDirPermissions); err != nil {
return fmt.Errorf("创建目录失败: %v", err)
}
// 写入文件
data := []byte(content)
if err := os.WriteFile(path, data, DefaultFilePermissions); err != nil {
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.LogWrite(path, int64(len(data)), err)
}
return fmt.Errorf("写入文件失败: %v", err)
}
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.LogWrite(path, int64(len(data)), nil)
}
return nil
}
// List 列出目录内容(实现 FileService 接口)
func (s *FileSystemService) List(path string) ([]map[string]interface{}, error) {
return s.ListDir(path)
}
// Open 打开文件(实现 FileService 接口)
func (s *FileSystemService) Open(path string) error {
// 使用系统默认程序打开文件
var cmd *exec.Cmd
switch runtime.GOOS {
case "windows":
cmd = exec.Command("cmd", "/c", "start", "", path)
case "darwin":
cmd = exec.Command("open", path)
default:
cmd = exec.Command("xdg-open", path)
}
return cmd.Start()
}
// Delete 删除文件或目录(实现 FileService 接口)
func (s *FileSystemService) Delete(path string) error {
return s.DeletePathWithContext(context.Background(), path)
}
// DeletePath 删除文件或目录
func (s *FileSystemService) DeletePath(path string) error {
return s.DeletePathWithContext(context.Background(), path)
}
// DeletePathWithContext 带上下文的删除操作
func (s *FileSystemService) DeletePathWithContext(ctx context.Context, path string) error {
// 路径验证
if err := s.validatePath(path); err != nil {
return err
}
// 获取文件信息
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("文件或目录不存在")
}
return fmt.Errorf("获取文件信息失败: %v", err)
}
// 检查删除限制
exceeds, details, checkErr := CheckDeleteRestrictions(path, info, s.config)
if checkErr != nil {
return checkErr
}
if exceeds {
if s.config.Security.DeleteRestrictions.RequireConfirm {
return &DeleteRestrictionWarning{
Path: path,
Details: details,
Info: info,
}
}
return fmt.Errorf("删除限制: %s", details)
}
// 文件锁检查(可选)
if s.lockChecker != nil {
if err := s.lockChecker.SafeDeleteWithLockCheck(path); err != nil {
return err
}
}
// 执行删除
var deleteErr error
if info.IsDir() {
deleteErr = os.RemoveAll(path)
} else {
deleteErr = os.Remove(path)
}
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.LogDelete(path, info.IsDir(), info.Size(), deleteErr)
}
if deleteErr != nil {
return fmt.Errorf("删除失败: %v", deleteErr)
}
// 如果启用回收站,移动到回收站而非永久删除
if s.recycleBin != nil {
// 检查是否已在回收站中
if !isInRecycleBin(path) {
if err := s.recycleBin.MoveToRecycleBin(path); err != nil {
// 回收站失败,记录但继续
fmt.Printf("[警告] 移动到回收站失败: %v\n", err)
}
}
}
return nil
}
// ListDir 列出目录内容
func (s *FileSystemService) ListDir(path string) ([]map[string]interface{}, error) {
// 路径验证
if err := s.validatePath(path); err != nil {
return nil, err
}
// 读取目录
entries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("读取目录失败: %v", err)
}
// 转换为结果格式
result := make([]map[string]interface{}, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
continue
}
fullPath := filepath.Join(path, entry.Name())
result = append(result, map[string]interface{}{
"name": entry.Name(),
"path": fullPath,
"is_dir": entry.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
})
}
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.Log(AuditLogEntry{
Timestamp: getCurrentTimestamp(),
Operation: OperationList,
Path: path,
IsDirectory: true,
Success: true,
})
}
return result, nil
}
// CreateDir 创建目录
func (s *FileSystemService) CreateDir(path string) error {
if err := s.validatePath(path); err != nil {
return err
}
if err := os.MkdirAll(path, DefaultDirPermissions); err != nil {
return fmt.Errorf("创建目录失败: %v", err)
}
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.Log(AuditLogEntry{
Timestamp: getCurrentTimestamp(),
Operation: OperationCreate,
Path: path,
IsDirectory: true,
Success: true,
})
}
return nil
}
// CreateFile 创建空文件
func (s *FileSystemService) CreateFile(path string) error {
if err := s.validatePath(path); err != nil {
return err
}
// 检查文件是否已存在
if _, err := os.Stat(path); err == nil {
return fmt.Errorf("文件已存在")
}
file, err := os.Create(path)
if err != nil {
return fmt.Errorf("创建文件失败: %v", err)
}
file.Close()
// 记录审计日志
if s.auditLogger != nil {
s.auditLogger.Log(AuditLogEntry{
Timestamp: getCurrentTimestamp(),
Operation: OperationCreate,
Path: path,
IsDirectory: false,
Success: true,
})
}
return nil
}
// GetInfo 获取文件信息(实现 FileService 接口)
func (s *FileSystemService) GetInfo(path string) (map[string]interface{}, error) {
return s.GetFileInfo(path)
}
// GetFileInfo 获取文件信息
func (s *FileSystemService) GetFileInfo(path string) (map[string]interface{}, error) {
if err := s.validatePath(path); err != nil {
return nil, err
}
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("文件或目录不存在")
}
return nil, fmt.Errorf("获取文件信息失败: %v", err)
}
return map[string]interface{}{
"name": info.Name(),
"path": path,
"size": info.Size(),
"size_str": formatBytes(info.Size()),
"is_dir": info.IsDir(),
"mod_time": info.ModTime().Format("2006-01-02 15:04:05"),
"mode": info.Mode().String(),
}, nil
}
// OpenPath 打开文件或目录(使用系统默认程序)
func (s *FileSystemService) OpenPath(path string) error {
if err := s.validatePath(path); err != nil {
return err
}
return OpenPath(path)
}
// ========== ZIP操作接口 ==========
// ListZip 列出ZIP文件内容
func (s *FileSystemService) ListZip(zipPath string) ([]map[string]interface{}, error) {
return ListZipContents(zipPath)
}
// ExtractZipFile 从ZIP提取文件内容
func (s *FileSystemService) ExtractZipFile(zipPath, filePath string) (string, error) {
return ExtractFileFromZip(zipPath, filePath)
}
// ExtractZipFileToTemp 从ZIP提取文件到临时目录
func (s *FileSystemService) ExtractZipFileToTemp(zipPath, filePath string) (string, error) {
return ExtractFileFromZipToTemp(zipPath, filePath)
}
// GetZipFileInfo 获取ZIP文件信息
func (s *FileSystemService) GetZipFileInfo(zipPath, filePath string) (map[string]interface{}, error) {
return GetZipFileInfo(zipPath, filePath)
}
// ========== 辅助函数 ==========
// getUserDataDir 获取用户数据目录
func getUserDataDir() string {
var basePath string
switch runtime.GOOS {
case "windows":
basePath = os.Getenv("LOCALAPPDATA")
if basePath == "" {
basePath = os.Getenv("APPDATA")
}
case "darwin":
homeDir, _ := os.UserHomeDir()
basePath = filepath.Join(homeDir, "Library", "Application Support")
default:
homeDir, _ := os.UserHomeDir()
basePath = filepath.Join(homeDir, ".config")
}
if basePath == "" {
basePath = "."
}
return filepath.Join(basePath, "u-desk")
}
// getCurrentTimestamp 获取当前时间戳
func getCurrentTimestamp() time.Time {
return time.Now()
}
// isInRecycleBin 检查路径是否在回收站中
func isInRecycleBin(path string) bool {
// 简化版本:检查路径是否包含回收站目录名
userDataDir := getUserDataDir()
recycleBinPath := filepath.Join(userDataDir, "recycle_bin")
return filepath.HasPrefix(filepath.Clean(path), filepath.Clean(recycleBinPath))
}
// ========== 辅助方法 ==========
// validatePath 验证路径
func (s *FileSystemService) validatePath(path string) error {
err := s.pathValidator.Validate(path)
if err != nil && err.IsError {
return err
}
return nil
}
// GetConfig 获取配置
func (s *FileSystemService) GetConfig() *Config {
return s.config
}
// GetAuditLogger 获取审计日志记录器
func (s *FileSystemService) GetAuditLogger() *AuditLogger {
return s.auditLogger
}
// GetRecycleBin 获取回收站
func (s *FileSystemService) GetRecycleBin() *RecycleBin {
return s.recycleBin
}
// Close 关闭服务,释放资源
func (s *FileSystemService) Close(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.initialized {
return nil
}
// 关闭审计日志
if s.auditLogger != nil {
if err := s.auditLogger.Close(); err != nil {
return fmt.Errorf("关闭审计日志失败: %w", err)
}
}
s.initialized = false
return nil
}
// ========== 全局服务实例(向后兼容)==========
var (
globalService *FileSystemService
globalServiceOnce sync.Once
)
// GetGlobalService 获取全局文件系统服务实例(单例)
// 保持向后兼容,但推荐使用依赖注入
func GetGlobalService() (*FileSystemService, error) {
var initErr error
globalServiceOnce.Do(func() {
globalService, initErr = NewFileSystemService(DefaultConfig())
})
return globalService, initErr
}
// InitGlobalFileSystem 初始化全局文件系统(兼容旧代码)
func InitGlobalFileSystem() error {
_, err := GetGlobalService()
return err
}
// CloseGlobalFileSystem 关闭全局文件系统
func CloseGlobalFileSystem(ctx context.Context) error {
if globalService != nil {
return globalService.Close(ctx)
}
return nil
}

View File

@@ -0,0 +1,27 @@
package filesystem
import (
"context"
)
// FileService 文件操作核心接口
// 定义所有文件操作的基本功能便于mock测试
type FileService interface {
// 基本操作
Read(path string) (string, error)
Write(path, content string) error
Delete(path string) error
List(path string) ([]map[string]interface{}, error)
CreateDir(path string) error
CreateFile(path string) error
GetInfo(path string) (map[string]interface{}, error)
Open(path string) error
// 配置
GetConfig() *Config
Close(ctx context.Context) error
}
// 确保实现接口
var _ FileService = (*FileSystemService)(nil)

391
internal/filesystem/zip.go Normal file
View File

@@ -0,0 +1,391 @@
package filesystem
import (
"archive/zip"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
"time"
)
// ZipFileEntry 表示 zip 文件中的一个文件条目
type ZipFileEntry struct {
Name string
Path string // 在 zip 中的完整路径
Size int64
Modified string
IsDir bool
Method string // 压缩方法 (Store/Deflate)
}
// validateZipPath 验证 ZIP 文件路径是否有效
// 统一的路径验证逻辑,避免在多个函数中重复
func validateZipPath(zipPath string) error {
if !isSafePath(zipPath) {
return fmt.Errorf("zip 路径不安全")
}
// 检查 zip 文件是否存在
if _, err := os.Stat(zipPath); os.IsNotExist(err) {
return fmt.Errorf("zip 文件不存在")
}
return nil
}
// debugLog 条件日志记录,仅在调试模式下输出
// 通过设置环境变量 UDESK_ZIP_DEBUG=1 启用调试日志
var zipDebugMode = os.Getenv("UDESK_ZIP_DEBUG") == "1"
func debugLog(format string, args ...interface{}) {
if zipDebugMode {
log.Printf(format, args...)
}
}
// ListZipContents 列出 zip 文件内容
// 🔒 安全增强添加ZIP炸弹防护、路径遍历检查
func ListZipContents(zipPath string) ([]map[string]interface{}, error) {
debugLog("[ListZipContents] 开始处理 ZIP 文件: %s", zipPath)
// 统一验证路径
if err := validateZipPath(zipPath); err != nil {
debugLog("[ListZipContents] 路径验证失败: %v", err)
return nil, err
}
// 检查文件是否存在
fileInfo, err := os.Stat(zipPath)
if err != nil {
debugLog("[ListZipContents] 文件状态检查失败: %v", err)
return nil, fmt.Errorf("无法访问文件: %v", err)
}
debugLog("[ListZipContents] 文件信息: 大小=%d bytes, 权限=%v", fileInfo.Size(), fileInfo.Mode())
// 🔒 安全检查:检查文件大小(太小或太大)
if fileInfo.Size() < 22 {
debugLog("[ListZipContents] 文件太小,可能不是有效的 ZIP 文件: %d bytes", fileInfo.Size())
return nil, fmt.Errorf("文件太小 (%d bytes),可能不是有效的 ZIP 文件", fileInfo.Size())
}
// 🔒 安全检查ZIP炸弹防护检查文件大小
if fileInfo.Size() > MaxZipSize {
debugLog("[ListZipContents] ZIP文件过大: %d bytes", fileInfo.Size())
return nil, fmt.Errorf("ZIP文件过大 (%d bytes),超过限制 (%d bytes)", fileInfo.Size(), MaxZipSize)
}
// 检查文件是否可读
file, err := os.Open(zipPath)
if err != nil {
debugLog("[ListZipContents] 无法打开文件: %v", err)
return nil, fmt.Errorf("无法打开文件: %v", err)
}
// 读取前 4 个字节检查 ZIP 文件头
header := make([]byte, 4)
n, err := file.Read(header)
file.Close()
if err != nil || n != 4 {
debugLog("[ListZipContents] 无法读取文件头: n=%d, err=%v", n, err)
return nil, fmt.Errorf("无法读取文件头")
}
debugLog("[ListZipContents] 文件头: 0x%02x 0x%02x 0x%02x 0x%02x", header[0], header[1], header[2], header[3])
// ZIP 文件应该是 PK\x03\x04 或 PK\x05\x06 (空 ZIP)
if header[0] != 0x50 || header[1] != 0x4B { // 'P' 'K'
debugLog("[ListZipContents] 文件头签名错误,不是有效的 ZIP 文件")
return nil, fmt.Errorf("文件头签名错误,不是有效的 ZIP 文件 (可能是其他格式或已损坏)")
}
// 打开 zip 文件
debugLog("[ListZipContents] 尝试打开 ZIP 读取器...")
reader, err := zip.OpenReader(zipPath)
if err != nil {
debugLog("[ListZipContents] 打开 ZIP 失败: %v", err)
debugLog("[ListZipContents] 错误类型: %T", err)
// 提供更详细的错误信息和解决建议
errMsg := fmt.Sprintf("打开 zip 文件失败: %v", err)
if strings.Contains(err.Error(), "not a valid zip file") {
errMsg += "\n\n可能的原因:\n" +
"1. 文件已损坏或不完整\n" +
"2. 不是标准的 ZIP 格式\n" +
"3. 文件正在被其他程序占用(如压缩软件)\n" +
"4. 使用了特殊的压缩方式\n\n" +
"建议解决方法:\n" +
"- 关闭所有可能打开该文件的程序\n" +
"- 尝试用 7-Zip 或 WinRAR 重新压缩\n" +
"- 检查文件大小是否正常(不是 0 字节)\n" +
"- 如果是从网络下载的,尝试重新下载"
}
return nil, fmt.Errorf("%s", errMsg)
}
defer reader.Close()
debugLog("[ListZipContents] 成功打开 ZIP开始读取文件列表...")
// 🔒 安全检查ZIP炸弹防护检查解压后总大小
var totalUncompressed int64
for _, file := range reader.File {
totalUncompressed += int64(file.UncompressedSize64)
}
if totalUncompressed > MaxExtractSize {
debugLog("[ListZipContents] 解压后总大小过大: %d bytes", totalUncompressed)
return nil, fmt.Errorf("解压后总大小过大 (%d bytes),超过限制 (%d bytes)", totalUncompressed, MaxExtractSize)
}
var result []map[string]interface{}
fileCount := 0
dirCount := 0
// 遍历 zip 文件中的所有文件
for _, file := range reader.File {
// 跳过 macOS 资源分支文件
if strings.HasPrefix(file.Name, "__MACOSX/") {
continue
}
// 🔒 安全检查:路径遍历攻击防护
if strings.Contains(file.Name, "..") {
debugLog("[ListZipContents] 检测到路径遍历尝试: %s", file.Name)
return nil, fmt.Errorf("ZIP文件包含不安全的路径: %s", file.Name)
}
// 🔒 安全检查:绝对路径防护
if filepath.IsAbs(file.Name) {
debugLog("[ListZipContents] 检测到绝对路径: %s", file.Name)
return nil, fmt.Errorf("ZIP文件包含绝对路径: %s", file.Name)
}
isDir := file.Mode().IsDir()
name := filepath.Base(file.Name)
// 对于目录,使用目录名;对于文件,使用文件名
if isDir {
name = file.Name
dirCount++
} else {
fileCount++
}
// 压缩方法描述
method := "Store"
if file.Method == 8 {
method = "Deflate"
}
entry := map[string]interface{}{
"name": name,
"path": file.Name, // zip 中的完整路径
"is_dir": isDir,
"size": file.UncompressedSize64,
"compressed": file.CompressedSize64,
"mod_time": file.Modified.Format("2006-01-02 15:04:05"),
"method": method,
}
result = append(result, entry)
}
debugLog("[ListZipContents] 读取完成: %d 个文件, %d 个目录", fileCount, dirCount)
return result, nil
}
// ExtractFileFromZip 从 zip 文件中提取单个文件内容
// 优化:使用通用包装器,消除重复代码
func ExtractFileFromZip(zipPath, filePath string) (string, error) {
result, err := withZipFile(zipPath, filePath, func(file *zip.File) (interface{}, error) {
// 打开文件
rc, err := file.Open()
if err != nil {
return nil, fmt.Errorf("打开 zip 中的文件失败: %v", err)
}
defer rc.Close()
// 读取内容
data, err := readAllFromFile(rc)
if err != nil {
return nil, fmt.Errorf("读取文件内容失败: %v", err)
}
return string(data), nil
})
if err != nil {
return "", err
}
return result.(string), nil
}
// ExtractFileFromZipToTemp 从 zip 文件中提取单个文件到临时目录
// 返回临时文件的完整路径
// 适用于提取图片等二进制文件
// 优化:使用通用包装器,消除重复代码
func ExtractFileFromZipToTemp(zipPath, filePath string) (string, error) {
debugLog("[ExtractFileFromZipToTemp] 开始提取: %s from %s", filePath, zipPath)
// 启动临时文件清理协程
go CleanOldTempFiles()
// 创建临时目录
tempDir := filepath.Join(os.TempDir(), TempFileDir)
if err := os.MkdirAll(tempDir, DefaultDirPermissions); err != nil {
return "", fmt.Errorf("创建临时目录失败: %v", err)
}
result, err := withZipFile(zipPath, filePath, func(file *zip.File) (interface{}, error) {
// 安全检查:文件大小限制
if file.UncompressedSize64 > MaxSingleFileSize {
debugLog("[ExtractFileFromZipToTemp] 文件过大: %d bytes", file.UncompressedSize64)
return nil, fmt.Errorf("文件过大 (%d bytes),超过限制 (%d bytes)",
file.UncompressedSize64, MaxSingleFileSize)
}
// 打开文件
rc, err := file.Open()
if err != nil {
return nil, fmt.Errorf("打开 zip 中的文件失败: %v", err)
}
defer rc.Close()
// 生成临时文件名
tempFileName := fmt.Sprintf("%d_%s", time.Now().UnixNano(), filepath.Base(file.Name))
tempFilePath := filepath.Join(tempDir, tempFileName)
// 创建临时文件
outFile, err := os.Create(tempFilePath)
if err != nil {
return nil, fmt.Errorf("创建临时文件失败: %v", err)
}
defer outFile.Close()
// 限制写入大小
limitedReader := &io.LimitedReader{R: rc, N: MaxSingleFileSize}
written, err := io.Copy(outFile, limitedReader)
if err != nil {
os.Remove(tempFilePath)
return nil, fmt.Errorf("写入临时文件失败: %v", err)
}
// 检查是否超过限制
if limitedReader.N <= 0 {
os.Remove(tempFilePath)
return nil, fmt.Errorf("文件大小超过限制")
}
debugLog("[ExtractFileFromZipToTemp] 提取成功: %s -> %s (%d bytes)",
file.Name, tempFilePath, written)
return tempFilePath, nil
})
if err != nil {
return "", err
}
return result.(string), nil
}
// CleanOldTempFiles 清理超过指定时间的临时文件
// 🔒 新增:防止临时文件累积占用磁盘空间
func CleanOldTempFiles() {
tempDir := filepath.Join(os.TempDir(), "u-desk-zip")
// 检查临时目录是否存在
dir, err := os.Open(tempDir)
if err != nil {
// 目录不存在或其他错误,无需清理
return
}
defer dir.Close()
// 读取目录内容
files, err := dir.Readdir(-1)
if err != nil {
return
}
cleanedCount := 0
now := time.Now()
for _, file := range files {
// 跳过目录
if file.IsDir() {
continue
}
// 检查文件年龄
if now.Sub(file.ModTime()) > TempFileCleanupAge {
filePath := filepath.Join(tempDir, file.Name())
if err := os.Remove(filePath); err == nil {
cleanedCount++
debugLog("[CleanOldTempFiles] 清理临时文件: %s (年龄: %v)", file.Name(), now.Sub(file.ModTime()))
}
}
}
if cleanedCount > 0 {
debugLog("[CleanOldTempFiles] 清理完成: 共清理 %d 个临时文件", cleanedCount)
}
}
// GetZipFileInfo 获取 zip 文件中特定文件的信息
// 优化:使用通用包装器,消除重复代码
func GetZipFileInfo(zipPath, filePath string) (map[string]interface{}, error) {
result, err := withZipFile(zipPath, filePath, func(file *zip.File) (interface{}, error) {
return createFileInfoMap(file, true), nil
})
if err != nil {
return nil, err
}
return result.(map[string]interface{}), nil
}
// validateZipFileBasic 验证ZIP文件的基本信息提取自ListZipContents
func validateZipFileBasic(zipPath string) error {
if err := validateZipPath(zipPath); err != nil {
return err
}
fileInfo, err := os.Stat(zipPath)
if err != nil {
return fmt.Errorf("无法访问文件: %v", err)
}
if fileInfo.Size() < MinValidZipSize {
return fmt.Errorf("文件太小 (%d bytes)", fileInfo.Size())
}
if fileInfo.Size() > MaxZipSize {
return fmt.Errorf("ZIP文件过大 (%d bytes)", fileInfo.Size())
}
return checkZipFileHeader(zipPath)
}
// checkZipFileHeader 检查ZIP文件头签名
func checkZipFileHeader(zipPath string) error {
file, err := os.Open(zipPath)
if err != nil {
return fmt.Errorf("无法打开文件: %v", err)
}
defer file.Close()
header := make([]byte, 4)
n, err := file.Read(header)
if err != nil || n != 4 {
return fmt.Errorf("无法读取文件头")
}
if header[0] != 0x50 || header[1] != 0x4B {
return fmt.Errorf("不是有效的 ZIP 文件")
}
return nil
}

View File

@@ -0,0 +1,121 @@
package filesystem
import (
"archive/zip"
"fmt"
"io"
"path/filepath"
)
// ZipOperation ZIP操作回调函数类型
// 用于 withZipReader 通用包装器
type ZipOperation func(*zip.ReadCloser) (interface{}, error)
// withZipReader 通用的ZIP文件操作包装器
// 消除重复的打开/关闭逻辑,统一错误处理
// 参数:
// - zipPath: ZIP文件路径
// - operation: 操作回调函数,接收 *zip.ReadCloser返回任意结果
//
// 返回:
// - interface{}: 操作结果
// - error: 错误信息
func withZipReader(zipPath string, operation ZipOperation) (interface{}, error) {
// 1. 统一验证路径
if err := validateZipPath(zipPath); err != nil {
return nil, err
}
// 2. 打开 ZIP 文件
reader, err := zip.OpenReader(zipPath)
if err != nil {
return nil, fmt.Errorf("打开 zip 文件失败: %v", err)
}
defer reader.Close()
// 3. 执行操作
result, err := operation(reader)
if err != nil {
return nil, err
}
return result, nil
}
// withZipFile 在ZIP文件中查找特定文件并执行操作
// 进一步封装,用于处理单个文件的操作
type ZipFileOperation func(*zip.File) (interface{}, error)
// withZipFile 在ZIP中查找文件并执行操作
func withZipFile(zipPath, filePath string, operation ZipFileOperation) (interface{}, error) {
return withZipReader(zipPath, func(reader *zip.ReadCloser) (interface{}, error) {
// 查找目标文件
for _, file := range reader.File {
if isMatchFile(file, filePath) {
return operation(file)
}
}
return nil, fmt.Errorf("文件在 zip 中不存在: %s", filePath)
})
}
// isMatchFile 检查文件是否匹配目标路径
func isMatchFile(file *zip.File, targetPath string) bool {
return file.Name == targetPath ||
filepath.Clean(file.Name) == filepath.Clean(targetPath)
}
// openZipFileInReader 在ZIP reader中打开指定文件
// 用于读取文件内容的辅助函数
func openZipFileInReader(reader *zip.ReadCloser, filePath string) (io.ReadCloser, *zip.File, error) {
for _, file := range reader.File {
if isMatchFile(file, filePath) {
if file.Mode().IsDir() {
return nil, nil, fmt.Errorf("不能读取目录")
}
rc, err := file.Open()
if err != nil {
return nil, nil, fmt.Errorf("打开 zip 中的文件失败: %v", err)
}
return rc, file, nil
}
}
return nil, nil, fmt.Errorf("文件在 zip 中不存在: %s", filePath)
}
// readAllFromFile 从文件读取所有内容
// 辅助函数,避免重复的 io.ReadAll 调用
func readAllFromFile(rc io.ReadCloser) ([]byte, error) {
defer rc.Close()
return io.ReadAll(rc)
}
// getCompressionMethodString 获取压缩方法字符串描述
func getCompressionMethodString(method uint16) string {
if method == 8 {
return "Deflate"
}
return "Store"
}
// createFileInfoMap 创建文件信息map通用格式
func createFileInfoMap(file *zip.File, includeExtra ...bool) map[string]interface{} {
info := map[string]interface{}{
"name": filepath.Base(file.Name),
"path": file.Name,
"is_dir": file.Mode().IsDir(),
"size": file.UncompressedSize64,
"compressed": file.CompressedSize64,
"mod_time": file.Modified.Format("2006-01-02 15:04:05"),
"method": getCompressionMethodString(file.Method),
}
// 可选:额外信息
if len(includeExtra) > 0 && includeExtra[0] {
info["mode"] = file.Mode().String()
info["comment"] = file.Comment
}
return info
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"time"
"go-desk/internal/common"
"go-desk/internal/dbclient"
"go-desk/internal/storage/models"
"go-desk/internal/storage/repository"
@@ -48,7 +49,7 @@ func (s *SqlExecService) ExecuteSQL(connectionID uint, sqlStr string, database s
}
startTime := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
@@ -214,7 +215,7 @@ func (s *SqlExecService) GetDatabases(connectionID uint) ([]string, error) {
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
defer cancel()
switch conn.Type {
@@ -248,7 +249,7 @@ func (s *SqlExecService) GetTables(connectionID uint, database string) ([]string
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutFastQuery)
defer cancel()
switch conn.Type {
@@ -324,7 +325,7 @@ func (s *SqlExecService) GetTableStructure(connectionID uint, database, tableNam
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
@@ -387,7 +388,7 @@ func (s *SqlExecService) GetIndexes(connectionID uint, database, tableName strin
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutQuery)
defer cancel()
switch conn.Type {
@@ -413,7 +414,7 @@ func (s *SqlExecService) PreviewTableStructure(connectionID uint, database, tabl
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
defer cancel()
switch conn.Type {
@@ -443,7 +444,7 @@ func (s *SqlExecService) UpdateTableStructure(connectionID uint, database, table
return nil, fmt.Errorf("获取连接配置失败: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), common.TimeoutLongOp)
defer cancel()
switch conn.Type {

View File

@@ -13,7 +13,14 @@ import (
var globalDB *gorm.DB
// Init 快速初始化 SQLite兼容旧代码
func Init() (*gorm.DB, error) {
return InitFast()
}
// InitFast 超快速初始化 SQLite优化版
// 跳过不必要的检查,使用 WAL 模式,优化连接池
func InitFast() (*gorm.DB, error) {
if globalDB != nil {
return globalDB, nil
}
@@ -24,22 +31,35 @@ func Init() (*gorm.DB, error) {
}
dataDir := filepath.Join(homeDir, ".go-desk")
os.MkdirAll(dataDir, 0755)
if err := os.MkdirAll(dataDir, 0755); err != nil {
return nil, err
}
dbPath := filepath.Join(dataDir, "db-cli.db")
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
// 极限性能优化参数:
// - journal_mode=WAL: 写前日志,大幅提升并发性能
// - synchronous=NORMAL: 降低持久性要求,提升性能
// - cache_size=-64000: 64MB 缓存,减少磁盘 I/O
// - temp_store=MEMORY: 临时表存储在内存中
// - mmap_size=30000000000: 300MB 内存映射,加速读取
// - page_size=4096: 优化页面大小
db, err := gorm.Open(sqlite.Open(dbPath+"?_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)&_pragma=cache_size(-64000)&_pragma=temp_store(MEMORY)&_pragma=mmap_size(30000000000)&_pragma=page_size(4096)&_pragma=foreign_keys(1)"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
SkipDefaultTransaction: true, // 跳过默认事务,提升性能
PrepareStmt: true, // 预编译语句缓存
})
if err != nil {
return nil, err
}
sqlDB, _ := db.DB()
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxOpenConns(1) // SQLite 只需要一个连接
sqlDB.SetMaxIdleConns(1)
sqlDB.SetConnMaxLifetime(time.Hour)
// AutoMigrate 在启动时执行,但只在表结构不存在时创建
// SQLite 的 AutoMigrate 很快,不会造成明显延迟
if err := db.AutoMigrate(
&models.DbConnection{},
&models.SqlTab{},

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"runtime"
"go-desk/internal/common"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/host"
@@ -72,19 +74,6 @@ func GetMemoryInfo() (map[string]interface{}, error) {
return nil, fmt.Errorf("获取内存信息失败: %v", err)
}
formatBytes := func(bytes uint64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
return map[string]interface{}{
"total": memInfo.Total,
"used": memInfo.Used,
@@ -92,10 +81,10 @@ func GetMemoryInfo() (map[string]interface{}, error) {
"available": memInfo.Available,
"usage": fmt.Sprintf("%.2f%%", memInfo.UsedPercent),
"usage_raw": memInfo.UsedPercent,
"total_str": formatBytes(memInfo.Total),
"used_str": formatBytes(memInfo.Used),
"free_str": formatBytes(memInfo.Free),
"available_str": formatBytes(memInfo.Available),
"total_str": common.FormatBytes(memInfo.Total),
"used_str": common.FormatBytes(memInfo.Used),
"free_str": common.FormatBytes(memInfo.Free),
"available_str": common.FormatBytes(memInfo.Available),
}, nil
}
@@ -107,18 +96,6 @@ func GetDiskInfo() ([]map[string]interface{}, error) {
}
var result []map[string]interface{}
formatBytes := func(bytes uint64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.2f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
for _, partition := range partitions {
usage, err := disk.Usage(partition.Mountpoint)
@@ -135,9 +112,9 @@ func GetDiskInfo() ([]map[string]interface{}, error) {
"free": usage.Free,
"usage": fmt.Sprintf("%.2f%%", usage.UsedPercent),
"usage_raw": usage.UsedPercent,
"total_str": formatBytes(usage.Total),
"used_str": formatBytes(usage.Used),
"free_str": formatBytes(usage.Free),
"total_str": common.FormatBytes(usage.Total),
"used_str": common.FormatBytes(usage.Used),
"free_str": common.FormatBytes(usage.Free),
})
}