package config import ( "fmt" "os" "strings" "gopkg.in/yaml.v3" ) type Config struct { Server ServerConfig `yaml:"server"` Auth AuthConfig `yaml:"auth"` CORS CORSConfig `yaml:"cors"` Log LogConfig `yaml:"log"` FileServer FileServerConfig `yaml:"file_server"` Security SecurityConfig `yaml:"security"` } type ServerConfig struct { Port int `yaml:"port"` Host string `yaml:"host"` } type AuthConfig struct { Token string `yaml:"token"` } type CORSConfig struct { AllowedOrigins []string `yaml:"allowed_origins"` } type LogConfig struct { Level string `yaml:"level"` Format string `yaml:"format"` } type FileServerConfig struct { Port int `yaml:"port"` MaxFileSize int64 `yaml:"max_file_size"` } type SecurityConfig struct { AllowSymlinks bool `yaml:"allow_symlinks"` CheckSystemPaths bool `yaml:"check_system_paths"` } // FileServerAddr 返回文件服务器的完整地址 func (c *Config) FileServerAddr() string { return fmt.Sprintf("http://localhost:%d", c.FileServer.Port) } func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { // 配置文件不存在时使用默认值 if os.IsNotExist(err) { return Default(), nil } return nil, err } cfg := Default() if err := yaml.Unmarshal(data, cfg); err != nil { return nil, err } // 清理 origins 中的空格并去重 seen := make(map[string]bool, len(cfg.CORS.AllowedOrigins)) uniques := cfg.CORS.AllowedOrigins[:0] for _, origin := range cfg.CORS.AllowedOrigins { o := strings.TrimSpace(origin) if o != "" && !seen[o] { seen[o] = true uniques = append(uniques, o) } } cfg.CORS.AllowedOrigins = uniques return cfg, nil } func Default() *Config { return &Config{ Server: ServerConfig{ Port: 9876, Host: "0.0.0.0", }, Auth: AuthConfig{ Token: "", }, CORS: CORSConfig{ AllowedOrigins: []string{"*"}, }, Log: LogConfig{ Level: "info", Format: "json", }, FileServer: FileServerConfig{ Port: 2652, MaxFileSize: 500 * 1024 * 1024, }, Security: SecurityConfig{ AllowSymlinks: false, CheckSystemPaths: true, }, } }