优化:配置文件管理、启动参数解析等

git-svn-id: svn://47.119.165.148/zhub@167 e63fbceb-bcc3-4977-ac22-735b83d8d0f4
This commit is contained in:
lxy
2023-05-21 17:47:04 +00:00
parent fd7ac85045
commit 66321ce7a8
12 changed files with 290 additions and 306 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/go-basic/uuid"
"io"
"unicode/utf8"
//"github.com/go-basic/uuid"
@@ -128,12 +129,7 @@ func (c *Client) init() {
go c.receive()
}
/*
// subscribe topic
---
subscribe x y z
---
*/
func (c *Client) Subscribe(topic string, fun func(v string)) {
c.send("subscribe " + topic)
if fun != nil {
@@ -158,18 +154,6 @@ func (c *Client) ping() {
}
//Publish -------------------------------------- pub-sub --------------------------------------
/*
send topic message :
---
*3
$7
message
$8
my-topic
$24
{username:xx,mobile:xxx}
---
*/
func (c *Client) Publish(topic string, message string) error {
return c.send("publish", topic, message)
}
@@ -344,19 +328,6 @@ func (c Client) RpcSubscribe(topic string, fun func(Rpc Rpc) RpcResult) {
// --------------------------------------------------------------------------------
/*func (c *Client) subscribes(topics ...string) error {
if len(topics) == 0 {
return nil
}
messages := "subscribe"
for _, topic := range topics {
messages += " " + topic
}
c.send(messages)
return nil
}*/
/*
send socket message :
if len(vs) equal 1 will send message `vs[0] + "\r\n"`
@@ -387,85 +358,63 @@ a:
}
func (c *Client) receive() {
c.rlock.Lock()
defer c.rlock.Unlock()
r := bufio.NewReader(c.conn)
for {
v, _, err := r.ReadLine()
if err != nil {
log.Println("receive error and reconn: ", err)
if err = c.reconn(); err == nil {
r = bufio.NewReader(c.conn)
} else {
}
time.Sleep(time.Second * 3)
continue
} else if len(v) == 0 {
log.Println("receive empty")
log.Println(err)
return
}
if len(v) == 0 {
continue
}
switch string(v[0:1]) {
case "*": // 订阅消息
// 数据行数
vlen, err := strconv.Atoi(string(v[1:]))
if err != nil {
log.Println("receive parse len error: ", err, string(v))
continue
}
// 读取完整数据
vs := make([]string, 0)
for i := 0; i < vlen; i++ {
r.ReadLine() // $x
v, _, err = r.ReadLine()
if err != nil {
log.Println("receive parse v error: ", err)
}
vs = append(vs, string(v))
}
if len(vs) == 3 && strings.EqualFold(vs[0], "message") {
if strings.EqualFold(vs[1], "lock") { // message lock Uuid
go func() {
log.Println("lock:" + vs[2])
c.wlock.Lock()
defer c.wlock.Unlock()
if c.lockFlag[vs[2]] == nil {
return
}
c.lockFlag[vs[2]].flagChan <- 0
}()
continue
}
c.chReceive <- vs
continue
}
if len(vs) == 2 && strings.EqualFold(vs[0], "timer") {
c.timerReceive <- vs
continue
}
/*if len(vs) == 2 && strings.EqualFold(vs[0], "delay") {
c.delayFun[vs[1]]()
delete(c.delayFun, vs[1])
continue
}*/
continue
case "+": // +pong, +xxx
if strings.EqualFold("+ping", string(v)) { // 心跳消息回复
switch string(v[0]) {
case "+":
if string(v) == "+ping" {
c.send("+pong")
}
case "-":
fmt.Println("error:", string(v))
case ":":
log.Println("error:", string(v))
case "*":
n, err := strconv.Atoi(string(v[1:]))
if err != nil {
log.Println(err)
continue
}
var vs []string
for i := 0; i < n; i++ {
line, _, err := r.ReadLine()
if err != nil {
log.Println(err)
continue
}
clen, _ := strconv.Atoi(string(line[1:]))
buf := make([]byte, clen)
_, err = io.ReadFull(r, buf)
if err != nil {
log.Println(err)
continue
}
vs = append(vs, string(buf))
}
if len(vs) == 3 && vs[0] == "message" && vs[1] == "lock" {
go func() {
log.Println("lock:" + vs[2])
c.wlock.Lock()
defer c.wlock.Unlock()
if c.lockFlag[vs[2]] == nil {
return
}
c.lockFlag[vs[2]].flagChan <- 0
}()
continue
}
if len(vs) == 2 && vs[0] == "timer" {
c.timerReceive <- vs
continue
}
}
}
}
// -------------------------------------- k-v --------------------------------------

21
go.mod
View File

@@ -3,7 +3,26 @@ module zhub
go 1.18
require (
github.com/go-basic/uuid v1.0.0 // indirect
github.com/go-basic/uuid v1.0.0
github.com/go-sql-driver/mysql v1.5.0
github.com/mitchellh/go-homedir v1.1.0
github.com/robfig/cron v1.2.0
github.com/spf13/viper v1.15.0
)
require (
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/spf13/afero v1.9.3 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
golang.org/x/sys v0.3.0 // indirect
golang.org/x/text v0.5.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

124
internal/config/config.go Normal file
View File

@@ -0,0 +1,124 @@
package config
import (
"fmt"
"github.com/spf13/viper"
"log"
"os"
)
type Log struct {
Handlers string
Level string
File string
}
type Config struct {
Log Log
Service struct {
Watch string
Addr string
Auth bool
}
Data struct {
Dir string
}
Ztimer struct {
Db struct {
Addr string
User string
Password string
Database string
}
}
Auth map[string]string
}
func main() {
config := ReadConfig()
fmt.Printf("%+v", config)
}
func ReadConfig() Config {
conf := Config{}
viper.SetDefault("log.handlers", "console")
viper.SetDefault("log.level", "info")
viper.SetDefault("service.auth", true)
/*// 读取指定的配置文件
if !strings.EqualFold("", fileName) {
viper.AddConfigPath(fileName) // 指定配置文件
if err := viper.ReadInConfig(); err == nil {
if err := viper.Unmarshal(&conf); err != nil {
log.Fatalf("Failed to unmarshal config: %s", err.Error())
}
return conf
}
log.Fatalf("Config file not found: " + fileName)
return conf
}*/
// 尝试从 /etc/ 目录下查找 zhub.ini 配置文件
viper.AddConfigPath("/etc/") // 添加 /etc/ 目录作为配置文件搜索路径
viper.SetConfigName("zhub") // 指定配置文件名为 zhub
if err := viper.ReadInConfig(); err == nil {
if err := viper.Unmarshal(&conf); err != nil {
log.Fatalf("Failed to unmarshal config: %s", err.Error())
}
return conf
}
// 如果 /etc/ 目录下未找到配置文件,则尝试从当前程序运行目录下查找 app.ini 配置文件
dir, err := os.Getwd() // 获取程序运行目录
if err != nil {
log.Fatalf("Failed to get current directory: %s", err.Error())
}
viper.SetConfigName("app") // 指定配置文件名为 app
viper.SetConfigType("ini") // 指定配置文件类型为 ini
viper.AddConfigPath(dir) // 添加当前程序所在目录作为配置文件搜索路径
if err := viper.ReadInConfig(); err == nil {
if err := viper.Unmarshal(&conf); err != nil {
log.Fatalf("Failed to unmarshal config: %s", err.Error())
}
return conf
}
// 如果在 /etc/ 目录和当前程序所在目录下均未找到配置文件,则报错
log.Fatalf("Config file not found")
return conf
}
func InitLog(logConfig Log) {
logHandlers := logConfig.Handlers
logLevel := logConfig.Level
logFile := logConfig.File
if logHandlers == "console" {
log.SetOutput(os.Stdout)
} else if logHandlers == "file" {
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777)
if err != nil {
log.Println(err)
}
log.SetOutput(file)
} else {
log.SetOutput(os.Stdout)
}
switch logLevel {
case "info":
log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile)
log.SetPrefix("[Info] ")
log.Println("Logger is set up with log level: info")
case "debug":
log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile)
log.SetPrefix("[Debug] ")
log.Println("Logger is set up with log level: debug")
case "error":
log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile)
log.SetPrefix("[Error] ")
log.Println("Logger is set up with log level: error")
default:
log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile)
log.SetPrefix("[Info] ")
log.Println("Logger is set up with default log level: info")
}
}

