lib/mongo/messages.ex

defmodule Mongo.Messages do
  @moduledoc """
    This module encodes and decodes the data from and to the mongodb server.
    We only support MongoDB >= 3.2 and use op_query with the hack collection "$cmd"
    Other op codes are deprecated. Therefore only op_reply and op_query are supported.
  """

  defmacro __using__(_opts) do
    quote do
      import unquote(__MODULE__)
    end
  end

  import Record
  import Mongo.BinaryUtils

  @op_reply 1
  @op_query 2004
  @op_msg_code 2013

  @query_flags [
    tailable_cursor: 0x2,
    slave_ok: 0x4,
    oplog_replay: 0x8,
    no_cursor_timeout: 0x10,
    await_data: 0x20,
    exhaust: 0x40,
    partial: 0x80
  ]

  @msg_flags [
    # Checksum present
    checksum_present: 0x00001,
    # Sender will send another message and is not prepared for overlapping messages
    more_to_come: 0x00002,
    # Client is prepared for multiple replies (using the moreToCome bit) to this request
    exhaust_allowed: 0x10000
  ]

  @header_size 4 * 4

  defrecord :msg_header, [:length, :request_id, :response_to, :op_code]
  defrecord :op_query, [:flags, :coll, :num_skip, :num_return, :query, :select]
  defrecord :op_reply, [:flags, :cursor_id, :from, :num, :docs]
  defrecord :sequence, [:size, :identifier, :docs]
  defrecord :payload, [:doc, :sequence]
  defrecord :section, [:payload_type, :payload]
  defrecord :op_msg, [:flags, :sections]

  @doc """
    Decodes the header from response of a request sent by the mongodb server
  """
  def decode_header(iolist) when is_list(iolist) do
    case IO.iodata_length(iolist) == @header_size do
      true ->
        iolist
        |> IO.iodata_to_binary()
        |> decode_header()

      false ->
        :error
    end
  end

  def decode_header(<<length::int32(), request_id::int32(), response_to::int32(), op_code::int32()>>) do
    header = msg_header(length: length - @header_size, request_id: request_id, response_to: response_to, op_code: op_code)
    {:ok, header}
  end

  def decode_header(_binary), do: :error

  @doc """
    Decodes the response body of a request sent by the mongodb server
  """
  def decode_response(msg_header(length: length) = header, iolist) when is_list(iolist) do
    case IO.iodata_length(iolist) >= length do
      true -> decode_response(header, IO.iodata_to_binary(iolist))
      false -> :error
    end
  end

  def decode_response(msg_header(length: length, response_to: response_to, op_code: op_code), binary) when byte_size(binary) >= length do
    <<response::binary(length), rest::binary>> = binary

    case op_code do
      @op_reply -> {:ok, response_to, decode_reply(response), rest}
      @op_msg_code -> {:ok, response_to, decode_msg(response), rest}
      _ -> :error
    end
  end

  def decode_response(_header, _binary), do: :error

  @doc """
    Decodes a reply message from the response
  """
  def decode_reply(<<flags::int32(), cursor_id::int64(), from::int32(), num::int32(), rest::binary>>) do
    op_reply(flags: flags, cursor_id: cursor_id, from: from, num: num, docs: BSON.Decoder.documents(rest))
  end

  def decode_msg(<<flags::int32(), rest::binary>>) do
    op_msg(flags: flags, sections: decode_sections(rest))
  end

  def decode_sections(binary), do: decode_sections(binary, [])
  def decode_sections("", acc), do: Enum.reverse(acc)

  def decode_sections(<<0x00::int8(), payload::binary>>, acc) do
    <<size::int32(), _rest::binary>> = payload
    <<doc::binary(size), rest::binary>> = payload

    with {doc, ""} <- BSON.Decoder.document(doc) do
      decode_sections(rest, [section(payload_type: 0, payload: payload(doc: doc)) | acc])
    end
  end

  def decode_sections(<<0x01::int8(), payload::binary>>, acc) do
    <<size::int32(), _rest::binary>> = payload
    <<sequence::binary(size), rest::binary>> = payload
    decode_sections(rest, [section(payload_type: 1, payload: payload(sequence: decode_sequence(sequence))) | acc])
  end

  def decode_sequence(<<size::int32(), rest::binary>>) do
    with {identifier, docs} <- cstring(rest) do
      sequence(size: size, identifier: identifier, docs: BSON.Decoder.documents(docs))
    end
  end

  defp cstring(binary) do
    [string, rest] = :binary.split(binary, <<0x00>>)
    {string, rest}
  end

  def encode(request_id, op_query() = op) do
    iodata = encode_op(op)
    header = msg_header(length: IO.iodata_length(iodata) + @header_size, request_id: request_id, response_to: 0, op_code: @op_query)
    [encode_header(header) | iodata]
  end

  def encode(request_id, op_msg() = op) do
    iodata = encode_op(op)
    header = msg_header(length: IO.iodata_length(iodata) + @header_size, request_id: request_id, response_to: 0, op_code: @op_msg_code)
    [encode_header(header) | iodata]
  end

  defp encode_header(msg_header(length: length, request_id: request_id, response_to: response_to, op_code: op_code)) do
    <<length::int32(), request_id::int32(), response_to::int32(), op_code::int32()>>
  end

  defp encode_op(op_query(flags: flags, coll: coll, num_skip: num_skip, num_return: num_return, query: query, select: select)) do
    [<<blit_flags(:query, flags)::int32()>>, coll, <<0x00, num_skip::int32(), num_return::int32()>>, BSON.Encoder.document(query), select]
  end

  defp encode_op(op_msg(flags: flags, sections: sections)) do
    [<<blit_flags(:msg, flags)::int32()>> | encode_sections(sections)]
  end

  defp encode_sections(sections) do
    Enum.map(sections, fn section -> encode_section(section) end)
  end

  defp encode_section(section(payload_type: t, payload: payload)) do
    [<<t::int8()>> | encode_payload(payload)]
  end

  defp encode_payload(payload(doc: doc, sequence: nil)) do
    BSON.Encoder.document(doc)
  end

  defp encode_payload(payload(doc: nil, sequence: sequence(identifier: identifier, docs: docs))) do
    iodata = [identifier, <<0x00>> | Enum.map(docs, fn doc -> BSON.Encoder.encode(doc) end)]
    size = IO.iodata_length(iodata) + 4
    [<<size::int32()>> | iodata]
  end

  defp blit_flags(op, flags) when is_list(flags) do
    import Bitwise
    Enum.reduce(flags, 0x0, &(flag_to_bit(op, &1) ||| &2))
  end

  defp blit_flags(_op, flags) when is_integer(flags) do
    flags
  end

  Enum.each(@query_flags, fn {flag, bit} ->
    defp flag_to_bit(:query, unquote(flag)), do: unquote(bit)
  end)

  Enum.each(@msg_flags, fn {flag, bit} ->
    defp flag_to_bit(:msg, unquote(flag)), do: unquote(bit)
  end)

  defp flag_to_bit(_op, _flag), do: 0x0
end