Skip to main content

lib/ex_sql/table.ex

defmodule ExSQL.Table do
  @moduledoc """
  In-memory table storage.

  Rows live in a map keyed by rowid, mirroring SQLite's model where every
  table is a B-tree keyed by a 64-bit rowid. Each row is stored as a
  **positional tuple** in `columns` order (~63% smaller than a per-row
  key-bearing map); the executor's hot single-table scan reads columns by
  position via the frame's `column_index`, while other call sites widen a row
  back to a `key => value` map at the read boundary (`scan/1`, `fetch_row/2`).
  A single `INTEGER PRIMARY KEY` column is detected as the rowid alias, with
  SQLite's semantics: inserting NULL into it auto-assigns the next rowid, and
  the column's value *is* the row's key.

  All functions are pure — they return an updated table or an error tuple.
  Constraint enforcement (NOT NULL, UNIQUE, PRIMARY KEY) happens here, by
  scanning; real indexes can replace the scans later without changing the
  interface.
  """

  alias ExSQL.AST.ColumnDef
  alias ExSQL.Value

  defstruct name: nil,
            schema: nil,
            columns: [],
            rows: %{},
            next_rowid: 1,
            rowid_alias: nil,
            autoincrement: false,
            sequence: 0,
            sequence_row: false,
            without_rowid: false,
            strict: false,
            indexes: [],
            autoindexes: [],
            composite_keys: [],
            composite_uniques: [],
            foreign_keys: [],
            checks: [],
            # Cached `{key, name, affinity, collate}` tuples for the executor's
            # frame template, so building a frame per query skips recomputing
            # `key/1` for every column. `nil` means "not cached" — recompute;
            # column mutations reset it to `nil`. See `frame_columns/1`.
            frame_columns: nil,
            # Cached column key => positional-index map (see column_index/1).
            column_index: nil

  @type index :: %{name: String.t(), columns: [String.t()], unique: boolean()}
  @type row :: %{String.t() => Value.t()}

  @typedoc "A composite PK or UNIQUE constraint: {constraint_name | nil, [column_key]}."
  @type composite_constraint :: {String.t() | nil, [String.t()]}

  @typedoc "A CHECK constraint: {constraint_name | nil, expr}."
  @type check_constraint :: {String.t() | nil, term()}

  @typedoc """
  A table-level FK: {child_keys, parent_table, parent_keys, actions}, where
  actions is `%{on_delete: action, on_update: action, deferred: boolean}`.
  """
  @type foreign_key ::
          {[String.t()], String.t(), [String.t()], ExSQL.AST.CreateTable.fk_actions()}

  @type t :: %__MODULE__{
          name: String.t(),
          schema: String.t() | nil,
          columns: [ColumnDef.t()],
          rows: %{integer() => row()},
          next_rowid: pos_integer(),
          rowid_alias: String.t() | nil,
          autoincrement: boolean(),
          sequence: integer(),
          sequence_row: boolean(),
          without_rowid: boolean(),
          strict: boolean(),
          indexes: [index()],
          autoindexes: [index()],
          composite_keys: [composite_constraint()],
          composite_uniques: [composite_constraint()],
          foreign_keys: [foreign_key()],
          checks: [check_constraint()]
        }

  @doc "Creates a table from a name, parsed column definitions, and optional table-level constraints."
  @spec new(String.t(), [ColumnDef.t()], keyword()) :: t()
  def new(name, columns, opts \\ []) do
    composite_keys = Keyword.get(opts, :composite_keys, [])
    composite_uniques = Keyword.get(opts, :composite_uniques, [])
    foreign_keys = Keyword.get(opts, :foreign_keys, [])
    checks = Keyword.get(opts, :checks, [])
    without_rowid = Keyword.get(opts, :without_rowid, false)
    strict = Keyword.get(opts, :strict, false)
    schema = Keyword.get(opts, :schema)
    columns = if strict, do: strict_columns(columns), else: columns

    # A table-level PRIMARY KEY(a) over a single INTEGER column → rowid alias,
    # same as inline INTEGER PRIMARY KEY.
    rowid_alias =
      if without_rowid, do: nil, else: find_rowid_alias(columns, composite_keys)

    autoincrement =
      rowid_alias != nil and
        Enum.any?(columns, &(key(&1.name) == rowid_alias and &1.autoincrement))

    %__MODULE__{
      name: name,
      schema: schema,
      columns: columns,
      rowid_alias: rowid_alias,
      autoincrement: autoincrement,
      without_rowid: without_rowid,
      strict: strict,
      composite_keys: composite_keys,
      composite_uniques: composite_uniques,
      foreign_keys: foreign_keys,
      checks: checks,
      frame_columns: frame_columns_for(columns),
      column_index: column_index_for(frame_columns_for(columns))
    }
  end

  @doc """
  The `{key, name, affinity, collate}` tuples used to build a query frame.
  Returns the cached value, recomputing if the cache was invalidated by a
  column mutation.
  """
  @spec frame_columns(t()) :: [{String.t(), String.t(), atom(), String.t() | nil}]
  def frame_columns(%__MODULE__{frame_columns: nil} = table), do: frame_columns_for(table.columns)
  def frame_columns(%__MODULE__{frame_columns: cached}), do: cached

  defp frame_columns_for(columns) do
    Enum.map(columns, &{key(&1.name), &1.name, &1.affinity, &1.collate})
  end

  # --- positional row access -------------------------------------------------
  #
  # Rows are stored as positional tuples in `columns` order: ~63% smaller than
  # the per-row key-bearing maps, and read by `elem/2` instead of a hashed
  # `Map.get`. The contract is deliberately tolerant of the sparse cases the map
  # relied on — a position past the tuple's arity (rows predating an
  # `ALTER TABLE ADD COLUMN`, or a virtual column not materialized) reads as
  # `nil`, exactly like a missing map key did.

  @doc "Column key → positional index, in `columns` order (matches a positional row tuple). Cached."
  @spec column_index(t()) :: %{String.t() => non_neg_integer()}
  def column_index(%__MODULE__{column_index: nil} = table),
    do: column_index_for(frame_columns(table))

  def column_index(%__MODULE__{column_index: cached}), do: cached

  defp column_index_for(frame_columns) do
    frame_columns
    |> Enum.with_index()
    |> Map.new(fn {{key, _name, _aff, _coll}, index} -> {key, index} end)
  end

  @doc """
  Reads a positional row tuple at `index`. An out-of-range index reads as `nil`,
  matching the old sparse-map behavior for columns a given row never stored.
  """
  @spec cell(tuple(), non_neg_integer()) :: Value.t()
  def cell(row, index) when is_integer(index) and index >= 0 and index < tuple_size(row),
    do: :erlang.element(index + 1, row)

  def cell(_row, _index), do: nil

  @doc "Reads a column from a positional row by key (resolves the index via the table)."
  @spec cell(t(), tuple(), String.t()) :: Value.t()
  def cell(table, row, key) do
    case column_index(table) do
      %{^key => index} -> cell(row, index)
      _ -> nil
    end
  end

  @doc "Builds a positional row tuple from a `key => value` map (missing keys → `nil`)."
  @spec row_from_map(t(), row()) :: tuple()
  def row_from_map(table, map) do
    table
    |> frame_columns()
    |> Enum.map(fn {key, _name, _aff, _coll} -> Map.get(map, key) end)
    |> List.to_tuple()
  end

  @doc "Rebuilds a `key => value` map from a positional row tuple (compat for un-migrated callers)."
  @spec row_to_map(t(), tuple()) :: row()
  def row_to_map(table, row) do
    table
    |> frame_columns()
    |> Enum.with_index()
    |> Map.new(fn {{key, _name, _aff, _coll}, index} -> {key, cell(row, index)} end)
  end

  defp strict_columns(columns) do
    Enum.map(columns, fn column ->
      if is_binary(column.declared_type) and String.upcase(column.declared_type) == "ANY" do
        %{column | affinity: :any}
      else
        column
      end
    end)
  end

  # Look for rowid alias: inline INTEGER PRIMARY KEY, or a table-level
  # PRIMARY KEY(a) over a single column that has INTEGER affinity.
  defp find_rowid_alias(columns, composite_keys) do
    inline =
      Enum.find_value(columns, fn column ->
        if column.primary_key and column.affinity == :integer, do: key(column.name)
      end)

    if inline do
      inline
    else
      case composite_keys do
        [{_name, [col_key]}] ->
          col = Enum.find(columns, &(key(&1.name) == col_key))
          if col && col.affinity == :integer, do: col_key

        _ ->
          nil
      end
    end
  end

  @doc "Case-insensitive column lookup key, since SQL identifiers fold case."
  @spec key(String.t()) :: String.t()
  def key(name) when is_binary(name) do
    # SQLite folds identifier case for ASCII only (a-z), not Unicode. So an
    # ASCII fold is both correct and far cheaper than `String.downcase/1`
    # (Unicode-aware). An already-lowercase name (the common case) is returned
    # unchanged with no allocation.
    if any_upper_ascii?(name), do: ascii_downcase(name), else: name
  end

  defp any_upper_ascii?(<<c, _rest::binary>>) when c >= ?A and c <= ?Z, do: true
  defp any_upper_ascii?(<<_c, rest::binary>>), do: any_upper_ascii?(rest)
  defp any_upper_ascii?(<<>>), do: false

  defp ascii_downcase(name), do: for(<<c <- name>>, into: <<>>, do: <<downcase_byte(c)>>)

  defp downcase_byte(c) when c >= ?A and c <= ?Z, do: c + 32
  defp downcase_byte(c), do: c

  @doc "Returns the ColumnDef for `name`, or `nil`."
  @spec column(t(), String.t()) :: ColumnDef.t() | nil
  def column(table, name) do
    lowered = key(name)

    # Match against the cached, pre-folded `frame_columns` keys so we fold only
    # the input name once, instead of re-folding every column's name on each
    # lookup (this is hot — `insert_targets` calls it per target column).
    case Enum.find_index(frame_columns(table), fn {k, _, _, _} -> k == lowered end) do
      nil -> nil
      index -> Enum.at(table.columns, index)
    end
  end

  @doc "All rows in rowid order, as `{rowid, row}` pairs."
  @spec scan(t()) :: [{integer(), row()}]
  def scan(table) do
    rows = table.rows
    n = map_size(rows)

    # Rows live in an unordered map but a scan must yield them in rowid order.
    # Tables loaded by sequential inserts have dense rowids `1..n`, so build
    # them in order directly (O(n)) and only pay the O(n log n) sort when there
    # is a gap (deletes / explicit rowids). The dense build aborts to the sort
    # the moment a rowid is missing, so the result is always correct.
    #
    # Rows are stored positionally (tuples, ~63% smaller than the per-row
    # key-bearing maps); the executor consumes `key => value` maps, so the
    # stored tuple is widened back to a map at this read boundary.
    ordered = if n == 0, do: [], else: dense_scan(rows, n, [])
    keys = column_keys(table)
    Enum.map(ordered, fn {rowid, tuple} -> {rowid, widen_row(keys, tuple)} end)
  end

  # The folded column keys in positional order (cached path via frame_columns).
  defp column_keys(table), do: Enum.map(frame_columns(table), &elem(&1, 0))

  @doc """
  Like `scan/1` but yields rows as raw positional tuples (no widening to maps).
  The hot single-table scan reads columns positionally via the frame's
  `column_index`, skipping the per-row map build that `scan/1` does.
  """
  @spec scan_positional(t()) :: [{integer(), tuple()}]
  def scan_positional(table) do
    rows = table.rows
    n = map_size(rows)
    if n == 0, do: [], else: dense_scan(rows, n, [])
  end

  @doc "Fetches one stored row as a `key => value` map, or `:error`."
  @spec fetch_row(t(), integer()) :: {:ok, row()} | :error
  def fetch_row(table, rowid) do
    case table.rows do
      %{^rowid => tuple} -> {:ok, widen_row(column_keys(table), tuple)}
      _ -> :error
    end
  end

  @doc "Like `fetch_row/2` but raises if the rowid is absent."
  @spec fetch_row!(t(), integer()) :: row()
  def fetch_row!(table, rowid), do: widen_row(column_keys(table), Map.fetch!(table.rows, rowid))

  @doc "Reads one stored row as a map, or `nil` if absent."
  @spec get_row(t(), integer()) :: row() | nil
  def get_row(table, rowid) do
    case table.rows do
      %{^rowid => tuple} -> widen_row(column_keys(table), tuple)
      _ -> nil
    end
  end

  @doc """
  Re-narrows all stored rows — currently held as `key => value` maps — to
  positional tuples against the table's *current* columns. Used after an
  `ALTER TABLE` rebuilds rows as maps under a changed column layout.
  """
  @spec narrow_all_rows(t()) :: t()
  def narrow_all_rows(table) do
    %{table | rows: Map.new(table.rows, fn {rowid, map} -> {rowid, row_from_map(table, map)} end)}
  end

  defp widen_row(keys, tuple) do
    Map.new(Enum.zip(keys, Tuple.to_list(tuple)))
  end

  defp dense_scan(_rows, 0, acc), do: acc

  defp dense_scan(rows, i, acc) do
    case rows do
      %{^i => row} -> dense_scan(rows, i - 1, [{i, row} | acc])
      _ -> Enum.sort_by(Map.to_list(rows), &elem(&1, 0))
    end
  end

  @doc """
  Inserts a row given as a map of column key => value (already evaluated and
  affinity-coerced by the executor). Missing columns get their DEFAULT or NULL.

  Options:

    * `:rowid` — an explicit rowid (from inserting into `rowid` by name)
    * `:on_conflict` — `:abort` (default) errors, `:replace` deletes the
      conflicting rows first, `:ignore` skips the insert and returns `:ignore`
  """
  @spec insert(t(), row(), keyword()) :: {:ok, t(), integer()} | :ignore | {:error, String.t()}
  def insert(table, values, opts \\ []) do
    on_conflict = Keyword.get(opts, :on_conflict) || :abort
    explicit_rowid = Keyword.get(opts, :rowid)

    row =
      table
      |> keyed_columns()
      |> Enum.map(fn {column_key, column} ->
        case Map.get(values, column_key, :missing) do
          :missing -> default_value(column)
          value -> value
        end
      end)
      |> List.to_tuple()

    with {:ok, rowid, row} <- assign_rowid(table, row, explicit_rowid) do
      case violations(table, row, rowid) do
        :ok ->
          {:ok, store(table, rowid, row), rowid}

        {:conflict, _rowids, _message} when on_conflict == :ignore ->
          :ignore

        {:conflict, rowids, _message} when on_conflict == :replace ->
          {:ok, table |> delete_rows(rowids) |> store(rowid, row), rowid}

        {:conflict, _rowids, message} ->
          {:error, message}

        {:error, _message} when on_conflict == :ignore ->
          :ignore

        {:error, message} ->
          {:error, message}
      end
    end
  end

  defp store(table, rowid, row) do
    # Accept an already-positional tuple (the insert path builds one directly),
    # otherwise narrow a `key => value` map (the update path).
    tuple = if is_tuple(row), do: row, else: row_from_map(table, row)
    rows = Map.put(table.rows, rowid, tuple)
    sequence = if table.autoincrement, do: max(table.sequence, rowid), else: table.sequence
    sequence_row = table.sequence_row or table.autoincrement

    %{
      table
      | rows: rows,
        next_rowid: max(table.next_rowid, rowid + 1),
        sequence: sequence,
        sequence_row: sequence_row
    }
  end

  defp assign_rowid(table, row, explicit) when is_integer(explicit) do
    case table.rowid_alias do
      nil ->
        {:ok, explicit, row}

      alias_key ->
        {:ok, explicit, put_elem(row, Map.fetch!(column_index(table), alias_key), explicit)}
    end
  end

  defp assign_rowid(table, row, nil) do
    case table.rowid_alias do
      nil ->
        {:ok, table.next_rowid, row}

      alias_key ->
        pos = Map.fetch!(column_index(table), alias_key)

        case elem(row, pos) do
          nil ->
            rowid =
              cond do
                table.autoincrement and table.sequence_row ->
                  max(table.sequence + 1, table.next_rowid)

                table.autoincrement ->
                  next_available_rowid(table)

                true ->
                  table.next_rowid
              end

            {:ok, rowid, put_elem(row, pos, rowid)}

          value when is_integer(value) ->
            {:ok, value, row}

          _other ->
            {:error, "datatype mismatch"}
        end
    end
  end

  defp assign_rowid(_table, _row, _explicit), do: {:error, "datatype mismatch"}

  defp next_available_rowid(table) do
    table.rows
    |> Map.keys()
    |> Enum.max(fn -> 0 end)
    |> Kernel.+(1)
  end

  # Checks the row for constraint problems. NOT NULL is a hard error;
  # rowid/PRIMARY KEY/UNIQUE collisions report the conflicting rowids so
  # OR REPLACE can delete them.
  defp violations(table, row, rowid) do
    ci = column_index(table)
    keyed = keyed_columns(table)

    not_null =
      Enum.find_value(keyed, fn {column_key, column} ->
        if not_null_violation?(table, row, Map.fetch!(ci, column_key), column_key, column),
          do: column
      end) || composite_primary_key_not_null_violation(table, row, ci)

    if not_null do
      {:error, "NOT NULL constraint failed: #{table.name}.#{not_null.name}"}
    else
      rowid_conflicts =
        if Map.has_key?(table.rows, rowid) do
          [{rowid, "UNIQUE constraint failed: #{table.name}.#{table.rowid_alias || "rowid"}"}]
        else
          []
        end

      # Single-column UNIQUE/PK on non-rowid columns
      single_unique_conflicts =
        for {column_key, column} <- keyed,
            needs_uniqueness?(table, column_key, column),
            value = cell(row, Map.fetch!(ci, column_key)),
            not is_nil(value),
            conflicting <- duplicates(table, column_key, value, rowid) do
          {conflicting, "UNIQUE constraint failed: #{table.name}.#{column.name}"}
        end

      # Composite PRIMARY KEY conflicts. A single-column PK that became the
      # rowid alias is already handled via the rowid itself.
      composite_pk_conflicts =
        for {_cname, col_keys} <- table.composite_keys,
            col_keys != [table.rowid_alias],
            values = Enum.map(col_keys, &cell(row, Map.fetch!(ci, &1))),
            Enum.all?(values, &(not is_nil(&1))),
            conflicting <- composite_duplicates(table, col_keys, values, rowid) do
          {conflicting, "UNIQUE constraint failed: #{column_list(table, col_keys)}"}
        end

      composite_unique_conflicts =
        for {_cname, col_keys} <- table.composite_uniques,
            values = Enum.map(col_keys, &cell(row, Map.fetch!(ci, &1))),
            Enum.all?(values, &(not is_nil(&1))),
            conflicting <- composite_duplicates(table, col_keys, values, rowid) do
          {conflicting, "UNIQUE constraint failed: #{column_list(table, col_keys)}"}
        end

      all_conflicts =
        rowid_conflicts ++
          single_unique_conflicts ++
          composite_pk_conflicts ++
          composite_unique_conflicts

      case all_conflicts do
        [] -> :ok
        conflicts -> {:conflict, Enum.map(conflicts, &elem(&1, 0)), elem(hd(conflicts), 1)}
      end
    end
  end

  # Columns paired with their cached (pre-folded) keys, so insert/violation
  # checks don't re-fold `key(column.name)` for every column on every insert.
  defp keyed_columns(table) do
    Enum.zip_with(frame_columns(table), table.columns, fn {column_key, _, _, _}, column ->
      {column_key, column}
    end)
  end

  defp not_null_violation?(table, row, pos, column_key, column) do
    is_nil(cell(row, pos)) and
      (column.not_null or
         (primary_key_columns_are_not_null?(table) and column.primary_key and
            column_key != table.rowid_alias))
  end

  defp composite_primary_key_not_null_violation(
         %{strict: false, without_rowid: false},
         _row,
         _ci
       ),
       do: nil

  defp composite_primary_key_not_null_violation(table, row, ci) do
    table.composite_keys
    |> Enum.flat_map(fn {_name, col_keys} -> col_keys end)
    |> Enum.find_value(fn col_key ->
      if is_nil(cell(row, Map.fetch!(ci, col_key))) and col_key != table.rowid_alias do
        column(table, col_key)
      end
    end)
  end

  defp primary_key_columns_are_not_null?(table), do: table.strict or table.without_rowid

  # "t.a, t.b" as it appears in UNIQUE constraint failure messages.
  defp column_list(table, col_keys),
    do: Enum.map_join(col_keys, ", ", &"#{table.name}.#{display_column_name(table, &1)}")

  # Returns the display name (original case) for a column key
  defp display_column_name(table, col_key) do
    case Enum.find(table.columns, &(key(&1.name) == col_key)) do
      nil -> col_key
      col -> col.name
    end
  end

  defp default_value(%ColumnDef{default: nil}), do: nil

  defp default_value(%ColumnDef{default: {:literal, value}, affinity: affinity}),
    do: Value.apply_affinity(value, affinity)

  defp default_value(%ColumnDef{default: {:negate, {:literal, value}}, affinity: affinity})
       when is_number(value),
       do: Value.apply_affinity(-value, affinity)

  # A bare word default (`f3 text default hi`) is the word as a string.
  defp default_value(%ColumnDef{default: {:column, nil, word}, affinity: affinity}),
    do: Value.apply_affinity(word, affinity)

  defp default_value(_column), do: nil

  # The rowid alias is unique by construction (it is the map key).
  defp needs_uniqueness?(table, column_key, column),
    do: (column.primary_key or column.unique) and column_key != table.rowid_alias

  defp duplicates(table, column_key, value, excluding_rowid) do
    pos = Map.fetch!(column_index(table), column_key)

    for {rowid, row} <- table.rows,
        rowid != excluding_rowid,
        Value.compare(cell(row, pos), value) == :eq,
        do: rowid
  end

  defp composite_duplicates(table, col_keys, values, excluding_rowid) do
    index = column_index(table)
    pos_values = Enum.map(Enum.zip(col_keys, values), fn {k, v} -> {Map.fetch!(index, k), v} end)

    for {rowid, row} <- table.rows,
        rowid != excluding_rowid,
        Enum.all?(pos_values, fn {pos, v} -> Value.compare(cell(row, pos), v) == :eq end) do
      rowid
    end
  end

  @doc """
  Replaces the row at `rowid` with `row` (a full row map), re-checking
  constraints. If the rowid-alias column changed, the row is re-keyed.
  Takes the same `:on_conflict` option as `insert/3`.
  """
  @spec update_row(t(), integer(), row(), keyword()) ::
          {:ok, t()} | :ignore | {:error, String.t()}
  def update_row(table, rowid, row, opts \\ []) do
    on_conflict = Keyword.get(opts, :on_conflict) || :abort

    rowid_result =
      case Keyword.fetch(opts, :rowid) do
        {:ok, explicit_rowid} when is_integer(explicit_rowid) ->
          row =
            case table.rowid_alias do
              nil -> row
              alias_key -> Map.put(row, alias_key, explicit_rowid)
            end

          {:ok, explicit_rowid, row}

        {:ok, _explicit_rowid} ->
          :error

        :error ->
          case table.rowid_alias do
            nil -> {:ok, rowid, row}
            alias_key -> {:ok, Map.fetch!(row, alias_key), row}
          end
      end

    if match?(:error, rowid_result) do
      {:error, "datatype mismatch"}
    else
      {:ok, new_rowid, row} = rowid_result

      if table.rowid_alias != nil and not is_integer(new_rowid) do
        {:error, "datatype mismatch"}
      else
        # The row being updated does not conflict with itself.
        shadow = delete_rows(table, [rowid])
        tuple = row_from_map(table, row)

        case violations(shadow, tuple, new_rowid) do
          :ok ->
            {:ok, store(shadow, new_rowid, tuple)}

          {:conflict, _rowids, _message} when on_conflict == :ignore ->
            :ignore

          {:conflict, rowids, _message} when on_conflict == :replace ->
            {:ok, shadow |> delete_rows(rowids) |> store(new_rowid, tuple)}

          {:conflict, _rowids, message} ->
            {:error, message}

          {:error, _message} when on_conflict == :ignore ->
            :ignore

          {:error, message} ->
            {:error, message}
        end
      end
    end
  end

  @doc "Deletes the rows with the given rowids."
  @spec delete_rows(t(), [integer()]) :: t()
  def delete_rows(table, rowids), do: %{table | rows: Map.drop(table.rows, rowids)}
end