diff --git a/engine.go b/engine.go index 9a07456..8d183af 100644 --- a/engine.go +++ b/engine.go @@ -1,6 +1,7 @@ package utpl import ( + "errors" "fmt" "gitea.1216.top/lxy/u-tpl/internal" @@ -114,8 +115,16 @@ func wrapParseError(err error, name string) error { if _, ok := err.(*ParseError); ok { return err } + msg := fmt.Sprintf("template %q: %s", name, err.Error()) + var posErr *internal.PosError + if errors.As(err, &posErr) { + return &ParseError{ + Pos: Position{Line: posErr.Line, Column: posErr.Col}, + Message: msg, + } + } return &ParseError{ - Pos: Position{Line: 0, Column: 0}, - Message: fmt.Sprintf("template %q: %s", name, err.Error()), + Pos: Position{}, + Message: msg, } } diff --git a/error.go b/error.go index d46d8c8..f63c2e9 100644 --- a/error.go +++ b/error.go @@ -14,11 +14,10 @@ func (p Position) String() string { type ParseError struct { Message string Pos Position - Token string } func (e ParseError) Error() string { - return fmt.Sprintf("%s: %s (token: %q)", e.Pos, e.Message, e.Token) + return fmt.Sprintf("%s: %s", e.Pos, e.Message) } type ExecError struct { diff --git a/internal/executor.go b/internal/executor.go index 4e01a8e..ee8a725 100644 --- a/internal/executor.go +++ b/internal/executor.go @@ -61,7 +61,7 @@ func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, arg val, ok := ctx.Get(n.Name) if !ok { if e.strict { - return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) + return PosErrorf(n.Pos.Line, n.Pos.Col, "undefined variable %q", n.Name) } val = nil } @@ -72,7 +72,7 @@ func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, arg val, ok := ctx.Get(n.Name) if !ok { if e.strict { - return fmt.Errorf("line %d, col %d: undefined variable %q", n.Pos.Line, n.Pos.Col, n.Name) + return PosErrorf(n.Pos.Line, n.Pos.Col, "undefined variable %q", n.Name) } val = "" } @@ -103,11 +103,11 @@ func (e *Executor) walk(ctx *Context, ph *Placeholder, sql *strings.Builder, arg // skip case *UseNode: if e.blocks == nil { - return fmt.Errorf("line %d, col %d: @use(\"%s\") no blocks available", n.Pos.Line, n.Pos.Col, n.Name) + return PosErrorf(n.Pos.Line, n.Pos.Col, "@use(\"%s\") no blocks available", n.Name) } blockNodes, ok := e.blocks[n.Name] if !ok { - return fmt.Errorf("line %d, col %d: @use(\"%s\") block not found", n.Pos.Line, n.Pos.Col, n.Name) + return PosErrorf(n.Pos.Line, n.Pos.Col, "@use(\"%s\") block not found", n.Name) } if err := e.walk(ctx, ph, sql, args, blockNodes); err != nil { return err @@ -159,7 +159,7 @@ func (e *Executor) walkFor(ctx *Context, ph *Placeholder, sql *strings.Builder, case reflect.Slice, reflect.Array: length = rv.Len() default: - return fmt.Errorf("line %d, col %d: @for requires a slice or array, got %T", n.Pos.Line, n.Pos.Col, listVal) + return PosErrorf(n.Pos.Line, n.Pos.Col, "@for requires a slice or array, got %T", listVal) } for i := 0; i < length; i++ { diff --git a/internal/expr.go b/internal/expr.go index a772bf1..98d4b9d 100644 --- a/internal/expr.go +++ b/internal/expr.go @@ -127,7 +127,7 @@ func (p *ExprParser) parseUnary() (*Expr, error) { func (p *ExprParser) parsePrimary() (*Expr, error) { p.skipSpaces() if p.pos >= len(p.input) { - return nil, fmt.Errorf("line %d, col %d: unexpected end of expression", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "unexpected end of expression") } ch := p.input[p.pos] @@ -152,12 +152,12 @@ func (p *ExprParser) parsePrimary() (*Expr, error) { } p.skipSpaces() if p.pos >= len(p.input) || p.input[p.pos] != ')' { - return nil, fmt.Errorf("line %d, col %d: expected ')'", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "expected ')'") } p.skip(1) return expr, nil } - return nil, fmt.Errorf("line %d, col %d: unexpected character %q", p.line, p.col, string(ch)) + return nil, PosErrorf(p.line, p.col, "unexpected character %q", string(ch)) } func (p *ExprParser) parseIdentOrKeyword() (*Expr, error) { @@ -186,7 +186,7 @@ func (p *ExprParser) parseIdentOrKeyword() (*Expr, error) { for p.pos < len(p.input) && p.input[p.pos] == '.' { p.skip(1) if p.pos >= len(p.input) || !isIdentStart(p.input[p.pos]) { - return nil, fmt.Errorf("line %d, col %d: expected identifier after '.'", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "expected identifier after '.'") } segStart := p.pos for p.pos < len(p.input) && isIdentPart(p.input[p.pos]) { @@ -221,7 +221,7 @@ func (p *ExprParser) parseFuncCall(name string) (*Expr, error) { } p.skipSpaces() if p.pos >= len(p.input) || p.input[p.pos] != ')' { - return nil, fmt.Errorf("line %d, col %d: expected ')' after function call", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "expected ')' after function call") } p.skip(1) return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprFuncCall, FuncName: name, FuncArgs: args}, nil @@ -252,7 +252,7 @@ func (p *ExprParser) parseStringLit() (*Expr, error) { p.pos++ } if p.pos >= len(p.input) { - return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "unterminated string") } p.skip(1) return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil @@ -266,7 +266,7 @@ func (p *ExprParser) parseSingleQuoteStringLit() (*Expr, error) { p.pos++ } if p.pos >= len(p.input) { - return nil, fmt.Errorf("line %d, col %d: unterminated string", p.line, p.col) + return nil, PosErrorf(p.line, p.col, "unterminated string") } p.skip(1) return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: buf.String()}, nil @@ -292,13 +292,13 @@ func (p *ExprParser) parseNumberLit() (*Expr, error) { if isFloat { v, err := strconv.ParseFloat(text, 64) if err != nil { - return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text) + return nil, PosErrorf(p.line, p.col, "invalid number %q", text) } return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: v}, nil } v, err := strconv.ParseInt(text, 10, 64) if err != nil { - return nil, fmt.Errorf("line %d, col %d: invalid number %q", p.line, p.col, text) + return nil, PosErrorf(p.line, p.col, "invalid number %q", text) } return &Expr{Pos: Pos{Line: p.line, Col: p.col}, ExprType: ExprLiteral, Value: int(v)}, nil } @@ -359,7 +359,7 @@ func Eval(expr *Expr, ctx *Context) (any, error) { case ExprFuncCall: return evalFuncCall(expr, ctx) default: - return nil, fmt.Errorf("line %d, col %d: unknown expression type", expr.Pos.Line, expr.Pos.Col) + return nil, PosErrorf(expr.Pos.Line, expr.Pos.Col, "unknown expression type") } } @@ -372,7 +372,7 @@ func evalUnary(expr *Expr, ctx *Context) (any, error) { case "!": return !isTruthy(val), nil default: - return nil, fmt.Errorf("line %d, col %d: unknown unary operator %q", expr.Pos.Line, expr.Pos.Col, expr.UnaryOp) + return nil, PosErrorf(expr.Pos.Line, expr.Pos.Col, "unknown unary operator %q", expr.UnaryOp) } } @@ -421,14 +421,14 @@ func evalBinary(expr *Expr, ctx *Context) (any, error) { case ">=": return compareOrder(left, right, expr.Op) default: - return nil, fmt.Errorf("line %d, col %d: unknown operator %q", expr.Pos.Line, expr.Pos.Col, expr.Op) + return nil, PosErrorf(expr.Pos.Line, expr.Pos.Col, "unknown operator %q", expr.Op) } } func evalFuncCall(expr *Expr, ctx *Context) (any, error) { fn, ok := LookupBuiltin(expr.FuncName) if !ok { - return nil, fmt.Errorf("line %d, col %d: unknown function %q", expr.Pos.Line, expr.Pos.Col, expr.FuncName) + return nil, PosErrorf(expr.Pos.Line, expr.Pos.Col, "unknown function %q", expr.FuncName) } var args []any for _, a := range expr.FuncArgs { @@ -440,7 +440,7 @@ func evalFuncCall(expr *Expr, ctx *Context) (any, error) { } result, ok := fn(args) if !ok { - return nil, fmt.Errorf("line %d, col %d: function %q call failed", expr.Pos.Line, expr.Pos.Col, expr.FuncName) + return nil, PosErrorf(expr.Pos.Line, expr.Pos.Col, "function %q call failed", expr.FuncName) } return result, nil } diff --git a/internal/parser.go b/internal/parser.go index 5114c22..8bd436a 100644 --- a/internal/parser.go +++ b/internal/parser.go @@ -3,6 +3,7 @@ package internal import ( "fmt" "strings" + "unicode/utf8" ) type Parser struct { @@ -39,7 +40,7 @@ func (p *Parser) Parse() ([]Node, error) { continue } if hasTpl && !isWhitespace(tok.Value) { - return nil, fmt.Errorf("line %d, col %d: top-level text is not allowed in templates with @tpl blocks", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "top-level text is not allowed in templates with @tpl blocks") } nodes = append(nodes, &TextNode{Pos: tok.Pos, Text: tok.Value}) @@ -96,7 +97,7 @@ func (p *Parser) Parse() ([]Node, error) { } nodes = append(nodes, node) - case TokUseStart: + case TokUseStart: node, err := p.parseUse(tok) if err != nil { return nil, err @@ -104,7 +105,7 @@ func (p *Parser) Parse() ([]Node, error) { nodes = append(nodes, node) case TokElse: - return nil, fmt.Errorf("line %d, col %d: unexpected else", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unexpected else") case TokComment: p.pos++ @@ -248,7 +249,7 @@ func (p *Parser) consumeTokensForRuneRange(_, runeEnd int) { // Text tokens store Pos as the end position; other tokens store Pos as the start. tokRuneStart := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) if tok.Type == TokText { - tokRuneStart -= len(tok.Value) + tokRuneStart -= utf8.RuneCountInString(tok.Value) } if tokRuneStart >= runeEnd { break @@ -256,11 +257,12 @@ func (p *Parser) consumeTokensForRuneRange(_, runeEnd int) { // For text tokens, check if the token extends past runeEnd. // If so, split it: keep the remainder as a new text token. if tok.Type == TokText && len(tok.Value) > 0 { - tokRuneEnd := tokRuneStart + len(tok.Value) + runes := []rune(tok.Value) + tokRuneEnd := tokRuneStart + len(runes) if tokRuneEnd > runeEnd { overlap := runeEnd - tokRuneStart - if overlap > 0 && overlap < len(tok.Value) { - remainder := tok.Value[overlap:] + if overlap > 0 && overlap < len(runes) { + remainder := string(runes[overlap:]) // Replace current token with the remainder p.tokens[p.pos] = Token{ Type: TokText, @@ -279,11 +281,11 @@ func (p *Parser) parseParam(tok Token) (*ParamNode, error) { runePos := p.runePosFromToken(tok) content, endPos, err := p.readUntilBrace(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated param, expected '}'", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated param, expected '}'") } content = strings.TrimSpace(content) if content == "" { - return nil, fmt.Errorf("line %d, col %d: empty param name", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "empty param name") } p.consumeTokensForRuneRange(runePos, endPos+1) return &ParamNode{Pos: tok.Pos, Name: content}, nil @@ -293,11 +295,11 @@ func (p *Parser) parseRaw(tok Token) (*RawNode, error) { runePos := p.runePosFromToken(tok) content, endPos, err := p.readUntilBrace(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated raw, expected '}'", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated raw, expected '}'") } content = strings.TrimSpace(content) if content == "" { - return nil, fmt.Errorf("line %d, col %d: empty raw name", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "empty raw name") } p.consumeTokensForRuneRange(runePos, endPos+1) return &RawNode{Pos: tok.Pos, Name: content}, nil @@ -307,7 +309,7 @@ func (p *Parser) parseIf(tok Token) (*IfNode, error) { runePos := p.runePosFromToken(tok) exprStr, parenClose, err := p.readUntilParen(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @if, expected ')'", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated @if, expected ')'") } expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse() @@ -318,7 +320,7 @@ func (p *Parser) parseIf(tok Token) (*IfNode, error) { // Find '{' after ')' in raw input braceOpen := findChar(p.input[parenClose+1:], '{') if braceOpen < 0 { - return nil, fmt.Errorf("line %d, col %d: expected '{' after @if condition", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "expected '{' after @if condition") } braceOpen = parenClose + 1 + braceOpen @@ -344,12 +346,12 @@ func (p *Parser) parseElseIfBranch(tok Token) (*ElseIfBranch, error) { runePos := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) idx := strings.Index(tok.Value, prefix) if idx < 0 { - return nil, fmt.Errorf("line %d, col %d: expected @elseif(", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "expected @elseif(") } exprStart := runePos + idx + len(prefix) exprStr, closePos, err := p.readUntilParen(exprStart) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @elseif, expected ')'", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated @elseif, expected ')'") } expr, err := NewExprParser(strings.TrimSpace(exprStr), tok.Pos.Line, tok.Pos.Col).Parse() @@ -376,19 +378,19 @@ func (p *Parser) parseFor(tok Token) (*ForNode, error) { runePos := p.runePosFromToken(tok) content, parenClose, err := p.readUntilParen(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @for, expected ')'", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated @for, expected ')'") } braceOpen := findChar(p.input[parenClose+1:], '{') if braceOpen < 0 { - return nil, fmt.Errorf("line %d, col %d: expected '{' after @for", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "expected '{' after @for") } braceOpen = parenClose + 1 + braceOpen p.consumeTokensForRuneRange(runePos, braceOpen+1) keyVar, valVar, listExprStr, err := parseForHeader(content) if err != nil { - return nil, fmt.Errorf("line %d, col %d: %s", tok.Pos.Line, tok.Pos.Col, err.Error()) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "%s", err.Error()) } listExpr, err := NewExprParser(strings.TrimSpace(listExprStr), tok.Pos.Line, tok.Pos.Col).Parse() @@ -451,14 +453,14 @@ func (p *Parser) parseTpl(tok Token, tplNames map[string]bool) (*BlockNode, erro runePos := p.runePosFromToken(tok) name, quotePos, err := p.readUntilQuote(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @tpl name", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated @tpl name") } name = strings.TrimSpace(name) if name == "" { - return nil, fmt.Errorf("line %d, col %d: empty @tpl block name", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "empty @tpl block name") } if tplNames[name] { - return nil, fmt.Errorf("line %d, col %d: duplicate @tpl block name %q", tok.Pos.Line, tok.Pos.Col, name) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "duplicate @tpl block name %q", name) } tplNames[name] = true @@ -470,7 +472,7 @@ func (p *Parser) parseTpl(tok Token, tplNames map[string]bool) (*BlockNode, erro // find '{' and consume tokens up to past it braceOpen := findChar(p.input[endPos:], '{') if braceOpen < 0 { - return nil, fmt.Errorf("line %d, col %d: expected '{' after @tpl", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "expected '{' after @tpl") } braceOpen = endPos + braceOpen p.consumeTokensForRuneRange(runePos, braceOpen+1) @@ -488,23 +490,11 @@ func (p *Parser) parseTpl(tok Token, tplNames map[string]bool) (*BlockNode, erro } func (p *Parser) parseInclude(tok Token) (*IncludeNode, error) { - runePos := p.runePosFromToken(tok) - path, quotePos, err := p.readUntilQuote(runePos) + name, err := p.parseQuotedName(tok, "@include path") if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @include path", tok.Pos.Line, tok.Pos.Col) + return nil, err } - path = strings.TrimSpace(path) - - endPos := quotePos + 1 - if endPos < len(p.input) && p.input[endPos] == ')' { - endPos++ - } - p.consumeTokensForRuneRange(runePos, endPos) - - return &IncludeNode{ - Pos: tok.Pos, - Path: path, - }, nil + return &IncludeNode{Pos: tok.Pos, Path: name}, nil } func (p *Parser) expandInclude(tok Token) ([]Node, error) { @@ -513,11 +503,11 @@ func (p *Parser) expandInclude(tok Token) ([]Node, error) { return nil, err } if p.includeMgr == nil { - return nil, fmt.Errorf("line %d, col %d: @include used but no include resolver configured", incNode.Pos.Line, incNode.Pos.Col) + return nil, PosErrorf(incNode.Pos.Line, incNode.Pos.Col, "@include used but no include resolver configured") } expanded, err := p.includeMgr.Resolve(incNode.Path) if err != nil { - return nil, fmt.Errorf("line %d, col %d: %s", incNode.Pos.Line, incNode.Pos.Col, err.Error()) + return nil, PosErrorf(incNode.Pos.Line, incNode.Pos.Col, "%s", err.Error()) } subLexer := NewLexer(expanded) subTokens, err := subLexer.Tokenize() @@ -528,11 +518,13 @@ func (p *Parser) expandInclude(tok Token) ([]Node, error) { return subParser.Parse() } -func (p *Parser) parseUse(tok Token) (*UseNode, error) { +// parseQuotedName reads a quoted name from the directive token position. +// Shared by @include, @use, and @namespace. +func (p *Parser) parseQuotedName(tok Token, desc string) (string, error) { runePos := p.runePosFromToken(tok) name, quotePos, err := p.readUntilQuote(runePos) if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @use name", tok.Pos.Line, tok.Pos.Col) + return "", PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated %s", desc) } name = strings.TrimSpace(name) @@ -542,30 +534,26 @@ func (p *Parser) parseUse(tok Token) (*UseNode, error) { } p.consumeTokensForRuneRange(runePos, endPos) - return &UseNode{ - Pos: tok.Pos, - Name: name, - }, nil + return name, nil +} + +func (p *Parser) parseUse(tok Token) (*UseNode, error) { + name, err := p.parseQuotedName(tok, "@use name") + if err != nil { + return nil, err + } + return &UseNode{Pos: tok.Pos, Name: name}, nil } func (p *Parser) parseNamespace(tok Token, nodeCount int) (*NamespaceNode, error) { if nodeCount > 0 { - return nil, fmt.Errorf("line %d, col %d: @namespace must be at the top of the file", tok.Pos.Line, tok.Pos.Col) + return nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "@namespace must be at the top of the file") } - runePos := p.runePosFromToken(tok) - name, quotePos, err := p.readUntilQuote(runePos) + name, err := p.parseQuotedName(tok, "@namespace name") if err != nil { - return nil, fmt.Errorf("line %d, col %d: unterminated @namespace name", tok.Pos.Line, tok.Pos.Col) + return nil, err } - name = strings.TrimSpace(name) - - endPos := quotePos + 1 - if endPos < len(p.input) && p.input[endPos] == ')' { - endPos++ - } - p.consumeTokensForRuneRange(runePos, endPos) - return &NamespaceNode{ Pos: tok.Pos, Name: name, @@ -584,7 +572,7 @@ func (p *Parser) skipToBraceOpen() { startRune := runePosFromLineCol(p.input, tok.Pos.Line, tok.Pos.Col) // Text tokens store Pos as the end position; adjust to start. if tok.Type == TokText { - startRune -= len(tok.Value) + startRune -= utf8.RuneCountInString(tok.Value) } // Find '{' in raw input @@ -649,7 +637,7 @@ func (p *Parser) parseBlockBodyWithElse(blockType string) ([]Node, []Node, []*El for p.pos < len(p.tokens) { tok := p.cur() if tok.Type == TokEOF { - return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated %s block", tok.Pos.Line, tok.Pos.Col, blockType) + return nil, nil, nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "unterminated %s block", blockType) } if tok.Type == TokText { @@ -715,21 +703,21 @@ func (p *Parser) parseBlockBodyWithElse(blockType string) ([]Node, []Node, []*El } body = append(body, node) case TokTplStart: - return nil, nil, nil, fmt.Errorf("line %d, col %d: @tpl blocks cannot be nested", tok.Pos.Line, tok.Pos.Col) + return nil, nil, nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "@tpl blocks cannot be nested") case TokIncludeStart: subNodes, err := p.expandInclude(tok) if err != nil { return nil, nil, nil, err } body = append(body, subNodes...) - case TokUseStart: + case TokUseStart: node, err := p.parseUse(tok) if err != nil { return nil, nil, nil, err } body = append(body, node) case TokNamespaceStart: - return nil, nil, nil, fmt.Errorf("line %d, col %d: @namespace must be at file top level", tok.Pos.Line, tok.Pos.Col) + return nil, nil, nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "@namespace must be at file top level") } } @@ -782,7 +770,7 @@ func (p *Parser) handleTokElse(tok Token, body []Node) ([]Node, []Node, []*ElseI } // It's a plain else — find '{' in text tokens and split if !p.splitTextAtBrace() { - return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else", tok.Pos.Line, tok.Pos.Col) + return nil, nil, nil, PosErrorf(tok.Pos.Line, tok.Pos.Col, "expected '{' after else") } elseBody, _, _, err := p.parseBlockBodyWithElse("else") if err != nil { @@ -854,7 +842,7 @@ func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, [ // Find "if" and skip to the opening paren ifIdx := strings.Index(trimmed, "if") if ifIdx < 0 { - return nil, nil, nil, fmt.Errorf("line %d, col %d: expected 'if' in else-if condition", next.Pos.Line, next.Pos.Col) + return nil, nil, nil, PosErrorf(next.Pos.Line, next.Pos.Col, "expected 'if' in else-if condition") } afterIf := trimmed[ifIdx+2:] // skip whitespace between "if" and "(" @@ -863,12 +851,12 @@ func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, [ parenLocalOffset++ } if parenLocalOffset >= len(afterIf) || afterIf[parenLocalOffset] != '(' { - return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col) + return nil, nil, nil, PosErrorf(next.Pos.Line, next.Pos.Col, "expected '(' after 'else if'") } // Compute rune position of '(' in the raw input. - // The text token's Pos marks the END of the token text; the start is Pos - len(Value). - runePos := runePosFromLineCol(p.input, next.Pos.Line, next.Pos.Col) - len(next.Value) + // The text token's Pos marks the END of the token text; the start is Pos - rune count. + runePos := runePosFromLineCol(p.input, next.Pos.Line, next.Pos.Col) - utf8.RuneCountInString(next.Value) // Walk through next.Value byte-by-byte to find the '(' that corresponds to the condition. // We know trimmed starts at some offset into next.Value; find "if" in the raw value. @@ -878,7 +866,7 @@ func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, [ rawParenOffset++ } if rawParenOffset >= len(next.Value) || next.Value[rawParenOffset] != '(' { - return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '(' after 'else if'", next.Pos.Line, next.Pos.Col) + return nil, nil, nil, PosErrorf(next.Pos.Line, next.Pos.Col, "expected '(' after 'else if'") } // parenRunePos points to '(' in the raw input @@ -886,7 +874,7 @@ func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, [ // readUntilParen expects the position AFTER '(' — it reads content between start and closing ')' exprStr, parenClose, err := p.readUntilParen(parenRunePos + 1) if err != nil { - return nil, nil, nil, fmt.Errorf("line %d, col %d: unterminated else-if condition, expected ')'", next.Pos.Line, next.Pos.Col) + return nil, nil, nil, PosErrorf(next.Pos.Line, next.Pos.Col, "unterminated else-if condition, expected ')'") } expr, err := NewExprParser(strings.TrimSpace(exprStr), next.Pos.Line, next.Pos.Col).Parse() @@ -900,7 +888,7 @@ func (p *Parser) parseElseIfFromText(body []Node, next Token) ([]Node, []Node, [ // Find '{' after ')' in raw input braceOpen := findChar(p.input[parenClose+1:], '{') if braceOpen < 0 { - return nil, nil, nil, fmt.Errorf("line %d, col %d: expected '{' after else-if condition", next.Pos.Line, next.Pos.Col) + return nil, nil, nil, PosErrorf(next.Pos.Line, next.Pos.Col, "expected '{' after else-if condition") } braceOpen = parenClose + 1 + braceOpen diff --git a/internal/poserror.go b/internal/poserror.go new file mode 100644 index 0000000..c8799f1 --- /dev/null +++ b/internal/poserror.go @@ -0,0 +1,19 @@ +package internal + +import "fmt" + +// PosError is an internal error type that carries source position information. +// Lexer and parser return this so that the public wrapParseError can extract the position. +type PosError struct { + Line int + Col int + Message string +} + +func (e *PosError) Error() string { + return fmt.Sprintf("line %d, col %d: %s", e.Line, e.Col, e.Message) +} + +func PosErrorf(line, col int, format string, args ...any) *PosError { + return &PosError{Line: line, Col: col, Message: fmt.Sprintf(format, args...)} +} diff --git a/template.go b/template.go index d6fb7aa..39e379a 100644 --- a/template.go +++ b/template.go @@ -81,5 +81,12 @@ func wrapExecError(err error) error { if errors.As(err, &unsafeErr) { return err } + var posErr *internal.PosError + if errors.As(err, &posErr) { + return &ExecError{ + Pos: Position{Line: posErr.Line, Column: posErr.Col}, + Message: err.Error(), + } + } return &ExecError{Message: err.Error()} } diff --git a/utpl_gap_test.go b/utpl_gap_test.go new file mode 100644 index 0000000..4034369 --- /dev/null +++ b/utpl_gap_test.go @@ -0,0 +1,488 @@ +package utpl + +import ( + "errors" + "strings" + "sync" + "testing" +) + +// ---------- TestErrorPositions ---------- +// Verify that ParseError.Pos and ExecError.Pos contain correct line/column values. + +func TestErrorPositions(t *testing.T) { + t.Run("ParseError has correct position for unterminated param", func(t *testing.T) { + // #{id on line 1, col 8 (1-indexed) + _, err := New().Parse("test", "SELECT #{id") + var pe *ParseError + if !errors.As(err, &pe) { + t.Fatalf("expected ParseError, got %T: %v", err, err) + } + if pe.Pos.Line == 0 || pe.Pos.Column == 0 { + t.Errorf("ParseError.Pos = {Line:%d, Col:%d}, expected non-zero", pe.Pos.Line, pe.Pos.Column) + } + if pe.Pos.Line != 1 { + t.Errorf("ParseError.Pos.Line = %d, want 1", pe.Pos.Line) + } + if pe.Pos.Column != 8 { + t.Errorf("ParseError.Pos.Column = %d, want 8", pe.Pos.Column) + } + }) + + t.Run("ParseError has correct position for unterminated if on line 2", func(t *testing.T) { + src := "SELECT 1\n@if(x > 0" + _, err := New().Parse("test", src) + var pe *ParseError + if !errors.As(err, &pe) { + t.Fatalf("expected ParseError, got %T: %v", err, err) + } + if pe.Pos.Line != 2 { + t.Errorf("ParseError.Pos.Line = %d, want 2", pe.Pos.Line) + } + }) + + t.Run("ParseError has correct position for unterminated for on line 3", func(t *testing.T) { + src := "SELECT 1\n\n@for(x, range list)" + _, err := New().Parse("test", src) + var pe *ParseError + if !errors.As(err, &pe) { + t.Fatalf("expected ParseError, got %T: %v", err, err) + } + if pe.Pos.Line != 3 { + t.Errorf("ParseError.Pos.Line = %d, want 3", pe.Pos.Line) + } + }) + + t.Run("ExecError has correct position for undefined variable in strict mode", func(t *testing.T) { + tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{missing}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(map[string]any{}) + var ee *ExecError + if !errors.As(err, &ee) { + t.Fatalf("expected ExecError, got %T: %v", err, err) + } + if ee.Pos.Line == 0 || ee.Pos.Column == 0 { + t.Errorf("ExecError.Pos = {Line:%d, Col:%d}, expected non-zero", ee.Pos.Line, ee.Pos.Column) + } + }) + + t.Run("ExecError has correct position for undefined variable on line 2", func(t *testing.T) { + tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT 1\nWHERE id = #{missing}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(map[string]any{}) + var ee *ExecError + if !errors.As(err, &ee) { + t.Fatalf("expected ExecError, got %T: %v", err, err) + } + if ee.Pos.Line != 2 { + t.Errorf("ExecError.Pos.Line = %d, want 2", ee.Pos.Line) + } + }) + + t.Run("ExecError has correct position for non-slice in @for", func(t *testing.T) { + tpl, err := New().Parse("test", "SELECT @for(x, range items) {#{x}}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(map[string]any{"items": "not a slice"}) + var ee *ExecError + if !errors.As(err, &ee) { + t.Fatalf("expected ExecError, got %T: %v", err, err) + } + if ee.Pos.Line == 0 || ee.Pos.Column == 0 { + t.Errorf("ExecError.Pos = {Line:%d, Col:%d}, expected non-zero", ee.Pos.Line, ee.Pos.Column) + } + }) +} + +// ---------- TestRawDefault ---------- +// Verify @raw behavior when no policy is configured (default nil). + +func TestRawDefault(t *testing.T) { + t.Run("raw with no policy substitutes directly", func(t *testing.T) { + // No WithRawPolicy — rawPolicy is nil, should skip validation + r := exec(t, "SELECT ${col} FROM ${table}", map[string]any{"col": "name", "table": "users"}) + if r.SQL != "SELECT name FROM users" { + t.Errorf("SQL = %q, want 'SELECT name FROM users'", r.SQL) + } + if len(r.Args) != 0 { + t.Errorf("Args = %v, want empty", r.Args) + } + }) + + t.Run("raw with nil value and no policy", func(t *testing.T) { + r := exec(t, "SELECT ${col}", map[string]any{"col": nil}) + if r.SQL != "SELECT " { + t.Errorf("SQL = %q, want 'SELECT '", r.SQL) + } + }) +} + +// ---------- TestRawNoop ---------- +// Verify RawNoop policy. + +func TestRawNoop(t *testing.T) { + t.Run("RawNoop allows all values", func(t *testing.T) { + r := exec(t, "SELECT ${col} FROM ${table}", map[string]any{"col": "name", "table": "users"}, + WithRawPolicy(RawNoop{})) + if r.SQL != "SELECT name FROM users" { + t.Errorf("SQL = %q, want 'SELECT name FROM users'", r.SQL) + } + }) +} + +// ---------- TestStructFieldAccess ---------- +// Verify that context resolves struct fields (case-insensitive). + +func TestStructFieldAccess(t *testing.T) { + type User struct { + Name string + Age int + } + + t.Run("struct field access via dot path", func(t *testing.T) { + r := exec(t, "SELECT #{u.Name}, #{u.Age}", + map[string]any{"u": User{Name: "alice", Age: 30}}) + if r.SQL != "SELECT ?, ?" { + t.Errorf("SQL = %q, want 'SELECT ?, ?'", r.SQL) + } + if len(r.Args) != 2 || r.Args[0] != "alice" || r.Args[1] != 30 { + t.Errorf("Args = %v, want [alice 30]", r.Args) + } + }) + + t.Run("struct field case-insensitive access", func(t *testing.T) { + r := exec(t, "SELECT #{u.name}", + map[string]any{"u": User{Name: "bob", Age: 25}}) + if len(r.Args) != 1 || r.Args[0] != "bob" { + t.Errorf("Args = %v, want [bob]", r.Args) + } + }) + + t.Run("struct field in @if condition", func(t *testing.T) { + r := exec(t, "@if(u.Active) {active}", + map[string]any{"u": struct{ Active bool }{Active: true}}) + if !strings.Contains(r.SQL, "active") { + t.Errorf("SQL = %q, want 'active'", r.SQL) + } + }) +} + +// ---------- TestConcurrentExecution ---------- +// Verify Template.Execute is safe for concurrent use. + +func TestConcurrentExecution(t *testing.T) { + tpl := parse(t, "SELECT * FROM users WHERE id = #{id} AND status = #{status}") + + var wg sync.WaitGroup + errs := make(chan error, 100) + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := tpl.Execute(map[string]any{"id": i, "status": "active"}) + if err != nil { + errs <- err + } + }(i) + } + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent execution error: %v", err) + } +} + +// ---------- TestMustParse ---------- +// Verify MustParse panics on invalid input. + +func TestMustParse(t *testing.T) { + t.Run("MustParse panics on invalid template", func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected MustParse to panic on invalid template") + } + }() + New().MustParse("test", "SELECT #{id") // unterminated param + }) + + t.Run("MustParse succeeds on valid template", func(t *testing.T) { + tpl := New().MustParse("test", "SELECT #{id}") + if tpl == nil { + t.Fatal("MustParse returned nil on valid template") + } + }) +} + +// ---------- TestInclude ---------- +// Additional include tests. + +func TestIncludeCircular(t *testing.T) { + t.Run("circular include returns error", func(t *testing.T) { + callCount := 0 + resolver := func(path string) (string, error) { + callCount++ + if callCount > 10 { + return "", errors.New("too many includes") + } + return `@include("other")`, nil + } + _, err := New(WithIncludeResolver(resolver)).Parse("test", `@include("a")`) + if err == nil { + t.Fatal("expected error for circular include, got nil") + } + }) +} + +func TestIncludeNested(t *testing.T) { + t.Run("include within include expands both", func(t *testing.T) { + resolver := func(path string) (string, error) { + files := map[string]string{ + "outer": "OUTER @include(\"inner\")", + "inner": "INNER", + } + src, ok := files[path] + if !ok { + return "", errors.New("not found") + } + return src, nil + } + tpl, err := New(WithIncludeResolver(resolver)).Parse("test", "@include(\"outer\")") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + r, err := tpl.Execute(nil) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + if !strings.Contains(r.SQL, "OUTER") { + t.Errorf("SQL = %q, want OUTER", r.SQL) + } + if !strings.Contains(r.SQL, "INNER") { + t.Errorf("SQL = %q, want INNER", r.SQL) + } + }) +} + +// ---------- TestNilVars ---------- +// Verify behavior when nil map is passed to Execute. + +func TestNilVars(t *testing.T) { + t.Run("nil vars map with strict mode", func(t *testing.T) { + tpl, err := New(WithStrictMode(true)).Parse("test", "SELECT #{id}") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + _, err = tpl.Execute(nil) + if err == nil { + t.Fatal("expected error for nil vars with strict mode and #{id}") + } + }) + + t.Run("nil vars map with non-strict mode", func(t *testing.T) { + r := exec(t, "SELECT 1", nil, WithStrictMode(false)) + if r.SQL != "SELECT 1" { + t.Errorf("SQL = %q, want 'SELECT 1'", r.SQL) + } + }) +} + +// ---------- TestForTypedSlices ---------- +// Verify @for works with typed slices. + +func TestForTypedSlices(t *testing.T) { + t.Run("for with []float64", func(t *testing.T) { + r := exec(t, "SELECT @for(v range items) {#{v}, }", + map[string]any{"items": []float64{1.5, 2.5, 3.5}}) + if len(r.Args) != 3 { + t.Errorf("Args = %v, want 3 args", r.Args) + } + if r.Args[0] != 1.5 || r.Args[1] != 2.5 || r.Args[2] != 3.5 { + t.Errorf("Args = %v, want [1.5 2.5 3.5]", r.Args) + } + }) + + t.Run("for with []string", func(t *testing.T) { + r := exec(t, "SELECT @for(v range items) {#{v}, }", + map[string]any{"items": []string{"a", "b"}}) + if len(r.Args) != 2 { + t.Errorf("Args = %v, want 2 args", r.Args) + } + if r.Args[0] != "a" || r.Args[1] != "b" { + t.Errorf("Args = %v, want [a b]", r.Args) + } + }) +} + +// ---------- TestExpressionEdgeCases ---------- + +func TestExpressionEdgeCases(t *testing.T) { + t.Run("deeply nested parentheses", func(t *testing.T) { + r := exec(t, "@if(((a))) {yes}", map[string]any{"a": true}) + if !strings.Contains(r.SQL, "yes") { + t.Errorf("SQL = %q, want 'yes'", r.SQL) + } + }) + + t.Run("deep dot path", func(t *testing.T) { + r := exec(t, "@if(a.b.c.d) {deep}", map[string]any{ + "a": map[string]any{ + "b": map[string]any{ + "c": map[string]any{ + "d": true, + }, + }, + }, + }) + if !strings.Contains(r.SQL, "deep") { + t.Errorf("SQL = %q, want 'deep'", r.SQL) + } + }) + + t.Run("float literal in expression", func(t *testing.T) { + r := exec(t, "@if(x > 1.5) {big}", map[string]any{"x": 2.5}) + if !strings.Contains(r.SQL, "big") { + t.Errorf("SQL = %q, want 'big'", r.SQL) + } + + r2 := exec(t, "@if(x > 1.5) {big}", map[string]any{"x": 0.5}) + if strings.Contains(r2.SQL, "big") { + t.Errorf("SQL = %q, should not contain 'big'", r2.SQL) + } + }) + + t.Run("single-quoted string in expression", func(t *testing.T) { + r := exec(t, `@if(role == 'admin') {is admin}`, map[string]any{"role": "admin"}) + if !strings.Contains(r.SQL, "is admin") { + t.Errorf("SQL = %q, want 'is admin'", r.SQL) + } + }) + + t.Run("string escape sequences", func(t *testing.T) { + r := exec(t, `@if(x == "a\"b") {escaped}`, map[string]any{"x": `a"b`}) + if !strings.Contains(r.SQL, "escaped") { + t.Errorf("SQL = %q, want 'escaped'", r.SQL) + } + }) + + t.Run("OR short circuit", func(t *testing.T) { + // When left is true, right should not cause error even if undefined + r := exec(t, "@if(a != nil || b != nil) {either}", + map[string]any{"a": 1}) + if !strings.Contains(r.SQL, "either") { + t.Errorf("SQL = %q, want 'either'", r.SQL) + } + }) + + t.Run("negative number in expression", func(t *testing.T) { + // Expression parser does not support unary minus; use a variable instead + r := exec(t, "@if(x == neg) {negative one}", map[string]any{"x": -1, "neg": -1}) + if !strings.Contains(r.SQL, "negative one") { + t.Errorf("SQL = %q, want 'negative one'", r.SQL) + } + }) + + t.Run("unknown function returns error at execution", func(t *testing.T) { + // Functions are validated at execution time, not parse time + tpl, err := New().Parse("test", `@if(unknown(x)) {bad}`) + if err != nil { + t.Fatalf("parse should succeed (functions validated at exec): %v", err) + } + _, err = tpl.Execute(map[string]any{"x": 1}) + if err == nil { + t.Fatal("expected error for unknown function at execution, got nil") + } + }) +} + +// ---------- TestTrailingCommaEdgeCases ---------- + +func TestTrailingCommaEdgeCases(t *testing.T) { + t.Run("trailing comma with space and newline", func(t *testing.T) { + r := exec(t, "SELECT @for(v range items) {#{v}, }\n", + map[string]any{"items": []int{1, 2}}) + if strings.HasSuffix(r.SQL, ",") { + t.Errorf("SQL = %q, should not end with comma", r.SQL) + } + }) + + t.Run("no trailing comma when not present", func(t *testing.T) { + r := exec(t, "SELECT @for(v range items) {#{v} }", + map[string]any{"items": []int{1, 2}}) + if strings.HasSuffix(r.SQL, ",") { + t.Errorf("SQL = %q, should not end with comma", r.SQL) + } + }) +} + +// ---------- TestUnsafeRawErrorFields ---------- + +func TestUnsafeRawErrorFields(t *testing.T) { + t.Run("UnsafeRawError contains param and value", func(t *testing.T) { + policy := RawAllowlist{"col": {"name"}} + tpl, _ := New(WithRawPolicy(policy)).Parse("test", "SELECT ${col}") + _, err := tpl.Execute(map[string]any{"col": "DROP TABLE users"}) + if err == nil { + t.Fatal("expected error") + } + var ue *UnsafeRawError + if !errors.As(err, &ue) { + t.Fatalf("expected UnsafeRawError, got %T: %v", err, err) + } + if ue.Param != "col" { + t.Errorf("UnsafeRawError.Param = %q, want %q", ue.Param, "col") + } + if ue.Value != "DROP TABLE users" { + t.Errorf("UnsafeRawError.Value = %q, want %q", ue.Value, "DROP TABLE users") + } + if ue.Message == "" { + t.Error("UnsafeRawError.Message is empty") + } + }) +} + +// ---------- TestForWithStructSlice ---------- + +func TestForWithStructSlice(t *testing.T) { + type Item struct { + Name string + Price float64 + } + + t.Run("for with struct slice", func(t *testing.T) { + items := []Item{ + {Name: "apple", Price: 1.5}, + {Name: "banana", Price: 0.8}, + } + r := exec(t, "SELECT @for(i range items) {#{i.Name}, #{i.Price}, }", + map[string]any{"items": items}) + if len(r.Args) != 4 { + t.Fatalf("Args = %v, want 4 args", r.Args) + } + if r.Args[0] != "apple" || r.Args[1] != 1.5 { + t.Errorf("first item args wrong: %v", r.Args[:2]) + } + if r.Args[2] != "banana" || r.Args[3] != 0.8 { + t.Errorf("second item args wrong: %v", r.Args[2:]) + } + }) +} + +// ---------- TestExpressionWithMissingVar ---------- +// Non-strict mode: undefined expression variables should not error. + +func TestExpressionWithMissingVar(t *testing.T) { + t.Run("non-strict: undefined in condition evaluates to nil (falsy)", func(t *testing.T) { + r := exec(t, "@if(missing != nil) {has it}", map[string]any{}, + WithStrictMode(false)) + if strings.Contains(r.SQL, "has it") { + t.Errorf("SQL = %q, should not contain 'has it' when variable is missing", r.SQL) + } + }) +}