lib/graphql/dataloader.ex

defmodule AshGraphql.Dataloader do
  @moduledoc "The dataloader in charge of resolving "
  defstruct [
    :api,
    batches: %{},
    results: %{},
    default_params: %{}
  ]

  @type t :: %__MODULE__{
          api: Ash.Api.t(),
          batches: map,
          results: map,
          default_params: map
        }

  @type api_opts :: Keyword.t()
  @type batch_fun :: (Ash.Resource.t(), Ash.Query.t(), any, [any], api_opts -> [any])

  import AshGraphql.TraceHelpers

  @doc """
  Create an Ash Dataloader source.
  This module handles retrieving data from Ash for dataloader. It requires a
  valid Ash API.
  """
  @spec new(Ash.Api.t()) :: t
  def new(api) do
    %__MODULE__{api: api}
  end

  defimpl Dataloader.Source do
    def run(source) do
      results = Dataloader.async_safely(__MODULE__, :run_batches, [source])

      results =
        Map.merge(source.results, results, fn _, {:ok, v1}, {:ok, v2} ->
          {:ok, Map.merge(v1, v2)}
        end)

      %{source | results: results, batches: %{}}
    end

    def fetch(source, batch_key, item) do
      {batch_key, item_key, _item} =
        batch_key
        |> normalize_key(source.default_params)
        |> get_keys(item)

      case Map.fetch(source.results, batch_key) do
        {:ok, batch} ->
          fetch_item_from_batch(batch, item_key)

        :error ->
          {:error, "Unable to find batch #{inspect(batch_key)}"}
      end
    end

    defp fetch_item_from_batch({:error, _reason} = tried_and_failed, _item_key),
      do: tried_and_failed

    defp fetch_item_from_batch({:ok, batch}, item_key) do
      case Map.fetch(batch, item_key) do
        :error -> {:error, "Unable to find item #{inspect(item_key)} in batch"}
        result -> result
      end
    end

    def put(source, _batch, _item, %Ash.NotLoaded{type: :relationship}) do
      source
    end

    def put(source, batch, item, result) do
      batch = normalize_key(batch, source.default_params)
      {batch_key, item_key, _item} = get_keys(batch, item)

      results =
        Map.update(
          source.results,
          batch_key,
          {:ok, %{item_key => result}},
          fn {:ok, map} -> {:ok, Map.put(map, item_key, result)} end
        )

      %{source | results: results}
    end

    def load(source, batch, item) do
      {batch_key, item_key, item} =
        batch
        |> normalize_key(source.default_params)
        |> get_keys(item)

      if fetched?(source.results, batch_key, item_key) do
        source
      else
        entry = {item_key, item}

        update_in(source.batches, fn batches ->
          Map.update(batches, batch_key, MapSet.new([entry]), &MapSet.put(&1, entry))
        end)
      end
    end

    defp fetched?(results, batch_key, item_key) do
      case results do
        %{^batch_key => {:ok, %{^item_key => _}}} -> true
        _ -> false
      end
    end

    def pending_batches?(%{batches: batches}) do
      batches != %{}
    end

    def timeout(_) do
      Dataloader.default_timeout()
    end

    defp related(path, resource) do
      Ash.Resource.Info.related(resource, path) ||
        raise """
        Valid relationship for path #{inspect(path)} not found on resource #{inspect(resource)}
        """
    end

    defp get_keys({assoc_field, %{type: :relationship} = opts}, %resource{} = record)
         when is_atom(assoc_field) do
      validate_resource(resource)
      pkey = Ash.Resource.Info.primary_key(resource)
      id = Enum.map(pkey, &Map.get(record, &1))

      queryable = related([assoc_field], resource)

      {{:assoc, resource, self(), assoc_field, queryable, opts}, id, record}
    end

    defp get_keys({calc, %{type: :calculation} = opts}, %resource{} = record) do
      validate_resource(resource)
      pkey = Ash.Resource.Info.primary_key(resource)
      id = Enum.map(pkey, &Map.get(record, &1))

      {{:calc, resource, self(), calc, opts}, id, record}
    end

    defp get_keys(key, item) do
      raise """
      Invalid batch key: #{inspect(key)}
      #{inspect(item)}
      """
    end

    defp validate_resource(resource) do
      unless Ash.Resource.Info.resource?(resource) do
        raise "The given module - #{resource} - is not an Ash resouce."
      end
    end

    defp normalize_key({key, params}, default_params) do
      {key, Enum.into(params, default_params)}
    end

    defp normalize_key(key, default_params) do
      {key, default_params}
    end

    def run_batches(source) do
      options = [
        timeout: Dataloader.default_timeout(),
        on_timeout: :kill_task
      ]

      results =
        source.batches
        |> Task.async_stream(
          fn batch ->
            id = :erlang.unique_integer()
            system_time = System.system_time()
            start_time_mono = System.monotonic_time()

            emit_start_event(id, system_time, batch)
            batch_result = run_batch(batch, source)
            emit_stop_event(id, start_time_mono, batch)

            batch_result
          end,
          options
        )
        |> Enum.map(fn
          {:ok, {_key, result}} -> {:ok, result}
          {:exit, reason} -> {:error, reason}
        end)

      source.batches
      |> Enum.map(fn {key, _set} -> key end)
      |> Enum.zip(results)
      |> Map.new()
    end

    defp run_batch(
           {{:assoc, source_resource, _pid, field, _resource, opts} = key, records},
           source
         ) do
      tracer = AshGraphql.Api.Info.tracer(source.api)

      if tracer && opts[:span_context] do
        tracer.set_span_context(opts[:span_context])
      end

      resource_short_name = Ash.Resource.Info.short_name(source_resource)

      metadata = %{
        api: source.api,
        resource: source_resource,
        resource_short_name: resource_short_name,
        actor: opts[:api_opts][:actor],
        tenant: opts[:api_opts][:tenant],
        relationship: field,
        source: :graphql,
        authorize?: AshGraphql.Api.Info.authorize?(source.api)
      }

      trace source.api,
            source_resource,
            :gql_relationship_batch,
            "#{resource_short_name}.#{field}",
            metadata do
        {ids, records} = Enum.unzip(records)
        query = opts[:query]
        api_opts = opts[:api_opts]
        tenant = opts[:tenant] || tenant_from_records(records)
        empty = source_resource |> struct |> Map.fetch!(field)
        records = records |> Enum.map(&Map.put(&1, field, empty))
        relationship = Ash.Resource.Info.relationship(source_resource, field)

        cardinality = relationship.cardinality

        loads =
          if Map.has_key?(relationship, :manual) && relationship.manual do
            field
          else
            query =
              query
              |> Ash.Query.new()
              |> Ash.Query.set_tenant(tenant)
              |> Ash.Query.for_read(
                relationship.read_action ||
                  Ash.Resource.Info.primary_action!(relationship.destination, :read).name,
                opts[:args],
                api_opts
              )

            {field, query}
          end

        loaded = source.api.load!(records, [loads], api_opts || [])

        loaded =
          case loaded do
            %struct{results: results} when struct in [Ash.Page.Offset, Ash.Page.Keyset] ->
              results

            loaded ->
              loaded
          end

        results =
          case cardinality do
            :many ->
              Enum.map(loaded, fn record ->
                List.wrap(Map.get(record, field))
              end)

            :one ->
              Enum.map(loaded, fn record ->
                Map.get(record, field)
              end)
          end

        {key, Map.new(Enum.zip(ids, results))}
      end
    end

    defp run_batch(
           {{:calc, _, _pid, calc,
             %{resource: resource, args: args, api_opts: api_opts, span_context: span_context}} =
              key, records},
           source
         ) do
      tracer = AshGraphql.Api.Info.tracer(source.api)

      if tracer && span_context do
        tracer.set_span_context(span_context)
      end

      resource_short_name = Ash.Resource.Info.short_name(resource)

      metadata = %{
        api: source.api,
        resource: resource,
        resource_short_name: resource_short_name,
        actor: api_opts[:actor],
        tenant: api_opts[:tenant],
        calculation: calc,
        source: :graphql,
        authorize?: AshGraphql.Api.Info.authorize?(source.api)
      }

      trace source.api,
            resource,
            :gql_calculation_batch,
            "#{resource_short_name}.#{calc}.batch",
            metadata do
        {ids, records} = Enum.unzip(records)

        calculation = Ash.Resource.Info.calculation(resource, calc)

        results =
          records
          |> source.api.load!([{calc, args}], api_opts)
          |> Enum.map(&Map.get(&1, calc))

        results =
          if Ash.Type.NewType.new_type?(calculation.type) &&
               Ash.Type.NewType.subtype_of(Ash.Type.Union) &&
               function_exported?(calculation.type, :graphql_unnested_unions, 1) do
            unnested_types = calculation.type.graphql_unnested_unions(calculation.constraints)
            constraints = Ash.Type.NewType.constraints(calculation.type, calculation.constraints)

            Enum.map(results, fn
              nil ->
                nil

              %Ash.Union{type: type, value: value} = result ->
                if type in unnested_types do
                  if value do
                    type =
                      AshGraphql.Resource.field_type(
                        constraints[:types][type][:type],
                        calculation,
                        resource
                      )

                    Map.put(value, :__union_type__, type)
                  end
                else
                  result
                end
            end)
          else
            results
          end

        {key, Map.new(Enum.zip(ids, results))}
      end
    end

    defp tenant_from_records([%{__metadata__: %{tenant: tenant}}]) when not is_nil(tenant) do
      tenant
    end

    defp tenant_from_records(_), do: nil

    defp emit_start_event(id, system_time, batch) do
      :telemetry.execute(
        [:dataloader, :source, :batch, :run, :start],
        %{system_time: system_time},
        %{id: id, batch: batch}
      )
    end

    defp emit_stop_event(id, start_time_mono, batch) do
      :telemetry.execute(
        [:dataloader, :source, :batch, :run, :stop],
        %{duration: System.monotonic_time() - start_time_mono},
        %{id: id, batch: batch}
      )
    end
  end
end