Skip to main content

lib/air_play/v2/secure_channel.ex

defmodule AirPlay.V2.SecureChannel do
  @moduledoc """
  AirPlay 2 encrypted RTSP control/event channel framing.

  Plain RTSP bytes are split into 1024-byte chunks. Each encrypted chunk is:
  `uint16_le(length) <> ciphertext <> tag`, with the length prefix as AAD and a
  monotonically increasing little-endian counter in the ChaCha20-Poly1305 nonce.
  """

  alias AirPlay.V2.Crypto

  defstruct [
    :encrypt_key,
    :decrypt_key,
    :write_key,
    :read_key,
    encrypt_counter: 0,
    decrypt_counter: 0,
    write_ctr: 0,
    read_ctr: 0,
    cipher_buffer: <<>>
  ]

  @block_max 1024
  @tag_len 16

  @type t :: %__MODULE__{}

  @doc "Create a client-side secure control channel from the SRP shared secret."
  @spec control(binary()) :: t()
  def control(shared_secret), do: new(shared_secret, :control)

  @doc "Create a secure channel for `:control` or `:events`."
  @spec new(binary(), :control | :events) :: t()
  def new(shared_secret, channel) do
    {write_salt, write_info, read_salt, read_info} = salts(channel)

    write_key = Crypto.hkdf_sha512(shared_secret, 32, write_salt, write_info)
    read_key = Crypto.hkdf_sha512(shared_secret, 32, read_salt, read_info)

    %__MODULE__{
      encrypt_key: write_key,
      decrypt_key: read_key,
      write_key: write_key,
      read_key: read_key
    }
  end

  @doc "Create a control secure channel from a shared secret."
  @spec new(binary()) :: t()
  def new(shared_secret), do: new(shared_secret, :control)

  @doc "Encrypt plaintext into AP2 secure-channel chunks."
  @spec encrypt(t(), binary()) :: {binary(), t()}
  def encrypt(%__MODULE__{} = channel, plaintext) do
    do_encrypt(channel, plaintext, [])
  end

  @doc "Decrypt as many complete chunks as are available."
  @spec decrypt_available(t(), binary()) :: {:ok, binary(), binary(), t()} | {:error, term()}
  def decrypt_available(%__MODULE__{} = channel, ciphertext) do
    do_decrypt(channel, ciphertext, [])
  end

  defp do_encrypt(channel, <<>>, acc), do: {IO.iodata_to_binary(Enum.reverse(acc)), channel}

  defp do_encrypt(%__MODULE__{} = channel, plaintext, acc) do
    size = min(byte_size(plaintext), @block_max)
    <<chunk::binary-size(^size), rest::binary>> = plaintext
    aad = <<size::16-little>>
    nonce = <<0::32-little, channel.write_ctr::64-little>>
    {encrypted, tag} = Crypto.chacha20_poly1305_encrypt(channel.write_key, nonce, chunk, aad)

    channel = %{
      channel
      | encrypt_counter: channel.write_ctr + 1,
        write_ctr: channel.write_ctr + 1
    }

    do_encrypt(channel, rest, [[aad, encrypted, tag] | acc])
  end

  @doc "Decrypt AP2 secure-channel chunks, buffering an incomplete trailing chunk."
  @spec decrypt(t(), binary()) :: {binary(), t()}
  def decrypt(%__MODULE__{} = channel, ciphertext) do
    case decrypt_available(channel, channel.cipher_buffer <> ciphertext) do
      {:ok, plaintext, rest, channel} -> {plaintext, %{channel | cipher_buffer: rest}}
      {:error, _reason} -> {:error, channel}
    end
  end

  defp do_decrypt(channel, ciphertext, acc) when byte_size(ciphertext) < 2 do
    {:ok, IO.iodata_to_binary(Enum.reverse(acc)), ciphertext, channel}
  end

  defp do_decrypt(%__MODULE__{} = channel, <<size::16-little, rest::binary>> = ciphertext, acc) do
    needed = size + @tag_len

    if byte_size(rest) < needed do
      {:ok, IO.iodata_to_binary(Enum.reverse(acc)), ciphertext, channel}
    else
      <<encrypted::binary-size(^size), tag::binary-size(@tag_len), tail::binary>> = rest
      aad = <<size::16-little>>
      nonce = <<0::32-little, channel.read_ctr::64-little>>

      case Crypto.chacha20_poly1305_decrypt(channel.read_key, nonce, encrypted, tag, aad) do
        {:ok, plaintext} ->
          channel = %{
            channel
            | decrypt_counter: channel.read_ctr + 1,
              read_ctr: channel.read_ctr + 1
          }

          do_decrypt(channel, tail, [plaintext | acc])

        :error ->
          {:error, :decrypt_failed}
      end
    end
  end

  defp salts(:control) do
    {
      "Control-Salt",
      "Control-Write-Encryption-Key",
      "Control-Salt",
      "Control-Read-Encryption-Key"
    }
  end

  defp salts(:events) do
    {
      "Events-Salt",
      "Control-Read-Encryption-Key",
      "Events-Salt",
      "Events-Write-Encryption-Key"
    }
  end
end