lib/dagex/repo.ex

defmodule Dagex.Repo do
  @moduledoc """
  Adds Dagex-specific functionality to your application's `Ecto.Repo` module.

  ```elixir
  defmodule MyApp.Repo do
    use Ecto.Repo, otp_app: :my_app, adapter: Ecto.Adapters.Postgres
    use Dagex.Repo
  end
  ```
  """
  alias Dagex.Operations.{CreateEdge, RemoveEdge}

  @spec dagex_update(Ecto.Repo.t(), CreateEdge.t() | RemoveEdge.t()) ::
          CreateEdge.result() | RemoveEdge.result()
  @doc false
  def dagex_update(_repo, {:error, _reason} = error), do: error

  def dagex_update(repo, %CreateEdge{} = op) do
    result =
      case repo.query("SELECT dagex_create_edge($1, $2, $3)", [
             op.node_type,
             op.parent_id,
             op.child_id
           ]) do
        {:ok, _result} ->
          :ok

        {:error, %Postgrex.Error{postgres: %{constraint: constraint_name}}}
        when is_bitstring(constraint_name) ->
          {:error, constraint_name}
      end

    CreateEdge.process_result(result, op)
  end

  def dagex_update(repo, %RemoveEdge{} = op) do
    result =
      case repo.query("SELECT dagex_remove_edge($1, $2, $3)", [
             op.node_type,
             op.parent_id,
             op.child_id
           ]) do
        {:ok, _result} ->
          :ok

        {:error, %Postgrex.Error{postgres: %{constraint: constraint_name}}}
        when is_bitstring(constraint_name) ->
          {:error, constraint_name}
      end

    RemoveEdge.process_result(result, op)
  end

  @doc """
  Executes a Dagex repo operation such as `Dagex.Operations.CreateEdge` and processes the result.
  """
  @callback dagex_update(operation :: CreateEdge.t() | RemoveEdge.t()) ::
              CreateEdge.result() | RemoveEdge.result()

  @spec dagex_paths(repo :: Ecto.Repo.t(), queryable :: Ecto.Queryable.t()) ::
          list(list(Ecto.Schema.t()))
  @doc false
  def dagex_paths(repo, queryable) do
    queryable
    |> repo.all()
    |> Enum.group_by(fn node -> node.path end)
    |> Enum.map(fn {_path, nodes} ->
      nodes
      |> Enum.sort(&(&1.position <= &2.position))
      |> Enum.map(fn node -> node.node end)
    end)
  end

  @doc """
  Executes the query generated by `c:Dagex.all_paths/2` and processes the result
  into a list of paths where each path is a list of the nodes between (and
  including) the `ancestor` and the `descendant` nodes.
  """
  @callback dagex_paths(all_paths_query :: Ecto.Queryable.t()) :: list(list(Ecto.Schema.t()))

  @spec __using__(any()) :: Macro.t()
  defmacro __using__(_opts) do
    quote do
      @behaviour Dagex.Repo

      @impl Dagex.Repo
      @spec dagex_update(CreateEdge.t() | RemoveEdge.t()) ::
              CreateEdge.result() | RemoveEdge.result()
      def dagex_update(operation), do: Dagex.Repo.dagex_update(__MODULE__, operation)

      @impl Dagex.Repo
      @spec dagex_paths(Ecto.Queryable.t()) :: list(list(Ecto.Schema.t()))
      def dagex_paths(queryable), do: Dagex.Repo.dagex_paths(__MODULE__, queryable)
    end
  end
end