defmodule ExSQL.Tokenizer do
@moduledoc """
Lexical analysis for SQL text.
Mirrors SQLite's `tokenize.c`, which classifies bytes through a 256-entry
lookup table and resolves keywords with a generated perfect hash. In Elixir
the same job is done with binary pattern matching: each clause of `scan/3`
corresponds to a character class, and keyword recognition is a `MapSet`
lookup on the downcased identifier.
Tokens are `{type, value, line}` tuples:
* `{:keyword, :select, 1}` — case-insensitive SQL keywords
* `{:id, "users", 1}` — identifiers (bare, `"quoted"`, `[bracketed]`, or backticked)
* `{:int, 42, 1}` / `{:float, 3.14, 1}` — numeric literals
* `{:string, "abc", 1}` — single-quoted string literals (`''` escapes a quote)
* `{:blob, <<0xAB>>, 1}` — `x'AB'` hex blob literals
* punctuation and operators such as `{:comma, ",", 1}`, `{:le, "<=", 1}`
"""
# NOTE: the type names integer/text/real/blob are deliberately NOT keywords —
# they tokenize as identifiers so a column's declared type keeps its original
# case for PRAGMA table_info (SQLite stores the type exactly as written), and
# so they can be used as identifiers. type_name/2 reads them as plain words.
@keywords ~w(
abort add all alter always analyze and as asc autoincrement begin between by case cast check
collate column commit conflict constraint create cross current default delete desc distinct
database do drop else end escape except exists explain fail false filter following foreign from full generated glob group having if
ignore in index inner insert
intersect into is join key left like limit match natural not nothing null offset on or
order outer over partition plan pragma preceding primary query range recursive release rename replace right
references regexp reindex returning rollback rows
savepoint select
row set table then to transaction true unbounded union unique update using values attach detach
stored strict vacuum view virtual when where window with without
)a
@keyword_strings MapSet.new(@keywords, &Atom.to_string/1)
@keyword_map Map.new(@keywords, &{Atom.to_string(&1), &1})
@type token :: {atom(), term(), pos_integer()}
@doc """
Tokenizes `sql`, returning `{:ok, tokens}` or `{:error, reason}`.
The token list does not include trailing end-of-input markers; the parser
treats the empty list as end of input.
"""
@spec tokenize(String.t()) :: {:ok, [token()]} | {:error, term()}
def tokenize(sql) when is_binary(sql) do
case scan(sql, 1, []) do
{:ok, tokens} -> {:ok, number_placeholders(tokens)}
error -> error
end
end
# Assigns bind-parameter indexes in source order, per statement, following
# sqlite3_bind_parameter_index semantics: `?` takes one more than the
# largest index so far, `?NNN` takes exactly NNN, and a named parameter
# reuses its previous index or takes the next free one. Placeholder tokens
# come out as `{:placeholder, {index, raw_text}, line}`.
defp number_placeholders(tokens) do
# The map_reduce rebuilds the whole token list; skip it entirely when there
# are no placeholders (all literal SQL — the common non-prepared case).
if Enum.any?(tokens, &match?({:placeholder, _, _}, &1)) do
do_number_placeholders(tokens)
else
tokens
end
end
defp do_number_placeholders(tokens) do
{tokens, _state} =
Enum.map_reduce(tokens, {0, %{}}, fn
{:semicolon, _, _} = token, _state ->
{token, {0, %{}}}
{:placeholder, raw, line}, {max, names} ->
{index, max, names} =
case raw do
"?" ->
{max + 1, max + 1, names}
"?" <> digits ->
n = String.to_integer(digits)
{n, max(max, n), names}
name ->
case Map.fetch(names, name) do
{:ok, index} -> {index, max, names}
:error -> {max + 1, max + 1, Map.put(names, name, max + 1)}
end
end
{{:placeholder, {index, raw}, line}, {max, names}}
token, state ->
{token, state}
end)
tokens
end
defp scan(<<>>, _line, acc), do: {:ok, Enum.reverse(acc)}
defp scan(<<"\n", rest::binary>>, line, acc), do: scan(rest, line + 1, acc)
defp scan(<<c, rest::binary>>, line, acc) when c in [?\s, ?\t, ?\r, ?\f] do
scan(rest, line, acc)
end
# -- comments ---------------------------------------------------------
defp scan(<<"--", rest::binary>>, line, acc) do
case String.split(rest, "\n", parts: 2) do
[_comment, rest] -> scan(rest, line + 1, acc)
[_comment] -> scan(<<>>, line, acc)
end
end
defp scan(<<"/*", rest::binary>>, line, acc) do
case String.split(rest, "*/", parts: 2) do
[comment, rest] ->
skipped = count_newlines(comment)
scan(rest, line + skipped, acc)
[_unterminated] ->
{:error, {:unterminated_comment, line}}
end
end
# -- multi-character operators (longest match first) ------------------
defp scan(<<"<>", rest::binary>>, line, acc), do: scan(rest, line, [{:ne, "<>", line} | acc])
defp scan(<<"!=", rest::binary>>, line, acc), do: scan(rest, line, [{:ne, "!=", line} | acc])
defp scan(<<"<=", rest::binary>>, line, acc), do: scan(rest, line, [{:le, "<=", line} | acc])
defp scan(<<">=", rest::binary>>, line, acc), do: scan(rest, line, [{:ge, ">=", line} | acc])
defp scan(<<"==", rest::binary>>, line, acc), do: scan(rest, line, [{:eq, "==", line} | acc])
defp scan(<<"||", rest::binary>>, line, acc),
do: scan(rest, line, [{:concat, "||", line} | acc])
defp scan(<<"<<", rest::binary>>, line, acc), do: scan(rest, line, [{:shl, "<<", line} | acc])
defp scan(<<">>", rest::binary>>, line, acc), do: scan(rest, line, [{:shr, ">>", line} | acc])
defp scan(<<"->>", rest::binary>>, line, acc),
do: scan(rest, line, [{:arrow_text, "->>", line} | acc])
defp scan(<<"->", rest::binary>>, line, acc),
do: scan(rest, line, [{:arrow, "->", line} | acc])
# -- single-character tokens ------------------------------------------
defp scan(<<"=", rest::binary>>, line, acc), do: scan(rest, line, [{:eq, "=", line} | acc])
defp scan(<<"<", rest::binary>>, line, acc), do: scan(rest, line, [{:lt, "<", line} | acc])
defp scan(<<">", rest::binary>>, line, acc), do: scan(rest, line, [{:gt, ">", line} | acc])
defp scan(<<"(", rest::binary>>, line, acc), do: scan(rest, line, [{:lparen, "(", line} | acc])
defp scan(<<")", rest::binary>>, line, acc), do: scan(rest, line, [{:rparen, ")", line} | acc])
defp scan(<<",", rest::binary>>, line, acc), do: scan(rest, line, [{:comma, ",", line} | acc])
defp scan(<<";", rest::binary>>, line, acc),
do: scan(rest, line, [{:semicolon, ";", line} | acc])
defp scan(<<"*", rest::binary>>, line, acc), do: scan(rest, line, [{:star, "*", line} | acc])
defp scan(<<"+", rest::binary>>, line, acc), do: scan(rest, line, [{:plus, "+", line} | acc])
defp scan(<<"-", rest::binary>>, line, acc), do: scan(rest, line, [{:minus, "-", line} | acc])
defp scan(<<"/", rest::binary>>, line, acc), do: scan(rest, line, [{:slash, "/", line} | acc])
defp scan(<<"%", rest::binary>>, line, acc), do: scan(rest, line, [{:percent, "%", line} | acc])
defp scan(<<"&", rest::binary>>, line, acc), do: scan(rest, line, [{:bitand, "&", line} | acc])
defp scan(<<"|", rest::binary>>, line, acc), do: scan(rest, line, [{:bitor, "|", line} | acc])
defp scan(<<"~", rest::binary>>, line, acc), do: scan(rest, line, [{:tilde, "~", line} | acc])
defp scan(<<".", c, _::binary>> = input, line, acc) when c in ?0..?9 do
<<_dot, rest::binary>> = input
number(rest, line, acc, "0.")
end
defp scan(<<".", rest::binary>>, line, acc), do: scan(rest, line, [{:dot, ".", line} | acc])
# -- bind parameters: ?, ?NNN, :name, @name, $name ---------------------
defp scan(<<"?", rest::binary>>, line, acc) do
{digits, rest} = take_while(rest, &(&1 in ?0..?9))
scan(rest, line, [{:placeholder, "?" <> digits, line} | acc])
end
defp scan(<<sigil, rest::binary>>, line, acc) when sigil in [?:, ?@, ?$] do
case take_while(rest, &identifier_char?/1) do
{"", _rest} -> {:error, {:unexpected_character, <<sigil>>, line}}
{name, rest} -> scan(rest, line, [{:placeholder, <<sigil, name::binary>>, line} | acc])
end
end
# -- literals ----------------------------------------------------------
defp scan(<<q, rest::binary>>, line, acc) when q == ?' do
case quoted(rest, ?', "") do
{:ok, value, rest} -> scan(rest, line, [{:string, value, line} | acc])
:error -> {:error, {:unterminated_string, line}}
end
end
defp scan(<<x, ?', rest::binary>>, line, acc) when x in [?x, ?X] do
case quoted(rest, ?', "") do
{:ok, hex, rest} ->
case Base.decode16(hex, case: :mixed) do
{:ok, blob} -> scan(rest, line, [{:blob, blob, line} | acc])
:error -> {:error, {:malformed_blob_literal, hex, line}}
end
:error ->
{:error, {:unterminated_string, line}}
end
end
# Hexadecimal integer literal `0x…` (SQLite: a 64-bit value, so digits beyond
# 64 bits wrap and the result is interpreted signed — `0xffffffffffffffff` is
# -1). Requires at least one hex digit; otherwise it falls through to the
# decimal number scan (`0` then identifier).
defp scan(<<"0", x, c, _::binary>> = input, line, acc)
when x in [?x, ?X] and (c in ?0..?9 or c in ?a..?f or c in ?A..?F) do
<<_::binary-size(2), rest::binary>> = input
{hex, rest} = take_while(rest, &hex_digit?/1)
wrapped = rem(String.to_integer(hex, 16), 0x1_0000_0000_0000_0000)
value =
if wrapped >= 0x8000_0000_0000_0000, do: wrapped - 0x1_0000_0000_0000_0000, else: wrapped
scan(rest, line, [{:int, value, line} | acc])
end
defp scan(<<c, _::binary>> = input, line, acc) when c in ?0..?9 do
number(input, line, acc, "")
end
# -- quoted identifiers -------------------------------------------------
defp scan(<<?", rest::binary>>, line, acc) do
case quoted(rest, ?", "") do
{:ok, name, rest} -> scan(rest, line, [{:id, name, line} | acc])
:error -> {:error, {:unterminated_identifier, line}}
end
end
defp scan(<<?`, rest::binary>>, line, acc) do
case quoted(rest, ?`, "") do
{:ok, name, rest} -> scan(rest, line, [{:id, name, line} | acc])
:error -> {:error, {:unterminated_identifier, line}}
end
end
defp scan(<<?[, rest::binary>>, line, acc) do
case String.split(rest, "]", parts: 2) do
[name, rest] -> scan(rest, line, [{:id, name, line} | acc])
[_] -> {:error, {:unterminated_identifier, line}}
end
end
# -- identifiers and keywords -------------------------------------------
defp scan(<<c, _::binary>> = input, line, acc)
when c in ?a..?z or c in ?A..?Z or c == ?_ do
{word, rest} = take_while(input, &identifier_char?/1)
# Keywords are ASCII; fold ASCII case only (and skip allocating when the
# word is already lowercase) instead of the Unicode-aware String.downcase,
# which ran for every identifier token.
lowered = ascii_downcase(word)
token =
if MapSet.member?(@keyword_strings, lowered) do
{:keyword, Map.fetch!(@keyword_map, lowered), line}
else
{:id, word, line}
end
scan(rest, line, [token | acc])
end
defp scan(<<c, _::binary>>, line, _acc) do
{:error, {:unexpected_character, <<c>>, line}}
end
# -- helpers -------------------------------------------------------------
defp number(input, line, acc, prefix) do
# Track whether a `.` was seen while scanning the digits, instead of
# re-scanning the (potentially long) number with `String.contains?/2` per
# literal — that re-scan was a measurable cost on numeric-heavy workloads.
{count, digits_dot?} = take_number_count(input, 0, false)
<<digits::binary-size(^count), rest::binary>> = input
{exponent, rest} = take_exponent(rest)
text = prefix <> digits
has_dot = digits_dot? or prefix_dot?(prefix)
cond do
exponent != "" ->
mantissa = if String.ends_with?(text, "."), do: text <> "0", else: text
mantissa = if has_dot, do: mantissa, else: mantissa <> ".0"
case Float.parse(mantissa <> exponent) do
{f, ""} -> scan(rest, line, [{:float, f, line} | acc])
_ -> {:error, {:malformed_number, text <> exponent, line}}
end
has_dot ->
case Float.parse(if(String.ends_with?(text, "."), do: text <> "0", else: text)) do
{f, ""} -> scan(rest, line, [{:float, f, line} | acc])
_ -> {:error, {:malformed_number, text, line}}
end
true ->
scan(rest, line, [{:int, String.to_integer(text), line} | acc])
end
end
defp take_number_count(<<c, rest::binary>>, n, dot?) when c >= ?0 and c <= ?9,
do: take_number_count(rest, n + 1, dot?)
defp take_number_count(<<?., rest::binary>>, n, _dot?),
do: take_number_count(rest, n + 1, true)
defp take_number_count(_binary, n, dot?), do: {n, dot?}
# `prefix` is "" (a plain number) or "0." (a leading-dot number like `.5`).
defp prefix_dot?("0."), do: true
defp prefix_dot?(""), do: false
# Scientific notation: `e`/`E` with an optional sign and at least one
# digit. Anything else leaves the input untouched (`1e` is `1` then `e`).
defp take_exponent(<<e, rest::binary>>) when e in [?e, ?E] do
{sign, digits_input} =
case rest do
<<s, more::binary>> when s in [?+, ?-] -> {<<s>>, more}
rest -> {"", rest}
end
case take_while(digits_input, &(&1 in ?0..?9)) do
{"", _} -> {"", <<e, rest::binary>>}
{digits, rest} -> {<<?e, sign::binary, digits::binary>>, rest}
end
end
defp take_exponent(input), do: {"", input}
defp count_newlines(text), do: count_newlines(text, 0)
defp count_newlines(<<"\n", rest::binary>>, count), do: count_newlines(rest, count + 1)
defp count_newlines(<<_char::utf8, rest::binary>>, count), do: count_newlines(rest, count)
defp count_newlines(<<>>, count), do: count
# Scans up to a closing `quote_char`; a doubled quote char is an escape.
# Accumulate with `<<acc::binary, _>>` (the BEAM's append-buffer form) rather
# than `acc <> <<_>>`, to avoid a fresh copy per character.
defp quoted(<<q, q, rest::binary>>, q, acc), do: quoted(rest, q, <<acc::binary, q>>)
defp quoted(<<q, rest::binary>>, q, acc), do: {:ok, acc, rest}
defp quoted(<<c, rest::binary>>, q, acc), do: quoted(rest, q, <<acc::binary, c>>)
defp quoted(<<>>, _q, _acc), do: :error
# Count the matching leading bytes, then slice the prefix once — instead of
# rebuilding the accumulator binary one byte at a time (`acc <> <<c>>`), which
# made tokenizing identifiers/numbers a per-character allocation hotspot.
defp take_while(binary, fun) do
count = take_count(binary, fun, 0)
<<taken::binary-size(^count), rest::binary>> = binary
{taken, rest}
end
defp take_count(<<c, rest::binary>>, fun, n) do
if fun.(c), do: take_count(rest, fun, n + 1), else: n
end
defp take_count(<<>>, _fun, n), do: n
defp identifier_char?(c), do: c in ?a..?z or c in ?A..?Z or c in ?0..?9 or c == ?_
defp ascii_downcase(word) do
if any_upper_ascii?(word), do: fold_ascii(word), else: word
end
defp any_upper_ascii?(<<c, _rest::binary>>) when c >= ?A and c <= ?Z, do: true
defp any_upper_ascii?(<<_c, rest::binary>>), do: any_upper_ascii?(rest)
defp any_upper_ascii?(<<>>), do: false
defp fold_ascii(word), do: for(<<c <- word>>, into: "", do: <<fold_byte(c)>>)
defp fold_byte(c) when c >= ?A and c <= ?Z, do: c + 32
defp fold_byte(c), do: c
defp hex_digit?(c), do: c in ?0..?9 or c in ?a..?f or c in ?A..?F
end