lib/runbox/runtime/stage/timezip/multi_queue.ex

defmodule Runbox.Runtime.Stage.Timezip.MultiQueue do
  @moduledoc """
  Buffering multi-queue data structure for Timezip.

  If the structure would be heap, the complexity would be Olog(N) where N is number of messages
  (topped by max demand).
  With multi-queue operation is ~ O(M) where M is number of producers.
  """

  alias __MODULE__

  defstruct queues: %{}

  @type t() :: %MultiQueue{
          queues: %{optional(term()) => :queue.queue()}
        }
  @type demand_distribution() :: %{optional(term) => pos_integer()}

  def new do
    %MultiQueue{queues: %{}}
  end

  @doc """
  Adds new queue to MultiQueue.

  If adding same id again -> clears buffered messages for given producer.
  """
  def add_queue(%MultiQueue{queues: queues} = multi_queue, queue_id) do
    %MultiQueue{multi_queue | queues: Map.put(queues, queue_id, :queue.new())}
  end

  @doc "Removes queue and its content from multi_queue"
  def remove_queue(%MultiQueue{queues: queues} = multi_queue, queue_id) do
    %MultiQueue{multi_queue | queues: Map.delete(queues, queue_id)}
  end

  @doc "Adds messages for queue id to multiqueue."
  @spec enqueue(t(), term(), [term()]) :: {:ok, t()} | {:error, term()}
  def enqueue(%MultiQueue{queues: queues} = buffer, id, messages) do
    case queues do
      %{^id => queue} ->
        new_queue = enqueue_many(queue, messages)
        {:ok, %MultiQueue{buffer | queues: Map.put(queues, id, new_queue)}}

      %{} ->
        {:error, {:bad_queue_id, id}}
    end
  end

  @doc """
  Main logic, will return messages from all queues zipped/sorted by min timestamp.

  Will return as many messages as posible, until some queue is empty.

  In returned tuple, returned values means:
   - *multi_queue:* multi-queue without emitted messages

   - *demand_distribution:* map with `id -> integer()`, meaning how many messages in output
     originates from particular queue id, can be used for manual demand handling in genstage

   - *emitted:* returned messages from multiqueue front
  """
  @spec dequeue_all(t(), (msg, msg -> boolean())) ::
          {:ok, t(), demand_distribution(), emitted :: [msg]} | {:error, term()}
        when msg: term()
  def dequeue_all(multi_queue, comparator \\ &default_ts_tuple_comparator/2) do
    dequeue_all(multi_queue, [], %{}, comparator)
  end

  defp dequeue_all(multi_queue, acc_msgs, acc_demand_dist, comparator) do
    # find any message with the lowest timestamp
    case get_min_message(multi_queue, comparator) do
      :empty ->
        {:ok, multi_queue, acc_demand_dist, Enum.reverse(List.flatten(acc_msgs))}

      {:value, min_msg} ->
        # emit all messages with that timestamp
        {multi_queue, acc_msgs, acc_demand_dist} =
          dequeue_all_at(multi_queue, min_msg, acc_msgs, acc_demand_dist, comparator)

        # try again, maybe there is another timestamp to emit
        dequeue_all(multi_queue, acc_msgs, acc_demand_dist, comparator)
    end
  end

  # Dequeue all messages at a specified time. The time is specified by a message at that time (so we
  # can use comparator).
  defp dequeue_all_at(multi_queue, min_msg, acc_msgs, acc_demand_dist, comparator) do
    Enum.reduce(
      multi_queue.queues,
      {multi_queue, acc_msgs, acc_demand_dist},
      fn {queue_id, queue}, {multi_queue, acc_msgs, acc_demand_dist} ->
        # for every queue pop all messages at the min_msg time
        {queue, msgs} = pop_all_at(queue, min_msg, comparator)
        len = length(msgs)

        {
          put_in(multi_queue.queues[queue_id], queue),
          [msgs, acc_msgs],
          Map.update(acc_demand_dist, queue_id, len, &(&1 + len))
        }
      end
    )
  end

  defp enqueue_many(queue, items) do
    Enum.reduce(items, queue, fn item, q_acc -> :queue.in(item, q_acc) end)
  end

  # Find any message with minimal time as `{:value, msg}`.  If any queue is empty `:empty` is
  # returned instead.
  defp get_min_message(%MultiQueue{queues: queues}, comparator) do
    queues
    |> Stream.map(fn {_id, queue} -> :queue.peek(queue) end)
    |> Enum.reduce(&queue_min_reducer(&1, &2, comparator))
  end

  defp queue_min_reducer(_item, :empty = acc, _comparator), do: acc
  defp queue_min_reducer(:empty = item, _acc, _comparator), do: item

  defp queue_min_reducer({:value, msg1} = item, {:value, msg2} = acc, comparator) do
    if comparator.(msg1, msg2), do: item, else: acc
  end

  # Pop all messages at time specified by minimal message. Note the provided message must be
  # minimal!
  defp pop_all_at(queue, min_msg, acc_msgs \\ [], comparator) do
    case :queue.out(queue) do
      {:empty, new_queue} ->
        {new_queue, acc_msgs}

      {{:value, value}, new_queue} ->
        # this checks !(min_msg.timestamp < value.timestamp) (the else clause is the main one)
        # which is equivalent to value.timestamp <= min_msg.timestamp
        # given the min_msg is minimal it can only be true that value.timestamp == min_msg.timestamp
        if comparator.(min_msg, value) do
          # value.timestamp > min_msg.timestamp => we cannot remove this yet
          {queue, acc_msgs}
        else
          # value.timestamp == min_msg.timestamp => we can pop this msg
          pop_all_at(new_queue, min_msg, [value | acc_msgs], comparator)
        end
    end
  end

  # return true if first msg is before second (smaller ts)
  defp default_ts_tuple_comparator({ts1, _val}, {ts2, _val2}) do
    ts1 < ts2
  end
end