diff --git a/app.ini b/app.ini index e5417bb..c578968 100644 --- a/app.ini +++ b/app.ini @@ -17,3 +17,5 @@ file=zhub.log # db.user=root # db.password=123456 # db.database=zhub +# db.schema=public +# db.type=postgres # mysql|postgres diff --git a/internal/config/config.go b/internal/config/config.go index d88fc9c..1fd469b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,6 +30,8 @@ type Config struct { User string Password string Database string + Schema string + Type string } } Auth map[string]string diff --git a/internal/zbus/ztimer.go b/internal/zbus/ztimer.go index 43ff5d8..b8add2a 100644 --- a/internal/zbus/ztimer.go +++ b/internal/zbus/ztimer.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" // 导入 pq 驱动 "github.com/robfig/cron" "log" "regexp" @@ -162,19 +163,38 @@ func (s *ZBus) ReloadTimer() { return } - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", - Conf.Ztimer.Db.User, - Conf.Ztimer.Db.Password, - Conf.Ztimer.Db.Addr, - Conf.Ztimer.Db.Database, - )) + var db *sql.DB + var err error + + if Conf.Ztimer.Db.Type == "postgres" { + hostPort := strings.Split(Conf.Ztimer.Db.Addr, ":") + db, err = sql.Open("postgres", fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=disable", + Conf.Ztimer.Db.User, + Conf.Ztimer.Db.Password, + hostPort[0], + hostPort[1], + Conf.Ztimer.Db.Database, + )) + // 设置当前会话的 schema + _, err = db.Exec("SET search_path TO " + Conf.Ztimer.Db.Schema) + if err != nil { + log.Println(err) + } + } else { + db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", + Conf.Ztimer.Db.User, + Conf.Ztimer.Db.Password, + Conf.Ztimer.Db.Addr, + Conf.Ztimer.Db.Database, + )) + } if err != nil { log.Println(err) return } defer db.Close() - rows, err := db.Query("SELECT t.`name`, IF(t.`status`=10,t.`expr`,''), IF(t.`single`=1,'a','x') 'single' FROM tasktimer t ORDER BY t.`timerid`") + rows, err := db.Query("SELECT t.`name`, IF(t.`status`=10,t.`expr`,''), IF(t.`single`=1,'a','x') single FROM tasktimer t ORDER BY t.`timerid`") if err != nil { log.Println(err) return