Skip to main content

lib/ex_sql/ecto/connection.ex

defmodule ExSQL.Ecto.Connection do
  @moduledoc false

  use DBConnection

  alias ExSQL.{Database, Executor, FileFormat, Log, Registry}
  alias ExSQL.Ecto.{Error, Query, Result}

  @default_checkpoint_after 256

  defstruct db: nil,
            database: nil,
            dirty?: false,
            transaction_status: :idle,
            status: :idle,
            journal_mode: :memory,
            synced_version: 0,
            persist: :file,
            log_count: 0,
            checkpoint_after: @default_checkpoint_after

  @type t :: %__MODULE__{
          db: {module(), reference()},
          database: String.t() | :memory | nil,
          dirty?: boolean(),
          transaction_status: :idle | :transaction,
          status: :idle | :busy,
          journal_mode: atom() | String.t(),
          synced_version: non_neg_integer() | :no_registry,
          persist: :file | :log,
          log_count: non_neg_integer(),
          checkpoint_after: pos_integer()
        }

  @impl true
  def connect(opts) do
    database = Keyword.get(opts, :database, :memory)
    journal_mode = Keyword.get(opts, :journal_mode, :memory)

    case persist_mode(Keyword.get(opts, :persist, :file), database) do
      :log ->
        with {:ok, db} <- Log.open(database, sync: Keyword.get(opts, :log_sync, false)) do
          {:ok,
           %__MODULE__{
             db: db_ref(db),
             database: database,
             journal_mode: journal_mode,
             persist: :log,
             checkpoint_after: Keyword.get(opts, :checkpoint_after, @default_checkpoint_after)
           }}
        end

      :file ->
        # Capture the version *before* reading the file: if a write lands in
        # between, we record the older version and reload on our next read
        # rather than trusting a snapshot that is already one commit stale.
        synced_version = Registry.current_version(database)

        with {:ok, db} <- open_database(database) do
          {:ok,
           %__MODULE__{
             db: db_ref(db),
             database: database,
             journal_mode: journal_mode,
             synced_version: synced_version
           }}
        end
    end
  end

  # The redo log needs a real file path; in-memory databases stay on :file.
  defp persist_mode(:log, database) when is_binary(database) and database != ":memory:", do: :log
  defp persist_mode(_other, _database), do: :file

  @impl true
  def disconnect(_err, %__MODULE__{persist: :log} = state) do
    _ = Log.flush(state.database)
    drop_db(state)
    :ok
  end

  def disconnect(_err, state) do
    _ = persist(state)
    drop_db(state)
    :ok
  end

  @impl true
  def checkout(%__MODULE__{status: :idle} = state) do
    {:ok, %{state | status: :busy}}
  end

  def checkout(%__MODULE__{status: :busy} = state) do
    {:disconnect, %Error{message: "Database is busy"}, state}
  end

  @impl true
  def ping(state), do: {:ok, state}

  @impl true
  def handle_prepare(%Query{} = query, _opts, state) do
    {:ok, query, state}
  end

  @impl true
  def handle_execute(%Query{} = query, params, _opts, state) do
    execute_query(query, params, state)
  end

  @impl true
  def handle_begin(opts, state) do
    state = maybe_reload(state)
    mode = Keyword.get(opts, :mode, :deferred)

    sql =
      if state.transaction_status == :idle do
        begin_sql(mode)
      else
        "SAVEPOINT exsql_savepoint"
      end

    run_transaction_sql(sql, state, :transaction, :without_query)
  end

  @impl true
  def handle_commit(opts, state) do
    mode = Keyword.get(opts, :mode, :deferred)

    sql =
      if mode == :savepoint do
        "RELEASE SAVEPOINT exsql_savepoint"
      else
        "COMMIT"
      end

    status =
      if outer_transaction_end?(state, mode) do
        :idle
      else
        :transaction
      end

    run_transaction_sql(sql, state, status, :without_query)
  end

  @impl true
  def handle_rollback(opts, state) do
    mode = Keyword.get(opts, :mode, :deferred)

    sql =
      if mode == :savepoint do
        "ROLLBACK TO SAVEPOINT exsql_savepoint; RELEASE SAVEPOINT exsql_savepoint"
      else
        "ROLLBACK"
      end

    status =
      if outer_transaction_end?(state, mode) do
        :idle
      else
        :transaction
      end

    case run_transaction_sql(sql, state, status, :without_query) do
      {:ok, result, state} ->
        if mode == :savepoint do
          {:ok, result, state}
        else
          {:ok, result, %{state | dirty?: false}}
        end

      other ->
        other
    end
  end

  @impl true
  def handle_status(_opts, state), do: {state.transaction_status, state}

  @impl true
  def handle_close(_query, _opts, state), do: {:ok, nil, state}

  @impl true
  def handle_declare(%Query{} = query, params, _opts, state) do
    state = maybe_reload(state)
    statement = IO.iodata_to_binary(query.statement)

    case Executor.run(db(state), statement, params) do
      {:ok, [result], new_db} ->
        cursor = %{result: Result.from_exsql(result), offset: 0}
        dirty? = mutating?(result.command, false)

        state =
          state
          |> put_db(new_db, db_changed?(result.command, false))
          |> Map.put(:dirty?, state.dirty? or dirty?)

        {:ok, query, cursor, state}

      {:ok, results, new_db} ->
        {:error,
         %Error{
           message: "expected one statement, got #{length(results)}",
           statement: statement
         }, state |> put_db(new_db) |> Map.put(:dirty?, true)}

      {:error, error, new_db} ->
        {:error, Error.from_exsql(error, statement),
         state |> put_db(new_db) |> Map.put(:dirty?, true)}
    end
  end

  @impl true
  def handle_fetch(_query, %{result: result, offset: offset} = cursor, opts, state) do
    max_rows = Keyword.get(opts, :max_rows, 500)
    rows = Enum.slice(result.rows, offset, max_rows)
    next_offset = offset + length(rows)
    result = %{result | rows: rows, num_rows: length(rows)}

    if next_offset >= length(cursor.result.rows) do
      {:halt, result, state}
    else
      {:cont, result, state}
    end
  end

  @impl true
  def handle_deallocate(_query, _cursor, _opts, state), do: {:ok, nil, state}

  defp execute_query(%Query{} = query, params, state) do
    state = maybe_reload(state)
    statement = IO.iodata_to_binary(query.statement)
    mutating_statement = mutating_statement?(statement)

    case Executor.run(db(state), statement, params) do
      {:ok, [result], new_db} ->
        dirty? = mutating?(result.command, mutating_statement)

        state =
          state
          |> put_db(new_db, db_changed?(result.command, mutating_statement))
          |> Map.put(:dirty?, state.dirty? or dirty?)
          |> persist_after(result.command, mutating_statement, statement, params)

        {:ok, query, Result.from_exsql(result), state}

      {:ok, results, new_db} ->
        {:error,
         %Error{
           message: "expected one statement, got #{length(results)}",
           statement: statement
         }, state |> put_db(new_db) |> Map.put(:dirty?, true)}

      {:error, error, new_db} ->
        {:error, Error.from_exsql(error, statement),
         state |> put_db(new_db) |> Map.put(:dirty?, true)}
    end
  end

  defp run_transaction_sql(sql, state, status, return_shape) do
    query = Query.build(statement: sql)
    mutating_statement = mutating_statement?(sql)

    case Executor.run(db(state), sql, []) do
      {:ok, [result], new_db} ->
        command = result.command
        result = Result.from_exsql(result)

        state =
          state
          |> put_db(new_db)
          |> Map.merge(%{transaction_status: status, dirty?: true})
          |> persist_after(command, mutating_statement, sql, [])

        transaction_result(query, result, state, return_shape)

      {:ok, results, new_db} ->
        command = results |> List.last() |> Map.get(:command)
        result = results |> List.last() |> Result.from_exsql()

        state =
          state
          |> put_db(new_db)
          |> Map.merge(%{transaction_status: status, dirty?: true})
          |> persist_after(command, mutating_statement, sql, [])

        transaction_result(query, result, state, return_shape)

      {:error, error, new_db} ->
        {:error, Error.from_exsql(error, sql), state |> put_db(new_db) |> Map.put(:dirty?, true)}
    end
  end

  defp transaction_result(_query, result, state, :without_query), do: {:ok, result, state}

  defp open_database(database) when database in [:memory, ":memory:"] do
    {:ok, Database.new()}
  end

  defp open_database(path) when is_binary(path) do
    if File.exists?(path) do
      case FileFormat.read(path) do
        {:ok, db} -> {:ok, db}
        {:error, message} -> {:error, %Error{message: message}}
      end
    else
      path |> Path.dirname() |> File.mkdir_p!()
      {:ok, Database.new()}
    end
  end

  defp open_database(_database), do: {:ok, Database.new()}

  # In :log mode the in-memory db is authoritative and the base file is only
  # checkpoint-current, so reloading from it would lose un-checkpointed commits.
  defp maybe_reload(%__MODULE__{persist: :log} = state), do: state

  defp maybe_reload(
         %__MODULE__{database: database, transaction_status: :idle, dirty?: false} = state
       )
       when is_binary(database) and database != ":memory:" do
    current = Registry.current_version(database)

    cond do
      current != :no_registry and current == state.synced_version ->
        # Our in-memory database already reflects the latest committed write
        # (no other connection has committed since), so skip re-parsing the
        # whole file. This is the common single-writer hot path.
        state

      File.exists?(database) ->
        case FileFormat.read(database) do
          {:ok, db} -> state |> put_db(db) |> Map.put(:synced_version, current)
          {:error, _message} -> state
        end

      true ->
        state
    end
  end

  defp maybe_reload(state), do: state

  defp persist(%__MODULE__{database: database, dirty?: true} = state)
       when is_binary(database) and database != ":memory:" do
    case FileFormat.write(db(state), database, journal_mode: state.journal_mode) do
      {:ok, _path} -> {:ok, Registry.bump(database)}
      {:error, message} -> {:error, %Error{message: message}}
    end
  end

  defp persist(_state), do: :ok

  defp persist_after(%__MODULE__{persist: :log} = state, command, _mutating, sql, params) do
    maybe_log(state, command, sql, params)
  end

  defp persist_after(state, command, mutating_statement, _sql, _params) do
    maybe_persist_after(state, command, mutating_statement)
  end

  # Pure reads have no durable effect, so they are not logged. Everything else
  # (DML, DDL, transaction control, pragmas) is appended verbatim and replayed
  # in order on open; the engine reproduces transaction semantics from the
  # statement stream. See `ExSQL.Log`.
  defp maybe_log(state, command, _sql, _params) when command in [:select, :explain], do: state

  defp maybe_log(state, _command, sql, params) do
    :ok = Log.append(state.database, [{sql, params}])
    maybe_checkpoint(%{state | dirty?: false, log_count: state.log_count + 1})
  end

  # Fold the log into the base only between transactions (txn_stack empty) —
  # never mid-transaction, which would persist uncommitted rows or strand
  # committed ones.
  defp maybe_checkpoint(%__MODULE__{log_count: count, checkpoint_after: threshold} = state)
       when count >= threshold do
    case db(state) do
      %Database{txn_stack: []} ->
        :ok = Log.checkpoint(state.database)
        %{state | log_count: 0}

      %Database{} ->
        state
    end
  end

  defp maybe_checkpoint(state), do: state

  defp maybe_persist_after(state, command, mutating_statement) do
    if state.transaction_status == :idle and mutating?(command, mutating_statement) do
      case persist(state) do
        {:ok, version} -> %{state | dirty?: false, synced_version: version}
        :ok -> %{state | dirty?: false}
        {:error, _error} -> state
      end
    else
      state
    end
  end

  defp mutating?(_command, true), do: true

  defp mutating?(command, false) do
    command not in [
      nil,
      :select,
      :pragma,
      :explain,
      :begin,
      :savepoint,
      :rollback
    ]
  end

  defp db_changed?(command, mutating_statement) do
    mutating_statement or command not in [:select, :explain]
  end

  defp mutating_statement?(statement) when is_binary(statement) do
    statement =
      statement
      |> String.trim()
      |> String.upcase()

    Regex.match?(
      ~r/\A(?:WITH\b[\s\S]*?\b(INSERT|UPDATE|DELETE|REPLACE)\b|\b(INSERT|UPDATE|DELETE|REPLACE)\b)/,
      statement
    )
  end

  defp mutating_statement?(_statement), do: false

  defp begin_sql(:immediate), do: "BEGIN IMMEDIATE TRANSACTION"
  defp begin_sql(:exclusive), do: "BEGIN EXCLUSIVE TRANSACTION"
  defp begin_sql(_mode), do: "BEGIN TRANSACTION"

  defp outer_transaction_end?(state, mode) do
    state.transaction_status == :transaction and mode != :savepoint
  end

  defp db_ref(%Database{} = db) do
    key = {__MODULE__, make_ref()}
    :persistent_term.put(key, db)
    key
  end

  defp db(%__MODULE__{db: key}) do
    :persistent_term.get(key)
  end

  defp put_db(state, _db, false), do: state

  defp put_db(%__MODULE__{db: key} = state, %Database{} = db, true) do
    :persistent_term.put(key, db)
    state
  end

  defp put_db(%__MODULE__{db: key} = state, %Database{} = db) do
    :persistent_term.put(key, db)
    state
  end

  defp drop_db(%__MODULE__{db: nil}), do: :ok

  defp drop_db(%__MODULE__{db: key}) do
    :persistent_term.erase(key)
    :ok
  end
end