Skip to main content

lib/amarula/protocol/crypto/noise_handler.ex

defmodule Amarula.Protocol.Crypto.NoiseHandler do
  @moduledoc """
  Noise protocol handler for WhatsApp WebSocket communication.

  This module implements the Noise_XX_25519_AESGCM_SHA256 protocol used by WhatsApp
  for secure WebSocket communication. It provides stateless functions that operate
  on noise state, which is stored in the Connection's GenServer state.

  The noise state is recreated on every new WebSocket connection with fresh ephemeral keys.
  """

  import Bitwise
  require Logger
  alias Amarula.Protocol.Crypto.{Crypto, Constants}

  # Pinned WhatsApp root certificate (Baileys WA_CERT_DETAILS). The Noise_XX
  # handshake's server authentication is only anchored to WhatsApp by verifying the
  # cert chain's intermediate signature against THIS key — without it any cert with
  # issuerSerial=0 would pass.
  @wa_cert_public_key Base.decode16!(
                        "142375574D0A587166AAE71EBE516437C4A28B73E3695C6CE1F7F9545DA8EE6B"
                      )
  @wa_cert_serial 0

  @type noise_state :: %{
          ephemeral_key_pair: Crypto.key_pair(),
          hash: binary(),
          salt: binary(),
          enc_key: binary(),
          dec_key: binary(),
          read_counter: non_neg_integer(),
          write_counter: non_neg_integer(),
          handshake_state: :init | :awaiting_server_hello | :handshake_complete | :transport,
          sent_intro: boolean(),
          in_bytes: binary(),
          routing_info: binary() | nil,
          noise_header: binary()
        }

  @type handshake_result :: {:ok, binary(), noise_state()} | {:error, term()}
  @type frame_result :: {:ok, list(binary()), noise_state()}
  @type encrypt_result :: {binary(), noise_state()}
  @type decrypt_result :: {:ok, binary(), noise_state()} | {:error, term()}

  @doc """
  Create initial noise state with ephemeral key pair and configuration.

  Returns a new noise state struct ready for handshake.
  """
  @spec new(Crypto.key_pair(), keyword()) :: noise_state()
  def new(ephemeral_key_pair, opts \\ []) do
    routing_info = Keyword.get(opts, :routing_info)
    noise_header = Constants.noise_wa_header()

    # Initialize hash with noise mode
    # Match Baileys: if noise_mode is exactly 32 bytes, use it directly as hash
    # otherwise compute SHA256
    noise_mode = Constants.noise_mode()
    hash = if byte_size(noise_mode) == 32, do: noise_mode, else: Crypto.sha256(noise_mode)

    %{
      ephemeral_key_pair: ephemeral_key_pair,
      hash: hash,
      salt: hash,
      enc_key: hash,
      dec_key: hash,
      read_counter: 0,
      write_counter: 0,
      handshake_state: :init,
      sent_intro: false,
      in_bytes: <<>>,
      routing_info: routing_info,
      noise_header: noise_header
    }
    |> authenticate(noise_header)
    |> authenticate(ephemeral_key_pair.public)
  end

  @doc """
  Update running hash with data for authentication.

  Returns updated noise state.
  """
  @spec authenticate(noise_state(), binary()) :: noise_state()
  def authenticate(%{handshake_state: :transport} = state, _data) do
    # Handshake complete, authentication no longer needed
    state
  end

  def authenticate(state, data) do
    hash = Crypto.sha256(state.hash <> data)
    %{state | hash: hash}
  end

  @doc """
  Encrypt plaintext using current encryption key and counter.

  Returns {encrypted_data, updated_state}.
  """
  @spec encrypt(noise_state(), binary()) :: encrypt_result()
  def encrypt(%{handshake_state: :init} = _state, _plaintext) do
    raise "Cannot encrypt before keys are established via mix_into_key"
  end

  def encrypt(state, plaintext) do
    iv = Crypto.generate_iv(state.write_counter)
    aad = if state.handshake_state == :transport, do: <<>>, else: state.hash

    {:ok, encrypted} = Crypto.aes_encrypt_gcm(plaintext, state.enc_key, iv, aad)
    new_state = %{state | write_counter: state.write_counter + 1}
    new_state = authenticate(new_state, encrypted)
    {encrypted, new_state}
  end

  @doc """
  Decrypt ciphertext using current decryption key and counter.

  Returns {:ok, decrypted_data, updated_state} or {:error, reason}.
  """
  @spec decrypt(noise_state(), binary()) :: decrypt_result()
  def decrypt(state, ciphertext) do
    # Compute default AAD based on phase and forward to 3-arity
    aad = if state.handshake_state == :transport, do: <<>>, else: state.hash
    decrypt(state, ciphertext, aad)
  end

  @doc """
  Decrypt with explicit AAD (used by tests to ensure exact handshake semantics).

  Returns {:ok, decrypted_data, updated_state} or {:error, reason}.
  """
  @spec decrypt(noise_state(), binary(), binary()) :: decrypt_result()
  def decrypt(state, ciphertext, aad) do
    counter =
      case state.handshake_state do
        :transport -> state.read_counter
        _ -> state.write_counter
      end

    iv = Crypto.generate_iv(counter)

    case Crypto.aes_decrypt_gcm(ciphertext, state.dec_key, iv, aad) do
      {:ok, decrypted} ->
        new_state =
          case state.handshake_state do
            :transport -> %{state | read_counter: state.read_counter + 1}
            _ -> %{state | write_counter: state.write_counter + 1}
          end

        new_state = authenticate(new_state, ciphertext)
        {:ok, decrypted, new_state}

      {:error, reason} ->
        Logger.error("Decryption failed: #{inspect(reason)}")
        {:error, reason}
    end
  end

  @doc """
  Mix data into key using HKDF.

  Returns updated noise state with new keys.
  """
  @spec mix_into_key(noise_state(), binary()) :: noise_state()
  def mix_into_key(state, data) do
    # HKDF according to Noise MixKey: HKDF(ck, DH_output, 2)
    # Returns (ck, k) where ck becomes the new salt and k is the cipher key
    derived_key = Crypto.hkdf(data, Constants.hkdf_output_length(), state.salt, <<>>)
    {new_salt, cipher_key} = :erlang.split_binary(derived_key, 32)

    # During handshake: both enc_key and dec_key use the same key (cipher_key)
    # Messages alternate, so both parties use the same cipher state
    # Key splitting into read/write keys only happens during finish_init (Split operation)
    # IMPORTANT: Reset counters to 0 on MixKey (per Noise spec and Baileys implementation)
    %{
      state
      | salt: new_salt,
        enc_key: cipher_key,
        dec_key: cipher_key,
        read_counter: 0,
        write_counter: 0,
        handshake_state: :handshake
    }
  end

  @doc """
  Complete handshake initialization by splitting keys.

  Returns updated noise state with finished handshake.
  """
  @spec finish_init(noise_state()) :: noise_state()
  def finish_init(state) do
    # Final HKDF with empty input
    derived_key = Crypto.hkdf(<<>>, Constants.hkdf_output_length(), state.salt, <<>>)
    {write_key, read_key} = :erlang.split_binary(derived_key, 32)

    Logger.debug("Noise finish_init: split keys, reset counters, entering transport phase")

    %{
      state
      | enc_key: write_key,
        dec_key: read_key,
        hash: <<>>,
        read_counter: 0,
        write_counter: 0,
        handshake_state: :transport
    }
  end

  @doc """
  Process server hello message during handshake.

  Returns {:ok, encrypted_key, updated_state} or {:error, reason}.
  """
  @spec process_handshake(noise_state(), map(), Crypto.key_pair()) :: handshake_result()
  def process_handshake(state, %{serverHello: server_hello}, noise_key) do
    # Authenticate the server's ephemeral public key into the hash
    state = authenticate(state, server_hello.ephemeral)

    # Compute the shared secret from ECDH and mix into keys
    shared_key = Crypto.shared_key(state.ephemeral_key_pair.private, server_hello.ephemeral)
    state = mix_into_key(state, shared_key)
    state = %{state | handshake_state: :awaiting_server_hello}

    # Decrypt and mix server static
    {:ok, decrypted_static, state_after_first_decrypt} = decrypt(state, server_hello.static)
    shared_static = Crypto.shared_key(state.ephemeral_key_pair.private, decrypted_static)
    state_after_mix = mix_into_key(state_after_first_decrypt, shared_static)

    # Decrypt and verify certificate
    {:ok, cert_decoded, state_after_second_decrypt} =
      decrypt(state_after_mix, server_hello.payload)

    with :ok <- verify_certificate(cert_decoded) do
      # Encrypt noise key
      {key_encrypted, state_after_encrypt} = encrypt(state_after_second_decrypt, noise_key.public)

      # Mix noise key with server ephemeral
      noise_shared = Crypto.shared_key(noise_key.private, server_hello.ephemeral)
      final_state = mix_into_key(state_after_encrypt, noise_shared)

      {:ok, key_encrypted, final_state}
    end
  end

  @doc """
  Encode data into protocol frame format.

  Returns {frame_binary, updated_state}.
  """
  @spec encode_frame(noise_state(), binary()) :: encrypt_result()
  def encode_frame(state, data) do
    # Encrypt data if handshake is complete (transport phase)
    {processed_data, state} =
      if state.handshake_state == :transport do
        encrypt(state, data)
      else
        {data, state}
      end

    # Build frame header
    header =
      if state.routing_info do
        routing_size = byte_size(state.routing_info)
        <<"ED", 0, 1, routing_size::16, state.routing_info::binary, state.noise_header::binary>>
      else
        state.noise_header
      end

    # Build frame by concatenating parts
    # Match TypeScript: write length as 3 separate bytes (big-endian)
    data_length = byte_size(processed_data)

    length_bytes = <<
      data_length >>> 16::8,
      data_length >>> 8 &&& 0xFF::8,
      data_length &&& 0xFF::8
    >>

    frame =
      if not state.sent_intro do
        header <> length_bytes <> processed_data
      else
        length_bytes <> processed_data
      end

    new_state = %{state | sent_intro: true}

    {frame, new_state}
  end

  @doc """
  Decode incoming frames and extract messages.

  Returns {:ok, frames, updated_state} where frames is a list of decoded messages.
  """
  @spec decode_frame(noise_state(), binary()) :: {:ok, list(binary()), noise_state()}
  def decode_frame(state, new_data) do
    # Append new data to buffer
    in_bytes = state.in_bytes <> new_data
    frames = []

    # Process complete frames, thread state through recursion
    {remaining_bytes, decoded_frames, final_state} = process_frames(in_bytes, frames, state)

    new_state = %{final_state | in_bytes: remaining_bytes}
    # process_frames accumulates by prepending; reverse to restore arrival order.
    {:ok, Enum.reverse(decoded_frames), new_state}
  end

  # Private helper functions

  # Base case: not enough data for length header
  defp process_frames(in_bytes, frames, state) when byte_size(in_bytes) < 3 do
    {in_bytes, frames, state}
  end

  defp process_frames(in_bytes, frames, state) do
    # Extract frame length using TypeScript format: (byte1 << 16) | (byte2 << 8) | byte3
    <<byte1, byte2, byte3, rest::binary>> = in_bytes
    length = byte1 <<< 16 ||| byte2 <<< 8 ||| byte3

    Logger.debug(
      "Processing Noise frame: length=#{length}, remaining_bytes=#{byte_size(rest)}, frames_so_far=#{length(frames)}"
    )

    cond do
      byte_size(rest) < length ->
        # Not enough data for complete frame
        Logger.debug("Not enough data for complete frame, waiting for more")
        {in_bytes, frames, state}

      true ->
        # Extract frame data
        <<frame_data::binary-size(^length), remaining::binary>> = rest

        Logger.debug(
          "Extracted frame data, remaining=#{byte_size(remaining)} bytes after this frame"
        )

        # Decrypt frame if in transport phase, thread state through
        {processed_frame, updated_state} = decrypt_if_transport(state, frame_data)

        # Continue processing remaining frames with updated state
        process_frames(remaining, [processed_frame | frames], updated_state)
    end
  end

  defp decrypt_if_transport(%{handshake_state: :transport} = state, frame_data) do
    case decrypt(state, frame_data) do
      {:ok, decrypted, new_state} ->
        {decrypted, new_state}

      {:error, reason} ->
        # Propagate instead of silently returning the (still-encrypted) frame
        raise "Transport phase decryption failed: #{inspect(reason)}"
    end
  end

  defp decrypt_if_transport(state, frame_data) do
    {frame_data, state}
  end

  # Verify the server's Noise certificate chain (Baileys noise-handler.ts). Three
  # checks, ALL required:
  #   1. the leaf is signed by the intermediate's key,
  #   2. the intermediate is signed by the pinned WhatsApp root key (@wa_cert_public_key),
  #   3. the intermediate's issuerSerial is the pinned serial (0).
  # Check #2 is what actually authenticates the server as WhatsApp; #1 chains the
  # leaf to it. Signatures are XEd25519 (Crypto.verify = XEdDSA), keyed by a 32-byte
  # Montgomery key — strip a leading 0x05 type byte if present.
  @doc false
  @spec verify_certificate(binary()) :: :ok | {:error, term()}
  def verify_certificate(cert_decoded) do
    with %{intermediate: intermediate, leaf: leaf}
         when not is_nil(intermediate) and not is_nil(leaf) <-
           Amarula.Protocol.Proto.CertChain.decode(cert_decoded),
         %{details: int_details, signature: int_sig}
         when is_binary(int_details) and is_binary(int_sig) <-
           intermediate,
         %{details: leaf_details, signature: leaf_sig}
         when is_binary(leaf_details) and is_binary(leaf_sig) <-
           leaf,
         details = Amarula.Protocol.Proto.CertChain.NoiseCertificate.Details.decode(int_details),
         :ok <- verify_sig(leaf_details, leaf_sig, details.key, :leaf_signature_invalid),
         :ok <-
           verify_sig(int_details, int_sig, @wa_cert_public_key, :intermediate_signature_invalid),
         :ok <-
           check(details.issuerSerial == @wa_cert_serial, {:issuer_serial, details.issuerSerial}) do
      :ok
    else
      {:error, _reason} = err -> err
      other -> {:error, {:invalid_certificate, other}}
    end
  end

  defp verify_sig(data, signature, public_key, reason) do
    check(Crypto.verify(data, signature, strip5(public_key)), reason)
  end

  defp check(true, _reason), do: :ok
  defp check(false, reason), do: {:error, reason}

  # WhatsApp pubkeys are sometimes prefixed with a 0x05 type byte; XEdDSA wants the
  # raw 32-byte Montgomery key.
  defp strip5(<<5, key::binary-size(32)>>), do: key
  defp strip5(<<key::binary-size(32)>>), do: key
  defp strip5(other), do: other
end