package utpl import ( "errors" "strings" "testing" ) // helper to parse with default engine func parse(t *testing.T, source string) *Template { t.Helper() tpl, err := New().Parse("test", source) if err != nil { t.Fatalf("parse failed: %v", err) } return tpl } // helper to parse with options and execute func exec(t *testing.T, source string, vars map[string]any, opts ...Option) *Result { t.Helper() tpl, err := New(opts...).Parse("test", source) if err != nil { t.Fatalf("parse failed: %v", err) } result, err := tpl.Execute(vars) if err != nil { t.Fatalf("execute failed: %v", err) } return result } // ---------- 1. TestParam ---------- func TestParam(t *testing.T) { t.Run("single param", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE id = #{id}", map[string]any{"id": 42}) if r.SQL != "SELECT * FROM users WHERE id = ?" { t.Errorf("SQL = %q, want %q", r.SQL, "SELECT * FROM users WHERE id = ?") } if len(r.Args) != 1 || r.Args[0] != 42 { t.Errorf("Args = %v, want [42]", r.Args) } }) t.Run("multiple params", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE id = #{id} AND name = #{name}", map[string]any{"id": 1, "name": "alice"}) if r.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" { t.Errorf("SQL = %q", r.SQL) } if len(r.Args) != 2 || r.Args[0] != 1 || r.Args[1] != "alice" { t.Errorf("Args = %v, want [1 alice]", r.Args) } }) t.Run("nil value param", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE name = #{name}", map[string]any{"name": nil}) if r.SQL != "SELECT * FROM users WHERE name = ?" { t.Errorf("SQL = %q", r.SQL) } if len(r.Args) != 1 || r.Args[0] != nil { t.Errorf("Args = %v, want [nil]", r.Args) } }) t.Run("param with nested dot path", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE name = #{user.name}", map[string]any{"user": map[string]any{"name": "bob"}}) if r.SQL != "SELECT * FROM users WHERE name = ?" { t.Errorf("SQL = %q", r.SQL) } if len(r.Args) != 1 || r.Args[0] != "bob" { t.Errorf("Args = %v, want [bob]", r.Args) } }) } // ---------- 2. TestRaw ---------- func TestRaw(t *testing.T) { t.Run("basic raw substitution", func(t *testing.T) { r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}) if r.SQL != "SELECT name FROM users" { t.Errorf("SQL = %q, want %q", r.SQL, "SELECT name FROM users") } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) t.Run("multiple raw params", func(t *testing.T) { r := exec(t, "SELECT ${a}, ${b} FROM ${table}", map[string]any{"a": "id", "b": "name", "table": "users"}) if r.SQL != "SELECT id, name FROM users" { t.Errorf("SQL = %q", r.SQL) } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) t.Run("allowlist pass", func(t *testing.T) { policy := RawAllowlist{ "col": {"name", "age"}, } r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}, WithRawPolicy(policy)) if r.SQL != "SELECT name FROM users" { t.Errorf("SQL = %q", r.SQL) } }) t.Run("allowlist reject", func(t *testing.T) { policy := RawAllowlist{ "col": {"name", "age"}, } tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users") _, err := tpl.Execute(map[string]any{"col": "password"}) if err == nil { t.Fatal("expected error for blocked raw value, got nil") } var unsafeErr *UnsafeRawError if !errors.As(err, &unsafeErr) { t.Errorf("expected UnsafeRawError, got %T: %v", err, err) } }) t.Run("blocklist allow", func(t *testing.T) { policy := RawBlocklist{ "col": {"password"}, } r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"}, WithRawPolicy(policy)) if r.SQL != "SELECT name FROM users" { t.Errorf("SQL = %q", r.SQL) } }) t.Run("blocklist reject", func(t *testing.T) { policy := RawBlocklist{ "col": {"password", "secret"}, } tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users") _, err := tpl.Execute(map[string]any{"col": "password"}) if err == nil { t.Fatal("expected error for blocklisted value, got nil") } var unsafeErr *UnsafeRawError if !errors.As(err, &unsafeErr) { t.Errorf("expected UnsafeRawError, got %T: %v", err, err) } }) } // ---------- 3. TestComment ---------- func TestComment(t *testing.T) { t.Run("comment stripped", func(t *testing.T) { r := exec(t, "# this is a comment\nSELECT 1", nil) if r.SQL != "SELECT 1" { t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1") } }) t.Run("comment in middle", func(t *testing.T) { r := exec(t, "SELECT 1\n# comment\nSELECT 2", nil) want := "SELECT 1\nSELECT 2" if r.SQL != want { t.Errorf("SQL = %q, want %q", r.SQL, want) } }) t.Run("no false positive with param", func(t *testing.T) { r := exec(t, "SELECT #{id} FROM users", map[string]any{"id": 1}) if r.SQL != "SELECT ? FROM users" { t.Errorf("SQL = %q, want %q", r.SQL, "SELECT ? FROM users") } if len(r.Args) != 1 || r.Args[0] != 1 { t.Errorf("Args = %v", r.Args) } }) } // ---------- 4. TestIf ---------- func TestIf(t *testing.T) { t.Run("true condition", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }", map[string]any{"status": "active"}) if !strings.Contains(r.SQL, "AND status = ?") { t.Errorf("SQL = %q, want to contain 'AND status = ?'", r.SQL) } if len(r.Args) != 1 || r.Args[0] != "active" { t.Errorf("Args = %v, want [active]", r.Args) } }) t.Run("false condition nil", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }", map[string]any{"status": nil}) if strings.Contains(r.SQL, "AND status") { t.Errorf("SQL = %q, should not contain 'AND status'", r.SQL) } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) t.Run("else branch", func(t *testing.T) { r := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}", map[string]any{"desc": true}) if !strings.Contains(r.SQL, "DESC") { t.Errorf("SQL = %q, want DESC", r.SQL) } if strings.Contains(r.SQL, "ASC") { t.Errorf("SQL = %q, should not contain ASC", r.SQL) } r2 := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}", map[string]any{"desc": false}) if !strings.Contains(r2.SQL, "ASC") { t.Errorf("SQL = %q, want ASC", r2.SQL) } }) t.Run("empty string check", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }", map[string]any{"name": "alice"}) if !strings.Contains(r.SQL, "AND name = ?") { t.Errorf("SQL = %q, want 'AND name = ?'", r.SQL) } r2 := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }", map[string]any{"name": ""}) if strings.Contains(r2.SQL, "AND name") { t.Errorf("SQL = %q, should not contain 'AND name'", r2.SQL) } }) t.Run("multiple independent conditions", func(t *testing.T) { r := exec(t, "@if(a != nil) {A} @if(b != nil) {B}", map[string]any{"a": 1, "b": 2}) if !strings.Contains(r.SQL, "A") || !strings.Contains(r.SQL, "B") { t.Errorf("SQL = %q, want both A and B", r.SQL) } r2 := exec(t, "@if(a != nil) {A} @if(b != nil) {B}", map[string]any{"a": 1}) if !strings.Contains(r2.SQL, "A") { t.Errorf("SQL = %q, want A", r2.SQL) } if strings.Contains(r2.SQL, "B") { t.Errorf("SQL = %q, should not contain B", r2.SQL) } }) t.Run("AND operator", func(t *testing.T) { r := exec(t, "@if(a != nil && b != \"\") { both }", map[string]any{"a": 1, "b": "x"}) if !strings.Contains(r.SQL, "both") { t.Errorf("SQL = %q, want 'both'", r.SQL) } r2 := exec(t, "@if(a != nil && b != \"\") { both }", map[string]any{"a": 1, "b": ""}) if strings.Contains(r2.SQL, "both") { t.Errorf("SQL = %q, should not contain 'both'", r2.SQL) } }) t.Run("OR operator", func(t *testing.T) { r := exec(t, "@if(a != nil || b != nil) { either }", map[string]any{"a": nil, "b": 1}) if !strings.Contains(r.SQL, "either") { t.Errorf("SQL = %q, want 'either'", r.SQL) } r2 := exec(t, "@if(a != nil || b != nil) { either }", map[string]any{"a": nil, "b": nil}) if strings.Contains(r2.SQL, "either") { t.Errorf("SQL = %q, should not contain 'either'", r2.SQL) } }) } // ---------- 5. TestFor ---------- func TestFor(t *testing.T) { t.Run("basic IN clause", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", map[string]any{"ids": []int{1, 2, 3}}) if len(r.Args) != 3 { t.Fatalf("Args = %v, want 3 args", r.Args) } for i, want := range []int{1, 2, 3} { if r.Args[i] != want { t.Errorf("Args[%d] = %v, want %v", i, r.Args[i], want) } } }) t.Run("trailing comma auto-trim", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", map[string]any{"ids": []int{1, 2}}) // trailing comma and space should be trimmed by the executor if strings.HasSuffix(r.SQL, ",") { t.Errorf("SQL = %q, should not end with comma", r.SQL) } }) t.Run("empty list", func(t *testing.T) { r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })", map[string]any{"ids": []int{}}) if strings.Contains(r.SQL, "?") { t.Errorf("SQL = %q, should have no placeholders", r.SQL) } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) t.Run("nil list", func(t *testing.T) { // nil list produces no output (treated as empty slice) tpl := parse(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })") r, err := tpl.Execute(map[string]any{"ids": nil}) if err != nil { t.Fatalf("unexpected error: %v", err) } if strings.Contains(r.SQL, "?") { t.Errorf("SQL = %q, should have no placeholders", r.SQL) } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) t.Run("with index variable", func(t *testing.T) { r := exec(t, "SELECT @for(i, v range items) {#{v}, }", map[string]any{"items": []string{"a", "b", "c"}}) if len(r.Args) != 3 { t.Errorf("Args = %v, want 3 args", r.Args) } }) t.Run("nested object access", func(t *testing.T) { r := exec(t, "SELECT @for(u range users) {#{u.name}, }", map[string]any{"users": []map[string]any{ {"name": "alice"}, {"name": "bob"}, }}) if len(r.Args) != 2 { t.Fatalf("Args = %v, want 2 args", r.Args) } if r.Args[0] != "alice" || r.Args[1] != "bob" { t.Errorf("Args = %v, want [alice bob]", r.Args) } }) } // ---------- 6. TestTpl ---------- func TestTpl(t *testing.T) { t.Run("single block", func(t *testing.T) { tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) r, err := tpl.ExecuteBlock("search", nil) if err != nil { t.Fatalf("ExecuteBlock failed: %v", err) } if !strings.Contains(r.SQL, "SELECT * FROM users") { t.Errorf("SQL = %q", r.SQL) } }) t.Run("multiple blocks", func(t *testing.T) { tpl := parse(t, `@tpl("search") { SELECT * FROM users } @tpl("count") { SELECT COUNT(*) FROM users }`) r1, err := tpl.ExecuteBlock("search", nil) if err != nil { t.Fatalf("ExecuteBlock search failed: %v", err) } if !strings.Contains(r1.SQL, "SELECT * FROM users") { t.Errorf("search SQL = %q", r1.SQL) } r2, err := tpl.ExecuteBlock("count", nil) if err != nil { t.Fatalf("ExecuteBlock count failed: %v", err) } if !strings.Contains(r2.SQL, "SELECT COUNT(*) FROM users") { t.Errorf("count SQL = %q", r2.SQL) } }) t.Run("execute on tpl template should error", func(t *testing.T) { tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) _, err := tpl.Execute(nil) if err == nil { t.Fatal("expected error when Execute is called on template with blocks") } var execErr *ExecError if !errors.As(err, &execErr) { t.Errorf("expected ExecError, got %T: %v", err, err) } }) t.Run("ExecuteBlock with wrong name", func(t *testing.T) { tpl := parse(t, `@tpl("search") { SELECT * FROM users }`) _, err := tpl.ExecuteBlock("nonexistent", nil) if err == nil { t.Fatal("expected error for nonexistent block name") } var execErr *ExecError if !errors.As(err, &execErr) { t.Errorf("expected ExecError, got %T: %v", err, err) } }) } // ---------- 7. TestNamespace ---------- func TestNamespace(t *testing.T) { t.Run("namespaced block access", func(t *testing.T) { tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) r, err := tpl.ExecuteBlock("orders.search", nil) if err != nil { t.Fatalf("ExecuteBlock orders.search failed: %v", err) } if !strings.Contains(r.SQL, "SELECT * FROM orders") { t.Errorf("SQL = %q", r.SQL) } }) t.Run("short name fallback", func(t *testing.T) { tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) r, err := tpl.ExecuteBlock("search", nil) if err != nil { t.Fatalf("ExecuteBlock search failed: %v", err) } if !strings.Contains(r.SQL, "SELECT * FROM orders") { t.Errorf("SQL = %q", r.SQL) } }) t.Run("same block names different namespaces", func(t *testing.T) { tpl1 := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`) tpl2 := parse(t, `@namespace("users") @tpl("search") { SELECT * FROM users }`) r1, err := tpl1.ExecuteBlock("orders.search", nil) if err != nil { t.Fatalf("orders.search failed: %v", err) } if !strings.Contains(r1.SQL, "orders") { t.Errorf("SQL = %q, want orders table", r1.SQL) } r2, err := tpl2.ExecuteBlock("users.search", nil) if err != nil { t.Fatalf("users.search failed: %v", err) } if !strings.Contains(r2.SQL, "users") { t.Errorf("SQL = %q, want users table", r2.SQL) } }) } // ---------- 8. TestPlaceholderStyles ---------- func TestPlaceholderStyles(t *testing.T) { t.Run("QuestionMark default", func(t *testing.T) { r := exec(t, "WHERE id = #{id}", map[string]any{"id": 1}) if !strings.Contains(r.SQL, "WHERE id = ?") { t.Errorf("SQL = %q, want 'WHERE id = ?'", r.SQL) } }) t.Run("DollarNumber", func(t *testing.T) { r := exec(t, "WHERE id = #{id} AND name = #{name}", map[string]any{"id": 1, "name": "a"}, WithPlaceholderStyle(DollarNumber)) if !strings.Contains(r.SQL, "WHERE id = $1 AND name = $2") { t.Errorf("SQL = %q, want 'WHERE id = $1 AND name = $2'", r.SQL) } }) t.Run("ColonNumber", func(t *testing.T) { r := exec(t, "WHERE id = #{id} AND name = #{name}", map[string]any{"id": 1, "name": "a"}, WithPlaceholderStyle(ColonNumber)) if !strings.Contains(r.SQL, "WHERE id = :1 AND name = :2") { t.Errorf("SQL = %q, want 'WHERE id = :1 AND name = :2'", r.SQL) } }) } // ---------- 9. TestStrictMode ---------- func TestStrictMode(t *testing.T) { t.Run("strict true undefined param", func(t *testing.T) { tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}") if err != nil { t.Fatalf("parse failed: %v", err) } _, err = tpl.Execute(map[string]any{}) if err == nil { t.Fatal("expected error for undefined variable in strict mode") } }) t.Run("strict false undefined param", func(t *testing.T) { r := exec(t, "SELECT #{id} FROM users", map[string]any{}, WithStrictMode(false)) if !strings.Contains(r.SQL, "SELECT ? FROM users") { t.Errorf("SQL = %q, want 'SELECT ? FROM users'", r.SQL) } if len(r.Args) != 1 || r.Args[0] != nil { t.Errorf("Args = %v, want [nil]", r.Args) } }) t.Run("strict false undefined raw", func(t *testing.T) { r := exec(t, "SELECT ${col} FROM users", map[string]any{}, WithStrictMode(false)) if r.SQL != "SELECT FROM users" { t.Errorf("SQL = %q, want 'SELECT FROM users'", r.SQL) } if len(r.Args) != 0 { t.Errorf("Args = %v, want empty", r.Args) } }) } // ---------- 10. TestComplex ---------- func TestComplex(t *testing.T) { t.Run("conditional WHERE with 1=1", func(t *testing.T) { src := `SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} } @if(name != "") { AND name = #{name} }` r := exec(t, src, map[string]any{"status": "active", "name": "alice"}) if !strings.Contains(r.SQL, "AND status = ?") { t.Errorf("SQL missing status condition: %q", r.SQL) } if !strings.Contains(r.SQL, "AND name = ?") { t.Errorf("SQL missing name condition: %q", r.SQL) } if len(r.Args) != 2 { t.Errorf("Args = %v, want 2 args", r.Args) } r2 := exec(t, src, map[string]any{"status": nil, "name": ""}) if strings.Contains(r2.SQL, "AND") { t.Errorf("SQL should have no AND: %q", r2.SQL) } if len(r2.Args) != 0 { t.Errorf("Args = %v, want empty", r2.Args) } }) t.Run("dynamic UPDATE with trailing commas", func(t *testing.T) { src := `UPDATE users SET @if(name != nil) { name = #{name}, } @if(email != nil) { email = #{email}, } WHERE id = #{id}` r := exec(t, src, map[string]any{"name": "alice", "email": "a@b.com", "id": 1}) if !strings.Contains(r.SQL, "name = ?") { t.Errorf("SQL missing name: %q", r.SQL) } if !strings.Contains(r.SQL, "email = ?") { t.Errorf("SQL missing email: %q", r.SQL) } if strings.Contains(r.SQL, ",\nWHERE") { t.Errorf("trailing comma not trimmed: %q", r.SQL) } if len(r.Args) != 3 { t.Errorf("Args = %v, want 3 args", r.Args) } }) t.Run("batch INSERT with @for", func(t *testing.T) { src := `INSERT INTO users (name, age) VALUES @for(u range users) { (#{u.name}, #{u.age}), }` r := exec(t, src, map[string]any{"users": []map[string]any{ {"name": "alice", "age": 30}, {"name": "bob", "age": 25}, }}) if len(r.Args) != 4 { t.Fatalf("Args = %v, want 4 args", r.Args) } if r.Args[0] != "alice" || r.Args[1] != 30 || r.Args[2] != "bob" || r.Args[3] != 25 { t.Errorf("Args = %v, want [alice 30 bob 25]", r.Args) } }) t.Run("ORDER BY with @if desc/asc", func(t *testing.T) { src := `SELECT * FROM users ORDER BY id @if(asc) {ASC} else {DESC}` r := exec(t, src, map[string]any{"asc": true}) if !strings.Contains(r.SQL, "ASC") { t.Errorf("SQL = %q, want ASC", r.SQL) } if strings.Contains(r.SQL, "DESC") { t.Errorf("SQL = %q, should not contain DESC", r.SQL) } r2 := exec(t, src, map[string]any{"asc": false}) if !strings.Contains(r2.SQL, "DESC") { t.Errorf("SQL = %q, want DESC", r2.SQL) } }) } // ---------- Benchmarks ---------- func BenchmarkSimple(b *testing.B) { tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") vars := map[string]any{"id": 42, "name": "alice"} b.ResetTimer() for b.Loop() { _, _ = tpl.Execute(vars) } } func BenchmarkConditional(b *testing.B) { src := `SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} } @if(name != "") { AND name = #{name} } ORDER BY id` tpl := New().MustParse("bench", src) vars := map[string]any{"status": "active", "name": "alice"} b.ResetTimer() for b.Loop() { _, _ = tpl.Execute(vars) } } func BenchmarkLoop(b *testing.B) { src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })` tpl := New().MustParse("bench", src) ids := make([]int, 10) for i := range ids { ids[i] = i } vars := map[string]any{"ids": ids} b.ResetTimer() for b.Loop() { _, _ = tpl.Execute(vars) } } func BenchmarkPlaceholderDollar(b *testing.B) { eng := New(WithPlaceholderStyle(DollarNumber)) tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}") vars := map[string]any{"id": 42, "name": "alice"} b.ResetTimer() for b.Loop() { _, _ = tpl.Execute(vars) } }