defmodule Maelstrom do
@moduledoc """
Allows you to create servers which implement the [Maelstrom protocol](https://github.com/jepsen-io/maelstrom/blob/main/doc/protocol.md).
[Maelstrom](https://github.com/jepsen-io/maelstrom) is a workbench for learning distributed systems by writing your own.
# Usage
To implement a server:
1. Create a module and `use Maelstrom`
2. Implement one or more `handle_message` heads.
3. Call `MyModule.run_forever()`.
I recommend you do this in a .exs script. Example:
```
defmodule Echo do
use Maelstrom
def handle_message(_src, _dest, %{"echo" => echo}, state, _) do
{:reply, %{"type" => "echo_ok", "echo" => echo}, state}
end
end
Echo.run_forever()
```
You could then run this with `mix run` e.g. `mix run echo.exs`.
> #### Tip {: .tip}
> Maelstrom expects a single binary with no arguments to call for testing.
> In order to accomplish this, wrap your mix run command in a shell script (see [demos](https://github.com/prehnRA/maelstrom_ex/tree/main/demo) for examples).
# Examples
For more examples, see [demos](https://github.com/prehnRA/maelstrom_ex/tree/main/demo).
"""
@type node_id() :: String.t()
@type msg_id() :: non_neg_integer()
@type error_code() :: non_neg_integer()
@type body() :: Map
@type state() :: Map
@type node_state() :: Map
@type handler_result() :: {:reply, body(), state()} | {:noreply, state()} | {:error, non_neg_integer(), String.t(), state()}
@type rpc_callback() :: (body(), body(), state() -> handler_result())
@doc "Docs on a callback?"
@callback handle_message(node_id(), node_id(), body(), state(), node_state()) :: handler_result()
@doc """
Run the server forever and prevent the script from exiting. Handles IO to and
from your server node. You should invoke this directly on your server module
instead of calling it on `Maelstrom` e.g. `MyModule.run_forever()`. This function
is defined automatically when you `use Maelstrom`.
"""
@spec run_forever(state()) :: :ok
def run_forever(_initial_state \\ %{}) do
# Stub for documentation
raise "Call this function directly on your server module e.g. MyModule.run_forever()."
end
@doc false
def take_input_forever(pid) do
stream = IO.stream()
Task.async(fn ->
Enum.each(stream, fn line ->
message = Jason.decode!(line)
log("Received: #{inspect(message)}")
GenServer.cast(pid, message)
end)
end)
end
@doc "Logs an error to :stderr in the way that Maelstrom expects."
@spec log(String.t(), atom()) :: :ok
def log(log, adapter \\ GenServer) do
adapter.cast(self(), {:log, log})
end
@doc """
Send an RPC-style message with the body `body` to `dest`. `callback` is a function
that will be invoked if and when `dest` replies to the message. The arguments to callback are:
1. The body of the original message.
2. The body of the reply.
3. The state of your server.
The callback should reply with a valid handler response:
1. `{:reply, reply_body, new_server_state}`
2. `{:noreply, new_server_state}`
3. `{:error, error_code, error_message, new_server_state}`
"""
@spec send_rpc(node_id(), body(), rpc_callback(), atom()) :: :ok
def send_rpc(dest, body, callback, adapter \\ GenServer) do
adapter.cast(self(), {:send_rpc, dest, body, callback})
end
@doc """
Send a message with the body `body` to `dest`.
"""
@spec send_message(node_id(), body(), atom()) :: :ok
def send_message(dest, body, adapter \\ GenServer) do
adapter.cast(self(), {:send_message, dest, body})
end
@doc """
Send a reply to `replyee` replying to the message with id `in_reply_to`. Include the body `body`.
"""
@spec send_reply(node_id(), msg_id(), body(), atom()) :: :ok
def send_reply(replyee, in_reply_to, body, adapter \\ GenServer) do
send_message(replyee, Map.put(body, "in_reply_to", in_reply_to), adapter)
end
@doc """
Send an error message to `replyee` in reply to the message indentified by `in_reply_to`. The error
message will include the error code `error_code` (see
[the Maelstrom docs](https://github.com/jepsen-io/maelstrom/blob/main/doc/protocol.md#errors) for
a list of codes) and a string message `error_text`.
"""
@spec send_error(node_id(), msg_id(), error_code(), String.t(), atom()) :: :ok
def send_error(replyee, in_reply_to, error_code, error_text, adapter \\ GenServer) do
body =
%{
"type" => "error",
"code" => error_code,
"text" => error_text
}
send_reply(replyee, in_reply_to, body, adapter)
end
defmacro __using__(_opts \\ []) do
quote do
use GenServer
require Logger
import Maelstrom
unquote(setup_cast_handlers())
end
end
@doc false
def default_node_state do
%{msg_id: 1, rpcs: %{}}
end
@doc false
def setup_cast_handlers() do
quote do
def run_forever(initial_state \\ %{}) do
{:ok, pid} = GenServer.start_link(__MODULE__, initial_state)
take_input_forever(pid) |> Task.await(:infinity)
end
@impl true
def init(inner_state \\ %{}) do
{:ok, {Maelstrom.default_node_state(), inner_state}}
end
@impl true
def handle_cast(%{
"src" => src,
"body" => %{
"type" => "init",
"msg_id" => msg_id,
"node_id" => node_id,
"node_ids" => node_ids
}}, {node, inner_state}) do
new_node_state = Map.merge(node, %{node_id: node_id, node_ids: node_ids})
send_reply(src, msg_id, %{"type" => "init_ok"})
{:noreply, {new_node_state, inner_state}}
end
def handle_cast(%{
"src" => src,
"body" => %{
"in_reply_to" => msg_id
} = reply_body} = msg, {%{rpcs: rpcs} = node_state, inner_state} = state) do
{original_body, callback} = Map.get(rpcs, msg_id, {nil, &warn_missing_rpc_callback/3})
process_callback(callback.(original_body, reply_body, inner_state), msg, state)
end
@impl true
def handle_cast(%{
"src" => src,
"dest" => dest,
"body" => %{"msg_id" => msg_id} = body
} = msg, {node_state, inner_state} = state) do
process_callback(handle_message(src, dest, body, inner_state, node_state), msg, state)
end
def handle_cast({:send_rpc, dest, body, callback}, {%{rpcs: rpcs, msg_id: msg_id} = node_state, inner_state}) do
{:noreply, {new_node_state, inner_state}} = handle_cast({:send_message, dest, body}, {node_state, inner_state})
rpcs = Map.put(rpcs, msg_id, {body, callback})
node_state = Map.put(node_state, :rpcs, rpcs)
{:noreply, {node_state, inner_state}}
end
def handle_cast({:send_message, dest, body}, {%{node_id: src, msg_id: msg_id} = node_state, inner_state}) do
message = %{
"src" => src,
"dest" => dest,
"body" => Map.put_new(body, "msg_id", msg_id)
}
json = Jason.encode!(message)
IO.puts(json)
new_node_state = Map.put(node_state, :msg_id, msg_id + 1)
{:noreply, {new_node_state, inner_state}}
end
def handle_cast({:log, log}, state) do
IO.warn(log)
{:noreply, state}
end
def handle_cast(other, _state) do
raise "Unexpected message format: #{inspect(other)}"
end
defp process_callback(result, %{
"src" => src,
"body" => %{"msg_id" => msg_id} = body
}, {node_state, _}) do
case result do
{:reply, body, new_inner_state} ->
send_reply(src, msg_id, body)
{:noreply, {node_state, new_inner_state}}
{:noreply, new_inner_state} ->
{:noreply, {node_state, new_inner_state}}
{:error, error_code, error_text, new_inner_state} ->
send_error(src, msg_id, error_code, error_text)
{:noreply, {node_state, new_inner_state}}
end
end
defp warn_missing_rpc_callback(_message_body, reply_body, inner_state) do
IO.warn("Reply for unknown RPC: reply_body=#{inspect(reply_body)}")
{:noreply, inner_state}
end
end
end
end