Files
u-tpl/internal/expr.go
绝尘 a6847c7d18 优化: 补全所有整数类型的跨类型比较支持
toFloat64 和 isTruthy 补充 int8/int16/int32/uint8/uint16/uint32
2026-04-01 01:01:49 +08:00

593 lines
12 KiB
Go

package internal
import (
"fmt"
"reflect"
"strconv"
"strings"
"unicode"
)
type ExprParser struct {
input []rune
pos int
line int
col int
}
func NewExprParser(input string, line, col int) *ExprParser {
return &ExprParser{
input: []rune(input),
pos: 0,
line: line,
col: col,
}
}
func (p *ExprParser) Parse() (*Expr, error) {
expr, err := p.parseOr()
if err != nil {
return nil, err
}
return expr, nil
}
func (p *ExprParser) parseOr() (*Expr, error) {
left, err := p.parseAnd()
if err != nil {
return nil, err
}
for {
p.skipSpaces()
if !p.peekStr("||") {
break
}
p.skip(2)
p.skipSpaces()
right, err := p.parseAnd()
if err != nil {
return nil, err
}
left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "||", Right: right}
}
return left, nil
}
func (p *ExprParser) parseAnd() (*Expr, error) {
left, err := p.parseCompare()
if err != nil {
return nil, err
}
for {
p.skipSpaces()
if !p.peekStr("&&") {
break
}
p.skip(2)
p.skipSpaces()
right, err := p.parseCompare()
if err != nil {
return nil, err
}
left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "&&", Right: right}
}
return left, nil
}
func (p *ExprParser) parseCompare() (*Expr, error) {
left, err := p.parseUnary()
if err != nil {
return nil, err
}
p.skipSpaces()
op := ""
if p.peekStr("==") {
op = "=="
p.skip(2)
} else if p.peekStr("!=") {
op = "!="
p.skip(2)
} else if p.peekStr("<=") {
op = "<="
p.skip(2)
} else if p.peekStr(">=") {
op = ">="
p.skip(2)
} else if p.peekStr("<") {
op = "<"
p.skip(1)
} else if p.peekStr(">") {
op = ">"
p.skip(1)
}
if op == "" {
return left, nil
}
p.skipSpaces()
right, err := p.parseUnary()
if err != nil {
return nil, err
}
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: op, Right: right}, nil
}
func (p *ExprParser) parseUnary() (*Expr, error) {
p.skipSpaces()
if p.peekStr("!") && !p.peekStr("!=") {
p.skip(1)
operand, err := p.parseUnary()
if err != nil {
return nil, err
}
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprUnary, UnaryOp: "!", Operand: operand}, nil
}
return p.parsePrimary()
}
func (p *ExprParser) parsePrimary() (*Expr, error) {
p.skipSpaces()
if p.pos >= len(p.input) {
return nil, fmt.Errorf("line %d, col %d: unexpected end of expression", p.line, p.col)
}
ch := p.input[p.pos]
if ch == '"' {
return p.parseStringLit()
}
if ch == '\'' {
return p.parseSingleQuoteStringLit()
}
if ch >= '0' && ch <= '9' {
return p.parseNumberLit()
}
if isIdentStart(ch) {
return p.parseIdentOrKeyword()
}
if ch == '(' {
p.skip(1)
p.skipSpaces()
expr, err := p.parseOr()
if err != nil {
return nil, err
}
p.skipSpaces()
if p.pos >= len(p.input) || p.input[p.pos] != ')' {
return nil, fmt.Errorf("line %d, col %d: expected ')'", p.line, p.col)
}
p.skip(1)
return expr, nil
}
return nil, fmt.Errorf("line %d, col %d: unexpected character %q", p.line, p.col, string(ch))
}
func (p *ExprParser) parseIdentOrKeyword() (*Expr, error) {
start := p.pos
for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) {
p.pos++
}
name := string(p.input[start:p.pos])
switch name {
case "true":
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: true}, nil
case "false":
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: false}, nil
case "nil":
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprNil}, nil
}
p.skipSpaces()
if p.pos < len(p.input) && p.input[p.pos] == '(' {
return p.parseFuncCall(name)
}
p.skipSpaces()
varName := name
for p.pos < len(p.input) && p.input[p.pos] == '.' {
p.skip(1)
if p.pos >= len(p.input) || !isIdentStart(p.input[p.pos]) {
return nil, fmt.Errorf("line %d, col %d: expected identifier after '.'", p.line, p.col)
}
segStart := p.pos
for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) {
p.pos++
}
varName += "." + string(p.input[segStart:p.pos])
p.skipSpaces()
}
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprVariable, Name: varName}, nil
}
func (p *ExprParser) parseFuncCall(name string) (*Expr, error) {
p.skip(1)
var args []*Expr
p.skipSpaces()
if p.pos < len(p.input) && p.input[p.pos] != ')' {
for {
arg, err := p.parseOr()
if err != nil {
return nil, err
}
args = append(args, arg)
p.skipSpaces()
if p.pos < len(p.input) && p.input[p.pos] == ',' {
p.skip(1)
p.skipSpaces()
continue
}
break
}
}
p.skipSpaces()
if p.pos >= len(p.input) || p.input[p.pos] != ')' {
return nil, fmt.Errorf("line %d, col %d: expected ')' after function call", p.line, p.col)
}
p.skip(1)
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprFuncCall, FuncName: name, FuncArgs: args}, nil
}
func (p *ExprParser) parseStringLit() (*Expr, error) {
p.skip(1)
var buf strings.Builder
for p.pos < len(p.input) && p.input[p.pos] != '"' {
if p.input[p.pos] == '\\' && p.pos+1 < len(p.input) {
p.pos++
switch p.input[p.pos] {
case 'n':
buf.WriteRune('\n')
case 't':
buf.WriteRune('\t')
case '\\':
buf.WriteRune('\\')
case '"':
buf.WriteRune('"')
default:
buf.WriteRune('\\')
buf.WriteRune(p.input[p.pos])
}
} else {
buf.WriteRune(p.input[p.pos])
}
p.pos++
}
if p.pos >= len(p.input) {
return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col)
}
p.skip(1)
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil
}
func (p *ExprParser) parseSingleQuoteStringLit() (*Expr, error) {
p.skip(1)
var buf strings.Builder
for p.pos < len(p.input) && p.input[p.pos] != '\'' {
buf.WriteRune(p.input[p.pos])
p.pos++
}
if p.pos >= len(p.input) {
return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col)
}
p.skip(1)
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil
}
func (p *ExprParser) parseNumberLit() (*Expr, error) {
start := p.pos
for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' {
p.pos++
}
isFloat := false
if p.pos < len(p.input) && p.input[p.pos] == '.' {
next := p.pos + 1
if next < len(p.input) && p.input[next] >= '0' && p.input[next] <= '9' {
isFloat = true
p.pos++
for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' {
p.pos++
}
}
}
text := string(p.input[start:p.pos])
if isFloat {
v, err := strconv.ParseFloat(text, 64)
if err != nil {
return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text)
}
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: v}, nil
}
v, err := strconv.ParseInt(text, 10, 64)
if err != nil {
return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text)
}
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: int(v)}, nil
}
func (p *ExprParser) peekStr(s string) bool {
if p.pos+len(s) > len(p.input) {
return false
}
for i, ch := range s {
if p.input[p.pos+i] != ch {
return false
}
}
return true
}
func (p *ExprParser) skip(n int) {
p.pos += n
p.col += n
}
func (p *ExprParser) skipSpaces() {
for p.pos < len(p.input) && unicode.IsSpace(p.input[p.pos]) {
if p.input[p.pos] == '\n' {
p.line++
p.col = 0
} else {
p.col++
}
p.pos++
}
}
func isIdentStart(ch rune) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
}
func isIdentPart(ch rune) bool {
return isIdentStart(ch) || (ch >= '0' && ch <= '9')
}
func Eval(expr *Expr, ctx *Context) (any, error) {
switch expr.ExprType {
case ExprLiteral:
return expr.Value, nil
case ExprNil:
return nil, nil
case ExprVariable:
val, ok := ctx.Get(expr.Name)
if !ok {
return nil, nil
}
return val, nil
case ExprUnary:
return evalUnary(expr, ctx)
case ExprBinary:
return evalBinary(expr, ctx)
case ExprFuncCall:
return evalFuncCall(expr, ctx)
default:
return nil, fmt.Errorf("line %d, col %d: unknown expression type", expr.Pos.Line, expr.Pos.Col)
}
}
func evalUnary(expr *Expr, ctx *Context) (any, error) {
val, err := Eval(expr.Operand, ctx)
if err != nil {
return nil, err
}
switch expr.UnaryOp {
case "!":
return !isTruthy(val), nil
default:
return nil, fmt.Errorf("line %d, col %d: unknown unary operator %q", expr.Pos.Line, expr.Pos.Col, expr.UnaryOp)
}
}
func evalBinary(expr *Expr, ctx *Context) (any, error) {
switch expr.Op {
case "&&":
left, err := Eval(expr.Left, ctx)
if err != nil {
return nil, err
}
if !isTruthy(left) {
return left, nil
}
return Eval(expr.Right, ctx)
case "||":
left, err := Eval(expr.Left, ctx)
if err != nil {
return nil, err
}
if isTruthy(left) {
return left, nil
}
return Eval(expr.Right, ctx)
}
left, err := Eval(expr.Left, ctx)
if err != nil {
return nil, err
}
right, err := Eval(expr.Right, ctx)
if err != nil {
return nil, err
}
switch expr.Op {
case "==":
return compareEqual(left, right), nil
case "!=":
return !compareEqual(left, right), nil
case "<":
return compareOrder(left, right, expr.Op)
case ">":
return compareOrder(left, right, expr.Op)
case "<=":
return compareOrder(left, right, expr.Op)
case ">=":
return compareOrder(left, right, expr.Op)
default:
return nil, fmt.Errorf("line %d, col %d: unknown operator %q", expr.Pos.Line, expr.Pos.Col, expr.Op)
}
}
func evalFuncCall(expr *Expr, ctx *Context) (any, error) {
fn, ok := LookupBuiltin(expr.FuncName)
if !ok {
return nil, fmt.Errorf("line %d, col %d: unknown function %q", expr.Pos.Line, expr.Pos.Col, expr.FuncName)
}
var args []any
for _, a := range expr.FuncArgs {
val, err := Eval(a, ctx)
if err != nil {
return nil, err
}
args = append(args, val)
}
result, ok := fn(args)
if !ok {
return nil, fmt.Errorf("line %d, col %d: function %q call failed", expr.Pos.Line, expr.Pos.Col, expr.FuncName)
}
return result, nil
}
func isTruthy(val any) bool {
if val == nil {
return false
}
switch v := val.(type) {
case bool:
return v
case int:
return v != 0
case int8:
return v != 0
case int16:
return v != 0
case int32:
return v != 0
case int64:
return v != 0
case uint:
return v != 0
case uint8:
return v != 0
case uint16:
return v != 0
case uint32:
return v != 0
case uint64:
return v != 0
case float64:
return v != 0
case float32:
return v != 0
case string:
return v != ""
case []any:
return len(v) > 0
case map[string]any:
return len(v) > 0
default:
rv := reflect.ValueOf(val)
switch rv.Kind() {
case reflect.Slice, reflect.Array, reflect.Map:
return rv.Len() > 0
default:
return true
}
}
}
func compareEqual(left, right any) bool {
if left == nil && right == nil {
return true
}
if left == nil || right == nil {
return false
}
lbool, lbOk := left.(bool)
rbool, rbOk := right.(bool)
if lbOk || rbOk {
return lbOk && rbOk && lbool == rbool
}
lf, lok := toFloat64(left)
rf, rok := toFloat64(right)
if lok && rok {
return lf == rf
}
lstr, lsOk := left.(string)
rstr, rsOk := right.(string)
if lsOk && rsOk {
return lstr == rstr
}
return fmt.Sprintf("%v", left) == fmt.Sprintf("%v", right)
}
func compareOrder(left, right any, op string) (bool, error) {
if left == nil || right == nil {
return false, nil
}
lf, lok := toFloat64(left)
rf, rok := toFloat64(right)
if lok && rok {
switch op {
case "<":
return lf < rf, nil
case ">":
return lf > rf, nil
case "<=":
return lf <= rf, nil
case ">=":
return lf >= rf, nil
}
}
ls, lsOk := left.(string)
rs, rsOk := right.(string)
if lsOk && rsOk {
switch op {
case "<":
return ls < rs, nil
case ">":
return ls > rs, nil
case "<=":
return ls <= rs, nil
case ">=":
return ls >= rs, nil
}
}
return false, fmt.Errorf("line 0, col 0: cannot compare %T and %T with %s", left, right, op)
}
func toFloat64(val any) (float64, bool) {
switch v := val.(type) {
case int:
return float64(v), true
case int64:
return float64(v), true
case float64:
return v, true
case float32:
return float64(v), true
case uint:
return float64(v), true
case uint64:
return float64(v), true
case int32:
return float64(v), true
case int16:
return float64(v), true
case int8:
return float64(v), true
case uint32:
return float64(v), true
case uint16:
return float64(v), true
case uint8:
return float64(v), true
default:
return 0, false
}
}