Skip to main content

lib/ex_sql/database.ex

defmodule ExSQL.Database do
  @moduledoc """
  An immutable in-memory database: a schema of named tables.

  This plays the role of SQLite's connection + pager state, except every
  operation returns a new database value. Use it directly for a purely
  functional workflow, or hold it in an `ExSQL.Connection` process for a
  stateful, sqlite3-like handle.

      {:ok, _, db} = ExSQL.Executor.run(ExSQL.Database.new(), "CREATE TABLE t (a)")
  """

  alias ExSQL.Table

  defstruct tables: %{},
            views: %{},
            triggers: %{},
            attached_databases: [],
            txn_stack: [],
            ctes: %{},
            pending_ctes: %{},
            changes: 0,
            total_changes: 0,
            last_insert_rowid: 0,
            foreign_keys: false,
            defer_foreign_keys: false,
            recursive_triggers: false,
            ignore_check_constraints: false,
            count_changes: false,
            read_uncommitted: false,
            case_sensitive_like: false,
            short_column_names: true,
            full_column_names: false,
            reverse_unordered_selects: false,
            query_only: false,
            schema_version: 0,
            user_version: 0,
            application_id: 0,
            schema_headers: %{},
            page_size: 4096,
            page_size_locked: false,
            cache_size: -2000,
            max_page_count: 1_073_741_823,
            cache_spill: 20_000,
            default_cache_size: 2000,
            auto_vacuum: 0,
            journal_mode: "memory",
            journal_size_limit: 32_768,
            locking_mode: "normal",
            synchronous: 2,
            temp_store: 0,
            automatic_index: true,
            busy_timeout: 0,
            soft_heap_limit: 0,
            threads: 0,
            secure_delete: 2,
            analysis_limit: 0,
            wal_autocheckpoint: 1000,
            cell_size_check: false,
            checkpoint_fullfsync: false,
            fullfsync: false,
            trusted_schema: false,
            empty_result_callbacks: false,
            active_triggers: [],
            sqlite_sequence_orphans: %{},
            scalar_functions: %{},
            aggregate_functions: %{},
            collations: %{}

  @typedoc """
  A stored view: the parsed query and an optional explicit column list.
  """
  @type view :: %{
          name: String.t(),
          schema: String.t() | nil,
          columns: [String.t()] | nil,
          query: term()
        }

  @typedoc """
  A materialized CTE result, threaded through query execution.
  """
  @type cte :: %{columns: [String.t()], rows: [[term()]], affinities: [atom()]}

  @typedoc """
  Open transactions and savepoints, newest first. Each entry snapshots both
  tables and views as they were when it was opened — rollback restores both.
  """
  @type txn_entry ::
          {:begin | {:savepoint, String.t()},
           %{tables: %{String.t() => Table.t()}, views: %{String.t() => view()}}}

  @typedoc """
  A stored trigger: parsed definition plus a creation sequence number that
  fixes firing order.
  """
  @type trigger :: %{
          key: String.t(),
          name: String.t(),
          schema: String.t() | nil,
          table_schema: String.t() | nil,
          table_key: String.t(),
          timing: :before | :after | :instead_of,
          event: :insert | :delete | :update,
          update_columns: [String.t()] | nil,
          when: term() | nil,
          body: [term()],
          seq: non_neg_integer()
        }

  @typedoc """
  A connection-local scalar function callback. It receives evaluated SQL values
  and must return a SQL value, `{:ok, value}`, or `{:error, message}`.
  """
  @type scalar_function :: %{
          name: String.t(),
          arity: non_neg_integer(),
          callback: function()
        }

  @typedoc """
  A connection-local collation callback. It receives two TEXT values and
  returns `:lt`/`:eq`/`:gt`, a negative/zero/positive integer, or
  `{:ok, result}` / `{:error, message}`.
  """
  @type collation :: %{
          name: String.t(),
          callback: function()
        }

  @typedoc """
  A connection-local aggregate callback. It receives a list of evaluated,
  non-NULL argument rows and must return a SQL value, `{:ok, value}`, or
  `{:error, message}`. Incremental window callbacks use the same registry but
  store a callback map and `kind: :incremental_window`.
  """
  @type aggregate_function :: %{
          required(:name) => String.t(),
          required(:arity) => non_neg_integer(),
          required(:callback) => function() | map(),
          optional(:kind) => :frame | :incremental_window
        }

  @type t :: %__MODULE__{
          tables: %{String.t() => Table.t()},
          views: %{String.t() => view()},
          triggers: %{String.t() => trigger()},
          attached_databases: [%{seq: pos_integer(), name: String.t(), file: String.t()}],
          txn_stack: [txn_entry()],
          ctes: %{String.t() => cte()},
          changes: non_neg_integer(),
          total_changes: non_neg_integer(),
          last_insert_rowid: integer(),
          foreign_keys: boolean(),
          defer_foreign_keys: boolean(),
          recursive_triggers: boolean(),
          ignore_check_constraints: boolean(),
          count_changes: boolean(),
          read_uncommitted: boolean(),
          case_sensitive_like: boolean(),
          short_column_names: boolean(),
          full_column_names: boolean(),
          reverse_unordered_selects: boolean(),
          query_only: boolean(),
          schema_version: integer(),
          user_version: integer(),
          application_id: integer(),
          schema_headers: %{String.t() => map()},
          page_size: pos_integer(),
          page_size_locked: boolean(),
          cache_size: integer(),
          max_page_count: pos_integer(),
          cache_spill: non_neg_integer(),
          default_cache_size: non_neg_integer(),
          auto_vacuum: 0..2,
          journal_mode: String.t(),
          journal_size_limit: integer(),
          locking_mode: String.t(),
          synchronous: 0..3,
          temp_store: 0..2,
          automatic_index: boolean(),
          busy_timeout: non_neg_integer(),
          soft_heap_limit: non_neg_integer(),
          threads: 0..8,
          secure_delete: 0..2,
          analysis_limit: non_neg_integer(),
          wal_autocheckpoint: non_neg_integer(),
          cell_size_check: boolean(),
          checkpoint_fullfsync: boolean(),
          fullfsync: boolean(),
          trusted_schema: boolean(),
          empty_result_callbacks: boolean(),
          active_triggers: [String.t()],
          sqlite_sequence_orphans: map(),
          scalar_functions: %{{String.t(), non_neg_integer()} => scalar_function()},
          aggregate_functions: %{{String.t(), non_neg_integer()} => aggregate_function()},
          collations: %{String.t() => collation()}
        }

  @doc "Returns an empty database."
  @spec new() :: t()
  def new, do: %__MODULE__{}

  @doc """
  Registers or replaces a connection-local scalar function.

  The callback arity must match the SQL arity exactly. The callback receives
  evaluated SQL values as positional arguments and may return a SQL value,
  `{:ok, value}`, or `{:error, message}`.
  """
  @spec create_scalar_function(t(), String.t(), non_neg_integer(), function()) ::
          {:ok, t()} | {:error, String.t()}
  def create_scalar_function(db, name, arity, callback) do
    cond do
      not is_binary(name) or String.trim(name) == "" ->
        {:error, "function name must be a non-empty string"}

      not is_integer(arity) or arity < 0 or arity > 127 ->
        {:error, "function arity must be between 0 and 127"}

      not is_function(callback) ->
        {:error, "function callback must be a function"}

      function_arity(callback) != arity ->
        {:error, "function callback arity must match SQL arity"}

      true ->
        key = scalar_function_key(name, arity)
        function = %{name: String.downcase(name), arity: arity, callback: callback}
        {:ok, %{db | scalar_functions: Map.put(db.scalar_functions, key, function)}}
    end
  end

  @doc "Fetches a registered scalar function by name and exact arity."
  @spec fetch_scalar_function(t(), String.t(), non_neg_integer()) ::
          {:ok, scalar_function()} | :error
  def fetch_scalar_function(db, name, arity) do
    Map.fetch(db.scalar_functions, scalar_function_key(name, arity))
  end

  @doc "Returns true if any arity is registered for this scalar function name."
  @spec scalar_function_exists?(t(), String.t()) :: boolean()
  def scalar_function_exists?(db, name) do
    key = Table.key(name)
    Enum.any?(db.scalar_functions, fn {{registered, _arity}, _function} -> registered == key end)
  end

  defp scalar_function_key(name, arity), do: {Table.key(name), arity}

  defp function_arity(callback) do
    {:arity, arity} = Function.info(callback, :arity)
    arity
  end

  @doc """
  Registers or replaces a connection-local aggregate function.

  The callback receives one list argument containing evaluated, non-NULL
  argument rows. A one-argument SQL aggregate gets rows like `[[value], ...]`;
  a two-argument aggregate gets rows like `[[a, b], ...]`.
  """
  @spec create_aggregate_function(t(), String.t(), non_neg_integer(), function()) ::
          {:ok, t()} | {:error, String.t()}
  def create_aggregate_function(db, name, arity, callback) do
    create_frame_function(db, name, arity, callback, "aggregate")
  end

  @doc """
  Registers or replaces a connection-local aggregate window function.

  The callback receives one list argument containing evaluated, non-NULL
  argument rows from the current window frame.
  """
  @spec create_window_function(t(), String.t(), non_neg_integer(), function()) ::
          {:ok, t()} | {:error, String.t()}
  def create_window_function(db, name, arity, callback) do
    create_frame_function(db, name, arity, callback, "window function")
  end

  @doc """
  Registers or replaces a connection-local incremental aggregate window
  function.

  Callback map keys follow SQLite's aggregate window model: `:init` (arity 0),
  `:step` and `:inverse` (arity 2, receiving state and evaluated SQL argument
  list), and `:value` / `:final` (arity 1).
  """
  @spec create_incremental_window_function(t(), String.t(), non_neg_integer(), map()) ::
          {:ok, t()} | {:error, String.t()}
  def create_incremental_window_function(db, name, arity, callbacks) do
    cond do
      not is_binary(name) or String.trim(name) == "" ->
        {:error, "incremental window function name must be a non-empty string"}

      not is_integer(arity) or arity < 0 or arity > 127 ->
        {:error, "incremental window function arity must be between 0 and 127"}

      not is_map(callbacks) ->
        {:error, "incremental window function callbacks must be a map"}

      error = invalid_incremental_window_callbacks(callbacks) ->
        {:error, error}

      true ->
        key = aggregate_function_key(name, arity)

        function = %{
          name: String.downcase(name),
          arity: arity,
          callback: callbacks,
          kind: :incremental_window
        }

        {:ok, %{db | aggregate_functions: Map.put(db.aggregate_functions, key, function)}}
    end
  end

  defp create_frame_function(db, name, arity, callback, label) do
    cond do
      not is_binary(name) or String.trim(name) == "" ->
        {:error, "#{label} name must be a non-empty string"}

      not is_integer(arity) or arity < 0 or arity > 127 ->
        {:error, "#{label} arity must be between 0 and 127"}

      not is_function(callback) ->
        {:error, "#{label} callback must be a function"}

      function_arity(callback) != 1 ->
        {:error, "#{label} callback arity must be 1"}

      true ->
        key = aggregate_function_key(name, arity)
        function = %{name: String.downcase(name), arity: arity, callback: callback, kind: :frame}
        {:ok, %{db | aggregate_functions: Map.put(db.aggregate_functions, key, function)}}
    end
  end

  defp invalid_incremental_window_callbacks(callbacks) do
    [
      init: 0,
      step: 2,
      inverse: 2,
      value: 1,
      final: 1
    ]
    |> Enum.find_value(fn {name, arity} ->
      callback = Map.get(callbacks, name)

      cond do
        not is_function(callback) ->
          "incremental window function #{name} callback must be a function"

        function_arity(callback) != arity ->
          "incremental window function #{name} callback arity must be #{arity}"

        true ->
          nil
      end
    end)
  end

  @doc "Fetches a registered aggregate function by name and exact SQL arity."
  @spec fetch_aggregate_function(t(), String.t(), non_neg_integer()) ::
          {:ok, aggregate_function()} | :error
  def fetch_aggregate_function(db, name, arity) do
    Map.fetch(db.aggregate_functions, aggregate_function_key(name, arity))
  end

  @doc "Returns true if any arity is registered for this aggregate function name."
  @spec aggregate_function_exists?(t(), String.t()) :: boolean()
  def aggregate_function_exists?(db, name) do
    key = Table.key(name)

    Enum.any?(db.aggregate_functions, fn {{registered, _arity}, _function} ->
      registered == key
    end)
  end

  defp aggregate_function_key(name, arity), do: {Table.key(name), arity}

  @doc """
  Registers or replaces a connection-local collation.

  The callback receives two TEXT values and returns `:lt`/`:eq`/`:gt`, a
  negative/zero/positive integer, or `{:ok, result}` / `{:error, message}`.
  """
  @spec create_collation(t(), String.t(), function()) :: {:ok, t()} | {:error, String.t()}
  def create_collation(db, name, callback) do
    cond do
      not is_binary(name) or String.trim(name) == "" ->
        {:error, "collation name must be a non-empty string"}

      not is_function(callback) ->
        {:error, "collation callback must be a function"}

      function_arity(callback) != 2 ->
        {:error, "collation callback arity must be 2"}

      true ->
        collation = %{name: String.downcase(name), callback: callback}
        {:ok, %{db | collations: Map.put(db.collations, Table.key(name), collation)}}
    end
  end

  @doc "Fetches a registered collation by name."
  @spec fetch_collation(t(), String.t()) :: {:ok, collation()} | :error
  def fetch_collation(db, name), do: Map.fetch(db.collations, Table.key(name))

  @doc "Fetches a view by name (case-insensitive)."
  @spec fetch_view(t(), String.t()) :: {:ok, view()} | :error
  def fetch_view(db, name) do
    fetch_view(db, nil, name)
  end

  @spec fetch_view(t(), String.t() | nil, String.t()) :: {:ok, view()} | :error
  def fetch_view(db, schema, name) do
    Map.fetch(db.views, table_storage_key(schema, name))
  end

  @doc """
  Fetches a view by unqualified name using SQLite's schema lookup order.
  """
  @spec lookup_view(t(), String.t()) :: {:ok, view()} | :error
  def lookup_view(db, name) do
    lookup_schema_object(db, :views, name)
  end

  @doc "Adds a new view; errors if the name collides with a table or existing view."
  @spec create_view(t(), view()) :: {:ok, t()} | {:error, String.t()}
  def create_view(db, view) do
    key = table_storage_key(view.schema, view.name)

    cond do
      Map.has_key?(db.tables, key) ->
        {:error, "table #{view.name} already exists"}

      Map.has_key?(db.views, key) ->
        {:error, "view #{view.name} already exists"}

      index_exists?(db, view.schema, view.name) ->
        {:error, "there is already an index named #{view.name}"}

      true ->
        {:ok, schema_changed(%{db | views: Map.put(db.views, key, view)}, view.schema)}
    end
  end

  @doc "Removes a view; errors if it does not exist or if the name is a table."
  @spec drop_view(t(), String.t()) :: {:ok, t()} | {:error, String.t()}
  def drop_view(db, name) do
    drop_view(db, nil, name)
  end

  @spec drop_view(t(), String.t() | nil, String.t()) :: {:ok, t()} | {:error, String.t()}
  def drop_view(db, schema, name) do
    key = table_storage_key(schema, name)

    cond do
      Map.has_key?(db.tables, key) ->
        {:error, "use DROP TABLE to delete table #{name}"}

      Map.has_key?(db.views, key) ->
        {:ok, schema_changed(%{db | views: Map.delete(db.views, key)}, schema)}

      true ->
        {:error, "no such view: #{name}"}
    end
  end

  @doc "Fetches a table by name (case-insensitive)."
  @spec fetch_table(t(), String.t()) :: {:ok, Table.t()} | {:error, String.t()}
  def fetch_table(db, name) do
    case Map.fetch(db.tables, Table.key(name)) do
      {:ok, table} -> {:ok, table}
      :error -> {:error, "no such table: #{name}"}
    end
  end

  @doc """
  Fetches a table by unqualified name using SQLite's schema lookup order.
  """
  @spec lookup_table(t(), String.t()) :: {:ok, Table.t()} | {:error, String.t()}
  def lookup_table(db, name) do
    case lookup_schema_object(db, :tables, name) do
      {:ok, table} -> {:ok, table}
      :error -> {:error, "no such table: #{name}"}
    end
  end

  @doc "Adds a new table; errors if the name collides with a table or existing view."
  @spec create_table(t(), Table.t()) :: {:ok, t()} | {:error, String.t()}
  def create_table(db, table) do
    table_key = table_storage_key(table.schema, table.name)

    cond do
      Map.has_key?(db.tables, table_key) ->
        {:error, "table #{table.name} already exists"}

      Map.has_key?(db.views, table_key) ->
        {:error, "there is already a table named #{table.name}"}

      index_exists?(db, table.schema, table.name) ->
        {:error, "there is already an index named #{table.name}"}

      true ->
        {:ok, db |> put_table(table) |> schema_changed(table.schema)}
    end
  end

  @doc "Replaces a table's state after DML."
  @spec put_table(t(), Table.t()) :: t()
  def put_table(db, table),
    do: %{db | tables: Map.put(db.tables, table_storage_key(table.schema, table.name), table)}

  @doc "Removes a table; errors if it does not exist or if the name is a view."
  @spec drop_table(t(), String.t()) :: {:ok, t()} | {:error, String.t()}
  def drop_table(db, name) do
    drop_table(db, nil, name)
  end

  @spec drop_table(t(), String.t() | nil, String.t()) :: {:ok, t()} | {:error, String.t()}
  def drop_table(db, schema, name) do
    table_key = table_storage_key(schema, name)

    cond do
      Map.has_key?(db.tables, table_key) ->
        {:ok, schema_changed(%{db | tables: Map.delete(db.tables, table_key)}, schema)}

      Map.has_key?(db.views, table_key) ->
        {:error, "use DROP VIEW to delete view #{name}"}

      true ->
        {:error, "no such table: #{name}"}
    end
  end

  @doc "Internal storage key for schema-qualified tables."
  @spec table_storage_key(String.t() | nil, String.t()) :: String.t()
  def table_storage_key(schema, name) when schema in [nil, "main"], do: Table.key(name)

  def table_storage_key(schema, name),
    do: Table.key(schema) <> "." <> Table.key(name)

  @doc """
  Finds which table owns the index with the given name, or nil.
  Index lookup is database-global and case-insensitive.
  """
  @spec find_index_owner(t(), String.t()) :: {Table.t(), map()} | nil
  def find_index_owner(db, index_name) do
    find_index_owner(db, :any, index_name)
  end

  @doc """
  Finds which table owns the index with the given name within a schema, or nil.
  """
  @spec find_index_owner(t(), String.t() | nil, String.t()) :: {Table.t(), map()} | nil
  def find_index_owner(db, schema, index_name) do
    key = Table.key(index_name)

    Enum.find_value(db.tables, fn {_table_key, table} ->
      if index_schema_matches?(table.schema, schema) do
        case Enum.find(table.indexes, &(Table.key(&1.name) == key)) do
          nil -> nil
          index -> {table, index}
        end
      else
        nil
      end
    end)
  end

  @doc "Returns true if any table has an index with this name."
  @spec index_exists?(t(), String.t()) :: boolean()
  def index_exists?(db, index_name) do
    find_index_owner(db, index_name) != nil
  end

  @doc "Returns true if a table in the schema has an index with this name."
  @spec index_exists?(t(), String.t() | nil, String.t()) :: boolean()
  def index_exists?(db, schema, index_name) do
    find_index_owner(db, schema, index_name) != nil
  end

  defp index_schema_matches?(_table_schema, :any), do: true

  defp index_schema_matches?(table_schema, schema) do
    table_storage_key(table_schema, "") == table_storage_key(schema, "")
  end

  defp lookup_schema_object(db, field, name) do
    objects = Map.fetch!(db, field)

    Enum.find_value(schema_lookup_order(db), :error, fn schema ->
      case Map.fetch(objects, table_storage_key(schema, name)) do
        {:ok, object} -> {:ok, object}
        :error -> nil
      end
    end)
  end

  defp schema_lookup_order(db), do: ["temp", nil] ++ Enum.map(db.attached_databases, & &1.name)

  @doc "Returns a snapshot of mutable schema state (tables, views, triggers) for transactions."
  @spec schema_snapshot(t()) :: %{
          tables: map(),
          views: map(),
          triggers: map(),
          sqlite_sequence_orphans: map(),
          schema_version: integer(),
          user_version: integer(),
          application_id: integer(),
          schema_headers: map()
        }
  def schema_snapshot(db) do
    %{
      tables: db.tables,
      views: db.views,
      triggers: db.triggers,
      sqlite_sequence_orphans: db.sqlite_sequence_orphans,
      schema_version: db.schema_version,
      user_version: db.user_version,
      application_id: db.application_id,
      schema_headers: db.schema_headers
    }
  end

  @doc "Restores schema from a snapshot."
  @spec restore_schema(t(), map()) :: t()
  def restore_schema(db, snapshot) do
    %{
      db
      | tables: snapshot.tables,
        views: snapshot.views,
        triggers: Map.get(snapshot, :triggers, db.triggers),
        sqlite_sequence_orphans:
          Map.get(snapshot, :sqlite_sequence_orphans, db.sqlite_sequence_orphans),
        schema_version: Map.get(snapshot, :schema_version, db.schema_version),
        user_version: Map.get(snapshot, :user_version, db.user_version),
        application_id: Map.get(snapshot, :application_id, db.application_id),
        schema_headers: Map.get(snapshot, :schema_headers, db.schema_headers)
    }
  end

  @doc "Bumps the schema cookie using SQLite's signed 32-bit wraparound."
  @spec schema_changed(t()) :: t()
  def schema_changed(db), do: schema_changed(db, nil)

  @spec schema_changed(t(), String.t() | nil) :: t()
  def schema_changed(db, schema) do
    version = schema_header_value(db, schema, :schema_version)

    db
    |> put_schema_header_value(schema, :schema_version, next_schema_version(version))
    |> Map.put(:page_size_locked, true)
  end

  @doc "Returns a schema header value for main or an attached schema."
  @spec schema_header_value(
          t(),
          String.t() | nil,
          :schema_version | :user_version | :application_id
        ) ::
          integer()
  def schema_header_value(db, schema, field) do
    db
    |> schema_header(schema)
    |> Map.fetch!(field)
  end

  @doc "Stores a schema header value for main or an attached schema."
  @spec put_schema_header_value(
          t(),
          String.t() | nil,
          :schema_version | :user_version | :application_id,
          integer()
        ) :: t()
  def put_schema_header_value(db, schema, field, value) when schema in [nil, "main"] do
    Map.put(db, field, value)
  end

  def put_schema_header_value(db, schema, field, value) do
    key = Table.key(schema)
    header = schema_header(db, schema)

    %{db | schema_headers: Map.put(db.schema_headers, key, Map.put(header, field, value))}
  end

  @doc "Removes attached schema header state."
  @spec drop_schema_header(t(), String.t()) :: t()
  def drop_schema_header(db, schema) do
    %{db | schema_headers: Map.delete(db.schema_headers, Table.key(schema))}
  end

  defp schema_header(db, schema) when schema in [nil, "main"] do
    %{
      schema_version: db.schema_version,
      user_version: db.user_version,
      application_id: db.application_id
    }
  end

  defp schema_header(db, schema) do
    Map.get(db.schema_headers, Table.key(schema), %{
      schema_version: 0,
      user_version: 0,
      application_id: 0
    })
  end

  defp next_schema_version(2_147_483_647), do: -2_147_483_648
  defp next_schema_version(value), do: value + 1

  @doc "Records the row-count state exposed by changes() and total_changes()."
  @spec record_changes(t(), non_neg_integer()) :: t()
  def record_changes(db, count), do: record_changes(db, count, nil)

  @spec record_changes(t(), non_neg_integer(), integer() | nil) :: t()
  def record_changes(db, count, last_insert_rowid) do
    db = %{db | changes: count, total_changes: db.total_changes + count}

    if last_insert_rowid == nil do
      db
    else
      %{db | last_insert_rowid: last_insert_rowid}
    end
  end
end