Files
u-tpl/internal/executor.go
绝尘 1b5b6aff8f 新增: @use 同文件片段复用
支持 @use("name") 引用同一文件内 @tpl 定义的块,
消除 _list/_count 模板中 WHERE 条件重复问题。
2026-04-01 01:59:51 +08:00

181 lines
3.9 KiB
Go

package internal
import (
"fmt"
"maps"
"reflect"
"strings"
)
type rawValidator interface {
Validate(param string, value string) error
}
type Executor struct {
style PlaceholderStyle
rawPolicy rawValidator
strict bool
blocks map[string][]Node
}
type Result struct {
SQL string
Args []any
}
func NewExecutor(style PlaceholderStyle, rawPolicy rawValidator, strict bool, blocks map[string][]Node) *Executor {
return &Executor{
style: style,
rawPolicy: rawPolicy,
strict: strict,
blocks: blocks,
}
}
func (e *Executor) Execute(nodes []Node, vars map[string]any) (*Result, error) {
ctx := NewContext(vars)
ph := NewPlaceholder(e.style)
var sql strings.Builder
var args []any
err := e.walk(ctx, ph, &sql, &args, nodes)
if err != nil {
return nil, err
}
s := strings.TrimRight(sql.String(), " \t\n\r")
if len(s) > 0 && s[len(s)-1] == ',' {
s = s[:len(s)-1]
}
return &Result{SQL: s, Args: args}, nil
}
func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, nodes []Node) error {
for _, node := range nodes {
switch n := node.(type) {
case *TextNode:
sql.WriteString(n.Text)
case *ParamNode:
val, ok := ctx.Get(n.Name)
if !ok {
if e.strict {
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
}
val = nil
}
sql.WriteString(ph.Next())
*args = append(*args, val)
case *RawNode:
val, ok := ctx.Get(n.Name)
if !ok {
if e.strict {
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
}
val = ""
}
strVal, ok := val.(string)
if !ok {
strVal = fmt.Sprint(val)
}
if e.rawPolicy != nil {
if err := e.rawPolicy.Validate(n.Name, strVal); err != nil {
return err
}
}
sql.WriteString(strVal)
case *IfNode:
err := e.walkIf(ctx, ph, sql, args, n)
if err != nil {
return err
}
case *ForNode:
err := e.walkFor(ctx, ph, sql, args, n)
if err != nil {
return err
}
case *BlockNode, *NamespaceNode, *IncludeNode, *CommentNode:
// skip
case *UseNode:
if e.blocks == nil {
return fmt.Errorf("line %d, col %d: @use(\"%s\") no blocks available", n.Pos.Line, n.Pos.Col, n.Name)
}
blockNodes, ok := e.blocks[n.Name]
if !ok {
return fmt.Errorf("line %d, col %d: @use(\"%s\") block not found", n.Pos.Line, n.Pos.Col, n.Name)
}
if err := e.walk(ctx, ph, sql, args, blockNodes); err != nil {
return err
}
}
}
return nil
}
func (e *Executor) walkIf(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *IfNode) error {
condVal, err := Eval(n.Cond, ctx)
if err != nil {
return err
}
if isTruthy(condVal) {
return e.walk(ctx, ph, sql, args, n.Body)
}
for _, branch := range n.ElseIf {
condVal, err = Eval(branch.Cond, ctx)
if err != nil {
return err
}
if isTruthy(condVal) {
return e.walk(ctx, ph, sql, args, branch.Body)
}
}
if len(n.Else) > 0 {
return e.walk(ctx, ph, sql, args, n.Else)
}
return nil
}
func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *ForNode) error {
listVal, err := Eval(n.List, ctx)
if err != nil {
return err
}
if listVal == nil {
return nil
}
rv := reflect.ValueOf(listVal)
for rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
var length int
switch rv.Kind() {
case reflect.Slice, reflect.Array:
length = rv.Len()
default:
return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal)
}
for i := 0; i < length; i++ {
childVars := make(map[string]any, len(ctx.vars)+2)
maps.Copy(childVars, ctx.vars)
if n.KeyVar != "" {
childVars[n.KeyVar] = i
}
childVars[n.ValVar] = rv.Index(i).Interface()
childCtx := NewContext(childVars)
err := e.walk(childCtx, ph, sql, args, n.Body)
if err != nil {
return err
}
}
return nil
}