defmodule Lockstep.Rewriter do
@moduledoc """
Compile-time AST rewriter that converts vanilla OTP calls into their
`Lockstep.*` equivalents inside a `ctest` body. Lets test bodies
read like ordinary Elixir without sacrificing controlled scheduling.
Enabled with `use Lockstep.Test, rewrite: true`. With `rewrite: true`
the linter is also disabled (since the rewritten body is correct by
construction).
## Mappings
vanilla rewritten
------- ---------
GenServer.call(s, m) Lockstep.GenServer.call(s, m)
GenServer.cast(s, m) Lockstep.GenServer.cast(s, m)
GenServer.start_link(M, A) Lockstep.GenServer.start_link(M, A)
Task.async(fn -> ... end) Lockstep.Task.async(fn -> ... end)
Task.await(t) Lockstep.Task.await(t)
Task.await_many(ts) Lockstep.Task.await_many(ts)
spawn(fun) Lockstep.spawn(fun)
spawn_link(fun) Lockstep.spawn_link(fun)
send(p, m) Lockstep.send(p, m)
Kernel.send(p, m) Lockstep.send(p, m)
:erlang.send(p, m) Lockstep.send(p, m)
Process.send_after(p, m, ms) Lockstep.send_after(p, m, ms)
Process.cancel_timer(ref) Lockstep.cancel_timer(ref)
Process.monitor(p) Lockstep.monitor(p)
Process.demonitor(ref[, opts]) Lockstep.demonitor(ref[, opts])
Process.alive?(p) Lockstep.alive?(p)
Process.sleep(ms) Lockstep.sleep(ms)
Process.link(p) Lockstep.link(p)
Process.unlink(p) Lockstep.unlink(p)
Process.flag(:trap_exit, b) Lockstep.flag(:trap_exit, b)
receive do clauses end rewritten to Lockstep.recv_first/1
## Limitations
* Only the *body* of `ctest` is rewritten, not helper functions
defined elsewhere. Inline race code into the test body, or call
Lockstep wrappers directly in helpers.
* `receive ... after t -> ... end` (bare receive with timeout) is
not yet rewritten -- bare receives without `after` are.
* `:erlang.spawn/3` (the 3-arg MFA form) is not rewritten.
Functions defined locally that happen to be named `send/2` etc.
will be rewritten too. Suppress by qualifying the call (e.g.
`MyMod.send(...)` is left alone).
"""
@doc "Walk `ast` and return the rewritten AST."
def rewrite(ast), do: walk(ast)
# Custom walker that skips constructs whose call-shaped AST nodes
# are not actual function calls:
#
# * `def send(a, b), do: ...` -- the head is a pattern, not a call
# (Mutex literally has this).
# * `@spec send(t) :: result` / `@type` / `@callback` -- module
# attributes use the same AST shape but represent declarations.
#
# Both subtrees are returned unchanged.
defp walk({op, meta, [head, body]}) when op in [:def, :defp, :defmacro, :defmacrop] do
{op, meta, [head, walk(body)]}
end
defp walk({:@, _meta, _args} = node), do: node
# Rewrite `{:via, Registry, args}` → `{:via, Lockstep.Registry, args}`.
#
# OTP's `:via` contract dispatches to the via module's `register_name/2`
# (and friends) inside `gen_server` / `proc_lib` internals — code we
# can't rewrite. So we have to fix it at the call site: change the
# tuple itself so OTP calls `Lockstep.Registry.register_name/2`, which
# we control. `Lockstep.Registry` already implements the full :via
# callback contract (`register_name/2`, `unregister_name/1`,
# `whereis_name/1`, `send/2`).
#
# The 3-tuple `{:via, Registry, term}` lives in AST as
# `{:{}, meta, [:via, {:__aliases__, _, [:Registry]}, term_ast]}`.
defp walk({:{}, meta, [:via, {:__aliases__, am, [:Registry]}, term_ast]}) do
{:{}, meta, [:via, {:__aliases__, am, [:Lockstep, :Registry]}, walk(term_ast)]}
end
defp walk({op, meta, args}) when is_atom(op) and is_list(args) do
rewritten_args = walk(args)
rewrite_node({op, meta, rewritten_args})
end
defp walk({first, meta, args}) when is_list(args) do
walked = {walk(first), meta, walk(args)}
rewrite_node(walked)
end
defp walk({a, b}), do: {walk(a), walk(b)}
defp walk(list) when is_list(list), do: Enum.map(list, &walk/1)
defp walk(other), do: other
# ============================================================
# Qualified calls
# ============================================================
# All qualified Module.fn calls are dispatched via one cond. Earlier
# we had multiple `rewrite_node({{:., _, ...}, ...})` clauses that
# all matched the same outer pattern; whichever was first won. That
# made it impossible to add new mappings without rewriting an
# earlier clause's "if module else node" guard. The cond approach
# collapses them all.
defp rewrite_node({{:., m, [callee, fun]}, m2, args} = node) do
arity = length(args)
cond do
# GenServer.{call, cast, start_link, stop, start, reply}
fun in [:call, :cast, :start_link, :stop, :start, :reply] and
module_matches?(callee, GenServer) ->
qualify(m, m2, [:Lockstep, :GenServer], fun, args)
# Registry.{start_link, register, unregister, lookup, count,
# keys, dispatch, meta, put_meta, select}
fun in [
:start_link,
:register,
:unregister,
:lookup,
:count,
:keys,
:dispatch,
:meta,
:put_meta,
:select
] and module_matches?(callee, Registry) ->
qualify(m, m2, [:Lockstep, :Registry], fun, args)
# Supervisor.{start_link, init, which_children, count_children,
# start_child, terminate_child, restart_child}
fun in [
:start_link,
:init,
:which_children,
:count_children,
:start_child,
:terminate_child,
:restart_child
] and module_matches?(callee, Supervisor) ->
qualify(m, m2, [:Lockstep, :Supervisor], fun, args)
# Task.{async, await, await_many, async_stream}
fun in [:async, :await, :await_many, :async_stream] and module_matches?(callee, Task) ->
qualify(m, m2, [:Lockstep, :Task], fun, args)
# Task.Supervisor.{start_link, async, start_child, async_stream}
fun in [:start_link, :async, :start_child, :async_stream] and
module_matches?(callee, Task.Supervisor) ->
qualify(m, m2, [:Lockstep, :Task, :Supervisor], fun, args)
# Agent.{start_link, start, get, update, get_and_update, cast, stop}
fun in [
:start_link,
:start,
:get,
:update,
:get_and_update,
:cast,
:stop
] and module_matches?(callee, Agent) ->
qualify(m, m2, [:Lockstep, :Agent], fun, args)
# :ets.* operations: insert sync points so read-modify-write
# races can be interleaved by the strategy. The full set of
# wrappable ops is in `Lockstep.ETS`.
fun in [
:new,
:insert,
:insert_new,
:lookup,
:lookup_element,
:member,
:delete,
:delete_all_objects,
:match_delete,
:update_counter,
:update_element,
:select,
:match,
:match_object,
:tab2list,
:info,
:first,
:last,
:next,
:prev,
:take,
:select_count,
:safe_fixtable
] and callee == :ets ->
qualify(m, m2, [:Lockstep, :ETS], fun, args)
# :atomics.* operations: same idea as ETS. Atomics are atomic
# individually but compositions (read-then-update) aren't.
fun in [
:new,
:get,
:put,
:add,
:sub,
:add_get,
:sub_get,
:exchange,
:compare_exchange,
:info
] and callee == :atomics ->
qualify(m, m2, [:Lockstep, :Atomics], fun, args)
# :persistent_term.* operations.
fun in [:get, :put, :erase, :info] and callee == :persistent_term ->
qualify(m, m2, [:Lockstep, :PersistentTerm], fun, args)
# :rpc.* — cross-node call, multicall, cast, abcast.
fun in [:call, :multicall, :cast, :abcast] and callee == :rpc ->
qualify(m, m2, [:Lockstep, :RPC], fun, args)
# :global.* — cluster-wide name registry.
fun in [:register_name, :unregister_name, :whereis_name, :sync, :send] and
callee == :global ->
qualify(m, m2, [:Lockstep, :Global], fun, args)
# Task.{start_link, start}/1,3 -> Lockstep.Task.{start_link, start}
# which return {:ok, pid} matching the OTP shape.
fun in [:start_link, :start] and module_matches?(callee, Task) and arity in [1, 3] ->
qualify(m, m2, [:Lockstep, :Task], fun, args)
# Process.send_after/3
fun == :send_after and module_matches?(callee, Process) and arity == 3 ->
qualify(m, m2, [:Lockstep], :send_after, args)
# Process.cancel_timer/1
fun == :cancel_timer and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :cancel_timer, args)
# Process.monitor/1
fun == :monitor and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :monitor, args)
# Process.demonitor/{1,2}
fun == :demonitor and module_matches?(callee, Process) and arity in [1, 2] ->
qualify(m, m2, [:Lockstep], :demonitor, args)
# Process.alive?/1
fun == :alive? and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :alive?, args)
# Process.sleep/1
fun == :sleep and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :sleep, args)
# Process.link/1
fun == :link and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :link, args)
# Process.unlink/1
fun == :unlink and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep], :unlink, args)
# Process.flag/2
fun == :flag and module_matches?(callee, Process) and arity == 2 ->
qualify(m, m2, [:Lockstep], :flag, args)
# Process.whereis/1 -- per-node name lookup
fun == :whereis and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep, :Process], :whereis, args)
# Process.register/2 -- per-node name registration
fun == :register and module_matches?(callee, Process) and arity == 2 ->
qualify(m, m2, [:Lockstep, :Process], :register, args)
# Process.unregister/1
fun == :unregister and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep, :Process], :unregister, args)
# Process.registered/0
fun == :registered and module_matches?(callee, Process) and arity == 0 ->
qualify(m, m2, [:Lockstep, :Process], :registered, args)
# Process.set_label/1 -- OTP 27+; route through Lockstep.Process
# which no-ops on older OTPs.
fun == :set_label and module_matches?(callee, Process) and arity == 1 ->
qualify(m, m2, [:Lockstep, :Process], :set_label, args)
# Process.send/3 (drops the opts arg)
fun == :send and module_matches?(callee, Process) and arity == 3 ->
[pid, msg, _opts] = args
qualify(m, m2, [:Lockstep], :send, [pid, msg])
# Kernel.send/2 / :erlang.send/2
fun == :send and arity == 2 and (module_matches?(callee, Kernel) or callee == :erlang) ->
qualify(m, m2, [:Lockstep], :send, args)
true ->
node
end
end
# ============================================================
# Unqualified calls (Kernel-imported)
# ============================================================
# send(p, m)
defp rewrite_node({:send, m, args}) when length(args) == 2 do
{{:., m, [{:__aliases__, [], [:Lockstep]}, :send]}, m, args}
end
# spawn(fn -> ... end)
defp rewrite_node({:spawn, m, [fun]}) do
{{:., m, [{:__aliases__, [], [:Lockstep]}, :spawn]}, m, [fun]}
end
# spawn_link(fn -> ... end)
defp rewrite_node({:spawn_link, m, [fun]}) do
{{:., m, [{:__aliases__, [], [:Lockstep]}, :spawn_link]}, m, [fun]}
end
# ============================================================
# receive ... end
# ============================================================
# Bare receive without `after`. Rewrite to either
# Lockstep.recv() (if the only clause is `_ -> body`)
# or
# case Lockstep.recv_first(matcher) do clauses end
# for multi-clause receives.
defp rewrite_node({:receive, meta, [[do: clauses]]}) when is_list(clauses) do
case clauses do
[{:->, _, [[{:_, _, _}], body]}] ->
# `receive do _ -> body end` -- catch-all single clause.
quote do
_ = Lockstep.recv()
unquote(body)
end
[{:->, _, [[var], body]}] when is_atom(elem(var, 0)) and var != {:_, [], nil} ->
# `receive do msg -> body end` -- single-clause with a bind.
# If the var is itself just an unqualified variable name (atom)
# we can bind it directly.
if simple_var?(var) do
quote do
unquote(var) = Lockstep.recv()
unquote(body)
end
else
rewrite_multi_clause(clauses, meta)
end
_ ->
rewrite_multi_clause(clauses, meta)
end
end
# `receive ... after timeout -> ... end`. Rewrite to schedule a
# Lockstep timer, recv_first across both the user patterns and the
# timeout sentinel, then dispatch via case. If the timeout doesn't
# fire, cancel it.
defp rewrite_node(
{:receive, _meta, [[do: clauses, after: [{:->, _, [[timeout_expr], after_body]}]]]}
)
when is_list(clauses) do
sentinel_atom = :__lockstep_after__
matcher_clauses =
Enum.map(clauses, fn
{:->, m, [[{:when, when_meta, [pat, guard]}], _body]} ->
guard_vars = collect_vars(guard)
anon = anonymize_pattern(pat, guard_vars)
{:->, m, [[{:when, when_meta, [anon, guard]}], true]}
{:->, m, [[head], _body]} ->
anon = anonymize_pattern(head, MapSet.new())
{:->, m, [[anon], true]}
end) ++
[{:->, [], [[sentinel_atom], true]}, {:->, [], [[{:_, [], nil}], false]}]
matcher =
quote do
fn lockstep_match_arg ->
case lockstep_match_arg do
unquote(matcher_clauses)
end
end
end
case_clauses =
[{:->, [], [[sentinel_atom], after_body]}] ++
Enum.map(clauses, fn {:->, m, [head, body]} ->
{:->, m,
[
head,
quote do
_ = Lockstep.cancel_timer(lockstep_after_timer)
unquote(body)
end
]}
end)
quote do
lockstep_after_timer =
Lockstep.send_after(self(), unquote(sentinel_atom), unquote(timeout_expr))
lockstep_after_msg = Lockstep.recv_first(unquote(matcher))
case lockstep_after_msg do
unquote(case_clauses)
end
end
end
defp rewrite_node(node), do: node
defp qualify(_m, m2, prefix_parts, fun, args) do
{{:., m2, [{:__aliases__, [], prefix_parts}, fun]}, m2, args}
end
# ============================================================
# Helpers
# ============================================================
defp simple_var?({name, _, ctx}) when is_atom(name) and is_atom(ctx), do: true
defp simple_var?(_), do: false
defp module_matches?({:__aliases__, _, parts}, target) when is_atom(target) do
try do
Module.concat(parts) == target
rescue
_ -> false
end
end
defp module_matches?(atom, atom) when is_atom(atom), do: true
defp module_matches?(_, _), do: false
defp rewrite_multi_clause(clauses, _meta) do
matcher = build_matcher(clauses)
quote do
lockstep_msg = Lockstep.recv_first(unquote(matcher))
case lockstep_msg do
unquote(clauses)
end
end
end
# Build `fn msg -> case msg do <patterns -> true>; _ -> false end end`
# where each pattern is the user's pattern (with vars underscored to
# silence "unused variable" warnings in the matcher).
#
# When the original clause has a guard (`pattern when guard`), we
# keep the pattern's variables as-is *only* for those vars referenced
# in the guard -- since the guard captures them from the matcher fun's
# scope and any rename would break the closure reference. Other
# pattern vars are still anonymized.
defp build_matcher(clauses) do
matcher_clauses =
Enum.map(clauses, fn
{:->, m, [[{:when, when_meta, [pat, guard]}], _body]} ->
guard_vars = collect_vars(guard)
anonymized_pat = anonymize_pattern(pat, guard_vars)
{:->, m, [[{:when, when_meta, [anonymized_pat, guard]}], true]}
{:->, m, [[head], _body]} ->
anonymized_head = anonymize_pattern(head, MapSet.new())
{:->, m, [[anonymized_head], true]}
end)
fallback = {:->, [], [[{:_, [], nil}], false]}
quote do
fn lockstep_match_arg ->
case lockstep_match_arg do
unquote(matcher_clauses ++ [fallback])
end
end
end
end
# Collect the names of every variable referenced in `expr`.
defp collect_vars(expr) do
{_, vars} =
Macro.prewalk(expr, MapSet.new(), fn
{name, _, ctx} = node, acc when is_atom(name) and is_atom(ctx) ->
if name == :_ or match?("_" <> _, Atom.to_string(name)) do
{node, acc}
else
{node, MapSet.put(acc, name)}
end
node, acc ->
{node, acc}
end)
vars
end
# Walk a pattern AST and underscore every unpinned variable, except
# those names listed in `keep` (typically vars referenced by a
# surrounding `when` guard).
# `^x` (pinned) is preserved literally so it still matches the outer
# `x`. Function calls (which look like `{atom, _, [..]}` with a list
# context) are also preserved.
defp anonymize_pattern(pattern, keep) do
{result, _} =
Macro.traverse(
pattern,
{false, keep},
&enter_pattern/2,
&exit_pattern/2
)
result
end
defp enter_pattern({:^, meta, [inner]}, {_, keep}) do
# Mark this subtree as pinned -- we'll restore it untouched in exit.
{{:__lockstep_pin__, meta, [inner]}, {true, keep}}
end
defp enter_pattern(other, acc), do: {other, acc}
defp exit_pattern({:__lockstep_pin__, meta, [inner]}, {_, keep}) do
{{:^, meta, [inner]}, {false, keep}}
end
defp exit_pattern({name, meta, ctx}, {pinned?, keep} = acc)
when is_atom(name) and is_atom(ctx) do
str = Atom.to_string(name)
cond do
pinned? ->
# We're inside a pin -- don't touch.
{{name, meta, ctx}, acc}
str == "_" or String.starts_with?(str, "_") ->
{{name, meta, ctx}, acc}
MapSet.member?(keep, name) ->
# Referenced by an enclosing `when` guard -- keep the original
# name so the guard's closure reference still resolves.
{{name, meta, ctx}, acc}
true ->
{{:"_#{str}", meta, ctx}, acc}
end
end
defp exit_pattern(other, acc), do: {other, acc}
end