diff --git a/cmd/client.go b/cmd/client.go index 2cad867..7eda4c7 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "github.com/go-basic/uuid" + "io" "unicode/utf8" //"github.com/go-basic/uuid" @@ -128,12 +129,7 @@ func (c *Client) init() { go c.receive() } -/* // subscribe topic ---- -subscribe x y z ---- -*/ func (c *Client) Subscribe(topic string, fun func(v string)) { c.send("subscribe " + topic) if fun != nil { @@ -158,18 +154,6 @@ func (c *Client) ping() { } //Publish -------------------------------------- pub-sub -------------------------------------- -/* -send topic message : ---- -*3 -$7 -message -$8 -my-topic -$24 -{username:xx,mobile:xxx} ---- -*/ func (c *Client) Publish(topic string, message string) error { return c.send("publish", topic, message) } @@ -344,19 +328,6 @@ func (c Client) RpcSubscribe(topic string, fun func(Rpc Rpc) RpcResult) { // -------------------------------------------------------------------------------- -/*func (c *Client) subscribes(topics ...string) error { - if len(topics) == 0 { - return nil - } - - messages := "subscribe" - for _, topic := range topics { - messages += " " + topic - } - c.send(messages) - return nil -}*/ - /* send socket message : if len(vs) equal 1 will send message `vs[0] + "\r\n"` @@ -387,85 +358,63 @@ a: } func (c *Client) receive() { - c.rlock.Lock() - defer c.rlock.Unlock() - r := bufio.NewReader(c.conn) for { v, _, err := r.ReadLine() if err != nil { - log.Println("receive error and reconn: ", err) - if err = c.reconn(); err == nil { - r = bufio.NewReader(c.conn) - } else { - - } - time.Sleep(time.Second * 3) - continue - } else if len(v) == 0 { - log.Println("receive empty") + log.Println(err) + return + } + if len(v) == 0 { continue } - - switch string(v[0:1]) { - case "*": // 订阅消息 - // 数据行数 - vlen, err := strconv.Atoi(string(v[1:])) - if err != nil { - log.Println("receive parse len error: ", err, string(v)) - continue - } - - // 读取完整数据 - vs := make([]string, 0) - for i := 0; i < vlen; i++ { - r.ReadLine() // $x - v, _, err = r.ReadLine() - if err != nil { - log.Println("receive parse v error: ", err) - } - vs = append(vs, string(v)) - } - - if len(vs) == 3 && strings.EqualFold(vs[0], "message") { - if strings.EqualFold(vs[1], "lock") { // message lock Uuid - go func() { - log.Println("lock:" + vs[2]) - c.wlock.Lock() - defer c.wlock.Unlock() - - if c.lockFlag[vs[2]] == nil { - return - } - c.lockFlag[vs[2]].flagChan <- 0 - }() - continue - } - c.chReceive <- vs - continue - } - if len(vs) == 2 && strings.EqualFold(vs[0], "timer") { - c.timerReceive <- vs - continue - } - /*if len(vs) == 2 && strings.EqualFold(vs[0], "delay") { - c.delayFun[vs[1]]() - delete(c.delayFun, vs[1]) - continue - }*/ - - continue - case "+": // +pong, +xxx - if strings.EqualFold("+ping", string(v)) { // 心跳消息回复 + switch string(v[0]) { + case "+": + if string(v) == "+ping" { c.send("+pong") } case "-": - fmt.Println("error:", string(v)) - case ":": - + log.Println("error:", string(v)) + case "*": + n, err := strconv.Atoi(string(v[1:])) + if err != nil { + log.Println(err) + continue + } + var vs []string + for i := 0; i < n; i++ { + line, _, err := r.ReadLine() + if err != nil { + log.Println(err) + continue + } + clen, _ := strconv.Atoi(string(line[1:])) + buf := make([]byte, clen) + _, err = io.ReadFull(r, buf) + if err != nil { + log.Println(err) + continue + } + vs = append(vs, string(buf)) + } + if len(vs) == 3 && vs[0] == "message" && vs[1] == "lock" { + go func() { + log.Println("lock:" + vs[2]) + c.wlock.Lock() + defer c.wlock.Unlock() + if c.lockFlag[vs[2]] == nil { + return + } + c.lockFlag[vs[2]].flagChan <- 0 + }() + continue + } + if len(vs) == 2 && vs[0] == "timer" { + c.timerReceive <- vs + continue + } } } - } // -------------------------------------- k-v -------------------------------------- diff --git a/go.mod b/go.mod index 96c742f..ae7832e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,26 @@ module zhub go 1.18 require ( - github.com/go-basic/uuid v1.0.0 // indirect + github.com/go-basic/uuid v1.0.0 github.com/go-sql-driver/mysql v1.5.0 + github.com/mitchellh/go-homedir v1.1.0 github.com/robfig/cron v1.2.0 + github.com/spf13/viper v1.15.0 +) + +require ( + github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect + github.com/spf13/afero v1.9.3 // indirect + github.com/spf13/cast v1.5.0 // indirect + github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.4.2 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/text v0.5.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..185d99b --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,124 @@ +package config + +import ( + "fmt" + "github.com/spf13/viper" + "log" + "os" +) + +type Log struct { + Handlers string + Level string + File string +} + +type Config struct { + Log Log + Service struct { + Watch string + Addr string + Auth bool + } + Data struct { + Dir string + } + Ztimer struct { + Db struct { + Addr string + User string + Password string + Database string + } + } + Auth map[string]string +} + +func main() { + config := ReadConfig() + + fmt.Printf("%+v", config) +} +func ReadConfig() Config { + conf := Config{} + viper.SetDefault("log.handlers", "console") + viper.SetDefault("log.level", "info") + viper.SetDefault("service.auth", true) + + /*// 读取指定的配置文件 + if !strings.EqualFold("", fileName) { + viper.AddConfigPath(fileName) // 指定配置文件 + if err := viper.ReadInConfig(); err == nil { + if err := viper.Unmarshal(&conf); err != nil { + log.Fatalf("Failed to unmarshal config: %s", err.Error()) + } + return conf + } + + log.Fatalf("Config file not found: " + fileName) + return conf + }*/ + + // 尝试从 /etc/ 目录下查找 zhub.ini 配置文件 + viper.AddConfigPath("/etc/") // 添加 /etc/ 目录作为配置文件搜索路径 + viper.SetConfigName("zhub") // 指定配置文件名为 zhub + if err := viper.ReadInConfig(); err == nil { + if err := viper.Unmarshal(&conf); err != nil { + log.Fatalf("Failed to unmarshal config: %s", err.Error()) + } + return conf + } + // 如果 /etc/ 目录下未找到配置文件,则尝试从当前程序运行目录下查找 app.ini 配置文件 + dir, err := os.Getwd() // 获取程序运行目录 + if err != nil { + log.Fatalf("Failed to get current directory: %s", err.Error()) + } + viper.SetConfigName("app") // 指定配置文件名为 app + viper.SetConfigType("ini") // 指定配置文件类型为 ini + viper.AddConfigPath(dir) // 添加当前程序所在目录作为配置文件搜索路径 + if err := viper.ReadInConfig(); err == nil { + if err := viper.Unmarshal(&conf); err != nil { + log.Fatalf("Failed to unmarshal config: %s", err.Error()) + } + return conf + } + // 如果在 /etc/ 目录和当前程序所在目录下均未找到配置文件,则报错 + log.Fatalf("Config file not found") + return conf +} +func InitLog(logConfig Log) { + logHandlers := logConfig.Handlers + logLevel := logConfig.Level + logFile := logConfig.File + + if logHandlers == "console" { + log.SetOutput(os.Stdout) + } else if logHandlers == "file" { + file, err := os.OpenFile(logFile, os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777) + if err != nil { + log.Println(err) + } + log.SetOutput(file) + } else { + log.SetOutput(os.Stdout) + } + + switch logLevel { + case "info": + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix("[Info] ") + log.Println("Logger is set up with log level: info") + case "debug": + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix("[Debug] ") + log.Println("Logger is set up with log level: debug") + case "error": + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix("[Error] ") + log.Println("Logger is set up with log level: error") + default: + log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile) + log.SetPrefix("[Info] ") + log.Println("Logger is set up with default log level: info") + } +} diff --git a/main.go b/main.go index 5b2c3cd..ac82b5c 100644 --- a/main.go +++ b/main.go @@ -1,57 +1,48 @@ package main import ( + "flag" "log" - "os" - "strings" - "time" "zhub/cmd" + "zhub/internal/config" "zhub/zsub" ) -var ( - dir, _ = os.Getwd() - confPath = dir + "/app.ini" // 配置文件地址 - server = true - addr = "" // 服务地址 -) - func main() { - for _, arg := range os.Args[1:] { - if strings.EqualFold(arg, "cli") { - server = false - } else if strings.Index(arg, "-d=") == 0 { - addr = arg[3:] - } else if strings.Index(arg, "-c=") == 0 { - confPath = arg[3:] - } - } - zsub.LoadConf(confPath) - if len(addr) == 0 { - addr = zsub.GetStr("service.zhub.servers", "127.0.0.1:1216") - } + var isCliMode bool // 是否以客户端模式运行的标志 + var rcmd string // 客户端模式下运行的命令 + flag.BoolVar(&isCliMode, "cli", false, "run as client mode") // 定义 cli 参数 + flag.StringVar(&rcmd, "r", "", "run as client mode") // 定义 r 参数 + flag.Parse() // 解析命令行参数 - if len(os.Args) == 3 && strings.EqualFold(os.Args[1], "-r") { - if cli, err := cmd.Create("zhub-local", addr, "group-admin", "zchd@123456"); err != nil { - log.Println(err) - } else { - switch os.Args[2] { - case "timer": - cli.Cmd("reload-timer") - case "shutdown", "stop": - cli.Cmd("shutdown") - } - cli.Close() - time.Sleep(time.Millisecond * 10) + conf := config.ReadConfig() // 读取配置文件 + addr := conf.Service.Addr // 获取服务地址 + config.InitLog(conf.Log) // 初始化日志配置 + + if rcmd != "" { // 如果指定了客户端命令 + auth := "" // 认证信息 + for key, value := range conf.Auth { // 遍历找到一个认证信息 + auth = key + "@" + value + break + } + cli, err := cmd.Create("zhub-local", addr, "group-admin", auth) // 创建客户端连接 + if err != nil { + log.Println(err) // 如果连接失败则打印错误信息 + return + } + defer cli.Close() // 延迟关闭客户端连接 + switch rcmd { + case "timer": + cli.Cmd("reload-timer") + case "shutdown", "stop": + cli.Cmd("shutdown") } return } - - if server { - go zsub.StartWatch() - zsub.StartServer(addr) // 服务进程启动 + if isCliMode { + cmd.ClientRun(addr) // 客户端运行 } else { - cmd.ClientRun(addr) + go zsub.StartWatch() // 启动监控协程 + zsub.StartServer(addr, conf) // 启动服务进程 } - } diff --git a/public/index.html b/public/index.html index 4c2078e..fb7820a 100644 --- a/public/index.html +++ b/public/index.html @@ -37,8 +37,8 @@ width:100%; height:50px; position:absolute; - top:100%; - margin-top:-50px; + bottom: 10px; + left: 0px; } diff --git a/zsub/config.go b/zsub/config.go deleted file mode 100644 index 6c23de4..0000000 --- a/zsub/config.go +++ /dev/null @@ -1,133 +0,0 @@ -package zsub - -import ( - "bufio" - "io" - "log" - "os" - "strconv" - "strings" -) - -var ( - dir, _ = os.Getwd() - config = make(map[string]string) - LogDebug bool - datadir = dir + "/data" -) - -func LoadConf(path string) { - //log.Println("APP_CONF =", path) - f, err := os.Open(path) - if err != nil { - log.Panicln(err) - } - - reader := bufio.NewReader(f) - space := "" - for { - bytes, err := reader.ReadBytes('\n') - if err == io.EOF { - break - } - line := string(bytes) - line = strings.Trim(line, " \r\n") - if len(line) == 0 { - continue - } - if strings.Contains(line, "#") { - line = line[0:strings.Index(line, "#")] - } - - switch { - case strings.EqualFold(line, ""): - case strings.Index(line, "[") == 0 && strings.Index(line, "]") > 0: - space = line[1:strings.Index(line, "]")] - space = strings.Trim(space, " ") - case strings.Index(line, "=") > 0: - arr := strings.Split(line, "=") - if len(arr) < 2 { - continue - } - - config[space+"."+strings.Trim(arr[0], " ")] = strings.Trim(arr[1], " ") - default: - continue - } - } - - LogDebug = strings.EqualFold(config["log.level"], "debug") - - datadir = GetStr("data.dir", "${APP_HOME}/data") - datadir = strings.ReplaceAll(datadir, "${APP_HOME}", dir) - - os.MkdirAll(datadir, os.ModeDir) - os.Chmod(datadir, 0777) - - initLog() -} - -func GetStr(key string, def string) string { - if len(config[key]) == 0 { - return def - } - return config[key] -} - -func GetInt(key string, def int) int { - if len(config[key]) == 0 { - return def - } - n, err := strconv.Atoi(config[key]) - if err != nil { - log.Println(err, "return def;") - return def - } - return n -} - -func initLog() { - defer func() { - if r := recover(); r != nil { - log.Println("initLog Err:", r) - } - }() - - file, err := os.OpenFile("zhub.log", os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777) - if err != nil { - log.Println(err) - } - log.SetOutput(file) - - /* - if strings.EqualFold(GetStr("log.handlers", "console"), "console") { - return - } - - var logfile = GetStr("log.pattern", "${APP_HOME}/logs-200601/log-20060102.log") - - c := cron.New() - fun := func() { - now := time.Now() - logfile := strings.ReplaceAll(logfile, "${APP_HOME}", dir) - logfile = now.Format(logfile) - - if strings.LastIndexAny(logfile, "/") > 0 { - logdir := logfile[0:strings.LastIndexAny(logfile, "/")] - os.MkdirAll(logdir, 0666) - } - - file, err := os.OpenFile(logfile, os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777) - if err != nil { - log.Println(err) - } - - //log.Println("SET LOG_FILE =", file.Name()) - log.SetOutput(file) - } - fun() - - c.AddFunc("0 0 * * * *", fun) - go c.Run() - */ -} diff --git a/zsub/monitor.go b/zsub/monitor.go index e87f482..5db155f 100644 --- a/zsub/monitor.go +++ b/zsub/monitor.go @@ -23,7 +23,7 @@ func StartWatch() { http.HandleFunc("/retimer", retimer) http.HandleFunc("/topic/publish", publish) - watchAddr := GetStr("service.zhub.watch", "0.0.0.0:1217") + watchAddr := Conf.Service.Watch log.Println("zhub.watch = ", watchAddr) http.ListenAndServe(watchAddr, nil) } diff --git a/zsub/msg-consumer.go b/zsub/msg-consumer.go index 779b372..389e0fb 100644 --- a/zsub/msg-consumer.go +++ b/zsub/msg-consumer.go @@ -29,10 +29,30 @@ func msgAccept(v Message) { return } - if LogDebug { - log.Printf("[%d] rcmd: %s\n", v.Conn.sn, strings.Join(rcmd, " ")) + if Conf.Log.Level == "debug" && rcmd[0] != "auth" { + log.Printf("[%d] cmd: %s\n", v.Conn.sn, strings.Join(rcmd, " ")) + } else if rcmd[0] == "auth" { + if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 { + c.send("-Error: invalid password!") + return + } + + inx := strings.IndexAny(rcmd[1], "@") //user@pwd + + authKey := rcmd[1][:inx] //user + authValue := Conf.Auth[rcmd[1][:inx]] //pwd + if strings.EqualFold(authValue, rcmd[1][inx+1:]) { + c.auth = rcmd[1][:inx] + c.send("+Auth: ok!") + log.Printf("[%d] cmd: %s\n", v.Conn.sn, "auth "+authKey+"@******* "+"[OK]") + } else { + c.send("-Auth: invalid password!") + log.Printf("[%d] cmd: %s\n", v.Conn.sn, "auth "+authKey+"@******* "+"[Error]") + } + return } - if strings.TrimSpace(c.auth) == "" && !strings.EqualFold("auth", rcmd[0]) && strings.EqualFold(GetStr("service.auth", "0"), "1") { + + if strings.TrimSpace(c.auth) == "" && rcmd[0] != "auth" && Conf.Service.Auth { c.send("-Auth: NOAUTH Authentication required:" + rcmd[0]) return } @@ -153,21 +173,22 @@ func msgAccept(v Message) { return } zsub._unlock(Lock{key: rcmd[1], uuid: rcmd[2]}) - case "auth": - if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 { - c.send("-Error: invalid password!") - return - } - - inx := strings.IndexAny(rcmd[1], "@") //user@pwd - - if strings.EqualFold(GetStr("auth."+rcmd[1][:inx], ""), rcmd[1][inx+1:]) { - c.auth = rcmd[1][:inx] - c.send("+Auth: ok!") - } else { - c.send("-Auth: invalid password!") - } + /*case "auth": + if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 { + c.send("-Error: invalid password!") return + } + + inx := strings.IndexAny(rcmd[1], "@") //user@pwd + + authKey := Conf.Auth[rcmd[1][:inx]] + if strings.EqualFold(authKey, rcmd[1][inx+1:]) { + c.auth = rcmd[1][:inx] + c.send("+Auth: ok!") + } else { + c.send("-Auth: invalid password!") + } + return*/ default: c.send("-Error: default not supported:[" + strings.Join(rcmd, " ") + "]") return diff --git a/zsub/zdb.go b/zsub/zdb.go index 13bd0c3..6a3761c 100644 --- a/zsub/zdb.go +++ b/zsub/zdb.go @@ -64,16 +64,19 @@ func (s *ZSub) dataStorage() { fmt.Println(err) } defer file.Close() - writer := bufio.NewWriter(file) - delays2 := s.delays - for _, delay := range delays2 { - writer.WriteString(delay.topic) + writer := bufio.NewWriter(file) + _delays := s.delays + + for _, delay := range _delays { + delayStr := fmt.Sprintf("%s %s %d\n", delay.topic, delay.value, delay.exectime.Unix()) + writer.WriteString(delayStr) + /*writer.WriteString(delay.topic) writer.WriteString(" ") writer.WriteString(delay.value) writer.WriteString(" ") writer.WriteString(strconv.FormatInt(delay.exectime.Unix(), 10)) - writer.WriteString("\n") + writer.WriteString("\n")*/ } writer.Flush() }() diff --git a/zsub/zsub.go b/zsub/zsub.go index 2e7bf9e..1d4276f 100644 --- a/zsub/zsub.go +++ b/zsub/zsub.go @@ -12,10 +12,13 @@ import ( "sync/atomic" "time" "unicode/utf8" + "zhub/internal/config" ) var ( - zsub = &ZSub{ + Conf config.Config + datadir string + zsub = &ZSub{ topics: make(map[string]*ZTopic), timers: make(map[string]*ZTimer), delays: make(map[string]*ZDelay), @@ -55,7 +58,7 @@ func init() { // close for _, c := range conns { - log.Println("========================================= conn ping close:", (*c.conn).RemoteAddr(), "[", c.groupid, "] =========================================") + log.Printf("========================================= conn ping close:%s [%d] =========================================\n", (*c.conn).RemoteAddr(), c.sn) c.close() } @@ -250,7 +253,10 @@ StartServer 1、load history data 2、init server */ -func StartServer(addr string) { +func StartServer(addr string, conf config.Config) { + Conf = conf + datadir = conf.Data.Dir + go func() { for { fun, ok := <-funChan @@ -281,7 +287,7 @@ func StartServer(addr string) { } zConn := NewZConn(&conn) - log.Println("conn start:", conn.RemoteAddr(), "[", zConn.sn, "]") + log.Printf("conn start: %s [%d]\n", conn.RemoteAddr(), zConn.sn) go zsub.acceptHandler(zConn) } } diff --git a/zsub/ztimer.go b/zsub/ztimer.go index 116bc34..79c4962 100644 --- a/zsub/ztimer.go +++ b/zsub/ztimer.go @@ -99,10 +99,14 @@ func (s *ZSub) timer(rcmd []string, c *ZConn) { func (s *ZSub) ReloadTimer() { db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", - GetStr("ztimer.db.user", "root"), + Conf.Ztimer.Db.User, + Conf.Ztimer.Db.Password, + Conf.Ztimer.Db.Addr, + Conf.Ztimer.Db.Database, + /*GetStr("ztimer.db.user", "root"), GetStr("ztimer.db.pwd", "123456"), GetStr("ztimer.db.addr", "127.0.0.1:3306"), - GetStr("ztimer.db.database", "zhub"), + GetStr("ztimer.db.database", "zhub"),*/ )) if err != nil { diff --git a/zsub/ztopic.go b/zsub/ztopic.go index 9b41eb4..023e8b1 100644 --- a/zsub/ztopic.go +++ b/zsub/ztopic.go @@ -23,10 +23,10 @@ func (t *ZTopic) init() { break } - for name, group := range t.groups { + for groupName, group := range t.groups { // zgroup chan overload check if len(group.chMsg) == cap(group.chMsg) { - log.Println(fmt.Sprintf("zgroup no cap: [%s.%s %s]", name, t.topic, msg)) + log.Println(fmt.Sprintf("zgroup no cap: [%s.%s %s]", groupName, t.topic, msg)) continue } group.chMsg <- msg