181 lines
3.9 KiB
Go
181 lines
3.9 KiB
Go
package internal
|
|
|
|
import (
|
|
"fmt"
|
|
"maps"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type rawValidator interface {
|
|
Validate(param string, value string) error
|
|
}
|
|
|
|
type Executor struct {
|
|
style PlaceholderStyle
|
|
rawPolicy rawValidator
|
|
strict bool
|
|
blocks map[string][]Node
|
|
}
|
|
|
|
type Result struct {
|
|
SQL string
|
|
Args []any
|
|
}
|
|
|
|
func NewExecutor(style PlaceholderStyle, rawPolicy rawValidator, strict bool, blocks map[string][]Node) *Executor {
|
|
return &Executor{
|
|
style: style,
|
|
rawPolicy: rawPolicy,
|
|
strict: strict,
|
|
blocks: blocks,
|
|
}
|
|
}
|
|
|
|
func (e *Executor) Execute(nodes []Node, vars map[string]any) (*Result, error) {
|
|
ctx := NewContext(vars)
|
|
ph := NewPlaceholder(e.style)
|
|
var sql strings.Builder
|
|
var args []any
|
|
|
|
err := e.walk(ctx, ph, &sql, &args, nodes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s := strings.TrimRight(sql.String(), " \t\n\r")
|
|
if len(s) > 0 && s[len(s)-1] == ',' {
|
|
s = s[:len(s)-1]
|
|
}
|
|
|
|
return &Result{SQL: s, Args: args}, nil
|
|
}
|
|
|
|
func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, nodes []Node) error {
|
|
for _, node := range nodes {
|
|
switch n := node.(type) {
|
|
case *TextNode:
|
|
sql.WriteString(n.Text)
|
|
|
|
case *ParamNode:
|
|
val, ok := ctx.Get(n.Name)
|
|
if !ok {
|
|
if e.strict {
|
|
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
|
|
}
|
|
val = nil
|
|
}
|
|
sql.WriteString(ph.Next())
|
|
*args = append(*args, val)
|
|
|
|
case *RawNode:
|
|
val, ok := ctx.Get(n.Name)
|
|
if !ok {
|
|
if e.strict {
|
|
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
|
|
}
|
|
val = ""
|
|
}
|
|
strVal, ok := val.(string)
|
|
if !ok {
|
|
strVal = fmt.Sprint(val)
|
|
}
|
|
if e.rawPolicy != nil {
|
|
if err := e.rawPolicy.Validate(n.Name, strVal); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
sql.WriteString(strVal)
|
|
|
|
case *IfNode:
|
|
err := e.walkIf(ctx, ph, sql, args, n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case *ForNode:
|
|
err := e.walkFor(ctx, ph, sql, args, n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case *BlockNode, *NamespaceNode, *IncludeNode, *CommentNode:
|
|
// skip
|
|
case *UseNode:
|
|
if e.blocks == nil {
|
|
return fmt.Errorf("line %d, col %d: @use(\"%s\") no blocks available", n.Pos.Line, n.Pos.Col, n.Name)
|
|
}
|
|
blockNodes, ok := e.blocks[n.Name]
|
|
if !ok {
|
|
return fmt.Errorf("line %d, col %d: @use(\"%s\") block not found", n.Pos.Line, n.Pos.Col, n.Name)
|
|
}
|
|
if err := e.walk(ctx, ph, sql, args, blockNodes); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Executor) walkIf(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *IfNode) error {
|
|
condVal, err := Eval(n.Cond, ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if isTruthy(condVal) {
|
|
return e.walk(ctx, ph, sql, args, n.Body)
|
|
}
|
|
for _, branch := range n.ElseIf {
|
|
condVal, err = Eval(branch.Cond, ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if isTruthy(condVal) {
|
|
return e.walk(ctx, ph, sql, args, branch.Body)
|
|
}
|
|
}
|
|
if len(n.Else) > 0 {
|
|
return e.walk(ctx, ph, sql, args, n.Else)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *ForNode) error {
|
|
listVal, err := Eval(n.List, ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if listVal == nil {
|
|
return nil
|
|
}
|
|
rv := reflect.ValueOf(listVal)
|
|
for rv.Kind() == reflect.Pointer {
|
|
rv = rv.Elem()
|
|
}
|
|
|
|
var length int
|
|
switch rv.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
length = rv.Len()
|
|
default:
|
|
return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal)
|
|
}
|
|
|
|
for i := 0; i < length; i++ {
|
|
childVars := make(map[string]any, len(ctx.vars)+2)
|
|
maps.Copy(childVars, ctx.vars)
|
|
if n.KeyVar != "" {
|
|
childVars[n.KeyVar] = i
|
|
}
|
|
childVars[n.ValVar] = rv.Index(i).Interface()
|
|
|
|
childCtx := NewContext(childVars)
|
|
err := e.walk(childCtx, ph, sql, args, n.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|