lib/mammoth.ex

defmodule Mammoth do
  @moduledoc ~S"""
  Mammoth: A STOMP client.

  Use `Mammoth` as the primary API and `Mammoth.Message` for working with received messages.

  My recommendation is to start mammoth from within your own implementation of the callback handler,
    and to put the callback handler itself into a supervisor.

  You shouldn't try to reconnect a disconnected mammoth process, at the moment the internal state
    may be inconsistent. Instead, start a new instance. Again, this is easily handled with OTP.

  ## Example
      {:ok, callback_pid} = Mammoth.DefaultCallbackHandler.start_link
      {:ok, pid} = Mammoth.start_link(callback_pid)
      :ok = Mammoth.DefaultCallbackHandler.set_mammoth_pid(callback_pid, pid)
      Mammoth.connect(pid, {127,0,0,1}, 61613, "/", "admin", "admin")
      Mammoth.subscribe(pid, "foo.bar", :client)
      Mammoth.disconnect(pid)

  ## Starting in a supervision tree

      children = [
        worker(Mammoth, [%{}, [name: Mammoth]])
      ]
  """

  use GenServer

  require Logger
  alias Mammoth.{Message, Receiver, Socket, Subscriber}

  @timeout_allowance_out 200
  @timeout_allowance_in 500

  def init(args) do
    {
      :ok,
      args
      |> Map.put_new(:socket, nil)
      |> Map.put_new(:receiver, nil)
      |> Map.put_new(:subscriber, nil)
      |> Map.put_new(:heartbeat, [0, 0])
      |> Map.put_new(:target_heartbeat, [5000, 5000])
      |> Map.put_new(:last_received_timestamp, 0)
      |> Map.put_new(:last_sent_timestamp, 0)
      |> Map.put_new(:is_connected, false)
    }
  end

  def terminate(_reason, _state = %{socket: socket}) when not is_nil(socket) do
    Logger.debug("Mammoth is being terminated, closing socket")
    :gen_tcp.close(socket)
  end

  @doc """
  Start link.

  `opts` is a map describing additional configuration options for this mammoth instance
  Currently supported arguments, any or none of:
  * :recbuf (int) - receive buffer size in bytes (default 0x400000: 4MiB)

  `link_opts` is a keyword lists with options passed to `GenServer.start_link`
  """
  def start_link(callback_handler, opts \\ %{}, link_opts \\ []) do
    GenServer.start_link(
      __MODULE__,
      %{callback_handler: callback_handler, opts: opts},
      link_opts
    )
  end

  @doc """
  Connect to server.

  `host` must be `inet:socket_address() | inet:hostname()`, for example `{127,0,0,1}` or `'example.com'`.
  `virtual_host` (`host` header) is mandatory in STOMP 1.2, behaviour is server defined. Default on RabbitMQ would be `"/"`.

  https://stomp.github.io/stomp-specification-1.2.html#CONNECT_or_STOMP_Frame
  """
  def connect(
        pid,
        host,
        port,
        virtual_host,
        login,
        password,
        additional_headers \\ [],
        send_timeout \\ 750
      ) do
    GenServer.call(
      pid,
      {:connect, host, port, virtual_host, login, password, additional_headers, send_timeout}
    )
  end

  @doc """
  Subscribe to a queue.

  Note that if you set ack_mode to :client_individual or :client,
  you must send ACK frames when you receive MESSAGE frames (not required with :auto)
  https://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE_ack_Header
  """
  def subscribe(pid, destination, ack_mode \\ :auto, additional_headers \\ []) do
    GenServer.call(pid, {:subscribe, destination, get_ack_mode(ack_mode), additional_headers})
  end

  defp get_ack_mode(ack_mode) do
    case ack_mode do
      :auto -> "auto"
      :client_individual -> "client-individual"
      :client -> "client"
    end
  end

  @doc """
  Unsubscribe from a queue.
  """
  def unsubscribe(pid, destination) do
    GenServer.call(pid, {:unsubscribe, destination})
  end

  @doc """
  Receive messages from the TCP socket.

  Is called automatically when necessary. Should not be called manually.
  """
  def receive(pid, message) do
    GenServer.call(pid, {:receive, message})
  end

  @doc """
  Disconnect from server.
  """
  def disconnect(pid) do
    GenServer.call(pid, :disconnect)
  end

  @doc """
  Send a frame to the remote server
  """
  def send_frame(pid, message = %Message{}) do
    GenServer.call(pid, {:send, message})
  end

  @doc """
  Send a SEND message to the remote server

  content-length header will be automatically added.
  """
  def send_send(pid, destination, body, additional_headers \\ []) do
    GenServer.call(
      pid,
      {:send,
       %Message{
         command: :send,
         headers: additional_headers ++ [{"destination", destination}],
         body: body
       }}
    )
  end

  @doc """
  Send an ACK message to the remote server for the specified frame if the 'ack' header is present
  """
  def send_ack_frame(pid, message = %Message{}, additional_headers \\ []) do
    case Message.get_header(message, "ack") do
      {:ok, message_id} -> send_ack_id(pid, message_id, additional_headers)
      {:error, :notfound} -> nil
    end
  end

  @doc """
  Send an ACK message to the remote server for the specified frame ID (referenced in the MESSAGE `ack` header)
  """
  def send_ack_id(pid, id, additional_headers) do
    GenServer.call(
      pid,
      {:send, %Message{command: :ack, headers: additional_headers ++ [{"id", id}]}}
    )
  end

  @doc """
  Send an NACK message to the remote server for the specified frame
  """
  def send_nack_frame(pid, message = %Message{}, additional_headers \\ []) do
    case Message.get_header(message, "ack") do
      {:ok, message_id} -> send_nack_id(pid, message_id, additional_headers)
      {:error, :notfound} -> nil
    end
  end

  @doc """
  Send an NACK message to the remote server for the specified frame ID (referenced in the MESSAGE `ack` header)
  """
  def send_nack_id(pid, id, additional_headers) do
    GenServer.call(
      pid,
      {:send, %Message{command: :nack, headers: [{"id", id}] ++ additional_headers}}
    )
  end

  defp determine_heartbeat(message, _state = %{target_heartbeat: target_heartbeat}) do
    case Message.get_header(message, "heart-beat") do
      {:ok, value} ->
        value
        |> String.split(",")
        |> Enum.map(fn n -> elem(Integer.parse(n), 0) end)
        |> Enum.reverse()
        |> Enum.zip(target_heartbeat)
        |> Enum.map(fn {a, b} -> Kernel.min(a, b) end)

      {:error, _} ->
        [0, 0]
    end
  end

  def handle_info(
        :heartbeat_watch,
        state = %{
          socket: socket,
          heartbeat: heartbeat,
          last_received_timestamp: last_received_timestamp,
          last_sent_timestamp: last_sent_timestamp,
          is_connected: true
        }
      ) do
    current_time = :os.system_time(:millisecond)
    {heartbeat_out, heartbeat_in} = List.to_tuple(heartbeat)

    last_sent_timestamp =
      if last_sent_timestamp - current_time < -(heartbeat_out + @timeout_allowance_out) and
           heartbeat_out > 0 do
        Socket.send(socket, "\r\n")
        Logger.debug("Sending heartbeat")
        :os.system_time(:millisecond)
      else
        last_sent_timestamp
      end

    if current_time - last_received_timestamp > heartbeat_in + @timeout_allowance_in and
         heartbeat_in > 0 do
      GenServer.cast(self(), {:fatal, :timeout})
      Logger.debug("Timed out on remote heartbeat")
    end

    Process.send_after(self(), :heartbeat_watch, 500)
    {:noreply, Map.put(state, :last_sent_timestamp, last_sent_timestamp)}
  end

  def handle_info(:heartbeat_watch, state), do: {:noreply, state}

  def handle_call(
        {:send, message},
        _from,
        state = %{socket: socket}
      ) do
    Socket.send(socket, message)
    {:reply, :ok, Map.put(state, :last_sent_timestamp, :os.system_time(:millisecond))}
  end

  def handle_call(
        {:receive, :heartbeat},
        _from,
        state
      ) do
    {:reply, :ok, Map.put(state, :last_received_timestamp, :os.system_time(:millisecond))}
  end

  def handle_call(
        {:receive, message = %Message{command: command}},
        _from,
        state = %{callback_handler: callback_handler}
      ) do
    state =
      if command == :connected do
        Map.put(state, :heartbeat, determine_heartbeat(message, state))
      else
        state
      end

    Kernel.send(callback_handler, {:mammoth, :receive_frame, message})

    {:reply, :ok, Map.put(state, :last_received_timestamp, :os.system_time(:millisecond))}
  end

  def handle_call(
        {:connect, host, port, virtual_host, login, password, additional_headers, send_timeout},
        _from,
        state = %{target_heartbeat: target_heartbeat, opts: opts}
      ) do
    # todo: handle {:error, :econnrefused} response
    {:ok, socket} = Socket.connect(host, port, send_timeout, Map.get(opts, :recbuf))

    Socket.send(
      socket,
      connect_message(virtual_host, login, password, target_heartbeat, additional_headers)
    )

    {:ok, subscriber} = Subscriber.start_link()
    {:ok, receiver} = Receiver.start_link(%{socket: socket, consumer: self()})
    Receiver.listen(receiver)

    Process.send_after(self(), :heartbeat_watch, 500)

    {:reply, socket,
     %{state | socket: socket, subscriber: subscriber, receiver: receiver, is_connected: true}}
  end

  @doc """
  Requests disconnection from the remote server
  """
  def handle_call(
        :disconnect,
        _from,
        state = %{
          socket: socket
        }
      ) do
    receipt_id = Enum.random(1_000..1_000_000)
    Socket.send(socket, disconnect_message(receipt_id))

    {:noreply, Map.put(state, :disconnect_id, "#{receipt_id}")}
  end

  def handle_call(
        {:subscribe, destination, ack_mode, additional_headers},
        _from,
        state = %{socket: socket, subscriber: subscriber}
      ) do
    {:ok, entry} = Subscriber.subscribe(subscriber, destination, ack_mode)
    message = subscribe_message(destination, entry.id, ack_mode, additional_headers)
    Socket.send(socket, message)
    {:reply, {:ok, entry}, state}
  end

  def handle_call(
        {:unsubscribe, destination},
        _from,
        state = %{socket: socket, subscriber: subscriber}
      ) do
    {:ok, entry} = Subscriber.unsubscribe(subscriber, destination)
    message = unsubscribe_message(entry.id)
    Socket.send(socket, message)
    {:reply, :ok, state}
  end

  def handle_cast(
        :disconnected,
        state = %{
          subscriber: subscriber,
          receiver: receiver,
          callback_handler: callback_handler,
          is_connected: true
        }
      ) do
    Subscriber.stop(subscriber)
    Receiver.stop(receiver)

    Kernel.send(
      callback_handler,
      {:mammoth, :disconnected,
       case Map.get(state, :disconnect_id) do
         nil -> :remote
         _ -> :local
       end}
    )

    {:noreply, Map.put(state, :is_connected, false)}
  end

  def handle_cast(:disconnected, state), do: {:noreply, state}

  def handle_cast(
        {:fatal, reason},
        state = %{
          socket: socket,
          subscriber: subscriber,
          receiver: receiver,
          callback_handler: callback_handler,
          is_connected: true
        }
      ) do
    if reason == :timeout do
      :gen_tcp.close(socket)
    else
      Socket.send(socket, disconnect_message("-"))
      :gen_tcp.shutdown(socket, :read)
    end

    Subscriber.stop(subscriber)
    Receiver.stop(receiver)

    Kernel.send(callback_handler, {:mammoth, :disconnected, reason})

    {:noreply, Map.put(state, :is_connected, false)}
  end

  defp connect_message(virtual_host, login, password, heartbeat, additional_headers) do
    %Message{
      command: :connect,
      headers:
        additional_headers ++
          [
            {"accept-version", "1.2"},
            {"heart-beat", Enum.join(heartbeat, ",")},
            {"host", virtual_host},
            {"login", login},
            {"passcode", password}
          ]
    }
  end

  defp disconnect_message(receipt_id) do
    %Message{
      command: :disconnect,
      headers: [
        {"receipt-id", receipt_id}
      ]
    }
  end

  defp subscribe_message(destination, id, ack_mode, additional_headers) do
    %Message{
      command: :subscribe,
      headers:
        additional_headers ++
          [
            {"destination", destination},
            {"ack", ack_mode},
            {"id", id}
          ]
    }
  end

  defp unsubscribe_message(id) do
    %Message{
      command: :unsubscribe,
      headers: [
        {"id", id}
      ]
    }
  end
end