489 lines
14 KiB
Go
489 lines
14 KiB
Go
package utpl
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
// ---------- TestErrorPositions ----------
|
|
// Verify that ParseError.Pos and ExecError.Pos contain correct line/column values.
|
|
|
|
func TestErrorPositions(t *testing.T) {
|
|
t.Run("ParseError has correct position for unterminated param", func(t *testing.T) {
|
|
// #{id on line 1, col 8 (1-indexed)
|
|
_, err := New().Parse("test", "SELECT #{id")
|
|
var pe *ParseError
|
|
if !errors.As(err, &pe) {
|
|
t.Fatalf("expected ParseError, got %T: %v", err, err)
|
|
}
|
|
if pe.Pos.Line == 0 || pe.Pos.Column == 0 {
|
|
t.Errorf("ParseError.Pos = {Line:%d, Col:%d}, expected non-zero", pe.Pos.Line, pe.Pos.Column)
|
|
}
|
|
if pe.Pos.Line != 1 {
|
|
t.Errorf("ParseError.Pos.Line = %d, want 1", pe.Pos.Line)
|
|
}
|
|
if pe.Pos.Column != 8 {
|
|
t.Errorf("ParseError.Pos.Column = %d, want 8", pe.Pos.Column)
|
|
}
|
|
})
|
|
|
|
t.Run("ParseError has correct position for unterminated if on line 2", func(t *testing.T) {
|
|
src := "SELECT 1\n@if(x > 0"
|
|
_, err := New().Parse("test", src)
|
|
var pe *ParseError
|
|
if !errors.As(err, &pe) {
|
|
t.Fatalf("expected ParseError, got %T: %v", err, err)
|
|
}
|
|
if pe.Pos.Line != 2 {
|
|
t.Errorf("ParseError.Pos.Line = %d, want 2", pe.Pos.Line)
|
|
}
|
|
})
|
|
|
|
t.Run("ParseError has correct position for unterminated for on line 3", func(t *testing.T) {
|
|
src := "SELECT 1\n\n@for(x, range list)"
|
|
_, err := New().Parse("test", src)
|
|
var pe *ParseError
|
|
if !errors.As(err, &pe) {
|
|
t.Fatalf("expected ParseError, got %T: %v", err, err)
|
|
}
|
|
if pe.Pos.Line != 3 {
|
|
t.Errorf("ParseError.Pos.Line = %d, want 3", pe.Pos.Line)
|
|
}
|
|
})
|
|
|
|
t.Run("ExecError has correct position for undefined variable in strict mode", func(t *testing.T) {
|
|
tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}")
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
_, err = tpl.Execute(map[string]any{})
|
|
var ee *ExecError
|
|
if !errors.As(err, &ee) {
|
|
t.Fatalf("expected ExecError, got %T: %v", err, err)
|
|
}
|
|
if ee.Pos.Line == 0 || ee.Pos.Column == 0 {
|
|
t.Errorf("ExecError.Pos = {Line:%d, Col:%d}, expected non-zero", ee.Pos.Line, ee.Pos.Column)
|
|
}
|
|
})
|
|
|
|
t.Run("ExecError has correct position for undefined variable on line 2", func(t *testing.T) {
|
|
tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT 1\nWHERE id = #{missing}")
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
_, err = tpl.Execute(map[string]any{})
|
|
var ee *ExecError
|
|
if !errors.As(err, &ee) {
|
|
t.Fatalf("expected ExecError, got %T: %v", err, err)
|
|
}
|
|
if ee.Pos.Line != 2 {
|
|
t.Errorf("ExecError.Pos.Line = %d, want 2", ee.Pos.Line)
|
|
}
|
|
})
|
|
|
|
t.Run("ExecError has correct position for non-slice in @for", func(t *testing.T) {
|
|
tpl, err := New().Parse("test", "SELECT @for(x, range items) {#{x}}")
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
_, err = tpl.Execute(map[string]any{"items": "not a slice"})
|
|
var ee *ExecError
|
|
if !errors.As(err, &ee) {
|
|
t.Fatalf("expected ExecError, got %T: %v", err, err)
|
|
}
|
|
if ee.Pos.Line == 0 || ee.Pos.Column == 0 {
|
|
t.Errorf("ExecError.Pos = {Line:%d, Col:%d}, expected non-zero", ee.Pos.Line, ee.Pos.Column)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestRawDefault ----------
|
|
// Verify @raw behavior when no policy is configured (default nil).
|
|
|
|
func TestRawDefault(t *testing.T) {
|
|
t.Run("raw with no policy substitutes directly", func(t *testing.T) {
|
|
// No WithRawPolicy — rawPolicy is nil, should skip validation
|
|
r := exec(t, "SELECT ${col} FROM ${table}", map[string]any{"col": "name", "table": "users"})
|
|
if r.SQL != "SELECT name FROM users" {
|
|
t.Errorf("SQL = %q, want 'SELECT name FROM users'", r.SQL)
|
|
}
|
|
if len(r.Args) != 0 {
|
|
t.Errorf("Args = %v, want empty", r.Args)
|
|
}
|
|
})
|
|
|
|
t.Run("raw with nil value and no policy", func(t *testing.T) {
|
|
r := exec(t, "SELECT ${col}", map[string]any{"col": nil})
|
|
if r.SQL != "SELECT <nil>" {
|
|
t.Errorf("SQL = %q, want 'SELECT <nil>'", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestRawNoop ----------
|
|
// Verify RawNoop policy.
|
|
|
|
func TestRawNoop(t *testing.T) {
|
|
t.Run("RawNoop allows all values", func(t *testing.T) {
|
|
r := exec(t, "SELECT ${col} FROM ${table}", map[string]any{"col": "name", "table": "users"},
|
|
WithRawPolicy(RawNoop{}))
|
|
if r.SQL != "SELECT name FROM users" {
|
|
t.Errorf("SQL = %q, want 'SELECT name FROM users'", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestStructFieldAccess ----------
|
|
// Verify that context resolves struct fields (case-insensitive).
|
|
|
|
func TestStructFieldAccess(t *testing.T) {
|
|
type User struct {
|
|
Name string
|
|
Age int
|
|
}
|
|
|
|
t.Run("struct field access via dot path", func(t *testing.T) {
|
|
r := exec(t, "SELECT #{u.Name}, #{u.Age}",
|
|
map[string]any{"u": User{Name: "alice", Age: 30}})
|
|
if r.SQL != "SELECT ?, ?" {
|
|
t.Errorf("SQL = %q, want 'SELECT ?, ?'", r.SQL)
|
|
}
|
|
if len(r.Args) != 2 || r.Args[0] != "alice" || r.Args[1] != 30 {
|
|
t.Errorf("Args = %v, want [alice 30]", r.Args)
|
|
}
|
|
})
|
|
|
|
t.Run("struct field case-insensitive access", func(t *testing.T) {
|
|
r := exec(t, "SELECT #{u.name}",
|
|
map[string]any{"u": User{Name: "bob", Age: 25}})
|
|
if len(r.Args) != 1 || r.Args[0] != "bob" {
|
|
t.Errorf("Args = %v, want [bob]", r.Args)
|
|
}
|
|
})
|
|
|
|
t.Run("struct field in @if condition", func(t *testing.T) {
|
|
r := exec(t, "@if(u.Active) {active}",
|
|
map[string]any{"u": struct{ Active bool }{Active: true}})
|
|
if !strings.Contains(r.SQL, "active") {
|
|
t.Errorf("SQL = %q, want 'active'", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestConcurrentExecution ----------
|
|
// Verify Template.Execute is safe for concurrent use.
|
|
|
|
func TestConcurrentExecution(t *testing.T) {
|
|
tpl := parse(t, "SELECT * FROM users WHERE id = #{id} AND status = #{status}")
|
|
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, 100)
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
_, err := tpl.Execute(map[string]any{"id": i, "status": "active"})
|
|
if err != nil {
|
|
errs <- err
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
for err := range errs {
|
|
t.Errorf("concurrent execution error: %v", err)
|
|
}
|
|
}
|
|
|
|
// ---------- TestMustParse ----------
|
|
// Verify MustParse panics on invalid input.
|
|
|
|
func TestMustParse(t *testing.T) {
|
|
t.Run("MustParse panics on invalid template", func(t *testing.T) {
|
|
defer func() {
|
|
r := recover()
|
|
if r == nil {
|
|
t.Fatal("expected MustParse to panic on invalid template")
|
|
}
|
|
}()
|
|
New().MustParse("test", "SELECT #{id") // unterminated param
|
|
})
|
|
|
|
t.Run("MustParse succeeds on valid template", func(t *testing.T) {
|
|
tpl := New().MustParse("test", "SELECT #{id}")
|
|
if tpl == nil {
|
|
t.Fatal("MustParse returned nil on valid template")
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestInclude ----------
|
|
// Additional include tests.
|
|
|
|
func TestIncludeCircular(t *testing.T) {
|
|
t.Run("circular include returns error", func(t *testing.T) {
|
|
callCount := 0
|
|
resolver := func(path string) (string, error) {
|
|
callCount++
|
|
if callCount > 10 {
|
|
return "", errors.New("too many includes")
|
|
}
|
|
return `@include("other")`, nil
|
|
}
|
|
_, err := New(WithIncludeResolver(resolver)).Parse("test", `@include("a")`)
|
|
if err == nil {
|
|
t.Fatal("expected error for circular include, got nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIncludeNested(t *testing.T) {
|
|
t.Run("include within include expands both", func(t *testing.T) {
|
|
resolver := func(path string) (string, error) {
|
|
files := map[string]string{
|
|
"outer": "OUTER @include(\"inner\")",
|
|
"inner": "INNER",
|
|
}
|
|
src, ok := files[path]
|
|
if !ok {
|
|
return "", errors.New("not found")
|
|
}
|
|
return src, nil
|
|
}
|
|
tpl, err := New(WithIncludeResolver(resolver)).Parse("test", "@include(\"outer\")")
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
r, err := tpl.Execute(nil)
|
|
if err != nil {
|
|
t.Fatalf("execute failed: %v", err)
|
|
}
|
|
if !strings.Contains(r.SQL, "OUTER") {
|
|
t.Errorf("SQL = %q, want OUTER", r.SQL)
|
|
}
|
|
if !strings.Contains(r.SQL, "INNER") {
|
|
t.Errorf("SQL = %q, want INNER", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestNilVars ----------
|
|
// Verify behavior when nil map is passed to Execute.
|
|
|
|
func TestNilVars(t *testing.T) {
|
|
t.Run("nil vars map with strict mode", func(t *testing.T) {
|
|
tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}")
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
_, err = tpl.Execute(nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for nil vars with strict mode and #{id}")
|
|
}
|
|
})
|
|
|
|
t.Run("nil vars map with non-strict mode", func(t *testing.T) {
|
|
r := exec(t, "SELECT 1", nil, WithStrictMode(false))
|
|
if r.SQL != "SELECT 1" {
|
|
t.Errorf("SQL = %q, want 'SELECT 1'", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestForTypedSlices ----------
|
|
// Verify @for works with typed slices.
|
|
|
|
func TestForTypedSlices(t *testing.T) {
|
|
t.Run("for with []float64", func(t *testing.T) {
|
|
r := exec(t, "SELECT @for(v range items) {#{v}, }",
|
|
map[string]any{"items": []float64{1.5, 2.5, 3.5}})
|
|
if len(r.Args) != 3 {
|
|
t.Errorf("Args = %v, want 3 args", r.Args)
|
|
}
|
|
if r.Args[0] != 1.5 || r.Args[1] != 2.5 || r.Args[2] != 3.5 {
|
|
t.Errorf("Args = %v, want [1.5 2.5 3.5]", r.Args)
|
|
}
|
|
})
|
|
|
|
t.Run("for with []string", func(t *testing.T) {
|
|
r := exec(t, "SELECT @for(v range items) {#{v}, }",
|
|
map[string]any{"items": []string{"a", "b"}})
|
|
if len(r.Args) != 2 {
|
|
t.Errorf("Args = %v, want 2 args", r.Args)
|
|
}
|
|
if r.Args[0] != "a" || r.Args[1] != "b" {
|
|
t.Errorf("Args = %v, want [a b]", r.Args)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestExpressionEdgeCases ----------
|
|
|
|
func TestExpressionEdgeCases(t *testing.T) {
|
|
t.Run("deeply nested parentheses", func(t *testing.T) {
|
|
r := exec(t, "@if(((a))) {yes}", map[string]any{"a": true})
|
|
if !strings.Contains(r.SQL, "yes") {
|
|
t.Errorf("SQL = %q, want 'yes'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("deep dot path", func(t *testing.T) {
|
|
r := exec(t, "@if(a.b.c.d) {deep}", map[string]any{
|
|
"a": map[string]any{
|
|
"b": map[string]any{
|
|
"c": map[string]any{
|
|
"d": true,
|
|
},
|
|
},
|
|
},
|
|
})
|
|
if !strings.Contains(r.SQL, "deep") {
|
|
t.Errorf("SQL = %q, want 'deep'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("float literal in expression", func(t *testing.T) {
|
|
r := exec(t, "@if(x > 1.5) {big}", map[string]any{"x": 2.5})
|
|
if !strings.Contains(r.SQL, "big") {
|
|
t.Errorf("SQL = %q, want 'big'", r.SQL)
|
|
}
|
|
|
|
r2 := exec(t, "@if(x > 1.5) {big}", map[string]any{"x": 0.5})
|
|
if strings.Contains(r2.SQL, "big") {
|
|
t.Errorf("SQL = %q, should not contain 'big'", r2.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("single-quoted string in expression", func(t *testing.T) {
|
|
r := exec(t, `@if(role == 'admin') {is admin}`, map[string]any{"role": "admin"})
|
|
if !strings.Contains(r.SQL, "is admin") {
|
|
t.Errorf("SQL = %q, want 'is admin'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("string escape sequences", func(t *testing.T) {
|
|
r := exec(t, `@if(x == "a\"b") {escaped}`, map[string]any{"x": `a"b`})
|
|
if !strings.Contains(r.SQL, "escaped") {
|
|
t.Errorf("SQL = %q, want 'escaped'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("OR short circuit", func(t *testing.T) {
|
|
// When left is true, right should not cause error even if undefined
|
|
r := exec(t, "@if(a != nil || b != nil) {either}",
|
|
map[string]any{"a": 1})
|
|
if !strings.Contains(r.SQL, "either") {
|
|
t.Errorf("SQL = %q, want 'either'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("negative number in expression", func(t *testing.T) {
|
|
// Expression parser does not support unary minus; use a variable instead
|
|
r := exec(t, "@if(x == neg) {negative one}", map[string]any{"x": -1, "neg": -1})
|
|
if !strings.Contains(r.SQL, "negative one") {
|
|
t.Errorf("SQL = %q, want 'negative one'", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("unknown function returns error at execution", func(t *testing.T) {
|
|
// Functions are validated at execution time, not parse time
|
|
tpl, err := New().Parse("test", `@if(unknown(x)) {bad}`)
|
|
if err != nil {
|
|
t.Fatalf("parse should succeed (functions validated at exec): %v", err)
|
|
}
|
|
_, err = tpl.Execute(map[string]any{"x": 1})
|
|
if err == nil {
|
|
t.Fatal("expected error for unknown function at execution, got nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestTrailingCommaEdgeCases ----------
|
|
|
|
func TestTrailingCommaEdgeCases(t *testing.T) {
|
|
t.Run("trailing comma with space and newline", func(t *testing.T) {
|
|
r := exec(t, "SELECT @for(v range items) {#{v}, }\n",
|
|
map[string]any{"items": []int{1, 2}})
|
|
if strings.HasSuffix(r.SQL, ",") {
|
|
t.Errorf("SQL = %q, should not end with comma", r.SQL)
|
|
}
|
|
})
|
|
|
|
t.Run("no trailing comma when not present", func(t *testing.T) {
|
|
r := exec(t, "SELECT @for(v range items) {#{v} }",
|
|
map[string]any{"items": []int{1, 2}})
|
|
if strings.HasSuffix(r.SQL, ",") {
|
|
t.Errorf("SQL = %q, should not end with comma", r.SQL)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestUnsafeRawErrorFields ----------
|
|
|
|
func TestUnsafeRawErrorFields(t *testing.T) {
|
|
t.Run("UnsafeRawError contains param and value", func(t *testing.T) {
|
|
policy := RawAllowlist{"col": {"name"}}
|
|
tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col}")
|
|
_, err := tpl.Execute(map[string]any{"col": "DROP TABLE users"})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var ue *UnsafeRawError
|
|
if !errors.As(err, &ue) {
|
|
t.Fatalf("expected UnsafeRawError, got %T: %v", err, err)
|
|
}
|
|
if ue.Param != "col" {
|
|
t.Errorf("UnsafeRawError.Param = %q, want %q", ue.Param, "col")
|
|
}
|
|
if ue.Value != "DROP TABLE users" {
|
|
t.Errorf("UnsafeRawError.Value = %q, want %q", ue.Value, "DROP TABLE users")
|
|
}
|
|
if ue.Message == "" {
|
|
t.Error("UnsafeRawError.Message is empty")
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestForWithStructSlice ----------
|
|
|
|
func TestForWithStructSlice(t *testing.T) {
|
|
type Item struct {
|
|
Name string
|
|
Price float64
|
|
}
|
|
|
|
t.Run("for with struct slice", func(t *testing.T) {
|
|
items := []Item{
|
|
{Name: "apple", Price: 1.5},
|
|
{Name: "banana", Price: 0.8},
|
|
}
|
|
r := exec(t, "SELECT @for(i range items) {#{i.Name}, #{i.Price}, }",
|
|
map[string]any{"items": items})
|
|
if len(r.Args) != 4 {
|
|
t.Fatalf("Args = %v, want 4 args", r.Args)
|
|
}
|
|
if r.Args[0] != "apple" || r.Args[1] != 1.5 {
|
|
t.Errorf("first item args wrong: %v", r.Args[:2])
|
|
}
|
|
if r.Args[2] != "banana" || r.Args[3] != 0.8 {
|
|
t.Errorf("second item args wrong: %v", r.Args[2:])
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------- TestExpressionWithMissingVar ----------
|
|
// Non-strict mode: undefined expression variables should not error.
|
|
|
|
func TestExpressionWithMissingVar(t *testing.T) {
|
|
t.Run("non-strict: undefined in condition evaluates to nil (falsy)", func(t *testing.T) {
|
|
r := exec(t, "@if(missing != nil) {has it}", map[string]any{},
|
|
WithStrictMode(false))
|
|
if strings.Contains(r.SQL, "has it") {
|
|
t.Errorf("SQL = %q, should not contain 'has it' when variable is missing", r.SQL)
|
|
}
|
|
})
|
|
}
|