lib/lockstep/rewriter.ex

defmodule Lockstep.Rewriter do
  @moduledoc """
  Compile-time AST rewriter that converts vanilla OTP calls into their
  `Lockstep.*` equivalents inside a `ctest` body. Lets test bodies
  read like ordinary Elixir without sacrificing controlled scheduling.

  Enabled with `use Lockstep.Test, rewrite: true`. With `rewrite: true`
  the linter is also disabled (since the rewritten body is correct by
  construction).

  ## Mappings

      vanilla                          rewritten
      -------                          ---------
      GenServer.call(s, m)             Lockstep.GenServer.call(s, m)
      GenServer.cast(s, m)             Lockstep.GenServer.cast(s, m)
      GenServer.start_link(M, A)       Lockstep.GenServer.start_link(M, A)
      Task.async(fn -> ... end)        Lockstep.Task.async(fn -> ... end)
      Task.await(t)                    Lockstep.Task.await(t)
      Task.await_many(ts)              Lockstep.Task.await_many(ts)
      spawn(fun)                       Lockstep.spawn(fun)
      spawn_link(fun)                  Lockstep.spawn_link(fun)
      send(p, m)                       Lockstep.send(p, m)
      Kernel.send(p, m)                Lockstep.send(p, m)
      :erlang.send(p, m)               Lockstep.send(p, m)
      Process.send_after(p, m, ms)     Lockstep.send_after(p, m, ms)
      Process.cancel_timer(ref)        Lockstep.cancel_timer(ref)
      Process.monitor(p)               Lockstep.monitor(p)
      Process.demonitor(ref[, opts])   Lockstep.demonitor(ref[, opts])
      Process.alive?(p)                Lockstep.alive?(p)
      Process.sleep(ms)                Lockstep.sleep(ms)
      Process.link(p)                  Lockstep.link(p)
      Process.unlink(p)                Lockstep.unlink(p)
      Process.flag(:trap_exit, b)      Lockstep.flag(:trap_exit, b)
      receive do clauses end           rewritten to Lockstep.recv_first/1

  ## Limitations

    * Only the *body* of `ctest` is rewritten, not helper functions
      defined elsewhere. Inline race code into the test body, or call
      Lockstep wrappers directly in helpers.
    * `receive ... after t -> ... end` (bare receive with timeout) is
      not yet rewritten -- bare receives without `after` are.
    * `:erlang.spawn/3` (the 3-arg MFA form) is not rewritten.

  Functions defined locally that happen to be named `send/2` etc.
  will be rewritten too. Suppress by qualifying the call (e.g.
  `MyMod.send(...)` is left alone).
  """

  @doc "Walk `ast` and return the rewritten AST."
  def rewrite(ast), do: walk(ast)

  # Custom walker that skips constructs whose call-shaped AST nodes
  # are not actual function calls:
  #
  #   * `def send(a, b), do: ...` -- the head is a pattern, not a call
  #     (Mutex literally has this).
  #   * `@spec send(t) :: result` / `@type` / `@callback` -- module
  #     attributes use the same AST shape but represent declarations.
  #
  # Both subtrees are returned unchanged.
  defp walk({op, meta, [head, body]}) when op in [:def, :defp, :defmacro, :defmacrop] do
    {op, meta, [head, walk(body)]}
  end

  defp walk({:@, _meta, _args} = node), do: node

  # Rewrite `{:via, Registry, args}` → `{:via, Lockstep.Registry, args}`.
  #
  # OTP's `:via` contract dispatches to the via module's `register_name/2`
  # (and friends) inside `gen_server` / `proc_lib` internals — code we
  # can't rewrite. So we have to fix it at the call site: change the
  # tuple itself so OTP calls `Lockstep.Registry.register_name/2`, which
  # we control. `Lockstep.Registry` already implements the full :via
  # callback contract (`register_name/2`, `unregister_name/1`,
  # `whereis_name/1`, `send/2`).
  #
  # The 3-tuple `{:via, Registry, term}` lives in AST as
  # `{:{}, meta, [:via, {:__aliases__, _, [:Registry]}, term_ast]}`.
  defp walk({:{}, meta, [:via, {:__aliases__, am, [:Registry]}, term_ast]}) do
    {:{}, meta, [:via, {:__aliases__, am, [:Lockstep, :Registry]}, walk(term_ast)]}
  end

  defp walk({op, meta, args}) when is_atom(op) and is_list(args) do
    rewritten_args = walk(args)
    rewrite_node({op, meta, rewritten_args})
  end

  defp walk({first, meta, args}) when is_list(args) do
    walked = {walk(first), meta, walk(args)}
    rewrite_node(walked)
  end

  defp walk({a, b}), do: {walk(a), walk(b)}
  defp walk(list) when is_list(list), do: Enum.map(list, &walk/1)
  defp walk(other), do: other

  # ============================================================
  # Qualified calls
  # ============================================================

  # All qualified Module.fn calls are dispatched via one cond. Earlier
  # we had multiple `rewrite_node({{:., _, ...}, ...})` clauses that
  # all matched the same outer pattern; whichever was first won. That
  # made it impossible to add new mappings without rewriting an
  # earlier clause's "if module else node" guard. The cond approach
  # collapses them all.
  defp rewrite_node({{:., m, [callee, fun]}, m2, args} = node) do
    arity = length(args)

    cond do
      # GenServer.{call, cast, start_link, stop, start, reply}
      fun in [:call, :cast, :start_link, :stop, :start, :reply] and
          module_matches?(callee, GenServer) ->
        qualify(m, m2, [:Lockstep, :GenServer], fun, args)

      # Registry.{start_link, register, unregister, lookup, count,
      # keys, dispatch, meta, put_meta, select}
      fun in [
        :start_link,
        :register,
        :unregister,
        :lookup,
        :count,
        :keys,
        :dispatch,
        :meta,
        :put_meta,
        :select
      ] and module_matches?(callee, Registry) ->
        qualify(m, m2, [:Lockstep, :Registry], fun, args)

      # Supervisor.{start_link, init, which_children, count_children,
      # start_child, terminate_child, restart_child}
      fun in [
        :start_link,
        :init,
        :which_children,
        :count_children,
        :start_child,
        :terminate_child,
        :restart_child
      ] and module_matches?(callee, Supervisor) ->
        qualify(m, m2, [:Lockstep, :Supervisor], fun, args)

      # Task.{async, await, await_many, async_stream}
      fun in [:async, :await, :await_many, :async_stream] and module_matches?(callee, Task) ->
        qualify(m, m2, [:Lockstep, :Task], fun, args)

      # Task.Supervisor.{start_link, async, start_child, async_stream}
      fun in [:start_link, :async, :start_child, :async_stream] and
          module_matches?(callee, Task.Supervisor) ->
        qualify(m, m2, [:Lockstep, :Task, :Supervisor], fun, args)

      # Agent.{start_link, start, get, update, get_and_update, cast, stop}
      fun in [
        :start_link,
        :start,
        :get,
        :update,
        :get_and_update,
        :cast,
        :stop
      ] and module_matches?(callee, Agent) ->
        qualify(m, m2, [:Lockstep, :Agent], fun, args)

      # :ets.* operations: insert sync points so read-modify-write
      # races can be interleaved by the strategy. The full set of
      # wrappable ops is in `Lockstep.ETS`.
      fun in [
        :new,
        :insert,
        :insert_new,
        :lookup,
        :lookup_element,
        :member,
        :delete,
        :delete_all_objects,
        :match_delete,
        :update_counter,
        :update_element,
        :select,
        :match,
        :match_object,
        :tab2list,
        :info,
        :first,
        :last,
        :next,
        :prev,
        :take,
        :select_count,
        :safe_fixtable
      ] and callee == :ets ->
        qualify(m, m2, [:Lockstep, :ETS], fun, args)

      # :atomics.* operations: same idea as ETS. Atomics are atomic
      # individually but compositions (read-then-update) aren't.
      fun in [
        :new,
        :get,
        :put,
        :add,
        :sub,
        :add_get,
        :sub_get,
        :exchange,
        :compare_exchange,
        :info
      ] and callee == :atomics ->
        qualify(m, m2, [:Lockstep, :Atomics], fun, args)

      # :persistent_term.* operations.
      fun in [:get, :put, :erase, :info] and callee == :persistent_term ->
        qualify(m, m2, [:Lockstep, :PersistentTerm], fun, args)

      # :rpc.* — cross-node call, multicall, cast, abcast.
      fun in [:call, :multicall, :cast, :abcast] and callee == :rpc ->
        qualify(m, m2, [:Lockstep, :RPC], fun, args)

      # :global.* — cluster-wide name registry.
      fun in [:register_name, :unregister_name, :whereis_name, :sync, :send] and
          callee == :global ->
        qualify(m, m2, [:Lockstep, :Global], fun, args)

      # Task.{start_link, start}/1,3 -> Lockstep.Task.{start_link, start}
      # which return {:ok, pid} matching the OTP shape.
      fun in [:start_link, :start] and module_matches?(callee, Task) and arity in [1, 3] ->
        qualify(m, m2, [:Lockstep, :Task], fun, args)

      # Process.send_after/3
      fun == :send_after and module_matches?(callee, Process) and arity == 3 ->
        qualify(m, m2, [:Lockstep], :send_after, args)

      # Process.cancel_timer/1
      fun == :cancel_timer and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :cancel_timer, args)

      # Process.monitor/1
      fun == :monitor and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :monitor, args)

      # Process.demonitor/{1,2}
      fun == :demonitor and module_matches?(callee, Process) and arity in [1, 2] ->
        qualify(m, m2, [:Lockstep], :demonitor, args)

      # Process.alive?/1
      fun == :alive? and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :alive?, args)

      # Process.sleep/1
      fun == :sleep and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :sleep, args)

      # Process.link/1
      fun == :link and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :link, args)

      # Process.unlink/1
      fun == :unlink and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep], :unlink, args)

      # Process.flag/2
      fun == :flag and module_matches?(callee, Process) and arity == 2 ->
        qualify(m, m2, [:Lockstep], :flag, args)

      # Process.whereis/1 -- per-node name lookup
      fun == :whereis and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep, :Process], :whereis, args)

      # Process.register/2 -- per-node name registration
      fun == :register and module_matches?(callee, Process) and arity == 2 ->
        qualify(m, m2, [:Lockstep, :Process], :register, args)

      # Process.unregister/1
      fun == :unregister and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep, :Process], :unregister, args)

      # Process.registered/0
      fun == :registered and module_matches?(callee, Process) and arity == 0 ->
        qualify(m, m2, [:Lockstep, :Process], :registered, args)

      # Process.set_label/1 -- OTP 27+; route through Lockstep.Process
      # which no-ops on older OTPs.
      fun == :set_label and module_matches?(callee, Process) and arity == 1 ->
        qualify(m, m2, [:Lockstep, :Process], :set_label, args)

      # Process.send/3 (drops the opts arg)
      fun == :send and module_matches?(callee, Process) and arity == 3 ->
        [pid, msg, _opts] = args
        qualify(m, m2, [:Lockstep], :send, [pid, msg])

      # Kernel.send/2 / :erlang.send/2
      fun == :send and arity == 2 and (module_matches?(callee, Kernel) or callee == :erlang) ->
        qualify(m, m2, [:Lockstep], :send, args)

      true ->
        node
    end
  end

  # ============================================================
  # Unqualified calls (Kernel-imported)
  # ============================================================

  # send(p, m)
  defp rewrite_node({:send, m, args}) when length(args) == 2 do
    {{:., m, [{:__aliases__, [], [:Lockstep]}, :send]}, m, args}
  end

  # spawn(fn -> ... end)
  defp rewrite_node({:spawn, m, [fun]}) do
    {{:., m, [{:__aliases__, [], [:Lockstep]}, :spawn]}, m, [fun]}
  end

  # spawn_link(fn -> ... end)
  defp rewrite_node({:spawn_link, m, [fun]}) do
    {{:., m, [{:__aliases__, [], [:Lockstep]}, :spawn_link]}, m, [fun]}
  end

  # ============================================================
  # receive ... end
  # ============================================================

  # Bare receive without `after`. Rewrite to either
  #   Lockstep.recv()  (if the only clause is `_ -> body`)
  # or
  #   case Lockstep.recv_first(matcher) do clauses end
  # for multi-clause receives.
  defp rewrite_node({:receive, meta, [[do: clauses]]}) when is_list(clauses) do
    case clauses do
      [{:->, _, [[{:_, _, _}], body]}] ->
        # `receive do _ -> body end` -- catch-all single clause.
        quote do
          _ = Lockstep.recv()
          unquote(body)
        end

      [{:->, _, [[var], body]}] when is_atom(elem(var, 0)) and var != {:_, [], nil} ->
        # `receive do msg -> body end` -- single-clause with a bind.
        # If the var is itself just an unqualified variable name (atom)
        # we can bind it directly.
        if simple_var?(var) do
          quote do
            unquote(var) = Lockstep.recv()
            unquote(body)
          end
        else
          rewrite_multi_clause(clauses, meta)
        end

      _ ->
        rewrite_multi_clause(clauses, meta)
    end
  end

  # `receive ... after timeout -> ... end`. Rewrite to schedule a
  # Lockstep timer, recv_first across both the user patterns and the
  # timeout sentinel, then dispatch via case. If the timeout doesn't
  # fire, cancel it.
  defp rewrite_node(
         {:receive, _meta, [[do: clauses, after: [{:->, _, [[timeout_expr], after_body]}]]]}
       )
       when is_list(clauses) do
    sentinel_atom = :__lockstep_after__

    matcher_clauses =
      Enum.map(clauses, fn
        {:->, m, [[{:when, when_meta, [pat, guard]}], _body]} ->
          guard_vars = collect_vars(guard)
          anon = anonymize_pattern(pat, guard_vars)
          {:->, m, [[{:when, when_meta, [anon, guard]}], true]}

        {:->, m, [[head], _body]} ->
          anon = anonymize_pattern(head, MapSet.new())
          {:->, m, [[anon], true]}
      end) ++
        [{:->, [], [[sentinel_atom], true]}, {:->, [], [[{:_, [], nil}], false]}]

    matcher =
      quote do
        fn lockstep_match_arg ->
          case lockstep_match_arg do
            unquote(matcher_clauses)
          end
        end
      end

    case_clauses =
      [{:->, [], [[sentinel_atom], after_body]}] ++
        Enum.map(clauses, fn {:->, m, [head, body]} ->
          {:->, m,
           [
             head,
             quote do
               _ = Lockstep.cancel_timer(lockstep_after_timer)
               unquote(body)
             end
           ]}
        end)

    quote do
      lockstep_after_timer =
        Lockstep.send_after(self(), unquote(sentinel_atom), unquote(timeout_expr))

      lockstep_after_msg = Lockstep.recv_first(unquote(matcher))

      case lockstep_after_msg do
        unquote(case_clauses)
      end
    end
  end

  defp rewrite_node(node), do: node

  defp qualify(_m, m2, prefix_parts, fun, args) do
    {{:., m2, [{:__aliases__, [], prefix_parts}, fun]}, m2, args}
  end

  # ============================================================
  # Helpers
  # ============================================================

  defp simple_var?({name, _, ctx}) when is_atom(name) and is_atom(ctx), do: true
  defp simple_var?(_), do: false

  defp module_matches?({:__aliases__, _, parts}, target) when is_atom(target) do
    try do
      Module.concat(parts) == target
    rescue
      _ -> false
    end
  end

  defp module_matches?(atom, atom) when is_atom(atom), do: true
  defp module_matches?(_, _), do: false

  defp rewrite_multi_clause(clauses, _meta) do
    matcher = build_matcher(clauses)

    quote do
      lockstep_msg = Lockstep.recv_first(unquote(matcher))

      case lockstep_msg do
        unquote(clauses)
      end
    end
  end

  # Build `fn msg -> case msg do <patterns -> true>; _ -> false end end`
  # where each pattern is the user's pattern (with vars underscored to
  # silence "unused variable" warnings in the matcher).
  #
  # When the original clause has a guard (`pattern when guard`), we
  # keep the pattern's variables as-is *only* for those vars referenced
  # in the guard -- since the guard captures them from the matcher fun's
  # scope and any rename would break the closure reference. Other
  # pattern vars are still anonymized.
  defp build_matcher(clauses) do
    matcher_clauses =
      Enum.map(clauses, fn
        {:->, m, [[{:when, when_meta, [pat, guard]}], _body]} ->
          guard_vars = collect_vars(guard)
          anonymized_pat = anonymize_pattern(pat, guard_vars)
          {:->, m, [[{:when, when_meta, [anonymized_pat, guard]}], true]}

        {:->, m, [[head], _body]} ->
          anonymized_head = anonymize_pattern(head, MapSet.new())
          {:->, m, [[anonymized_head], true]}
      end)

    fallback = {:->, [], [[{:_, [], nil}], false]}

    quote do
      fn lockstep_match_arg ->
        case lockstep_match_arg do
          unquote(matcher_clauses ++ [fallback])
        end
      end
    end
  end

  # Collect the names of every variable referenced in `expr`.
  defp collect_vars(expr) do
    {_, vars} =
      Macro.prewalk(expr, MapSet.new(), fn
        {name, _, ctx} = node, acc when is_atom(name) and is_atom(ctx) ->
          if name == :_ or match?("_" <> _, Atom.to_string(name)) do
            {node, acc}
          else
            {node, MapSet.put(acc, name)}
          end

        node, acc ->
          {node, acc}
      end)

    vars
  end

  # Walk a pattern AST and underscore every unpinned variable, except
  # those names listed in `keep` (typically vars referenced by a
  # surrounding `when` guard).
  # `^x` (pinned) is preserved literally so it still matches the outer
  # `x`. Function calls (which look like `{atom, _, [..]}` with a list
  # context) are also preserved.
  defp anonymize_pattern(pattern, keep) do
    {result, _} =
      Macro.traverse(
        pattern,
        {false, keep},
        &enter_pattern/2,
        &exit_pattern/2
      )

    result
  end

  defp enter_pattern({:^, meta, [inner]}, {_, keep}) do
    # Mark this subtree as pinned -- we'll restore it untouched in exit.
    {{:__lockstep_pin__, meta, [inner]}, {true, keep}}
  end

  defp enter_pattern(other, acc), do: {other, acc}

  defp exit_pattern({:__lockstep_pin__, meta, [inner]}, {_, keep}) do
    {{:^, meta, [inner]}, {false, keep}}
  end

  defp exit_pattern({name, meta, ctx}, {pinned?, keep} = acc)
       when is_atom(name) and is_atom(ctx) do
    str = Atom.to_string(name)

    cond do
      pinned? ->
        # We're inside a pin -- don't touch.
        {{name, meta, ctx}, acc}

      str == "_" or String.starts_with?(str, "_") ->
        {{name, meta, ctx}, acc}

      MapSet.member?(keep, name) ->
        # Referenced by an enclosing `when` guard -- keep the original
        # name so the guard's closure reference still resolves.
        {{name, meta, ctx}, acc}

      true ->
        {{:"_#{str}", meta, ctx}, acc}
    end
  end

  defp exit_pattern(other, acc), do: {other, acc}
end