Skip to main content

lib/skuld/fiber_pool/batching.ex

# Batch grouping, execution, and fiber resumption for the FiberPool.
#
# This module provides the complete batch lifecycle:
# - Group suspended fibers by their batch_key
# - Execute batch groups using registered executors
# - Match results back to the requesting fibers
# - Pop batch suspensions from state, execute them, and resume fibers with results
defmodule Skuld.FiberPool.Batching do
  @moduledoc false

  alias Skuld.Comp
  alias Skuld.Comp.Throw
  alias Skuld.Comp.InternalSuspend
  alias Skuld.FiberPool.BatchExecutor
  alias Skuld.FiberPool.FiberPoolState

  @type fiber_id :: reference()
  @type batch_key :: term()

  @doc """
  Group suspended fibers by batch_key.

  Returns a map of `batch_key => [{fiber_id, InternalSuspend.t()}]`.

  The batch_key is stored directly in the `InternalSuspend.Batch` payload,
  set at suspension time by the caller.

  ## Example

      suspended = [
        {fid1, %InternalSuspend{payload: %InternalSuspend.Batch{batch_key: {:db_fetch, User}, ...}, ...}},
        {fid2, %InternalSuspend{payload: %InternalSuspend.Batch{batch_key: {:db_fetch, User}, ...}, ...}},
        {fid3, %InternalSuspend{payload: %InternalSuspend.Batch{batch_key: {:db_fetch, Post}, ...}, ...}}
      ]

      groups = Batching.group_suspended(suspended)
      # groups = %{
      #   {:db_fetch, User} => [{fid1, suspend1}, {fid2, suspend2}],
      #   {:db_fetch, Post} => [{fid3, suspend3}]
      # }
  """
  @spec group_suspended([{fiber_id, InternalSuspend.t()}]) ::
          %{batch_key => [{fiber_id, InternalSuspend.t()}]}
  def group_suspended(suspended_fibers) do
    Enum.group_by(suspended_fibers, fn {_fid, suspend} ->
      suspend.payload.batch_key
    end)
  end

  @doc """
  Execute a batch group using the registered executor.

  Returns a computation that yields `[{fiber_id, result}]` - a list of
  fiber IDs paired with their individual results.

  Raises if no executor is registered for the batch_key.

  ## Parameters

  - `batch_key` - The batch key for this group
  - `group` - List of `{fiber_id, InternalSuspend.t()}` tuples
  - `env` - The current environment (for executor lookup)
  """
  @spec execute_group(batch_key, [{fiber_id, InternalSuspend.t()}], Comp.Types.env()) ::
          Comp.Types.computation()
  def execute_group(batch_key, group, env) do
    # Build the ops list for the executor: [{request_id, op}]
    ops =
      Enum.map(group, fn {_fid, suspend} -> {suspend.payload.request_id, suspend.payload.op} end)

    case BatchExecutor.get_executor(env, batch_key) do
      nil ->
        # Return a computation that yields a Throw struct directly
        # (not through the Throw effect, which would need a handler)
        fn e, _k -> {%Throw{error: {:no_batch_executor, batch_key}}, e} end

      executor ->
        # Execute the batch and map results back to fiber_ids
        Comp.bind(executor.(ops), fn results ->
          fiber_results =
            Enum.map(group, fn {fiber_id, suspend} ->
              result = Map.fetch!(results, suspend.payload.request_id)
              {fiber_id, result}
            end)

          fiber_results
        end)
    end
  end

  @doc """
  Execute all batch groups.

  Returns a computation that yields a flat list of `{fiber_id, result}` tuples
  for all fibers across all batch groups.
  """
  @spec execute_all_groups(%{batch_key => [{fiber_id, InternalSuspend.t()}]}, Comp.Types.env()) ::
          Comp.Types.computation()
  def execute_all_groups(groups, _env) when map_size(groups) == 0 do
    []
  end

  def execute_all_groups(groups, env) do
    # Execute each group and collect results
    group_list = Map.to_list(groups)

    Enum.reduce(group_list, [], fn {batch_key, group}, acc_comp ->
      Comp.bind(acc_comp, fn acc_results ->
        Comp.bind(execute_group(batch_key, group, env), fn group_results ->
          acc_results ++ group_results
        end)
      end)
    end)
  end

  #############################################################################
  ## Batch Execution and Fiber Resumption
  #############################################################################

  @doc """
  Pop all pending batch suspensions from state, execute them, and resume
  the suspended fibers with their results.

  Groups suspensions by batch_key, executes each group via the registered
  executor, and enqueues the fibers to run with their results.

  Returns `{state, env}` with fibers re-enqueued.
  """
  @spec execute_pending_batches(FiberPoolState.t(), Comp.Types.env()) ::
          {FiberPoolState.t(), Comp.Types.env()}
  def execute_pending_batches(state, env) do
    {suspensions, state} = FiberPoolState.pop_all_batch_suspensions(state)

    if suspensions == [] do
      {state, env}
    else
      groups = group_suspended(suspensions)

      Enum.reduce(groups, {state, env}, fn {batch_key, group}, {acc_state, acc_env} ->
        execute_and_resume(acc_state, acc_env, batch_key, group)
      end)
    end
  end

  # Execute a single batch group and resume its fibers with results
  defp execute_and_resume(state, env, batch_key, group) do
    batch_comp = execute_group(batch_key, group, env)

    # Run the batch computation
    case Comp.call(batch_comp, env, &Comp.identity_k/2) do
      {%Throw{error: error}, new_env} ->
        # Batch execution failed - resume all fibers with error
        state =
          Enum.reduce(group, state, fn {fiber_id, _suspend}, acc ->
            resume_fiber_with_result(acc, fiber_id, {:error, error})
          end)

        {state, new_env}

      {fiber_results, new_env} when is_list(fiber_results) ->
        # Resume each fiber with its result
        state =
          Enum.reduce(fiber_results, state, fn {fiber_id, result}, acc ->
            resume_fiber_with_result(acc, fiber_id, {:ok, result})
          end)

        {state, new_env}
    end
  end

  # Resume a fiber with a batch result by enqueuing it with a wake marker
  defp resume_fiber_with_result(state, fiber_id, result) do
    case FiberPoolState.get_fiber(state, fiber_id) do
      nil ->
        state

      _fiber ->
        state = FiberPoolState.remove_batch_suspension(state, fiber_id)

        wake_value = unwrap_batch_result(result)

        state =
          put_in(state, [Access.key(:wake_signals), fiber_id], {:batch_wake, wake_value})

        FiberPoolState.enqueue(state, fiber_id)
    end
  end

  defp unwrap_batch_result({:ok, value}), do: value

  defp unwrap_batch_result({:error, reason}) do
    %Throw{error: reason}
  end
end