新增: u-tpl SQL 模板引擎完整实现
- Lexer/Parser/Executor 三阶段架构
- #{param} 参数化 + ${raw} 原样替换 + 白名单安全策略
- @if/@for/@tpl/@include/@namespace 控制流
- 表达式引擎: 比较、逻辑、nil 检查、len() 内置函数
- 支持 ?/$1/:1 多数据库占位符风格
- 零依赖,纯 Go 标准库实现
This commit is contained in:
5
.gitattributes
vendored
Normal file
5
.gitattributes
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
* text=auto eol=lf
|
||||||
|
*.go text eol=lf
|
||||||
|
*.md text eol=lf
|
||||||
|
*.sql text eol=lf
|
||||||
|
*.tpl text eol=lf
|
||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1 +1,7 @@
|
|||||||
.claude/
|
.claude/
|
||||||
|
.idea/
|
||||||
|
*.exe
|
||||||
|
*.test
|
||||||
|
*.out
|
||||||
|
coverage.out
|
||||||
|
vendor/
|
||||||
|
|||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
19
Makefile
Normal file
19
Makefile
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.PHONY: build test bench lint vet clean
|
||||||
|
|
||||||
|
build:
|
||||||
|
go build ./...
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./... -v
|
||||||
|
|
||||||
|
bench:
|
||||||
|
go test ./... -bench=. -benchmem
|
||||||
|
|
||||||
|
lint: vet
|
||||||
|
gofmt -l .
|
||||||
|
|
||||||
|
vet:
|
||||||
|
go vet ./...
|
||||||
|
|
||||||
|
clean:
|
||||||
|
go clean -testcache
|
||||||
@@ -27,7 +27,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/nicedoc/utpl"
|
"gitea.1216.top/lxy/u-tpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -52,7 +52,7 @@ func main() {
|
|||||||
## 安装
|
## 安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go get github.com/nicedoc/utpl
|
go get gitea.1216.top/lxy/u-tpl
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -673,7 +673,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"github.com/nicedoc/utpl"
|
"gitea.1216.top/lxy/u-tpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed tpl/order_search.tpl
|
//go:embed tpl/order_search.tpl
|
||||||
|
|||||||
121
engine.go
Normal file
121
engine.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.1216.top/lxy/u-tpl/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PlaceholderStyle = internal.PlaceholderStyle
|
||||||
|
|
||||||
|
const (
|
||||||
|
QuestionMark PlaceholderStyle = internal.QuestionMark
|
||||||
|
DollarNumber = internal.DollarNumber
|
||||||
|
ColonNumber = internal.ColonNumber
|
||||||
|
)
|
||||||
|
|
||||||
|
type IncludeResolver func(path string) (string, error)
|
||||||
|
|
||||||
|
type Option func(*Engine)
|
||||||
|
|
||||||
|
type Engine struct {
|
||||||
|
style internal.PlaceholderStyle
|
||||||
|
rawPolicy RawPolicy
|
||||||
|
includeResolver IncludeResolver
|
||||||
|
strict bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(opts ...Option) *Engine {
|
||||||
|
e := &Engine{
|
||||||
|
style: internal.QuestionMark,
|
||||||
|
strict: true,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(e)
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPlaceholderStyle(style PlaceholderStyle) Option {
|
||||||
|
return func(e *Engine) { e.style = style }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRawPolicy(policy RawPolicy) Option {
|
||||||
|
return func(e *Engine) { e.rawPolicy = policy }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithIncludeResolver(resolver IncludeResolver) Option {
|
||||||
|
return func(e *Engine) { e.includeResolver = resolver }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStrictMode(strict bool) Option {
|
||||||
|
return func(e *Engine) { e.strict = strict }
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) Parse(name string, source string) (*Template, error) {
|
||||||
|
lexer := internal.NewLexer(source)
|
||||||
|
tokens, err := lexer.Tokenize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapParseError(err, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var includeMgr *internal.IncludeManager
|
||||||
|
if e.includeResolver != nil {
|
||||||
|
includeMgr = internal.NewIncludeManager(internal.IncludeResolver(e.includeResolver))
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := internal.NewParser(source, tokens, includeMgr)
|
||||||
|
nodes, err := parser.Parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapParseError(err, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace := ""
|
||||||
|
blocks := make(map[string][]internal.Node)
|
||||||
|
var bodyNodes []internal.Node
|
||||||
|
|
||||||
|
hasBlocks := false
|
||||||
|
for _, n := range nodes {
|
||||||
|
if ns, ok := n.(*internal.NamespaceNode); ok {
|
||||||
|
namespace = ns.Name
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if blk, ok := n.(*internal.BlockNode); ok {
|
||||||
|
hasBlocks = true
|
||||||
|
fullName := blk.Name
|
||||||
|
if namespace != "" {
|
||||||
|
fullName = namespace + "." + blk.Name
|
||||||
|
}
|
||||||
|
blocks[fullName] = blk.Body
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bodyNodes = append(bodyNodes, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Template{
|
||||||
|
name: name,
|
||||||
|
engine: e,
|
||||||
|
nodes: bodyNodes,
|
||||||
|
blocks: blocks,
|
||||||
|
hasBlocks: hasBlocks,
|
||||||
|
namespace: namespace,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) MustParse(name string, source string) *Template {
|
||||||
|
tpl, err := e.Parse(name, source)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return tpl
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapParseError(err error, name string) error {
|
||||||
|
if _, ok := err.(*ParseError); ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return &ParseError{
|
||||||
|
Pos: Position{Line: 0, Column: 0},
|
||||||
|
Message: fmt.Sprintf("template %q: %s", name, err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
42
error.go
Normal file
42
error.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Position struct {
|
||||||
|
Line int
|
||||||
|
Column int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Position) String() string {
|
||||||
|
return fmt.Sprintf("line %d, column %d", p.Line, p.Column)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParseError struct {
|
||||||
|
Message string
|
||||||
|
Pos Position
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ParseError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s (token: %q)", e.Pos, e.Message, e.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExecError struct {
|
||||||
|
Pos Position
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ExecError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s", e.Pos, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnsafeRawError struct {
|
||||||
|
Message string
|
||||||
|
Pos Position
|
||||||
|
Param string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e UnsafeRawError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %s (param: %q, value: %q)", e.Pos, e.Message, e.Param, e.Value)
|
||||||
|
}
|
||||||
30
internal/builtin.go
Normal file
30
internal/builtin.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
type BuiltinFunc func(args []any) (any, bool)
|
||||||
|
|
||||||
|
var builtins = map[string]BuiltinFunc{
|
||||||
|
"len": builtinLen,
|
||||||
|
}
|
||||||
|
|
||||||
|
func builtinLen(args []any) (any, bool) {
|
||||||
|
if len(args) != 1 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if args[0] == nil {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(args[0])
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array, reflect.String, reflect.Map, reflect.Chan:
|
||||||
|
return v.Len(), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LookupBuiltin(name string) (BuiltinFunc, bool) {
|
||||||
|
fn, ok := builtins[name]
|
||||||
|
return fn, ok
|
||||||
|
}
|
||||||
68
internal/context.go
Normal file
68
internal/context.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Context struct {
|
||||||
|
vars map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContext(vars map[string]any) *Context {
|
||||||
|
return &Context{vars: vars}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Get(path string) (any, bool) {
|
||||||
|
if c.vars == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return resolvePath(c.vars, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePath(current any, path string) (any, bool) {
|
||||||
|
for seg := range strings.SplitSeq(path, ".") {
|
||||||
|
if current == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
switch v := current.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
var ok bool
|
||||||
|
current, ok = v[seg]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
val := reflect.ValueOf(current)
|
||||||
|
for val.Kind() == reflect.Pointer {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
if val.Kind() != reflect.Struct {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
field := val.FieldByName(seg)
|
||||||
|
if !field.IsValid() {
|
||||||
|
field = findFieldIgnoreCase(val, seg)
|
||||||
|
}
|
||||||
|
if !field.IsValid() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
current = field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return current, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFieldIgnoreCase(v reflect.Value, name string) reflect.Value {
|
||||||
|
typ := v.Type()
|
||||||
|
for i := range typ.NumField() {
|
||||||
|
f := typ.Field(i)
|
||||||
|
if !f.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(f.Name, name) {
|
||||||
|
return v.Field(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reflect.Value{}
|
||||||
|
}
|
||||||
167
internal/executor.go
Normal file
167
internal/executor.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rawValidator interface {
|
||||||
|
Validate(param string, value string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Executor struct {
|
||||||
|
style PlaceholderStyle
|
||||||
|
rawPolicy rawValidator
|
||||||
|
strict bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
SQL string
|
||||||
|
Args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExecutor(style PlaceholderStyle, rawPolicy rawValidator, strict bool) *Executor {
|
||||||
|
return &Executor{
|
||||||
|
style: style,
|
||||||
|
rawPolicy: rawPolicy,
|
||||||
|
strict: strict,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Executor) Execute(nodes []Node, vars map[string]any) (*Result, error) {
|
||||||
|
ctx := NewContext(vars)
|
||||||
|
ph := NewPlaceholder(e.style)
|
||||||
|
var sql strings.Builder
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
err := e.walk(ctx, ph, &sql, &args, nodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strings.TrimRight(sql.String(), " \t\n\r")
|
||||||
|
if len(s) > 0 && s[len(s)-1] == ',' {
|
||||||
|
s = s[:len(s)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Result{SQL: s, Args: args}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, nodes []Node) error {
|
||||||
|
for _, node := range nodes {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *TextNode:
|
||||||
|
sql.WriteString(n.Text)
|
||||||
|
|
||||||
|
case *ParamNode:
|
||||||
|
val, ok := ctx.Get(n.Name)
|
||||||
|
if !ok {
|
||||||
|
if e.strict {
|
||||||
|
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
|
||||||
|
}
|
||||||
|
val = nil
|
||||||
|
}
|
||||||
|
sql.WriteString(ph.Next())
|
||||||
|
*args = append(*args, val)
|
||||||
|
|
||||||
|
case *RawNode:
|
||||||
|
val, ok := ctx.Get(n.Name)
|
||||||
|
if !ok {
|
||||||
|
if e.strict {
|
||||||
|
return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name)
|
||||||
|
}
|
||||||
|
val = ""
|
||||||
|
}
|
||||||
|
strVal, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
strVal = fmt.Sprint(val)
|
||||||
|
}
|
||||||
|
if e.rawPolicy != nil {
|
||||||
|
if err := e.rawPolicy.Validate(n.Name, strVal); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql.WriteString(strVal)
|
||||||
|
|
||||||
|
case *IfNode:
|
||||||
|
err := e.walkIf(ctx, ph, sql, args, n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *ForNode:
|
||||||
|
err := e.walkFor(ctx, ph, sql, args, n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case *BlockNode, *NamespaceNode, *IncludeNode, *CommentNode:
|
||||||
|
// skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Executor) walkIf(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *IfNode) error {
|
||||||
|
condVal, err := Eval(n.Cond, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isTruthy(condVal) {
|
||||||
|
return e.walk(ctx, ph, sql, args, n.Body)
|
||||||
|
}
|
||||||
|
for _, branch := range n.ElseIf {
|
||||||
|
condVal, err = Eval(branch.Cond, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isTruthy(condVal) {
|
||||||
|
return e.walk(ctx, ph, sql, args, branch.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(n.Else) > 0 {
|
||||||
|
return e.walk(ctx, ph, sql, args, n.Else)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, args *[]any, n *ForNode) error {
|
||||||
|
listVal, err := Eval(n.List, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if listVal == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(listVal)
|
||||||
|
for rv.Kind() == reflect.Pointer {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var length int
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
length = rv.Len()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
childVars := make(map[string]any, len(ctx.vars)+2)
|
||||||
|
maps.Copy(childVars, ctx.vars)
|
||||||
|
if n.KeyVar != "" {
|
||||||
|
childVars[n.KeyVar] = i
|
||||||
|
}
|
||||||
|
childVars[n.ValVar] = rv.Index(i).Interface()
|
||||||
|
|
||||||
|
childCtx := NewContext(childVars)
|
||||||
|
err := e.walk(childCtx, ph, sql, args, n.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
564
internal/expr.go
Normal file
564
internal/expr.go
Normal file
@@ -0,0 +1,564 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExprParser struct {
|
||||||
|
input []rune
|
||||||
|
pos int
|
||||||
|
line int
|
||||||
|
col int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExprParser(input string, line, col int) *ExprParser {
|
||||||
|
return &ExprParser{
|
||||||
|
input: []rune(input),
|
||||||
|
pos: 0,
|
||||||
|
line: line,
|
||||||
|
col: col,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) Parse() (*Expr, error) {
|
||||||
|
expr, err := p.parseOr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return expr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseOr() (*Expr, error) {
|
||||||
|
left, err := p.parseAnd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.skipSpaces()
|
||||||
|
if !p.peekStr("||") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.skip(2)
|
||||||
|
p.skipSpaces()
|
||||||
|
right, err := p.parseAnd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "||", Right: right}
|
||||||
|
}
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseAnd() (*Expr, error) {
|
||||||
|
left, err := p.parseCompare()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.skipSpaces()
|
||||||
|
if !p.peekStr("&&") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
p.skip(2)
|
||||||
|
p.skipSpaces()
|
||||||
|
right, err := p.parseCompare()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
left = &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: "&&", Right: right}
|
||||||
|
}
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseCompare() (*Expr, error) {
|
||||||
|
left, err := p.parseUnary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.skipSpaces()
|
||||||
|
op := ""
|
||||||
|
if p.peekStr("==") {
|
||||||
|
op = "=="
|
||||||
|
p.skip(2)
|
||||||
|
} else if p.peekStr("!=") {
|
||||||
|
op = "!="
|
||||||
|
p.skip(2)
|
||||||
|
} else if p.peekStr("<=") {
|
||||||
|
op = "<="
|
||||||
|
p.skip(2)
|
||||||
|
} else if p.peekStr(">=") {
|
||||||
|
op = ">="
|
||||||
|
p.skip(2)
|
||||||
|
} else if p.peekStr("<") {
|
||||||
|
op = "<"
|
||||||
|
p.skip(1)
|
||||||
|
} else if p.peekStr(">") {
|
||||||
|
op = ">"
|
||||||
|
p.skip(1)
|
||||||
|
}
|
||||||
|
if op == "" {
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
p.skipSpaces()
|
||||||
|
right, err := p.parseUnary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprBinary, Left: left, Op: op, Right: right}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseUnary() (*Expr, error) {
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.peekStr("!") && !p.peekStr("!=") {
|
||||||
|
p.skip(1)
|
||||||
|
operand, err := p.parseUnary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprUnary, UnaryOp: "!", Operand: operand}, nil
|
||||||
|
}
|
||||||
|
return p.parsePrimary()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parsePrimary() (*Expr, error) {
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos >= len(p.input) {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unexpected end of expression", p.line, p.col)
|
||||||
|
}
|
||||||
|
ch := p.input[p.pos]
|
||||||
|
|
||||||
|
if ch == '"' {
|
||||||
|
return p.parseStringLit()
|
||||||
|
}
|
||||||
|
if ch == '\'' {
|
||||||
|
return p.parseSingleQuoteStringLit()
|
||||||
|
}
|
||||||
|
if ch >= '0' && ch <= '9' {
|
||||||
|
return p.parseNumberLit()
|
||||||
|
}
|
||||||
|
if isIdentStart(ch) {
|
||||||
|
return p.parseIdentOrKeyword()
|
||||||
|
}
|
||||||
|
if ch == '(' {
|
||||||
|
p.skip(1)
|
||||||
|
p.skipSpaces()
|
||||||
|
expr, err := p.parseOr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos >= len(p.input) || p.input[p.pos] != ')' {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected ')'", p.line, p.col)
|
||||||
|
}
|
||||||
|
p.skip(1)
|
||||||
|
return expr, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unexpected character %q", p.line, p.col, string(ch))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseIdentOrKeyword() (*Expr, error) {
|
||||||
|
start := p.pos
|
||||||
|
for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
name := string(p.input[start:p.pos])
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "true":
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: true}, nil
|
||||||
|
case "false":
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: false}, nil
|
||||||
|
case "nil":
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprNil}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos < len(p.input) && p.input[p.pos] == '(' {
|
||||||
|
return p.parseFuncCall(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.skipSpaces()
|
||||||
|
varName := name
|
||||||
|
for p.pos < len(p.input) && p.input[p.pos] == '.' {
|
||||||
|
p.skip(1)
|
||||||
|
if p.pos >= len(p.input) || !isIdentStart(p.input[p.pos]) {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected identifier after '.'", p.line, p.col)
|
||||||
|
}
|
||||||
|
segStart := p.pos
|
||||||
|
for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
varName += "." + string(p.input[segStart:p.pos])
|
||||||
|
p.skipSpaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprVariable, Name: varName}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseFuncCall(name string) (*Expr, error) {
|
||||||
|
p.skip(1)
|
||||||
|
var args []*Expr
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos < len(p.input) && p.input[p.pos] != ')' {
|
||||||
|
for {
|
||||||
|
arg, err := p.parseOr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args = append(args, arg)
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos < len(p.input) && p.input[p.pos] == ',' {
|
||||||
|
p.skip(1)
|
||||||
|
p.skipSpaces()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.skipSpaces()
|
||||||
|
if p.pos >= len(p.input) || p.input[p.pos] != ')' {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected ')' after function call", p.line, p.col)
|
||||||
|
}
|
||||||
|
p.skip(1)
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprFuncCall, FuncName: name, FuncArgs: args}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseStringLit() (*Expr, error) {
|
||||||
|
p.skip(1)
|
||||||
|
var buf strings.Builder
|
||||||
|
for p.pos < len(p.input) && p.input[p.pos] != '"' {
|
||||||
|
if p.input[p.pos] == '\\' && p.pos+1 < len(p.input) {
|
||||||
|
p.pos++
|
||||||
|
switch p.input[p.pos] {
|
||||||
|
case 'n':
|
||||||
|
buf.WriteRune('\n')
|
||||||
|
case 't':
|
||||||
|
buf.WriteRune('\t')
|
||||||
|
case '\\':
|
||||||
|
buf.WriteRune('\\')
|
||||||
|
case '"':
|
||||||
|
buf.WriteRune('"')
|
||||||
|
default:
|
||||||
|
buf.WriteRune('\\')
|
||||||
|
buf.WriteRune(p.input[p.pos])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(p.input[p.pos])
|
||||||
|
}
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
if p.pos >= len(p.input) {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col)
|
||||||
|
}
|
||||||
|
p.skip(1)
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseSingleQuoteStringLit() (*Expr, error) {
|
||||||
|
p.skip(1)
|
||||||
|
var buf strings.Builder
|
||||||
|
for p.pos < len(p.input) && p.input[p.pos] != '\'' {
|
||||||
|
buf.WriteRune(p.input[p.pos])
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
if p.pos >= len(p.input) {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col)
|
||||||
|
}
|
||||||
|
p.skip(1)
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) parseNumberLit() (*Expr, error) {
|
||||||
|
start := p.pos
|
||||||
|
for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
isFloat := false
|
||||||
|
if p.pos < len(p.input) && p.input[p.pos] == '.' {
|
||||||
|
next := p.pos + 1
|
||||||
|
if next < len(p.input) && p.input[next] >= '0' && p.input[next] <= '9' {
|
||||||
|
isFloat = true
|
||||||
|
p.pos++
|
||||||
|
for p.pos < len(p.input) && p.input[p.pos] >= '0' && p.input[p.pos] <= '9' {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
text := string(p.input[start:p.pos])
|
||||||
|
if isFloat {
|
||||||
|
v, err := strconv.ParseFloat(text, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text)
|
||||||
|
}
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: v}, nil
|
||||||
|
}
|
||||||
|
v, err := strconv.ParseInt(text, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text)
|
||||||
|
}
|
||||||
|
return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: int(v)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) peekStr(s string) bool {
|
||||||
|
if p.pos+len(s) > len(p.input) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, ch := range s {
|
||||||
|
if p.input[p.pos+i] != ch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) skip(n int) {
|
||||||
|
p.pos += n
|
||||||
|
p.col += n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ExprParser) skipSpaces() {
|
||||||
|
for p.pos < len(p.input) && unicode.IsSpace(p.input[p.pos]) {
|
||||||
|
if p.input[p.pos] == '\n' {
|
||||||
|
p.line++
|
||||||
|
p.col = 0
|
||||||
|
} else {
|
||||||
|
p.col++
|
||||||
|
}
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentStart(ch rune) bool {
|
||||||
|
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentPart(ch rune) bool {
|
||||||
|
return isIdentStart(ch) || (ch >= '0' && ch <= '9')
|
||||||
|
}
|
||||||
|
|
||||||
|
func Eval(expr *Expr, ctx *Context) (any, error) {
|
||||||
|
switch expr.ExprType {
|
||||||
|
case ExprLiteral:
|
||||||
|
return expr.Value, nil
|
||||||
|
case ExprNil:
|
||||||
|
return nil, nil
|
||||||
|
case ExprVariable:
|
||||||
|
val, ok := ctx.Get(expr.Name)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
case ExprUnary:
|
||||||
|
return evalUnary(expr, ctx)
|
||||||
|
case ExprBinary:
|
||||||
|
return evalBinary(expr, ctx)
|
||||||
|
case ExprFuncCall:
|
||||||
|
return evalFuncCall(expr, ctx)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unknown expression type", expr.Pos.Line, expr.Pos.Col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func evalUnary(expr *Expr, ctx *Context) (any, error) {
|
||||||
|
val, err := Eval(expr.Operand, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch expr.UnaryOp {
|
||||||
|
case "!":
|
||||||
|
return !isTruthy(val), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unknown unary operator %q", expr.Pos.Line, expr.Pos.Col, expr.UnaryOp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func evalBinary(expr *Expr, ctx *Context) (any, error) {
|
||||||
|
switch expr.Op {
|
||||||
|
case "&&":
|
||||||
|
left, err := Eval(expr.Left, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !isTruthy(left) {
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
return Eval(expr.Right, ctx)
|
||||||
|
case "||":
|
||||||
|
left, err := Eval(expr.Left, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if isTruthy(left) {
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
return Eval(expr.Right, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, err := Eval(expr.Left, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
right, err := Eval(expr.Right, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch expr.Op {
|
||||||
|
case "==":
|
||||||
|
return compareEqual(left, right), nil
|
||||||
|
case "!=":
|
||||||
|
return !compareEqual(left, right), nil
|
||||||
|
case "<":
|
||||||
|
return compareOrder(left, right, expr.Op)
|
||||||
|
case ">":
|
||||||
|
return compareOrder(left, right, expr.Op)
|
||||||
|
case "<=":
|
||||||
|
return compareOrder(left, right, expr.Op)
|
||||||
|
case ">=":
|
||||||
|
return compareOrder(left, right, expr.Op)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unknown operator %q", expr.Pos.Line, expr.Pos.Col, expr.Op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func evalFuncCall(expr *Expr, ctx *Context) (any, error) {
|
||||||
|
fn, ok := LookupBuiltin(expr.FuncName)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unknown function %q", expr.Pos.Line, expr.Pos.Col, expr.FuncName)
|
||||||
|
}
|
||||||
|
var args []any
|
||||||
|
for _, a := range expr.FuncArgs {
|
||||||
|
val, err := Eval(a, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args = append(args, val)
|
||||||
|
}
|
||||||
|
result, ok := fn(args)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: function %q call failed", expr.Pos.Line, expr.Pos.Col, expr.FuncName)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTruthy(val any) bool {
|
||||||
|
if val == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch v := val.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return v != 0
|
||||||
|
case int64:
|
||||||
|
return v != 0
|
||||||
|
case float64:
|
||||||
|
return v != 0
|
||||||
|
case string:
|
||||||
|
return v != ""
|
||||||
|
case []any:
|
||||||
|
return len(v) > 0
|
||||||
|
case map[string]any:
|
||||||
|
return len(v) > 0
|
||||||
|
default:
|
||||||
|
rv := reflect.ValueOf(val)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array, reflect.Map:
|
||||||
|
return rv.Len() > 0
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareEqual(left, right any) bool {
|
||||||
|
if left == nil && right == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lbool, lbOk := left.(bool)
|
||||||
|
rbool, rbOk := right.(bool)
|
||||||
|
if lbOk || rbOk {
|
||||||
|
return lbOk && rbOk && lbool == rbool
|
||||||
|
}
|
||||||
|
|
||||||
|
lf, lok := toFloat64(left)
|
||||||
|
rf, rok := toFloat64(right)
|
||||||
|
if lok && rok {
|
||||||
|
return lf == rf
|
||||||
|
}
|
||||||
|
|
||||||
|
lstr, lsOk := left.(string)
|
||||||
|
rstr, rsOk := right.(string)
|
||||||
|
if lsOk && rsOk {
|
||||||
|
return lstr == rstr
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v", left) == fmt.Sprintf("%v", right)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareOrder(left, right any, op string) (bool, error) {
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lf, lok := toFloat64(left)
|
||||||
|
rf, rok := toFloat64(right)
|
||||||
|
if lok && rok {
|
||||||
|
switch op {
|
||||||
|
case "<":
|
||||||
|
return lf < rf, nil
|
||||||
|
case ">":
|
||||||
|
return lf > rf, nil
|
||||||
|
case "<=":
|
||||||
|
return lf <= rf, nil
|
||||||
|
case ">=":
|
||||||
|
return lf >= rf, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ls, lsOk := left.(string)
|
||||||
|
rs, rsOk := right.(string)
|
||||||
|
if lsOk && rsOk {
|
||||||
|
switch op {
|
||||||
|
case "<":
|
||||||
|
return ls < rs, nil
|
||||||
|
case ">":
|
||||||
|
return ls > rs, nil
|
||||||
|
case "<=":
|
||||||
|
return ls <= rs, nil
|
||||||
|
case ">=":
|
||||||
|
return ls >= rs, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, fmt.Errorf("line 0, col 0: cannot compare %T and %T with %s", left, right, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toFloat64(val any) (float64, bool) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int:
|
||||||
|
return float64(v), true
|
||||||
|
case int64:
|
||||||
|
return float64(v), true
|
||||||
|
case float64:
|
||||||
|
return v, true
|
||||||
|
case float32:
|
||||||
|
return float64(v), true
|
||||||
|
case uint:
|
||||||
|
return float64(v), true
|
||||||
|
case uint64:
|
||||||
|
return float64(v), true
|
||||||
|
case int32:
|
||||||
|
return float64(v), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
73
internal/include.go
Normal file
73
internal/include.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IncludeResolver func(path string) (string, error)
|
||||||
|
|
||||||
|
type IncludeManager struct {
|
||||||
|
resolver IncludeResolver
|
||||||
|
stack []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIncludeManager(resolver IncludeResolver) *IncludeManager {
|
||||||
|
return &IncludeManager{resolver: resolver}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IncludeManager) Resolve(path string) (string, error) {
|
||||||
|
return m.resolveInternal(path, m.stack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IncludeManager) resolveInternal(path string, stack []string) (string, error) {
|
||||||
|
for i, p := range stack {
|
||||||
|
if p == path {
|
||||||
|
return "", fmt.Errorf("circular include detected: %s is already in the include chain at depth %d", path, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
src, err := m.resolver(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to resolve include %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newStack := make([]string, len(stack)+1)
|
||||||
|
copy(newStack, stack)
|
||||||
|
newStack[len(stack)] = path
|
||||||
|
|
||||||
|
return m.expandIncludes(src, newStack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *IncludeManager) expandIncludes(src string, stack []string) (string, error) {
|
||||||
|
result := src
|
||||||
|
offset := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
start := strings.Index(result[offset:], `@include("`)
|
||||||
|
if start < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
absStart := offset + start
|
||||||
|
pathStart := absStart + len(`@include("`)
|
||||||
|
|
||||||
|
end := strings.Index(result[pathStart:], `")`)
|
||||||
|
if end < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
includePath := result[pathStart : pathStart+end]
|
||||||
|
absEnd := pathStart + end + len(`")`)
|
||||||
|
|
||||||
|
resolved, err := m.resolveInternal(includePath, stack)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = result[:absStart] + resolved + result[absEnd:]
|
||||||
|
offset = absStart + len(resolved)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
211
internal/lexer.go
Normal file
211
internal/lexer.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokText TokenType = iota
|
||||||
|
TokParamStart
|
||||||
|
TokRawStart
|
||||||
|
TokIfStart
|
||||||
|
TokForStart
|
||||||
|
TokTplStart
|
||||||
|
TokIncludeStart
|
||||||
|
TokNamespaceStart
|
||||||
|
TokElse
|
||||||
|
TokComment
|
||||||
|
TokEOF
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
Type TokenType
|
||||||
|
Value string
|
||||||
|
Pos Pos
|
||||||
|
}
|
||||||
|
|
||||||
|
type Lexer struct {
|
||||||
|
input []rune
|
||||||
|
pos int
|
||||||
|
line int
|
||||||
|
col int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLexer(input string) *Lexer {
|
||||||
|
return &Lexer{
|
||||||
|
input: []rune(input),
|
||||||
|
line: 1,
|
||||||
|
col: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) Tokenize() ([]Token, error) {
|
||||||
|
var tokens []Token
|
||||||
|
|
||||||
|
for l.pos < len(l.input) {
|
||||||
|
ch := l.input[l.pos]
|
||||||
|
|
||||||
|
if ch == '#' {
|
||||||
|
if l.peek(1) == '{' {
|
||||||
|
tokens = append(tokens, Token{Type: TokParamStart, Value: "#{", Pos: l.curPos()})
|
||||||
|
l.advance()
|
||||||
|
l.advance()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens = append(tokens, l.readComment())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch == '$' && l.peek(1) == '{' {
|
||||||
|
tokens = append(tokens, Token{Type: TokRawStart, Value: "${", Pos: l.curPos()})
|
||||||
|
l.advance()
|
||||||
|
l.advance()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch == '@' {
|
||||||
|
if tok, ok := l.tryDirective(); ok {
|
||||||
|
tokens = append(tokens, tok)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch == '}' {
|
||||||
|
// Skip spaces after '}' to check for "else"
|
||||||
|
spaceOffset := 1
|
||||||
|
for l.peek(spaceOffset) == ' ' || l.peek(spaceOffset) == '\t' {
|
||||||
|
spaceOffset++
|
||||||
|
}
|
||||||
|
if l.peekWord(spaceOffset, "else") {
|
||||||
|
pos := l.curPos()
|
||||||
|
l.advance() // consume '}'
|
||||||
|
l.advanceN(spaceOffset - 1) // consume spaces
|
||||||
|
l.advanceN(4) // consume "else"
|
||||||
|
tokens = append(tokens, Token{Type: TokElse, Value: "} else", Pos: pos})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l.advance()
|
||||||
|
// Pos stores the end position (after '}'), consistent with other TokText tokens
|
||||||
|
tokens = append(tokens, Token{Type: TokText, Value: "}", Pos: l.curPos()})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ch == '\n' {
|
||||||
|
l.advance()
|
||||||
|
// Pos stores the end position (after '\n'), consistent with other TokText tokens
|
||||||
|
tokens = append(tokens, Token{Type: TokText, Value: "\n", Pos: l.curPos()})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular text: scan until special character
|
||||||
|
start := l.pos
|
||||||
|
for l.pos < len(l.input) {
|
||||||
|
c := l.input[l.pos]
|
||||||
|
if c == '#' || c == '$' || c == '@' || c == '}' || c == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
if l.pos > start {
|
||||||
|
tokens = append(tokens, Token{Type: TokText, Value: string(l.input[start:l.pos]), Pos: Pos{Line: l.line, Col: l.col}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = append(tokens, Token{Type: TokEOF, Pos: Pos{Line: l.line, Col: l.col}})
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) curPos() Pos {
|
||||||
|
return Pos{Line: l.line, Col: l.col}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) advance() {
|
||||||
|
if l.pos < len(l.input) {
|
||||||
|
if l.input[l.pos] == '\n' {
|
||||||
|
l.line++
|
||||||
|
l.col = 1
|
||||||
|
} else {
|
||||||
|
l.col++
|
||||||
|
}
|
||||||
|
l.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) advanceN(n int) {
|
||||||
|
for range n {
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) peek(offset int) rune {
|
||||||
|
idx := l.pos + offset
|
||||||
|
if idx < len(l.input) {
|
||||||
|
return l.input[idx]
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) peekWord(offset int, word string) bool {
|
||||||
|
runes := []rune(word)
|
||||||
|
n := len(runes)
|
||||||
|
for i := range n {
|
||||||
|
if l.peek(offset+i) != runes[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
after := l.peek(offset + n)
|
||||||
|
return after == 0 || after == ' ' || after == '\n' || after == '\t' || after == '{' || after == '}'
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) tryDirective() (Token, bool) {
|
||||||
|
type directive struct {
|
||||||
|
prefix []rune
|
||||||
|
ttype TokenType
|
||||||
|
skip int
|
||||||
|
}
|
||||||
|
|
||||||
|
directives := []directive{
|
||||||
|
{[]rune("@if("), TokIfStart, 4},
|
||||||
|
{[]rune("@for("), TokForStart, 5},
|
||||||
|
{[]rune("@tpl(\""), TokTplStart, 5},
|
||||||
|
{[]rune("@include(\""), TokIncludeStart, 10},
|
||||||
|
{[]rune("@namespace(\""), TokNamespaceStart, 12},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range directives {
|
||||||
|
if l.matchRunes(d.prefix) {
|
||||||
|
pos := l.curPos()
|
||||||
|
val := string(d.prefix)
|
||||||
|
l.advanceN(d.skip)
|
||||||
|
return Token{Type: d.ttype, Value: val, Pos: pos}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Token{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) matchRunes(runes []rune) bool {
|
||||||
|
for i, r := range runes {
|
||||||
|
if l.peek(i) != r {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Lexer) readComment() Token {
|
||||||
|
pos := l.curPos()
|
||||||
|
l.advance() // #
|
||||||
|
start := l.pos
|
||||||
|
for l.pos < len(l.input) && l.input[l.pos] != '\n' {
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
// consume trailing newline so the comment line disappears
|
||||||
|
if l.pos < len(l.input) && l.input[l.pos] == '\n' {
|
||||||
|
l.advance()
|
||||||
|
}
|
||||||
|
return Token{Type: TokComment, Value: strings.TrimRightFunc(string(l.input[start:l.pos]), unicode.IsSpace), Pos: pos}
|
||||||
|
}
|
||||||
111
internal/node.go
Normal file
111
internal/node.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
type Node interface {
|
||||||
|
nodeType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Pos struct {
|
||||||
|
Line int
|
||||||
|
Col int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *TextNode) nodeType() string { return "Text" }
|
||||||
|
|
||||||
|
type ParamNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *ParamNode) nodeType() string { return "Param" }
|
||||||
|
|
||||||
|
type RawNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *RawNode) nodeType() string { return "Raw" }
|
||||||
|
|
||||||
|
type IfNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Cond *Expr
|
||||||
|
Body []Node
|
||||||
|
Else []Node
|
||||||
|
ElseIf []*ElseIfBranch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *IfNode) nodeType() string { return "If" }
|
||||||
|
|
||||||
|
type ElseIfBranch struct {
|
||||||
|
Pos Pos
|
||||||
|
Cond *Expr
|
||||||
|
Body []Node
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForNode struct {
|
||||||
|
Pos Pos
|
||||||
|
KeyVar string
|
||||||
|
ValVar string
|
||||||
|
List *Expr
|
||||||
|
Body []Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *ForNode) nodeType() string { return "For" }
|
||||||
|
|
||||||
|
type BlockNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Name string
|
||||||
|
Body []Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *BlockNode) nodeType() string { return "Block" }
|
||||||
|
|
||||||
|
type IncludeNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *IncludeNode) nodeType() string { return "Include" }
|
||||||
|
|
||||||
|
type NamespaceNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NamespaceNode) nodeType() string { return "Namespace" }
|
||||||
|
|
||||||
|
type CommentNode struct {
|
||||||
|
Pos Pos
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *CommentNode) nodeType() string { return "Comment" }
|
||||||
|
|
||||||
|
type ExprType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExprLiteral ExprType = iota
|
||||||
|
ExprVariable
|
||||||
|
ExprBinary
|
||||||
|
ExprUnary
|
||||||
|
ExprFuncCall
|
||||||
|
ExprNil
|
||||||
|
)
|
||||||
|
|
||||||
|
type Expr struct {
|
||||||
|
Pos Pos
|
||||||
|
ExprType ExprType
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
Left *Expr
|
||||||
|
Op string
|
||||||
|
Right *Expr
|
||||||
|
UnaryOp string
|
||||||
|
Operand *Expr
|
||||||
|
FuncName string
|
||||||
|
FuncArgs []*Expr
|
||||||
|
}
|
||||||
913
internal/parser.go
Normal file
913
internal/parser.go
Normal file
@@ -0,0 +1,913 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parser struct {
|
||||||
|
input []rune
|
||||||
|
tokens []Token
|
||||||
|
pos int
|
||||||
|
includeMgr *IncludeManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParser(input string, tokens []Token, includeMgr *IncludeManager) *Parser {
|
||||||
|
return &Parser{
|
||||||
|
input: []rune(input),
|
||||||
|
tokens: tokens,
|
||||||
|
pos: 0,
|
||||||
|
includeMgr: includeMgr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) Parse() ([]Node, error) {
|
||||||
|
var nodes []Node
|
||||||
|
var hasTpl bool
|
||||||
|
var tplNames map[string]bool
|
||||||
|
|
||||||
|
for p.pos < len(p.tokens) {
|
||||||
|
tok := p.cur()
|
||||||
|
if tok.Type == TokEOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tok.Type {
|
||||||
|
case TokText:
|
||||||
|
p.pos++
|
||||||
|
if len(tok.Value) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hasTpl && !isWhitespace(tok.Value) {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: top-level text is not allowed in templates with @tpl blocks", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
nodes = append(nodes, &TextNode{Pos: tok.Pos, Text: tok.Value})
|
||||||
|
|
||||||
|
case TokParamStart:
|
||||||
|
node, err := p.parseParam(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokRawStart:
|
||||||
|
node, err := p.parseRaw(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokIfStart:
|
||||||
|
node, err := p.parseIf(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokForStart:
|
||||||
|
node, err := p.parseFor(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokTplStart:
|
||||||
|
hasTpl = true
|
||||||
|
if tplNames == nil {
|
||||||
|
tplNames = make(map[string]bool)
|
||||||
|
}
|
||||||
|
node, err := p.parseTpl(tok, tplNames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokIncludeStart:
|
||||||
|
subNodes, err := p.expandInclude(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, subNodes...)
|
||||||
|
|
||||||
|
case TokNamespaceStart:
|
||||||
|
node, err := p.parseNamespace(tok, len(nodes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
case TokElse:
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unexpected else", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
|
||||||
|
case TokComment:
|
||||||
|
p.pos++
|
||||||
|
|
||||||
|
default:
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) cur() Token {
|
||||||
|
if p.pos < len(p.tokens) {
|
||||||
|
return p.tokens[p.pos]
|
||||||
|
}
|
||||||
|
return Token{Type: TokEOF}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUntilParen reads from runePos until ')' at paren depth 0.
|
||||||
|
// Only tracks '(' depth, not braces.
|
||||||
|
func (p *Parser) readUntilParen(startRunePos int) (string, int, error) {
|
||||||
|
i := startRunePos
|
||||||
|
depth := 0
|
||||||
|
for i < len(p.input) {
|
||||||
|
ch := p.input[i]
|
||||||
|
if ch == '(' {
|
||||||
|
depth++
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == ')' {
|
||||||
|
if depth > 0 {
|
||||||
|
depth--
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return string(p.input[startRunePos:i]), i, nil
|
||||||
|
}
|
||||||
|
if ch == '\'' || ch == '"' {
|
||||||
|
quote := ch
|
||||||
|
i++
|
||||||
|
for i < len(p.input) && p.input[i] != quote {
|
||||||
|
if p.input[i] == '\\' && i+1 < len(p.input) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i < len(p.input) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return "", i, fmt.Errorf("unexpected end of input, expected ')'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUntilBrace reads from runePos until '}' at brace depth 0.
|
||||||
|
// Only tracks '{' depth, not parens.
|
||||||
|
func (p *Parser) readUntilBrace(startRunePos int) (string, int, error) {
|
||||||
|
i := startRunePos
|
||||||
|
depth := 0
|
||||||
|
for i < len(p.input) {
|
||||||
|
ch := p.input[i]
|
||||||
|
if ch == '{' {
|
||||||
|
depth++
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '}' {
|
||||||
|
if depth > 0 {
|
||||||
|
depth--
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return string(p.input[startRunePos:i]), i, nil
|
||||||
|
}
|
||||||
|
if ch == '\'' || ch == '"' {
|
||||||
|
quote := ch
|
||||||
|
i++
|
||||||
|
for i < len(p.input) && p.input[i] != quote {
|
||||||
|
if p.input[i] == '\\' && i+1 < len(p.input) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i < len(p.input) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return "", i, fmt.Errorf("unexpected end of input, expected '}'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// readUntilQuote reads from runePos until '"'.
|
||||||
|
func (p *Parser) readUntilQuote(startRunePos int) (string, int, error) {
|
||||||
|
i := startRunePos
|
||||||
|
for i < len(p.input) {
|
||||||
|
if p.input[i] == '"' {
|
||||||
|
return string(p.input[startRunePos:i]), i, nil
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return "", i, fmt.Errorf("unexpected end of input, expected '\"'")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) runePosFromToken(tok Token) int {
|
||||||
|
line, col := tok.Pos.Line, tok.Pos.Col
|
||||||
|
prefixLen := len(tok.Value)
|
||||||
|
return runePosFromLineCol(p.input, line, col) + prefixLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func runePosFromLineCol(input []rune, line, col int) int {
|
||||||
|
if line <= 1 && col <= 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
currentLine := 1
|
||||||
|
currentCol := 1
|
||||||
|
for i, ch := range input {
|
||||||
|
if currentLine == line && currentCol == col {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
if ch == '\n' {
|
||||||
|
currentLine++
|
||||||
|
currentCol = 1
|
||||||
|
} else {
|
||||||
|
currentCol++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) consumeTokensForRuneRange(_, runeEnd int) {
|
||||||
|
for p.pos < len(p.tokens) {
|
||||||
|
tok := p.tokens[p.pos]
|
||||||
|
if tok.Type == TokEOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Text tokens store Pos as the end position; other tokens store Pos as the start.
|
||||||
|
tokRuneStart := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col)
|
||||||
|
if tok.Type == TokText {
|
||||||
|
tokRuneStart -= len(tok.Value)
|
||||||
|
}
|
||||||
|
if tokRuneStart >= runeEnd {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// For text tokens, check if the token extends past runeEnd.
|
||||||
|
// If so, split it: keep the remainder as a new text token.
|
||||||
|
if tok.Type == TokText && len(tok.Value) > 0 {
|
||||||
|
tokRuneEnd := tokRuneStart + len(tok.Value)
|
||||||
|
if tokRuneEnd > runeEnd {
|
||||||
|
overlap := runeEnd - tokRuneStart
|
||||||
|
if overlap > 0 && overlap < len(tok.Value) {
|
||||||
|
remainder := tok.Value[overlap:]
|
||||||
|
// Replace current token with the remainder
|
||||||
|
p.tokens[p.pos] = Token{
|
||||||
|
Type: TokText,
|
||||||
|
Value: remainder,
|
||||||
|
Pos: tok.Pos,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseParam(tok Token) (*ParamNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
content, endPos, err := p.readUntilBrace(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated param, expected '}'", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: empty param name", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
p.consumeTokensForRuneRange(runePos, endPos+1)
|
||||||
|
return &ParamNode{Pos: tok.Pos, Name: content}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseRaw(tok Token) (*RawNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
content, endPos, err := p.readUntilBrace(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated raw, expected '}'", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: empty raw name", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
p.consumeTokensForRuneRange(runePos, endPos+1)
|
||||||
|
return &RawNode{Pos: tok.Pos, Name: content}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseIf(tok Token) (*IfNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
exprStr, parenClose, err := p.readUntilParen(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @if, expected ')'", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find '{' after ')' in raw input
|
||||||
|
braceOpen := findChar(p.input[parenClose+1:], '{')
|
||||||
|
if braceOpen < 0 {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected '{' after @if condition", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
braceOpen = parenClose + 1 + braceOpen
|
||||||
|
|
||||||
|
// consume tokens covering ) ... {
|
||||||
|
p.consumeTokensForRuneRange(runePos, braceOpen+1)
|
||||||
|
|
||||||
|
body, elseBody, elseIfBranches, err := p.parseBlockBodyWithElse("if")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &IfNode{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
Cond: expr,
|
||||||
|
Body: body,
|
||||||
|
Else: elseBody,
|
||||||
|
ElseIf: elseIfBranches,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseElseIfBranch(tok Token) (*ElseIfBranch, error) {
|
||||||
|
prefix := "@elseif("
|
||||||
|
runePos := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col)
|
||||||
|
idx := strings.Index(tok.Value, prefix)
|
||||||
|
if idx < 0 {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected @elseif(", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
exprStart := runePos + idx + len(prefix)
|
||||||
|
exprStr, closePos, err := p.readUntilParen(exprStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @elseif, expected ')'", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.consumeTokensForRuneRange(exprStart, closePos+1)
|
||||||
|
p.skipToBraceOpen()
|
||||||
|
|
||||||
|
body, _, _, err := p.parseBlockBodyWithElse("elseif")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ElseIfBranch{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
Cond: expr,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseFor(tok Token) (*ForNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
content, parenClose, err := p.readUntilParen(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @for, expected ')'", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
braceOpen := findChar(p.input[parenClose+1:], '{')
|
||||||
|
if braceOpen < 0 {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected '{' after @for", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
braceOpen = parenClose + 1 + braceOpen
|
||||||
|
p.consumeTokensForRuneRange(runePos, braceOpen+1)
|
||||||
|
|
||||||
|
keyVar, valVar, listExprStr, err := parseForHeader(content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: %s", tok.Pos.Line, tok.Pos.Col, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
listExpr, err := NewExprParser(strings.TrimSpace(listExprStr), tok.Pos.Line, tok.Pos.Col).Parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _, _, err := p.parseBlockBodyWithElse("for")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ForNode{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
KeyVar: keyVar,
|
||||||
|
ValVar: valVar,
|
||||||
|
List: listExpr,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findChar(input []rune, ch rune) int {
|
||||||
|
for i, c := range input {
|
||||||
|
if c == ch {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseForHeader(content string) (keyVar, valVar, listExpr string, err error) {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
|
||||||
|
rangeIdx := strings.Index(content, " range ")
|
||||||
|
if rangeIdx < 0 {
|
||||||
|
return "", "", "", fmt.Errorf("expected 'range' keyword in @for")
|
||||||
|
}
|
||||||
|
|
||||||
|
varsPart := strings.TrimRight(strings.TrimSpace(content[:rangeIdx]), ", ")
|
||||||
|
listExpr = strings.TrimSpace(content[rangeIdx+len(" range "):])
|
||||||
|
|
||||||
|
parts := strings.Split(varsPart, ",")
|
||||||
|
switch len(parts) {
|
||||||
|
case 1:
|
||||||
|
valVar = strings.TrimSpace(parts[0])
|
||||||
|
case 2:
|
||||||
|
keyVar = strings.TrimSpace(parts[0])
|
||||||
|
valVar = strings.TrimSpace(parts[1])
|
||||||
|
default:
|
||||||
|
return "", "", "", fmt.Errorf("invalid @for variable declaration")
|
||||||
|
}
|
||||||
|
|
||||||
|
if valVar == "" {
|
||||||
|
return "", "", "", fmt.Errorf("missing value variable in @for")
|
||||||
|
}
|
||||||
|
return keyVar, valVar, listExpr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseTpl(tok Token, tplNames map[string]bool) (*BlockNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
name, quotePos, err := p.readUntilQuote(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @tpl name", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: empty @tpl block name", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
if tplNames[name] {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: duplicate @tpl block name %q", tok.Pos.Line, tok.Pos.Col, name)
|
||||||
|
}
|
||||||
|
tplNames[name] = true
|
||||||
|
|
||||||
|
// skip ") {"
|
||||||
|
endPos := quotePos + 1 // past closing "
|
||||||
|
if endPos < len(p.input) && p.input[endPos] == ')' {
|
||||||
|
endPos++
|
||||||
|
}
|
||||||
|
// find '{' and consume tokens up to past it
|
||||||
|
braceOpen := findChar(p.input[endPos:], '{')
|
||||||
|
if braceOpen < 0 {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: expected '{' after @tpl", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
braceOpen = endPos + braceOpen
|
||||||
|
p.consumeTokensForRuneRange(runePos, braceOpen+1)
|
||||||
|
|
||||||
|
body, _, _, err := p.parseBlockBodyWithElse("tpl")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &BlockNode{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
Name: name,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseInclude(tok Token) (*IncludeNode, error) {
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
path, quotePos, err := p.readUntilQuote(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @include path", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
|
||||||
|
endPos := quotePos + 1
|
||||||
|
if endPos < len(p.input) && p.input[endPos] == ')' {
|
||||||
|
endPos++
|
||||||
|
}
|
||||||
|
p.consumeTokensForRuneRange(runePos, endPos)
|
||||||
|
|
||||||
|
return &IncludeNode{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
Path: path,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) expandInclude(tok Token) ([]Node, error) {
|
||||||
|
incNode, err := p.parseInclude(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if p.includeMgr == nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: @include used but no include resolver configured", incNode.Pos.Line, incNode.Pos.Col)
|
||||||
|
}
|
||||||
|
expanded, err := p.includeMgr.Resolve(incNode.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: %s", incNode.Pos.Line, incNode.Pos.Col, err.Error())
|
||||||
|
}
|
||||||
|
subLexer := NewLexer(expanded)
|
||||||
|
subTokens, err := subLexer.Tokenize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
subParser := NewParser(expanded, subTokens, p.includeMgr)
|
||||||
|
return subParser.Parse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseNamespace(tok Token, nodeCount int) (*NamespaceNode, error) {
|
||||||
|
if nodeCount > 0 {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: @namespace must be at the top of the file", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
runePos := p.runePosFromToken(tok)
|
||||||
|
name, quotePos, err := p.readUntilQuote(runePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("line %d, col %d: unterminated @namespace name", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
|
endPos := quotePos + 1
|
||||||
|
if endPos < len(p.input) && p.input[endPos] == ')' {
|
||||||
|
endPos++
|
||||||
|
}
|
||||||
|
p.consumeTokensForRuneRange(runePos, endPos)
|
||||||
|
|
||||||
|
return &NamespaceNode{
|
||||||
|
Pos: tok.Pos,
|
||||||
|
Name: name,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// skipToBraceOpen finds '{' in raw input starting from current token position,
|
||||||
|
// then consumes all tokens up to and past '{'. The text after '{' is preserved
|
||||||
|
// as the current token so it becomes part of the block body.
|
||||||
|
func (p *Parser) skipToBraceOpen() {
|
||||||
|
// Find current rune position from current token
|
||||||
|
if p.pos >= len(p.tokens) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tok := p.tokens[p.pos]
|
||||||
|
startRune := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col)
|
||||||
|
// Text tokens store Pos as the end position; adjust to start.
|
||||||
|
if tok.Type == TokText {
|
||||||
|
startRune -= len(tok.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find '{' in raw input
|
||||||
|
idx := findChar(p.input[startRune:], '{')
|
||||||
|
if idx < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
braceRune := startRune + idx
|
||||||
|
|
||||||
|
// Consume all tokens whose rune position is before '{'
|
||||||
|
p.consumeTokensForRuneRange(startRune, braceRune+1)
|
||||||
|
|
||||||
|
// The text after '{' needs to be available as a token.
|
||||||
|
// If there are remaining characters after '{', they'll be in subsequent tokens.
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitTextAtBrace finds the first '{' in the current text token at p.pos,
|
||||||
|
// splits it so the part after '{' remains, and returns true.
|
||||||
|
// If no '{' is found, advances p.pos and returns false.
|
||||||
|
func (p *Parser) splitTextAtBrace() bool {
|
||||||
|
for p.pos < len(p.tokens) {
|
||||||
|
tok := p.tokens[p.pos]
|
||||||
|
if tok.Type != TokText {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
idx := strings.Index(tok.Value, "{")
|
||||||
|
if idx >= 0 {
|
||||||
|
if idx+1 < len(tok.Value) {
|
||||||
|
remainder := tok.Value[idx+1:]
|
||||||
|
p.tokens[p.pos] = Token{Type: TokText, Value: remainder, Pos: tok.Pos}
|
||||||
|
} else {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) skipWhitespaceText() {
|
||||||
|
for p.pos < len(p.tokens) {
|
||||||
|
tok := p.tokens[p.pos]
|
||||||
|
if tok.Type == TokEOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if tok.Type == TokText && isWhitespace(tok.Value) {
|
||||||
|
p.pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBlockBodyWithElse parses tokens inside a { } block.
|
||||||
|
// Returns: body nodes, else body (nil if no else), elseif branches, error.
|
||||||
|
// Handles nested @if/@for blocks by tracking brace depth.
|
||||||
|
func (p *Parser) parseBlockBodyWithElse(blockType string) ([]Node, []Node, []*ElseIfBranch, error) {
|
||||||
|
var body []Node
|
||||||
|
braceDepth := 1
|
||||||
|
|
||||||
|
for p.pos < len(p.tokens) {
|
||||||
|
tok := p.cur()
|
||||||
|
if tok.Type == TokEOF {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated %s block", tok.Pos.Line, tok.Pos.Col, blockType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tok.Type == TokText {
|
||||||
|
text := tok.Value
|
||||||
|
// Scan for '}' at depth 1
|
||||||
|
before, found := splitAtClosingBrace(text)
|
||||||
|
if found {
|
||||||
|
if before != "" {
|
||||||
|
body = append(body, &TextNode{Pos: tok.Pos, Text: before})
|
||||||
|
}
|
||||||
|
p.pos++ // consume this token
|
||||||
|
// Check what follows: else, elseif, or nothing
|
||||||
|
return p.handleBlockEnd(body, blockType)
|
||||||
|
}
|
||||||
|
// No closing brace — count nested braces and add as text
|
||||||
|
for _, ch := range text {
|
||||||
|
if ch == '{' {
|
||||||
|
braceDepth++
|
||||||
|
} else if ch == '}' {
|
||||||
|
braceDepth--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(text) > 0 {
|
||||||
|
body = append(body, &TextNode{Pos: tok.Pos, Text: text})
|
||||||
|
}
|
||||||
|
p.pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tok.Type == TokElse && braceDepth == 1 {
|
||||||
|
// Don't consume TokElse here — handleBlockEnd will consume it.
|
||||||
|
// This way handleBlockEnd can properly process the else branch.
|
||||||
|
return p.handleBlockEnd(body, blockType)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pos++
|
||||||
|
|
||||||
|
switch tok.Type {
|
||||||
|
case TokComment:
|
||||||
|
// skip comments in body
|
||||||
|
case TokParamStart:
|
||||||
|
node, err := p.parseParam(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
body = append(body, node)
|
||||||
|
case TokRawStart:
|
||||||
|
node, err := p.parseRaw(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
body = append(body, node)
|
||||||
|
case TokIfStart:
|
||||||
|
node, err := p.parseIf(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
body = append(body, node)
|
||||||
|
case TokForStart:
|
||||||
|
node, err := p.parseFor(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
body = append(body, node)
|
||||||
|
case TokTplStart:
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: @tpl blocks cannot be nested", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
case TokIncludeStart:
|
||||||
|
subNodes, err := p.expandInclude(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
body = append(body, subNodes...)
|
||||||
|
case TokNamespaceStart:
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: @namespace must be at file top level", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, nil, fmt.Errorf("unterminated %s block", blockType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleBlockEnd is called after the closing '}' of a block.
|
||||||
|
// Checks for else/elseif branches.
|
||||||
|
func (p *Parser) handleBlockEnd(body []Node, blockType string) ([]Node, []Node, []*ElseIfBranch, error) {
|
||||||
|
// For non-if blocks, no else support
|
||||||
|
if blockType != "if" && blockType != "elseif" {
|
||||||
|
return body, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.skipWhitespaceText()
|
||||||
|
tok := p.cur()
|
||||||
|
|
||||||
|
// TokElse means lexer already matched "} else"
|
||||||
|
if tok.Type == TokElse {
|
||||||
|
return p.handleTokElse(tok, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for text token starting with "else" or "elseif("
|
||||||
|
if tok.Type == TokText {
|
||||||
|
return p.handleTextElse(tok, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTokElse processes a TokElse token (lexer already matched "} else").
|
||||||
|
func (p *Parser) handleTokElse(tok Token, body []Node) ([]Node, []Node, []*ElseIfBranch, error) {
|
||||||
|
p.pos++ // consume TokElse
|
||||||
|
p.skipWhitespaceText()
|
||||||
|
next := p.cur()
|
||||||
|
// Check for "} else if(expr)" — text starting with "if " or "if("
|
||||||
|
if next.Type == TokText {
|
||||||
|
trimmed := strings.TrimSpace(next.Value)
|
||||||
|
if strings.HasPrefix(trimmed, "if ") || strings.HasPrefix(trimmed, "if(") {
|
||||||
|
return p.parseElseIfFromText(body, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check for @elseif( syntax
|
||||||
|
if next.Type == TokText && strings.Contains(next.Value, "@elseif(") {
|
||||||
|
branch, err := p.parseElseIfBranch(next)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
return body, nil, []*ElseIfBranch{branch}, nil
|
||||||
|
}
|
||||||
|
// It's a plain else — find '{' in text tokens and split
|
||||||
|
if !p.splitTextAtBrace() {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else", tok.Pos.Line, tok.Pos.Col)
|
||||||
|
}
|
||||||
|
elseBody, _, _, err := p.parseBlockBodyWithElse("else")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
return body, elseBody, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTextElse processes a TokText token that may contain "else" or "elseif(".
|
||||||
|
func (p *Parser) handleTextElse(tok Token, body []Node) ([]Node, []Node, []*ElseIfBranch, error) {
|
||||||
|
trimmed := strings.TrimSpace(tok.Value)
|
||||||
|
|
||||||
|
// Handle "else if" in a single text token (e.g. "} else if (expr) {")
|
||||||
|
if strings.HasPrefix(trimmed, "else if ") || strings.HasPrefix(trimmed, "else if(") {
|
||||||
|
p.consumeElseKeyword(tok)
|
||||||
|
p.skipWhitespaceText()
|
||||||
|
next := p.cur()
|
||||||
|
if next.Type == TokText {
|
||||||
|
return p.parseElseIfFromText(body, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plain else
|
||||||
|
if strings.HasPrefix(trimmed, "else") && !strings.HasPrefix(trimmed, "elseif(") && !strings.HasPrefix(trimmed, "else if ") && !strings.HasPrefix(trimmed, "else if(") {
|
||||||
|
p.consumeElseKeyword(tok)
|
||||||
|
if !p.splitTextAtBrace() {
|
||||||
|
p.skipToBraceOpen()
|
||||||
|
}
|
||||||
|
elseBody, _, _, err := p.parseBlockBodyWithElse("else")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
return body, elseBody, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// elseif( syntax
|
||||||
|
if strings.HasPrefix(trimmed, "elseif(") {
|
||||||
|
branch, err := p.parseElseIfBranch(tok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
return body, nil, []*ElseIfBranch{branch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// consumeElseKeyword removes the "else" prefix from a text token,
|
||||||
|
// keeping any remaining text as the current token.
|
||||||
|
func (p *Parser) consumeElseKeyword(tok Token) {
|
||||||
|
idx := strings.Index(tok.Value, "else")
|
||||||
|
if idx >= 0 {
|
||||||
|
remainder := tok.Value[idx+4:] // after "else"
|
||||||
|
if len(remainder) > 0 {
|
||||||
|
p.tokens[p.pos] = Token{Type: TokText, Value: remainder, Pos: tok.Pos}
|
||||||
|
} else {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
p.pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseElseIfFromText handles "} else if (expr) { body }" pattern.
|
||||||
|
// next is a TokText whose trimmed value starts with "if(" or "if ".
|
||||||
|
func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, []*ElseIfBranch, error) {
|
||||||
|
trimmed := strings.TrimSpace(next.Value)
|
||||||
|
|
||||||
|
// Find "if" and skip to the opening paren
|
||||||
|
ifIdx := strings.Index(trimmed, "if")
|
||||||
|
if ifIdx < 0 {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: expected 'if' in else-if condition", next.Pos.Line, next.Pos.Col)
|
||||||
|
}
|
||||||
|
afterIf := trimmed[ifIdx+2:]
|
||||||
|
// skip whitespace between "if" and "("
|
||||||
|
parenLocalOffset := 0
|
||||||
|
for parenLocalOffset < len(afterIf) && afterIf[parenLocalOffset] == ' ' {
|
||||||
|
parenLocalOffset++
|
||||||
|
}
|
||||||
|
if parenLocalOffset >= len(afterIf) || afterIf[parenLocalOffset] != '(' {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute rune position of '(' in the raw input.
|
||||||
|
// The text token's Pos marks the END of the token text; the start is Pos - len(Value).
|
||||||
|
runePos := runePosFromLineCol(p.input, next.Pos.Line, next.Pos.Col) - len(next.Value)
|
||||||
|
|
||||||
|
// Walk through next.Value byte-by-byte to find the '(' that corresponds to the condition.
|
||||||
|
// We know trimmed starts at some offset into next.Value; find "if" in the raw value.
|
||||||
|
rawIfIdx := strings.Index(next.Value, "if")
|
||||||
|
rawParenOffset := rawIfIdx + 2
|
||||||
|
for rawParenOffset < len(next.Value) && next.Value[rawParenOffset] == ' ' {
|
||||||
|
rawParenOffset++
|
||||||
|
}
|
||||||
|
if rawParenOffset >= len(next.Value) || next.Value[rawParenOffset] != '(' {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parenRunePos points to '(' in the raw input
|
||||||
|
parenRunePos := runePos + rawParenOffset
|
||||||
|
// readUntilParen expects the position AFTER '(' — it reads content between start and closing ')'
|
||||||
|
exprStr, parenClose, err := p.readUntilParen(parenRunePos + 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated else-if condition, expected ')'", next.Pos.Line, next.Pos.Col)
|
||||||
|
}
|
||||||
|
|
||||||
|
expr, err := NewExprParser(strings.TrimSpace(exprStr), next.Pos.Line, next.Pos.Col).Parse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume tokens from the start of this text token up to past ')'
|
||||||
|
p.consumeTokensForRuneRange(runePos, parenClose+1)
|
||||||
|
|
||||||
|
// Find '{' after ')' in raw input
|
||||||
|
braceOpen := findChar(p.input[parenClose+1:], '{')
|
||||||
|
if braceOpen < 0 {
|
||||||
|
return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else-if condition", next.Pos.Line, next.Pos.Col)
|
||||||
|
}
|
||||||
|
braceOpen = parenClose + 1 + braceOpen
|
||||||
|
|
||||||
|
// Consume tokens up to past '{'
|
||||||
|
p.consumeTokensForRuneRange(parenClose+1, braceOpen+1)
|
||||||
|
|
||||||
|
// Parse the elseif body — it may itself have else/elseif
|
||||||
|
elseifBody, elseBody, moreBranches, err := p.parseBlockBodyWithElse("elseif")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
branch := &ElseIfBranch{
|
||||||
|
Pos: next.Pos,
|
||||||
|
Cond: expr,
|
||||||
|
Body: elseifBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
allBranches := []*ElseIfBranch{branch}
|
||||||
|
allBranches = append(allBranches, moreBranches...)
|
||||||
|
return body, elseBody, allBranches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitAtClosingBrace splits text at the first '}' at brace depth 0.
|
||||||
|
// Returns: text before '}', whether '}' was found.
|
||||||
|
func splitAtClosingBrace(text string) (string, bool) {
|
||||||
|
depth := 0
|
||||||
|
for i, ch := range text {
|
||||||
|
if ch == '{' {
|
||||||
|
depth++
|
||||||
|
} else if ch == '}' {
|
||||||
|
if depth == 0 {
|
||||||
|
return text[:i], true
|
||||||
|
}
|
||||||
|
depth--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return text, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWhitespace(s string) bool {
|
||||||
|
return strings.TrimSpace(s) == ""
|
||||||
|
}
|
||||||
40
internal/placeholder.go
Normal file
40
internal/placeholder.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type PlaceholderStyle int
|
||||||
|
|
||||||
|
const (
|
||||||
|
QuestionMark PlaceholderStyle = iota
|
||||||
|
DollarNumber
|
||||||
|
ColonNumber
|
||||||
|
)
|
||||||
|
|
||||||
|
type Placeholder struct {
|
||||||
|
style PlaceholderStyle
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPlaceholder(style PlaceholderStyle) *Placeholder {
|
||||||
|
return &Placeholder{style: style}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Placeholder) Next() string {
|
||||||
|
p.count++
|
||||||
|
switch p.style {
|
||||||
|
case DollarNumber:
|
||||||
|
return fmt.Sprintf("$%d", p.count)
|
||||||
|
case ColonNumber:
|
||||||
|
return fmt.Sprintf(":%d", p.count)
|
||||||
|
default:
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Placeholder) Reset() {
|
||||||
|
p.count = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Placeholder) Count() int {
|
||||||
|
return p.count
|
||||||
|
}
|
||||||
39
safety.go
Normal file
39
safety.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import "slices"
|
||||||
|
|
||||||
|
type RawPolicy interface {
|
||||||
|
Validate(param string, value string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RawAllowlist map[string][]string
|
||||||
|
|
||||||
|
func (a RawAllowlist) Validate(param string, value string) error {
|
||||||
|
allowed, ok := a[param]
|
||||||
|
if !ok {
|
||||||
|
return &UnsafeRawError{Param: param, Value: value, Message: "no allowlist defined"}
|
||||||
|
}
|
||||||
|
if slices.Contains(allowed, value) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UnsafeRawError{Param: param, Value: value, Message: "value not in allowlist"}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RawBlocklist map[string][]string
|
||||||
|
|
||||||
|
func (b RawBlocklist) Validate(param string, value string) error {
|
||||||
|
blocked, ok := b[param]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if slices.Contains(blocked, value) {
|
||||||
|
return &UnsafeRawError{Param: param, Value: value, Message: "value is blocked"}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type RawNoop struct{}
|
||||||
|
|
||||||
|
func (RawNoop) Validate(string, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
85
template.go
Normal file
85
template.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.1216.top/lxy/u-tpl/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
SQL string
|
||||||
|
Args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
type Template struct {
|
||||||
|
name string
|
||||||
|
engine *Engine
|
||||||
|
nodes []internal.Node
|
||||||
|
blocks map[string][]internal.Node
|
||||||
|
hasBlocks bool
|
||||||
|
namespace string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) Execute(vars map[string]any) (*Result, error) {
|
||||||
|
if t.hasBlocks {
|
||||||
|
return nil, &ExecError{
|
||||||
|
Pos: Position{},
|
||||||
|
Message: fmt.Sprintf("template %q has named blocks, use ExecuteBlock instead", t.name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t.executeNodes(t.nodes, vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) ExecuteBlock(blockName string, vars map[string]any) (*Result, error) {
|
||||||
|
nodes, ok := t.blocks[blockName]
|
||||||
|
if !ok {
|
||||||
|
if t.namespace != "" {
|
||||||
|
nodes, ok = t.blocks[t.namespace+"."+blockName]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, &ExecError{
|
||||||
|
Pos: Position{},
|
||||||
|
Message: fmt.Sprintf("block %q not found in template %q", blockName, t.name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t.executeNodes(nodes, vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) ExecuteString(vars map[string]any) (string, error) {
|
||||||
|
result, err := t.Execute(vars)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result.SQL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) ExecuteBlockString(blockName string, vars map[string]any) (string, error) {
|
||||||
|
result, err := t.ExecuteBlock(blockName, vars)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return result.SQL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) executeNodes(nodes []internal.Node, vars map[string]any) (*Result, error) {
|
||||||
|
executor := internal.NewExecutor(t.engine.style, t.engine.rawPolicy, t.engine.strict)
|
||||||
|
result, err := executor.Execute(nodes, vars)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapExecError(err)
|
||||||
|
}
|
||||||
|
return &Result{SQL: result.SQL, Args: result.Args}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapExecError(err error) error {
|
||||||
|
var execErr *ExecError
|
||||||
|
if errors.As(err, &execErr) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var unsafeErr *UnsafeRawError
|
||||||
|
if errors.As(err, &unsafeErr) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return &ExecError{Message: err.Error()}
|
||||||
|
}
|
||||||
944
utpl_ext_test.go
Normal file
944
utpl_ext_test.go
Normal file
@@ -0,0 +1,944 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------- 1. TestInclude — @include directive ----------
|
||||||
|
//
|
||||||
|
// NOTE: The current implementation creates IncludeNode during parsing but does
|
||||||
|
// not expand included content inline. The IncludeManager is constructed but never
|
||||||
|
// invoked. These tests document the actual behavior. When include expansion is
|
||||||
|
// fully wired up (content resolved before lexing), these tests should be updated
|
||||||
|
// to verify full inline expansion.
|
||||||
|
|
||||||
|
func TestInclude(t *testing.T) {
|
||||||
|
t.Run("include node parses without error when resolver configured", func(t *testing.T) {
|
||||||
|
resolver := func(path string) (string, error) {
|
||||||
|
files := map[string]string{
|
||||||
|
"common/tenant": "AND tenant_id = #{tenant_id}",
|
||||||
|
}
|
||||||
|
src, ok := files[path]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("not found: %s", path)
|
||||||
|
}
|
||||||
|
return src, nil
|
||||||
|
}
|
||||||
|
src := `SELECT * FROM orders WHERE 1=1 @include("common/tenant")`
|
||||||
|
_, err := New(WithIncludeResolver(resolver)).Parse("test", src)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse should succeed, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("include without resolver returns error", func(t *testing.T) {
|
||||||
|
src := `SELECT 1 @include("anything")`
|
||||||
|
_, err := New().Parse("test", src)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when @include used without resolver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("include content is expanded in output", func(t *testing.T) {
|
||||||
|
resolver := func(path string) (string, error) {
|
||||||
|
if path == "x" {
|
||||||
|
return "RESOLVED_CONTENT", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("not found: %s", path)
|
||||||
|
}
|
||||||
|
src := `BEFORE @include("x") AFTER`
|
||||||
|
tpl, _ := New(WithIncludeResolver(resolver)).Parse("test", src)
|
||||||
|
r, err := tpl.Execute(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "RESOLVED_CONTENT") {
|
||||||
|
t.Errorf("SQL = %q, should contain resolved content", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "BEFORE") || !strings.Contains(r.SQL, "AFTER") {
|
||||||
|
t.Errorf("SQL = %q, should contain BEFORE and AFTER text", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("include with resolver error fails at parse", func(t *testing.T) {
|
||||||
|
resolver := func(path string) (string, error) {
|
||||||
|
return "", fmt.Errorf("not found: %s", path)
|
||||||
|
}
|
||||||
|
src := `SELECT 1 @include("nonexistent")`
|
||||||
|
_, err := New(WithIncludeResolver(resolver)).Parse("test", src)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when include resolver fails")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 2. TestElseIf — else if / else branches ----------
|
||||||
|
|
||||||
|
func TestElseIf(t *testing.T) {
|
||||||
|
t.Run("if with else branch", func(t *testing.T) {
|
||||||
|
src := "@if(x == 1) {ONE} else {OTHER}"
|
||||||
|
tests := []struct {
|
||||||
|
val any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{1, "ONE"},
|
||||||
|
{2, "OTHER"},
|
||||||
|
{99, "OTHER"},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
r, err := tpl.Execute(map[string]any{"x": tc.val})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x=%v: %v", tc.val, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("if else if else chain", func(t *testing.T) {
|
||||||
|
src := `@if(role == "admin") {ADMIN} else if (role == "manager") {MGR} else {OTHER}`
|
||||||
|
tests := []struct {
|
||||||
|
role string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"admin", "ADMIN"},
|
||||||
|
{"manager", "MGR"},
|
||||||
|
{"guest", "OTHER"},
|
||||||
|
{"other", "OTHER"},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
r, err := tpl.Execute(map[string]any{"role": tc.role})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("role=%q: %v", tc.role, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("role=%q: SQL = %q, want to contain %q", tc.role, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple else if branches", func(t *testing.T) {
|
||||||
|
src := `@if(x == 1) {ONE} else if (x == 2) {TWO} else if (x == 3) {THREE} else {OTHER}`
|
||||||
|
tests := []struct {
|
||||||
|
val any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{1, "ONE"},
|
||||||
|
{2, "TWO"},
|
||||||
|
{3, "THREE"},
|
||||||
|
{99, "OTHER"},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
r, err := tpl.Execute(map[string]any{"x": tc.val})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x=%v: %v", tc.val, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("else if without final else", func(t *testing.T) {
|
||||||
|
src := `@if(x == 1) {ONE} else if (x == 2) {TWO}`
|
||||||
|
tests := []struct {
|
||||||
|
val any
|
||||||
|
want string
|
||||||
|
found bool
|
||||||
|
}{
|
||||||
|
{1, "ONE", true},
|
||||||
|
{2, "TWO", true},
|
||||||
|
{99, "", false},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
r, err := tpl.Execute(map[string]any{"x": tc.val})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x=%v: %v", tc.val, err)
|
||||||
|
}
|
||||||
|
if tc.found && !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
if !tc.found && r.SQL != "" {
|
||||||
|
t.Errorf("x=%v: SQL = %q, want empty", tc.val, r.SQL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("else if with params in branches", func(t *testing.T) {
|
||||||
|
src := `@if(role == "admin") {level = #{admin_level}} else if (role == "manager") {level = #{mgr_level}} else {level = #{default_level}}`
|
||||||
|
tests := []struct {
|
||||||
|
role string
|
||||||
|
want string
|
||||||
|
argVal any
|
||||||
|
}{
|
||||||
|
{"admin", "level = ?", 10},
|
||||||
|
{"manager", "level = ?", 5},
|
||||||
|
{"guest", "level = ?", 1},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, parseErr := New().Parse("test", src)
|
||||||
|
if parseErr != nil {
|
||||||
|
t.Fatalf("role=%q: parse failed: %v", tc.role, parseErr)
|
||||||
|
}
|
||||||
|
r, err := tpl.Execute(map[string]any{
|
||||||
|
"role": tc.role,
|
||||||
|
"admin_level": 10,
|
||||||
|
"mgr_level": 5,
|
||||||
|
"default_level": 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("role=%q: %v", tc.role, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("role=%q: SQL = %q, want to contain %q", tc.role, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != tc.argVal {
|
||||||
|
t.Errorf("role=%q: Args = %v, want [%v]", tc.role, r.Args, tc.argVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("else if with no space before paren", func(t *testing.T) {
|
||||||
|
src := "@if(x == 1) {ONE} else if(x == 2) {TWO} else {OTHER}"
|
||||||
|
tests := []struct {
|
||||||
|
val any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{1, "ONE"},
|
||||||
|
{2, "TWO"},
|
||||||
|
{99, "OTHER"},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
r, err := tpl.Execute(map[string]any{"x": tc.val})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x=%v: %v", tc.val, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, tc.want) {
|
||||||
|
t.Errorf("x=%v: SQL = %q, want to contain %q", tc.val, r.SQL, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("else if with multiline SQL", func(t *testing.T) {
|
||||||
|
src := `SELECT * FROM users WHERE 1=1
|
||||||
|
@if(role == "admin") {
|
||||||
|
AND level >= #{admin_level}
|
||||||
|
} else if (role == "manager") {
|
||||||
|
AND level >= #{mgr_level}
|
||||||
|
} else {
|
||||||
|
AND level >= 1
|
||||||
|
}`
|
||||||
|
r := exec(t, src, map[string]any{
|
||||||
|
"role": "manager",
|
||||||
|
"admin_level": 10,
|
||||||
|
"mgr_level": 5,
|
||||||
|
})
|
||||||
|
if !strings.Contains(r.SQL, "AND level >= ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'AND level >= ?'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != 5 {
|
||||||
|
t.Errorf("Args = %v, want [5]", r.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, src, map[string]any{
|
||||||
|
"role": "admin",
|
||||||
|
"admin_level": 10,
|
||||||
|
"mgr_level": 5,
|
||||||
|
})
|
||||||
|
if !strings.Contains(r2.SQL, "AND level >= ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'AND level >= ?'", r2.SQL)
|
||||||
|
}
|
||||||
|
if len(r2.Args) != 1 || r2.Args[0] != 10 {
|
||||||
|
t.Errorf("Args = %v, want [10]", r2.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
r3 := exec(t, src, map[string]any{
|
||||||
|
"role": "guest",
|
||||||
|
"admin_level": 10,
|
||||||
|
"mgr_level": 5,
|
||||||
|
})
|
||||||
|
if !strings.Contains(r3.SQL, "AND level >= 1") {
|
||||||
|
t.Errorf("SQL = %q, want 'AND level >= 1'", r3.SQL)
|
||||||
|
}
|
||||||
|
if len(r3.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r3.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("if without else produces no output when false", func(t *testing.T) {
|
||||||
|
src := "@if(x == 1) {ONE}"
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
|
||||||
|
r, _ := tpl.Execute(map[string]any{"x": 1})
|
||||||
|
if !strings.Contains(r.SQL, "ONE") {
|
||||||
|
t.Errorf("x=1: SQL = %q, want ONE", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2, _ := tpl.Execute(map[string]any{"x": 99})
|
||||||
|
if strings.Contains(r2.SQL, "ONE") {
|
||||||
|
t.Errorf("x=99: SQL = %q, should not contain ONE", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 3. TestExpression — expression edge cases ----------
|
||||||
|
|
||||||
|
func TestExpression(t *testing.T) {
|
||||||
|
t.Run("len function with []any non-empty", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(len(ids) > 0) {has ids}", map[string]any{"ids": []any{1, 2}})
|
||||||
|
if !strings.Contains(r.SQL, "has ids") {
|
||||||
|
t.Errorf("SQL = %q, want 'has ids'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("len function with []any empty", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(len(ids) > 0) {has ids}", map[string]any{"ids": []any{}})
|
||||||
|
if strings.Contains(r.SQL, "has ids") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'has ids'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("len function with string", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(len(name) > 0) {has name}", map[string]any{"name": "alice"})
|
||||||
|
if !strings.Contains(r.SQL, "has name") {
|
||||||
|
t.Errorf("SQL = %q, want 'has name'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(len(name) > 0) {has name}", map[string]any{"name": ""})
|
||||||
|
if strings.Contains(r2.SQL, "has name") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'has name'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("numeric comparison ge", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(age >= 18) {adult}", map[string]any{"age": 20})
|
||||||
|
if !strings.Contains(r.SQL, "adult") {
|
||||||
|
t.Errorf("age=20: SQL = %q, want 'adult'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(age >= 18) {adult}", map[string]any{"age": 10})
|
||||||
|
if strings.Contains(r2.SQL, "adult") {
|
||||||
|
t.Errorf("age=10: SQL = %q, should not contain 'adult'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("boolean variable true", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(flag) {yes}", map[string]any{"flag": true})
|
||||||
|
if !strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("flag=true: SQL = %q, want 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("boolean variable false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(flag) {yes}", map[string]any{"flag": false})
|
||||||
|
if strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("flag=false: SQL = %q, should not contain 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negation true", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(!flag) {no flag}", map[string]any{"flag": true})
|
||||||
|
if strings.Contains(r.SQL, "no flag") {
|
||||||
|
t.Errorf("flag=true: SQL = %q, should not contain 'no flag'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("negation false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(!flag) {no flag}", map[string]any{"flag": false})
|
||||||
|
if !strings.Contains(r.SQL, "no flag") {
|
||||||
|
t.Errorf("flag=false: SQL = %q, want 'no flag'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested dot path in condition", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(user.active) {active}", map[string]any{
|
||||||
|
"user": map[string]any{"active": true},
|
||||||
|
})
|
||||||
|
if !strings.Contains(r.SQL, "active") {
|
||||||
|
t.Errorf("SQL = %q, want 'active'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(user.active) {active}", map[string]any{
|
||||||
|
"user": map[string]any{"active": false},
|
||||||
|
})
|
||||||
|
if strings.Contains(r2.SQL, "active") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'active'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("compare to nil true", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(val == nil) {is nil}", map[string]any{"val": nil})
|
||||||
|
if !strings.Contains(r.SQL, "is nil") {
|
||||||
|
t.Errorf("SQL = %q, want 'is nil'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("compare to nil false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(val == nil) {is nil}", map[string]any{"val": 42})
|
||||||
|
if strings.Contains(r.SQL, "is nil") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'is nil'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("compare nil not equal true", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(val != nil) {not nil}", map[string]any{"val": 42})
|
||||||
|
if !strings.Contains(r.SQL, "not nil") {
|
||||||
|
t.Errorf("SQL = %q, want 'not nil'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("compare nil not equal false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(val != nil) {not nil}", map[string]any{"val": nil})
|
||||||
|
if strings.Contains(r.SQL, "not nil") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'not nil'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("string equality comparison", func(t *testing.T) {
|
||||||
|
r := exec(t, `@if(role == "admin") {is admin}`, map[string]any{"role": "admin"})
|
||||||
|
if !strings.Contains(r.SQL, "is admin") {
|
||||||
|
t.Errorf("SQL = %q, want 'is admin'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, `@if(role == "admin") {is admin}`, map[string]any{"role": "user"})
|
||||||
|
if strings.Contains(r2.SQL, "is admin") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'is admin'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("numeric less than comparison", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(count < 100) {under limit}", map[string]any{"count": 50})
|
||||||
|
if !strings.Contains(r.SQL, "under limit") {
|
||||||
|
t.Errorf("SQL = %q, want 'under limit'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(count < 100) {under limit}", map[string]any{"count": 150})
|
||||||
|
if strings.Contains(r2.SQL, "under limit") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'under limit'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("and operator short circuit", func(t *testing.T) {
|
||||||
|
// When left is false, right should not cause error even if undefined
|
||||||
|
r := exec(t, "@if(a != nil && b != \"\") {both}",
|
||||||
|
map[string]any{"a": nil})
|
||||||
|
if strings.Contains(r.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'both' when a is nil", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 4. TestErrorPaths — error handling ----------
|
||||||
|
|
||||||
|
func TestErrorPaths(t *testing.T) {
|
||||||
|
t.Run("unterminated param", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", "SELECT #{id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unterminated param, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unterminated if", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", "@if(x > 0")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unterminated @if, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unterminated for", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", "@for(x, range list)")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unterminated @for (missing {), got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty tpl name", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", `@tpl("") { SELECT 1 }`)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty @tpl name, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "empty") {
|
||||||
|
t.Errorf("error should mention empty: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate tpl names", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", `@tpl("a") {X} @tpl("a") {Y}`)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate @tpl names, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "duplicate") {
|
||||||
|
t.Errorf("error should mention duplicate: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("namespace not at top", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", `SELECT 1 @namespace("x")`)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for @namespace not at top, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "top") {
|
||||||
|
t.Errorf("error should mention top: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tpl nested inside tpl", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", `@tpl("a") { @tpl("b") {X} }`)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nested @tpl, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "nested") {
|
||||||
|
t.Errorf("error should mention nested: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("undefined variable in strict mode", func(t *testing.T) {
|
||||||
|
tpl, _ := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}")
|
||||||
|
_, err := tpl.Execute(map[string]any{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for undefined variable in strict mode")
|
||||||
|
}
|
||||||
|
// The executor returns a plain error (not wrapped in ExecError)
|
||||||
|
if !strings.Contains(err.Error(), "undefined") {
|
||||||
|
t.Errorf("error should mention undefined: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("for with non-slice value", func(t *testing.T) {
|
||||||
|
tpl, err := New().Parse("test", "SELECT @for(x, range items) {#{x}, }")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
_, err = tpl.Execute(map[string]any{"items": "not a slice"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-slice in @for, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "slice") {
|
||||||
|
t.Errorf("error should mention slice: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unterminated raw param", func(t *testing.T) {
|
||||||
|
_, err := New().Parse("test", "SELECT ${col")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unterminated raw param, got nil")
|
||||||
|
}
|
||||||
|
var parseErr *ParseError
|
||||||
|
if !errors.As(err, &parseErr) {
|
||||||
|
t.Errorf("expected ParseError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 5. TestEdgeCases — boundary conditions ----------
|
||||||
|
|
||||||
|
func TestEdgeCases(t *testing.T) {
|
||||||
|
t.Run("empty template", func(t *testing.T) {
|
||||||
|
r := exec(t, "", nil)
|
||||||
|
if r.SQL != "" {
|
||||||
|
t.Errorf("SQL = %q, want empty", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pure text", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT 1", nil)
|
||||||
|
if r.SQL != "SELECT 1" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1")
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only comments", func(t *testing.T) {
|
||||||
|
r := exec(t, "# comment", nil)
|
||||||
|
if r.SQL != "" {
|
||||||
|
t.Errorf("SQL = %q, want empty", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple consecutive comments", func(t *testing.T) {
|
||||||
|
r := exec(t, "# line1\n# line2\nSELECT 1", nil)
|
||||||
|
if r.SQL != "SELECT 1" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed params and raw", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT ${col} FROM t WHERE id = #{id}",
|
||||||
|
map[string]any{"col": "name", "id": 1})
|
||||||
|
if !strings.Contains(r.SQL, "SELECT name FROM t WHERE id = ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'SELECT name FROM t WHERE id = ?'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != 1 {
|
||||||
|
t.Errorf("Args = %v, want [1]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested if both true", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a) { @if(b) {both} }",
|
||||||
|
map[string]any{"a": true, "b": true})
|
||||||
|
if !strings.Contains(r.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, want 'both'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested if inner false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a) { @if(b) {both} }",
|
||||||
|
map[string]any{"a": true, "b": false})
|
||||||
|
if strings.Contains(r.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'both'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested if outer false", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a) { @if(b) {both} }",
|
||||||
|
map[string]any{"a": false, "b": true})
|
||||||
|
if strings.Contains(r.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'both'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("if with parenthesized expression", func(t *testing.T) {
|
||||||
|
src := "@if((a || b) && c) {yes}"
|
||||||
|
tpl, _ := New().Parse("test", src)
|
||||||
|
|
||||||
|
r, _ := tpl.Execute(map[string]any{"a": true, "b": false, "c": true})
|
||||||
|
if !strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("a=true,b=false,c=true: SQL = %q, want 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _ = tpl.Execute(map[string]any{"a": false, "b": true, "c": true})
|
||||||
|
if !strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("a=false,b=true,c=true: SQL = %q, want 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _ = tpl.Execute(map[string]any{"a": false, "b": false, "c": true})
|
||||||
|
if strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("a=false,b=false,c=true: SQL = %q, should not contain 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _ = tpl.Execute(map[string]any{"a": true, "b": true, "c": false})
|
||||||
|
if strings.Contains(r.SQL, "yes") {
|
||||||
|
t.Errorf("a=true,b=true,c=false: SQL = %q, should not contain 'yes'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("for with index variable", func(t *testing.T) {
|
||||||
|
r := exec(t, "@for(i, v range items) {#{i}:#{v}, }",
|
||||||
|
map[string]any{"items": []any{"x", "y", "z"}})
|
||||||
|
if len(r.Args) != 6 {
|
||||||
|
t.Fatalf("Args = %v, want 6 args (3 indices + 3 values)", r.Args)
|
||||||
|
}
|
||||||
|
// Indices: 0, 1, 2
|
||||||
|
if r.Args[0] != 0 || r.Args[2] != 1 || r.Args[4] != 2 {
|
||||||
|
t.Errorf("index args wrong: %v", r.Args)
|
||||||
|
}
|
||||||
|
// Values: x, y, z
|
||||||
|
if r.Args[1] != "x" || r.Args[3] != "y" || r.Args[5] != "z" {
|
||||||
|
t.Errorf("value args wrong: %v", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("comment does not interfere with param", func(t *testing.T) {
|
||||||
|
// #{param} starts with # but is not a comment
|
||||||
|
r := exec(t, "SELECT #{id}\n# real comment",
|
||||||
|
map[string]any{"id": 42})
|
||||||
|
if !strings.Contains(r.SQL, "SELECT ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'SELECT ?'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != 42 {
|
||||||
|
t.Errorf("Args = %v, want [42]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw with numeric value", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT ${n}", map[string]any{"n": 123})
|
||||||
|
if !strings.Contains(r.SQL, "SELECT 123") {
|
||||||
|
t.Errorf("SQL = %q, want 'SELECT 123'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple consecutive params", func(t *testing.T) {
|
||||||
|
r := exec(t, "INSERT INTO t (a, b, c) VALUES (#{a}, #{b}, #{c})",
|
||||||
|
map[string]any{"a": 1, "b": 2, "c": 3})
|
||||||
|
if !strings.Contains(r.SQL, "?, ?, ?") {
|
||||||
|
t.Errorf("SQL = %q, want three placeholders", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 3 {
|
||||||
|
t.Errorf("Args = %v, want 3", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 6. TestMultiline — real SQL with line breaks ----------
|
||||||
|
|
||||||
|
func TestMultiline(t *testing.T) {
|
||||||
|
t.Run("full SELECT with multiple lines", func(t *testing.T) {
|
||||||
|
src := `SELECT u.id, u.name, u.email
|
||||||
|
FROM users u
|
||||||
|
WHERE u.status = #{status}
|
||||||
|
@if(name != "") { AND u.name LIKE #{name} }
|
||||||
|
ORDER BY u.id DESC
|
||||||
|
LIMIT #{limit}`
|
||||||
|
tpl, err := New().Parse("test", src)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
r, err := tpl.Execute(map[string]any{
|
||||||
|
"status": "active",
|
||||||
|
"name": "%alice%",
|
||||||
|
"limit": 10,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(r.SQL, "SELECT u.id, u.name, u.email") {
|
||||||
|
t.Errorf("SQL missing SELECT: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "FROM users u") {
|
||||||
|
t.Errorf("SQL missing FROM: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "WHERE u.status = ?") {
|
||||||
|
t.Errorf("SQL missing WHERE: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "AND u.name LIKE ?") {
|
||||||
|
t.Errorf("SQL missing name condition: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "ORDER BY u.id DESC") {
|
||||||
|
t.Errorf("SQL missing ORDER BY: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "LIMIT ?") {
|
||||||
|
t.Errorf("SQL missing LIMIT: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 3 || r.Args[0] != "active" || r.Args[1] != "%alice%" || r.Args[2] != 10 {
|
||||||
|
t.Errorf("Args = %v, want [active %%alice%% 10]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("line breaks preserved", func(t *testing.T) {
|
||||||
|
src := "SELECT id\nFROM users\nWHERE id = #{id}"
|
||||||
|
r := exec(t, src, map[string]any{"id": 1})
|
||||||
|
lines := strings.Split(r.SQL, "\n")
|
||||||
|
if len(lines) < 3 {
|
||||||
|
t.Errorf("SQL = %q, want at least 3 lines", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "SELECT id\nFROM users\nWHERE id = ?") {
|
||||||
|
t.Errorf("SQL = %q, line breaks not preserved", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiline dynamic update", func(t *testing.T) {
|
||||||
|
src := `UPDATE users SET
|
||||||
|
@if(name != nil) { name = #{name}, }
|
||||||
|
@if(email != nil) { email = #{email}, }
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = #{id}`
|
||||||
|
r := exec(t, src, map[string]any{"name": "alice", "email": nil, "id": 1})
|
||||||
|
if !strings.Contains(r.SQL, "name = ?") {
|
||||||
|
t.Errorf("SQL missing name: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if strings.Contains(r.SQL, "email = ?") {
|
||||||
|
t.Errorf("SQL should not contain email: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "WHERE id = ?") {
|
||||||
|
t.Errorf("SQL missing WHERE: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 2 {
|
||||||
|
t.Errorf("Args = %v, want 2 args", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 7. TestExecuteString — convenience methods ----------
|
||||||
|
|
||||||
|
func TestExecuteString(t *testing.T) {
|
||||||
|
t.Run("ExecuteString returns SQL only", func(t *testing.T) {
|
||||||
|
tpl, err := New().Parse("test", "SELECT * FROM users WHERE id = #{id}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
sql, err := tpl.ExecuteString(map[string]any{"id": 42})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteString failed: %v", err)
|
||||||
|
}
|
||||||
|
if sql != "SELECT * FROM users WHERE id = ?" {
|
||||||
|
t.Errorf("SQL = %q, want %q", sql, "SELECT * FROM users WHERE id = ?")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteBlockString returns SQL only", func(t *testing.T) {
|
||||||
|
tpl, err := New().Parse("test", "@tpl(\"search\") {\nSELECT * FROM orders WHERE id = #{id}\n}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
sql, err := tpl.ExecuteBlockString("search", map[string]any{"id": 7})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlockString failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(sql, "SELECT * FROM orders WHERE id = ?") {
|
||||||
|
t.Errorf("SQL = %q, want to contain 'SELECT * FROM orders WHERE id = ?'", sql)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteString error propagation", func(t *testing.T) {
|
||||||
|
tpl, _ := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}")
|
||||||
|
_, err := tpl.ExecuteString(map[string]any{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for undefined variable, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteBlockString error for missing block", func(t *testing.T) {
|
||||||
|
tpl, _ := New().Parse("test", `@tpl("search") {SELECT 1}`)
|
||||||
|
_, err := tpl.ExecuteBlockString("nonexistent", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for missing block, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 8. TestBenchmarkConsistency ----------
|
||||||
|
|
||||||
|
func TestBenchmarkConsistency(t *testing.T) {
|
||||||
|
t.Run("simple benchmark scenario", func(t *testing.T) {
|
||||||
|
tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
|
||||||
|
result, err := tpl.Execute(map[string]any{"id": 42, "name": "alice"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" {
|
||||||
|
t.Errorf("SQL = %q", result.SQL)
|
||||||
|
}
|
||||||
|
if len(result.Args) != 2 || result.Args[0] != 42 || result.Args[1] != "alice" {
|
||||||
|
t.Errorf("Args = %v, want [42 alice]", result.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("conditional benchmark scenario", func(t *testing.T) {
|
||||||
|
src := `SELECT * FROM users WHERE 1=1
|
||||||
|
@if(status != nil) { AND status = #{status} }
|
||||||
|
@if(name != "") { AND name = #{name} }
|
||||||
|
ORDER BY id`
|
||||||
|
tpl := New().MustParse("bench", src)
|
||||||
|
|
||||||
|
// all conditions active
|
||||||
|
result, err := tpl.Execute(map[string]any{"status": "active", "name": "alice"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.SQL, "AND status = ?") {
|
||||||
|
t.Errorf("SQL missing status: %q", result.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.SQL, "AND name = ?") {
|
||||||
|
t.Errorf("SQL missing name: %q", result.SQL)
|
||||||
|
}
|
||||||
|
if len(result.Args) != 2 {
|
||||||
|
t.Errorf("Args = %v, want 2 args", result.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// no conditions
|
||||||
|
result2, err := tpl.Execute(map[string]any{"status": nil, "name": ""})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if strings.Contains(result2.SQL, "AND") {
|
||||||
|
t.Errorf("SQL should have no AND: %q", result2.SQL)
|
||||||
|
}
|
||||||
|
if len(result2.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", result2.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("loop benchmark scenario", func(t *testing.T) {
|
||||||
|
src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })`
|
||||||
|
tpl := New().MustParse("bench", src)
|
||||||
|
ids := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
result, err := tpl.Execute(map[string]any{"ids": ids})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Args) != 10 {
|
||||||
|
t.Errorf("Args = %v, want 10 args", result.Args)
|
||||||
|
}
|
||||||
|
for i, want := range ids {
|
||||||
|
if result.Args[i] != want {
|
||||||
|
t.Errorf("Args[%d] = %v, want %v", i, result.Args[i], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(result.SQL, ",") {
|
||||||
|
t.Errorf("SQL should not end with comma: %q", result.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("placeholder dollar benchmark scenario", func(t *testing.T) {
|
||||||
|
eng := New(WithPlaceholderStyle(DollarNumber))
|
||||||
|
tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
|
||||||
|
result, err := tpl.Execute(map[string]any{"id": 42, "name": "alice"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.SQL, "WHERE id = $1 AND name = $2") {
|
||||||
|
t.Errorf("SQL = %q, want $1/$2 placeholders", result.SQL)
|
||||||
|
}
|
||||||
|
if len(result.Args) != 2 {
|
||||||
|
t.Errorf("Args = %v, want 2 args", result.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("placeholder colon benchmark scenario", func(t *testing.T) {
|
||||||
|
eng := New(WithPlaceholderStyle(ColonNumber))
|
||||||
|
tpl := eng.MustParse("bench", "WHERE id = #{id} AND name = #{name}")
|
||||||
|
result, err := tpl.Execute(map[string]any{"id": 1, "name": "a"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.SQL, "WHERE id = :1 AND name = :2") {
|
||||||
|
t.Errorf("SQL = %q, want :1/:2 placeholders", result.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
660
utpl_test.go
Normal file
660
utpl_test.go
Normal file
@@ -0,0 +1,660 @@
|
|||||||
|
package utpl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helper to parse with default engine
|
||||||
|
func parse(t *testing.T, source string) *Template {
|
||||||
|
t.Helper()
|
||||||
|
tpl, err := New().Parse("test", source)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
return tpl
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper to parse with options and execute
|
||||||
|
func exec(t *testing.T, source string, vars map[string]any, opts ...Option) *Result {
|
||||||
|
t.Helper()
|
||||||
|
tpl, err := New(opts...).Parse("test", source)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
result, err := tpl.Execute(vars)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 1. TestParam ----------
|
||||||
|
|
||||||
|
func TestParam(t *testing.T) {
|
||||||
|
t.Run("single param", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE id = #{id}", map[string]any{"id": 42})
|
||||||
|
if r.SQL != "SELECT * FROM users WHERE id = ?" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT * FROM users WHERE id = ?")
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != 42 {
|
||||||
|
t.Errorf("Args = %v, want [42]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple params", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE id = #{id} AND name = #{name}",
|
||||||
|
map[string]any{"id": 1, "name": "alice"})
|
||||||
|
if r.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 2 || r.Args[0] != 1 || r.Args[1] != "alice" {
|
||||||
|
t.Errorf("Args = %v, want [1 alice]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil value param", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE name = #{name}", map[string]any{"name": nil})
|
||||||
|
if r.SQL != "SELECT * FROM users WHERE name = ?" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != nil {
|
||||||
|
t.Errorf("Args = %v, want [nil]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("param with nested dot path", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE name = #{user.name}",
|
||||||
|
map[string]any{"user": map[string]any{"name": "bob"}})
|
||||||
|
if r.SQL != "SELECT * FROM users WHERE name = ?" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != "bob" {
|
||||||
|
t.Errorf("Args = %v, want [bob]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 2. TestRaw ----------
|
||||||
|
|
||||||
|
func TestRaw(t *testing.T) {
|
||||||
|
t.Run("basic raw substitution", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"})
|
||||||
|
if r.SQL != "SELECT name FROM users" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT name FROM users")
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple raw params", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT ${a}, ${b} FROM ${table}",
|
||||||
|
map[string]any{"a": "id", "b": "name", "table": "users"})
|
||||||
|
if r.SQL != "SELECT id, name FROM users" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allowlist pass", func(t *testing.T) {
|
||||||
|
policy := RawAllowlist{
|
||||||
|
"col": {"name", "age"},
|
||||||
|
}
|
||||||
|
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"},
|
||||||
|
WithRawPolicy(policy))
|
||||||
|
if r.SQL != "SELECT name FROM users" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allowlist reject", func(t *testing.T) {
|
||||||
|
policy := RawAllowlist{
|
||||||
|
"col": {"name", "age"},
|
||||||
|
}
|
||||||
|
tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users")
|
||||||
|
_, err := tpl.Execute(map[string]any{"col": "password"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for blocked raw value, got nil")
|
||||||
|
}
|
||||||
|
var unsafeErr *UnsafeRawError
|
||||||
|
if !errors.As(err, &unsafeErr) {
|
||||||
|
t.Errorf("expected UnsafeRawError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocklist allow", func(t *testing.T) {
|
||||||
|
policy := RawBlocklist{
|
||||||
|
"col": {"password"},
|
||||||
|
}
|
||||||
|
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"},
|
||||||
|
WithRawPolicy(policy))
|
||||||
|
if r.SQL != "SELECT name FROM users" {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocklist reject", func(t *testing.T) {
|
||||||
|
policy := RawBlocklist{
|
||||||
|
"col": {"password", "secret"},
|
||||||
|
}
|
||||||
|
tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users")
|
||||||
|
_, err := tpl.Execute(map[string]any{"col": "password"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for blocklisted value, got nil")
|
||||||
|
}
|
||||||
|
var unsafeErr *UnsafeRawError
|
||||||
|
if !errors.As(err, &unsafeErr) {
|
||||||
|
t.Errorf("expected UnsafeRawError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 3. TestComment ----------
|
||||||
|
|
||||||
|
func TestComment(t *testing.T) {
|
||||||
|
t.Run("comment stripped", func(t *testing.T) {
|
||||||
|
r := exec(t, "# this is a comment\nSELECT 1", nil)
|
||||||
|
if r.SQL != "SELECT 1" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("comment in middle", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT 1\n# comment\nSELECT 2", nil)
|
||||||
|
want := "SELECT 1\nSELECT 2"
|
||||||
|
if r.SQL != want {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no false positive with param", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT #{id} FROM users", map[string]any{"id": 1})
|
||||||
|
if r.SQL != "SELECT ? FROM users" {
|
||||||
|
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT ? FROM users")
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != 1 {
|
||||||
|
t.Errorf("Args = %v", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 4. TestIf ----------
|
||||||
|
|
||||||
|
func TestIf(t *testing.T) {
|
||||||
|
t.Run("true condition", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }",
|
||||||
|
map[string]any{"status": "active"})
|
||||||
|
if !strings.Contains(r.SQL, "AND status = ?") {
|
||||||
|
t.Errorf("SQL = %q, want to contain 'AND status = ?'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != "active" {
|
||||||
|
t.Errorf("Args = %v, want [active]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("false condition nil", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }",
|
||||||
|
map[string]any{"status": nil})
|
||||||
|
if strings.Contains(r.SQL, "AND status") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'AND status'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("else branch", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}",
|
||||||
|
map[string]any{"desc": true})
|
||||||
|
if !strings.Contains(r.SQL, "DESC") {
|
||||||
|
t.Errorf("SQL = %q, want DESC", r.SQL)
|
||||||
|
}
|
||||||
|
if strings.Contains(r.SQL, "ASC") {
|
||||||
|
t.Errorf("SQL = %q, should not contain ASC", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}",
|
||||||
|
map[string]any{"desc": false})
|
||||||
|
if !strings.Contains(r2.SQL, "ASC") {
|
||||||
|
t.Errorf("SQL = %q, want ASC", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty string check", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }",
|
||||||
|
map[string]any{"name": "alice"})
|
||||||
|
if !strings.Contains(r.SQL, "AND name = ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'AND name = ?'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }",
|
||||||
|
map[string]any{"name": ""})
|
||||||
|
if strings.Contains(r2.SQL, "AND name") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'AND name'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple independent conditions", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a != nil) {A} @if(b != nil) {B}",
|
||||||
|
map[string]any{"a": 1, "b": 2})
|
||||||
|
if !strings.Contains(r.SQL, "A") || !strings.Contains(r.SQL, "B") {
|
||||||
|
t.Errorf("SQL = %q, want both A and B", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(a != nil) {A} @if(b != nil) {B}",
|
||||||
|
map[string]any{"a": 1})
|
||||||
|
if !strings.Contains(r2.SQL, "A") {
|
||||||
|
t.Errorf("SQL = %q, want A", r2.SQL)
|
||||||
|
}
|
||||||
|
if strings.Contains(r2.SQL, "B") {
|
||||||
|
t.Errorf("SQL = %q, should not contain B", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AND operator", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a != nil && b != \"\") { both }",
|
||||||
|
map[string]any{"a": 1, "b": "x"})
|
||||||
|
if !strings.Contains(r.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, want 'both'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(a != nil && b != \"\") { both }",
|
||||||
|
map[string]any{"a": 1, "b": ""})
|
||||||
|
if strings.Contains(r2.SQL, "both") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'both'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OR operator", func(t *testing.T) {
|
||||||
|
r := exec(t, "@if(a != nil || b != nil) { either }",
|
||||||
|
map[string]any{"a": nil, "b": 1})
|
||||||
|
if !strings.Contains(r.SQL, "either") {
|
||||||
|
t.Errorf("SQL = %q, want 'either'", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, "@if(a != nil || b != nil) { either }",
|
||||||
|
map[string]any{"a": nil, "b": nil})
|
||||||
|
if strings.Contains(r2.SQL, "either") {
|
||||||
|
t.Errorf("SQL = %q, should not contain 'either'", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 5. TestFor ----------
|
||||||
|
|
||||||
|
func TestFor(t *testing.T) {
|
||||||
|
t.Run("basic IN clause", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
|
||||||
|
map[string]any{"ids": []int{1, 2, 3}})
|
||||||
|
if len(r.Args) != 3 {
|
||||||
|
t.Fatalf("Args = %v, want 3 args", r.Args)
|
||||||
|
}
|
||||||
|
for i, want := range []int{1, 2, 3} {
|
||||||
|
if r.Args[i] != want {
|
||||||
|
t.Errorf("Args[%d] = %v, want %v", i, r.Args[i], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("trailing comma auto-trim", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
|
||||||
|
map[string]any{"ids": []int{1, 2}})
|
||||||
|
// trailing comma and space should be trimmed by the executor
|
||||||
|
if strings.HasSuffix(r.SQL, ",") {
|
||||||
|
t.Errorf("SQL = %q, should not end with comma", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty list", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
|
||||||
|
map[string]any{"ids": []int{}})
|
||||||
|
if strings.Contains(r.SQL, "?") {
|
||||||
|
t.Errorf("SQL = %q, should have no placeholders", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil list", func(t *testing.T) {
|
||||||
|
// nil list produces no output (treated as empty slice)
|
||||||
|
tpl := parse(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })")
|
||||||
|
r, err := tpl.Execute(map[string]any{"ids": nil})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if strings.Contains(r.SQL, "?") {
|
||||||
|
t.Errorf("SQL = %q, should have no placeholders", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with index variable", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT @for(i, v range items) {#{v}, }",
|
||||||
|
map[string]any{"items": []string{"a", "b", "c"}})
|
||||||
|
if len(r.Args) != 3 {
|
||||||
|
t.Errorf("Args = %v, want 3 args", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested object access", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT @for(u range users) {#{u.name}, }",
|
||||||
|
map[string]any{"users": []map[string]any{
|
||||||
|
{"name": "alice"},
|
||||||
|
{"name": "bob"},
|
||||||
|
}})
|
||||||
|
if len(r.Args) != 2 {
|
||||||
|
t.Fatalf("Args = %v, want 2 args", r.Args)
|
||||||
|
}
|
||||||
|
if r.Args[0] != "alice" || r.Args[1] != "bob" {
|
||||||
|
t.Errorf("Args = %v, want [alice bob]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 6. TestTpl ----------
|
||||||
|
|
||||||
|
func TestTpl(t *testing.T) {
|
||||||
|
t.Run("single block", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
|
||||||
|
r, err := tpl.ExecuteBlock("search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlock failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "SELECT * FROM users") {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple blocks", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@tpl("search") { SELECT * FROM users } @tpl("count") { SELECT COUNT(*) FROM users }`)
|
||||||
|
r1, err := tpl.ExecuteBlock("search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlock search failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r1.SQL, "SELECT * FROM users") {
|
||||||
|
t.Errorf("search SQL = %q", r1.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2, err := tpl.ExecuteBlock("count", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlock count failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r2.SQL, "SELECT COUNT(*) FROM users") {
|
||||||
|
t.Errorf("count SQL = %q", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("execute on tpl template should error", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
|
||||||
|
_, err := tpl.Execute(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when Execute is called on template with blocks")
|
||||||
|
}
|
||||||
|
var execErr *ExecError
|
||||||
|
if !errors.As(err, &execErr) {
|
||||||
|
t.Errorf("expected ExecError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ExecuteBlock with wrong name", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
|
||||||
|
_, err := tpl.ExecuteBlock("nonexistent", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent block name")
|
||||||
|
}
|
||||||
|
var execErr *ExecError
|
||||||
|
if !errors.As(err, &execErr) {
|
||||||
|
t.Errorf("expected ExecError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 7. TestNamespace ----------
|
||||||
|
|
||||||
|
func TestNamespace(t *testing.T) {
|
||||||
|
t.Run("namespaced block access", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
|
||||||
|
r, err := tpl.ExecuteBlock("orders.search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlock orders.search failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "SELECT * FROM orders") {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("short name fallback", func(t *testing.T) {
|
||||||
|
tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
|
||||||
|
r, err := tpl.ExecuteBlock("search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteBlock search failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "SELECT * FROM orders") {
|
||||||
|
t.Errorf("SQL = %q", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("same block names different namespaces", func(t *testing.T) {
|
||||||
|
tpl1 := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
|
||||||
|
tpl2 := parse(t, `@namespace("users") @tpl("search") { SELECT * FROM users }`)
|
||||||
|
|
||||||
|
r1, err := tpl1.ExecuteBlock("orders.search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("orders.search failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r1.SQL, "orders") {
|
||||||
|
t.Errorf("SQL = %q, want orders table", r1.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2, err := tpl2.ExecuteBlock("users.search", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("users.search failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r2.SQL, "users") {
|
||||||
|
t.Errorf("SQL = %q, want users table", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 8. TestPlaceholderStyles ----------
|
||||||
|
|
||||||
|
func TestPlaceholderStyles(t *testing.T) {
|
||||||
|
t.Run("QuestionMark default", func(t *testing.T) {
|
||||||
|
r := exec(t, "WHERE id = #{id}", map[string]any{"id": 1})
|
||||||
|
if !strings.Contains(r.SQL, "WHERE id = ?") {
|
||||||
|
t.Errorf("SQL = %q, want 'WHERE id = ?'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DollarNumber", func(t *testing.T) {
|
||||||
|
r := exec(t, "WHERE id = #{id} AND name = #{name}",
|
||||||
|
map[string]any{"id": 1, "name": "a"},
|
||||||
|
WithPlaceholderStyle(DollarNumber))
|
||||||
|
if !strings.Contains(r.SQL, "WHERE id = $1 AND name = $2") {
|
||||||
|
t.Errorf("SQL = %q, want 'WHERE id = $1 AND name = $2'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ColonNumber", func(t *testing.T) {
|
||||||
|
r := exec(t, "WHERE id = #{id} AND name = #{name}",
|
||||||
|
map[string]any{"id": 1, "name": "a"},
|
||||||
|
WithPlaceholderStyle(ColonNumber))
|
||||||
|
if !strings.Contains(r.SQL, "WHERE id = :1 AND name = :2") {
|
||||||
|
t.Errorf("SQL = %q, want 'WHERE id = :1 AND name = :2'", r.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 9. TestStrictMode ----------
|
||||||
|
|
||||||
|
func TestStrictMode(t *testing.T) {
|
||||||
|
t.Run("strict true undefined param", func(t *testing.T) {
|
||||||
|
tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
_, err = tpl.Execute(map[string]any{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for undefined variable in strict mode")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strict false undefined param", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT #{id} FROM users", map[string]any{},
|
||||||
|
WithStrictMode(false))
|
||||||
|
if !strings.Contains(r.SQL, "SELECT ? FROM users") {
|
||||||
|
t.Errorf("SQL = %q, want 'SELECT ? FROM users'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 1 || r.Args[0] != nil {
|
||||||
|
t.Errorf("Args = %v, want [nil]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strict false undefined raw", func(t *testing.T) {
|
||||||
|
r := exec(t, "SELECT ${col} FROM users", map[string]any{},
|
||||||
|
WithStrictMode(false))
|
||||||
|
if r.SQL != "SELECT FROM users" {
|
||||||
|
t.Errorf("SQL = %q, want 'SELECT FROM users'", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- 10. TestComplex ----------
|
||||||
|
|
||||||
|
func TestComplex(t *testing.T) {
|
||||||
|
t.Run("conditional WHERE with 1=1", func(t *testing.T) {
|
||||||
|
src := `SELECT * FROM users WHERE 1=1
|
||||||
|
@if(status != nil) { AND status = #{status} }
|
||||||
|
@if(name != "") { AND name = #{name} }`
|
||||||
|
|
||||||
|
r := exec(t, src, map[string]any{"status": "active", "name": "alice"})
|
||||||
|
if !strings.Contains(r.SQL, "AND status = ?") {
|
||||||
|
t.Errorf("SQL missing status condition: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "AND name = ?") {
|
||||||
|
t.Errorf("SQL missing name condition: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 2 {
|
||||||
|
t.Errorf("Args = %v, want 2 args", r.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, src, map[string]any{"status": nil, "name": ""})
|
||||||
|
if strings.Contains(r2.SQL, "AND") {
|
||||||
|
t.Errorf("SQL should have no AND: %q", r2.SQL)
|
||||||
|
}
|
||||||
|
if len(r2.Args) != 0 {
|
||||||
|
t.Errorf("Args = %v, want empty", r2.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("dynamic UPDATE with trailing commas", func(t *testing.T) {
|
||||||
|
src := `UPDATE users SET
|
||||||
|
@if(name != nil) { name = #{name}, }
|
||||||
|
@if(email != nil) { email = #{email}, }
|
||||||
|
WHERE id = #{id}`
|
||||||
|
|
||||||
|
r := exec(t, src, map[string]any{"name": "alice", "email": "a@b.com", "id": 1})
|
||||||
|
if !strings.Contains(r.SQL, "name = ?") {
|
||||||
|
t.Errorf("SQL missing name: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.SQL, "email = ?") {
|
||||||
|
t.Errorf("SQL missing email: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if strings.Contains(r.SQL, ",\nWHERE") {
|
||||||
|
t.Errorf("trailing comma not trimmed: %q", r.SQL)
|
||||||
|
}
|
||||||
|
if len(r.Args) != 3 {
|
||||||
|
t.Errorf("Args = %v, want 3 args", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch INSERT with @for", func(t *testing.T) {
|
||||||
|
src := `INSERT INTO users (name, age) VALUES
|
||||||
|
@for(u range users) { (#{u.name}, #{u.age}), }`
|
||||||
|
|
||||||
|
r := exec(t, src, map[string]any{"users": []map[string]any{
|
||||||
|
{"name": "alice", "age": 30},
|
||||||
|
{"name": "bob", "age": 25},
|
||||||
|
}})
|
||||||
|
if len(r.Args) != 4 {
|
||||||
|
t.Fatalf("Args = %v, want 4 args", r.Args)
|
||||||
|
}
|
||||||
|
if r.Args[0] != "alice" || r.Args[1] != 30 || r.Args[2] != "bob" || r.Args[3] != 25 {
|
||||||
|
t.Errorf("Args = %v, want [alice 30 bob 25]", r.Args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ORDER BY with @if desc/asc", func(t *testing.T) {
|
||||||
|
src := `SELECT * FROM users ORDER BY id @if(asc) {ASC} else {DESC}`
|
||||||
|
r := exec(t, src, map[string]any{"asc": true})
|
||||||
|
if !strings.Contains(r.SQL, "ASC") {
|
||||||
|
t.Errorf("SQL = %q, want ASC", r.SQL)
|
||||||
|
}
|
||||||
|
if strings.Contains(r.SQL, "DESC") {
|
||||||
|
t.Errorf("SQL = %q, should not contain DESC", r.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
r2 := exec(t, src, map[string]any{"asc": false})
|
||||||
|
if !strings.Contains(r2.SQL, "DESC") {
|
||||||
|
t.Errorf("SQL = %q, want DESC", r2.SQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Benchmarks ----------
|
||||||
|
|
||||||
|
func BenchmarkSimple(b *testing.B) {
|
||||||
|
tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
|
||||||
|
vars := map[string]any{"id": 42, "name": "alice"}
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = tpl.Execute(vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConditional(b *testing.B) {
|
||||||
|
src := `SELECT * FROM users WHERE 1=1
|
||||||
|
@if(status != nil) { AND status = #{status} }
|
||||||
|
@if(name != "") { AND name = #{name} }
|
||||||
|
ORDER BY id`
|
||||||
|
tpl := New().MustParse("bench", src)
|
||||||
|
vars := map[string]any{"status": "active", "name": "alice"}
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = tpl.Execute(vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoop(b *testing.B) {
|
||||||
|
src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })`
|
||||||
|
tpl := New().MustParse("bench", src)
|
||||||
|
ids := make([]int, 10)
|
||||||
|
for i := range ids {
|
||||||
|
ids[i] = i
|
||||||
|
}
|
||||||
|
vars := map[string]any{"ids": ids}
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = tpl.Execute(vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPlaceholderDollar(b *testing.B) {
|
||||||
|
eng := New(WithPlaceholderStyle(DollarNumber))
|
||||||
|
tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
|
||||||
|
vars := map[string]any{"id": 42, "name": "alice"}
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = tpl.Execute(vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user