package internal import ( "fmt" "maps" "reflect" "strings" ) type rawValidator interface { Validate(param string, value string) error } type Executor struct { style PlaceholderStyle rawPolicy rawValidator strict bool blocks map[string][]Node } type Result struct { SQL string Args []any } func NewExecutor(style PlaceholderStyle, rawPolicy rawValidator, strict bool, blocks map[string][]Node) *Executor { return &Executor{ style: style, rawPolicy: rawPolicy, strict: strict, blocks: blocks, } } func (e *Executor) Execute(nodes []Node, vars map[string]any) (*Result, error) { ctx := NewContext(vars) ph := NewPlaceholder(e.style) var sql strings.Builder var args []any err := e.walk(ctx, ph, &sql, &args, nodes) if err != nil { return nil, err } s := strings.TrimRight(sql.String(), " \t\n\r") if len(s) > 0 && s[len(s)-1] == ',' { s = s[:len(s)-1] } return &Result{SQL: s, Args: args}, nil } func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, nodes []Node) error { for _, node := range nodes { switch n := node.(type) { case *TextNode: sql.WriteString(n.Text) case *ParamNode: val, ok := ctx.Get(n.Name) if !ok { if e.strict { return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) } val = nil } sql.WriteString(ph.Next()) *args = append(*args, val) case *RawNode: val, ok := ctx.Get(n.Name) if !ok { if e.strict { return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) } val = "" } strVal, ok := val.(string) if !ok { strVal = fmt.Sprint(val) } if e.rawPolicy != nil { if err := e.rawPolicy.Validate(n.Name, strVal); err != nil { return err } } sql.WriteString(strVal) case *IfNode: err := e.walkIf(ctx, ph, sql, args, n) if err != nil { return err } case *ForNode: err := e.walkFor(ctx, ph, sql, args, n) if err != nil { return err } case *BlockNode, *NamespaceNode, *IncludeNode, *CommentNode: // skip case *UseNode: if e.blocks == nil { return fmt.Errorf("line %d, col %d: @use(\"%s\") no blocks available", n.Pos.Line, n.Pos.Col, n.Name) } blockNodes, ok := e.blocks[n.Name] if !ok { return fmt.Errorf("line %d, col %d: @use(\"%s\") block not found", n.Pos.Line, n.Pos.Col, n.Name) } if err := e.walk(ctx, ph, sql, args, blockNodes); err != nil { return err } } } return nil } func (e *Executor) walkIf(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *IfNode) error { condVal, err := Eval(n.Cond, ctx) if err != nil { return err } if isTruthy(condVal) { return e.walk(ctx, ph, sql, args, n.Body) } for _, branch := range n.ElseIf { condVal, err = Eval(branch.Cond, ctx) if err != nil { return err } if isTruthy(condVal) { return e.walk(ctx, ph, sql, args, branch.Body) } } if len(n.Else) > 0 { return e.walk(ctx, ph, sql, args, n.Else) } return nil } func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *ForNode) error { listVal, err := Eval(n.List, ctx) if err != nil { return err } if listVal == nil { return nil } rv := reflect.ValueOf(listVal) for rv.Kind() == reflect.Pointer { rv = rv.Elem() } var length int switch rv.Kind() { case reflect.Slice, reflect.Array: length = rv.Len() default: return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal) } for i := 0; i < length; i++ { childVars := make(map[string]any, len(ctx.vars)+2) maps.Copy(childVars, ctx.vars) if n.KeyVar != "" { childVars[n.KeyVar] = i } childVars[n.ValVar] = rv.Index(i).Interface() childCtx := NewContext(childVars) err := e.walk(childCtx, ph, sql, args, n.Body) if err != nil { return err } } return nil }