lib/kvasir/pub_sub/protocols/mqtt/mqtt.ex

defmodule Kvasir.PubSub.Protocols.MQTT do
  alias __MODULE__.V5
  @behaviour Kvasir.PubSub.Protocol

  @impl Kvasir.PubSub.Protocol
  def port, do: 1883

  @impl Kvasir.PubSub.Protocol
  def secure_port, do: 8883

  @impl Kvasir.PubSub.Protocol
  def handshake(transport, socket) do
    case __MODULE__.read(transport, socket) do
      {:ok, {:mqtt, :connect, {version, flags, opts}}} ->
        u = opts[:username]
        p = opts[:password]

        auth =
          cond do
            is_nil(u) -> :anonymous
            is_nil(p) -> {:user, u}
            :auth -> {:password, u, p}
          end

        {:ok, auth, %{flags: flags, version: version, opts: opts}}

      _ ->
        {:error, :no_mqtt_connect}
    end
  end

  @impl Kvasir.PubSub.Protocol
  @spec run(
          transport :: term,
          socket :: term,
          state :: term,
          server_state :: Kvasir.PubSub.server_state()
        ) :: :ok
  def run(transport, socket, state = %{opts: opts}, server_state) do
    state =
      if opts[:client_id] == :requested do
        client_id = Base.encode16(:crypto.strong_rand_bytes(8), padding: false, case: :lower)
        %{state | opts: Keyword.put(opts, :client_id, client_id)}
      else
        state
      end

    write(transport, socket, {:mqtt, :connack, client_id: state.opts[:client_id]})
    do_run(transport, socket, state, server_state)
  end

  defp do_run(transport, socket, state, server_state) do
    timeout = (Keyword.get(state.opts, :keep_alive, 60) + 5) * 1_000

    with {:ok, mqtt} <- __MODULE__.read(transport, socket, timeout) do
      case mqtt do
        {:mqtt, :subscribe, payload} ->
          id = payload[:id]
          subs = payload[:subscriptions]
          {:ok, groups} = subscribe(subs, server_state)

          write(
            transport,
            socket,
            {:mqtt, :suback,
             id: id,
             subscriptions:
               Enum.map(groups, fn %{group: g} ->
                 if(g, do: :granted_qos_1, else: :granted_qos_0)
               end)}
          )

          start_listen(transport, socket, id, groups, server_state)

        {:mqtt, :pingreq} ->
          write(transport, socket, {:mqtt, :pingresp})

        {:ok, x} ->
          IO.inspect(x, label: "Unknown")
      end

      do_run(transport, socket, state, server_state)
    else
      err ->
        IO.inspect(err, label: "Error")
        :ok
    end
  end

  defp start_listen(transport, socket, id, groups, server_state)

  defp start_listen(transport, socket, id, groups, _server_state = {%{source: source}, _client}) do
    groups
    |> Enum.reject(& &1.group)
    |> Enum.each(fn group ->
      source.listen(
        group.topic,
        fn event ->
          write(
            transport,
            socket,
            {:mqtt, :publish,
             id: id,
             content_type: "application/json",
             topic: group.topic,
             qos: 0,
             payload: Jason.encode!(event)}
          )

          :ok
        end,
        only: group.events,
        events: group.events
      )
    end)
  end

  defp subscribe(subscriptions, _server_state = {%{source: source}, _client}) do
    with {:ok, parsed} <-
           EnumX.reduce_while(subscriptions, %{}, fn {subscription, _opts}, acc ->
             with {:ok, sub} <- parse_subscription(subscription, source) do
               {:ok,
                Map.update(
                  acc,
                  {sub.topic, sub.group},
                  sub,
                  &%{&1 | events: Enum.uniq(&1.events ++ sub.events)}
                )}
             end
           end),
         groups <- Map.values(parsed) do
      {:ok, groups}
    end
  end

  defp parse_subscription(subscription, source)

  defp parse_subscription("$shared/" <> subscription, source) do
    with [group, topic | selector] <- String.split(subscription, "/"),
         t when t != nil <- source.__topics__()[topic],
         e when e != [] <- find_events(t.events, selector) do
      {:ok, %{group: group, topic: topic, events: e}}
    else
      _ -> {:error, :invalid_subscription}
    end
  end

  defp parse_subscription(subscription, source) do
    with [topic | selector] <- String.split(subscription, "/"),
         t when t != nil <- source.__topics__()[topic],
         e when e != [] <- find_events(t.events, selector) do
      {:ok, %{group: false, topic: topic, events: e}}
    else
      _ -> {:error, :invalid_subscription}
    end
  end

  defp find_events(events, selector)
  defp find_events(events, []), do: events

  defp find_events(events, selector) do
    matcher =
      selector
      |> clean_selector()
      |> merge_matcher()
      |> Regex.compile!()

    Enum.filter(events, &(&1.__event__(:type) =~ matcher))
  end

  defp clean_selector(selector, acc \\ [])
  defp clean_selector([], acc), do: :lists.reverse(acc)
  defp clean_selector(["*", "*" | rest], acc), do: clean_selector(["*" | rest], acc)
  defp clean_selector(["*", "+" | rest], acc), do: clean_selector(["*" | rest], acc)
  defp clean_selector([h | rest], acc), do: clean_selector(rest, [h | acc])

  defp merge_matcher(selector, acc \\ "^")
  defp merge_matcher([], acc), do: acc <> "$"
  defp merge_matcher(["*"], acc), do: acc
  defp merge_matcher(["*" | rest], acc), do: merge_matcher(rest, join_matcher(acc, "[a-z\_\.]*"))
  defp merge_matcher(["+" | rest], acc), do: merge_matcher(rest, join_matcher(acc, "[a-z\_]+"))
  defp merge_matcher([a | rest], acc), do: merge_matcher(rest, join_matcher(acc, a))

  defp join_matcher(a, b)
  defp join_matcher("^", b), do: "^" <> b
  defp join_matcher(a, b), do: a <> "." <> b

  ### Actual Logic => MOVE IT ###
  @type control_type ::
          :connect
          | :connack
          | :publish
          | :puback
          | :pubrec
          | :pubrel
          | :pubcomp
          | :subscribe
          | :suback
          | :unsubscribe
          | :unsuback
          | :pingreq
          | :pingresp
          | :disconnect

  @type packet ::
          {:mqtt, type :: control_type} | {:mqtt, type :: control_type, payload :: Keyword.t()}

  @spec read(atom, port, timeout :: pos_integer) :: {:ok, MQTT.packet()} | {:error, atom}
  def read(transport, socket, timeout \\ 15_000) do
    with {:ok, <<header::binary-1, l::binary>>} <- transport.recv(socket, 2, timeout),
         length = decode_length(transport, socket, l),
         {:ok, payload} <-
           if(length == 0, do: {:ok, <<>>}, else: transport.recv(socket, length, 5_000)) do
      V5.decode(header, payload)
    end
  end

  def write(transport, socket, package) do
    {:ok, data} = V5.encode(package)
    transport.send(socket, data)
  end

  def pong(transport, socket) do
    transport.send(socket, <<13::4, 0::4, 0>>)
  end

  def publish(transport, socket, data) do
    dup = 0
    qos = 0
    retain = 0
    topic = "TEST/TIME"
    topic_length = byte_size(topic)
    payload = <<topic_length::unsigned-integer-16, topic::binary, 0, data::binary>>
    length = encode_length(byte_size(payload))

    transport.send(
      socket,
      <<3::4, dup::1, qos::2, retain::1, length::binary, payload::binary>>
    )
  end

  defp decode_length(transport, socket, length, position \\ 0, acc \\ 0)
  defp decode_length(_, _, <<0::1, l::7>>, position, acc), do: acc + :erlang.bsl(l, position)

  defp decode_length(transport, socket, <<1::1, l::7>>, position, acc) do
    {:ok, n} = transport.receive(socket, 1, 5_000)

    decode_length(transport, socket, n, position + 7, acc + :erlang.bsl(l, position))
  end

  defp encode_length(length) when length < 128, do: <<length>>

  defp encode_length(length) do
    <<1::1, :erlang.band(length, 127)::7, encode_length(:erlang.bsr(length, 7))::binary>>
  end
end