lib/lockstep/supervisor.ex

defmodule Lockstep.Supervisor do
  @moduledoc """
  A `:one_for_one` supervisor that runs under Lockstep's controller.
  Builds on `Lockstep.spawn_link/1` + `Lockstep.flag(:trap_exit, true)`
  so child crashes are observed via `{:EXIT, child, reason}` and
  restarts can be issued in-band with the rest of the test schedule.

  ## Supported

    * Strategy `:one_for_one`. (Other strategies aren't modeled in v1.)
    * Restart options `:permanent`, `:transient`, `:temporary`.
    * `max_restarts` / `max_seconds` intensity. When exceeded the
      supervisor itself exits with reason `:shutdown`.
    * `start_link/2`, `which_children/1`, `count_children/1`,
      `start_child/2`, `terminate_child/2`, `restart_child/2`.

  ## Child specs

      %{id: :counter, start: {Counter, :start_link, [42]}, restart: :permanent}

  Or the shorthand `{Counter, 42}`, which maps to
  `%{id: Counter, start: {Counter, :start_link, [42]}, restart: :permanent}`.

  Or the bare module `Counter`, which calls `Counter.child_spec(:no_arg)`
  if exported, otherwise `{Counter, :start_link, []}`.

  ## Caveats

    * `Lockstep.Supervisor` does not invoke `terminate/2` callbacks on
      shutdown; see `Lockstep.GenServer` for the same caveat.
    * `max_seconds` uses Lockstep's virtual clock (`Lockstep.now/0`),
      so timing-based restart-intensity is reproducible across replays.
  """

  alias Lockstep.GenServer, as: LGS

  @type child_id :: any()
  @type child_pid :: pid() | :undefined
  @type child_info :: {child_id(), child_pid(), :worker | :supervisor, [module()]}

  @default_max_restarts 3
  @default_max_seconds 5

  # ============================================================
  # Public API
  # ============================================================

  @doc """
  Start a supervisor. Two shapes:

    * `start_link(children_list, opts)` -- children are a literal
      list of child specs.
    * `start_link(module, init_arg, opts)` -- children are returned
      from `module.init(init_arg)`, which must return
      `{:ok, {sup_flags, child_specs}}` (typically via
      `Lockstep.Supervisor.init/2`).

  Options:

    * `:strategy` -- `:one_for_one` (default; only one supported).
    * `:max_restarts` -- default `3`.
    * `:max_seconds` -- default `5`.
  """
  @spec start_link([child_spec()] | module(), keyword() | any()) ::
          {:ok, pid()} | :ignore | {:error, term()}
  def start_link(children_or_module, opts_or_arg \\ [])

  def start_link(children, opts) when is_list(children) do
    do_start_with_children(children, opts)
  end

  def start_link(module, init_arg) when is_atom(module) do
    start_link(module, init_arg, [])
  end

  @spec start_link(module(), any(), keyword()) ::
          {:ok, pid()} | :ignore | {:error, term()}
  def start_link(module, init_arg, opts) when is_atom(module) do
    case module.init(init_arg) do
      {:ok, {sup_flags, children}} ->
        merged_opts =
          opts
          |> Keyword.put_new(:strategy, Map.get(sup_flags, :strategy, :one_for_one))
          |> Keyword.put_new(:max_restarts, Map.get(sup_flags, :intensity, @default_max_restarts))
          |> Keyword.put_new(:max_seconds, Map.get(sup_flags, :period, @default_max_seconds))

        do_start_with_children(children, merged_opts)

      :ignore ->
        :ignore

      other ->
        {:error, {:bad_return_value, other}}
    end
  end

  defp do_start_with_children(children, opts) do
    strategy = Keyword.get(opts, :strategy, :one_for_one)

    unless strategy in [:one_for_one, :rest_for_one, :one_for_all] do
      raise ArgumentError,
            "Lockstep.Supervisor: only :one_for_one, :rest_for_one, :one_for_all strategies are supported, got #{inspect(strategy)}"
    end

    init = %{
      children: Enum.map(children, &normalize_child/1),
      strategy: strategy,
      max_restarts: Keyword.get(opts, :max_restarts, @default_max_restarts),
      max_seconds: Keyword.get(opts, :max_seconds, @default_max_seconds)
    }

    name_opt = name_option(opts)
    LGS.start_link(__MODULE__.Server, init, name_opt)
  end

  @doc """
  Build the supervisor-spec tuple returned from a `use Supervisor`
  module's `init/1` callback. Mirrors OTP `Supervisor.init/2`.

      def init(_arg) do
        children = [...]
        Lockstep.Supervisor.init(children, strategy: :one_for_one)
      end

  Returns `{:ok, {sup_flags, children}}` where `sup_flags` carries
  `:strategy`, `:intensity` (from `:max_restarts`), and `:period`
  (from `:max_seconds`).
  """
  @spec init([child_spec()], keyword()) :: {:ok, {map(), [child_spec()]}}
  def init(children, opts) when is_list(children) and is_list(opts) do
    flags = %{
      strategy: Keyword.get(opts, :strategy, :one_for_one),
      intensity: Keyword.get(opts, :max_restarts, @default_max_restarts),
      period: Keyword.get(opts, :max_seconds, @default_max_seconds)
    }

    {:ok, {flags, children}}
  end

  defp name_option(opts) do
    case Keyword.get(opts, :name) do
      nil -> []
      name -> [name: name]
    end
  end

  @doc "List children: `[{id, pid_or_:undefined, :worker, [module]}, ...]`."
  @spec which_children(pid()) :: [child_info()]
  def which_children(sup) do
    LGS.call(sup, :which_children)
  end

  @doc "Count of currently-tracked child slots (including :undefined)."
  @spec count_children(pid()) :: non_neg_integer()
  def count_children(sup) do
    LGS.call(sup, :count_children)
  end

  @doc """
  Add a new child dynamically. Returns `{:ok, pid}` or `{:error, reason}`.
  """
  @spec start_child(pid(), child_spec()) :: {:ok, pid()} | {:error, term()}
  def start_child(sup, spec) do
    LGS.call(sup, {:start_child, normalize_child(spec)})
  end

  @doc """
  Terminate a running child by id. The slot is preserved with pid
  `:undefined` so `restart_child/2` can revive it.
  """
  @spec terminate_child(pid(), child_id()) :: :ok | {:error, :not_found}
  def terminate_child(sup, child_id) do
    LGS.call(sup, {:terminate_child, child_id})
  end

  @doc """
  Restart a child whose pid is `:undefined`. Errors if the child is
  already alive or doesn't exist.
  """
  @spec restart_child(pid(), child_id()) :: {:ok, pid()} | {:error, term()}
  def restart_child(sup, child_id) do
    LGS.call(sup, {:restart_child, child_id})
  end

  @type child_spec :: %{
          required(:id) => any(),
          required(:start) => {module(), atom(), [any()]},
          optional(:restart) => :permanent | :transient | :temporary
        }

  # ============================================================
  # Spec normalization
  # ============================================================

  # Modules whose vanilla OTP version we want to redirect to a
  # Lockstep-aware equivalent when they appear in child specs.
  # User code that writes `{Registry, [...]}` in a child spec should
  # actually start `Lockstep.Registry` under Lockstep, since
  # `Registry.start_link` would spawn a vanilla GenServer outside
  # the controller.
  @module_aliases %{
    Registry => Lockstep.Registry,
    GenServer => Lockstep.GenServer,
    Task => Lockstep.Task,
    Task.Supervisor => Lockstep.Supervisor
  }

  defp normalize_child(%{id: id, start: {m, f, a}} = spec)
       when is_atom(m) and is_atom(f) and is_list(a) do
    aliased_m = Map.get(@module_aliases, m, m)
    aliased_id = if id == m, do: aliased_m, else: id

    spec
    |> Map.put_new(:restart, :permanent)
    |> Map.put(:start, {aliased_m, f, a})
    |> Map.put(:id, aliased_id)
  end

  defp normalize_child({mod, arg}) when is_atom(mod) do
    aliased = Map.get(@module_aliases, mod, mod)

    # Mirror OTP's behaviour: `{Module, arg}` is shorthand for
    # `Module.child_spec(arg)`. Modules implementing the
    # Supervisor / GenServer behaviours typically define a
    # `child_spec/1` that returns the actual spec. Falling through
    # to a default `{Module, :start_link, [arg]}` is only correct
    # when `child_spec/1` isn't defined.
    if function_exported?(aliased, :child_spec, 1) do
      aliased.child_spec(arg) |> normalize_child()
    else
      %{id: aliased, start: {aliased, :start_link, [arg]}, restart: :permanent}
    end
  end

  defp normalize_child(mod) when is_atom(mod) do
    aliased = Map.get(@module_aliases, mod, mod)

    if function_exported?(aliased, :child_spec, 1) do
      aliased.child_spec(:no_arg) |> normalize_child()
    else
      %{id: aliased, start: {aliased, :start_link, [[]]}, restart: :permanent}
    end
  end

  defp normalize_child(other) do
    raise ArgumentError,
          "Lockstep.Supervisor: invalid child spec #{inspect(other)}"
  end
