Skip to main content

lib/skuld/fiber_pool/tasks.ex

# BEAM Task integration for FiberPool.
#
# This module handles spawning and receiving results from BEAM Tasks
# that run in parallel processes. Tasks are used when you need true
# parallelism (multiple CPU cores) rather than cooperative concurrency
# (fibers in a single process).
#
# ## Usage
#
# Tasks are spawned via `FiberPool.task/2` and awaited like fibers:
#
#     comp do
#       # Spawn a task (runs in separate process)
#       h <- FiberPool.task(fn -> expensive_cpu_work() end)
#
#       # Await result (suspends until task completes)
#       result <- FiberPool.await!(h)
#     end
#
# ## Integration
#
# The FiberPool scheduler calls these functions to:
# 1. Spawn pending tasks after the main computation runs
# 2. Wait for task completion messages
# 3. Record task results for await satisfaction
defmodule Skuld.FiberPool.Tasks do
  @moduledoc false

  alias Skuld.FiberPool.FiberPoolState
  alias Skuld.Coroutine.Error
  alias Skuld.Comp.Throw

  @type task_info :: {reference(), (-> term()), keyword()}

  @doc """
  Spawn pending tasks using the Task.Supervisor.

  Takes a list of `{handle_id, thunk, opts}` tuples and spawns each
  as an async task. The task results are sent back as messages and
  handled by `receive_message/1`.
  """
  @spec spawn_pending(FiberPoolState.t(), [task_info()]) :: FiberPoolState.t()
  def spawn_pending(state, []), do: state

  def spawn_pending(%{task_supervisor: nil}, [_ | _]) do
    raise ArgumentError, """
    Task.task/2 requires a Task.Supervisor, but none is installed.

    Wrap your computation with Task.with_task_supervisor/1:

        comp
        |> FiberPool.with_handler()
        |> Task.with_handler()
        |> Task.with_task_supervisor()
        |> Comp.run!()
    """
  end

  def spawn_pending(state, pending_tasks) do
    task_sup = state.task_supervisor

    Enum.reduce(pending_tasks, state, fn {handle_id, thunk, opts}, acc ->
      _timeout = Keyword.get(opts, :timeout, 5000)

      # Spawn the task - it runs the thunk and sends result back
      task =
        Task.Supervisor.async_nolink(task_sup, fn ->
          # Call the thunk directly - no env/effects, just pure computation
          thunk.()
        end)

      # Track the task by its ref
      FiberPoolState.add_task(acc, task.ref, handle_id)
    end)
  end

  @doc """
  Wait for all remaining tasks to complete.

  Blocks until all tracked tasks have sent their completion messages.
  Returns the updated state with all task results recorded.
  """
  @spec wait_for_all(FiberPoolState.t()) :: FiberPoolState.t()
  def wait_for_all(state) do
    if FiberPoolState.has_tasks?(state) do
      {:task_completed, state} = receive_message(state)
      wait_for_all(state)
    else
      state
    end
  end

  @doc """
  Receive and handle a single task message.

  Blocks until a task completion or crash message is received.
  Records the result (success or error) in the state.

  Returns `{:task_completed, state}` with the updated state.
  """
  @spec receive_message(FiberPoolState.t()) :: {:task_completed, FiberPoolState.t()}
  def receive_message(state) do
    receive do
      {ref, result} when is_reference(ref) ->
        # Task completed
        Process.demonitor(ref, [:flush])
        handle_task_result(state, ref, result)

      {:DOWN, ref, :process, _pid, reason} ->
        # Task crashed
        handle_task_crash(state, ref, reason)
    end
  end

  #############################################################################
  ## Internal
  #############################################################################

  defp handle_task_result(state, ref, result) do
    case FiberPoolState.pop_task(state, ref) do
      {nil, state} ->
        {:task_completed, state}

      {handle_id, state} ->
        completion =
          case result do
            %Throw{error: error} ->
              error = normalize_task_throw_error(error)
              {:error, error}

            _ ->
              {:ok, result}
          end

        state = FiberPoolState.record_completion(state, handle_id, completion)
        {:task_completed, state}
    end
  end

  defp handle_task_crash(state, ref, reason) do
    case FiberPoolState.pop_task(state, ref) do
      {nil, state} ->
        {:task_completed, state}

      {handle_id, state} ->
        error =
          case reason do
            {exception, stacktrace} when is_list(stacktrace) ->
              %Error{type: :exception, error: exception, stacktrace: stacktrace}

            _ ->
              %Error{type: :exit, error: reason}
          end

        state = FiberPoolState.record_completion(state, handle_id, {:error, error})
        {:task_completed, state}
    end
  end

  defp normalize_task_throw_error(%{kind: :error, payload: exception, stacktrace: stacktrace}) do
    %Error{type: :exception, error: exception, stacktrace: stacktrace}
  end

  defp normalize_task_throw_error(%{kind: :throw, payload: value, stacktrace: stacktrace}) do
    %Error{type: :throw, error: value, stacktrace: stacktrace}
  end

  defp normalize_task_throw_error(%{kind: :exit, payload: reason, stacktrace: stacktrace}) do
    %Error{type: :exit, error: reason, stacktrace: stacktrace}
  end

  defp normalize_task_throw_error(plain_value) do
    %Error{type: :throw, error: plain_value}
  end
end