73
main.go
View File

@@ -1,57 +1,48 @@
package main
import (
"flag"
"log"
"os"
"strings"
"time"
"zhub/cmd"
"zhub/internal/config"
"zhub/zsub"
)
var (
dir, _ = os.Getwd()
confPath = dir + "/app.ini" // 配置文件地址
server = true
addr = "" // 服务地址
)
func main() {
for _, arg := range os.Args[1:] {
if strings.EqualFold(arg, "cli") {
server = false
} else if strings.Index(arg, "-d=") == 0 {
addr = arg[3:]
} else if strings.Index(arg, "-c=") == 0 {
confPath = arg[3:]
}
}
zsub.LoadConf(confPath)
if len(addr) == 0 {
addr = zsub.GetStr("service.zhub.servers", "127.0.0.1:1216")
}
var isCliMode bool // 是否以客户端模式运行的标志
var rcmd string // 客户端模式下运行的命令
flag.BoolVar(&isCliMode, "cli", false, "run as client mode") // 定义 cli 参数
flag.StringVar(&rcmd, "r", "", "run as client mode") // 定义 r 参数
flag.Parse() // 解析命令行参数
if len(os.Args) == 3 && strings.EqualFold(os.Args[1], "-r") {
if cli, err := cmd.Create("zhub-local", addr, "group-admin", "zchd@123456"); err != nil {
log.Println(err)
} else {
switch os.Args[2] {
case "timer":
cli.Cmd("reload-timer")
case "shutdown", "stop":
cli.Cmd("shutdown")
}
cli.Close()
time.Sleep(time.Millisecond * 10)
conf := config.ReadConfig() // 读取配置文件
addr := conf.Service.Addr // 获取服务地址
config.InitLog(conf.Log) // 初始化日志配置
if rcmd != "" { // 如果指定了客户端命令
auth := "" // 认证信息
for key, value := range conf.Auth { // 遍历找到一个认证信息
auth = key + "@" + value
break
}
cli, err := cmd.Create("zhub-local", addr, "group-admin", auth) // 创建客户端连接
if err != nil {
log.Println(err) // 如果连接失败则打印错误信息
return
}
defer cli.Close() // 延迟关闭客户端连接
switch rcmd {
case "timer":
cli.Cmd("reload-timer")
case "shutdown", "stop":
cli.Cmd("shutdown")
}
return
}
if server {
go zsub.StartWatch()
zsub.StartServer(addr) // 服务进程启动
if isCliMode {
cmd.ClientRun(addr) // 客户端运行
} else {
cmd.ClientRun(addr)
go zsub.StartWatch() // 启动监控协程
zsub.StartServer(addr, conf) // 启动服务进程
}
}

View File

@@ -37,8 +37,8 @@
width:100%;
height:50px;
position:absolute;
top:100%;
margin-top:-50px;
bottom: 10px;
left: 0px;
}
</style>
</head>

View File

@@ -1,133 +0,0 @@
package zsub
import (
"bufio"
"io"
"log"
"os"
"strconv"
"strings"
)
var (
dir, _ = os.Getwd()
config = make(map[string]string)
LogDebug bool
datadir = dir + "/data"
)
func LoadConf(path string) {
//log.Println("APP_CONF =", path)
f, err := os.Open(path)
if err != nil {
log.Panicln(err)
}
reader := bufio.NewReader(f)
space := ""
for {
bytes, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
line := string(bytes)
line = strings.Trim(line, " \r\n")
if len(line) == 0 {
continue
}
if strings.Contains(line, "#") {
line = line[0:strings.Index(line, "#")]
}
switch {
case strings.EqualFold(line, ""):
case strings.Index(line, "[") == 0 && strings.Index(line, "]") > 0:
space = line[1:strings.Index(line, "]")]
space = strings.Trim(space, " ")
case strings.Index(line, "=") > 0:
arr := strings.Split(line, "=")
if len(arr) < 2 {
continue
}
config[space+"."+strings.Trim(arr[0], " ")] = strings.Trim(arr[1], " ")
default:
continue
}
}
LogDebug = strings.EqualFold(config["log.level"], "debug")
datadir = GetStr("data.dir", "${APP_HOME}/data")
datadir = strings.ReplaceAll(datadir, "${APP_HOME}", dir)
os.MkdirAll(datadir, os.ModeDir)
os.Chmod(datadir, 0777)
initLog()
}
func GetStr(key string, def string) string {
if len(config[key]) == 0 {
return def
}
return config[key]
}
func GetInt(key string, def int) int {
if len(config[key]) == 0 {
return def
}
n, err := strconv.Atoi(config[key])
if err != nil {
log.Println(err, "return def;")
return def
}
return n
}
func initLog() {
defer func() {
if r := recover(); r != nil {
log.Println("initLog Err:", r)
}
}()
file, err := os.OpenFile("zhub.log", os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777)
if err != nil {
log.Println(err)
}
log.SetOutput(file)
/*
if strings.EqualFold(GetStr("log.handlers", "console"), "console") {
return
}
var logfile = GetStr("log.pattern", "${APP_HOME}/logs-200601/log-20060102.log")
c := cron.New()
fun := func() {
now := time.Now()
logfile := strings.ReplaceAll(logfile, "${APP_HOME}", dir)
logfile = now.Format(logfile)
if strings.LastIndexAny(logfile, "/") > 0 {
logdir := logfile[0:strings.LastIndexAny(logfile, "/")]
os.MkdirAll(logdir, 0666)
}
file, err := os.OpenFile(logfile, os.O_CREATE|os.O_APPEND|os.O_SYNC|os.O_RDWR, 0777)
if err != nil {
log.Println(err)
}
//log.Println("SET LOG_FILE =", file.Name())
log.SetOutput(file)
}
fun()
c.AddFunc("0 0 * * * *", fun)
go c.Run()
*/
}

View File

@@ -23,7 +23,7 @@ func StartWatch() {
http.HandleFunc("/retimer", retimer)
http.HandleFunc("/topic/publish", publish)
watchAddr := GetStr("service.zhub.watch", "0.0.0.0:1217")
watchAddr := Conf.Service.Watch
log.Println("zhub.watch = ", watchAddr)
http.ListenAndServe(watchAddr, nil)
}

View File

@@ -29,10 +29,30 @@ func msgAccept(v Message) {
return
}
if LogDebug {
log.Printf("[%d] rcmd: %s\n", v.Conn.sn, strings.Join(rcmd, " "))
if Conf.Log.Level == "debug" && rcmd[0] != "auth" {
log.Printf("[%d] cmd: %s\n", v.Conn.sn, strings.Join(rcmd, " "))
} else if rcmd[0] == "auth" {
if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 {
c.send("-Error: invalid password!")
return
}
inx := strings.IndexAny(rcmd[1], "@") //user@pwd
authKey := rcmd[1][:inx] //user
authValue := Conf.Auth[rcmd[1][:inx]] //pwd
if strings.EqualFold(authValue, rcmd[1][inx+1:]) {
c.auth = rcmd[1][:inx]
c.send("+Auth: ok!")
log.Printf("[%d] cmd: %s\n", v.Conn.sn, "auth "+authKey+"@******* "+"[OK]")
} else {
c.send("-Auth: invalid password!")
log.Printf("[%d] cmd: %s\n", v.Conn.sn, "auth "+authKey+"@******* "+"[Error]")
}
return
}
if strings.TrimSpace(c.auth) == "" && !strings.EqualFold("auth", rcmd[0]) && strings.EqualFold(GetStr("service.auth", "0"), "1") {
if strings.TrimSpace(c.auth) == "" && rcmd[0] != "auth" && Conf.Service.Auth {
c.send("-Auth: NOAUTH Authentication required:" + rcmd[0])
return
}
@@ -153,21 +173,22 @@ func msgAccept(v Message) {
return
}
zsub._unlock(Lock{key: rcmd[1], uuid: rcmd[2]})
case "auth":
if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 {
c.send("-Error: invalid password!")
return
}
inx := strings.IndexAny(rcmd[1], "@") //user@pwd
if strings.EqualFold(GetStr("auth."+rcmd[1][:inx], ""), rcmd[1][inx+1:]) {
c.auth = rcmd[1][:inx]
c.send("+Auth: ok!")
} else {
c.send("-Auth: invalid password!")
}
/*case "auth":
if len(rcmd) != 2 || strings.IndexAny(rcmd[1], "@") == -1 {
c.send("-Error: invalid password!")
return
}
inx := strings.IndexAny(rcmd[1], "@") //user@pwd
authKey := Conf.Auth[rcmd[1][:inx]]
if strings.EqualFold(authKey, rcmd[1][inx+1:]) {
c.auth = rcmd[1][:inx]
c.send("+Auth: ok!")
} else {
c.send("-Auth: invalid password!")
}
return*/
default:
c.send("-Error: default not supported:[" + strings.Join(rcmd, " ") + "]")
return

View File

@@ -64,16 +64,19 @@ func (s *ZSub) dataStorage() {
fmt.Println(err)
}
defer file.Close()
writer := bufio.NewWriter(file)
delays2 := s.delays
for _, delay := range delays2 {
writer.WriteString(delay.topic)
writer := bufio.NewWriter(file)
_delays := s.delays
for _, delay := range _delays {
delayStr := fmt.Sprintf("%s %s %d\n", delay.topic, delay.value, delay.exectime.Unix())
writer.WriteString(delayStr)
/*writer.WriteString(delay.topic)
writer.WriteString(" ")
writer.WriteString(delay.value)
writer.WriteString(" ")
writer.WriteString(strconv.FormatInt(delay.exectime.Unix(), 10))
writer.WriteString("\n")
writer.WriteString("\n")*/
}
writer.Flush()
}()

View File

@@ -12,10 +12,13 @@ import (
"sync/atomic"
"time"
"unicode/utf8"
"zhub/internal/config"
)
var (
zsub = &ZSub{
Conf config.Config
datadir string
zsub = &ZSub{
topics: make(map[string]*ZTopic),
timers: make(map[string]*ZTimer),
delays: make(map[string]*ZDelay),
@@ -55,7 +58,7 @@ func init() {
// close
for _, c := range conns {
log.Println("========================================= conn ping close:", (*c.conn).RemoteAddr(), "[", c.groupid, "] =========================================")
log.Printf("========================================= conn ping close:%s [%d] =========================================\n", (*c.conn).RemoteAddr(), c.sn)
c.close()
}
@@ -250,7 +253,10 @@ StartServer
1、load history data
2、init server
*/
func StartServer(addr string) {
func StartServer(addr string, conf config.Config) {
Conf = conf
datadir = conf.Data.Dir
go func() {
for {
fun, ok := <-funChan
@@ -281,7 +287,7 @@ func StartServer(addr string) {
}
zConn := NewZConn(&conn)
log.Println("conn start:", conn.RemoteAddr(), "[", zConn.sn, "]")
log.Printf("conn start: %s [%d]\n", conn.RemoteAddr(), zConn.sn)
go zsub.acceptHandler(zConn)
}
}

View File

@@ -99,10 +99,14 @@ func (s *ZSub) timer(rcmd []string, c *ZConn) {
func (s *ZSub) ReloadTimer() {
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
GetStr("ztimer.db.user", "root"),
Conf.Ztimer.Db.User,
Conf.Ztimer.Db.Password,
Conf.Ztimer.Db.Addr,
Conf.Ztimer.Db.Database,
/*GetStr("ztimer.db.user", "root"),
GetStr("ztimer.db.pwd", "123456"),
GetStr("ztimer.db.addr", "127.0.0.1:3306"),
GetStr("ztimer.db.database", "zhub"),
GetStr("ztimer.db.database", "zhub"),*/
))
if err != nil {

View File

@@ -23,10 +23,10 @@ func (t *ZTopic) init() {
break
}
for name, group := range t.groups {
for groupName, group := range t.groups {
// zgroup chan overload check
if len(group.chMsg) == cap(group.chMsg) {
log.Println(fmt.Sprintf("zgroup no cap: [%s.%s %s]", name, t.topic, msg))
log.Println(fmt.Sprintf("zgroup no cap: [%s.%s %s]", groupName, t.topic, msg))
continue
}
group.chMsg <- msg