defmodule Hub.Channel do
@moduledoc """
GenServer that handles a single channel. This serializes publishes, subscribes and unsubscribes on that channel, and
makes sure no race condition can occur.
"""
alias Hub.ChannelRegistry
alias Hub.Subscriber
use GenServer
@type subscribe_options :: [subscribe_option]
@type subscribe_option :: {:pid, pid} | {:count, count} | {:multi, boolean}
@type count :: pos_integer | :infinity
@type pattern :: any
@type subscription_ref :: {pid, reference}
# Public API
@doc """
Starts the Channel
"""
@spec start_link(String.t()) :: GenServer.on_start()
def start_link(channel_name) do
GenServer.start_link(__MODULE__, channel_name)
end
@doc """
Subscribes with the quoted pattern
"""
@spec subscribe_quoted(pid, any, subscribe_options) :: {:ok, subscription_ref} | {:error, reason :: String}
def subscribe_quoted(channel, quoted_pattern, options \\ []) when is_pid(channel) do
map_options = options |> Enum.into(%{})
do_subscribe_quoted(channel, quoted_pattern, map_options)
end
@doc """
Publishes the message to all matching subscribers of this channel.
Returns number of subscribers that the message was sent to.
"""
@spec publish(pid, any) :: non_neg_integer
def publish(channel, message) when is_pid(channel) do
GenServer.call(channel, {:publish, message})
end
@doc """
Get all subscribers from channel
"""
@spec subscribers(pid) :: [Subscriber.t()]
def subscribers(channel) when is_pid(channel) do
GenServer.call(channel, :subscribers)
end
@doc """
Unsubscribes using the reference returned on subscribe
"""
@spec unsubscribe(subscription_ref) :: :ok
def unsubscribe({channel, ref}) when is_pid(channel) and is_reference(ref) do
case GenServer.whereis(channel) do
pid when is_pid(pid) ->
GenServer.cast(pid, {:unsubscribe, ref})
nil ->
:ok
end
end
@doc """
Unsubscribes and flushes any messages in the mailbox that matches the subscribed pattern
"""
@spec unsubscribe_and_flush(subscription_ref) :: :ok
def unsubscribe_and_flush({channel, ref} = subscription_ref) when is_pid(channel) and is_reference(ref) do
case GenServer.whereis(channel) do
pid when is_pid(pid) ->
channel
|> subscribers()
|> Enum.filter(&(&1.ref == ref))
|> case do
[] ->
:ok
[subscriber] ->
unsubscribe(subscription_ref)
flush(subscriber)
end
nil ->
:ok
end
end
# GenServer callbacks
def init(channel_name) do
case ChannelRegistry.register(channel_name) do
:ok ->
state = %{
# ref => subscriber
subscriber_by_ref: %{},
# pid => %{ref => subscriber}
subscribers_by_pid: %{}
}
{:ok, state}
{:duplicate_key, _pid} ->
:ignore
end
end
def handle_call({:publish, message}, _from, state) do
subscribers =
state.subscriber_by_ref
|> Map.values()
|> Enum.filter(&publish_to_subscriber?(message, &1))
state =
subscribers
|> Enum.reduce(state, &publish_to_subscriber(&2, message, &1))
{:reply, length(subscribers), state}
end
def handle_call({:subscribe_quoted, quoted_pattern, options, caller}, _from, state) do
pid = options |> Map.get(:pid, caller)
count = options |> Map.get(:count, :infinity)
multi = options |> Map.get(:multi, false)
subscriber = Subscriber.new(pid, quoted_pattern, count, multi)
Process.monitor(pid)
state = add_subscriber(state, subscriber)
{:reply, {:ok, {self(), subscriber.ref}}, state}
end
def handle_call(:subscribers, _from, state) do
{:reply, Map.values(state.subscriber_by_ref), state}
end
def handle_cast({:unsubscribe, ref}, state) do
state =
case Map.fetch(state.subscriber_by_ref, ref) do
{:ok, subscriber} ->
remove_subscriber(state, subscriber)
:error ->
state
end
{:noreply, state}
end
def handle_info({:DOWN, _monitor, :process, pid, _reason}, state) do
state =
state.subscribers_by_pid
|> Map.get(pid, %{})
|> Map.values()
|> Enum.reduce(state, fn subscriber, state -> remove_subscriber(state, subscriber) end)
{:noreply, state}
end
# Helpers
defp publish_to_subscriber?(term, %{multi: true} = subscriber) do
subscriber.pattern
|> Enum.any?(&pattern_match?(&1, term))
end
defp publish_to_subscriber?(term, subscriber) do
pattern_match?(subscriber.pattern, term)
end
defp publish_to_subscriber(state, term, subscriber) do
state = update_subscriber(state, subscriber)
send(subscriber.pid, term)
state
end
defp update_subscriber(state, %{count: :infinity}) do
state
end
defp update_subscriber(state, %{count: 1} = subscriber) do
remove_subscriber(state, subscriber)
end
defp update_subscriber(state, %{count: count} = subscriber) when count > 1 do
new_subscriber = %{subscriber | count: count - 1}
state
|> remove_subscriber(subscriber)
|> add_subscriber(new_subscriber)
end
defp pattern_match?(pattern, term) do
quoted_term = Macro.escape(term)
ast =
quote do
case unquote(quoted_term) do
unquote(pattern) -> true
_ -> false
end
end
{result, _} = Code.eval_quoted(ast)
result
end
defp do_subscribe_quoted(_channel, quoted_pattern, %{multi: true}) when not is_list(quoted_pattern) do
{:error, "Must subscribe with a list of patterns when using multi: true"}
end
defp do_subscribe_quoted(channel, quoted_pattern, options) do
# Try to pattern match to catch syntax errors before publishing
pattern_match?(quoted_pattern, nil)
GenServer.call(channel, {:subscribe_quoted, quoted_pattern, options, self()})
end
defp add_subscriber(state, subscriber) do
%{
state
| subscriber_by_ref: Map.put(state.subscriber_by_ref, subscriber.ref, subscriber),
subscribers_by_pid:
Map.update(state.subscribers_by_pid, subscriber.pid, %{subscriber.ref => subscriber}, fn subscribers ->
Map.put(subscribers, subscriber.ref, subscriber)
end)
}
end
defp remove_subscriber(state, subscriber) do
%{
state
| subscriber_by_ref: Map.delete(state.subscriber_by_ref, subscriber.ref),
subscribers_by_pid:
Map.update(state.subscribers_by_pid, subscriber.pid, %{}, fn subscribers ->
Map.delete(subscribers, subscriber.ref)
end)
}
end
defp flush(subscriber) do
if subscriber.multi do
Enum.each(subscriber.pattern, &do_flush/1)
else
do_flush(subscriber.pattern)
end
end
defp do_flush(quoted_pattern) do
{removed_any, _} =
quote do
receive do
unquote(quoted_pattern) -> true
after
0 -> false
end
end
|> Code.eval_quoted()
if removed_any do
do_flush(quoted_pattern)
else
:ok
end
end
end