Skip to main content

lib/ex_sql/sql_logic_test.ex

defmodule ExSQL.SqlLogicTest do
  @moduledoc """
  A runner for SQLite's sqllogictest conformance corpus
  (https://sqlite.org/sqllogictest) — the cross-engine suite of millions of
  generated queries with stored results.

  Each `.test` file is a sequence of records:

      statement ok          -- DDL/DML that must succeed
      statement error       -- a statement that must fail
      query <types> <sort>  -- a query whose flattened, formatted, sorted
      SELECT ...               values must match the stored lines or the
      ----                     stored `N values hashing to <md5>` digest
      <expected values>

  `skipif sqlite` / `onlyif <engine>` conditions are honored as the engine
  named "sqlite", since ExSQL implements SQLite semantics. Statement
  failures abort the rest of the file (the schema has diverged); query
  failures are counted and execution continues.

      ExSQL.SqlLogicTest.run_file("path/to/select1.test")
      #=> %{ok: 990, fail: 10, skip: 0, aborted: false, failures: [...]}

  Pass `max_records: n` to run only the first `n` parsed records in a large
  corpus file.
  """

  alias ExSQL.{Database, Executor}

  @engine "sqlite"
  @max_failure_examples 10

  @type stats :: %{
          ok: non_neg_integer(),
          fail: non_neg_integer(),
          skip: non_neg_integer(),
          aborted: boolean(),
          failures: [map()]
        }

  @doc "Runs one corpus file against a fresh database."
  @spec run_file(Path.t(), keyword()) :: stats()
  def run_file(path, opts \\ []) do
    path |> File.read!() |> run_string(path, opts)
  end

  @doc "Runs corpus content (for tests and tools)."
  @spec run_string(String.t(), String.t(), keyword()) :: stats()
  def run_string(content, name \\ "(inline)", opts \\ []) do
    records =
      content
      |> parse()
      |> limit_records(opts[:max_records])

    state = %{db: Database.new(), name: name}
    stats = %{ok: 0, fail: 0, skip: 0, aborted: false, failures: []}

    {_state, stats} =
      Enum.reduce_while(records, {state, stats}, fn record, {state, stats} ->
        case execute_record(record, state, stats) do
          {:cont, state, stats} -> {:cont, {state, stats}}
          {:halt, state, stats} -> {:halt, {state, %{stats | aborted: true}}}
        end
      end)

    %{stats | failures: Enum.reverse(stats.failures)}
  end

  defp limit_records(records, nil), do: records
  defp limit_records(records, n) when is_integer(n) and n >= 0, do: Enum.take(records, n)

  # -- record execution ------------------------------------------------------------

  defp execute_record({:halt, _line}, state, stats), do: {:halt, state, stats}

  defp execute_record({:skip, _reason}, state, stats),
    do: {:cont, state, %{stats | skip: stats.skip + 1}}

  defp execute_record({:statement, expect, sql, line}, state, stats) do
    case Executor.run(state.db, sql) do
      {:ok, _results, db} when expect == :ok ->
        {:cont, %{state | db: db}, %{stats | ok: stats.ok + 1}}

      {:error, _error, db} when expect == :error ->
        {:cont, %{state | db: db}, %{stats | ok: stats.ok + 1}}

      {:ok, _results, db} ->
        stats = record_failure(stats, line, sql, "expected the statement to fail")
        {:cont, %{state | db: db}, stats}

      {:error, error, db} ->
        # The schema has diverged; the rest of the file cannot be trusted.
        stats = record_failure(stats, line, sql, "statement failed: #{error.message}")
        {:halt, %{state | db: db}, stats}
    end
  rescue
    e -> {:halt, state, record_failure(stats, line, sql, "crashed: #{Exception.message(e)}")}
  end

  defp execute_record({:query, types, sort_mode, expected, sql, line}, state, stats) do
    case Executor.run(state.db, sql) do
      {:error, error, _db} ->
        {:cont, state, record_failure(stats, line, sql, "query failed: #{error.message}")}

      {:ok, results, db} ->
        values =
          results
          |> List.last()
          |> result_values(types)
          |> apply_sort(sort_mode, String.length(types))

        state = %{state | db: db}

        case compare(values, expected) do
          :ok ->
            {:cont, state, %{stats | ok: stats.ok + 1}}

          {:mismatch, detail} ->
            {:cont, state, record_failure(stats, line, sql, detail)}
        end
    end
  rescue
    e -> {:cont, state, record_failure(stats, line, sql, "crashed: #{Exception.message(e)}")}
  end

  defp record_failure(stats, line, sql, detail) do
    failures =
      if length(stats.failures) < @max_failure_examples do
        [%{line: line, sql: String.slice(sql, 0, 200), detail: detail} | stats.failures]
      else
        stats.failures
      end

    %{stats | fail: stats.fail + 1, failures: failures}
  end

  # -- result formatting -------------------------------------------------------------

  defp result_values(%{rows: rows}, types) do
    type_chars = String.to_charlist(types)

    for row <- rows, {value, type} <- Enum.zip(row, type_chars) do
      format_value(value, type)
    end
  end

  defp format_value(nil, _type), do: "NULL"

  defp format_value(value, ?I) do
    case value do
      v when is_integer(v) -> Integer.to_string(v)
      v when is_float(v) -> Integer.to_string(trunc(v))
      v when is_binary(v) -> Integer.to_string(text_to_integer(v))
      {:blob, _} -> "0"
    end
  end

  defp format_value(value, ?R) do
    float =
      case value do
        v when is_number(v) -> v * 1.0
        v when is_binary(v) -> text_to_float(v)
        {:blob, _} -> 0.0
      end

    :erlang.float_to_binary(float, decimals: 3)
  end

  defp format_value(value, ?T) do
    text =
      case value do
        v when is_binary(v) -> v
        v when is_integer(v) -> Integer.to_string(v)
        v when is_float(v) -> ExSQL.Value.to_text(v)
        {:blob, b} -> b
      end

    case text do
      "" -> "(empty)"
      text -> for <<c <- text>>, into: "", do: if(c < 0x20 or c > 0x7E, do: "@", else: <<c>>)
    end
  end

  defp text_to_integer(text) do
    case Integer.parse(String.trim_leading(text)) do
      {n, _rest} -> n
      :error -> 0
    end
  end

  defp text_to_float(text) do
    trimmed = String.trim_leading(text)

    case Float.parse(trimmed) do
      {f, _rest} ->
        f

      :error ->
        case Integer.parse(trimmed) do
          {n, _rest} -> n * 1.0
          :error -> 0.0
        end
    end
  end

  defp apply_sort(values, :nosort, _ncols), do: values

  defp apply_sort(values, :valuesort, _ncols), do: Enum.sort(values)

  defp apply_sort(values, :rowsort, ncols) do
    values
    |> Enum.chunk_every(ncols)
    |> Enum.sort()
    |> List.flatten()
  end

  # -- comparison ---------------------------------------------------------------------

  defp compare(values, {:hash, count, md5}) do
    cond do
      length(values) != count ->
        {:mismatch, "expected #{count} values, got #{length(values)}"}

      result_hash(values) != md5 ->
        {:mismatch, "result hash mismatch (#{length(values)} values)"}

      true ->
        :ok
    end
  end

  defp compare(values, {:values, expected}) do
    cond do
      length(values) != length(expected) ->
        {:mismatch, "expected #{length(expected)} values, got #{length(values)}"}

      values != expected ->
        diff =
          Enum.zip(expected, values)
          |> Enum.find(fn {e, v} -> e != v end)
          |> then(fn {e, v} -> "first difference: expected #{inspect(e)}, got #{inspect(v)}" end)

        {:mismatch, diff}

      true ->
        :ok
    end
  end

  defp result_hash(values) do
    values
    |> Enum.map(&[&1, "\n"])
    |> IO.iodata_to_binary()
    |> then(&:crypto.hash(:md5, &1))
    |> Base.encode16(case: :lower)
  end

  # -- parsing ----------------------------------------------------------------------

  defp parse(content) do
    content
    |> String.split("\n")
    |> Enum.map(&String.trim_trailing(&1, "\r"))
    |> Enum.with_index(1)
    |> parse_records([])
    |> Enum.reverse()
  end

  defp parse_records([], acc), do: acc

  defp parse_records(lines, acc) do
    case skip_blank(lines) do
      [] -> acc
      lines -> parse_record(lines, acc)
    end
  end

  defp skip_blank(lines) do
    Enum.drop_while(lines, fn {line, _n} ->
      trimmed = String.trim(line)
      trimmed == "" or String.starts_with?(trimmed, "#")
    end)
  end

  defp parse_record(lines, acc) do
    {conditions, lines} = take_conditions(lines, [])

    case lines do
      [] ->
        acc

      [{line, n} | rest] ->
        cond do
          not applicable?(conditions) ->
            {rest, record} = skip_record_body(line, rest)

            if record,
              do: parse_records(rest, [{:skip, :condition} | acc]),
              else: parse_records(rest, acc)

          String.starts_with?(line, "hash-threshold") ->
            parse_records(rest, acc)

          String.trim(line) == "halt" ->
            [{:halt, n} | acc]

          String.starts_with?(line, "statement") ->
            expect = if String.contains?(line, "error"), do: :error, else: :ok
            {sql_lines, rest} = take_until_blank(rest)
            parse_records(rest, [{:statement, expect, join_sql(sql_lines), n} | acc])

          String.starts_with?(line, "query") ->
            {record, rest} = parse_query(line, n, rest)
            parse_records(rest, [record | acc])

          true ->
            # Unrecognized directive: skip through the record body.
            {rest, _} = skip_record_body(line, rest)
            parse_records(rest, acc)
        end
    end
  end

  defp take_conditions([{line, _n} | rest] = lines, acc) do
    case String.split(String.trim(line)) do
      ["skipif", engine] -> take_conditions(rest, [{:skipif, engine} | acc])
      ["onlyif", engine] -> take_conditions(rest, [{:onlyif, engine} | acc])
      _other -> {acc, lines}
    end
  end

  defp take_conditions([], acc), do: {acc, []}

  defp applicable?(conditions) do
    Enum.all?(conditions, fn
      {:skipif, @engine} -> false
      {:skipif, _other} -> true
      {:onlyif, @engine} -> true
      {:onlyif, _other} -> false
    end)
  end

  # Consumes a record body without interpreting it (for skipped records).
  # Returns {remaining_lines, was_a_record?}.
  defp skip_record_body(line, rest) do
    cond do
      String.starts_with?(line, "hash-threshold") or String.trim(line) == "halt" ->
        {rest, false}

      String.starts_with?(line, "query") ->
        {_sql, rest} =
          take_until(rest, fn l -> String.trim(l) == "----" or String.trim(l) == "" end)

        case rest do
          [{"----" <> _, _} | rest] ->
            {_expected, rest} = take_until_blank(rest)
            {rest, true}

          rest ->
            {rest, true}
        end

      true ->
        {_sql, rest} = take_until_blank(rest)
        {rest, true}
    end
  end

  defp parse_query(line, n, rest) do
    {types, sort_mode} =
      case String.split(String.trim(line)) do
        ["query", types | opts] ->
          sort =
            case opts do
              ["rowsort" | _] -> :rowsort
              ["valuesort" | _] -> :valuesort
              _other -> :nosort
            end

          {types, sort}

        _other ->
          {"", :nosort}
      end

    {sql_lines, rest} =
      take_until(rest, fn l -> String.trim(l) == "----" or String.trim(l) == "" end)

    {expected, rest} =
      case rest do
        [{"----" <> _, _} | rest] ->
          {expected_lines, rest} = take_until_blank(rest)

          expected =
            case expected_lines do
              [first | _] = lines ->
                case Regex.run(~r/^(\d+) values hashing to ([0-9a-f]{32})$/, String.trim(first)) do
                  [_, count, md5] when length(lines) == 1 ->
                    {:hash, String.to_integer(count), md5}

                  _no_hash ->
                    {:values, lines}
                end

              [] ->
                {:values, []}
            end

          {expected, rest}

        rest ->
          {{:values, []}, rest}
      end

    {{:query, types, sort_mode, expected, join_sql(sql_lines), n}, rest}
  end

  defp take_until_blank(lines), do: take_until(lines, &(String.trim(&1) == ""))

  defp take_until(lines, stop?) do
    {taken, rest} = Enum.split_while(lines, fn {line, _n} -> not stop?.(line) end)
    {Enum.map(taken, &elem(&1, 0)), rest}
  end

  defp join_sql(lines), do: lines |> Enum.join("\n") |> String.trim()
end