Skip to main content

lib/proto_channel.ex

defmodule ProtoChannel do
  @moduledoc """
  A typed Protobuf layer over `Phoenix.Channel`.

  A `use ProtoChannel` channel declares its event ⇄ Protobuf-message pairs at
  compile time and exchanges typed structs with handlers instead of raw
  `{:binary, bytes}` payloads. Pattern-matching the structs at the boundary
  gives compile-time field-name safety, and the `c:handle_proto/3` spec gives
  dialyzer full value-type checking.

  ## Example

      defmodule MyAppWeb.MyChannel do
        use ProtoChannel

        alias MyApp.{Request, Response, Notice}

        proto_message "ping", request: Request, reply: Response
        proto_push "notice", Notice
        proto_broadcast "notice", Notice

        @impl Phoenix.Channel
        def join("room:" <> _, _payload, socket), do: {:ok, socket}

        @impl ProtoChannel
        def handle_proto("ping", %Request{} = req, socket) do
          push(socket, "notice", %Notice{text: req.text})
          broadcast(socket, "notice", %Notice{text: req.text})
          {:reply, {:ok, %Response{text: req.text}}, socket}
        end
      end

  ## What the macros generate

    * `proto_message/2` — one `handle_in/3` clause per declared event that
      decodes the inbound bytes into the request struct, dispatches to
      `c:handle_proto/3`, and encodes the reply struct back to bytes.
    * `proto_push/2` and `proto_broadcast/2` — typed wrappers around
      `Phoenix.Channel.push/3`, `broadcast/3`, `broadcast!/3`,
      `broadcast_from/3`, and `broadcast_from!/3`. Each declared event accepts
      only its declared struct; anything else is a function-clause mismatch at
      the call site.

  The unqualified `push/3`, `broadcast/3`, ... names inside your channel
  resolve to the generated wrappers — the macro imports `Phoenix.Channel` with
  those five names excluded. To bypass the wrappers, call
  `Phoenix.Channel.push/3` etc. directly.

  ## Compile-time validation

    * Duplicate event names within the same macro family (`proto_message`,
      `proto_push`, or `proto_broadcast`) raise `ArgumentError` at compile
      time.
    * Every referenced module must `use Protobuf` — the macro checks for
      `__message_props__/0` and raises if missing, so typos and stray plain
      structs are caught up-front.

  ## Wire format

  The macro only produces `{:binary, bytes}` payloads. To frame those over
  the socket as protobuf, pair this with `ProtoChannel.Serializer`.
  """

  @typedoc """
  A reply payload from `c:handle_proto/3`: the status atom plus a Protobuf
  struct that the macro will encode to bytes.
  """
  @type reply :: {:ok | :error, struct()}

  @typedoc """
  Return shapes accepted by `c:handle_proto/3`.

  The reply form mirrors `c:Phoenix.Channel.handle_in/3`, but the reply
  struct is encoded to `{:binary, bytes}` by the macro before being handed
  back to Phoenix. `:noreply` and bare `:stop` variants pass through
  unchanged.
  """
  @type handle_proto_result ::
          {:reply, reply(), Phoenix.Socket.t()}
          | {:noreply, Phoenix.Socket.t()}
          | {:noreply, Phoenix.Socket.t(), timeout() | :hibernate}
          | {:stop, reason :: term(), Phoenix.Socket.t()}
          | {:stop, reason :: term(), reply(), Phoenix.Socket.t()}

  @doc """
  Handles a decoded request struct for a `proto_message/2`-declared event.

  Invoked from the generated `handle_in/3` clause after the inbound bytes
  have been decoded into the declared request struct. The returned reply
  struct (if any) is encoded back to bytes before being handed to Phoenix.

  See `t:handle_proto_result/0` for the supported return shapes.
  """
  @callback handle_proto(event :: String.t(), request :: struct(), socket :: Phoenix.Socket.t()) ::
              handle_proto_result()

  @optional_callbacks handle_proto: 3

  defmacro __using__(opts) do
    quote do
      use Phoenix.Channel, unquote(opts)

      import Phoenix.Channel,
        except: [push: 3, broadcast: 3, broadcast!: 3, broadcast_from: 3, broadcast_from!: 3]

      @behaviour ProtoChannel

      import ProtoChannel, only: [proto_message: 2, proto_push: 2, proto_broadcast: 2]
      Module.register_attribute(__MODULE__, :proto_messages, accumulate: true)
      Module.register_attribute(__MODULE__, :proto_pushes, accumulate: true)
      Module.register_attribute(__MODULE__, :proto_broadcasts, accumulate: true)

      @before_compile ProtoChannel
    end
  end

  @doc """
  Declares an RPC-style inbound event.

  The channel receives `event` as a `{:binary, bytes}` payload, decodes the
  bytes into a `request` struct, dispatches to `c:handle_proto/3`, and encodes
  the returned reply struct back to bytes.

  ## Options

    * `:request` — module returned by `use Protobuf`; the request struct type.
    * `:reply` — module returned by `use Protobuf`; the reply struct type.

  Both keys are required. Each module is resolved and checked for
  `__message_props__/0` at compile time.

  ## Example

      proto_message "ping", request: MyApp.Ping, reply: MyApp.Pong

  """
  defmacro proto_message(event, opts) do
    request = Keyword.fetch!(opts, :request)
    reply = Keyword.fetch!(opts, :reply)

    quote do
      @proto_messages {unquote(event), unquote(request), unquote(reply)}
    end
  end

  @doc """
  Generates a typed `push/3` wrapper for an outbound event.

  Inside the channel, `push(socket, event, %module{} = msg)` encodes `msg` and
  forwards to `Phoenix.Channel.push/3` as `{:binary, bytes}`. Other event
  names still resolve to the unwrapped `Phoenix.Channel.push/3`.

  ## Example

      proto_push "notice", MyApp.Notice
      # ...
      push(socket, "notice", %MyApp.Notice{text: "hi"})

  """
  defmacro proto_push(event, module) do
    quote do
      @proto_pushes {unquote(event), unquote(module)}
    end
  end

  @doc """
  Generates typed wrappers for an outbound broadcast event.

  Generates clauses for `broadcast/3`, `broadcast!/3`, `broadcast_from/3`, and
  `broadcast_from!/3`. Each encodes the struct and forwards to the matching
  `Phoenix.Channel` function as `{:binary, bytes}`.

  ## Example

      proto_broadcast "notice", MyApp.Notice
      # ...
      broadcast(socket, "notice", %MyApp.Notice{text: "hi"})

  """
  defmacro proto_broadcast(event, module) do
    quote do
      @proto_broadcasts {unquote(event), unquote(module)}
    end
  end

  defmacro __before_compile__(env) do
    messages = env.module |> Module.get_attribute(:proto_messages) |> Enum.reverse()
    pushes = env.module |> Module.get_attribute(:proto_pushes) |> Enum.reverse()
    broadcasts = env.module |> Module.get_attribute(:proto_broadcasts) |> Enum.reverse()

    check_duplicates!(env.module, "proto_message", messages)
    check_duplicates!(env.module, "proto_push", pushes)
    check_duplicates!(env.module, "proto_broadcast", broadcasts)

    handle_in_clauses =
      for {event, request_mod, reply_mod} <- messages do
        validate_proto_module!(env.module, event, :request, request_mod)
        validate_proto_module!(env.module, event, :reply, reply_mod)

        quote do
          def handle_in(unquote(event), {:binary, bin}, socket) do
            request = unquote(request_mod).decode(bin)
            result = handle_proto(unquote(event), request, socket)
            ProtoChannel.__encode_result__(result, unquote(reply_mod))
          end
        end
      end

    if messages != [] and not Module.defines?(env.module, {:handle_proto, 3}) do
      raise ArgumentError,
            "ProtoChannel in #{inspect(env.module)}: " <>
              "handle_proto/3 must be implemented when proto_message/2 is declared."
    end

    push_clauses =
      for {event, mod} <- pushes do
        validate_proto_module!(env.module, event, :push, mod)

        quote do
          def push(socket, unquote(event), %unquote(mod){} = msg) do
            Phoenix.Channel.push(socket, unquote(event), {:binary, unquote(mod).encode(msg)})
          end
        end
      end

    broadcast_clauses =
      for {event, mod} <- broadcasts do
        validate_proto_module!(env.module, event, :broadcast, mod)

        quote do
          def broadcast(socket, unquote(event), %unquote(mod){} = msg) do
            Phoenix.Channel.broadcast(socket, unquote(event), {:binary, unquote(mod).encode(msg)})
          end

          def broadcast!(socket, unquote(event), %unquote(mod){} = msg) do
            Phoenix.Channel.broadcast!(socket, unquote(event), {:binary, unquote(mod).encode(msg)})
          end

          def broadcast_from(socket, unquote(event), %unquote(mod){} = msg) do
            Phoenix.Channel.broadcast_from(
              socket,
              unquote(event),
              {:binary, unquote(mod).encode(msg)}
            )
          end

          def broadcast_from!(socket, unquote(event), %unquote(mod){} = msg) do
            Phoenix.Channel.broadcast_from!(
              socket,
              unquote(event),
              {:binary, unquote(mod).encode(msg)}
            )
          end
        end
      end

    quote do
      (unquote_splicing(handle_in_clauses ++ push_clauses ++ broadcast_clauses))
    end
  end

  @doc false
  @spec __encode_result__(handle_proto_result(), module()) :: term()
  def __encode_result__({:reply, {status, %_{} = reply}, socket}, reply_mod)
      when status in [:ok, :error] do
    {:reply, {status, {:binary, reply_mod.encode(reply)}}, socket}
  end

  def __encode_result__({:stop, reason, {status, %_{} = reply}, socket}, reply_mod)
      when status in [:ok, :error] do
    {:stop, reason, {status, {:binary, reply_mod.encode(reply)}}, socket}
  end

  def __encode_result__({:noreply, _socket} = passthrough, _reply_mod), do: passthrough
  def __encode_result__({:noreply, _socket, _timeout} = passthrough, _reply_mod), do: passthrough
  def __encode_result__({:stop, _reason, _socket} = passthrough, _reply_mod), do: passthrough

  @doc false
  def check_duplicates!(channel_mod, macro_name, entries) do
    duplicates =
      entries
      |> Enum.frequencies_by(&elem(&1, 0))
      |> Enum.filter(fn {_, count} -> count > 1 end)
      |> Enum.map(&elem(&1, 0))

    if duplicates != [] do
      raise ArgumentError,
            "ProtoChannel in #{inspect(channel_mod)}: duplicate #{macro_name} event(s) #{inspect(duplicates)}"
    end
  end

  @doc false
  def validate_proto_module!(channel_mod, event, role, module) do
    case Code.ensure_compiled(module) do
      {:module, ^module} ->
        :ok

      {:error, reason} ->
        raise ArgumentError,
              "ProtoChannel in #{inspect(channel_mod)}: " <>
                "event #{inspect(event)} (#{role}: #{inspect(module)}) " <>
                "could not be compiled (#{inspect(reason)})."
    end

    unless function_exported?(module, :__message_props__, 0) do
      raise ArgumentError,
            "ProtoChannel in #{inspect(channel_mod)}: " <>
              "event #{inspect(event)} (#{role}: #{inspect(module)}) " <>
              "must `use Protobuf` — got a module without __message_props__/0."
    end
  end
end