lib/runbox/runtime/simple/template_carrier.ex

defmodule Runbox.Runtime.Simple.TemplateCarrier do
  @moduledoc group: :internal
  @moduledoc """
  Template carrier for Simple scenario.

  GenStage handling Simple scenario template. It carries the state of the template and handles
  `init` and `handle_message` callbacks. It expects to consume messages from input streams and
  produces output-stream-ready events.
  """
  alias Runbox.Message
  alias Runbox.RunContext
  alias Runbox.RunStartContext
  alias Runbox.Runtime.OutputAction
  alias Runbox.Runtime.RuntimeInstruction
  alias Runbox.Runtime.RuntimeInstruction.Timeout
  alias Runbox.Runtime.Simple.TemplateCarrier
  alias Runbox.Scenario.Simple
  alias Runbox.Scenario.Simple.Config
  alias Runbox.StateStore.Entity
  alias Runbox.Utils.Traversal
  use GenStage
  require Logger

  defmodule Snapshot do
    @moduledoc group: :internal
    @moduledoc "State that is periodically persisted to savepoints and loaded when run continues."

    defstruct template_state: nil,
              timeouts: Heap.min(),
              timeout_counter: 0,
              runbox_version: Runbox.get_runbox_version()

    @type t :: %Snapshot{
            template_state: Simple.state(),
            timeouts: Heap.t(),
            timeout_counter: non_neg_integer(),
            runbox_version: Version.t()
          }

    @doc """
    Upgrades a possibly old version of the struct to the current version.

    The upgrade process is based on steps. Each step applies transformation of the state
    related to the changes in a specific version. It is similar to database migrations.
    This function executes all the upgrade steps since the version of the given struct
    until the struct is upgraded to the current version.

    ## Adding a new upgrade step
    For example, if version 42.0.0 adds the field `:new_field`, then you should create
    a corresponding clause in the `do_upgrade_step/1` function. This clause should upgrade
    the state from the last preceding version (e.g., 41.1.0) and update the `:runbox_version`
    field. Like so:

        defp do_upgrade_step(%{runbox_version: %{major: major}} = snapshot) when major < 42 do
          snapshot
          |> Map.put(:new_field, :some_value)
          |> Map.put(:runbox_version, Version.parse!("42.0.0"))
        end

    Place this new clause right before the catch-all clause (`defp do_upgrade_step(snapshot)`).
    """
    def upgrade(snapshot) do
      if Map.get(snapshot, :runbox_version) != Runbox.get_runbox_version() do
        snapshot
        |> do_upgrade_step()
        |> upgrade()
      else
        snapshot
      end
    end

    defp do_upgrade_step(snapshot) when not :erlang.is_map_key(:runbox_version, snapshot) do
      # in Runbox 12.1.0 (which predates the runbox_version field), the timeout_counter field was
      # added, so add it to the snapshot in case we are upgrading from snapshot of runbox 12.0.0
      snapshot
      |> Map.put_new(:timeout_counter, 0)
      |> Map.put(:runbox_version, Version.parse!("12.1.0"))
    end

    defp do_upgrade_step(%{runbox_version: %{major: major}} = snapshot) when major < 13 do
      # in Runbox 13.0.0, the structs Toolbox.Message and Toolbox.Runtime.Stage.Unit were
      # moved to the Runbox namespace, so replace every occurrence of them in the snapshot
      # credo:disable-for-lines:13 Credo.Check.Design.DuplicatedCode
      transform_fun = fn
        msg when is_struct(msg, Toolbox.Message) ->
          msg |> Map.from_struct() |> then(&struct(Runbox.Message, &1))

        unit when is_struct(unit, Toolbox.Runtime.Stage.Unit) ->
          unit |> Map.from_struct() |> then(&struct(Runbox.Runtime.Stage.Unit, &1))

        user_action when is_struct(user_action, Toolbox.Scenario.UserAction) ->
          user_action |> Map.from_struct() |> then(&struct(Runbox.Scenario.UserAction, &1))

        other ->
          other
      end

      snapshot
      |> Traversal.prewalk(transform_fun)
      |> Map.put(:runbox_version, Version.parse!("13.0.0"))
    end

    defp do_upgrade_step(snapshot) do
      # if the current Runbox version does not change the format of the snapshot, simply update
      # the runbox_version in it
      Map.put(snapshot, :runbox_version, Runbox.get_runbox_version())
    end
  end

  @type state_entity() :: Entity.t(Snapshot.t())

  @typedoc "Stage output (input for the next stage) is either a tick message or an output action"
  @type stage_outputs() :: [Message.t() | OutputAction.t()]

  @type timestamp() :: non_neg_integer()
  @type origin() :: Message.t() | :init

  defmodule State do
    @moduledoc group: :internal
    @moduledoc "State of a running template carrier for a Simple scenario."
    defstruct [:module, :state_entity, :start_from, :runbox_ctx, :run_id, :scenario_id]

    @type t() :: %State{
            module: module(),
            state_entity: TemplateCarrier.state_entity(),
            start_from: TemplateCarrier.timestamp(),
            runbox_ctx: RunContext.t(),
            run_id: String.t(),
            scenario_id: String.t()
          }
  end

  @doc "Returns component name."
  def component_name do
    :template
  end

  @doc "Starts the GenStage."
  def start_link(args, runbox_ctx, start_ctx) do
    GenStage.start_link(
      __MODULE__,
      {args, runbox_ctx, start_ctx}
    )
  end

  @doc "Returns state entity of a running template carrier."
  def get_state_entity(pid) do
    GenStage.call(pid, :fetch_state_entity)
  end

  @impl true
  def init({args, runbox_ctx, start_ctx}) do
    %{
      run_id: run_id,
      config: %{
        module: module,
        subscribe_to: [sub_component],
        scenario_id: scenario_id,
        start_from: start_from,
        start_or_continue: start_or_continue
      },
      state_entity: state_entity
    } = args

    subscribe_to = {RunStartContext.component_pid(start_ctx, sub_component), []}

    case start_or_continue do
      :start -> send(self(), :init_state)
      :continue -> send(self(), :load_state)
    end

    Logger.metadata(run_id: run_id, scenario_id: scenario_id)

    {:producer_consumer,
     %State{
       module: module,
       state_entity: state_entity,
       start_from: start_from,
       runbox_ctx: runbox_ctx,
       run_id: run_id,
       scenario_id: scenario_id
     }, subscribe_to: [subscribe_to]}
  end

  @impl true
  def handle_info(:init_state, %State{} = state) do
    snapshot = %Snapshot{}

    case Simple.init(state.module, %Config{start_from: state.start_from}) do
      {:ok, template_outputs, template_state} ->
        {oas, snapshot} = process_timeouts(template_outputs, state.start_from, :init, snapshot)
        stage_outputs = process_oas(oas, state.start_from, state, :init, nil)

        snapshot = %Snapshot{snapshot | template_state: template_state}
        entity = Entity.update_state(state.state_entity, state.start_from, snapshot)
        state = %State{state | state_entity: entity}

        {:noreply, stage_outputs, state}

      {:error, error} ->
        log_init_error(error)
        {:stop, :init_error, state}
    end
  end

  def handle_info(:load_state, %State{} = state) do
    entity = state.state_entity

    snapshot =
      entity
      |> Entity.state()
      |> Snapshot.upgrade()

    template_state = snapshot.template_state

    case Simple.set_state(state.module, template_state) do
      {:ok, template_state} ->
        snapshot = %Snapshot{snapshot | template_state: template_state}
        entity = Entity.update_state(entity, Entity.timestamp(entity), snapshot)
        {:noreply, [], %State{state | state_entity: entity}}

      {:error, error} ->
        log_state_load_error(error, template_state)
        {:stop, :set_state_error, state}
    end
  end

  @impl true
  def handle_events(events, _from, %State{} = state) do
    case handle_messages(events, [], state, state.state_entity) do
      {:ok, outputs, state_entity} ->
        {:noreply, add_tick_if_empty(outputs, events), %State{state | state_entity: state_entity}}

      _error ->
        # error was already logged at this point
        {:stop, :handle_message_error, state}
    end
  end

  @spec handle_messages([Message.t()], stage_outputs(), State.t(), state_entity()) ::
          {:ok, stage_outputs(), state_entity()} | {:error, any()}
  defp handle_messages([_h | _t] = msgs, stage_outputs_acc, state, state_entity) do
    {[msg | next_msgs], new_timeouts} = maybe_pop_timeout_msg(msgs, state_entity)

    entity_save_result =
      Entity.ack_processed_time(
        state_entity,
        msg.timestamp,
        {Runbox, :save_entity, [state.runbox_ctx]},
        fn entity -> transform_entity_before_save(entity, state.module) end
      )

    with {:ok, state_entity} <- entity_save_result do
      snapshot = Entity.state(state_entity)
      snapshot = %Snapshot{snapshot | timeouts: new_timeouts}
      {new_stage_outputs, snapshot} = handle_message(msg, state, snapshot)
      state_entity = Entity.update_state(state_entity, msg.timestamp, snapshot)

      stage_outputs_acc = stage_outputs_acc ++ new_stage_outputs
      handle_messages(next_msgs, stage_outputs_acc, state, state_entity)
    end
  end

  defp handle_messages([], stage_outputs_acc, _state, state_etity) do
    {:ok, stage_outputs_acc, state_etity}
  end

  @spec handle_message(Message.t(), State.t(), Snapshot.t()) :: {stage_outputs(), Snapshot.t()}
  defp handle_message(%Message{type: :tick} = msg, _state, snapshot) do
    # output tick to advance time
    {[msg], snapshot}
  end

  defp handle_message(msg, %State{} = state, snapshot) do
    origin_template_state = snapshot.template_state

    case Simple.handle_message(state.module, msg, origin_template_state) do
      {:ok, template_outputs, template_state} ->
        {oas, snapshot} = process_timeouts(template_outputs, msg.timestamp, msg, snapshot)
        oas = process_oas(oas, msg.timestamp, state, msg, origin_template_state)

        snapshot = %Snapshot{snapshot | template_state: template_state}
        {oas, snapshot}

      {:error, error} ->
        log_message_error(error, msg, origin_template_state)
        {[], snapshot}
    end
  end

  @impl true
  def handle_call(:fetch_state_entity, _from, %State{} = state) do
    {:reply, state.state_entity, [], state}
  end

  @spec process_oas(
          hopefully_oas :: [OutputAction.t() | any()],
          timestamp(),
          State.t(),
          origin(),
          template_state :: Simple.state()
        ) :: [OutputAction.t()]
  defp process_oas(oas, ts, %State{} = state, origin, origin_state) do
    Enum.flat_map(oas, fn oa ->
      if OutputAction.oa_body?(oa) do
        [OutputAction.new(oa, ts, state.scenario_id, state.run_id)]
      else
        log_unknown_oa(oa, origin, origin_state)
        []
      end
    end)
  end

  @spec process_timeouts(
          [OutputAction.t() | RuntimeInstruction.t() | any()],
          timestamp(),
          origin(),
          Snapshot.t()
        ) ::
          {[OutputAction.t() | any()], Snapshot.t()}
  defp process_timeouts(template_outputs, ts, origin, snapshot) do
    {timeout_instrs, oas} =
      Enum.split_with(
        template_outputs,
        &match?(%RuntimeInstruction{body: %Timeout{timeout_message: %Message{}}}, &1)
      )

    snapshot = Enum.reduce(timeout_instrs, snapshot, &register_timeout(&1.body, ts, origin, &2))
    {oas, snapshot}
  end

  @spec register_timeout(Timeout.t(), timestamp(), origin(), Snapshot.t()) :: Snapshot.t()
  defp register_timeout(%Timeout{timeout_message: msg}, ts, origin, %Snapshot{} = snapshot) do
    msg = maybe_fix_timeout_in_the_past(msg, ts, origin, snapshot.template_state)
    timeouts = Heap.push(snapshot.timeouts, {msg.timestamp, snapshot.timeout_counter, msg})
    %Snapshot{snapshot | timeouts: timeouts, timeout_counter: snapshot.timeout_counter + 1}
  end

  defp maybe_fix_timeout_in_the_past(timeout_msg, now_ts, origin, template_state) do
    if timeout_msg.timestamp < now_ts do
      log_timeout_registered_in_the_past(timeout_msg, now_ts, origin, template_state)
      %Message{timeout_msg | timestamp: now_ts}
    else
      # nothing to do, timeout is set to the future (or current timestamp, which is allowed)
      timeout_msg
    end
  end

  @spec maybe_pop_timeout_msg([Message.t(), ...], state_entity()) :: {[Message.t(), ...], Heap.t()}
  defp maybe_pop_timeout_msg([msg | next_msgs], state_entity) do
    snapshot = Entity.state(state_entity)
    now_ts = msg.timestamp

    case Heap.root(snapshot.timeouts) do
      # old case for previous state structure, timeouts are now 3-tuple
      {ts, %Message{} = timeout} when ts <= now_ts ->
        {[timeout, msg | next_msgs], Heap.pop(snapshot.timeouts)}

      {ts, _counter, %Message{} = timeout} when ts <= now_ts ->
        {[timeout, msg | next_msgs], Heap.pop(snapshot.timeouts)}

      # no remaining timeout, or all timeouts are in the future
      _ ->
        {[msg | next_msgs], snapshot.timeouts}
    end
  end

  @spec transform_entity_before_save(state_entity(), module()) ::
          {:ok, state_entity()} | {:error, any()}
  defp transform_entity_before_save(state_entity, module) do
    snapshot = Entity.state(state_entity)
    template_state = snapshot.template_state

    case Simple.get_state(module, template_state) do
      {:ok, template_state} ->
        snapshot = %Snapshot{snapshot | template_state: template_state}
        {:ok, Entity.update_state(state_entity, Entity.timestamp(state_entity), snapshot)}

      {:error, reason} = error ->
        log_state_save_error(reason, template_state)
        error
    end
  end

  # We need to ensure output is not empty. If template processes all events but doesn't generate any
  # output action, and there are no ticks coming in, Output Sink at the end would not get any
  # messages. Therefore it would not ack the time to StateStore thus preventing StateStore from
  # storing savepoints. We must always output something, so add tick if there is no output.
  defp add_tick_if_empty([], events) do
    ts = List.last(events).timestamp
    [%Message{timestamp: ts, type: :tick, body: %{}}]
  end

  defp add_tick_if_empty(output, _events) do
    output
  end

  defp log_init_error({:bad_return_value, value}) do
    Logger.warning("""
    Simple scenario has crashed during initialization because of unexpected return value.
    Value: #{inspect(value)}\
    """)
  end

  defp log_init_error({:exception, exception, stacktrace}) do
    info = String.trim(Exception.format(:error, exception, stacktrace))

    Logger.warning("""
    Simple scenario has crashed during initialization because of exception.
    Exception: #{info}\
    """)
  end

  defp log_state_load_error({:bad_return_value, value}, state) do
    Logger.warning("""
    Simple scenario has crashed during set_state callback because of unexpected return value.
    Value: #{inspect(value)}
    External state: #{inspect(state)}\
    """)
  end

  defp log_state_load_error({:exception, exception, stacktrace}, state) do
    info = String.trim(Exception.format(:error, exception, stacktrace))

    Logger.warning("""
    Simple scenario has crashed during set_state callback because of exception.
    Exception: #{info}
    External state: #{inspect(state)}\
    """)
  end

  defp log_state_save_error({:bad_return_value, value}, state) do
    Logger.warning("""
    Simple scenario has crashed during get_state callback because of bad return value.
    Value: #{inspect(value)}
    Internal state: #{inspect(state)}\
    """)
  end

  defp log_state_save_error({:exception, exception, stacktrace}, state) do
    info = String.trim(Exception.format(:error, exception, stacktrace))

    Logger.warning("""
    Simple scenario has crashed during get_state callback because of exception.
    Exception: #{info}
    Internal state: #{inspect(state)}\
    """)
  end

  defp log_message_error({:bad_return_value, value}, msg, state) do
    Logger.warning("""
    Simple scenario has ignored message because of unexpected return value.
    Value: #{inspect(value)}
    Message: #{inspect(msg)}
    State: #{inspect(state)}\
    """)
  end

  defp log_message_error({:exception, exception, stacktrace}, msg, state) do
    info = String.trim(Exception.format(:error, exception, stacktrace))

    Logger.warning("""
    Simple scenario has ignored message because of exception.
    Exception: #{info}
    Message: #{inspect(msg)}
    State: #{inspect(state)}\
    """)
  end

  defp log_unknown_oa(oa, :init, _state) do
    Logger.warning("""
    Simple scenario has ignored unknown output action returned from initialization.
    Output action: #{inspect(oa)}\
    """)
  end

  defp log_unknown_oa(oa, msg, state) do
    Logger.warning("""
    Simple scenario has ignored unknown output action.
    Output action: #{inspect(oa)}
    Message: #{inspect(msg)}
    State: #{inspect(state)}\
    """)
  end

  defp log_timeout_registered_in_the_past(timeout_msg, now_ts, :init, _state) do
    Logger.warning("""
    Simple scenario has registered timeout in init, but the timeout is before start_from.
    The timeout has been shifted to start_from, which is #{now_ts}.
    Original timeout message: #{inspect(timeout_msg)}\
    """)
  end

  defp log_timeout_registered_in_the_past(timeout_msg, now_ts, msg, state) do
    Logger.warning("""
    Simple scenario has registered timeout in the past.
    The timeout has been shifted to the current timestamp, which is #{now_ts}.
    Original timeout message: #{inspect(timeout_msg)}
    Message: #{inspect(msg)}
    State: #{inspect(state)}\
    """)
  end
end