Skip to main content

lib/ex_sql/tokenizer.ex

defmodule ExSQL.Tokenizer do
  @moduledoc """
  Lexical analysis for SQL text.

  Mirrors SQLite's `tokenize.c`, which classifies bytes through a 256-entry
  lookup table and resolves keywords with a generated perfect hash. In Elixir
  the same job is done with binary pattern matching: each clause of `scan/3`
  corresponds to a character class, and keyword recognition is a `MapSet`
  lookup on the downcased identifier.

  Tokens are `{type, value, line}` tuples:

    * `{:keyword, :select, 1}` — case-insensitive SQL keywords
    * `{:id, "users", 1}` — identifiers (bare, `"quoted"`, `[bracketed]`, or backticked)
    * `{:int, 42, 1}` / `{:float, 3.14, 1}` — numeric literals
    * `{:string, "abc", 1}` — single-quoted string literals (`''` escapes a quote)
    * `{:blob, <<0xAB>>, 1}` — `x'AB'` hex blob literals
    * punctuation and operators such as `{:comma, ",", 1}`, `{:le, "<=", 1}`
  """

  # NOTE: the type names integer/text/real/blob are deliberately NOT keywords —
  # they tokenize as identifiers so a column's declared type keeps its original
  # case for PRAGMA table_info (SQLite stores the type exactly as written), and
  # so they can be used as identifiers. type_name/2 reads them as plain words.
  @keywords ~w(
    abort add all alter always analyze and as asc autoincrement begin between by case cast check
    collate column commit conflict constraint create cross current default delete desc distinct
    database do drop else end escape except exists explain fail false filter following foreign from full generated glob group having if
    ignore in index inner insert
    intersect into is join key left like limit match natural not nothing null offset on or
    order outer over partition plan pragma preceding primary query range recursive release rename replace right
    references regexp reindex returning rollback rows
    savepoint select
    row set table then to transaction true unbounded union unique update using values attach detach
    stored strict vacuum view virtual when where window with without
  )a

  @keyword_strings MapSet.new(@keywords, &Atom.to_string/1)
  @keyword_map Map.new(@keywords, &{Atom.to_string(&1), &1})

  @type token :: {atom(), term(), pos_integer()}

  @doc """
  Tokenizes `sql`, returning `{:ok, tokens}` or `{:error, reason}`.

  The token list does not include trailing end-of-input markers; the parser
  treats the empty list as end of input.
  """
  @spec tokenize(String.t()) :: {:ok, [token()]} | {:error, term()}
  def tokenize(sql) when is_binary(sql) do
    case scan(sql, 1, []) do
      {:ok, tokens} -> {:ok, number_placeholders(tokens)}
      error -> error
    end
  end

  # Assigns bind-parameter indexes in source order, per statement, following
  # sqlite3_bind_parameter_index semantics: `?` takes one more than the
  # largest index so far, `?NNN` takes exactly NNN, and a named parameter
  # reuses its previous index or takes the next free one. Placeholder tokens
  # come out as `{:placeholder, {index, raw_text}, line}`.
  defp number_placeholders(tokens) do
    # The map_reduce rebuilds the whole token list; skip it entirely when there
    # are no placeholders (all literal SQL — the common non-prepared case).
    if Enum.any?(tokens, &match?({:placeholder, _, _}, &1)) do
      do_number_placeholders(tokens)
    else
      tokens
    end
  end

  defp do_number_placeholders(tokens) do
    {tokens, _state} =
      Enum.map_reduce(tokens, {0, %{}}, fn
        {:semicolon, _, _} = token, _state ->
          {token, {0, %{}}}

        {:placeholder, raw, line}, {max, names} ->
          {index, max, names} =
            case raw do
              "?" ->
                {max + 1, max + 1, names}

              "?" <> digits ->
                n = String.to_integer(digits)
                {n, max(max, n), names}

              name ->
                case Map.fetch(names, name) do
                  {:ok, index} -> {index, max, names}
                  :error -> {max + 1, max + 1, Map.put(names, name, max + 1)}
                end
            end

          {{:placeholder, {index, raw}, line}, {max, names}}

        token, state ->
          {token, state}
      end)

    tokens
  end

  defp scan(<<>>, _line, acc), do: {:ok, Enum.reverse(acc)}

  defp scan(<<"\n", rest::binary>>, line, acc), do: scan(rest, line + 1, acc)

  defp scan(<<c, rest::binary>>, line, acc) when c in [?\s, ?\t, ?\r, ?\f] do
    scan(rest, line, acc)
  end

  # -- comments ---------------------------------------------------------

  defp scan(<<"--", rest::binary>>, line, acc) do
    case String.split(rest, "\n", parts: 2) do
      [_comment, rest] -> scan(rest, line + 1, acc)
      [_comment] -> scan(<<>>, line, acc)
    end
  end

  defp scan(<<"/*", rest::binary>>, line, acc) do
    case String.split(rest, "*/", parts: 2) do
      [comment, rest] ->
        skipped = count_newlines(comment)
        scan(rest, line + skipped, acc)

      [_unterminated] ->
        {:error, {:unterminated_comment, line}}
    end
  end

  # -- multi-character operators (longest match first) ------------------

  defp scan(<<"<>", rest::binary>>, line, acc), do: scan(rest, line, [{:ne, "<>", line} | acc])
  defp scan(<<"!=", rest::binary>>, line, acc), do: scan(rest, line, [{:ne, "!=", line} | acc])
  defp scan(<<"<=", rest::binary>>, line, acc), do: scan(rest, line, [{:le, "<=", line} | acc])
  defp scan(<<">=", rest::binary>>, line, acc), do: scan(rest, line, [{:ge, ">=", line} | acc])
  defp scan(<<"==", rest::binary>>, line, acc), do: scan(rest, line, [{:eq, "==", line} | acc])

  defp scan(<<"||", rest::binary>>, line, acc),
    do: scan(rest, line, [{:concat, "||", line} | acc])

  defp scan(<<"<<", rest::binary>>, line, acc), do: scan(rest, line, [{:shl, "<<", line} | acc])
  defp scan(<<">>", rest::binary>>, line, acc), do: scan(rest, line, [{:shr, ">>", line} | acc])

  defp scan(<<"->>", rest::binary>>, line, acc),
    do: scan(rest, line, [{:arrow_text, "->>", line} | acc])

  defp scan(<<"->", rest::binary>>, line, acc),
    do: scan(rest, line, [{:arrow, "->", line} | acc])

  # -- single-character tokens ------------------------------------------

  defp scan(<<"=", rest::binary>>, line, acc), do: scan(rest, line, [{:eq, "=", line} | acc])
  defp scan(<<"<", rest::binary>>, line, acc), do: scan(rest, line, [{:lt, "<", line} | acc])
  defp scan(<<">", rest::binary>>, line, acc), do: scan(rest, line, [{:gt, ">", line} | acc])
  defp scan(<<"(", rest::binary>>, line, acc), do: scan(rest, line, [{:lparen, "(", line} | acc])
  defp scan(<<")", rest::binary>>, line, acc), do: scan(rest, line, [{:rparen, ")", line} | acc])
  defp scan(<<",", rest::binary>>, line, acc), do: scan(rest, line, [{:comma, ",", line} | acc])

  defp scan(<<";", rest::binary>>, line, acc),
    do: scan(rest, line, [{:semicolon, ";", line} | acc])

  defp scan(<<"*", rest::binary>>, line, acc), do: scan(rest, line, [{:star, "*", line} | acc])
  defp scan(<<"+", rest::binary>>, line, acc), do: scan(rest, line, [{:plus, "+", line} | acc])
  defp scan(<<"-", rest::binary>>, line, acc), do: scan(rest, line, [{:minus, "-", line} | acc])
  defp scan(<<"/", rest::binary>>, line, acc), do: scan(rest, line, [{:slash, "/", line} | acc])
  defp scan(<<"%", rest::binary>>, line, acc), do: scan(rest, line, [{:percent, "%", line} | acc])
  defp scan(<<"&", rest::binary>>, line, acc), do: scan(rest, line, [{:bitand, "&", line} | acc])
  defp scan(<<"|", rest::binary>>, line, acc), do: scan(rest, line, [{:bitor, "|", line} | acc])
  defp scan(<<"~", rest::binary>>, line, acc), do: scan(rest, line, [{:tilde, "~", line} | acc])

  defp scan(<<".", c, _::binary>> = input, line, acc) when c in ?0..?9 do
    <<_dot, rest::binary>> = input
    number(rest, line, acc, "0.")
  end

  defp scan(<<".", rest::binary>>, line, acc), do: scan(rest, line, [{:dot, ".", line} | acc])

  # -- bind parameters: ?, ?NNN, :name, @name, $name ---------------------

  defp scan(<<"?", rest::binary>>, line, acc) do
    {digits, rest} = take_while(rest, &(&1 in ?0..?9))
    scan(rest, line, [{:placeholder, "?" <> digits, line} | acc])
  end

  defp scan(<<sigil, rest::binary>>, line, acc) when sigil in [?:, ?@, ?$] do
    case take_while(rest, &identifier_char?/1) do
      {"", _rest} -> {:error, {:unexpected_character, <<sigil>>, line}}
      {name, rest} -> scan(rest, line, [{:placeholder, <<sigil, name::binary>>, line} | acc])
    end
  end

  # -- literals ----------------------------------------------------------

  defp scan(<<q, rest::binary>>, line, acc) when q == ?' do
    case quoted(rest, ?', "") do
      {:ok, value, rest} -> scan(rest, line, [{:string, value, line} | acc])
      :error -> {:error, {:unterminated_string, line}}
    end
  end

  defp scan(<<x, ?', rest::binary>>, line, acc) when x in [?x, ?X] do
    case quoted(rest, ?', "") do
      {:ok, hex, rest} ->
        case Base.decode16(hex, case: :mixed) do
          {:ok, blob} -> scan(rest, line, [{:blob, blob, line} | acc])
          :error -> {:error, {:malformed_blob_literal, hex, line}}
        end

      :error ->
        {:error, {:unterminated_string, line}}
    end
  end

  # Hexadecimal integer literal `0x…` (SQLite: a 64-bit value, so digits beyond
  # 64 bits wrap and the result is interpreted signed — `0xffffffffffffffff` is
  # -1). Requires at least one hex digit; otherwise it falls through to the
  # decimal number scan (`0` then identifier).
  defp scan(<<"0", x, c, _::binary>> = input, line, acc)
       when x in [?x, ?X] and (c in ?0..?9 or c in ?a..?f or c in ?A..?F) do
    <<_::binary-size(2), rest::binary>> = input
    {hex, rest} = take_while(rest, &hex_digit?/1)

    wrapped = rem(String.to_integer(hex, 16), 0x1_0000_0000_0000_0000)

    value =
      if wrapped >= 0x8000_0000_0000_0000, do: wrapped - 0x1_0000_0000_0000_0000, else: wrapped

    scan(rest, line, [{:int, value, line} | acc])
  end

  defp scan(<<c, _::binary>> = input, line, acc) when c in ?0..?9 do
    number(input, line, acc, "")
  end

  # -- quoted identifiers -------------------------------------------------

  defp scan(<<?", rest::binary>>, line, acc) do
    case quoted(rest, ?", "") do
      {:ok, name, rest} -> scan(rest, line, [{:id, name, line} | acc])
      :error -> {:error, {:unterminated_identifier, line}}
    end
  end

  defp scan(<<?`, rest::binary>>, line, acc) do
    case quoted(rest, ?`, "") do
      {:ok, name, rest} -> scan(rest, line, [{:id, name, line} | acc])
      :error -> {:error, {:unterminated_identifier, line}}
    end
  end

  defp scan(<<?[, rest::binary>>, line, acc) do
    case String.split(rest, "]", parts: 2) do
      [name, rest] -> scan(rest, line, [{:id, name, line} | acc])
      [_] -> {:error, {:unterminated_identifier, line}}
    end
  end

  # -- identifiers and keywords -------------------------------------------

  defp scan(<<c, _::binary>> = input, line, acc)
       when c in ?a..?z or c in ?A..?Z or c == ?_ do
    {word, rest} = take_while(input, &identifier_char?/1)
    # Keywords are ASCII; fold ASCII case only (and skip allocating when the
    # word is already lowercase) instead of the Unicode-aware String.downcase,
    # which ran for every identifier token.
    lowered = ascii_downcase(word)

    token =
      if MapSet.member?(@keyword_strings, lowered) do
        {:keyword, Map.fetch!(@keyword_map, lowered), line}
      else
        {:id, word, line}
      end

    scan(rest, line, [token | acc])
  end

  defp scan(<<c, _::binary>>, line, _acc) do
    {:error, {:unexpected_character, <<c>>, line}}
  end

  # -- helpers -------------------------------------------------------------

  defp number(input, line, acc, prefix) do
    # Track whether a `.` was seen while scanning the digits, instead of
    # re-scanning the (potentially long) number with `String.contains?/2` per
    # literal — that re-scan was a measurable cost on numeric-heavy workloads.
    {count, digits_dot?} = take_number_count(input, 0, false)
    <<digits::binary-size(^count), rest::binary>> = input
    {exponent, rest} = take_exponent(rest)
    text = prefix <> digits
    has_dot = digits_dot? or prefix_dot?(prefix)

    cond do
      exponent != "" ->
        mantissa = if String.ends_with?(text, "."), do: text <> "0", else: text
        mantissa = if has_dot, do: mantissa, else: mantissa <> ".0"

        case Float.parse(mantissa <> exponent) do
          {f, ""} -> scan(rest, line, [{:float, f, line} | acc])
          _ -> {:error, {:malformed_number, text <> exponent, line}}
        end

      has_dot ->
        case Float.parse(if(String.ends_with?(text, "."), do: text <> "0", else: text)) do
          {f, ""} -> scan(rest, line, [{:float, f, line} | acc])
          _ -> {:error, {:malformed_number, text, line}}
        end

      true ->
        scan(rest, line, [{:int, String.to_integer(text), line} | acc])
    end
  end

  defp take_number_count(<<c, rest::binary>>, n, dot?) when c >= ?0 and c <= ?9,
    do: take_number_count(rest, n + 1, dot?)

  defp take_number_count(<<?., rest::binary>>, n, _dot?),
    do: take_number_count(rest, n + 1, true)

  defp take_number_count(_binary, n, dot?), do: {n, dot?}

  # `prefix` is "" (a plain number) or "0." (a leading-dot number like `.5`).
  defp prefix_dot?("0."), do: true
  defp prefix_dot?(""), do: false

  # Scientific notation: `e`/`E` with an optional sign and at least one
  # digit. Anything else leaves the input untouched (`1e` is `1` then `e`).
  defp take_exponent(<<e, rest::binary>>) when e in [?e, ?E] do
    {sign, digits_input} =
      case rest do
        <<s, more::binary>> when s in [?+, ?-] -> {<<s>>, more}
        rest -> {"", rest}
      end

    case take_while(digits_input, &(&1 in ?0..?9)) do
      {"", _} -> {"", <<e, rest::binary>>}
      {digits, rest} -> {<<?e, sign::binary, digits::binary>>, rest}
    end
  end

  defp take_exponent(input), do: {"", input}

  defp count_newlines(text), do: count_newlines(text, 0)
  defp count_newlines(<<"\n", rest::binary>>, count), do: count_newlines(rest, count + 1)
  defp count_newlines(<<_char::utf8, rest::binary>>, count), do: count_newlines(rest, count)
  defp count_newlines(<<>>, count), do: count

  # Scans up to a closing `quote_char`; a doubled quote char is an escape.
  # Accumulate with `<<acc::binary, _>>` (the BEAM's append-buffer form) rather
  # than `acc <> <<_>>`, to avoid a fresh copy per character.
  defp quoted(<<q, q, rest::binary>>, q, acc), do: quoted(rest, q, <<acc::binary, q>>)
  defp quoted(<<q, rest::binary>>, q, acc), do: {:ok, acc, rest}
  defp quoted(<<c, rest::binary>>, q, acc), do: quoted(rest, q, <<acc::binary, c>>)
  defp quoted(<<>>, _q, _acc), do: :error

  # Count the matching leading bytes, then slice the prefix once — instead of
  # rebuilding the accumulator binary one byte at a time (`acc <> <<c>>`), which
  # made tokenizing identifiers/numbers a per-character allocation hotspot.
  defp take_while(binary, fun) do
    count = take_count(binary, fun, 0)
    <<taken::binary-size(^count), rest::binary>> = binary
    {taken, rest}
  end

  defp take_count(<<c, rest::binary>>, fun, n) do
    if fun.(c), do: take_count(rest, fun, n + 1), else: n
  end

  defp take_count(<<>>, _fun, n), do: n

  defp identifier_char?(c), do: c in ?a..?z or c in ?A..?Z or c in ?0..?9 or c == ?_

  defp ascii_downcase(word) do
    if any_upper_ascii?(word), do: fold_ascii(word), else: word
  end

  defp any_upper_ascii?(<<c, _rest::binary>>) when c >= ?A and c <= ?Z, do: true
  defp any_upper_ascii?(<<_c, rest::binary>>), do: any_upper_ascii?(rest)
  defp any_upper_ascii?(<<>>), do: false

  defp fold_ascii(word), do: for(<<c <- word>>, into: "", do: <<fold_byte(c)>>)

  defp fold_byte(c) when c >= ?A and c <= ?Z, do: c + 32
  defp fold_byte(c), do: c

  defp hex_digit?(c), do: c in ?0..?9 or c in ?a..?f or c in ?A..?F
end