lib/membrane_rtmp_plugin/rtmp/sink/sink.ex

defmodule Membrane.RTMP.Sink do
  @moduledoc """
  Membrane element being client-side of RTMP streams.
  To work successfuly it requires to receive both audio and video streams in AAC and H264 format respectively.
  Currently it supports only:
    - RTMP proper - "plain" RTMP protocol
    - RTMPS - RTMP over TLS/SSL
  other RTMP variants - RTMPT, RTMPE, RTMFP are not supported.
  Implementation based on FFmpeg.
  """
  use Membrane.Sink

  require Membrane.Logger

  alias __MODULE__.Native
  alias Membrane.{AAC, Buffer, MP4}

  @supported_protocols ["rtmp://", "rtmps://"]
  @connection_attempt_interval 500

  def_input_pad :audio,
    availability: :always,
    accepted_format: AAC,
    mode: :pull,
    demand_unit: :buffers

  def_input_pad :video,
    availability: :always,
    accepted_format: MP4.Payload,
    mode: :pull,
    demand_unit: :buffers

  def_options rtmp_url: [
                spec: String.t(),
                description: """
                Destination URL of the stream. It needs to start with rtmp:// or rtmps:// depending on the protocol variant.
                This URL should be provided by your streaming service.
                """
              ],
              max_attempts: [
                spec: pos_integer() | :infinity,
                default: 1,
                description: """
                Maximum number of connection attempts before failing with an error.
                The attempts will happen every #{@connection_attempt_interval} ms
                """
              ]

  @impl true
  def handle_init(_ctx, options) do
    unless String.starts_with?(options.rtmp_url, @supported_protocols) do
      raise ArgumentError, "Invalid destination URL provided"
    end

    unless options.max_attempts == :infinity or
             (is_integer(options.max_attempts) and options.max_attempts >= 1) do
      raise ArgumentError, "Invalid max_attempts option value: #{options.max_attempts}"
    end

    state =
      options
      |> Map.from_struct()
      |> Map.merge(%{
        attempts: 0,
        native: nil,
        # Keys here are the pad names.
        frame_buffer: %{audio: nil, video: nil},
        ready?: false,
        # Activated when one of the source inputs gets closed. Interleaving is
        # disabled, frame buffer is flushed and from that point buffers on the
        # remaining pad are simply forwarded to the output.
        forward_mode?: false
      })

    {[], state}
  end

  @impl true
  def handle_setup(_ctx, state) do
    {:ok, native} = Native.create(state.rtmp_url)

    state
    |> Map.put(:native, native)
    |> try_connect()
    |> then(&{[], &1})
  end

  @impl true
  def handle_playing(_ctx, state) do
    {build_demand(state), state}
  end

  @impl true
  def handle_stream_format(
        :video,
        %MP4.Payload{content: %MP4.Payload.AVC1{avcc: avc_config}} = stream_format,
        _ctx,
        state
      ) do
    case Native.init_video_stream(
           state.native,
           stream_format.width,
           stream_format.height,
           avc_config
         ) do
      {:ok, ready?, native} ->
        Membrane.Logger.debug("Correctly initialized video stream.")
        {[], %{state | native: native, ready?: ready?}}

      {:error, :stream_format_resent} ->
        Membrane.Logger.error(
          "Input stream format redefined on pad :video. RTMP Sink does not support dynamic stream parameters"
        )

        {[], state}
    end
  end

  @impl true
  def handle_stream_format(:audio, %Membrane.AAC{} = stream_format, _ctx, state) do
    profile = AAC.profile_to_aot_id(stream_format.profile)
    sr_index = AAC.sample_rate_to_sampling_frequency_id(stream_format.sample_rate)
    channel_configuration = AAC.channels_to_channel_config_id(stream_format.channels)
    frame_length_id = AAC.samples_per_frame_to_frame_length_id(stream_format.samples_per_frame)

    aac_config =
      <<profile::5, sr_index::4, channel_configuration::4, frame_length_id::1, 0::1, 0::1>>

    case Native.init_audio_stream(
           state.native,
           stream_format.channels,
           stream_format.sample_rate,
           aac_config
         ) do
      {:ok, ready?, native} ->
        Membrane.Logger.debug("Correctly initialized audio stream.")
        {[], %{state | native: native, ready?: ready?}}

      {:error, :stream_format_resent} ->
        Membrane.Logger.error(
          "Input stream format redefined on pad :audio. RTMP Sink does not support dynamic stream parameters"
        )

        {[], state}
    end
  end

  @impl true
  def handle_write(pad, buffer, _ctx, %{ready?: false} = state) do
    {[], fill_frame_buffer(state, pad, buffer)}
  end

  def handle_write(pad, buffer, _ctx, %{forward_mode?: true} = state) do
    {[demand: pad], write_frame(state, pad, buffer)}
  end

  def handle_write(pad, buffer, _ctx, state) do
    state
    |> fill_frame_buffer(pad, buffer)
    |> write_frame_interleaved()
  end

  @impl true
  def handle_end_of_stream(pad, _ctx, state) do
    if state.forward_mode? do
      Native.finalize_stream(state.native)
      {[], state}
    else
      # The interleave logic does not work if either one of the inputs does not
      # produce buffers. From this point on we act as a "forward" filter.
      other_pad =
        case pad do
          :audio -> :video
          :video -> :audio
        end

      state = flush_frame_buffer(state)
      {[demand: other_pad], %{state | forward_mode?: true}}
    end
  end

  defp try_connect(%{attempts: attempts, max_attempts: max_attempts} = state)
       when max_attempts != :infinity and attempts >= max_attempts do
    raise "failed to connect to '#{state.rtmp_url}' #{attempts} times, aborting"
  end

  defp try_connect(state) do
    state = %{state | attempts: state.attempts + 1}

    case Native.try_connect(state.native) do
      :ok ->
        Membrane.Logger.debug("Correctly initialized connection with: #{state.rtmp_url}")

        state

      {:error, error} when error in [:econnrefused, :etimedout] ->
        Membrane.Logger.warn(
          "Connection to #{state.rtmp_url} refused, retrying in #{@connection_attempt_interval}ms"
        )

        Process.sleep(@connection_attempt_interval)

        try_connect(state)

      {:error, reason} ->
        raise "failed to connect to '#{state.rtmp_url}': #{inspect(reason)}"
    end
  end

  defp build_demand(%{frame_buffer: frame_buffer}) do
    frame_buffer
    |> Enum.filter(fn {_pad, buffer} -> buffer == nil end)
    |> Enum.map(fn {pad, _buffer} -> {:demand, pad} end)
  end

  defp fill_frame_buffer(state, pad, buffer) do
    if get_in(state, [:frame_buffer, pad]) == nil do
      put_in(state, [:frame_buffer, pad], buffer)
    else
      raise "attempted to overwrite frame buffer on pad #{inspect(pad)}"
    end
  end

  defp write_frame_interleaved(%{frame_buffer: %{audio: audio, video: video}} = state)
       when audio == nil or video == nil do
    # We still have to wait for the other frame.
    {[], state}
  end

  defp write_frame_interleaved(%{frame_buffer: frame_buffer} = state) do
    {pad, buffer} =
      Enum.min_by(frame_buffer, fn {_pad, buffer} ->
        buffer
        |> Buffer.get_dts_or_pts()
        |> Ratio.ceil()
      end)

    state =
      state
      |> write_frame(pad, buffer)
      |> put_in([:frame_buffer, pad], nil)

    {build_demand(state), state}
  end

  defp flush_frame_buffer(%{frame_buffer: frame_buffer} = state) do
    pads_with_buffer =
      frame_buffer
      |> Enum.filter(fn {_pad, buffer} -> buffer != nil end)
      |> Enum.sort(fn {_, left}, {_, right} ->
        Buffer.get_dts_or_pts(left) <= Buffer.get_dts_or_pts(right)
      end)

    Enum.reduce(pads_with_buffer, state, fn {pad, buffer}, state ->
      state
      |> write_frame(pad, buffer)
      |> put_in([:frame_buffer, pad], nil)
    end)
  end

  defp write_frame(state, :audio, buffer) do
    buffer_pts = buffer.pts |> Ratio.ceil()

    case Native.write_audio_frame(state.native, buffer.payload, buffer_pts) do
      {:ok, native} ->
        Map.put(state, :native, native)

      {:error, reason} ->
        raise "writing audio frame failed with reason: #{inspect(reason)}"
    end
  end

  defp write_frame(state, :video, buffer) do
    case Native.write_video_frame(
           state.native,
           buffer.payload,
           buffer.dts,
           buffer.pts || buffer.dts,
           buffer.metadata.h264.key_frame?
         ) do
      {:ok, native} ->
        Map.put(state, :native, native)

      {:error, reason} ->
        raise "writing video frame failed with reason: #{inspect(reason)}"
    end
  end
end