lib/barytherium/frame.ex

defmodule Barytherium.Frame do
  @moduledoc """
  Representation of a STOMP frame.

  In STOMP, each frame has three components:
  - command (1, mandatory) - only specified commands are permitted
  - headers (0 or more, see below) - unspecified headers permitted
  - body (empty for commands other than SEND, MESSAGE and ERROR)

  Most commands have required and optional headers, and non-specified headers
  can also be given (though these have no specified semantics, of course).
  Most headers aren't relevant to frame parsing, with the exception of
  content-length, which defines the length of a frame's body.

  It's legal for header keys to be duplicated. In this case, only the first
  takes effect.

  Supported client frames:
  - `:stomp` (STOMP)
  - `:connect` (CONNECT)
  - `:disconnect` (DISCONNECT)
  - `:ack` (ACK)
  - `:nack` (NACK)
  - `:begin` (BEGIN)
  - `:commit` (COMMIT)
  - `:abort` (ABORT)
  - `:send` (SEND)
  - `:subscribe` (SUBSCRIBE)
  - `:unsubscribe` (UNSUBSCRIBE)

  Supported server frames:
  - `:connected` (CONNECTED)
  - `:error` (ERROR)
  - `:message` (MESSAGE)
  - `:receipt` (RECEIPT)

  Specification section: https://stomp.github.io/stomp-specification-1.2.html#STOMP_Frames
  """

  alias Barytherium.Frame
  alias Barytherium.Frame.ValueEncoding
  require Barytherium.Frame.Macros
  import Barytherium.Frame.Macros

  @enforce_keys [:command]
  defstruct command: nil, headers: [], body: ""

  @cr <<13>>
  @lf <<10>>
  @crlf @cr <> @lf

  @type frame_command ::
          :stomp
          | :connect
          | :disconnect
          | :ack
          | :nack
          | :begin
          | :commit
          | :abort
          | :send
          | :subscribe
          | :unsubscribe
          | :connected
          | :error
          | :message
          | :receipt
  @type frame_headers :: list({binary(), binary()})
  @type frame_body :: binary()
  @type t :: %Frame{command: frame_command(), headers: frame_headers(), body: frame_body()}

  @type parse_state ::
          {:command, nil, integer(), nil, binary()}
          | {:headers | :body, t(), integer(), integer() | nil, binary()}
  @type parse_state_return ::
          parse_state() | {:done, t(), nil, 0, binary()} | {:error, atom(), binary()}

  @spec headers_to_map(frame_headers()) :: map()
  def headers_to_map(headers) do
    Map.new(Enum.reverse(headers))
  end

  @spec format(list(t())) :: binary()
  def format(frames) when is_list(frames) do
    Enum.map_join(frames, "", &format/1)
  end

  format_fn_command("STOMP", :stomp)
  format_fn_command("CONNECT", :connect)
  format_fn_command("DISCONNECT", :disconnect)
  format_fn_command("ACK", :ack)
  format_fn_command("NACK", :nack)
  format_fn_command("BEGIN", :begin)
  format_fn_command("COMMIT", :commit)
  format_fn_command("ABORT", :abort)
  format_fn_command("SEND", :send)
  format_fn_command("SUBSCRIBE", :subscribe)
  format_fn_command("UNSUBSCRIBE", :unsubscribe)
  format_fn_command("CONNECTED", :connected)
  format_fn_command("ERROR", :error)
  format_fn_command("MESSAGE", :message)
  format_fn_command("RECEIPT", :receipt)

  defp format(command, %Frame{body: body, headers: headers}) do
    headers =
      if byte_size(body) != 0,
        do: [{"content-length", Integer.to_string(byte_size(body))} | headers],
        else: headers

    headers_formatted =
      headers
      |> Enum.map(fn {key, value} ->
        ValueEncoding.encode(key) <> ":" <> ValueEncoding.encode(value)
      end)

    Enum.join([command | headers_formatted], "\n") <> "\n\n" <> body <> "\0"
  end

  @spec parse_all(parse_state()) :: {parse_state_return(), list(t())}

  def parse_all(parse_state = {_, _, _, _, _}) do
    parse_all(parse(parse_state), [])
  end

  defp parse_all(parse_state = {:error, _stage, _message}, accumulator) do
    {parse_state, accumulator}
  end

  defp parse_all(parse_state = {command, _, _, _, _}, accumulator) when command != :done do
    {parse_state, Enum.reverse(accumulator)}
  end

  defp parse_all({:done, frame, _, _, remainder}, accumulator) do
    parse_all(parse({:command, nil, 0, nil, remainder}), [frame | accumulator])
  end

  @spec parse(parse_state()) :: parse_state_return()

  def parse({command, frame, interval, length, data}) when byte_size(data) < interval do
    {command, frame, interval, length, data}
  end

  def parse({:command, nil, 0, nil, <<@lf, rest::binary>>}) do
    parse({:command, nil, 0, nil, rest})
  end

  def parse({:command, nil, 0, nil, <<@crlf, rest::binary>>}) do
    parse({:command, nil, 0, nil, rest})
  end

  parse_fn_command("STOMP", :stomp)
  parse_fn_command("CONNECT", :connect)
  parse_fn_command("DISCONNECT", :disconnect)
  parse_fn_command("ACK", :ack)
  parse_fn_command("NACK", :nack)
  parse_fn_command("BEGIN", :begin)
  parse_fn_command("COMMIT", :commit)
  parse_fn_command("ABORT", :abort)
  parse_fn_command("SEND", :send)
  parse_fn_command("SUBSCRIBE", :subscribe)
  parse_fn_command("UNSUBSCRIBE", :unsubscribe)
  parse_fn_command("CONNECTED", :connected)
  parse_fn_command("ERROR", :error)
  parse_fn_command("MESSAGE", :message)
  parse_fn_command("RECEIPT", :receipt)

  def parse({:command, nil, 0, nil, data}) when byte_size(data) >= 13 do
    {:error, :message,
     "Frame too long without command match: #{inspect(String.slice(data, 0..12))}"}
  end

  def parse({:command, nil, 0, nil, data}) do
    {:command, nil, 0, nil, data}
  end

  def parse({:headers, frame = %Frame{command: command, headers: headers}, 0, length, data}) do
    case Regex.split(~r/\r?\n/, data, parts: 2) do
      ["", rest] ->
        parse({:body, Map.put(frame, :headers, Enum.reverse(headers)), 0, length, rest})

      [line, rest] ->
        header_parsed = {header_key, header_value} = parse_header(line, command)

        length =
          if is_nil(length) and header_key == "content-length",
            do: elem(Integer.parse(header_value), 0),
            else: length

        parse({:headers, Map.put(frame, :headers, [header_parsed | headers]), 0, length, rest})

      [rest] ->
        {:headers, frame, 0, length, rest}
    end
  end

  def parse({:body, frame = %Frame{}, interval, nil, data}) do
    case :binary.match(data, <<0>>, scope: {interval, byte_size(data) - interval}) do
      :nomatch ->
        {:body, frame, byte_size(data), nil, data}

      {match_pos, _} ->
        {:done, Map.put(frame, :body, binary_part(data, 0, match_pos)), 0, nil,
         binary_part(data, match_pos + 1, byte_size(data) - match_pos - 1)}
    end
  end

  def parse({:body, frame = %Frame{}, 0, length, data}) when byte_size(data) < length + 1 do
    {:body, frame, 0, length, data}
  end

  def parse({:body, frame = %Frame{}, 0, length, data}) do
    body = Map.put(frame, :body, binary_part(data, 0, length))

    <<0, rest::binary>> = binary_part(data, length, byte_size(data) - length)

    {:done, body, 0, nil, rest}
  end

  defp parse_header(header_line, command)
       when command == :connected or command == :connect or command == :stomp do
    [key, value] = :binary.split(header_line, ":")
    {key, value}
  end

  defp parse_header(header_line, _) do
    [key, value] = String.split(header_line, ":")
    {ValueEncoding.decode(key), ValueEncoding.decode(value)}
  end
end