defmodule ExSQL.Database do
@moduledoc """
An immutable in-memory database: a schema of named tables.
This plays the role of SQLite's connection + pager state, except every
operation returns a new database value. Use it directly for a purely
functional workflow, or hold it in an `ExSQL.Connection` process for a
stateful, sqlite3-like handle.
{:ok, _, db} = ExSQL.Executor.run(ExSQL.Database.new(), "CREATE TABLE t (a)")
"""
alias ExSQL.Table
defstruct tables: %{},
views: %{},
triggers: %{},
attached_databases: [],
txn_stack: [],
ctes: %{},
pending_ctes: %{},
changes: 0,
total_changes: 0,
last_insert_rowid: 0,
foreign_keys: false,
defer_foreign_keys: false,
recursive_triggers: false,
ignore_check_constraints: false,
count_changes: false,
read_uncommitted: false,
case_sensitive_like: false,
short_column_names: true,
full_column_names: false,
reverse_unordered_selects: false,
query_only: false,
schema_version: 0,
user_version: 0,
application_id: 0,
schema_headers: %{},
page_size: 4096,
page_size_locked: false,
cache_size: -2000,
max_page_count: 1_073_741_823,
cache_spill: 20_000,
default_cache_size: 2000,
auto_vacuum: 0,
journal_mode: "memory",
journal_size_limit: 32_768,
locking_mode: "normal",
synchronous: 2,
temp_store: 0,
automatic_index: true,
busy_timeout: 0,
soft_heap_limit: 0,
threads: 0,
secure_delete: 2,
analysis_limit: 0,
wal_autocheckpoint: 1000,
cell_size_check: false,
checkpoint_fullfsync: false,
fullfsync: false,
trusted_schema: false,
empty_result_callbacks: false,
active_triggers: [],
sqlite_sequence_orphans: %{},
scalar_functions: %{},
aggregate_functions: %{},
collations: %{}
@typedoc """
A stored view: the parsed query and an optional explicit column list.
"""
@type view :: %{
name: String.t(),
schema: String.t() | nil,
columns: [String.t()] | nil,
query: term()
}
@typedoc """
A materialized CTE result, threaded through query execution.
"""
@type cte :: %{columns: [String.t()], rows: [[term()]], affinities: [atom()]}
@typedoc """
Open transactions and savepoints, newest first. Each entry snapshots both
tables and views as they were when it was opened — rollback restores both.
"""
@type txn_entry ::
{:begin | {:savepoint, String.t()},
%{tables: %{String.t() => Table.t()}, views: %{String.t() => view()}}}
@typedoc """
A stored trigger: parsed definition plus a creation sequence number that
fixes firing order.
"""
@type trigger :: %{
key: String.t(),
name: String.t(),
schema: String.t() | nil,
table_schema: String.t() | nil,
table_key: String.t(),
timing: :before | :after | :instead_of,
event: :insert | :delete | :update,
update_columns: [String.t()] | nil,
when: term() | nil,
body: [term()],
seq: non_neg_integer()
}
@typedoc """
A connection-local scalar function callback. It receives evaluated SQL values
and must return a SQL value, `{:ok, value}`, or `{:error, message}`.
"""
@type scalar_function :: %{
name: String.t(),
arity: non_neg_integer(),
callback: function()
}
@typedoc """
A connection-local collation callback. It receives two TEXT values and
returns `:lt`/`:eq`/`:gt`, a negative/zero/positive integer, or
`{:ok, result}` / `{:error, message}`.
"""
@type collation :: %{
name: String.t(),
callback: function()
}
@typedoc """
A connection-local aggregate callback. It receives a list of evaluated,
non-NULL argument rows and must return a SQL value, `{:ok, value}`, or
`{:error, message}`. Incremental window callbacks use the same registry but
store a callback map and `kind: :incremental_window`.
"""
@type aggregate_function :: %{
required(:name) => String.t(),
required(:arity) => non_neg_integer(),
required(:callback) => function() | map(),
optional(:kind) => :frame | :incremental_window
}
@type t :: %__MODULE__{
tables: %{String.t() => Table.t()},
views: %{String.t() => view()},
triggers: %{String.t() => trigger()},
attached_databases: [%{seq: pos_integer(), name: String.t(), file: String.t()}],
txn_stack: [txn_entry()],
ctes: %{String.t() => cte()},
changes: non_neg_integer(),
total_changes: non_neg_integer(),
last_insert_rowid: integer(),
foreign_keys: boolean(),
defer_foreign_keys: boolean(),
recursive_triggers: boolean(),
ignore_check_constraints: boolean(),
count_changes: boolean(),
read_uncommitted: boolean(),
case_sensitive_like: boolean(),
short_column_names: boolean(),
full_column_names: boolean(),
reverse_unordered_selects: boolean(),
query_only: boolean(),
schema_version: integer(),
user_version: integer(),
application_id: integer(),
schema_headers: %{String.t() => map()},
page_size: pos_integer(),
page_size_locked: boolean(),
cache_size: integer(),
max_page_count: pos_integer(),
cache_spill: non_neg_integer(),
default_cache_size: non_neg_integer(),
auto_vacuum: 0..2,
journal_mode: String.t(),
journal_size_limit: integer(),
locking_mode: String.t(),
synchronous: 0..3,
temp_store: 0..2,
automatic_index: boolean(),
busy_timeout: non_neg_integer(),
soft_heap_limit: non_neg_integer(),
threads: 0..8,
secure_delete: 0..2,
analysis_limit: non_neg_integer(),
wal_autocheckpoint: non_neg_integer(),
cell_size_check: boolean(),
checkpoint_fullfsync: boolean(),
fullfsync: boolean(),
trusted_schema: boolean(),
empty_result_callbacks: boolean(),
active_triggers: [String.t()],
sqlite_sequence_orphans: map(),
scalar_functions: %{{String.t(), non_neg_integer()} => scalar_function()},
aggregate_functions: %{{String.t(), non_neg_integer()} => aggregate_function()},
collations: %{String.t() => collation()}
}
@doc "Returns an empty database."
@spec new() :: t()
def new, do: %__MODULE__{}
@doc """
Registers or replaces a connection-local scalar function.
The callback arity must match the SQL arity exactly. The callback receives
evaluated SQL values as positional arguments and may return a SQL value,
`{:ok, value}`, or `{:error, message}`.
"""
@spec create_scalar_function(t(), String.t(), non_neg_integer(), function()) ::
{:ok, t()} | {:error, String.t()}
def create_scalar_function(db, name, arity, callback) do
cond do
not is_binary(name) or String.trim(name) == "" ->
{:error, "function name must be a non-empty string"}
not is_integer(arity) or arity < 0 or arity > 127 ->
{:error, "function arity must be between 0 and 127"}
not is_function(callback) ->
{:error, "function callback must be a function"}
function_arity(callback) != arity ->
{:error, "function callback arity must match SQL arity"}
true ->
key = scalar_function_key(name, arity)
function = %{name: String.downcase(name), arity: arity, callback: callback}
{:ok, %{db | scalar_functions: Map.put(db.scalar_functions, key, function)}}
end
end
@doc "Fetches a registered scalar function by name and exact arity."
@spec fetch_scalar_function(t(), String.t(), non_neg_integer()) ::
{:ok, scalar_function()} | :error
def fetch_scalar_function(db, name, arity) do
Map.fetch(db.scalar_functions, scalar_function_key(name, arity))
end
@doc "Returns true if any arity is registered for this scalar function name."
@spec scalar_function_exists?(t(), String.t()) :: boolean()
def scalar_function_exists?(db, name) do
key = Table.key(name)
Enum.any?(db.scalar_functions, fn {{registered, _arity}, _function} -> registered == key end)
end
defp scalar_function_key(name, arity), do: {Table.key(name), arity}
defp function_arity(callback) do
{:arity, arity} = Function.info(callback, :arity)
arity
end
@doc """
Registers or replaces a connection-local aggregate function.
The callback receives one list argument containing evaluated, non-NULL
argument rows. A one-argument SQL aggregate gets rows like `[[value], ...]`;
a two-argument aggregate gets rows like `[[a, b], ...]`.
"""
@spec create_aggregate_function(t(), String.t(), non_neg_integer(), function()) ::
{:ok, t()} | {:error, String.t()}
def create_aggregate_function(db, name, arity, callback) do
create_frame_function(db, name, arity, callback, "aggregate")
end
@doc """
Registers or replaces a connection-local aggregate window function.
The callback receives one list argument containing evaluated, non-NULL
argument rows from the current window frame.
"""
@spec create_window_function(t(), String.t(), non_neg_integer(), function()) ::
{:ok, t()} | {:error, String.t()}
def create_window_function(db, name, arity, callback) do
create_frame_function(db, name, arity, callback, "window function")
end
@doc """
Registers or replaces a connection-local incremental aggregate window
function.
Callback map keys follow SQLite's aggregate window model: `:init` (arity 0),
`:step` and `:inverse` (arity 2, receiving state and evaluated SQL argument
list), and `:value` / `:final` (arity 1).
"""
@spec create_incremental_window_function(t(), String.t(), non_neg_integer(), map()) ::
{:ok, t()} | {:error, String.t()}
def create_incremental_window_function(db, name, arity, callbacks) do
cond do
not is_binary(name) or String.trim(name) == "" ->
{:error, "incremental window function name must be a non-empty string"}
not is_integer(arity) or arity < 0 or arity > 127 ->
{:error, "incremental window function arity must be between 0 and 127"}
not is_map(callbacks) ->
{:error, "incremental window function callbacks must be a map"}
error = invalid_incremental_window_callbacks(callbacks) ->
{:error, error}
true ->
key = aggregate_function_key(name, arity)
function = %{
name: String.downcase(name),
arity: arity,
callback: callbacks,
kind: :incremental_window
}
{:ok, %{db | aggregate_functions: Map.put(db.aggregate_functions, key, function)}}
end
end
defp create_frame_function(db, name, arity, callback, label) do
cond do
not is_binary(name) or String.trim(name) == "" ->
{:error, "#{label} name must be a non-empty string"}
not is_integer(arity) or arity < 0 or arity > 127 ->
{:error, "#{label} arity must be between 0 and 127"}
not is_function(callback) ->
{:error, "#{label} callback must be a function"}
function_arity(callback) != 1 ->
{:error, "#{label} callback arity must be 1"}
true ->
key = aggregate_function_key(name, arity)
function = %{name: String.downcase(name), arity: arity, callback: callback, kind: :frame}
{:ok, %{db | aggregate_functions: Map.put(db.aggregate_functions, key, function)}}
end
end
defp invalid_incremental_window_callbacks(callbacks) do
[
init: 0,
step: 2,
inverse: 2,
value: 1,
final: 1
]
|> Enum.find_value(fn {name, arity} ->
callback = Map.get(callbacks, name)
cond do
not is_function(callback) ->
"incremental window function #{name} callback must be a function"
function_arity(callback) != arity ->
"incremental window function #{name} callback arity must be #{arity}"
true ->
nil
end
end)
end
@doc "Fetches a registered aggregate function by name and exact SQL arity."
@spec fetch_aggregate_function(t(), String.t(), non_neg_integer()) ::
{:ok, aggregate_function()} | :error
def fetch_aggregate_function(db, name, arity) do
Map.fetch(db.aggregate_functions, aggregate_function_key(name, arity))
end
@doc "Returns true if any arity is registered for this aggregate function name."
@spec aggregate_function_exists?(t(), String.t()) :: boolean()
def aggregate_function_exists?(db, name) do
key = Table.key(name)
Enum.any?(db.aggregate_functions, fn {{registered, _arity}, _function} ->
registered == key
end)
end
defp aggregate_function_key(name, arity), do: {Table.key(name), arity}
@doc """
Registers or replaces a connection-local collation.
The callback receives two TEXT values and returns `:lt`/`:eq`/`:gt`, a
negative/zero/positive integer, or `{:ok, result}` / `{:error, message}`.
"""
@spec create_collation(t(), String.t(), function()) :: {:ok, t()} | {:error, String.t()}
def create_collation(db, name, callback) do
cond do
not is_binary(name) or String.trim(name) == "" ->
{:error, "collation name must be a non-empty string"}
not is_function(callback) ->
{:error, "collation callback must be a function"}
function_arity(callback) != 2 ->
{:error, "collation callback arity must be 2"}
true ->
collation = %{name: String.downcase(name), callback: callback}
{:ok, %{db | collations: Map.put(db.collations, Table.key(name), collation)}}
end
end
@doc "Fetches a registered collation by name."
@spec fetch_collation(t(), String.t()) :: {:ok, collation()} | :error
def fetch_collation(db, name), do: Map.fetch(db.collations, Table.key(name))
@doc "Fetches a view by name (case-insensitive)."
@spec fetch_view(t(), String.t()) :: {:ok, view()} | :error
def fetch_view(db, name) do
fetch_view(db, nil, name)
end
@spec fetch_view(t(), String.t() | nil, String.t()) :: {:ok, view()} | :error
def fetch_view(db, schema, name) do
Map.fetch(db.views, table_storage_key(schema, name))
end
@doc """
Fetches a view by unqualified name using SQLite's schema lookup order.
"""
@spec lookup_view(t(), String.t()) :: {:ok, view()} | :error
def lookup_view(db, name) do
lookup_schema_object(db, :views, name)
end
@doc "Adds a new view; errors if the name collides with a table or existing view."
@spec create_view(t(), view()) :: {:ok, t()} | {:error, String.t()}
def create_view(db, view) do
key = table_storage_key(view.schema, view.name)
cond do
Map.has_key?(db.tables, key) ->
{:error, "table #{view.name} already exists"}
Map.has_key?(db.views, key) ->
{:error, "view #{view.name} already exists"}
index_exists?(db, view.schema, view.name) ->
{:error, "there is already an index named #{view.name}"}
true ->
{:ok, schema_changed(%{db | views: Map.put(db.views, key, view)}, view.schema)}
end
end
@doc "Removes a view; errors if it does not exist or if the name is a table."
@spec drop_view(t(), String.t()) :: {:ok, t()} | {:error, String.t()}
def drop_view(db, name) do
drop_view(db, nil, name)
end
@spec drop_view(t(), String.t() | nil, String.t()) :: {:ok, t()} | {:error, String.t()}
def drop_view(db, schema, name) do
key = table_storage_key(schema, name)
cond do
Map.has_key?(db.tables, key) ->
{:error, "use DROP TABLE to delete table #{name}"}
Map.has_key?(db.views, key) ->
{:ok, schema_changed(%{db | views: Map.delete(db.views, key)}, schema)}
true ->
{:error, "no such view: #{name}"}
end
end
@doc "Fetches a table by name (case-insensitive)."
@spec fetch_table(t(), String.t()) :: {:ok, Table.t()} | {:error, String.t()}
def fetch_table(db, name) do
case Map.fetch(db.tables, Table.key(name)) do
{:ok, table} -> {:ok, table}
:error -> {:error, "no such table: #{name}"}
end
end
@doc """
Fetches a table by unqualified name using SQLite's schema lookup order.
"""
@spec lookup_table(t(), String.t()) :: {:ok, Table.t()} | {:error, String.t()}
def lookup_table(db, name) do
case lookup_schema_object(db, :tables, name) do
{:ok, table} -> {:ok, table}
:error -> {:error, "no such table: #{name}"}
end
end
@doc "Adds a new table; errors if the name collides with a table or existing view."
@spec create_table(t(), Table.t()) :: {:ok, t()} | {:error, String.t()}
def create_table(db, table) do
table_key = table_storage_key(table.schema, table.name)
cond do
Map.has_key?(db.tables, table_key) ->
{:error, "table #{table.name} already exists"}
Map.has_key?(db.views, table_key) ->
{:error, "there is already a table named #{table.name}"}
index_exists?(db, table.schema, table.name) ->
{:error, "there is already an index named #{table.name}"}
true ->
{:ok, db |> put_table(table) |> schema_changed(table.schema)}
end
end
@doc "Replaces a table's state after DML."
@spec put_table(t(), Table.t()) :: t()
def put_table(db, table),
do: %{db | tables: Map.put(db.tables, table_storage_key(table.schema, table.name), table)}
@doc "Removes a table; errors if it does not exist or if the name is a view."
@spec drop_table(t(), String.t()) :: {:ok, t()} | {:error, String.t()}
def drop_table(db, name) do
drop_table(db, nil, name)
end
@spec drop_table(t(), String.t() | nil, String.t()) :: {:ok, t()} | {:error, String.t()}
def drop_table(db, schema, name) do
table_key = table_storage_key(schema, name)
cond do
Map.has_key?(db.tables, table_key) ->
{:ok, schema_changed(%{db | tables: Map.delete(db.tables, table_key)}, schema)}
Map.has_key?(db.views, table_key) ->
{:error, "use DROP VIEW to delete view #{name}"}
true ->
{:error, "no such table: #{name}"}
end
end
@doc "Internal storage key for schema-qualified tables."
@spec table_storage_key(String.t() | nil, String.t()) :: String.t()
def table_storage_key(schema, name) when schema in [nil, "main"], do: Table.key(name)
def table_storage_key(schema, name),
do: Table.key(schema) <> "." <> Table.key(name)
@doc """
Finds which table owns the index with the given name, or nil.
Index lookup is database-global and case-insensitive.
"""
@spec find_index_owner(t(), String.t()) :: {Table.t(), map()} | nil
def find_index_owner(db, index_name) do
find_index_owner(db, :any, index_name)
end
@doc """
Finds which table owns the index with the given name within a schema, or nil.
"""
@spec find_index_owner(t(), String.t() | nil, String.t()) :: {Table.t(), map()} | nil
def find_index_owner(db, schema, index_name) do
key = Table.key(index_name)
Enum.find_value(db.tables, fn {_table_key, table} ->
if index_schema_matches?(table.schema, schema) do
case Enum.find(table.indexes, &(Table.key(&1.name) == key)) do
nil -> nil
index -> {table, index}
end
else
nil
end
end)
end
@doc "Returns true if any table has an index with this name."
@spec index_exists?(t(), String.t()) :: boolean()
def index_exists?(db, index_name) do
find_index_owner(db, index_name) != nil
end
@doc "Returns true if a table in the schema has an index with this name."
@spec index_exists?(t(), String.t() | nil, String.t()) :: boolean()
def index_exists?(db, schema, index_name) do
find_index_owner(db, schema, index_name) != nil
end
defp index_schema_matches?(_table_schema, :any), do: true
defp index_schema_matches?(table_schema, schema) do
table_storage_key(table_schema, "") == table_storage_key(schema, "")
end
defp lookup_schema_object(db, field, name) do
objects = Map.fetch!(db, field)
Enum.find_value(schema_lookup_order(db), :error, fn schema ->
case Map.fetch(objects, table_storage_key(schema, name)) do
{:ok, object} -> {:ok, object}
:error -> nil
end
end)
end
defp schema_lookup_order(db), do: ["temp", nil] ++ Enum.map(db.attached_databases, & &1.name)
@doc "Returns a snapshot of mutable schema state (tables, views, triggers) for transactions."
@spec schema_snapshot(t()) :: %{
tables: map(),
views: map(),
triggers: map(),
sqlite_sequence_orphans: map(),
schema_version: integer(),
user_version: integer(),
application_id: integer(),
schema_headers: map()
}
def schema_snapshot(db) do
%{
tables: db.tables,
views: db.views,
triggers: db.triggers,
sqlite_sequence_orphans: db.sqlite_sequence_orphans,
schema_version: db.schema_version,
user_version: db.user_version,
application_id: db.application_id,
schema_headers: db.schema_headers
}
end
@doc "Restores schema from a snapshot."
@spec restore_schema(t(), map()) :: t()
def restore_schema(db, snapshot) do
%{
db
| tables: snapshot.tables,
views: snapshot.views,
triggers: Map.get(snapshot, :triggers, db.triggers),
sqlite_sequence_orphans:
Map.get(snapshot, :sqlite_sequence_orphans, db.sqlite_sequence_orphans),
schema_version: Map.get(snapshot, :schema_version, db.schema_version),
user_version: Map.get(snapshot, :user_version, db.user_version),
application_id: Map.get(snapshot, :application_id, db.application_id),
schema_headers: Map.get(snapshot, :schema_headers, db.schema_headers)
}
end
@doc "Bumps the schema cookie using SQLite's signed 32-bit wraparound."
@spec schema_changed(t()) :: t()
def schema_changed(db), do: schema_changed(db, nil)
@spec schema_changed(t(), String.t() | nil) :: t()
def schema_changed(db, schema) do
version = schema_header_value(db, schema, :schema_version)
db
|> put_schema_header_value(schema, :schema_version, next_schema_version(version))
|> Map.put(:page_size_locked, true)
end
@doc "Returns a schema header value for main or an attached schema."
@spec schema_header_value(
t(),
String.t() | nil,
:schema_version | :user_version | :application_id
) ::
integer()
def schema_header_value(db, schema, field) do
db
|> schema_header(schema)
|> Map.fetch!(field)
end
@doc "Stores a schema header value for main or an attached schema."
@spec put_schema_header_value(
t(),
String.t() | nil,
:schema_version | :user_version | :application_id,
integer()
) :: t()
def put_schema_header_value(db, schema, field, value) when schema in [nil, "main"] do
Map.put(db, field, value)
end
def put_schema_header_value(db, schema, field, value) do
key = Table.key(schema)
header = schema_header(db, schema)
%{db | schema_headers: Map.put(db.schema_headers, key, Map.put(header, field, value))}
end
@doc "Removes attached schema header state."
@spec drop_schema_header(t(), String.t()) :: t()
def drop_schema_header(db, schema) do
%{db | schema_headers: Map.delete(db.schema_headers, Table.key(schema))}
end
defp schema_header(db, schema) when schema in [nil, "main"] do
%{
schema_version: db.schema_version,
user_version: db.user_version,
application_id: db.application_id
}
end
defp schema_header(db, schema) do
Map.get(db.schema_headers, Table.key(schema), %{
schema_version: 0,
user_version: 0,
application_id: 0
})
end
defp next_schema_version(2_147_483_647), do: -2_147_483_648
defp next_schema_version(value), do: value + 1
@doc "Records the row-count state exposed by changes() and total_changes()."
@spec record_changes(t(), non_neg_integer()) :: t()
def record_changes(db, count), do: record_changes(db, count, nil)
@spec record_changes(t(), non_neg_integer(), integer() | nil) :: t()
def record_changes(db, count, last_insert_rowid) do
db = %{db | changes: count, total_changes: db.total_changes + count}
if last_insert_rowid == nil do
db
else
%{db | last_insert_rowid: last_insert_rowid}
end
end
end