diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..04d7e04 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +* text=auto eol=lf +*.go text eol=lf +*.md text eol=lf +*.sql text eol=lf +*.tpl text eol=lf diff --git a/.gitignore b/.gitignore index 4c5f206..5fbb404 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ .claude/ +.idea/ +*.exe +*.test +*.out +coverage.out +vendor/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..14fac91 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..92516f4 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.PHONY: build test bench lint vet clean + +build: + go build ./... + +test: + go test ./... -v + +bench: + go test ./... -bench=. -benchmem + +lint: vet + gofmt -l . + +vet: + go vet ./... + +clean: + go clean -testcache diff --git a/README.md b/README.md index 1f9d2ae..c3e26a8 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ package main import ( "fmt" - "github.com/nicedoc/utpl" + "gitea.1216.top/lxy/u-tpl" ) func main() { @@ -52,7 +52,7 @@ func main() { ## 安装 ```bash -go get github.com/nicedoc/utpl +go get gitea.1216.top/lxy/u-tpl ``` --- @@ -673,7 +673,7 @@ package main import ( _ "embed" - "github.com/nicedoc/utpl" + "gitea.1216.top/lxy/u-tpl" ) //go:embed tpl/order_search.tpl diff --git a/engine.go b/engine.go new file mode 100644 index 0000000..9a07456 --- /dev/null +++ b/engine.go @@ -0,0 +1,121 @@ +package utpl + +import ( + "fmt" + + "gitea.1216.top/lxy/u-tpl/internal" +) + +type PlaceholderStyle = internal.PlaceholderStyle + +const ( + QuestionMark PlaceholderStyle = internal.QuestionMark + DollarNumber = internal.DollarNumber + ColonNumber = internal.ColonNumber +) + +type IncludeResolver func(path string) (string, error) + +type Option func(*Engine) + +type Engine struct { + style internal.PlaceholderStyle + rawPolicy RawPolicy + includeResolver IncludeResolver + strict bool +} + +func New(opts ...Option) *Engine { + e := &Engine{ + style: internal.QuestionMark, + strict: true, + } + for _, opt := range opts { + opt(e) + } + return e +} + +func WithPlaceholderStyle(style PlaceholderStyle) Option { + return func(e *Engine) { e.style = style } +} + +func WithRawPolicy(policy RawPolicy) Option { + return func(e *Engine) { e.rawPolicy = policy } +} + +func WithIncludeResolver(resolver IncludeResolver) Option { + return func(e *Engine) { e.includeResolver = resolver } +} + +func WithStrictMode(strict bool) Option { + return func(e *Engine) { e.strict = strict } +} + +func (e *Engine) Parse(name string, source string) (*Template, error) { + lexer := internal.NewLexer(source) + tokens, err := lexer.Tokenize() + if err != nil { + return nil, wrapParseError(err, name) + } + + var includeMgr *internal.IncludeManager + if e.includeResolver != nil { + includeMgr = internal.NewIncludeManager(internal.IncludeResolver(e.includeResolver)) + } + + parser := internal.NewParser(source, tokens, includeMgr) + nodes, err := parser.Parse() + if err != nil { + return nil, wrapParseError(err, name) + } + + namespace := "" + blocks := make(map[string][]internal.Node) + var bodyNodes []internal.Node + + hasBlocks := false + for _, n := range nodes { + if ns, ok := n.(*internal.NamespaceNode); ok { + namespace = ns.Name + continue + } + if blk, ok := n.(*internal.BlockNode); ok { + hasBlocks = true + fullName := blk.Name + if namespace != "" { + fullName = namespace + "." + blk.Name + } + blocks[fullName] = blk.Body + continue + } + bodyNodes = append(bodyNodes, n) + } + + return &Template{ + name: name, + engine: e, + nodes: bodyNodes, + blocks: blocks, + hasBlocks: hasBlocks, + namespace: namespace, + }, nil +} + +func (e *Engine) MustParse(name string, source string) *Template { + tpl, err := e.Parse(name, source) + if err != nil { + panic(err) + } + return tpl +} + +func wrapParseError(err error, name string) error { + if _, ok := err.(*ParseError); ok { + return err + } + return &ParseError{ + Pos: Position{Line: 0, Column: 0}, + Message: fmt.Sprintf("template %q: %s", name, err.Error()), + } +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..d46d8c8 --- /dev/null +++ b/error.go @@ -0,0 +1,42 @@ +package utpl + +import "fmt" + +type Position struct { + Line int + Column int +} + +func (p Position) String() string { + return fmt.Sprintf("line %d, column %d", p.Line, p.Column) +} + +type ParseError struct { + Message string + Pos Position + Token string +} + +func (e ParseError) Error() string { + return fmt.Sprintf("%s: %s (token: %q)", e.Pos, e.Message, e.Token) +} + +type ExecError struct { + Pos Position + Message string +} + +func (e ExecError) Error() string { + return fmt.Sprintf("%s: %s", e.Pos, e.Message) +} + +type UnsafeRawError struct { + Message string + Pos Position + Param string + Value string +} + +func (e UnsafeRawError) Error() string { + return fmt.Sprintf("%s: %s (param: %q, value: %q)", e.Pos, e.Message, e.Param, e.Value) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ffbc60c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module gitea.1216.top/lxy/u-tpl + +go 1.26.1 diff --git a/internal/builtin.go b/internal/builtin.go new file mode 100644 index 0000000..909f794 --- /dev/null +++ b/internal/builtin.go @@ -0,0 +1,30 @@ +package internal + +import "reflect" + +type BuiltinFunc func(args []any) (any, bool) + +var builtins = map[string]BuiltinFunc{ + "len": builtinLen, +} + +func builtinLen(args []any) (any, bool) { + if len(args) != 1 { + return nil, false + } + if args[0] == nil { + return 0, true + } + v := reflect.ValueOf(args[0]) + switch v.Kind() { + case reflect.Slice, reflect.Array, reflect.String, reflect.Map, reflect.Chan: + return v.Len(), true + default: + return 0, false + } +} + +func LookupBuiltin(name string) (BuiltinFunc, bool) { + fn, ok := builtins[name] + return fn, ok +} diff --git a/internal/context.go b/internal/context.go new file mode 100644 index 0000000..8f6fe7c --- /dev/null +++ b/internal/context.go @@ -0,0 +1,68 @@ +package internal + +import ( + "reflect" + "strings" +) + +type Context struct { + vars map[string]any +} + +func NewContext(vars map[string]any) *Context { + return &Context{vars: vars} +} + +func (c *Context) Get(path string) (any, bool) { + if c.vars == nil { + return nil, false + } + return resolvePath(c.vars, path) +} + +func resolvePath(current any, path string) (any, bool) { + for seg := range strings.SplitSeq(path, ".") { + if current == nil { + return nil, false + } + switch v := current.(type) { + case map[string]any: + var ok bool + current, ok = v[seg] + if !ok { + return nil, false + } + default: + val := reflect.ValueOf(current) + for val.Kind() == reflect.Pointer { + val = val.Elem() + } + if val.Kind() != reflect.Struct { + return nil, false + } + field := val.FieldByName(seg) + if !field.IsValid() { + field = findFieldIgnoreCase(val, seg) + } + if !field.IsValid() { + return nil, false + } + current = field.Interface() + } + } + return current, true +} + +func findFieldIgnoreCase(v reflect.Value, name string) reflect.Value { + typ := v.Type() + for i := range typ.NumField() { + f := typ.Field(i) + if !f.IsExported() { + continue + } + if strings.EqualFold(f.Name, name) { + return v.Field(i) + } + } + return reflect.Value{} +} diff --git a/internal/executor.go b/internal/executor.go new file mode 100644 index 0000000..927c5eb --- /dev/null +++ b/internal/executor.go @@ -0,0 +1,167 @@ +package internal + +import ( + "fmt" + "maps" + "reflect" + "strings" +) + +type rawValidator interface { + Validate(param string, value string) error +} + +type Executor struct { + style PlaceholderStyle + rawPolicy rawValidator + strict bool +} + +type Result struct { + SQL string + Args []any +} + +func NewExecutor(style PlaceholderStyle, rawPolicy rawValidator, strict bool) *Executor { + return &Executor{ + style: style, + rawPolicy: rawPolicy, + strict: strict, + } +} + +func (e *Executor) Execute(nodes []Node, vars map[string]any) (*Result, error) { + ctx := NewContext(vars) + ph := NewPlaceholder(e.style) + var sql strings.Builder + var args []any + + err := e.walk(ctx, ph, &sql, &args, nodes) + if err != nil { + return nil, err + } + + s := strings.TrimRight(sql.String(), " \t\n\r") + if len(s) > 0 && s[len(s)-1] == ',' { + s = s[:len(s)-1] + } + + return &Result{SQL: s, Args: args}, nil +} + +func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, nodes []Node) error { + for _, node := range nodes { + switch n := node.(type) { + case *TextNode: + sql.WriteString(n.Text) + + case *ParamNode: + val, ok := ctx.Get(n.Name) + if !ok { + if e.strict { + return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) + } + val = nil + } + sql.WriteString(ph.Next()) + *args = append(*args, val) + + case *RawNode: + val, ok := ctx.Get(n.Name) + if !ok { + if e.strict { + return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) + } + val = "" + } + strVal, ok := val.(string) + if !ok { + strVal = fmt.Sprint(val) + } + if e.rawPolicy != nil { + if err := e.rawPolicy.Validate(n.Name, strVal); err != nil { + return err + } + } + sql.WriteString(strVal) + + case *IfNode: + err := e.walkIf(ctx, ph, sql, args, n) + if err != nil { + return err + } + + case *ForNode: + err := e.walkFor(ctx, ph, sql, args, n) + if err != nil { + return err + } + + case *BlockNode, *NamespaceNode, *IncludeNode, *CommentNode: + // skip + } + } + return nil +} + +func (e *Executor) walkIf(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *IfNode) error { + condVal, err := Eval(n.Cond, ctx) + if err != nil { + return err + } + if isTruthy(condVal) { + return e.walk(ctx, ph, sql, args, n.Body) + } + for _, branch := range n.ElseIf { + condVal, err = Eval(branch.Cond, ctx) + if err != nil { + return err + } + if isTruthy(condVal) { + return e.walk(ctx, ph, sql, args, branch.Body) + } + } + if len(n.Else) > 0 { + return e.walk(ctx, ph, sql, args, n.Else) + } + return nil +} + +func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *ForNode) error { + listVal, err := Eval(n.List, ctx) + if err != nil { + return err + } + + if listVal == nil { + return nil + } + rv := reflect.ValueOf(listVal) + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + + var length int + switch rv.Kind() { + case reflect.Slice, reflect.Array: + length = rv.Len() + default: + return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal) + } + + for i := 0; i < length; i++ { + childVars := make(map[string]any, len(ctx.vars)+2) + maps.Copy(childVars, ctx.vars) + if n.KeyVar != "" { + childVars[n.KeyVar] = i + } + childVars[n.ValVar] = rv.Index(i).Interface() + + childCtx := NewContext(childVars) + err := e.walk(childCtx, ph, sql, args, n.Body) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/expr.go b/internal/expr.go new file mode 100644 index 0000000..e481268 --- /dev/null +++ b/internal/expr.go @@ -0,0 +1,564 @@ +package internal + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "unicode" +) + +type ExprParser struct { + input []rune + pos int + line int + col int +} + +func NewExprParser(input string, line, col int) *ExprParser { + return &ExprParser{ + input: []rune(input), + pos: 0, + line: line, + col: col, + } +} + +func (p *ExprParser) Parse() (*Expr, error) { + expr, err := p.parseOr() + if err != nil { + return nil, err + } + return expr, nil +} + +func (p *ExprParser) parseOr() (*Expr, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for { + p.skipSpaces() + if !p.peekStr("||") { + break + } + p.skip(2) + p.skipSpaces() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "||", Right: right} + } + return left, nil +} + +func (p *ExprParser) parseAnd() (*Expr, error) { + left, err := p.parseCompare() + if err != nil { + return nil, err + } + for { + p.skipSpaces() + if !p.peekStr("&&") { + break + } + p.skip(2) + p.skipSpaces() + right, err := p.parseCompare() + if err != nil { + return nil, err + } + left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "&&", Right: right} + } + return left, nil +} + +func (p *ExprParser) parseCompare() (*Expr, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + p.skipSpaces() + op := "" + if p.peekStr("==") { + op = "==" + p.skip(2) + } else if p.peekStr("!=") { + op = "!=" + p.skip(2) + } else if p.peekStr("<=") { + op = "<=" + p.skip(2) + } else if p.peekStr(">=") { + op = ">=" + p.skip(2) + } else if p.peekStr("<") { + op = "<" + p.skip(1) + } else if p.peekStr(">") { + op = ">" + p.skip(1) + } + if op == "" { + return left, nil + } + p.skipSpaces() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: op, Right: right}, nil +} + +func (p *ExprParser) parseUnary() (*Expr, error) { + p.skipSpaces() + if p.peekStr("!") && !p.peekStr("!=") { + p.skip(1) + operand, err := p.parseUnary() + if err != nil { + return nil, err + } + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprUnary, UnaryOp: "!", Operand: operand}, nil + } + return p.parsePrimary() +} + +func (p *ExprParser) parsePrimary() (*Expr, error) { + p.skipSpaces() + if p.pos >= len(p.input) { + return nil, fmt.Errorf("line %d, col %d: unexpected end of expression", p.line, p.col) + } + ch := p.input[p.pos] + + if ch == '"' { + return p.parseStringLit() + } + if ch == '\'' { + return p.parseSingleQuoteStringLit() + } + if ch >= '0' && ch <= '9' { + return p.parseNumberLit() + } + if isIdentStart(ch) { + return p.parseIdentOrKeyword() + } + if ch == '(' { + p.skip(1) + p.skipSpaces() + expr, err := p.parseOr() + if err != nil { + return nil, err + } + p.skipSpaces() + if p.pos >= len(p.input) || p.input[p.pos] != ')' { + return nil, fmt.Errorf("line %d, col %d: expected ')'", p.line, p.col) + } + p.skip(1) + return expr, nil + } + return nil, fmt.Errorf("line %d, col %d: unexpected character %q", p.line, p.col, string(ch)) +} + +func (p *ExprParser) parseIdentOrKeyword() (*Expr, error) { + start := p.pos + for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) { + p.pos++ + } + name := string(p.input[start:p.pos]) + + switch name { + case "true": + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: true}, nil + case "false": + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: false}, nil + case "nil": + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprNil}, nil + } + + p.skipSpaces() + if p.pos < len(p.input) && p.input[p.pos] == '(' { + return p.parseFuncCall(name) + } + + p.skipSpaces() + varName := name + for p.pos < len(p.input) && p.input[p.pos] == '.' { + p.skip(1) + if p.pos >= len(p.input) || !isIdentStart(p.input[p.pos]) { + return nil, fmt.Errorf("line %d, col %d: expected identifier after '.'", p.line, p.col) + } + segStart := p.pos + for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) { + p.pos++ + } + varName += "." + string(p.input[segStart:p.pos]) + p.skipSpaces() + } + + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprVariable, Name: varName}, nil +} + +func (p *ExprParser) parseFuncCall(name string) (*Expr, error) { + p.skip(1) + var args []*Expr + p.skipSpaces() + if p.pos < len(p.input) && p.input[p.pos] != ')' { + for { + arg, err := p.parseOr() + if err != nil { + return nil, err + } + args = append(args, arg) + p.skipSpaces() + if p.pos < len(p.input) && p.input[p.pos] == ',' { + p.skip(1) + p.skipSpaces() + continue + } + break + } + } + p.skipSpaces() + if p.pos >= len(p.input) || p.input[p.pos] != ')' { + return nil, fmt.Errorf("line %d, col %d: expected ')' after function call", p.line, p.col) + } + p.skip(1) + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprFuncCall, FuncName: name, FuncArgs: args}, nil +} + +func (p *ExprParser) parseStringLit() (*Expr, error) { + p.skip(1) + var buf strings.Builder + for p.pos < len(p.input) && p.input[p.pos] != '"' { + if p.input[p.pos] == '\\' && p.pos+1 < len(p.input) { + p.pos++ + switch p.input[p.pos] { + case 'n': + buf.WriteRune('\n') + case 't': + buf.WriteRune('\t') + case '\\': + buf.WriteRune('\\') + case '"': + buf.WriteRune('"') + default: + buf.WriteRune('\\') + buf.WriteRune(p.input[p.pos]) + } + } else { + buf.WriteRune(p.input[p.pos]) + } + p.pos++ + } + if p.pos >= len(p.input) { + return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col) + } + p.skip(1) + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil +} + +func (p *ExprParser) parseSingleQuoteStringLit() (*Expr, error) { + p.skip(1) + var buf strings.Builder + for p.pos < len(p.input) && p.input[p.pos] != '\'' { + buf.WriteRune(p.input[p.pos]) + p.pos++ + } + if p.pos >= len(p.input) { + return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col) + } + p.skip(1) + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil +} + +func (p *ExprParser) parseNumberLit() (*Expr, error) { + start := p.pos + for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' { + p.pos++ + } + isFloat := false + if p.pos < len(p.input) && p.input[p.pos] == '.' { + next := p.pos + 1 + if next < len(p.input) && p.input[next] >= '0' && p.input[next] <= '9' { + isFloat = true + p.pos++ + for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' { + p.pos++ + } + } + } + text := string(p.input[start:p.pos]) + if isFloat { + v, err := strconv.ParseFloat(text, 64) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text) + } + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: v}, nil + } + v, err := strconv.ParseInt(text, 10, 64) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text) + } + return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: int(v)}, nil +} + +func (p *ExprParser) peekStr(s string) bool { + if p.pos+len(s) > len(p.input) { + return false + } + for i, ch := range s { + if p.input[p.pos+i] != ch { + return false + } + } + return true +} + +func (p *ExprParser) skip(n int) { + p.pos += n + p.col += n +} + +func (p *ExprParser) skipSpaces() { + for p.pos < len(p.input) && unicode.IsSpace(p.input[p.pos]) { + if p.input[p.pos] == '\n' { + p.line++ + p.col = 0 + } else { + p.col++ + } + p.pos++ + } +} + +func isIdentStart(ch rune) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentPart(ch rune) bool { + return isIdentStart(ch) || (ch >= '0' && ch <= '9') +} + +func Eval(expr *Expr, ctx *Context) (any, error) { + switch expr.ExprType { + case ExprLiteral: + return expr.Value, nil + case ExprNil: + return nil, nil + case ExprVariable: + val, ok := ctx.Get(expr.Name) + if !ok { + return nil, nil + } + return val, nil + case ExprUnary: + return evalUnary(expr, ctx) + case ExprBinary: + return evalBinary(expr, ctx) + case ExprFuncCall: + return evalFuncCall(expr, ctx) + default: + return nil, fmt.Errorf("line %d, col %d: unknown expression type", expr.Pos.Line, expr.Pos.Col) + } +} + +func evalUnary(expr *Expr, ctx *Context) (any, error) { + val, err := Eval(expr.Operand, ctx) + if err != nil { + return nil, err + } + switch expr.UnaryOp { + case "!": + return !isTruthy(val), nil + default: + return nil, fmt.Errorf("line %d, col %d: unknown unary operator %q", expr.Pos.Line, expr.Pos.Col, expr.UnaryOp) + } +} + +func evalBinary(expr *Expr, ctx *Context) (any, error) { + switch expr.Op { + case "&&": + left, err := Eval(expr.Left, ctx) + if err != nil { + return nil, err + } + if !isTruthy(left) { + return left, nil + } + return Eval(expr.Right, ctx) + case "||": + left, err := Eval(expr.Left, ctx) + if err != nil { + return nil, err + } + if isTruthy(left) { + return left, nil + } + return Eval(expr.Right, ctx) + } + + left, err := Eval(expr.Left, ctx) + if err != nil { + return nil, err + } + right, err := Eval(expr.Right, ctx) + if err != nil { + return nil, err + } + + switch expr.Op { + case "==": + return compareEqual(left, right), nil + case "!=": + return !compareEqual(left, right), nil + case "<": + return compareOrder(left, right, expr.Op) + case ">": + return compareOrder(left, right, expr.Op) + case "<=": + return compareOrder(left, right, expr.Op) + case ">=": + return compareOrder(left, right, expr.Op) + default: + return nil, fmt.Errorf("line %d, col %d: unknown operator %q", expr.Pos.Line, expr.Pos.Col, expr.Op) + } +} + +func evalFuncCall(expr *Expr, ctx *Context) (any, error) { + fn, ok := LookupBuiltin(expr.FuncName) + if !ok { + return nil, fmt.Errorf("line %d, col %d: unknown function %q", expr.Pos.Line, expr.Pos.Col, expr.FuncName) + } + var args []any + for _, a := range expr.FuncArgs { + val, err := Eval(a, ctx) + if err != nil { + return nil, err + } + args = append(args, val) + } + result, ok := fn(args) + if !ok { + return nil, fmt.Errorf("line %d, col %d: function %q call failed", expr.Pos.Line, expr.Pos.Col, expr.FuncName) + } + return result, nil +} + +func isTruthy(val any) bool { + if val == nil { + return false + } + switch v := val.(type) { + case bool: + return v + case int: + return v != 0 + case int64: + return v != 0 + case float64: + return v != 0 + case string: + return v != "" + case []any: + return len(v) > 0 + case map[string]any: + return len(v) > 0 + default: + rv := reflect.ValueOf(val) + switch rv.Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + return rv.Len() > 0 + default: + return true + } + } +} + +func compareEqual(left, right any) bool { + if left == nil && right == nil { + return true + } + if left == nil || right == nil { + return false + } + + lbool, lbOk := left.(bool) + rbool, rbOk := right.(bool) + if lbOk || rbOk { + return lbOk && rbOk && lbool == rbool + } + + lf, lok := toFloat64(left) + rf, rok := toFloat64(right) + if lok && rok { + return lf == rf + } + + lstr, lsOk := left.(string) + rstr, rsOk := right.(string) + if lsOk && rsOk { + return lstr == rstr + } + + return fmt.Sprintf("%v", left) == fmt.Sprintf("%v", right) +} + +func compareOrder(left, right any, op string) (bool, error) { + if left == nil || right == nil { + return false, nil + } + + lf, lok := toFloat64(left) + rf, rok := toFloat64(right) + if lok && rok { + switch op { + case "<": + return lf < rf, nil + case ">": + return lf > rf, nil + case "<=": + return lf <= rf, nil + case ">=": + return lf >= rf, nil + } + } + + ls, lsOk := left.(string) + rs, rsOk := right.(string) + if lsOk && rsOk { + switch op { + case "<": + return ls < rs, nil + case ">": + return ls > rs, nil + case "<=": + return ls <= rs, nil + case ">=": + return ls >= rs, nil + } + } + + return false, fmt.Errorf("line 0, col 0: cannot compare %T and %T with %s", left, right, op) +} + +func toFloat64(val any) (float64, bool) { + switch v := val.(type) { + case int: + return float64(v), true + case int64: + return float64(v), true + case float64: + return v, true + case float32: + return float64(v), true + case uint: + return float64(v), true + case uint64: + return float64(v), true + case int32: + return float64(v), true + default: + return 0, false + } +} diff --git a/internal/include.go b/internal/include.go new file mode 100644 index 0000000..04e6b7e --- /dev/null +++ b/internal/include.go @@ -0,0 +1,73 @@ +package internal + +import ( + "fmt" + "strings" +) + +type IncludeResolver func(path string) (string, error) + +type IncludeManager struct { + resolver IncludeResolver + stack []string +} + +func NewIncludeManager(resolver IncludeResolver) *IncludeManager { + return &IncludeManager{resolver: resolver} +} + +func (m *IncludeManager) Resolve(path string) (string, error) { + return m.resolveInternal(path, m.stack) +} + +func (m *IncludeManager) resolveInternal(path string, stack []string) (string, error) { + for i, p := range stack { + if p == path { + return "", fmt.Errorf("circular include detected: %s is already in the include chain at depth %d", path, i) + } + } + + src, err := m.resolver(path) + if err != nil { + return "", fmt.Errorf("failed to resolve include %q: %w", path, err) + } + + newStack := make([]string, len(stack)+1) + copy(newStack, stack) + newStack[len(stack)] = path + + return m.expandIncludes(src, newStack) +} + +func (m *IncludeManager) expandIncludes(src string, stack []string) (string, error) { + result := src + offset := 0 + + for { + start := strings.Index(result[offset:], `@include("`) + if start < 0 { + break + } + + absStart := offset + start + pathStart := absStart + len(`@include("`) + + end := strings.Index(result[pathStart:], `")`) + if end < 0 { + break + } + + includePath := result[pathStart : pathStart+end] + absEnd := pathStart + end + len(`")`) + + resolved, err := m.resolveInternal(includePath, stack) + if err != nil { + return "", err + } + + result = result[:absStart] + resolved + result[absEnd:] + offset = absStart + len(resolved) + } + + return result, nil +} diff --git a/internal/lexer.go b/internal/lexer.go new file mode 100644 index 0000000..7026034 --- /dev/null +++ b/internal/lexer.go @@ -0,0 +1,211 @@ +package internal + +import ( + "strings" + "unicode" +) + +type TokenType int + +const ( + TokText TokenType = iota + TokParamStart + TokRawStart + TokIfStart + TokForStart + TokTplStart + TokIncludeStart + TokNamespaceStart + TokElse + TokComment + TokEOF +) + +type Token struct { + Type TokenType + Value string + Pos Pos +} + +type Lexer struct { + input []rune + pos int + line int + col int +} + +func NewLexer(input string) *Lexer { + return &Lexer{ + input: []rune(input), + line: 1, + col: 1, + } +} + +func (l *Lexer) Tokenize() ([]Token, error) { + var tokens []Token + + for l.pos < len(l.input) { + ch := l.input[l.pos] + + if ch == '#' { + if l.peek(1) == '{' { + tokens = append(tokens, Token{Type: TokParamStart, Value: "#{", Pos: l.curPos()}) + l.advance() + l.advance() + continue + } + tokens = append(tokens, l.readComment()) + continue + } + + if ch == '$' && l.peek(1) == '{' { + tokens = append(tokens, Token{Type: TokRawStart, Value: "${", Pos: l.curPos()}) + l.advance() + l.advance() + continue + } + + if ch == '@' { + if tok, ok := l.tryDirective(); ok { + tokens = append(tokens, tok) + continue + } + } + + if ch == '}' { + // Skip spaces after '}' to check for "else" + spaceOffset := 1 + for l.peek(spaceOffset) == ' ' || l.peek(spaceOffset) == '\t' { + spaceOffset++ + } + if l.peekWord(spaceOffset, "else") { + pos := l.curPos() + l.advance() // consume '}' + l.advanceN(spaceOffset - 1) // consume spaces + l.advanceN(4) // consume "else" + tokens = append(tokens, Token{Type: TokElse, Value: "} else", Pos: pos}) + continue + } + l.advance() + // Pos stores the end position (after '}'), consistent with other TokText tokens + tokens = append(tokens, Token{Type: TokText, Value: "}", Pos: l.curPos()}) + continue + } + + if ch == '\n' { + l.advance() + // Pos stores the end position (after '\n'), consistent with other TokText tokens + tokens = append(tokens, Token{Type: TokText, Value: "\n", Pos: l.curPos()}) + continue + } + + // Regular text: scan until special character + start := l.pos + for l.pos < len(l.input) { + c := l.input[l.pos] + if c == '#' || c == '$' || c == '@' || c == '}' || c == '\n' { + break + } + l.advance() + } + if l.pos > start { + tokens = append(tokens, Token{Type: TokText, Value: string(l.input[start:l.pos]), Pos: Pos{Line: l.line, Col: l.col}}) + } + } + + tokens = append(tokens, Token{Type: TokEOF, Pos: Pos{Line: l.line, Col: l.col}}) + return tokens, nil +} + +func (l *Lexer) curPos() Pos { + return Pos{Line: l.line, Col: l.col} +} + +func (l *Lexer) advance() { + if l.pos < len(l.input) { + if l.input[l.pos] == '\n' { + l.line++ + l.col = 1 + } else { + l.col++ + } + l.pos++ + } +} + +func (l *Lexer) advanceN(n int) { + for range n { + l.advance() + } +} + +func (l *Lexer) peek(offset int) rune { + idx := l.pos + offset + if idx < len(l.input) { + return l.input[idx] + } + return 0 +} + +func (l *Lexer) peekWord(offset int, word string) bool { + runes := []rune(word) + n := len(runes) + for i := range n { + if l.peek(offset+i) != runes[i] { + return false + } + } + after := l.peek(offset + n) + return after == 0 || after == ' ' || after == '\n' || after == '\t' || after == '{' || after == '}' +} + +func (l *Lexer) tryDirective() (Token, bool) { + type directive struct { + prefix []rune + ttype TokenType + skip int + } + + directives := []directive{ + {[]rune("@if("), TokIfStart, 4}, + {[]rune("@for("), TokForStart, 5}, + {[]rune("@tpl(\""), TokTplStart, 5}, + {[]rune("@include(\""), TokIncludeStart, 10}, + {[]rune("@namespace(\""), TokNamespaceStart, 12}, + } + + for _, d := range directives { + if l.matchRunes(d.prefix) { + pos := l.curPos() + val := string(d.prefix) + l.advanceN(d.skip) + return Token{Type: d.ttype, Value: val, Pos: pos}, true + } + } + + return Token{}, false +} + +func (l *Lexer) matchRunes(runes []rune) bool { + for i, r := range runes { + if l.peek(i) != r { + return false + } + } + return true +} + +func (l *Lexer) readComment() Token { + pos := l.curPos() + l.advance() // # + start := l.pos + for l.pos < len(l.input) && l.input[l.pos] != '\n' { + l.advance() + } + // consume trailing newline so the comment line disappears + if l.pos < len(l.input) && l.input[l.pos] == '\n' { + l.advance() + } + return Token{Type: TokComment, Value: strings.TrimRightFunc(string(l.input[start:l.pos]), unicode.IsSpace), Pos: pos} +} diff --git a/internal/node.go b/internal/node.go new file mode 100644 index 0000000..829077c --- /dev/null +++ b/internal/node.go @@ -0,0 +1,111 @@ +package internal + +type Node interface { + nodeType() string +} + +type Pos struct { + Line int + Col int +} + +type TextNode struct { + Pos Pos + Text string +} + +func (n *TextNode) nodeType() string { return "Text" } + +type ParamNode struct { + Pos Pos + Name string +} + +func (n *ParamNode) nodeType() string { return "Param" } + +type RawNode struct { + Pos Pos + Name string +} + +func (n *RawNode) nodeType() string { return "Raw" } + +type IfNode struct { + Pos Pos + Cond *Expr + Body []Node + Else []Node + ElseIf []*ElseIfBranch +} + +func (n *IfNode) nodeType() string { return "If" } + +type ElseIfBranch struct { + Pos Pos + Cond *Expr + Body []Node +} + +type ForNode struct { + Pos Pos + KeyVar string + ValVar string + List *Expr + Body []Node +} + +func (n *ForNode) nodeType() string { return "For" } + +type BlockNode struct { + Pos Pos + Name string + Body []Node +} + +func (n *BlockNode) nodeType() string { return "Block" } + +type IncludeNode struct { + Pos Pos + Path string +} + +func (n *IncludeNode) nodeType() string { return "Include" } + +type NamespaceNode struct { + Pos Pos + Name string +} + +func (n *NamespaceNode) nodeType() string { return "Namespace" } + +type CommentNode struct { + Pos Pos + Text string +} + +func (n *CommentNode) nodeType() string { return "Comment" } + +type ExprType int + +const ( + ExprLiteral ExprType = iota + ExprVariable + ExprBinary + ExprUnary + ExprFuncCall + ExprNil +) + +type Expr struct { + Pos Pos + ExprType ExprType + Name string + Value any + Left *Expr + Op string + Right *Expr + UnaryOp string + Operand *Expr + FuncName string + FuncArgs []*Expr +} diff --git a/internal/parser.go b/internal/parser.go new file mode 100644 index 0000000..9376cfc --- /dev/null +++ b/internal/parser.go @@ -0,0 +1,913 @@ +package internal + +import ( + "fmt" + "strings" +) + +type Parser struct { + input []rune + tokens []Token + pos int + includeMgr *IncludeManager +} + +func NewParser(input string, tokens []Token, includeMgr *IncludeManager) *Parser { + return &Parser{ + input: []rune(input), + tokens: tokens, + pos: 0, + includeMgr: includeMgr, + } +} + +func (p *Parser) Parse() ([]Node, error) { + var nodes []Node + var hasTpl bool + var tplNames map[string]bool + + for p.pos < len(p.tokens) { + tok := p.cur() + if tok.Type == TokEOF { + break + } + + switch tok.Type { + case TokText: + p.pos++ + if len(tok.Value) == 0 { + continue + } + if hasTpl && !isWhitespace(tok.Value) { + return nil, fmt.Errorf("line %d, col %d: top-level text is not allowed in templates with @tpl blocks", tok.Pos.Line, tok.Pos.Col) + } + nodes = append(nodes, &TextNode{Pos: tok.Pos, Text: tok.Value}) + + case TokParamStart: + node, err := p.parseParam(tok) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokRawStart: + node, err := p.parseRaw(tok) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokIfStart: + node, err := p.parseIf(tok) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokForStart: + node, err := p.parseFor(tok) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokTplStart: + hasTpl = true + if tplNames == nil { + tplNames = make(map[string]bool) + } + node, err := p.parseTpl(tok, tplNames) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokIncludeStart: + subNodes, err := p.expandInclude(tok) + if err != nil { + return nil, err + } + nodes = append(nodes, subNodes...) + + case TokNamespaceStart: + node, err := p.parseNamespace(tok, len(nodes)) + if err != nil { + return nil, err + } + nodes = append(nodes, node) + + case TokElse: + return nil, fmt.Errorf("line %d, col %d: unexpected else", tok.Pos.Line, tok.Pos.Col) + + case TokComment: + p.pos++ + + default: + p.pos++ + } + } + return nodes, nil +} + +func (p *Parser) cur() Token { + if p.pos < len(p.tokens) { + return p.tokens[p.pos] + } + return Token{Type: TokEOF} +} + +// readUntilParen reads from runePos until ')' at paren depth 0. +// Only tracks '(' depth, not braces. +func (p *Parser) readUntilParen(startRunePos int) (string, int, error) { + i := startRunePos + depth := 0 + for i < len(p.input) { + ch := p.input[i] + if ch == '(' { + depth++ + i++ + continue + } + if ch == ')' { + if depth > 0 { + depth-- + i++ + continue + } + return string(p.input[startRunePos:i]), i, nil + } + if ch == '\'' || ch == '"' { + quote := ch + i++ + for i < len(p.input) && p.input[i] != quote { + if p.input[i] == '\\' && i+1 < len(p.input) { + i++ + } + i++ + } + if i < len(p.input) { + i++ + } + continue + } + i++ + } + return "", i, fmt.Errorf("unexpected end of input, expected ')'") +} + +// readUntilBrace reads from runePos until '}' at brace depth 0. +// Only tracks '{' depth, not parens. +func (p *Parser) readUntilBrace(startRunePos int) (string, int, error) { + i := startRunePos + depth := 0 + for i < len(p.input) { + ch := p.input[i] + if ch == '{' { + depth++ + i++ + continue + } + if ch == '}' { + if depth > 0 { + depth-- + i++ + continue + } + return string(p.input[startRunePos:i]), i, nil + } + if ch == '\'' || ch == '"' { + quote := ch + i++ + for i < len(p.input) && p.input[i] != quote { + if p.input[i] == '\\' && i+1 < len(p.input) { + i++ + } + i++ + } + if i < len(p.input) { + i++ + } + continue + } + i++ + } + return "", i, fmt.Errorf("unexpected end of input, expected '}'") +} + +// readUntilQuote reads from runePos until '"'. +func (p *Parser) readUntilQuote(startRunePos int) (string, int, error) { + i := startRunePos + for i < len(p.input) { + if p.input[i] == '"' { + return string(p.input[startRunePos:i]), i, nil + } + i++ + } + return "", i, fmt.Errorf("unexpected end of input, expected '\"'") +} + +func (p *Parser) runePosFromToken(tok Token) int { + line, col := tok.Pos.Line, tok.Pos.Col + prefixLen := len(tok.Value) + return runePosFromLineCol(p.input, line, col) + prefixLen +} + +func runePosFromLineCol(input []rune, line, col int) int { + if line <= 1 && col <= 1 { + return 0 + } + currentLine := 1 + currentCol := 1 + for i, ch := range input { + if currentLine == line && currentCol == col { + return i + } + if ch == '\n' { + currentLine++ + currentCol = 1 + } else { + currentCol++ + } + } + return len(input) +} + +func (p *Parser) consumeTokensForRuneRange(_, runeEnd int) { + for p.pos < len(p.tokens) { + tok := p.tokens[p.pos] + if tok.Type == TokEOF { + break + } + // Text tokens store Pos as the end position; other tokens store Pos as the start. + tokRuneStart := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) + if tok.Type == TokText { + tokRuneStart -= len(tok.Value) + } + if tokRuneStart >= runeEnd { + break + } + // For text tokens, check if the token extends past runeEnd. + // If so, split it: keep the remainder as a new text token. + if tok.Type == TokText && len(tok.Value) > 0 { + tokRuneEnd := tokRuneStart + len(tok.Value) + if tokRuneEnd > runeEnd { + overlap := runeEnd - tokRuneStart + if overlap > 0 && overlap < len(tok.Value) { + remainder := tok.Value[overlap:] + // Replace current token with the remainder + p.tokens[p.pos] = Token{ + Type: TokText, + Value: remainder, + Pos: tok.Pos, + } + } + break + } + } + p.pos++ + } +} + +func (p *Parser) parseParam(tok Token) (*ParamNode, error) { + runePos := p.runePosFromToken(tok) + content, endPos, err := p.readUntilBrace(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated param, expected '}'", tok.Pos.Line, tok.Pos.Col) + } + content = strings.TrimSpace(content) + if content == "" { + return nil, fmt.Errorf("line %d, col %d: empty param name", tok.Pos.Line, tok.Pos.Col) + } + p.consumeTokensForRuneRange(runePos, endPos+1) + return &ParamNode{Pos: tok.Pos, Name: content}, nil +} + +func (p *Parser) parseRaw(tok Token) (*RawNode, error) { + runePos := p.runePosFromToken(tok) + content, endPos, err := p.readUntilBrace(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated raw, expected '}'", tok.Pos.Line, tok.Pos.Col) + } + content = strings.TrimSpace(content) + if content == "" { + return nil, fmt.Errorf("line %d, col %d: empty raw name", tok.Pos.Line, tok.Pos.Col) + } + p.consumeTokensForRuneRange(runePos, endPos+1) + return &RawNode{Pos: tok.Pos, Name: content}, nil +} + +func (p *Parser) parseIf(tok Token) (*IfNode, error) { + runePos := p.runePosFromToken(tok) + exprStr, parenClose, err := p.readUntilParen(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @if, expected ')'", tok.Pos.Line, tok.Pos.Col) + } + + expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse() + if err != nil { + return nil, err + } + + // Find '{' after ')' in raw input + braceOpen := findChar(p.input[parenClose+1:], '{') + if braceOpen < 0 { + return nil, fmt.Errorf("line %d, col %d: expected '{' after @if condition", tok.Pos.Line, tok.Pos.Col) + } + braceOpen = parenClose + 1 + braceOpen + + // consume tokens covering ) ... { + p.consumeTokensForRuneRange(runePos, braceOpen+1) + + body, elseBody, elseIfBranches, err := p.parseBlockBodyWithElse("if") + if err != nil { + return nil, err + } + + return &IfNode{ + Pos: tok.Pos, + Cond: expr, + Body: body, + Else: elseBody, + ElseIf: elseIfBranches, + }, nil +} + +func (p *Parser) parseElseIfBranch(tok Token) (*ElseIfBranch, error) { + prefix := "@elseif(" + runePos := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) + idx := strings.Index(tok.Value, prefix) + if idx < 0 { + return nil, fmt.Errorf("line %d, col %d: expected @elseif(", tok.Pos.Line, tok.Pos.Col) + } + exprStart := runePos + idx + len(prefix) + exprStr, closePos, err := p.readUntilParen(exprStart) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @elseif, expected ')'", tok.Pos.Line, tok.Pos.Col) + } + + expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse() + if err != nil { + return nil, err + } + + p.consumeTokensForRuneRange(exprStart, closePos+1) + p.skipToBraceOpen() + + body, _, _, err := p.parseBlockBodyWithElse("elseif") + if err != nil { + return nil, err + } + + return &ElseIfBranch{ + Pos: tok.Pos, + Cond: expr, + Body: body, + }, nil +} + +func (p *Parser) parseFor(tok Token) (*ForNode, error) { + runePos := p.runePosFromToken(tok) + content, parenClose, err := p.readUntilParen(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @for, expected ')'", tok.Pos.Line, tok.Pos.Col) + } + + braceOpen := findChar(p.input[parenClose+1:], '{') + if braceOpen < 0 { + return nil, fmt.Errorf("line %d, col %d: expected '{' after @for", tok.Pos.Line, tok.Pos.Col) + } + braceOpen = parenClose + 1 + braceOpen + p.consumeTokensForRuneRange(runePos, braceOpen+1) + + keyVar, valVar, listExprStr, err := parseForHeader(content) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: %s", tok.Pos.Line, tok.Pos.Col, err.Error()) + } + + listExpr, err := NewExprParser(strings.TrimSpace(listExprStr), tok.Pos.Line, tok.Pos.Col).Parse() + if err != nil { + return nil, err + } + + body, _, _, err := p.parseBlockBodyWithElse("for") + if err != nil { + return nil, err + } + + return &ForNode{ + Pos: tok.Pos, + KeyVar: keyVar, + ValVar: valVar, + List: listExpr, + Body: body, + }, nil +} + +func findChar(input []rune, ch rune) int { + for i, c := range input { + if c == ch { + return i + } + } + return -1 +} + +func parseForHeader(content string) (keyVar, valVar, listExpr string, err error) { + content = strings.TrimSpace(content) + + rangeIdx := strings.Index(content, " range ") + if rangeIdx < 0 { + return "", "", "", fmt.Errorf("expected 'range' keyword in @for") + } + + varsPart := strings.TrimRight(strings.TrimSpace(content[:rangeIdx]), ", ") + listExpr = strings.TrimSpace(content[rangeIdx+len(" range "):]) + + parts := strings.Split(varsPart, ",") + switch len(parts) { + case 1: + valVar = strings.TrimSpace(parts[0]) + case 2: + keyVar = strings.TrimSpace(parts[0]) + valVar = strings.TrimSpace(parts[1]) + default: + return "", "", "", fmt.Errorf("invalid @for variable declaration") + } + + if valVar == "" { + return "", "", "", fmt.Errorf("missing value variable in @for") + } + return keyVar, valVar, listExpr, nil +} + +func (p *Parser) parseTpl(tok Token, tplNames map[string]bool) (*BlockNode, error) { + runePos := p.runePosFromToken(tok) + name, quotePos, err := p.readUntilQuote(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @tpl name", tok.Pos.Line, tok.Pos.Col) + } + name = strings.TrimSpace(name) + if name == "" { + return nil, fmt.Errorf("line %d, col %d: empty @tpl block name", tok.Pos.Line, tok.Pos.Col) + } + if tplNames[name] { + return nil, fmt.Errorf("line %d, col %d: duplicate @tpl block name %q", tok.Pos.Line, tok.Pos.Col, name) + } + tplNames[name] = true + + // skip ") {" + endPos := quotePos + 1 // past closing " + if endPos < len(p.input) && p.input[endPos] == ')' { + endPos++ + } + // find '{' and consume tokens up to past it + braceOpen := findChar(p.input[endPos:], '{') + if braceOpen < 0 { + return nil, fmt.Errorf("line %d, col %d: expected '{' after @tpl", tok.Pos.Line, tok.Pos.Col) + } + braceOpen = endPos + braceOpen + p.consumeTokensForRuneRange(runePos, braceOpen+1) + + body, _, _, err := p.parseBlockBodyWithElse("tpl") + if err != nil { + return nil, err + } + + return &BlockNode{ + Pos: tok.Pos, + Name: name, + Body: body, + }, nil +} + +func (p *Parser) parseInclude(tok Token) (*IncludeNode, error) { + runePos := p.runePosFromToken(tok) + path, quotePos, err := p.readUntilQuote(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @include path", tok.Pos.Line, tok.Pos.Col) + } + path = strings.TrimSpace(path) + + endPos := quotePos + 1 + if endPos < len(p.input) && p.input[endPos] == ')' { + endPos++ + } + p.consumeTokensForRuneRange(runePos, endPos) + + return &IncludeNode{ + Pos: tok.Pos, + Path: path, + }, nil +} + +func (p *Parser) expandInclude(tok Token) ([]Node, error) { + incNode, err := p.parseInclude(tok) + if err != nil { + return nil, err + } + if p.includeMgr == nil { + return nil, fmt.Errorf("line %d, col %d: @include used but no include resolver configured", incNode.Pos.Line, incNode.Pos.Col) + } + expanded, err := p.includeMgr.Resolve(incNode.Path) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: %s", incNode.Pos.Line, incNode.Pos.Col, err.Error()) + } + subLexer := NewLexer(expanded) + subTokens, err := subLexer.Tokenize() + if err != nil { + return nil, err + } + subParser := NewParser(expanded, subTokens, p.includeMgr) + return subParser.Parse() +} + +func (p *Parser) parseNamespace(tok Token, nodeCount int) (*NamespaceNode, error) { + if nodeCount > 0 { + return nil, fmt.Errorf("line %d, col %d: @namespace must be at the top of the file", tok.Pos.Line, tok.Pos.Col) + } + + runePos := p.runePosFromToken(tok) + name, quotePos, err := p.readUntilQuote(runePos) + if err != nil { + return nil, fmt.Errorf("line %d, col %d: unterminated @namespace name", tok.Pos.Line, tok.Pos.Col) + } + name = strings.TrimSpace(name) + + endPos := quotePos + 1 + if endPos < len(p.input) && p.input[endPos] == ')' { + endPos++ + } + p.consumeTokensForRuneRange(runePos, endPos) + + return &NamespaceNode{ + Pos: tok.Pos, + Name: name, + }, nil +} + +// skipToBraceOpen finds '{' in raw input starting from current token position, +// then consumes all tokens up to and past '{'. The text after '{' is preserved +// as the current token so it becomes part of the block body. +func (p *Parser) skipToBraceOpen() { + // Find current rune position from current token + if p.pos >= len(p.tokens) { + return + } + tok := p.tokens[p.pos] + startRune := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) + // Text tokens store Pos as the end position; adjust to start. + if tok.Type == TokText { + startRune -= len(tok.Value) + } + + // Find '{' in raw input + idx := findChar(p.input[startRune:], '{') + if idx < 0 { + return + } + braceRune := startRune + idx + + // Consume all tokens whose rune position is before '{' + p.consumeTokensForRuneRange(startRune, braceRune+1) + + // The text after '{' needs to be available as a token. + // If there are remaining characters after '{', they'll be in subsequent tokens. +} + +// splitTextAtBrace finds the first '{' in the current text token at p.pos, +// splits it so the part after '{' remains, and returns true. +// If no '{' is found, advances p.pos and returns false. +func (p *Parser) splitTextAtBrace() bool { + for p.pos < len(p.tokens) { + tok := p.tokens[p.pos] + if tok.Type != TokText { + return false + } + idx := strings.Index(tok.Value, "{") + if idx >= 0 { + if idx+1 < len(tok.Value) { + remainder := tok.Value[idx+1:] + p.tokens[p.pos] = Token{Type: TokText, Value: remainder, Pos: tok.Pos} + } else { + p.pos++ + } + return true + } + p.pos++ + } + return false +} + +func (p *Parser) skipWhitespaceText() { + for p.pos < len(p.tokens) { + tok := p.tokens[p.pos] + if tok.Type == TokEOF { + break + } + if tok.Type == TokText && isWhitespace(tok.Value) { + p.pos++ + continue + } + break + } +} + +// parseBlockBodyWithElse parses tokens inside a { } block. +// Returns: body nodes, else body (nil if no else), elseif branches, error. +// Handles nested @if/@for blocks by tracking brace depth. +func (p *Parser) parseBlockBodyWithElse(blockType string) ([]Node, []Node, []*ElseIfBranch, error) { + var body []Node + braceDepth := 1 + + for p.pos < len(p.tokens) { + tok := p.cur() + if tok.Type == TokEOF { + return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated %s block", tok.Pos.Line, tok.Pos.Col, blockType) + } + + if tok.Type == TokText { + text := tok.Value + // Scan for '}' at depth 1 + before, found := splitAtClosingBrace(text) + if found { + if before != "" { + body = append(body, &TextNode{Pos: tok.Pos, Text: before}) + } + p.pos++ // consume this token + // Check what follows: else, elseif, or nothing + return p.handleBlockEnd(body, blockType) + } + // No closing brace — count nested braces and add as text + for _, ch := range text { + if ch == '{' { + braceDepth++ + } else if ch == '}' { + braceDepth-- + } + } + if len(text) > 0 { + body = append(body, &TextNode{Pos: tok.Pos, Text: text}) + } + p.pos++ + continue + } + + if tok.Type == TokElse && braceDepth == 1 { + // Don't consume TokElse here — handleBlockEnd will consume it. + // This way handleBlockEnd can properly process the else branch. + return p.handleBlockEnd(body, blockType) + } + + p.pos++ + + switch tok.Type { + case TokComment: + // skip comments in body + case TokParamStart: + node, err := p.parseParam(tok) + if err != nil { + return nil, nil, nil, err + } + body = append(body, node) + case TokRawStart: + node, err := p.parseRaw(tok) + if err != nil { + return nil, nil, nil, err + } + body = append(body, node) + case TokIfStart: + node, err := p.parseIf(tok) + if err != nil { + return nil, nil, nil, err + } + body = append(body, node) + case TokForStart: + node, err := p.parseFor(tok) + if err != nil { + return nil, nil, nil, err + } + body = append(body, node) + case TokTplStart: + return nil, nil, nil, fmt.Errorf("line %d, col %d: @tpl blocks cannot be nested", tok.Pos.Line, tok.Pos.Col) + case TokIncludeStart: + subNodes, err := p.expandInclude(tok) + if err != nil { + return nil, nil, nil, err + } + body = append(body, subNodes...) + case TokNamespaceStart: + return nil, nil, nil, fmt.Errorf("line %d, col %d: @namespace must be at file top level", tok.Pos.Line, tok.Pos.Col) + } + } + + return nil, nil, nil, fmt.Errorf("unterminated %s block", blockType) +} + +// handleBlockEnd is called after the closing '}' of a block. +// Checks for else/elseif branches. +func (p *Parser) handleBlockEnd(body []Node, blockType string) ([]Node, []Node, []*ElseIfBranch, error) { + // For non-if blocks, no else support + if blockType != "if" && blockType != "elseif" { + return body, nil, nil, nil + } + + p.skipWhitespaceText() + tok := p.cur() + + // TokElse means lexer already matched "} else" + if tok.Type == TokElse { + return p.handleTokElse(tok, body) + } + + // Check for text token starting with "else" or "elseif(" + if tok.Type == TokText { + return p.handleTextElse(tok, body) + } + + return body, nil, nil, nil +} + +// handleTokElse processes a TokElse token (lexer already matched "} else"). +func (p *Parser) handleTokElse(tok Token, body []Node) ([]Node, []Node, []*ElseIfBranch, error) { + p.pos++ // consume TokElse + p.skipWhitespaceText() + next := p.cur() + // Check for "} else if(expr)" — text starting with "if " or "if(" + if next.Type == TokText { + trimmed := strings.TrimSpace(next.Value) + if strings.HasPrefix(trimmed, "if ") || strings.HasPrefix(trimmed, "if(") { + return p.parseElseIfFromText(body, next) + } + } + // Check for @elseif( syntax + if next.Type == TokText && strings.Contains(next.Value, "@elseif(") { + branch, err := p.parseElseIfBranch(next) + if err != nil { + return nil, nil, nil, err + } + return body, nil, []*ElseIfBranch{branch}, nil + } + // It's a plain else — find '{' in text tokens and split + if !p.splitTextAtBrace() { + return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else", tok.Pos.Line, tok.Pos.Col) + } + elseBody, _, _, err := p.parseBlockBodyWithElse("else") + if err != nil { + return nil, nil, nil, err + } + return body, elseBody, nil, nil +} + +// handleTextElse processes a TokText token that may contain "else" or "elseif(". +func (p *Parser) handleTextElse(tok Token, body []Node) ([]Node, []Node, []*ElseIfBranch, error) { + trimmed := strings.TrimSpace(tok.Value) + + // Handle "else if" in a single text token (e.g. "} else if (expr) {") + if strings.HasPrefix(trimmed, "else if ") || strings.HasPrefix(trimmed, "else if(") { + p.consumeElseKeyword(tok) + p.skipWhitespaceText() + next := p.cur() + if next.Type == TokText { + return p.parseElseIfFromText(body, next) + } + } + + // Plain else + if strings.HasPrefix(trimmed, "else") && !strings.HasPrefix(trimmed, "elseif(") && !strings.HasPrefix(trimmed, "else if ") && !strings.HasPrefix(trimmed, "else if(") { + p.consumeElseKeyword(tok) + if !p.splitTextAtBrace() { + p.skipToBraceOpen() + } + elseBody, _, _, err := p.parseBlockBodyWithElse("else") + if err != nil { + return nil, nil, nil, err + } + return body, elseBody, nil, nil + } + + // elseif( syntax + if strings.HasPrefix(trimmed, "elseif(") { + branch, err := p.parseElseIfBranch(tok) + if err != nil { + return nil, nil, nil, err + } + return body, nil, []*ElseIfBranch{branch}, nil + } + + return body, nil, nil, nil +} + +// consumeElseKeyword removes the "else" prefix from a text token, +// keeping any remaining text as the current token. +func (p *Parser) consumeElseKeyword(tok Token) { + idx := strings.Index(tok.Value, "else") + if idx >= 0 { + remainder := tok.Value[idx+4:] // after "else" + if len(remainder) > 0 { + p.tokens[p.pos] = Token{Type: TokText, Value: remainder, Pos: tok.Pos} + } else { + p.pos++ + } + } else { + p.pos++ + } +} + +// parseElseIfFromText handles "} else if (expr) { body }" pattern. +// next is a TokText whose trimmed value starts with "if(" or "if ". +func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, []*ElseIfBranch, error) { + trimmed := strings.TrimSpace(next.Value) + + // Find "if" and skip to the opening paren + ifIdx := strings.Index(trimmed, "if") + if ifIdx < 0 { + return nil, nil, nil, fmt.Errorf("line %d, col %d: expected 'if' in else-if condition", next.Pos.Line, next.Pos.Col) + } + afterIf := trimmed[ifIdx+2:] + // skip whitespace between "if" and "(" + parenLocalOffset := 0 + for parenLocalOffset < len(afterIf) && afterIf[parenLocalOffset] == ' ' { + parenLocalOffset++ + } + if parenLocalOffset >= len(afterIf) || afterIf[parenLocalOffset] != '(' { + return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col) + } + + // Compute rune position of '(' in the raw input. + // The text token's Pos marks the END of the token text; the start is Pos - len(Value). + runePos := runePosFromLineCol(p.input, next.Pos.Line, next.Pos.Col) - len(next.Value) + + // Walk through next.Value byte-by-byte to find the '(' that corresponds to the condition. + // We know trimmed starts at some offset into next.Value; find "if" in the raw value. + rawIfIdx := strings.Index(next.Value, "if") + rawParenOffset := rawIfIdx + 2 + for rawParenOffset < len(next.Value) && next.Value[rawParenOffset] == ' ' { + rawParenOffset++ + } + if rawParenOffset >= len(next.Value) || next.Value[rawParenOffset] != '(' { + return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col) + } + + // parenRunePos points to '(' in the raw input + parenRunePos := runePos + rawParenOffset + // readUntilParen expects the position AFTER '(' — it reads content between start and closing ')' + exprStr, parenClose, err := p.readUntilParen(parenRunePos + 1) + if err != nil { + return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated else-if condition, expected ')'", next.Pos.Line, next.Pos.Col) + } + + expr, err := NewExprParser(strings.TrimSpace(exprStr), next.Pos.Line, next.Pos.Col).Parse() + if err != nil { + return nil, nil, nil, err + } + + // Consume tokens from the start of this text token up to past ')' + p.consumeTokensForRuneRange(runePos, parenClose+1) + + // Find '{' after ')' in raw input + braceOpen := findChar(p.input[parenClose+1:], '{') + if braceOpen < 0 { + return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else-if condition", next.Pos.Line, next.Pos.Col) + } + braceOpen = parenClose + 1 + braceOpen + + // Consume tokens up to past '{' + p.consumeTokensForRuneRange(parenClose+1, braceOpen+1) + + // Parse the elseif body — it may itself have else/elseif + elseifBody, elseBody, moreBranches, err := p.parseBlockBodyWithElse("elseif") + if err != nil { + return nil, nil, nil, err + } + + branch := &ElseIfBranch{ + Pos: next.Pos, + Cond: expr, + Body: elseifBody, + } + + allBranches := []*ElseIfBranch{branch} + allBranches = append(allBranches, moreBranches...) + return body, elseBody, allBranches, nil +} + +// splitAtClosingBrace splits text at the first '}' at brace depth 0. +// Returns: text before '}', whether '}' was found. +func splitAtClosingBrace(text string) (string, bool) { + depth := 0 + for i, ch := range text { + if ch == '{' { + depth++ + } else if ch == '}' { + if depth == 0 { + return text[:i], true + } + depth-- + } + } + return text, false +} + +func isWhitespace(s string) bool { + return strings.TrimSpace(s) == "" +} diff --git a/internal/placeholder.go b/internal/placeholder.go new file mode 100644 index 0000000..f3b1b03 --- /dev/null +++ b/internal/placeholder.go @@ -0,0 +1,40 @@ +package internal + +import "fmt" + +type PlaceholderStyle int + +const ( + QuestionMark PlaceholderStyle = iota + DollarNumber + ColonNumber +) + +type Placeholder struct { + style PlaceholderStyle + count int +} + +func NewPlaceholder(style PlaceholderStyle) *Placeholder { + return &Placeholder{style: style} +} + +func (p *Placeholder) Next() string { + p.count++ + switch p.style { + case DollarNumber: + return fmt.Sprintf("$%d", p.count) + case ColonNumber: + return fmt.Sprintf(":%d", p.count) + default: + return "?" + } +} + +func (p *Placeholder) Reset() { + p.count = 0 +} + +func (p *Placeholder) Count() int { + return p.count +} diff --git a/safety.go b/safety.go new file mode 100644 index 0000000..37cc105 --- /dev/null +++ b/safety.go @@ -0,0 +1,39 @@ +package utpl + +import "slices" + +type RawPolicy interface { + Validate(param string, value string) error +} + +type RawAllowlist map[string][]string + +func (a RawAllowlist) Validate(param string, value string) error { + allowed, ok := a[param] + if !ok { + return &UnsafeRawError{Param: param, Value: value, Message: "no allowlist defined"} + } + if slices.Contains(allowed, value) { + return nil + } + return &UnsafeRawError{Param: param, Value: value, Message: "value not in allowlist"} +} + +type RawBlocklist map[string][]string + +func (b RawBlocklist) Validate(param string, value string) error { + blocked, ok := b[param] + if !ok { + return nil + } + if slices.Contains(blocked, value) { + return &UnsafeRawError{Param: param, Value: value, Message: "value is blocked"} + } + return nil +} + +type RawNoop struct{} + +func (RawNoop) Validate(string, string) error { + return nil +} diff --git a/template.go b/template.go new file mode 100644 index 0000000..d5a362b --- /dev/null +++ b/template.go @@ -0,0 +1,85 @@ +package utpl + +import ( + "errors" + "fmt" + + "gitea.1216.top/lxy/u-tpl/internal" +) + +type Result struct { + SQL string + Args []any +} + +type Template struct { + name string + engine *Engine + nodes []internal.Node + blocks map[string][]internal.Node + hasBlocks bool + namespace string +} + +func (t *Template) Execute(vars map[string]any) (*Result, error) { + if t.hasBlocks { + return nil, &ExecError{ + Pos: Position{}, + Message: fmt.Sprintf("template %q has named blocks, use ExecuteBlock instead", t.name), + } + } + return t.executeNodes(t.nodes, vars) +} + +func (t *Template) ExecuteBlock(blockName string, vars map[string]any) (*Result, error) { + nodes, ok := t.blocks[blockName] + if !ok { + if t.namespace != "" { + nodes, ok = t.blocks[t.namespace+"."+blockName] + } + } + if !ok { + return nil, &ExecError{ + Pos: Position{}, + Message: fmt.Sprintf("block %q not found in template %q", blockName, t.name), + } + } + return t.executeNodes(nodes, vars) +} + +func (t *Template) ExecuteString(vars map[string]any) (string, error) { + result, err := t.Execute(vars) + if err != nil { + return "", err + } + return result.SQL, nil +} + +func (t *Template) ExecuteBlockString(blockName string, vars map[string]any) (string, error) { + result, err := t.ExecuteBlock(blockName, vars) + if err != nil { + return "", err + } + return result.SQL, nil +} + +func (t *Template) executeNodes(nodes []internal.Node, vars map[string]any) (*Result, error) { + executor := internal.NewExecutor(t.engine.style, t.engine.rawPolicy, t.engine.strict) + result, err := executor.Execute(nodes, vars) + if err != nil { + return nil, wrapExecError(err) + } + return &Result{SQL: result.SQL, Args: result.Args}, nil +} + +func wrapExecError(err error) error { + var execErr *ExecError + if errors.As(err, &execErr) { + return err + } + var unsafeErr *UnsafeRawError + if errors.As(err, &unsafeErr) { + return err + } + return &ExecError{Message: err.Error()} +} diff --git a/utpl_ext_test.go b/utpl_ext_test.go new file mode 100644 index 0000000..55eda4e --- /dev/null +++ b/utpl_ext_test.go @@ -0,0 +1,944 @@ +package utpl + +import ( + "errors" + "fmt" + "strings" + "testing" +) + +// ---------- 1. TestInclude — @include directive ---------- +// +// NOTE: The current implementation creates IncludeNode during parsing but does +// not expand included content inline. The IncludeManager is constructed but never +// invoked. These tests document the actual behavior. When include expansion is +// fully wired up (content resolved before lexing), these tests should be updated +// to verify full inline expansion. + +func TestInclude(t *testing.T) { + t.Run("include node parses without error when resolver configured", func(t *testing.T) { + resolver := func(path string) (string, error) { + files := map[string]string{ + "common/tenant": "AND tenant_id = #{tenant_id}", + } + src, ok := files[path] + if !ok { + return "", fmt.Errorf("not found: %s", path) + } + return src, nil + } + src := `SELECT * FROM orders WHERE 1=1 @include("common/tenant")` + _, err := New(WithIncludeResolver(resolver)).Parse("test", src) + if err != nil { + t.Fatalf("parse should succeed, got: %v", err) + } + }) + + t.Run("include without resolver returns error", func(t *testing.T) { + src := `SELECT 1 @include("anything")` + _, err := New().Parse("test", src) + if err == nil { + t.Fatal("expected error when @include used without resolver") + } + }) + + t.Run("include content is expanded in output", func(t *testing.T) { + resolver := func(path string) (string, error) { + if path == "x" { + return "RESOLVED_CONTENT", nil + } + return "", fmt.Errorf("not found: %s", path) + } + src := `BEFORE @include("x") AFTER` + tpl, _ := New(WithIncludeResolver(resolver)).Parse("test", src) + r, err := tpl.Execute(nil) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(r.SQL, "RESOLVED_CONTENT") { + t.Errorf("SQL = %q, should contain resolved content", r.SQL) + } + if !strings.Contains(r.SQL, "BEFORE") || !strings.Contains(r.SQL, "AFTER") { + t.Errorf("SQL = %q, should contain BEFORE and AFTER text", r.SQL) + } + }) + + t.Run("include with resolver error fails at parse", func(t *testing.T) { + resolver := func(path string) (string, error) { + return "", fmt.Errorf("not found: %s", path) + } + src := `SELECT 1 @include("nonexistent")` + _, err := New(WithIncludeResolver(resolver)).Parse("test", src) + if err == nil { + t.Fatal("expected error when include resolver fails") + } + }) +} + +// ---------- 2. TestElseIf — else if / else branches ---------- + +func TestElseIf(t *testing.T) { + t.Run("if with else branch", func(t *testing.T) { + src := "@if(x == 1) {ONE} else {OTHER}" + tests := []struct { + val any + want string + }{ + {1, "ONE"}, + {2, "OTHER"}, + {99, "OTHER"}, + } + for _, tc := range tests { + tpl, _ := New().Parse("test", src) + r, err := tpl.Execute(map[string]any{"x": tc.val}) + if err != nil { + t.Fatalf("x=%v: %v", tc.val, err) + } + if !strings.Contains(r.SQL, tc.want) { + t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want) + } + } + }) + + t.Run("if else if else chain", func(t *testing.T) { + src := `@if(role == "admin") {ADMIN} else if (role == "manager") {MGR} else {OTHER}` + tests := []struct { + role string + want string + }{ + {"admin", "ADMIN"}, + {"manager", "MGR"}, + {"guest", "OTHER"}, + {"other", "OTHER"}, + } + for _, tc := range tests { + tpl, _ := New().Parse("test", src) + r, err := tpl.Execute(map[string]any{"role": tc.role}) + if err != nil { + t.Fatalf("role=%q: %v", tc.role, err) + } + if !strings.Contains(r.SQL, tc.want) { + t.Errorf("role=%q: SQL = %q, want to contain %q", tc.role, r.SQL, tc.want) + } + } + }) + + t.Run("multiple else if branches", func(t *testing.T) { + src := `@if(x == 1) {ONE} else if (x == 2) {TWO} else if (x == 3) {THREE} else {OTHER}` + tests := []struct { + val any + want string + }{ + {1, "ONE"}, + {2, "TWO"}, + {3, "THREE"}, + {99, "OTHER"}, + } + for _, tc := range tests { + tpl, _ := New().Parse("test", src) + r, err := tpl.Execute(map[string]any{"x": tc.val}) + if err != nil { + t.Fatalf("x=%v: %v", tc.val, err) + } + if !strings.Contains(r.SQL, tc.want) { + t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want) + } + } + }) + + t.Run("else if without final else", func(t *testing.T) { + src := `@if(x == 1) {ONE} else if (x == 2) {TWO}` + tests := []struct { + val any + want string + found bool + }{ + {1, "ONE", true}, + {2, "TWO", true}, + {99, "", false}, + } + for _, tc := range tests { + tpl, _ := New().Parse("test", src) + r, err := tpl.Execute(map[string]any{"x": tc.val}) + if err != nil { + t.Fatalf("x=%v: %v", tc.val, err) + } + if tc.found && !strings.Contains(r.SQL, tc.want) { + t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want) + } + if !tc.found && r.SQL != "" { + t.Errorf("x=%v: SQL = %q, want empty", tc.val, r.SQL) + } + } + }) + + t.Run("else if with params in branches", func(t *testing.T) { + src := `@if(role == "admin") {level = #{admin_level}} else if (role == "manager") {level = #{mgr_level}} else {level = #{default_level}}` + tests := []struct { + role string + want string + argVal any + }{ + {"admin", "level = ?", 10}, + {"manager", "level = ?", 5}, + {"guest", "level = ?", 1}, + } + for _, tc := range tests { + tpl, parseErr := New().Parse("test", src) + if parseErr != nil { + t.Fatalf("role=%q: parse failed: %v", tc.role, parseErr) + } + r, err := tpl.Execute(map[string]any{ + "role": tc.role, + "admin_level": 10, + "mgr_level": 5, + "default_level": 1, + }) + if err != nil { + t.Fatalf("role=%q: %v", tc.role, err) + } + if !strings.Contains(r.SQL, tc.want) { + t.Errorf("role=%q: SQL = %q, want to contain %q", tc.role, r.SQL, tc.want) + } + if len(r.Args) != 1 || r.Args[0] != tc.argVal { + t.Errorf("role=%q: Args = %v, want [%v]", tc.role, r.Args, tc.argVal) + } + } + }) + + t.Run("else if with no space before paren", func(t *testing.T) { + src := "@if(x == 1) {ONE} else if(x == 2) {TWO} else {OTHER}" + tests := []struct { + val any + want string + }{ + {1, "ONE"}, + {2, "TWO"}, + {99, "OTHER"}, + } + for _, tc := range tests { + tpl, _ := New().Parse("test", src) + r, err := tpl.Execute(map[string]any{"x": tc.val}) + if err != nil { + t.Fatalf("x=%v: %v", tc.val, err) + } + if !strings.Contains(r.SQL, tc.want) { + t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want) + } + } + }) + + t.Run("else if with multiline SQL", func(t *testing.T) { + src := `SELECT * FROM users WHERE 1=1 +@if(role == "admin") { + AND level >= #{admin_level} +} else if (role == "manager") { + AND level >= #{mgr_level} +} else { + AND level >= 1 +}` + r := exec(t, src, map[string]any{ + "role": "manager", + "admin_level": 10, + "mgr_level": 5, + }) + if !strings.Contains(r.SQL, "AND level >= ?") { + t.Errorf("SQL = %q, want 'AND level >= ?'", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != 5 { + t.Errorf("Args = %v, want [5]", r.Args) + } + + r2 := exec(t, src, map[string]any{ + "role": "admin", + "admin_level": 10, + "mgr_level": 5, + }) + if !strings.Contains(r2.SQL, "AND level >= ?") { + t.Errorf("SQL = %q, want 'AND level >= ?'", r2.SQL) + } + if len(r2.Args) != 1 || r2.Args[0] != 10 { + t.Errorf("Args = %v, want [10]", r2.Args) + } + + r3 := exec(t, src, map[string]any{ + "role": "guest", + "admin_level": 10, + "mgr_level": 5, + }) + if !strings.Contains(r3.SQL, "AND level >= 1") { + t.Errorf("SQL = %q, want 'AND level >= 1'", r3.SQL) + } + if len(r3.Args) != 0 { + t.Errorf("Args = %v, want empty", r3.Args) + } + }) + + t.Run("if without else produces no output when false", func(t *testing.T) { + src := "@if(x == 1) {ONE}" + tpl, _ := New().Parse("test", src) + + r, _ := tpl.Execute(map[string]any{"x": 1}) + if !strings.Contains(r.SQL, "ONE") { + t.Errorf("x=1: SQL = %q, want ONE", r.SQL) + } + + r2, _ := tpl.Execute(map[string]any{"x": 99}) + if strings.Contains(r2.SQL, "ONE") { + t.Errorf("x=99: SQL = %q, should not contain ONE", r2.SQL) + } + }) +} + +// ---------- 3. TestExpression — expression edge cases ---------- + +func TestExpression(t *testing.T) { + t.Run("len function with []any non-empty", func(t *testing.T) { + r := exec(t, "@if(len(ids) > 0) {has ids}", map[string]any{"ids": []any{1, 2}}) + if !strings.Contains(r.SQL, "has ids") { + t.Errorf("SQL = %q, want 'has ids'", r.SQL) + } + }) + + t.Run("len function with []any empty", func(t *testing.T) { + r := exec(t, "@if(len(ids) > 0) {has ids}", map[string]any{"ids": []any{}}) + if strings.Contains(r.SQL, "has ids") { + t.Errorf("SQL = %q, should not contain 'has ids'", r.SQL) + } + }) + + t.Run("len function with string", func(t *testing.T) { + r := exec(t, "@if(len(name) > 0) {has name}", map[string]any{"name": "alice"}) + if !strings.Contains(r.SQL, "has name") { + t.Errorf("SQL = %q, want 'has name'", r.SQL) + } + + r2 := exec(t, "@if(len(name) > 0) {has name}", map[string]any{"name": ""}) + if strings.Contains(r2.SQL, "has name") { + t.Errorf("SQL = %q, should not contain 'has name'", r2.SQL) + } + }) + + t.Run("numeric comparison ge", func(t *testing.T) { + r := exec(t, "@if(age >= 18) {adult}", map[string]any{"age": 20}) + if !strings.Contains(r.SQL, "adult") { + t.Errorf("age=20: SQL = %q, want 'adult'", r.SQL) + } + + r2 := exec(t, "@if(age >= 18) {adult}", map[string]any{"age": 10}) + if strings.Contains(r2.SQL, "adult") { + t.Errorf("age=10: SQL = %q, should not contain 'adult'", r2.SQL) + } + }) + + t.Run("boolean variable true", func(t *testing.T) { + r := exec(t, "@if(flag) {yes}", map[string]any{"flag": true}) + if !strings.Contains(r.SQL, "yes") { + t.Errorf("flag=true: SQL = %q, want 'yes'", r.SQL) + } + }) + + t.Run("boolean variable false", func(t *testing.T) { + r := exec(t, "@if(flag) {yes}", map[string]any{"flag": false}) + if strings.Contains(r.SQL, "yes") { + t.Errorf("flag=false: SQL = %q, should not contain 'yes'", r.SQL) + } + }) + + t.Run("negation true", func(t *testing.T) { + r := exec(t, "@if(!flag) {no flag}", map[string]any{"flag": true}) + if strings.Contains(r.SQL, "no flag") { + t.Errorf("flag=true: SQL = %q, should not contain 'no flag'", r.SQL) + } + }) + + t.Run("negation false", func(t *testing.T) { + r := exec(t, "@if(!flag) {no flag}", map[string]any{"flag": false}) + if !strings.Contains(r.SQL, "no flag") { + t.Errorf("flag=false: SQL = %q, want 'no flag'", r.SQL) + } + }) + + t.Run("nested dot path in condition", func(t *testing.T) { + r := exec(t, "@if(user.active) {active}", map[string]any{ + "user": map[string]any{"active": true}, + }) + if !strings.Contains(r.SQL, "active") { + t.Errorf("SQL = %q, want 'active'", r.SQL) + } + + r2 := exec(t, "@if(user.active) {active}", map[string]any{ + "user": map[string]any{"active": false}, + }) + if strings.Contains(r2.SQL, "active") { + t.Errorf("SQL = %q, should not contain 'active'", r2.SQL) + } + }) + + t.Run("compare to nil true", func(t *testing.T) { + r := exec(t, "@if(val == nil) {is nil}", map[string]any{"val": nil}) + if !strings.Contains(r.SQL, "is nil") { + t.Errorf("SQL = %q, want 'is nil'", r.SQL) + } + }) + + t.Run("compare to nil false", func(t *testing.T) { + r := exec(t, "@if(val == nil) {is nil}", map[string]any{"val": 42}) + if strings.Contains(r.SQL, "is nil") { + t.Errorf("SQL = %q, should not contain 'is nil'", r.SQL) + } + }) + + t.Run("compare nil not equal true", func(t *testing.T) { + r := exec(t, "@if(val != nil) {not nil}", map[string]any{"val": 42}) + if !strings.Contains(r.SQL, "not nil") { + t.Errorf("SQL = %q, want 'not nil'", r.SQL) + } + }) + + t.Run("compare nil not equal false", func(t *testing.T) { + r := exec(t, "@if(val != nil) {not nil}", map[string]any{"val": nil}) + if strings.Contains(r.SQL, "not nil") { + t.Errorf("SQL = %q, should not contain 'not nil'", r.SQL) + } + }) + + t.Run("string equality comparison", func(t *testing.T) { + r := exec(t, `@if(role == "admin") {is admin}`, map[string]any{"role": "admin"}) + if !strings.Contains(r.SQL, "is admin") { + t.Errorf("SQL = %q, want 'is admin'", r.SQL) + } + + r2 := exec(t, `@if(role == "admin") {is admin}`, map[string]any{"role": "user"}) + if strings.Contains(r2.SQL, "is admin") { + t.Errorf("SQL = %q, should not contain 'is admin'", r2.SQL) + } + }) + + t.Run("numeric less than comparison", func(t *testing.T) { + r := exec(t, "@if(count < 100) {under limit}", map[string]any{"count": 50}) + if !strings.Contains(r.SQL, "under limit") { + t.Errorf("SQL = %q, want 'under limit'", r.SQL) + } + + r2 := exec(t, "@if(count < 100) {under limit}", map[string]any{"count": 150}) + if strings.Contains(r2.SQL, "under limit") { + t.Errorf("SQL = %q, should not contain 'under limit'", r2.SQL) + } + }) + + t.Run("and operator short circuit", func(t *testing.T) { + // When left is false, right should not cause error even if undefined + r := exec(t, "@if(a != nil && b != \"\") {both}", + map[string]any{"a": nil}) + if strings.Contains(r.SQL, "both") { + t.Errorf("SQL = %q, should not contain 'both' when a is nil", r.SQL) + } + }) +} + +// ---------- 4. TestErrorPaths — error handling ---------- + +func TestErrorPaths(t *testing.T) { + t.Run("unterminated param", func(t *testing.T) { + _, err := New().Parse("test", "SELECT #{id") + if err == nil { + t.Fatal("expected error for unterminated param, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + }) + + t.Run("unterminated if", func(t *testing.T) { + _, err := New().Parse("test", "@if(x > 0") + if err == nil { + t.Fatal("expected error for unterminated @if, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + }) + + t.Run("unterminated for", func(t *testing.T) { + _, err := New().Parse("test", "@for(x, range list)") + if err == nil { + t.Fatal("expected error for unterminated @for (missing {), got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + }) + + t.Run("empty tpl name", func(t *testing.T) { + _, err := New().Parse("test", `@tpl("") { SELECT 1 }`) + if err == nil { + t.Fatal("expected error for empty @tpl name, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + if !strings.Contains(err.Error(), "empty") { + t.Errorf("error should mention empty: %v", err) + } + }) + + t.Run("duplicate tpl names", func(t *testing.T) { + _, err := New().Parse("test", `@tpl("a") {X} @tpl("a") {Y}`) + if err == nil { + t.Fatal("expected error for duplicate @tpl names, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + if !strings.Contains(err.Error(), "duplicate") { + t.Errorf("error should mention duplicate: %v", err) + } + }) + + t.Run("namespace not at top", func(t *testing.T) { + _, err := New().Parse("test", `SELECT 1 @namespace("x")`) + if err == nil { + t.Fatal("expected error for @namespace not at top, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + if !strings.Contains(err.Error(), "top") { + t.Errorf("error should mention top: %v", err) + } + }) + + t.Run("tpl nested inside tpl", func(t *testing.T) { + _, err := New().Parse("test", `@tpl("a") { @tpl("b") {X} }`) + if err == nil { + t.Fatal("expected error for nested @tpl, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + if !strings.Contains(err.Error(), "nested") { + t.Errorf("error should mention nested: %v", err) + } + }) + + t.Run("undefined variable in strict mode", func(t *testing.T) { + tpl, _ := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}") + _, err := tpl.Execute(map[string]any{}) + if err == nil { + t.Fatal("expected error for undefined variable in strict mode") + } + // The executor returns a plain error (not wrapped in ExecError) + if !strings.Contains(err.Error(), "undefined") { + t.Errorf("error should mention undefined: %v", err) + } + }) + + t.Run("for with non-slice value", func(t *testing.T) { + tpl, err := New().Parse("test", "SELECT @for(x, range items) {#{x}, }") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(map[string]any{"items": "not a slice"}) + if err == nil { + t.Fatal("expected error for non-slice in @for, got nil") + } + if !strings.Contains(err.Error(), "slice") { + t.Errorf("error should mention slice: %v", err) + } + }) + + t.Run("unterminated raw param", func(t *testing.T) { + _, err := New().Parse("test", "SELECT ${col") + if err == nil { + t.Fatal("expected error for unterminated raw param, got nil") + } + var parseErr *ParseError + if !errors.As(err, &parseErr) { + t.Errorf("expected ParseError, got %T: %v", err, err) + } + }) +} + +// ---------- 5. TestEdgeCases — boundary conditions ---------- + +func TestEdgeCases(t *testing.T) { + t.Run("empty template", func(t *testing.T) { + r := exec(t, "", nil) + if r.SQL != "" { + t.Errorf("SQL = %q, want empty", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("pure text", func(t *testing.T) { + r := exec(t, "SELECT 1", nil) + if r.SQL != "SELECT 1" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1") + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("only comments", func(t *testing.T) { + r := exec(t, "# comment", nil) + if r.SQL != "" { + t.Errorf("SQL = %q, want empty", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("multiple consecutive comments", func(t *testing.T) { + r := exec(t, "# line1\n# line2\nSELECT 1", nil) + if r.SQL != "SELECT 1" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1") + } + }) + + t.Run("mixed params and raw", func(t *testing.T) { + r := exec(t, "SELECT ${col} FROM t WHERE id = #{id}", + map[string]any{"col": "name", "id": 1}) + if !strings.Contains(r.SQL, "SELECT name FROM t WHERE id = ?") { + t.Errorf("SQL = %q, want 'SELECT name FROM t WHERE id = ?'", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != 1 { + t.Errorf("Args = %v, want [1]", r.Args) + } + }) + + t.Run("nested if both true", func(t *testing.T) { + r := exec(t, "@if(a) { @if(b) {both} }", + map[string]any{"a": true, "b": true}) + if !strings.Contains(r.SQL, "both") { + t.Errorf("SQL = %q, want 'both'", r.SQL) + } + }) + + t.Run("nested if inner false", func(t *testing.T) { + r := exec(t, "@if(a) { @if(b) {both} }", + map[string]any{"a": true, "b": false}) + if strings.Contains(r.SQL, "both") { + t.Errorf("SQL = %q, should not contain 'both'", r.SQL) + } + }) + + t.Run("nested if outer false", func(t *testing.T) { + r := exec(t, "@if(a) { @if(b) {both} }", + map[string]any{"a": false, "b": true}) + if strings.Contains(r.SQL, "both") { + t.Errorf("SQL = %q, should not contain 'both'", r.SQL) + } + }) + + t.Run("if with parenthesized expression", func(t *testing.T) { + src := "@if((a || b) && c) {yes}" + tpl, _ := New().Parse("test", src) + + r, _ := tpl.Execute(map[string]any{"a": true, "b": false, "c": true}) + if !strings.Contains(r.SQL, "yes") { + t.Errorf("a=true,b=false,c=true: SQL = %q, want 'yes'", r.SQL) + } + + r, _ = tpl.Execute(map[string]any{"a": false, "b": true, "c": true}) + if !strings.Contains(r.SQL, "yes") { + t.Errorf("a=false,b=true,c=true: SQL = %q, want 'yes'", r.SQL) + } + + r, _ = tpl.Execute(map[string]any{"a": false, "b": false, "c": true}) + if strings.Contains(r.SQL, "yes") { + t.Errorf("a=false,b=false,c=true: SQL = %q, should not contain 'yes'", r.SQL) + } + + r, _ = tpl.Execute(map[string]any{"a": true, "b": true, "c": false}) + if strings.Contains(r.SQL, "yes") { + t.Errorf("a=true,b=true,c=false: SQL = %q, should not contain 'yes'", r.SQL) + } + }) + + t.Run("for with index variable", func(t *testing.T) { + r := exec(t, "@for(i, v range items) {#{i}:#{v}, }", + map[string]any{"items": []any{"x", "y", "z"}}) + if len(r.Args) != 6 { + t.Fatalf("Args = %v, want 6 args (3 indices + 3 values)", r.Args) + } + // Indices: 0, 1, 2 + if r.Args[0] != 0 || r.Args[2] != 1 || r.Args[4] != 2 { + t.Errorf("index args wrong: %v", r.Args) + } + // Values: x, y, z + if r.Args[1] != "x" || r.Args[3] != "y" || r.Args[5] != "z" { + t.Errorf("value args wrong: %v", r.Args) + } + }) + + t.Run("comment does not interfere with param", func(t *testing.T) { + // #{param} starts with # but is not a comment + r := exec(t, "SELECT #{id}\n# real comment", + map[string]any{"id": 42}) + if !strings.Contains(r.SQL, "SELECT ?") { + t.Errorf("SQL = %q, want 'SELECT ?'", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != 42 { + t.Errorf("Args = %v, want [42]", r.Args) + } + }) + + t.Run("raw with numeric value", func(t *testing.T) { + r := exec(t, "SELECT ${n}", map[string]any{"n": 123}) + if !strings.Contains(r.SQL, "SELECT 123") { + t.Errorf("SQL = %q, want 'SELECT 123'", r.SQL) + } + }) + + t.Run("multiple consecutive params", func(t *testing.T) { + r := exec(t, "INSERT INTO t (a, b, c) VALUES (#{a}, #{b}, #{c})", + map[string]any{"a": 1, "b": 2, "c": 3}) + if !strings.Contains(r.SQL, "?, ?, ?") { + t.Errorf("SQL = %q, want three placeholders", r.SQL) + } + if len(r.Args) != 3 { + t.Errorf("Args = %v, want 3", r.Args) + } + }) +} + +// ---------- 6. TestMultiline — real SQL with line breaks ---------- + +func TestMultiline(t *testing.T) { + t.Run("full SELECT with multiple lines", func(t *testing.T) { + src := `SELECT u.id, u.name, u.email +FROM users u +WHERE u.status = #{status} + @if(name != "") { AND u.name LIKE #{name} } +ORDER BY u.id DESC +LIMIT #{limit}` + tpl, err := New().Parse("test", src) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + r, err := tpl.Execute(map[string]any{ + "status": "active", + "name": "%alice%", + "limit": 10, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + + if !strings.Contains(r.SQL, "SELECT u.id, u.name, u.email") { + t.Errorf("SQL missing SELECT: %q", r.SQL) + } + if !strings.Contains(r.SQL, "FROM users u") { + t.Errorf("SQL missing FROM: %q", r.SQL) + } + if !strings.Contains(r.SQL, "WHERE u.status = ?") { + t.Errorf("SQL missing WHERE: %q", r.SQL) + } + if !strings.Contains(r.SQL, "AND u.name LIKE ?") { + t.Errorf("SQL missing name condition: %q", r.SQL) + } + if !strings.Contains(r.SQL, "ORDER BY u.id DESC") { + t.Errorf("SQL missing ORDER BY: %q", r.SQL) + } + if !strings.Contains(r.SQL, "LIMIT ?") { + t.Errorf("SQL missing LIMIT: %q", r.SQL) + } + if len(r.Args) != 3 || r.Args[0] != "active" || r.Args[1] != "%alice%" || r.Args[2] != 10 { + t.Errorf("Args = %v, want [active %%alice%% 10]", r.Args) + } + }) + + t.Run("line breaks preserved", func(t *testing.T) { + src := "SELECT id\nFROM users\nWHERE id = #{id}" + r := exec(t, src, map[string]any{"id": 1}) + lines := strings.Split(r.SQL, "\n") + if len(lines) < 3 { + t.Errorf("SQL = %q, want at least 3 lines", r.SQL) + } + if !strings.Contains(r.SQL, "SELECT id\nFROM users\nWHERE id = ?") { + t.Errorf("SQL = %q, line breaks not preserved", r.SQL) + } + }) + + t.Run("multiline dynamic update", func(t *testing.T) { + src := `UPDATE users SET + @if(name != nil) { name = #{name}, } + @if(email != nil) { email = #{email}, } + updated_at = NOW() +WHERE id = #{id}` + r := exec(t, src, map[string]any{"name": "alice", "email": nil, "id": 1}) + if !strings.Contains(r.SQL, "name = ?") { + t.Errorf("SQL missing name: %q", r.SQL) + } + if strings.Contains(r.SQL, "email = ?") { + t.Errorf("SQL should not contain email: %q", r.SQL) + } + if !strings.Contains(r.SQL, "WHERE id = ?") { + t.Errorf("SQL missing WHERE: %q", r.SQL) + } + if len(r.Args) != 2 { + t.Errorf("Args = %v, want 2 args", r.Args) + } + }) +} + +// ---------- 7. TestExecuteString — convenience methods ---------- + +func TestExecuteString(t *testing.T) { + t.Run("ExecuteString returns SQL only", func(t *testing.T) { + tpl, err := New().Parse("test", "SELECT * FROM users WHERE id = #{id}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + sql, err := tpl.ExecuteString(map[string]any{"id": 42}) + if err != nil { + t.Fatalf("ExecuteString failed: %v", err) + } + if sql != "SELECT * FROM users WHERE id = ?" { + t.Errorf("SQL = %q, want %q", sql, "SELECT * FROM users WHERE id = ?") + } + }) + + t.Run("ExecuteBlockString returns SQL only", func(t *testing.T) { + tpl, err := New().Parse("test", "@tpl(\"search\") {\nSELECT * FROM orders WHERE id = #{id}\n}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + sql, err := tpl.ExecuteBlockString("search", map[string]any{"id": 7}) + if err != nil { + t.Fatalf("ExecuteBlockString failed: %v", err) + } + if !strings.Contains(sql, "SELECT * FROM orders WHERE id = ?") { + t.Errorf("SQL = %q, want to contain 'SELECT * FROM orders WHERE id = ?'", sql) + } + }) + + t.Run("ExecuteString error propagation", func(t *testing.T) { + tpl, _ := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}") + _, err := tpl.ExecuteString(map[string]any{}) + if err == nil { + t.Fatal("expected error for undefined variable, got nil") + } + }) + + t.Run("ExecuteBlockString error for missing block", func(t *testing.T) { + tpl, _ := New().Parse("test", `@tpl("search") {SELECT 1}`) + _, err := tpl.ExecuteBlockString("nonexistent", nil) + if err == nil { + t.Fatal("expected error for missing block, got nil") + } + }) +} + +// ---------- 8. TestBenchmarkConsistency ---------- + +func TestBenchmarkConsistency(t *testing.T) { + t.Run("simple benchmark scenario", func(t *testing.T) { + tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") + result, err := tpl.Execute(map[string]any{"id": 42, "name": "alice"}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if result.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" { + t.Errorf("SQL = %q", result.SQL) + } + if len(result.Args) != 2 || result.Args[0] != 42 || result.Args[1] != "alice" { + t.Errorf("Args = %v, want [42 alice]", result.Args) + } + }) + + t.Run("conditional benchmark scenario", func(t *testing.T) { + src := `SELECT * FROM users WHERE 1=1 +@if(status != nil) { AND status = #{status} } +@if(name != "") { AND name = #{name} } +ORDER BY id` + tpl := New().MustParse("bench", src) + + // all conditions active + result, err := tpl.Execute(map[string]any{"status": "active", "name": "alice"}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(result.SQL, "AND status = ?") { + t.Errorf("SQL missing status: %q", result.SQL) + } + if !strings.Contains(result.SQL, "AND name = ?") { + t.Errorf("SQL missing name: %q", result.SQL) + } + if len(result.Args) != 2 { + t.Errorf("Args = %v, want 2 args", result.Args) + } + + // no conditions + result2, err := tpl.Execute(map[string]any{"status": nil, "name": ""}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if strings.Contains(result2.SQL, "AND") { + t.Errorf("SQL should have no AND: %q", result2.SQL) + } + if len(result2.Args) != 0 { + t.Errorf("Args = %v, want empty", result2.Args) + } + }) + + t.Run("loop benchmark scenario", func(t *testing.T) { + src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })` + tpl := New().MustParse("bench", src) + ids := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + result, err := tpl.Execute(map[string]any{"ids": ids}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if len(result.Args) != 10 { + t.Errorf("Args = %v, want 10 args", result.Args) + } + for i, want := range ids { + if result.Args[i] != want { + t.Errorf("Args[%d] = %v, want %v", i, result.Args[i], want) + } + } + if strings.HasSuffix(result.SQL, ",") { + t.Errorf("SQL should not end with comma: %q", result.SQL) + } + }) + + t.Run("placeholder dollar benchmark scenario", func(t *testing.T) { + eng := New(WithPlaceholderStyle(DollarNumber)) + tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") + result, err := tpl.Execute(map[string]any{"id": 42, "name": "alice"}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(result.SQL, "WHERE id = $1 AND name = $2") { + t.Errorf("SQL = %q, want $1/$2 placeholders", result.SQL) + } + if len(result.Args) != 2 { + t.Errorf("Args = %v, want 2 args", result.Args) + } + }) + + t.Run("placeholder colon benchmark scenario", func(t *testing.T) { + eng := New(WithPlaceholderStyle(ColonNumber)) + tpl := eng.MustParse("bench", "WHERE id = #{id} AND name = #{name}") + result, err := tpl.Execute(map[string]any{"id": 1, "name": "a"}) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(result.SQL, "WHERE id = :1 AND name = :2") { + t.Errorf("SQL = %q, want :1/:2 placeholders", result.SQL) + } + }) +} \ No newline at end of file diff --git a/utpl_test.go b/utpl_test.go new file mode 100644 index 0000000..2b3b6d3 --- /dev/null +++ b/utpl_test.go @@ -0,0 +1,660 @@ +package utpl + +import ( + "errors" + "strings" + "testing" +) + +// helper to parse with default engine +func parse(t *testing.T, source string) *Template { + t.Helper() + tpl, err := New().Parse("test", source) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + return tpl +} + +// helper to parse with options and execute +func exec(t *testing.T, source string, vars map[string]any, opts ...Option) *Result { + t.Helper() + tpl, err := New(opts...).Parse("test", source) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + result, err := tpl.Execute(vars) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + return result +} + +// ---------- 1. TestParam ---------- + +func TestParam(t *testing.T) { + t.Run("single param", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE id = #{id}", map[string]any{"id": 42}) + if r.SQL != "SELECT * FROM users WHERE id = ?" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT * FROM users WHERE id = ?") + } + if len(r.Args) != 1 || r.Args[0] != 42 { + t.Errorf("Args = %v, want [42]", r.Args) + } + }) + + t.Run("multiple params", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE id = #{id} AND name = #{name}", + map[string]any{"id": 1, "name": "alice"}) + if r.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" { + t.Errorf("SQL = %q", r.SQL) + } + if len(r.Args) != 2 || r.Args[0] != 1 || r.Args[1] != "alice" { + t.Errorf("Args = %v, want [1 alice]", r.Args) + } + }) + + t.Run("nil value param", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE name = #{name}", map[string]any{"name": nil}) + if r.SQL != "SELECT * FROM users WHERE name = ?" { + t.Errorf("SQL = %q", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != nil { + t.Errorf("Args = %v, want [nil]", r.Args) + } + }) + + t.Run("param with nested dot path", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE name = #{user.name}", + map[string]any{"user": map[string]any{"name": "bob"}}) + if r.SQL != "SELECT * FROM users WHERE name = ?" { + t.Errorf("SQL = %q", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != "bob" { + t.Errorf("Args = %v, want [bob]", r.Args) + } + }) +} + +// ---------- 2. TestRaw ---------- + +func TestRaw(t *testing.T) { + t.Run("basic raw substitution", func(t *testing.T) { + r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}) + if r.SQL != "SELECT name FROM users" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT name FROM users") + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("multiple raw params", func(t *testing.T) { + r := exec(t, "SELECT ${a}, ${b} FROM ${table}", + map[string]any{"a": "id", "b": "name", "table": "users"}) + if r.SQL != "SELECT id, name FROM users" { + t.Errorf("SQL = %q", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("allowlist pass", func(t *testing.T) { + policy := RawAllowlist{ + "col": {"name", "age"}, + } + r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}, + WithRawPolicy(policy)) + if r.SQL != "SELECT name FROM users" { + t.Errorf("SQL = %q", r.SQL) + } + }) + + t.Run("allowlist reject", func(t *testing.T) { + policy := RawAllowlist{ + "col": {"name", "age"}, + } + tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users") + _, err := tpl.Execute(map[string]any{"col": "password"}) + if err == nil { + t.Fatal("expected error for blocked raw value, got nil") + } + var unsafeErr *UnsafeRawError + if !errors.As(err, &unsafeErr) { + t.Errorf("expected UnsafeRawError, got %T: %v", err, err) + } + }) + + t.Run("blocklist allow", func(t *testing.T) { + policy := RawBlocklist{ + "col": {"password"}, + } + r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}, + WithRawPolicy(policy)) + if r.SQL != "SELECT name FROM users" { + t.Errorf("SQL = %q", r.SQL) + } + }) + + t.Run("blocklist reject", func(t *testing.T) { + policy := RawBlocklist{ + "col": {"password", "secret"}, + } + tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users") + _, err := tpl.Execute(map[string]any{"col": "password"}) + if err == nil { + t.Fatal("expected error for blocklisted value, got nil") + } + var unsafeErr *UnsafeRawError + if !errors.As(err, &unsafeErr) { + t.Errorf("expected UnsafeRawError, got %T: %v", err, err) + } + }) +} + +// ---------- 3. TestComment ---------- + +func TestComment(t *testing.T) { + t.Run("comment stripped", func(t *testing.T) { + r := exec(t, "# this is a comment\nSELECT 1", nil) + if r.SQL != "SELECT 1" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1") + } + }) + + t.Run("comment in middle", func(t *testing.T) { + r := exec(t, "SELECT 1\n# comment\nSELECT 2", nil) + want := "SELECT 1\nSELECT 2" + if r.SQL != want { + t.Errorf("SQL = %q, want %q", r.SQL, want) + } + }) + + t.Run("no false positive with param", func(t *testing.T) { + r := exec(t, "SELECT #{id} FROM users", map[string]any{"id": 1}) + if r.SQL != "SELECT ? FROM users" { + t.Errorf("SQL = %q, want %q", r.SQL, "SELECT ? FROM users") + } + if len(r.Args) != 1 || r.Args[0] != 1 { + t.Errorf("Args = %v", r.Args) + } + }) +} + +// ---------- 4. TestIf ---------- + +func TestIf(t *testing.T) { + t.Run("true condition", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }", + map[string]any{"status": "active"}) + if !strings.Contains(r.SQL, "AND status = ?") { + t.Errorf("SQL = %q, want to contain 'AND status = ?'", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != "active" { + t.Errorf("Args = %v, want [active]", r.Args) + } + }) + + t.Run("false condition nil", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }", + map[string]any{"status": nil}) + if strings.Contains(r.SQL, "AND status") { + t.Errorf("SQL = %q, should not contain 'AND status'", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("else branch", func(t *testing.T) { + r := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}", + map[string]any{"desc": true}) + if !strings.Contains(r.SQL, "DESC") { + t.Errorf("SQL = %q, want DESC", r.SQL) + } + if strings.Contains(r.SQL, "ASC") { + t.Errorf("SQL = %q, should not contain ASC", r.SQL) + } + + r2 := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}", + map[string]any{"desc": false}) + if !strings.Contains(r2.SQL, "ASC") { + t.Errorf("SQL = %q, want ASC", r2.SQL) + } + }) + + t.Run("empty string check", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }", + map[string]any{"name": "alice"}) + if !strings.Contains(r.SQL, "AND name = ?") { + t.Errorf("SQL = %q, want 'AND name = ?'", r.SQL) + } + + r2 := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }", + map[string]any{"name": ""}) + if strings.Contains(r2.SQL, "AND name") { + t.Errorf("SQL = %q, should not contain 'AND name'", r2.SQL) + } + }) + + t.Run("multiple independent conditions", func(t *testing.T) { + r := exec(t, "@if(a != nil) {A} @if(b != nil) {B}", + map[string]any{"a": 1, "b": 2}) + if !strings.Contains(r.SQL, "A") || !strings.Contains(r.SQL, "B") { + t.Errorf("SQL = %q, want both A and B", r.SQL) + } + + r2 := exec(t, "@if(a != nil) {A} @if(b != nil) {B}", + map[string]any{"a": 1}) + if !strings.Contains(r2.SQL, "A") { + t.Errorf("SQL = %q, want A", r2.SQL) + } + if strings.Contains(r2.SQL, "B") { + t.Errorf("SQL = %q, should not contain B", r2.SQL) + } + }) + + t.Run("AND operator", func(t *testing.T) { + r := exec(t, "@if(a != nil && b != \"\") { both }", + map[string]any{"a": 1, "b": "x"}) + if !strings.Contains(r.SQL, "both") { + t.Errorf("SQL = %q, want 'both'", r.SQL) + } + + r2 := exec(t, "@if(a != nil && b != \"\") { both }", + map[string]any{"a": 1, "b": ""}) + if strings.Contains(r2.SQL, "both") { + t.Errorf("SQL = %q, should not contain 'both'", r2.SQL) + } + }) + + t.Run("OR operator", func(t *testing.T) { + r := exec(t, "@if(a != nil || b != nil) { either }", + map[string]any{"a": nil, "b": 1}) + if !strings.Contains(r.SQL, "either") { + t.Errorf("SQL = %q, want 'either'", r.SQL) + } + + r2 := exec(t, "@if(a != nil || b != nil) { either }", + map[string]any{"a": nil, "b": nil}) + if strings.Contains(r2.SQL, "either") { + t.Errorf("SQL = %q, should not contain 'either'", r2.SQL) + } + }) +} + +// ---------- 5. TestFor ---------- + +func TestFor(t *testing.T) { + t.Run("basic IN clause", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", + map[string]any{"ids": []int{1, 2, 3}}) + if len(r.Args) != 3 { + t.Fatalf("Args = %v, want 3 args", r.Args) + } + for i, want := range []int{1, 2, 3} { + if r.Args[i] != want { + t.Errorf("Args[%d] = %v, want %v", i, r.Args[i], want) + } + } + }) + + t.Run("trailing comma auto-trim", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", + map[string]any{"ids": []int{1, 2}}) + // trailing comma and space should be trimmed by the executor + if strings.HasSuffix(r.SQL, ",") { + t.Errorf("SQL = %q, should not end with comma", r.SQL) + } + }) + + t.Run("empty list", func(t *testing.T) { + r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", + map[string]any{"ids": []int{}}) + if strings.Contains(r.SQL, "?") { + t.Errorf("SQL = %q, should have no placeholders", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("nil list", func(t *testing.T) { + // nil list produces no output (treated as empty slice) + tpl := parse(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })") + r, err := tpl.Execute(map[string]any{"ids": nil}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.Contains(r.SQL, "?") { + t.Errorf("SQL = %q, should have no placeholders", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("with index variable", func(t *testing.T) { + r := exec(t, "SELECT @for(i, v range items) {#{v}, }", + map[string]any{"items": []string{"a", "b", "c"}}) + if len(r.Args) != 3 { + t.Errorf("Args = %v, want 3 args", r.Args) + } + }) + + t.Run("nested object access", func(t *testing.T) { + r := exec(t, "SELECT @for(u range users) {#{u.name}, }", + map[string]any{"users": []map[string]any{ + {"name": "alice"}, + {"name": "bob"}, + }}) + if len(r.Args) != 2 { + t.Fatalf("Args = %v, want 2 args", r.Args) + } + if r.Args[0] != "alice" || r.Args[1] != "bob" { + t.Errorf("Args = %v, want [alice bob]", r.Args) + } + }) +} + +// ---------- 6. TestTpl ---------- + +func TestTpl(t *testing.T) { + t.Run("single block", func(t *testing.T) { + tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) + r, err := tpl.ExecuteBlock("search", nil) + if err != nil { + t.Fatalf("ExecuteBlock failed: %v", err) + } + if !strings.Contains(r.SQL, "SELECT * FROM users") { + t.Errorf("SQL = %q", r.SQL) + } + }) + + t.Run("multiple blocks", func(t *testing.T) { + tpl := parse(t, `@tpl("search") { SELECT * FROM users } @tpl("count") { SELECT COUNT(*) FROM users }`) + r1, err := tpl.ExecuteBlock("search", nil) + if err != nil { + t.Fatalf("ExecuteBlock search failed: %v", err) + } + if !strings.Contains(r1.SQL, "SELECT * FROM users") { + t.Errorf("search SQL = %q", r1.SQL) + } + + r2, err := tpl.ExecuteBlock("count", nil) + if err != nil { + t.Fatalf("ExecuteBlock count failed: %v", err) + } + if !strings.Contains(r2.SQL, "SELECT COUNT(*) FROM users") { + t.Errorf("count SQL = %q", r2.SQL) + } + }) + + t.Run("execute on tpl template should error", func(t *testing.T) { + tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) + _, err := tpl.Execute(nil) + if err == nil { + t.Fatal("expected error when Execute is called on template with blocks") + } + var execErr *ExecError + if !errors.As(err, &execErr) { + t.Errorf("expected ExecError, got %T: %v", err, err) + } + }) + + t.Run("ExecuteBlock with wrong name", func(t *testing.T) { + tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) + _, err := tpl.ExecuteBlock("nonexistent", nil) + if err == nil { + t.Fatal("expected error for nonexistent block name") + } + var execErr *ExecError + if !errors.As(err, &execErr) { + t.Errorf("expected ExecError, got %T: %v", err, err) + } + }) +} + +// ---------- 7. TestNamespace ---------- + +func TestNamespace(t *testing.T) { + t.Run("namespaced block access", func(t *testing.T) { + tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) + r, err := tpl.ExecuteBlock("orders.search", nil) + if err != nil { + t.Fatalf("ExecuteBlock orders.search failed: %v", err) + } + if !strings.Contains(r.SQL, "SELECT * FROM orders") { + t.Errorf("SQL = %q", r.SQL) + } + }) + + t.Run("short name fallback", func(t *testing.T) { + tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) + r, err := tpl.ExecuteBlock("search", nil) + if err != nil { + t.Fatalf("ExecuteBlock search failed: %v", err) + } + if !strings.Contains(r.SQL, "SELECT * FROM orders") { + t.Errorf("SQL = %q", r.SQL) + } + }) + + t.Run("same block names different namespaces", func(t *testing.T) { + tpl1 := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) + tpl2 := parse(t, `@namespace("users") @tpl("search") { SELECT * FROM users }`) + + r1, err := tpl1.ExecuteBlock("orders.search", nil) + if err != nil { + t.Fatalf("orders.search failed: %v", err) + } + if !strings.Contains(r1.SQL, "orders") { + t.Errorf("SQL = %q, want orders table", r1.SQL) + } + + r2, err := tpl2.ExecuteBlock("users.search", nil) + if err != nil { + t.Fatalf("users.search failed: %v", err) + } + if !strings.Contains(r2.SQL, "users") { + t.Errorf("SQL = %q, want users table", r2.SQL) + } + }) +} + +// ---------- 8. TestPlaceholderStyles ---------- + +func TestPlaceholderStyles(t *testing.T) { + t.Run("QuestionMark default", func(t *testing.T) { + r := exec(t, "WHERE id = #{id}", map[string]any{"id": 1}) + if !strings.Contains(r.SQL, "WHERE id = ?") { + t.Errorf("SQL = %q, want 'WHERE id = ?'", r.SQL) + } + }) + + t.Run("DollarNumber", func(t *testing.T) { + r := exec(t, "WHERE id = #{id} AND name = #{name}", + map[string]any{"id": 1, "name": "a"}, + WithPlaceholderStyle(DollarNumber)) + if !strings.Contains(r.SQL, "WHERE id = $1 AND name = $2") { + t.Errorf("SQL = %q, want 'WHERE id = $1 AND name = $2'", r.SQL) + } + }) + + t.Run("ColonNumber", func(t *testing.T) { + r := exec(t, "WHERE id = #{id} AND name = #{name}", + map[string]any{"id": 1, "name": "a"}, + WithPlaceholderStyle(ColonNumber)) + if !strings.Contains(r.SQL, "WHERE id = :1 AND name = :2") { + t.Errorf("SQL = %q, want 'WHERE id = :1 AND name = :2'", r.SQL) + } + }) +} + +// ---------- 9. TestStrictMode ---------- + +func TestStrictMode(t *testing.T) { + t.Run("strict true undefined param", func(t *testing.T) { + tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(map[string]any{}) + if err == nil { + t.Fatal("expected error for undefined variable in strict mode") + } + }) + + t.Run("strict false undefined param", func(t *testing.T) { + r := exec(t, "SELECT #{id} FROM users", map[string]any{}, + WithStrictMode(false)) + if !strings.Contains(r.SQL, "SELECT ? FROM users") { + t.Errorf("SQL = %q, want 'SELECT ? FROM users'", r.SQL) + } + if len(r.Args) != 1 || r.Args[0] != nil { + t.Errorf("Args = %v, want [nil]", r.Args) + } + }) + + t.Run("strict false undefined raw", func(t *testing.T) { + r := exec(t, "SELECT ${col} FROM users", map[string]any{}, + WithStrictMode(false)) + if r.SQL != "SELECT FROM users" { + t.Errorf("SQL = %q, want 'SELECT FROM users'", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) +} + +// ---------- 10. TestComplex ---------- + +func TestComplex(t *testing.T) { + t.Run("conditional WHERE with 1=1", func(t *testing.T) { + src := `SELECT * FROM users WHERE 1=1 +@if(status != nil) { AND status = #{status} } +@if(name != "") { AND name = #{name} }` + + r := exec(t, src, map[string]any{"status": "active", "name": "alice"}) + if !strings.Contains(r.SQL, "AND status = ?") { + t.Errorf("SQL missing status condition: %q", r.SQL) + } + if !strings.Contains(r.SQL, "AND name = ?") { + t.Errorf("SQL missing name condition: %q", r.SQL) + } + if len(r.Args) != 2 { + t.Errorf("Args = %v, want 2 args", r.Args) + } + + r2 := exec(t, src, map[string]any{"status": nil, "name": ""}) + if strings.Contains(r2.SQL, "AND") { + t.Errorf("SQL should have no AND: %q", r2.SQL) + } + if len(r2.Args) != 0 { + t.Errorf("Args = %v, want empty", r2.Args) + } + }) + + t.Run("dynamic UPDATE with trailing commas", func(t *testing.T) { + src := `UPDATE users SET +@if(name != nil) { name = #{name}, } +@if(email != nil) { email = #{email}, } +WHERE id = #{id}` + + r := exec(t, src, map[string]any{"name": "alice", "email": "a@b.com", "id": 1}) + if !strings.Contains(r.SQL, "name = ?") { + t.Errorf("SQL missing name: %q", r.SQL) + } + if !strings.Contains(r.SQL, "email = ?") { + t.Errorf("SQL missing email: %q", r.SQL) + } + if strings.Contains(r.SQL, ",\nWHERE") { + t.Errorf("trailing comma not trimmed: %q", r.SQL) + } + if len(r.Args) != 3 { + t.Errorf("Args = %v, want 3 args", r.Args) + } + }) + + t.Run("batch INSERT with @for", func(t *testing.T) { + src := `INSERT INTO users (name, age) VALUES +@for(u range users) { (#{u.name}, #{u.age}), }` + + r := exec(t, src, map[string]any{"users": []map[string]any{ + {"name": "alice", "age": 30}, + {"name": "bob", "age": 25}, + }}) + if len(r.Args) != 4 { + t.Fatalf("Args = %v, want 4 args", r.Args) + } + if r.Args[0] != "alice" || r.Args[1] != 30 || r.Args[2] != "bob" || r.Args[3] != 25 { + t.Errorf("Args = %v, want [alice 30 bob 25]", r.Args) + } + }) + + t.Run("ORDER BY with @if desc/asc", func(t *testing.T) { + src := `SELECT * FROM users ORDER BY id @if(asc) {ASC} else {DESC}` + r := exec(t, src, map[string]any{"asc": true}) + if !strings.Contains(r.SQL, "ASC") { + t.Errorf("SQL = %q, want ASC", r.SQL) + } + if strings.Contains(r.SQL, "DESC") { + t.Errorf("SQL = %q, should not contain DESC", r.SQL) + } + + r2 := exec(t, src, map[string]any{"asc": false}) + if !strings.Contains(r2.SQL, "DESC") { + t.Errorf("SQL = %q, want DESC", r2.SQL) + } + }) +} + +// ---------- Benchmarks ---------- + +func BenchmarkSimple(b *testing.B) { + tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") + vars := map[string]any{"id": 42, "name": "alice"} + b.ResetTimer() + for b.Loop() { + _, _ = tpl.Execute(vars) + } +} + +func BenchmarkConditional(b *testing.B) { + src := `SELECT * FROM users WHERE 1=1 +@if(status != nil) { AND status = #{status} } +@if(name != "") { AND name = #{name} } +ORDER BY id` + tpl := New().MustParse("bench", src) + vars := map[string]any{"status": "active", "name": "alice"} + b.ResetTimer() + for b.Loop() { + _, _ = tpl.Execute(vars) + } +} + +func BenchmarkLoop(b *testing.B) { + src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })` + tpl := New().MustParse("bench", src) + ids := make([]int, 10) + for i := range ids { + ids[i] = i + } + vars := map[string]any{"ids": ids} + b.ResetTimer() + for b.Loop() { + _, _ = tpl.Execute(vars) + } +} + +func BenchmarkPlaceholderDollar(b *testing.B) { + eng := New(WithPlaceholderStyle(DollarNumber)) + tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") + vars := map[string]any{"id": 42, "name": "alice"} + b.ResetTimer() + for b.Loop() { + _, _ = tpl.Execute(vars) + } +} \ No newline at end of file