lib/membrane_mp4/muxer/cmaf.ex

defmodule Membrane.MP4.Muxer.CMAF do
  @moduledoc """
  Puts payloaded stream into [Common Media Application Format](https://www.wowza.com/blog/what-is-cmaf),
  an MP4-based container commonly used in adaptive streaming over HTTP.

  Multiple input streams are supported. If that is the case, they will be muxed into a single CMAF Track.
  Given that all input streams need to have a keyframe at the beginning of each CMAF Segment, it is recommended
  that all input streams are renditions of the same content.

  If a stream contains non-key frames (like H264 P or B frames), they should be marked
  with a `mp4_payload: %{key_frame?: false}` metadata entry.

  """
  use Membrane.Filter

  require Membrane.Logger

  alias __MODULE__.{Header, Segment}
  alias Membrane.{Buffer, Time}
  alias Membrane.MP4.Payload.{AAC, AVC1}
  alias Membrane.MP4.{Helper, Track}

  def_input_pad :input,
    availability: :on_request,
    demand_unit: :buffers,
    caps: Membrane.MP4.Payload

  def_output_pad :output, caps: Membrane.CMAF.Track

  def_options segment_duration: [
                type: :time,
                spec: Membrane.Time.t(),
                default: 2 |> Time.seconds()
              ]

  @impl true
  def handle_init(options) do
    state =
      options
      |> Map.from_struct()
      |> Map.merge(%{
        seq_num: 0,
        # Caps waiting to be sent after receiving the next buffer. Holds the structure {caps_timestamp, caps}
        awaiting_caps: nil,
        pad_to_track_data: %{},
        # ID for the next input track
        next_track_id: 1,
        samples: %{}
      })

    {:ok, state}
  end

  @impl true
  def handle_pad_added(_pad, ctx, _state) when ctx.playback_state == :playing,
    do:
      raise(
        "New tracks can be added to #{inspect(__MODULE__)} only before transition to state: :playing"
      )

  @impl true
  def handle_pad_added(Pad.ref(:input, _id) = pad, _ctx, state) do
    {track_id, state} = Map.get_and_update!(state, :next_track_id, &{&1, &1 + 1})

    track_data = %{
      id: track_id,
      track: nil,
      elapsed_time: 0,
      end_timestamp: 0,
      buffer_awaiting_duration: nil
    }

    state
    |> put_in([:pad_to_track_data, pad], track_data)
    |> put_in([:samples, pad], [])
    |> then(&{:ok, &1})
  end

  @impl true
  def handle_demand(:output, _size, _unit, _ctx, state) do
    {pad, _elapsed_time} =
      state.pad_to_track_data
      |> Enum.map(fn {pad, track_data} -> {pad, track_data.end_timestamp} end)
      |> Enum.reject(fn {_key, timestamp} -> is_nil(timestamp) end)
      |> Enum.min_by(fn {_key, timestamp} -> Ratio.to_float(timestamp) end)

    {{:ok, demand: {pad, 1}}, state}
  end

  @impl true
  def handle_caps(pad, %Membrane.MP4.Payload{} = caps, ctx, state) do
    state =
      update_in(state, [:pad_to_track_data, pad], fn track_data ->
        track =
          caps
          |> Map.from_struct()
          |> Map.take([:width, :height, :content, :timescale])
          |> Map.put(:id, track_data.id)
          |> Track.new()

        %{track_data | track: track}
      end)

    has_all_input_caps? =
      Map.drop(ctx.pads, [:output, pad]) |> Map.values() |> Enum.all?(&(&1.caps != nil))

    if has_all_input_caps? do
      caps = generate_output_caps(state)

      cond do
        is_nil(ctx.pads.output.caps) ->
          {{:ok, caps: {:output, caps}}, state}

        caps != ctx.pads.output.caps ->
          {:ok, %{state | awaiting_caps: {{:update_with_next, pad}, caps}}}

        true ->
          {:ok, state}
      end
    else
      {:ok, state}
    end
  end

  @impl true
  def handle_process(Pad.ref(:input, _id) = pad, sample, ctx, state) do
    use Ratio, comparison: true

    state =
      state
      |> process_buffer_awaiting_duration(pad, sample)
      |> update_awaiting_caps(pad)

    {caps_action, segment} =
      if is_nil(state.awaiting_caps) do
        {[], Segment.Helper.get_segment(state, state.segment_duration)}
      else
        {duration, caps} = state.awaiting_caps
        {[caps: {:output, caps}], Segment.Helper.get_discontinuity_segment(state, duration)}
      end

    case segment do
      {:ok, segment, state} ->
        {buffer, state} = generate_segment(segment, ctx, state)
        actions = [buffer: {:output, buffer}] ++ caps_action ++ [redemand: :output]
        {{:ok, actions}, state}

      {:error, :not_enough_data} ->
        {{:ok, redemand: :output}, state}
    end
  end

  @impl true
  def handle_end_of_stream(Pad.ref(:input, _track_id) = pad, ctx, state) do
    sample = state.pad_to_track_data[pad].buffer_awaiting_duration

    sample_metadata =
      Map.put(sample.metadata, :duration, hd(state.samples[pad]).metadata.duration)

    sample = %Buffer{sample | metadata: sample_metadata}

    state = update_in(state, [:samples, pad], &[sample | &1])

    processing_finished? =
      ctx.pads |> Map.drop([:output, pad]) |> Map.values() |> Enum.all?(& &1.end_of_stream?)

    if processing_finished? do
      with {:ok, segment, state} <- Segment.Helper.take_all_samples(state) do
        {buffer, state} = generate_segment(segment, ctx, state)
        {{:ok, buffer: {:output, buffer}, end_of_stream: :output}, state}
      else
        {:error, :not_enough_data} -> {{:ok, end_of_stream: :output}, state}
      end
    else
      state = put_in(state, [:pad_to_track_data, pad, :end_timestamp], nil)

      {{:ok, redemand: :output}, state}
    end
  end

  defp generate_output_caps(state) do
    tracks = Enum.map(state.pad_to_track_data, fn {_pad, track_data} -> track_data.track end)

    header = Header.serialize(tracks)

    content_type =
      tracks
      |> Enum.map(fn
        %{content: %AAC{}} -> :audio
        %{content: %AVC1{}} -> :video
      end)
      |> then(fn
        [item] -> item
        list -> list
      end)

    %Membrane.CMAF.Track{
      content_type: content_type,
      header: header
    }
  end

  defp generate_segment(acc, ctx, state) do
    use Ratio, comparison: true

    tracks_data =
      Enum.map(acc, fn {pad, samples} ->
        %{timescale: timescale} = ctx.pads[pad].caps
        first_sample = hd(samples)
        last_sample = List.last(samples)
        samples = Enum.to_list(samples)

        samples_table =
          samples
          |> Enum.map(fn sample ->
            %{
              sample_size: byte_size(sample.payload),
              sample_flags: generate_sample_flags(sample.metadata),
              sample_duration:
                Helper.timescalify(
                  sample.metadata.duration,
                  timescale
                )
                |> Ratio.trunc(),
              sample_offset: Ratio.floor((sample.pts - sample.dts) / timescale)
            }
          end)

        samples_data = Enum.map_join(samples, & &1.payload)

        duration = last_sample.dts - first_sample.dts + last_sample.metadata.duration

        %{
          pad: pad,
          id: state.pad_to_track_data[pad].id,
          sequence_number: state.seq_num,
          elapsed_time:
            Helper.timescalify(state.pad_to_track_data[pad].elapsed_time, timescale)
            |> Ratio.trunc(),
          unscaled_duration: duration,
          duration: Helper.timescalify(duration, timescale),
          timescale: timescale,
          samples_table: samples_table,
          samples_data: samples_data
        }
      end)

    payload = Segment.serialize(tracks_data)

    # Duration of the tracks will never be exactly the same. To minimize the error and avoid its magnification over time,
    # duration of the segment is assumed to be the average of tracks' durations.
    duration =
      tracks_data
      |> Enum.map(&Ratio.to_float(&1.unscaled_duration))
      |> then(&(Enum.sum(&1) / length(&1)))
      |> floor()

    buffer = %Buffer{payload: payload, metadata: %{duration: duration}}

    # Update elapsed time counters for each track
    state =
      Enum.reduce(tracks_data, state, fn %{unscaled_duration: duration, pad: pad}, state ->
        update_in(state, [:pad_to_track_data, pad, :elapsed_time], &(&1 + duration))
      end)
      |> Map.update!(:seq_num, &(&1 + 1))

    {buffer, state}
  end

  defp generate_sample_flags(metadata) do
    key_frame? = metadata |> Map.get(:mp4_payload, %{}) |> Map.get(:key_frame?, true)

    is_leading = 0
    depends_on = if key_frame?, do: 2, else: 1
    is_depended_on = 0
    has_redundancy = 0
    padding_value = 0
    non_sync = if key_frame?, do: 0, else: 1
    degradation_priority = 0

    <<0::4, is_leading::2, depends_on::2, is_depended_on::2, has_redundancy::2, padding_value::3,
      non_sync::1, degradation_priority::16>>
  end

  # Update the duration of the awaiting sample and insert the current sample into the queue
  defp process_buffer_awaiting_duration(state, pad, sample) do
    use Ratio

    prev_sample = state.pad_to_track_data[pad].buffer_awaiting_duration

    if is_nil(prev_sample) do
      put_in(state, [:pad_to_track_data, pad, :buffer_awaiting_duration], sample)
    else
      duration = Ratio.to_float(sample.dts - prev_sample.dts)
      prev_sample_metadata = Map.put(prev_sample.metadata, :duration, duration)
      prev_sample = %Buffer{prev_sample | metadata: prev_sample_metadata}

      put_in(state, [:pad_to_track_data, pad, :end_timestamp], prev_sample.dts)
      |> put_in([:pad_to_track_data, pad, :buffer_awaiting_duration], sample)
      |> update_in([:samples, pad], &[prev_sample | &1])
    end
  end

  # It is not possible to determine the duration of the segment that is connected with discontinuity before receiving the next sample.
  # This function acts to update the information about the duration of the discontinuity segment that needs to be produced
  defp update_awaiting_caps(%{awaiting_caps: {{:update_with_next, pad}, caps}} = state, pad) do
    use Ratio

    duration =
      state.pad_to_track_data[pad].buffer_awaiting_duration.dts -
        List.last(state.samples[pad]).dts

    %{state | awaiting_caps: {duration, caps}}
  end

  defp update_awaiting_caps(state, _pad), do: state
end