lib/runbox/runtime/stage/unit_registry.ex

defmodule Runbox.Runtime.Stage.UnitRegistry do
  @moduledoc """
  Unit register used in stage based runtime to manage state of units.

  UnitRegistry is configured via `c:Toolbox.Scenario.Template.StageBased.subscriptions/0` callback
  result. Callback returns template subscriptions, each subscription is defined as
  {message_type, routing_rule}. Unit registry contains one register per configured message_type.
  Routing rule is used to store unit and later on lookup this unit when needed.

  Units are being registered using unit attributes as defined in routing_rule. Unit lookup uses
  routing_rule to parse message body and the result compares to stored unit attributes.
  """

  alias __MODULE__
  alias Runbox.Runtime.Stage.UnitRegistry.AlternativeRegistry
  alias Runbox.Runtime.Stage.UnitRegistry.RoutingKeyBuilder
  alias Runbox.StateStore.ScheduleUtils
  alias Toolbox.Message, as: Msg
  alias Toolbox.Runtime.Stage.Unit

  defstruct units: %{},
            parse_msg_fns: %{},
            register_unit_fns: %{},
            alt_regs: %{},
            config: [],
            timeouts: []

  @type unit_id :: String.t()
  @type path :: [String.t() | atom]
  @type routing_key :: [String.t()]
  @type routing_key_def :: {:=, [path], [path]} | {:in, [path], [path]}
  @type msg_parser :: (Msg.t() -> routing_key)
  @type unit_register :: (map, Msg.type(), Unit.t() -> map)

  @type t :: %UnitRegistry{
          units: %{unit_id => Unit.t()},
          parse_msg_fns: %{Msg.type() => msg_parser},
          register_unit_fns: %{Msg.type() => unit_register},
          alt_regs: %{Msg.type() => %{routing_key => unit_id}},
          config: [{Msg.type(), routing_key_def}],
          timeouts: [{ScheduleUtils.epoch_ms(), unit_id, Msg.t()}]
        }

  @doc """
  Creates new unit registry.
  """
  @spec new([{Msg.type(), routing_key_def}]) :: t
  def new(config) do
    %UnitRegistry{
      config: config
    }
    |> init()
  end

  @doc """
  Initialize UnitRegistry non-persisted fields.

  This needs to be called right before starting new unit registry (done in `new/1` automatically)
  and also after loading UnitRegistry from savepoint. This is needed because not all fields are
  persisted into a savepoint.
  """
  @spec init(t()) :: t()
  def init(%UnitRegistry{config: config} = registry) do
    parse_msg_fns =
      Map.new(config, fn {msg_type, unit_routing_def} ->
        {msg_type, RoutingKeyBuilder.build_msg_parser(unit_routing_def)}
      end)

    register_unit_fns =
      Map.new(config, fn {msg_type, unit_routing_def} ->
        {msg_type, RoutingKeyBuilder.build_unit_register(unit_routing_def)}
      end)

    alt_regs =
      if registry.alt_regs == %{} do
        # old versions of unit registry created alt regs during first unit registration
        # this ensures that all alt regs are created as soon as unit registry is created, other
        # functions here doesnt have to know about this -> simpler code
        Map.new(config, fn {msg_type, _} ->
          {msg_type, AlternativeRegistry.new()}
        end)
      else
        registry.alt_regs
      end

    %UnitRegistry{
      registry
      | register_unit_fns: register_unit_fns,
        parse_msg_fns: parse_msg_fns,
        alt_regs: alt_regs
    }
  end

  @doc false
  @spec config(t) :: [{Msg.type(), routing_key_def}]
  def config(%UnitRegistry{config: config}) do
    config
  end

  @doc false
  @spec state(t) :: %{
          units: %{unit_id => Unit.t()},
          alt_regs: %{Msg.type() => %{routing_key => unit_id}}
        }
  def state(%UnitRegistry{units: units, alt_regs: alt_regs}) do
    %{units: units, alt_regs: alt_regs}
  end

  @doc false
  @spec units(t) :: [Unit.t()]
  def units(%UnitRegistry{units: units}) do
    Map.values(units)
  end

  @doc """
  Registers given unit to the unit registry.
  """
  @spec register(t, Unit.t()) :: t
  def register(%UnitRegistry{} = unit_registry, unit) do
    units = Map.put(unit_registry.units, unit.id, unit)

    alt_regs =
      Enum.reduce(unit_registry.register_unit_fns, unit_registry.alt_regs, fn
        {reg_name, register_unit_fn}, alt_regs ->
          register_unit_fn.(alt_regs, reg_name, unit)
      end)

    %UnitRegistry{unit_registry | units: units, alt_regs: alt_regs}
  end

  @doc """
  Unregisters given unit from the unit registry.
  """
  @spec unregister(t, Unit.t()) :: t
  def unregister(%UnitRegistry{} = unit_registry, unit) do
    unit_registry
    |> unregister_unit(unit)
    |> unregister_timeouts_for(unit)
  end

  @spec unregister_unit(t, Unit.t()) :: t
  defp unregister_unit(%UnitRegistry{} = unit_registry, unit) do
    units = Map.delete(unit_registry.units, unit.id)

    remove_from_alt_reg = fn {alt_reg_name, alt_reg} ->
      {alt_reg_name, AlternativeRegistry.unregister_unit(alt_reg, unit.id)}
    end

    alt_regs =
      unit_registry.alt_regs
      |> Enum.map(remove_from_alt_reg)
      |> Map.new()

    %UnitRegistry{unit_registry | units: units, alt_regs: alt_regs}
  end

  @spec unregister_timeouts_for(t, Unit.t()) :: t
  defp unregister_timeouts_for(%UnitRegistry{} = unit_registry, unit) do
    timeouts = Enum.filter(unit_registry.timeouts, fn {_, id, _} -> id != unit.id end)

    %UnitRegistry{unit_registry | timeouts: timeouts}
  end

  @doc """
  Reregisters unit in the unit registry.

  Function updates alternative registers using new version of unit attributes. Timeouts will remain
  untouched.
  """
  @spec reregister(t, Unit.t()) :: t
  def reregister(%UnitRegistry{} = unit_registry, unit) do
    unit_registry
    |> unregister_unit(unit)
    |> register(unit)
  end

  @doc """
  Returns unit with given id.
  """
  @spec get_unit(t, unit_id) :: {:ok, Unit.t()} | {:error, :not_found}
  def get_unit(%UnitRegistry{} = unit_registry, id) do
    case Map.fetch(unit_registry.units, id) do
      {:ok, unit} ->
        {:ok, unit}

      _ ->
        {:error, :not_found}
    end
  end

  @doc """
  Searches for registered unit in the unit registry.
  """
  @spec lookup(t, Msg.t()) ::
          {:ok, [Unit.t()]} | {:error, :unknown_message_type}
  def lookup(%UnitRegistry{} = unit_registry, %Msg{} = msg) do
    case Map.fetch(unit_registry.parse_msg_fns, msg.type) do
      :error ->
        {:error, :unknown_message_type}

      {:ok, parse_msg_fn} ->
        routing_key = parse_msg_fn.(msg)
        alt_reg = Map.get(unit_registry.alt_regs, msg.type)

        units =
          alt_reg
          |> AlternativeRegistry.lookup(routing_key)
          |> Enum.map(fn unit_id ->
            {:ok, unit} = get_unit(unit_registry, unit_id)
            unit
          end)

        {:ok, units}
    end
  end

  @doc """
  Updates unit in unit registry.
  """
  @spec update(t, Unit.t()) :: {:ok, t} | {:error, :not_found}
  def update(%UnitRegistry{} = unit_registry, %Unit{attributes: new_attributes} = unit) do
    case get_unit(unit_registry, unit.id) do
      {:ok, %Unit{attributes: ^new_attributes}} ->
        # unit attributes didn't change -> just update unit
        {:ok, %UnitRegistry{unit_registry | units: Map.put(unit_registry.units, unit.id, unit)}}

      {:ok, _unit} ->
        # unit attributes did change -> reregister unit
        unit_registry = reregister(unit_registry, unit)
        {:ok, unit_registry}

      {:error, :not_found} ->
        {:error, :not_found}
    end
  end

  @doc """
  Registers timeout.
  """
  @spec register_timeout(t, Unit.t(), ScheduleUtils.epoch_ms(), Msg.t()) :: t
  def register_timeout(%UnitRegistry{} = unit_registry, unit, ts, msg) do
    timeouts = insert_timeout(unit_registry.timeouts, {ts, unit.id, msg})
    %UnitRegistry{unit_registry | timeouts: timeouts}
  end

  # for tests only
  @doc false
  def registered_timeouts(%UnitRegistry{} = unit_registry, ts) do
    Enum.take_while(unit_registry.timeouts, fn {timeout_ts, _, _} -> timeout_ts <= ts end)
  end

  @doc """
  Returns first reached timeout and removes it from unit registry in one operation.
  """
  @spec pop_reached_timeout(t, ScheduleUtils.epoch_ms()) ::
          {:ok, Unit.t(), Msg.t(), t} | :no_reached_timeout
  def pop_reached_timeout(%UnitRegistry{timeouts: []}, _ts) do
    :no_reached_timeout
  end

  def pop_reached_timeout(%UnitRegistry{timeouts: [{timeout_ts, _, _} | _]}, ts)
      when timeout_ts > ts do
    :no_reached_timeout
  end

  def pop_reached_timeout(
        %UnitRegistry{timeouts: [{_, unit_id, timeout_msg} | other_timeouts]} = unit_registry,
        _ts
      ) do
    {:ok, unit} = get_unit(unit_registry, unit_id)
    {:ok, unit, timeout_msg, %UnitRegistry{unit_registry | timeouts: other_timeouts}}
  end

  # timeouts are sorted so insert new timeout on the first correct position
  defp insert_timeout([{ts2, _, _} = h | t] = timeouts, {ts1, _, _} = timeout) do
    if ts1 < ts2, do: [timeout | timeouts], else: [h | insert_timeout(t, timeout)]
  end

  defp insert_timeout([], timeout) do
    [timeout]
  end
end