Files
u-tpl/utpl_test.go
绝尘 861d58d718 新增: u-tpl SQL 模板引擎完整实现
- Lexer/Parser/Executor 三阶段架构
- #{param} 参数化 + ${raw} 原样替换 + 白名单安全策略
- @if/@for/@tpl/@include/@namespace 控制流
- 表达式引擎: 比较、逻辑、nil 检查、len() 内置函数
- 支持 ?/$1/:1 多数据库占位符风格
- 零依赖,纯 Go 标准库实现
2026-04-01 00:27:50 +08:00

660 lines
19 KiB
Go

package utpl
import (
"errors"
"strings"
"testing"
)
// helper to parse with default engine
func parse(t *testing.T, source string) *Template {
t.Helper()
tpl, err := New().Parse("test", source)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
return tpl
}
// helper to parse with options and execute
func exec(t *testing.T, source string, vars map[string]any, opts ...Option) *Result {
t.Helper()
tpl, err := New(opts...).Parse("test", source)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
result, err := tpl.Execute(vars)
if err != nil {
t.Fatalf("execute failed: %v", err)
}
return result
}
// ---------- 1. TestParam ----------
func TestParam(t *testing.T) {
t.Run("single param", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE id = #{id}", map[string]any{"id": 42})
if r.SQL != "SELECT * FROM users WHERE id = ?" {
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT * FROM users WHERE id = ?")
}
if len(r.Args) != 1 || r.Args[0] != 42 {
t.Errorf("Args = %v, want [42]", r.Args)
}
})
t.Run("multiple params", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE id = #{id} AND name = #{name}",
map[string]any{"id": 1, "name": "alice"})
if r.SQL != "SELECT * FROM users WHERE id = ? AND name = ?" {
t.Errorf("SQL = %q", r.SQL)
}
if len(r.Args) != 2 || r.Args[0] != 1 || r.Args[1] != "alice" {
t.Errorf("Args = %v, want [1 alice]", r.Args)
}
})
t.Run("nil value param", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE name = #{name}", map[string]any{"name": nil})
if r.SQL != "SELECT * FROM users WHERE name = ?" {
t.Errorf("SQL = %q", r.SQL)
}
if len(r.Args) != 1 || r.Args[0] != nil {
t.Errorf("Args = %v, want [nil]", r.Args)
}
})
t.Run("param with nested dot path", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE name = #{user.name}",
map[string]any{"user": map[string]any{"name": "bob"}})
if r.SQL != "SELECT * FROM users WHERE name = ?" {
t.Errorf("SQL = %q", r.SQL)
}
if len(r.Args) != 1 || r.Args[0] != "bob" {
t.Errorf("Args = %v, want [bob]", r.Args)
}
})
}
// ---------- 2. TestRaw ----------
func TestRaw(t *testing.T) {
t.Run("basic raw substitution", func(t *testing.T) {
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"})
if r.SQL != "SELECT name FROM users" {
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT name FROM users")
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
t.Run("multiple raw params", func(t *testing.T) {
r := exec(t, "SELECT ${a}, ${b} FROM ${table}",
map[string]any{"a": "id", "b": "name", "table": "users"})
if r.SQL != "SELECT id, name FROM users" {
t.Errorf("SQL = %q", r.SQL)
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
t.Run("allowlist pass", func(t *testing.T) {
policy := RawAllowlist{
"col": {"name", "age"},
}
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"},
WithRawPolicy(policy))
if r.SQL != "SELECT name FROM users" {
t.Errorf("SQL = %q", r.SQL)
}
})
t.Run("allowlist reject", func(t *testing.T) {
policy := RawAllowlist{
"col": {"name", "age"},
}
tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users")
_, err := tpl.Execute(map[string]any{"col": "password"})
if err == nil {
t.Fatal("expected error for blocked raw value, got nil")
}
var unsafeErr *UnsafeRawError
if !errors.As(err, &unsafeErr) {
t.Errorf("expected UnsafeRawError, got %T: %v", err, err)
}
})
t.Run("blocklist allow", func(t *testing.T) {
policy := RawBlocklist{
"col": {"password"},
}
r := exec(t, "SELECT ${col} FROM users", map[string]any{"col": "name"},
WithRawPolicy(policy))
if r.SQL != "SELECT name FROM users" {
t.Errorf("SQL = %q", r.SQL)
}
})
t.Run("blocklist reject", func(t *testing.T) {
policy := RawBlocklist{
"col": {"password", "secret"},
}
tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col} FROM users")
_, err := tpl.Execute(map[string]any{"col": "password"})
if err == nil {
t.Fatal("expected error for blocklisted value, got nil")
}
var unsafeErr *UnsafeRawError
if !errors.As(err, &unsafeErr) {
t.Errorf("expected UnsafeRawError, got %T: %v", err, err)
}
})
}
// ---------- 3. TestComment ----------
func TestComment(t *testing.T) {
t.Run("comment stripped", func(t *testing.T) {
r := exec(t, "# this is a comment\nSELECT 1", nil)
if r.SQL != "SELECT 1" {
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT 1")
}
})
t.Run("comment in middle", func(t *testing.T) {
r := exec(t, "SELECT 1\n# comment\nSELECT 2", nil)
want := "SELECT 1\nSELECT 2"
if r.SQL != want {
t.Errorf("SQL = %q, want %q", r.SQL, want)
}
})
t.Run("no false positive with param", func(t *testing.T) {
r := exec(t, "SELECT #{id} FROM users", map[string]any{"id": 1})
if r.SQL != "SELECT ? FROM users" {
t.Errorf("SQL = %q, want %q", r.SQL, "SELECT ? FROM users")
}
if len(r.Args) != 1 || r.Args[0] != 1 {
t.Errorf("Args = %v", r.Args)
}
})
}
// ---------- 4. TestIf ----------
func TestIf(t *testing.T) {
t.Run("true condition", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }",
map[string]any{"status": "active"})
if !strings.Contains(r.SQL, "AND status = ?") {
t.Errorf("SQL = %q, want to contain 'AND status = ?'", r.SQL)
}
if len(r.Args) != 1 || r.Args[0] != "active" {
t.Errorf("Args = %v, want [active]", r.Args)
}
})
t.Run("false condition nil", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(status != nil) { AND status = #{status} }",
map[string]any{"status": nil})
if strings.Contains(r.SQL, "AND status") {
t.Errorf("SQL = %q, should not contain 'AND status'", r.SQL)
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
t.Run("else branch", func(t *testing.T) {
r := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}",
map[string]any{"desc": true})
if !strings.Contains(r.SQL, "DESC") {
t.Errorf("SQL = %q, want DESC", r.SQL)
}
if strings.Contains(r.SQL, "ASC") {
t.Errorf("SQL = %q, should not contain ASC", r.SQL)
}
r2 := exec(t, "SELECT * FROM users ORDER BY id @if(desc) {DESC} else {ASC}",
map[string]any{"desc": false})
if !strings.Contains(r2.SQL, "ASC") {
t.Errorf("SQL = %q, want ASC", r2.SQL)
}
})
t.Run("empty string check", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }",
map[string]any{"name": "alice"})
if !strings.Contains(r.SQL, "AND name = ?") {
t.Errorf("SQL = %q, want 'AND name = ?'", r.SQL)
}
r2 := exec(t, "SELECT * FROM users WHERE 1=1 @if(name != \"\") { AND name = #{name} }",
map[string]any{"name": ""})
if strings.Contains(r2.SQL, "AND name") {
t.Errorf("SQL = %q, should not contain 'AND name'", r2.SQL)
}
})
t.Run("multiple independent conditions", func(t *testing.T) {
r := exec(t, "@if(a != nil) {A} @if(b != nil) {B}",
map[string]any{"a": 1, "b": 2})
if !strings.Contains(r.SQL, "A") || !strings.Contains(r.SQL, "B") {
t.Errorf("SQL = %q, want both A and B", r.SQL)
}
r2 := exec(t, "@if(a != nil) {A} @if(b != nil) {B}",
map[string]any{"a": 1})
if !strings.Contains(r2.SQL, "A") {
t.Errorf("SQL = %q, want A", r2.SQL)
}
if strings.Contains(r2.SQL, "B") {
t.Errorf("SQL = %q, should not contain B", r2.SQL)
}
})
t.Run("AND operator", func(t *testing.T) {
r := exec(t, "@if(a != nil && b != \"\") { both }",
map[string]any{"a": 1, "b": "x"})
if !strings.Contains(r.SQL, "both") {
t.Errorf("SQL = %q, want 'both'", r.SQL)
}
r2 := exec(t, "@if(a != nil && b != \"\") { both }",
map[string]any{"a": 1, "b": ""})
if strings.Contains(r2.SQL, "both") {
t.Errorf("SQL = %q, should not contain 'both'", r2.SQL)
}
})
t.Run("OR operator", func(t *testing.T) {
r := exec(t, "@if(a != nil || b != nil) { either }",
map[string]any{"a": nil, "b": 1})
if !strings.Contains(r.SQL, "either") {
t.Errorf("SQL = %q, want 'either'", r.SQL)
}
r2 := exec(t, "@if(a != nil || b != nil) { either }",
map[string]any{"a": nil, "b": nil})
if strings.Contains(r2.SQL, "either") {
t.Errorf("SQL = %q, should not contain 'either'", r2.SQL)
}
})
}
// ---------- 5. TestFor ----------
func TestFor(t *testing.T) {
t.Run("basic IN clause", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
map[string]any{"ids": []int{1, 2, 3}})
if len(r.Args) != 3 {
t.Fatalf("Args = %v, want 3 args", r.Args)
}
for i, want := range []int{1, 2, 3} {
if r.Args[i] != want {
t.Errorf("Args[%d] = %v, want %v", i, r.Args[i], want)
}
}
})
t.Run("trailing comma auto-trim", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
map[string]any{"ids": []int{1, 2}})
// trailing comma and space should be trimmed by the executor
if strings.HasSuffix(r.SQL, ",") {
t.Errorf("SQL = %q, should not end with comma", r.SQL)
}
})
t.Run("empty list", func(t *testing.T) {
r := exec(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })",
map[string]any{"ids": []int{}})
if strings.Contains(r.SQL, "?") {
t.Errorf("SQL = %q, should have no placeholders", r.SQL)
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
t.Run("nil list", func(t *testing.T) {
// nil list produces no output (treated as empty slice)
tpl := parse(t, "SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })")
r, err := tpl.Execute(map[string]any{"ids": nil})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if strings.Contains(r.SQL, "?") {
t.Errorf("SQL = %q, should have no placeholders", r.SQL)
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
t.Run("with index variable", func(t *testing.T) {
r := exec(t, "SELECT @for(i, v range items) {#{v}, }",
map[string]any{"items": []string{"a", "b", "c"}})
if len(r.Args) != 3 {
t.Errorf("Args = %v, want 3 args", r.Args)
}
})
t.Run("nested object access", func(t *testing.T) {
r := exec(t, "SELECT @for(u range users) {#{u.name}, }",
map[string]any{"users": []map[string]any{
{"name": "alice"},
{"name": "bob"},
}})
if len(r.Args) != 2 {
t.Fatalf("Args = %v, want 2 args", r.Args)
}
if r.Args[0] != "alice" || r.Args[1] != "bob" {
t.Errorf("Args = %v, want [alice bob]", r.Args)
}
})
}
// ---------- 6. TestTpl ----------
func TestTpl(t *testing.T) {
t.Run("single block", func(t *testing.T) {
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
r, err := tpl.ExecuteBlock("search", nil)
if err != nil {
t.Fatalf("ExecuteBlock failed: %v", err)
}
if !strings.Contains(r.SQL, "SELECT * FROM users") {
t.Errorf("SQL = %q", r.SQL)
}
})
t.Run("multiple blocks", func(t *testing.T) {
tpl := parse(t, `@tpl("search") { SELECT * FROM users } @tpl("count") { SELECT COUNT(*) FROM users }`)
r1, err := tpl.ExecuteBlock("search", nil)
if err != nil {
t.Fatalf("ExecuteBlock search failed: %v", err)
}
if !strings.Contains(r1.SQL, "SELECT * FROM users") {
t.Errorf("search SQL = %q", r1.SQL)
}
r2, err := tpl.ExecuteBlock("count", nil)
if err != nil {
t.Fatalf("ExecuteBlock count failed: %v", err)
}
if !strings.Contains(r2.SQL, "SELECT COUNT(*) FROM users") {
t.Errorf("count SQL = %q", r2.SQL)
}
})
t.Run("execute on tpl template should error", func(t *testing.T) {
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
_, err := tpl.Execute(nil)
if err == nil {
t.Fatal("expected error when Execute is called on template with blocks")
}
var execErr *ExecError
if !errors.As(err, &execErr) {
t.Errorf("expected ExecError, got %T: %v", err, err)
}
})
t.Run("ExecuteBlock with wrong name", func(t *testing.T) {
tpl := parse(t, `@tpl("search") { SELECT * FROM users }`)
_, err := tpl.ExecuteBlock("nonexistent", nil)
if err == nil {
t.Fatal("expected error for nonexistent block name")
}
var execErr *ExecError
if !errors.As(err, &execErr) {
t.Errorf("expected ExecError, got %T: %v", err, err)
}
})
}
// ---------- 7. TestNamespace ----------
func TestNamespace(t *testing.T) {
t.Run("namespaced block access", func(t *testing.T) {
tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
r, err := tpl.ExecuteBlock("orders.search", nil)
if err != nil {
t.Fatalf("ExecuteBlock orders.search failed: %v", err)
}
if !strings.Contains(r.SQL, "SELECT * FROM orders") {
t.Errorf("SQL = %q", r.SQL)
}
})
t.Run("short name fallback", func(t *testing.T) {
tpl := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
r, err := tpl.ExecuteBlock("search", nil)
if err != nil {
t.Fatalf("ExecuteBlock search failed: %v", err)
}
if !strings.Contains(r.SQL, "SELECT * FROM orders") {
t.Errorf("SQL = %q", r.SQL)
}
})
t.Run("same block names different namespaces", func(t *testing.T) {
tpl1 := parse(t, `@namespace("orders") @tpl("search") { SELECT * FROM orders }`)
tpl2 := parse(t, `@namespace("users") @tpl("search") { SELECT * FROM users }`)
r1, err := tpl1.ExecuteBlock("orders.search", nil)
if err != nil {
t.Fatalf("orders.search failed: %v", err)
}
if !strings.Contains(r1.SQL, "orders") {
t.Errorf("SQL = %q, want orders table", r1.SQL)
}
r2, err := tpl2.ExecuteBlock("users.search", nil)
if err != nil {
t.Fatalf("users.search failed: %v", err)
}
if !strings.Contains(r2.SQL, "users") {
t.Errorf("SQL = %q, want users table", r2.SQL)
}
})
}
// ---------- 8. TestPlaceholderStyles ----------
func TestPlaceholderStyles(t *testing.T) {
t.Run("QuestionMark default", func(t *testing.T) {
r := exec(t, "WHERE id = #{id}", map[string]any{"id": 1})
if !strings.Contains(r.SQL, "WHERE id = ?") {
t.Errorf("SQL = %q, want 'WHERE id = ?'", r.SQL)
}
})
t.Run("DollarNumber", func(t *testing.T) {
r := exec(t, "WHERE id = #{id} AND name = #{name}",
map[string]any{"id": 1, "name": "a"},
WithPlaceholderStyle(DollarNumber))
if !strings.Contains(r.SQL, "WHERE id = $1 AND name = $2") {
t.Errorf("SQL = %q, want 'WHERE id = $1 AND name = $2'", r.SQL)
}
})
t.Run("ColonNumber", func(t *testing.T) {
r := exec(t, "WHERE id = #{id} AND name = #{name}",
map[string]any{"id": 1, "name": "a"},
WithPlaceholderStyle(ColonNumber))
if !strings.Contains(r.SQL, "WHERE id = :1 AND name = :2") {
t.Errorf("SQL = %q, want 'WHERE id = :1 AND name = :2'", r.SQL)
}
})
}
// ---------- 9. TestStrictMode ----------
func TestStrictMode(t *testing.T) {
t.Run("strict true undefined param", func(t *testing.T) {
tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}")
if err != nil {
t.Fatalf("parse failed: %v", err)
}
_, err = tpl.Execute(map[string]any{})
if err == nil {
t.Fatal("expected error for undefined variable in strict mode")
}
})
t.Run("strict false undefined param", func(t *testing.T) {
r := exec(t, "SELECT #{id} FROM users", map[string]any{},
WithStrictMode(false))
if !strings.Contains(r.SQL, "SELECT ? FROM users") {
t.Errorf("SQL = %q, want 'SELECT ? FROM users'", r.SQL)
}
if len(r.Args) != 1 || r.Args[0] != nil {
t.Errorf("Args = %v, want [nil]", r.Args)
}
})
t.Run("strict false undefined raw", func(t *testing.T) {
r := exec(t, "SELECT ${col} FROM users", map[string]any{},
WithStrictMode(false))
if r.SQL != "SELECT FROM users" {
t.Errorf("SQL = %q, want 'SELECT FROM users'", r.SQL)
}
if len(r.Args) != 0 {
t.Errorf("Args = %v, want empty", r.Args)
}
})
}
// ---------- 10. TestComplex ----------
func TestComplex(t *testing.T) {
t.Run("conditional WHERE with 1=1", func(t *testing.T) {
src := `SELECT * FROM users WHERE 1=1
@if(status != nil) { AND status = #{status} }
@if(name != "") { AND name = #{name} }`
r := exec(t, src, map[string]any{"status": "active", "name": "alice"})
if !strings.Contains(r.SQL, "AND status = ?") {
t.Errorf("SQL missing status condition: %q", r.SQL)
}
if !strings.Contains(r.SQL, "AND name = ?") {
t.Errorf("SQL missing name condition: %q", r.SQL)
}
if len(r.Args) != 2 {
t.Errorf("Args = %v, want 2 args", r.Args)
}
r2 := exec(t, src, map[string]any{"status": nil, "name": ""})
if strings.Contains(r2.SQL, "AND") {
t.Errorf("SQL should have no AND: %q", r2.SQL)
}
if len(r2.Args) != 0 {
t.Errorf("Args = %v, want empty", r2.Args)
}
})
t.Run("dynamic UPDATE with trailing commas", func(t *testing.T) {
src := `UPDATE users SET
@if(name != nil) { name = #{name}, }
@if(email != nil) { email = #{email}, }
WHERE id = #{id}`
r := exec(t, src, map[string]any{"name": "alice", "email": "a@b.com", "id": 1})
if !strings.Contains(r.SQL, "name = ?") {
t.Errorf("SQL missing name: %q", r.SQL)
}
if !strings.Contains(r.SQL, "email = ?") {
t.Errorf("SQL missing email: %q", r.SQL)
}
if strings.Contains(r.SQL, ",\nWHERE") {
t.Errorf("trailing comma not trimmed: %q", r.SQL)
}
if len(r.Args) != 3 {
t.Errorf("Args = %v, want 3 args", r.Args)
}
})
t.Run("batch INSERT with @for", func(t *testing.T) {
src := `INSERT INTO users (name, age) VALUES
@for(u range users) { (#{u.name}, #{u.age}), }`
r := exec(t, src, map[string]any{"users": []map[string]any{
{"name": "alice", "age": 30},
{"name": "bob", "age": 25},
}})
if len(r.Args) != 4 {
t.Fatalf("Args = %v, want 4 args", r.Args)
}
if r.Args[0] != "alice" || r.Args[1] != 30 || r.Args[2] != "bob" || r.Args[3] != 25 {
t.Errorf("Args = %v, want [alice 30 bob 25]", r.Args)
}
})
t.Run("ORDER BY with @if desc/asc", func(t *testing.T) {
src := `SELECT * FROM users ORDER BY id @if(asc) {ASC} else {DESC}`
r := exec(t, src, map[string]any{"asc": true})
if !strings.Contains(r.SQL, "ASC") {
t.Errorf("SQL = %q, want ASC", r.SQL)
}
if strings.Contains(r.SQL, "DESC") {
t.Errorf("SQL = %q, should not contain DESC", r.SQL)
}
r2 := exec(t, src, map[string]any{"asc": false})
if !strings.Contains(r2.SQL, "DESC") {
t.Errorf("SQL = %q, want DESC", r2.SQL)
}
})
}
// ---------- Benchmarks ----------
func BenchmarkSimple(b *testing.B) {
tpl := New().MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
vars := map[string]any{"id": 42, "name": "alice"}
b.ResetTimer()
for b.Loop() {
_, _ = tpl.Execute(vars)
}
}
func BenchmarkConditional(b *testing.B) {
src := `SELECT * FROM users WHERE 1=1
@if(status != nil) { AND status = #{status} }
@if(name != "") { AND name = #{name} }
ORDER BY id`
tpl := New().MustParse("bench", src)
vars := map[string]any{"status": "active", "name": "alice"}
b.ResetTimer()
for b.Loop() {
_, _ = tpl.Execute(vars)
}
}
func BenchmarkLoop(b *testing.B) {
src := `SELECT * FROM users WHERE id IN (@for(id range ids) {#{id}, })`
tpl := New().MustParse("bench", src)
ids := make([]int, 10)
for i := range ids {
ids[i] = i
}
vars := map[string]any{"ids": ids}
b.ResetTimer()
for b.Loop() {
_, _ = tpl.Execute(vars)
}
}
func BenchmarkPlaceholderDollar(b *testing.B) {
eng := New(WithPlaceholderStyle(DollarNumber))
tpl := eng.MustParse("bench", "SELECT * FROM users WHERE id = #{id} AND name = #{name}")
vars := map[string]any{"id": 42, "name": "alice"}
b.ResetTimer()
for b.Loop() {
_, _ = tpl.Execute(vars)
}
}