新增: 自动更新与一键自升级
This commit is contained in:
269
internal/update.go
Normal file
269
internal/update.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/bubbletea/v2"
|
||||
)
|
||||
|
||||
// Version 当前版本,发布时更新
|
||||
const Version = "0.1.0"
|
||||
|
||||
// --- 远程 JSON 结构 ---
|
||||
|
||||
type platformInfo struct {
|
||||
DownloadURL string `json:"download_url"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
SHA256 string `json:"sha256"`
|
||||
}
|
||||
|
||||
type versionInfo struct {
|
||||
Version string `json:"version"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
Changelog string `json:"changelog"`
|
||||
Platforms map[string]platformInfo `json:"platforms"`
|
||||
}
|
||||
|
||||
// --- bubbletea 消息类型 ---
|
||||
|
||||
// UpdateAvailableMsg 发现新版本
|
||||
type UpdateAvailableMsg struct {
|
||||
NewVersion string
|
||||
Changelog string
|
||||
DownloadURL string
|
||||
SHA256 string
|
||||
FileSize int64
|
||||
}
|
||||
|
||||
// UpdateCompleteMsg 更新完成
|
||||
type UpdateCompleteMsg struct {
|
||||
NewVersion string
|
||||
}
|
||||
|
||||
// UpdateErrorMsg 更新失败
|
||||
type UpdateErrorMsg struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// --- 更新状态 ---
|
||||
|
||||
// UpdateState 更新流程状态
|
||||
type UpdateState struct {
|
||||
Checking bool
|
||||
Available bool
|
||||
Updating bool
|
||||
NewVersion string
|
||||
Changelog string
|
||||
DownloadURL string
|
||||
SHA256 string
|
||||
FileSize int64
|
||||
Done bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// --- Cmd 函数 ---
|
||||
|
||||
// CheckUpdateCmd 检查远程是否有新版本,返回 bubbletea Cmd
|
||||
func CheckUpdateCmd() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://c.1216.top/u-tabs/last-version.json", nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "u-tabs/"+Version)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var info versionInfo
|
||||
if err := json.Unmarshal(body, &info); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查找当前平台的下载信息
|
||||
platformKey := runtime.GOOS + "-" + runtime.GOARCH
|
||||
pi, ok := info.Platforms[platformKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 比较版本
|
||||
if !semverCompare(Version, info.Version) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return UpdateAvailableMsg{
|
||||
NewVersion: info.Version,
|
||||
Changelog: info.Changelog,
|
||||
DownloadURL: pi.DownloadURL,
|
||||
SHA256: pi.SHA256,
|
||||
FileSize: pi.FileSize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SelfUpdateCmd 下载并替换当前二进制,返回 bubbletea Cmd
|
||||
func SelfUpdateCmd(downloadURL, expectedSHA256, newVersion string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// 下载到临时文件
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
|
||||
req, err := http.NewRequest("GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("create request failed: %w", err)}
|
||||
}
|
||||
req.Header.Set("User-Agent", "u-tabs/"+Version)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("download failed: %w", err)}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("download returned status %d", resp.StatusCode)}
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "u-tabs-update-*")
|
||||
if err != nil {
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("create temp file failed: %w", err)}
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(io.MultiWriter(tmpFile, hasher), resp.Body); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("download write failed: %w", err)}
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// 校验 SHA256
|
||||
actualSHA := fmt.Sprintf("%x", hasher.Sum(nil))
|
||||
if !strings.EqualFold(actualSHA, expectedSHA256) {
|
||||
os.Remove(tmpPath)
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("sha256 mismatch: expected %s, got %s", expectedSHA256, actualSHA)}
|
||||
}
|
||||
|
||||
// 替换当前二进制
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("get executable path failed: %w", err)}
|
||||
}
|
||||
|
||||
// 备份旧二进制
|
||||
oldPath := exePath + ".old"
|
||||
_ = os.Remove(oldPath) // 清理可能残留的旧备份
|
||||
if err := os.Rename(exePath, oldPath); err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("rename old binary failed: %w", err)}
|
||||
}
|
||||
|
||||
// 复制新二进制到原路径
|
||||
if err := copyFile(tmpPath, exePath); err != nil {
|
||||
// 回滚:恢复旧二进制
|
||||
os.Remove(exePath)
|
||||
os.Rename(oldPath, exePath)
|
||||
os.Remove(tmpPath)
|
||||
return UpdateErrorMsg{Err: fmt.Errorf("copy new binary failed: %w", err)}
|
||||
}
|
||||
|
||||
// Linux/macOS: 设置可执行权限
|
||||
if runtime.GOOS != "windows" {
|
||||
os.Chmod(exePath, 0755)
|
||||
}
|
||||
|
||||
// 清理临时文件和旧备份
|
||||
os.Remove(tmpPath)
|
||||
_ = os.Remove(oldPath)
|
||||
|
||||
return UpdateCompleteMsg{NewVersion: newVersion}
|
||||
}
|
||||
}
|
||||
|
||||
// --- 辅助函数 ---
|
||||
|
||||
// semverCompare 比较 v1 和 v2,v2 > v1 时返回 true
|
||||
func semverCompare(v1, v2 string) bool {
|
||||
parts1 := strings.Split(v1, ".")
|
||||
parts2 := strings.Split(v2, ".")
|
||||
|
||||
maxLen := len(parts1)
|
||||
if len(parts2) > maxLen {
|
||||
maxLen = len(parts2)
|
||||
}
|
||||
|
||||
for i := 0; i < maxLen; i++ {
|
||||
var n1, n2 int
|
||||
if i < len(parts1) {
|
||||
n1 = parseSemverPart(parts1[i])
|
||||
}
|
||||
if i < len(parts2) {
|
||||
n2 = parseSemverPart(parts2[i])
|
||||
}
|
||||
if n2 > n1 {
|
||||
return true
|
||||
}
|
||||
if n2 < n1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false // 版本相同
|
||||
}
|
||||
|
||||
// parseSemverPart 解析语义版本的一段为数字,非数字前缀的部分返回 0
|
||||
func parseSemverPart(s string) int {
|
||||
// 去除可能的前缀字母 (如 "v1" → "1")
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// copyFile 复制文件内容
|
||||
func copyFile(src, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := io.Copy(out, in); err != nil {
|
||||
return err
|
||||
}
|
||||
return out.Sync()
|
||||
}
|
||||
Reference in New Issue
Block a user