end

defmodule Lockstep.Supervisor.Server do
  @moduledoc false

  # State:
  #   children: %{id => %{spec, pid, restart, restart_history}}
  #     pid is :undefined if the child is currently dead/never-started
  #     restart_history is a list of virtual-time stamps (most recent first)
  #   max_restarts, max_seconds: intensity limits applied across the
  #     supervisor (sum of all child restarts in the rolling window)
  #   global_restart_history: ms timestamps of every restart we issued
  #     in the current virtual-time window

  def init(%{children: specs, max_restarts: mr, max_seconds: ms} = init) do
    Lockstep.flag(:trap_exit, true)

    strategy = Map.get(init, :strategy, :one_for_one)

    # Track child order for :rest_for_one. When a child crashes, all
    # later-started children must be restarted as well.
    {children, errors, _} =
      Enum.reduce(specs, {%{}, [], 0}, fn spec, {acc, errs, idx} ->
        case start_child_inline(spec) do
          {:ok, pid} ->
            child = %{spec: spec, pid: pid, history: [], order: idx}
            {Map.put(acc, spec.id, child), errs, idx + 1}

          {:error, reason} ->
            child = %{spec: spec, pid: :undefined, history: [], order: idx}
            {Map.put(acc, spec.id, child), [{spec.id, reason} | errs], idx + 1}
        end
      end)

    if errors == [] do
      {:ok,
       %{
         children: children,
         strategy: strategy,
         max_restarts: mr,
         max_seconds: ms,
         global_history: []
       }}
    else
      # OTP would shut down on init failure. Match by raising.
      raise "Lockstep.Supervisor: child(ren) failed to start: #{inspect(errors)}"
    end
  end

  def handle_call(:which_children, _from, state) do
    list =
      state.children
      |> Enum.map(fn {id, %{pid: pid, spec: %{start: {m, _, _}}}} ->
        {id, pid, :worker, [m]}
      end)
      |> Enum.sort_by(&elem(&1, 0))

    {:reply, list, state}
  end

  def handle_call(:count_children, _from, state) do
    {:reply, map_size(state.children), state}
  end

  def handle_call({:start_child, spec}, _from, state) do
    if Map.has_key?(state.children, spec.id) do
      {:reply, {:error, {:already_present, spec.id}}, state}
    else
      next_order = next_order(state)

      case start_child_inline(spec) do
        {:ok, pid} ->
          child = %{spec: spec, pid: pid, history: [], order: next_order}
          state = put_in(state.children[spec.id], child)
          {:reply, {:ok, pid}, state}

        {:error, _reason} = err ->
          {:reply, err, state}
      end
    end
  end

  def handle_call({:terminate_child, id}, _from, state) do
    case Map.get(state.children, id) do
      nil ->
        {:reply, {:error, :not_found}, state}

      %{pid: :undefined} ->
        {:reply, :ok, state}

      %{pid: pid} = child ->
        # Send :kill via :exit signal -- non-trappable.
        try do
          Process.exit(pid, :kill)
        catch
          _, _ -> :ok
        end

        state = put_in(state.children[id], %{child | pid: :undefined})
        {:reply, :ok, state}
    end
  end

  def handle_call({:restart_child, id}, _from, state) do
    case Map.get(state.children, id) do
      nil ->
        {:reply, {:error, :not_found}, state}

      %{pid: pid} when is_pid(pid) ->
        {:reply, {:error, :running}, state}

      %{spec: spec} = child ->
        case start_child_inline(spec) do
          {:ok, pid} ->
            state = put_in(state.children[id], %{child | pid: pid})
            {:reply, {:ok, pid}, state}

          {:error, _} = err ->
            {:reply, err, state}
        end
    end
  end

  # Child crashed -- handle the EXIT signal.
  def handle_info({:EXIT, dead_pid, reason}, state) do
    case find_child_by_pid(state.children, dead_pid) do
      nil ->
        {:noreply, state}

      {id, %{spec: spec} = child} ->
        if should_restart?(spec, reason) do
          # Add to the global history *and* check intensity BEFORE
          # attempting the restart, so the restart itself counts.
          now = Lockstep.now()

          state =
            update_in(state.global_history, fn h ->
              add_within_window(h, now, state.max_seconds)
            end)

          if length(state.global_history) > state.max_restarts do
            # Intensity exceeded. Supervisor itself exits with :shutdown.
            {:stop, :shutdown, state}
          else
            # For :rest_for_one, terminate every child whose order
            # index is greater than the dead one before restarting.
            # For :one_for_all, terminate every OTHER alive child
            # (i.e., everything except the dead one).
            state =
              cond do
                state.strategy == :rest_for_one ->
                  terminate_later_siblings(state, child.order)

                state.strategy == :one_for_all ->
                  terminate_all_siblings(state, id)

                true ->
                  state
              end

            case start_child_inline(spec) do
              {:ok, new_pid} ->
                state = put_in(state.children[id], %{child | pid: new_pid})
                # For :one_for_all, restart any sibling we just
                # terminated. (rest_for_one's helper does this for
                # order > dead_order, which works for both cases.)
                state = restart_terminated_siblings(state, child.order)

                state =
                  if state.strategy == :one_for_all do
                    restart_remaining_undefined(state)
                  else
                    state
                  end

                {:noreply, state}

              {:error, _} ->
                # Restart failed. Mark as :undefined; another EXIT may
                # arrive if the start_link partially started something.
                state = put_in(state.children[id], %{child | pid: :undefined})
                {:noreply, state}
            end
          end
        else
          # Don't restart -- mark slot as :undefined and continue.
          state = put_in(state.children[id], %{child | pid: :undefined})
          {:noreply, state}
        end
    end
  end

  def handle_info(_, state), do: {:noreply, state}

  # :rest_for_one helpers. When a child at order N crashes, we
  # terminate everyone at order > N (the "later" siblings), then
  # restart them in order after the crashed child is restarted.
  defp terminate_later_siblings(state, dead_order) do
    state.children
    |> Enum.filter(fn {_id, %{order: o, pid: pid}} ->
      o > dead_order and is_pid(pid)
    end)
    |> Enum.reduce(state, fn {id, %{pid: pid} = c}, acc ->
      try do
        Process.exit(pid, :kill)
      catch
        _, _ -> :ok
      end

      put_in(acc.children[id], %{c | pid: :undefined})
    end)
  end

  # :one_for_all helper. When ANY child crashes, terminate every
  # OTHER alive child (`dead_id`'s sibling slots regardless of order).
  defp terminate_all_siblings(state, dead_id) do
    state.children
    |> Enum.filter(fn {id, %{pid: pid}} ->
      id != dead_id and is_pid(pid)
    end)
    |> Enum.reduce(state, fn {id, %{pid: pid} = c}, acc ->
      try do
        Process.exit(pid, :kill)
      catch
        _, _ -> :ok
      end

      put_in(acc.children[id], %{c | pid: :undefined})
    end)
  end

  # Restart everything currently :undefined. Used by :one_for_all
  # after a sibling restart completes -- siblings with order < dead
  # might have been terminated and need to come back.
  defp restart_remaining_undefined(state) do
    state.children
    |> Enum.filter(fn {_id, %{pid: p}} -> p == :undefined end)
    |> Enum.sort_by(fn {_id, %{order: o}} -> o end)
    |> Enum.reduce(state, fn {id, %{spec: spec} = c}, acc ->
      case start_child_inline(spec) do
        {:ok, pid} -> put_in(acc.children[id], %{c | pid: pid})
        {:error, _} -> acc
      end
    end)
  end

  defp restart_terminated_siblings(state, dead_order) do
    # Restart in order: ascending by order index.
    state.children
    |> Enum.filter(fn {_id, %{order: o, pid: p}} ->
      o > dead_order and p == :undefined
    end)
    |> Enum.sort_by(fn {_id, %{order: o}} -> o end)
    |> Enum.reduce(state, fn {id, %{spec: spec} = c}, acc ->
      case start_child_inline(spec) do
        {:ok, pid} -> put_in(acc.children[id], %{c | pid: pid})
        {:error, _} -> acc
      end
    end)
  end

  # ============================================================
  # Helpers
  # ============================================================

  defp start_child_inline(%{start: {m, f, a}}) do
    result =
      try do
        apply(m, f, a)
      rescue
        e -> {:error, e}
      catch
        :exit, reason -> {:error, reason}
      end
      |> normalize_start_result()

    # OTP contract: start_link establishes a link to the supervisor.
    # `Lockstep.GenServer.start_link/2-3` uses `Lockstep.spawn` (no
    # link), so we link explicitly here. If the child died between
    # start and link we'll receive `{:EXIT, child, :noproc}` and the
    # restart machinery handles it normally.
    case result do
      {:ok, pid} ->
        try do
          Lockstep.link(pid)
        catch
          _, _ -> :ok
        end

        {:ok, pid}

      other ->
        other
    end
  end

  defp normalize_start_result({:ok, pid}) when is_pid(pid), do: {:ok, pid}
  defp normalize_start_result({:ok, pid, _info}) when is_pid(pid), do: {:ok, pid}
  defp normalize_start_result(:ignore), do: {:error, :ignore}
  defp normalize_start_result({:error, _} = err), do: err
  defp normalize_start_result(other), do: {:error, {:bad_return_value, other}}

  defp find_child_by_pid(children, pid) do
    Enum.find(children, fn {_id, %{pid: p}} -> p == pid end)
  end

  defp should_restart?(%{restart: :permanent}, _), do: true
  defp should_restart?(%{restart: :temporary}, _), do: false
  defp should_restart?(%{restart: :transient}, :normal), do: false
  defp should_restart?(%{restart: :transient}, :shutdown), do: false
  defp should_restart?(%{restart: :transient}, {:shutdown, _}), do: false
  defp should_restart?(%{restart: :transient}, _), do: true
  defp should_restart?(_, _), do: true

  # Add a timestamp and prune anything older than `window` seconds.
  defp add_within_window(history, now_ms, window_seconds) do
    cutoff = now_ms - window_seconds * 1000
    [now_ms | history] |> Enum.take_while(&(&1 >= cutoff))
  end

  defp next_order(state) do
    case Enum.map(state.children, fn {_, %{order: o}} -> o end) do
      [] -> 0
      orders -> Enum.max(orders) + 1
    end
  end
end