Skip to main content

lib/air_play/v2/rtsp2.ex

defmodule AirPlay.V2.Rtsp2 do
  @moduledoc """
  RTSP client for AirPlay 2.

  Starts as ordinary RTSP for `/pair-setup`, then can be upgraded to the AP2
  ChaCha20-Poly1305 secure channel after transient pairing succeeds.
  """

  import Bitwise, only: [&&&: 2, |||: 2]

  alias AirPlay.V2.SecureChannel

  @user_agent "AirPlay/690.7.1"
  @timeout 5_000

  defstruct [
    :sock,
    :host,
    :port,
    :session_id,
    :rtsp_session,
    :secure,
    # DACP remote-control identity. When set, every request advertises where the
    # receiver should send on-device transport commands (play/pause/next/volume).
    :dacp_id,
    :active_remote,
    cseq: 0,
    cipher_buffer: <<>>,
    plain_buffer: <<>>
  ]

  @type t :: %__MODULE__{}

  @spec connect(String.t(), :inet.port_number()) :: {:ok, t()} | {:error, term()}
  def connect(host, port \\ 7000) do
    with {:ok, sock} <-
           :gen_tcp.connect(String.to_charlist(host), port, [:binary, active: false], @timeout) do
      {:ok,
       %__MODULE__{
         sock: sock,
         host: host,
         port: port,
         session_id: random_session_id()
       }}
    end
  end

  @spec close(t()) :: :ok
  def close(%__MODULE__{sock: sock}), do: :gen_tcp.close(sock)

  @spec enable_encryption(t(), binary()) :: t()
  def enable_encryption(%__MODULE__{} = state, shared_secret) do
    %{state | secure: SecureChannel.control(shared_secret)}
  end

  @spec session_url(t()) :: String.t()
  def session_url(%__MODULE__{} = state), do: "rtsp://#{state.host}/#{state.session_id}"

  @spec request(t(), String.t(), String.t(), list(), binary()) ::
          {:ok, non_neg_integer(), map(), binary(), t()} | {:error, term()}
  def request(%__MODULE__{} = state, method, path, headers \\ [], body \\ <<>>) do
    state = %{state | cseq: state.cseq + 1}
    message = build_request(state, method, path, headers, body)
    {wire_message, state} = maybe_encrypt(state, message)

    with :ok <- :gen_tcp.send(state.sock, wire_message),
         {:ok, status, response_headers, response_body, state} <- recv_response(state) do
      {:ok, status, response_headers, response_body, maybe_store_session(state, response_headers)}
    end
  end

  @doc "Parse a buffered RTSP response, returning `:more` until the body is complete."
  @spec parse(binary()) :: {:ok, non_neg_integer(), map(), binary(), binary()} | :more
  def parse(buffer) when is_binary(buffer), do: try_parse_response(buffer)

  defp maybe_store_session(state, %{"session" => session}), do: %{state | rtsp_session: session}
  defp maybe_store_session(state, _headers), do: state

  defp build_request(state, method, path, headers, body) do
    base = [
      {"CSeq", Integer.to_string(state.cseq)},
      {"User-Agent", @user_agent}
    ]

    base = if state.rtsp_session, do: base ++ [{"Session", state.rtsp_session}], else: base
    base = if body == <<>>, do: base, else: base ++ [{"Content-Length", byte_size(body)}]
    base = base ++ dacp_headers(state)

    lines =
      ["#{method} #{path} RTSP/1.0"] ++
        Enum.map(base ++ headers, fn {key, value} -> "#{key}: #{value}" end)

    IO.iodata_to_binary([Enum.join(lines, "\r\n"), "\r\n\r\n", body])
  end

  defp dacp_headers(%{dacp_id: id, active_remote: remote})
       when is_binary(id) and is_binary(remote) do
    [{"DACP-ID", id}, {"Active-Remote", remote}]
  end

  defp dacp_headers(_state), do: []

  defp maybe_encrypt(%{secure: nil} = state, message), do: {message, state}

  defp maybe_encrypt(%{secure: secure} = state, message) do
    {encrypted, secure} = SecureChannel.encrypt(secure, message)
    {encrypted, %{state | secure: secure}}
  end

  defp recv_response(%{secure: nil} = state) do
    with {:ok, head, body0} <- recv_until_headers(state.sock, <<>>),
         {:ok, status, headers, body} <- parse_complete_response(head, body0, state.sock) do
      {:ok, status, headers, body, state}
    end
  end

  defp recv_response(%{secure: secure, plain_buffer: plain, cipher_buffer: cipher} = state) do
    recv_secure_response(%{state | secure: secure}, plain, cipher)
  end

  defp recv_secure_response(state, plain, cipher) do
    case try_parse_response(plain) do
      {:ok, status, headers, body, rest} ->
        {:ok, status, headers, body, %{state | plain_buffer: rest, cipher_buffer: cipher}}

      :more ->
        with {:ok, data} <- :gen_tcp.recv(state.sock, 0, @timeout),
             {:ok, decrypted, cipher, secure} <-
               SecureChannel.decrypt_available(state.secure, cipher <> data) do
          recv_secure_response(%{state | secure: secure}, plain <> decrypted, cipher)
        end
    end
  end

  defp recv_until_headers(sock, acc) do
    case :binary.split(acc, "\r\n\r\n") do
      [head, body] ->
        {:ok, head, body}

      [_] ->
        case :gen_tcp.recv(sock, 0, @timeout) do
          {:ok, data} -> recv_until_headers(sock, acc <> data)
          error -> error
        end
    end
  end

  defp parse_complete_response(head, body0, sock) do
    {status, headers} = parse_head(head)
    length = headers |> Map.get("content-length", "0") |> to_int()

    with {:ok, body} <- recv_body(sock, length, body0) do
      {:ok, status, headers, body}
    end
  end

  defp try_parse_response(buffer) do
    case :binary.split(buffer, "\r\n\r\n") do
      [head, body0] ->
        {status, headers} = parse_head(head)
        length = headers |> Map.get("content-length", "0") |> to_int()

        if byte_size(body0) >= length do
          <<body::binary-size(^length), rest::binary>> = body0
          {:ok, status, headers, body, rest}
        else
          :more
        end

      [_] ->
        :more
    end
  end

  defp recv_body(_sock, length, acc) when byte_size(acc) >= length do
    {:ok, binary_part(acc, 0, length)}
  end

  defp recv_body(sock, length, acc) do
    case :gen_tcp.recv(sock, length - byte_size(acc), @timeout) do
      {:ok, data} -> recv_body(sock, length, acc <> data)
      error -> error
    end
  end

  defp parse_head(head) do
    [status_line | header_lines] = String.split(head, "\r\n")
    status = status_line |> String.split(" ", parts: 3) |> Enum.at(1) |> to_int()

    headers =
      header_lines
      |> Enum.reject(&(&1 == ""))
      |> Map.new(fn line ->
        [key, value] = String.split(line, ":", parts: 2)
        {String.downcase(String.trim(key)), String.trim(value)}
      end)

    {status, headers}
  end

  defp random_session_id do
    <<a::32, b::16, c0::16, d0::16, e::48>> = :crypto.strong_rand_bytes(16)
    c = (c0 &&& 0x0FFF) ||| 0x4000
    d = (d0 &&& 0x3FFF) ||| 0x8000

    [a, b, c, d, e]
    |> Enum.zip([8, 4, 4, 4, 12])
    |> Enum.map_join("-", fn {part, width} ->
      part |> Integer.to_string(16) |> String.pad_leading(width, "0")
    end)
    |> String.upcase()
  end

  defp to_int(value) do
    case Integer.parse(to_string(value)) do
      {int, _rest} -> int
      :error -> 0
    end
  end
end