lib/lockstep/erlang_rewriter.ex

defmodule Lockstep.ErlangRewriter do
  @moduledoc """
  Rewrites Erlang source (`.erl`) so that vanilla OTP calls
  (`gen_server:call`, `erlang:spawn`, `Pid ! Msg`, bare `receive`)
  go through Lockstep's controller. Mirrors `Lockstep.Rewriter` but
  on Erlang's abstract format instead of Elixir's macro AST.

  ## Why

  Most of the BEAM ecosystem's foundational libraries are pure Erlang:
  `:pg`, `gen_stage`, `gen_statem`, `:dets`, `sleeplocks`, the Erlang
  internals of `:gen_server`, `:supervisor`, etc. Without an Erlang
  rewriter, libraries that depend on these (Phoenix.PubSub, GenStage,
  Cachex's transaction layer, libcluster, ...) fall outside Lockstep.
  This module is the bridge.

  ## Mappings

      vanilla Erlang                    rewritten
      -------                           ---------
      gen_server:call(S, M)             'Elixir.Lockstep.GenServer':call(S, M)
      gen_server:cast(S, M)             'Elixir.Lockstep.GenServer':cast(S, M)
      gen_server:start_link(M, A, O)    'Elixir.Lockstep.GenServer':start_link(M, A, O)
      gen_server:reply(F, R)            'Elixir.Lockstep.GenServer':reply(F, R)
      gen_server:stop(S, R, T)          'Elixir.Lockstep.GenServer':stop(S, R, T)
      erlang:spawn(F)                   'Elixir.Lockstep':spawn(F)
      erlang:spawn_link(F)              'Elixir.Lockstep':spawn_link(F)
      erlang:send(D, M)                 'Elixir.Lockstep':send(D, M)
      erlang:monitor(process, P)        'Elixir.Lockstep':monitor(P)
      erlang:demonitor(R)               'Elixir.Lockstep':demonitor(R, [])
      erlang:demonitor(R, O)            'Elixir.Lockstep':demonitor(R, O)
      erlang:link(P)                    'Elixir.Lockstep':link(P)
      erlang:unlink(P)                  'Elixir.Lockstep':unlink(P)
      erlang:process_flag(F, V)         'Elixir.Lockstep':flag(F, V)
      erlang:is_process_alive(P)        'Elixir.Lockstep':'alive?'(P)
      erlang:send_after(T, D, M)        'Elixir.Lockstep':send_after(D, M, T)   ** arg reorder! **
      erlang:cancel_timer(R)            'Elixir.Lockstep':cancel_timer(R)
      Pid ! Msg                         'Elixir.Lockstep':send(Pid, Msg)
      receive Cls end                   case 'Elixir.Lockstep':recv_first(matcher) of Cls end
      receive Cls after T -> Body end   timer + recv_first dispatch (see below)

  Also handles bare BIFs (`spawn(F)` without `erlang:` prefix), which
  are the auto-imported BIFs from the Erlang module.

  ## Limitations (v0.1)

    * `gen_server:start_link({local, Name}, M, A, O)` 4-arg form not
      yet supported. Use the 3-arg form (no name) and pass the pid
      around explicitly.
    * Pre-bound variables in receive patterns lose their "pin"
      semantics (the matcher does structural match only). Most code
      doesn't rely on this; use guards explicitly if you do.
    * `gen_statem`, `supervisor`, `:pg` are not yet wrapper-rewritten
      individually. They could be added by extending the call mappings.
  """

  # ============================================================
  # Public API
  # ============================================================

  @doc """
  Read `.erl` file, parse, rewrite, write back to `output_path`.
  Returns `{:ok, output_path}` on success.
  """
  @spec rewrite_file(Path.t(), Path.t(), keyword()) ::
          {:ok, Path.t()} | {:error, term()}
  def rewrite_file(input_path, output_path, opts \\ []) do
    include_dirs = Keyword.get(opts, :include_dirs, [])

    case :epp.parse_file(
           String.to_charlist(input_path),
           Enum.map(include_dirs, &String.to_charlist/1),
           []
         ) do
      {:ok, forms} ->
        rewritten = rewrite_forms(forms)
        File.mkdir_p!(Path.dirname(output_path))

        source =
          rewritten
          |> Enum.map(&:erl_pp.form/1)
          |> IO.iodata_to_binary()

        File.write!(output_path, source)
        {:ok, output_path}

      {:error, reason} ->
        {:error, {:parse_error, reason}}
    end
  end

  @doc """
  Rewrite a `.erl` file's forms and compile directly to a `.beam`
  binary. Skips the `.erl` round-trip. Returns `{:ok, module, binary}`
  or `{:error, errors, warnings}`.
  """
  @spec rewrite_and_compile(Path.t(), keyword()) ::
          {:ok, atom(), binary()} | {:error, list(), list()}
  def rewrite_and_compile(input_path, opts \\ []) do
    include_dirs = Keyword.get(opts, :include_dirs, [])

    case :epp.parse_file(
           String.to_charlist(input_path),
           Enum.map(include_dirs, &String.to_charlist/1),
           []
         ) do
      {:ok, forms} ->
        rewritten = rewrite_forms(forms)

        case :compile.forms(rewritten, [:binary, :return_errors]) do
          {:ok, mod, bin} -> {:ok, mod, bin}
          {:error, errors, warnings} -> {:error, errors, warnings}
        end

      {:error, reason} ->
        {:error, [{:parse_error, reason}], []}
    end
  end

  @doc """
  Walk a list of Erlang forms and rewrite. Returns the new forms.
  """
  @spec rewrite_forms([tuple()]) :: [tuple()]
  def rewrite_forms(forms) when is_list(forms) do
    Enum.map(forms, &rewrite_form/1)
  end

  # ============================================================
  # Form-level rewrite
  # ============================================================

  defp rewrite_form({:function, line, name, arity, clauses}) do
    {:function, line, name, arity, Enum.map(clauses, &rewrite_clause/1)}
  end

  defp rewrite_form(other), do: other

  defp rewrite_clause({:clause, line, args, guards, body}) do
    {:clause, line, args, guards, Enum.map(body, &rewrite_expr/1)}
  end

  # ============================================================
  # Expression rewrite
  # ============================================================

  # Module:Function(Args) -- the main rewrite target.
  defp rewrite_expr({:call, l, {:remote, lr, {:atom, lm, mod}, {:atom, lf, fun}}, args}) do
    args = Enum.map(args, &rewrite_expr/1)

    case rewrite_remote_call(mod, fun, length(args), args, l) do
      :nochange ->
        {:call, l, {:remote, lr, {:atom, lm, mod}, {:atom, lf, fun}}, args}

      replacement ->
        replacement
    end
  end

  # Bare BIF call: spawn(F), spawn_link(F), monitor(...), self(), etc.
  # We rewrite a small set; everything else passes through.
  defp rewrite_expr({:call, l, {:atom, _, fun}, args} = node) do
    args = Enum.map(args, &rewrite_expr/1)
    arity = length(args)

    case rewrite_bif_call(fun, arity, args, l) do
      :nochange -> {:call, l, elem(node, 2), args}
      replacement -> replacement
    end
  end

  # Pid ! Msg
  defp rewrite_expr({:op, l, :!, target, msg}) do
    target = rewrite_expr(target)
    msg = rewrite_expr(msg)
    lockstep_call(l, :"Elixir.Lockstep", :send, [target, msg])
  end

  # receive ... end
  defp rewrite_expr({:receive, l, clauses}) do
    rewrite_receive(l, Enum.map(clauses, &rewrite_clause/1), nil, nil)
  end

  defp rewrite_expr({:receive, l, clauses, timeout_expr, after_body}) do
    after_body = Enum.map(after_body, &rewrite_expr/1)

    rewrite_receive(
      l,
      Enum.map(clauses, &rewrite_clause/1),
      rewrite_expr(timeout_expr),
      after_body
    )
  end

  # Recurse into compound expressions.
  defp rewrite_expr({:case, l, expr, clauses}) do
    {:case, l, rewrite_expr(expr), Enum.map(clauses, &rewrite_clause/1)}
  end

  defp rewrite_expr({:if, l, clauses}) do
    {:if, l, Enum.map(clauses, &rewrite_clause/1)}
  end

  defp rewrite_expr({:try, l, body, of_clauses, catch_clauses, after_body}) do
    {:try, l, Enum.map(body, &rewrite_expr/1), Enum.map(of_clauses, &rewrite_clause/1),
     Enum.map(catch_clauses, &rewrite_clause/1), Enum.map(after_body, &rewrite_expr/1)}
  end

  defp rewrite_expr({:match, l, lhs, rhs}) do
    {:match, l, lhs, rewrite_expr(rhs)}
  end

  defp rewrite_expr({:op, l, op, lhs, rhs}) do
    {:op, l, op, rewrite_expr(lhs), rewrite_expr(rhs)}
  end

  defp rewrite_expr({:op, l, op, expr}) do
    {:op, l, op, rewrite_expr(expr)}
  end

  defp rewrite_expr({:tuple, l, elems}), do: {:tuple, l, Enum.map(elems, &rewrite_expr/1)}
  defp rewrite_expr({:cons, l, h, t}), do: {:cons, l, rewrite_expr(h), rewrite_expr(t)}
  defp rewrite_expr({:map, l, fields}), do: {:map, l, Enum.map(fields, &rewrite_expr/1)}

  defp rewrite_expr({:map, l, expr, fields}),
    do: {:map, l, rewrite_expr(expr), Enum.map(fields, &rewrite_expr/1)}

  defp rewrite_expr({:map_field_assoc, l, k, v}) do
    {:map_field_assoc, l, rewrite_expr(k), rewrite_expr(v)}
  end

  defp rewrite_expr({:map_field_exact, l, k, v}) do
    {:map_field_exact, l, rewrite_expr(k), rewrite_expr(v)}
  end

  defp rewrite_expr({:fun, l, {:clauses, clauses}}) do
    {:fun, l, {:clauses, Enum.map(clauses, &rewrite_clause/1)}}
  end

  defp rewrite_expr({:lc, l, expr, qs}) do
    {:lc, l, rewrite_expr(expr), Enum.map(qs, &rewrite_lc_qual/1)}
  end

  defp rewrite_expr({:bc, l, expr, qs}) do
    {:bc, l, rewrite_expr(expr), Enum.map(qs, &rewrite_lc_qual/1)}
  end

  defp rewrite_expr({:bin, l, segments}) do
    {:bin, l, Enum.map(segments, &rewrite_bin_segment/1)}
  end

  defp rewrite_expr({:block, l, exprs}), do: {:block, l, Enum.map(exprs, &rewrite_expr/1)}
  defp rewrite_expr(other), do: other

  defp rewrite_lc_qual({:generate, l, p, e}), do: {:generate, l, p, rewrite_expr(e)}
  defp rewrite_lc_qual({:b_generate, l, p, e}), do: {:b_generate, l, p, rewrite_expr(e)}
  defp rewrite_lc_qual(other), do: rewrite_expr(other)

  defp rewrite_bin_segment({:bin_element, l, val, size, types}) do
    {:bin_element, l, rewrite_expr(val), size, types}
  end

  # ============================================================
  # Specific rewrite rules
  # ============================================================

  # `gen_server:call(S, M)` → `'Elixir.Lockstep.GenServer':call(S, M)`
  defp rewrite_remote_call(:gen_server, fun, _arity, args, l)
       when fun in [:call, :cast, :start_link, :start, :stop, :reply] do
    lockstep_call(l, :"Elixir.Lockstep.GenServer", fun, args)
  end

  # `ets:Fn(...)` → `Lockstep.ETS.Fn(...)` so atom-named tables get
  # per-node namespacing. Functions we don't wrap pass through
  # untouched.
  defp rewrite_remote_call(:ets, fun, arity, args, l)
       when {fun, arity} in [
              {:new, 2},
              {:delete, 1},
              {:delete, 2},
              {:delete_all_objects, 1},
              {:insert, 2},
              {:insert_new, 2},
              {:lookup, 2},
              {:lookup_element, 3},
              {:lookup_element, 4},
              {:member, 2},
              {:match_delete, 2},
              {:update_counter, 3},
              {:update_counter, 4},
              {:update_element, 3},
              {:select, 1},
              {:select, 2},
              {:select, 3},
              {:match, 2},
              {:match_object, 2},
              {:tab2list, 1},
              {:info, 1},
              {:info, 2},
              {:first, 1},
              {:last, 1},
              {:next, 2},
              {:prev, 2},
              {:take, 2},
              {:select_count, 2},
              {:safe_fixtable, 2}
            ] do
    lockstep_call(l, :"Elixir.Lockstep.ETS", fun, args)
  end

  # `persistent_term:Fn(...)` → `Lockstep.PersistentTerm.Fn(...)` so
  # keys are per-node-namespaced.
  defp rewrite_remote_call(:persistent_term, fun, arity, args, l)
       when {fun, arity} in [{:put, 2}, {:get, 1}, {:get, 2}, {:erase, 1}, {:info, 0}] do
    lockstep_call(l, :"Elixir.Lockstep.PersistentTerm", fun, args)
  end

  # `erlang:send(D, M)` → `Lockstep:send(D, M)`
  defp rewrite_remote_call(:erlang, :send, 2, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :send, args)
  end

  # `erlang:spawn(F)` / `erlang:spawn(M, F, A)` → `Lockstep:spawn(F)`
  # 3-arg form only handles the case where M, F are atoms — we wrap as
  # a fun: `fun() -> M:F(A) end`. (Not implemented — for v1, only
  # 1-arg form is rewritten.)
  defp rewrite_remote_call(:erlang, :spawn, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :spawn, args)
  end

  defp rewrite_remote_call(:erlang, :spawn_link, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :spawn_link, args)
  end

  # `erlang:monitor(process, P)` → `Lockstep:monitor(P)`. Other types
  # (port, time_offset) aren't modeled.
  defp rewrite_remote_call(:erlang, :monitor, 2, [{:atom, _, :process}, pid], l) do
    lockstep_call(l, :"Elixir.Lockstep", :monitor, [pid])
  end

  defp rewrite_remote_call(:erlang, :demonitor, 1, [ref], l) do
    lockstep_call(l, :"Elixir.Lockstep", :demonitor, [ref, {nil, l}])
  end

  defp rewrite_remote_call(:erlang, :demonitor, 2, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :demonitor, args)
  end

  defp rewrite_remote_call(:erlang, :link, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :link, args)
  end

  defp rewrite_remote_call(:erlang, :unlink, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :unlink, args)
  end

  defp rewrite_remote_call(:erlang, :process_flag, 2, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :flag, args)
  end

  defp rewrite_remote_call(:erlang, :is_process_alive, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :alive?, args)
  end

  # `erlang:send_after(Time, Dest, Msg)` → `Lockstep:send_after(Dest, Msg, Time)`
  # ** Argument reorder! **
  defp rewrite_remote_call(:erlang, :send_after, 3, [time, dest, msg], l) do
    lockstep_call(l, :"Elixir.Lockstep", :send_after, [dest, msg, time])
  end

  defp rewrite_remote_call(:erlang, :cancel_timer, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep", :cancel_timer, args)
  end

  # `erlang:register/2`, `erlang:unregister/1`, `erlang:whereis/1`,
  # `erlang:registered/0` — all four become per-node via Lockstep.Process.
  # We catch both the `erlang:` qualified form and bare BIF forms.
  defp rewrite_remote_call(:erlang, :register, 2, args, l) do
    lockstep_call(l, :"Elixir.Lockstep.Process", :register, swap(args))
  end

  defp rewrite_remote_call(:erlang, :unregister, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep.Process", :unregister, args)
  end

  defp rewrite_remote_call(:erlang, :whereis, 1, args, l) do
    lockstep_call(l, :"Elixir.Lockstep.Process", :whereis, args)
  end

  defp rewrite_remote_call(:erlang, :registered, 0, args, l) do
    lockstep_call(l, :"Elixir.Lockstep.Process", :registered, args)
  end

  defp rewrite_remote_call(_, _, _, _, _), do: :nochange

  # `erlang:register(Name, Pid)` reverses arg order vs Lockstep.Process.register/2.
  # Lockstep.Process.register/2 takes (pid, name) like the Elixir API.
  defp swap([a, b]), do: [b, a]

  # ============================================================
  # Bare BIFs (auto-imported from the erlang module)
  # ============================================================

  defp rewrite_bif_call(:spawn, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :spawn, args)

  defp rewrite_bif_call(:spawn_link, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :spawn_link, args)

  defp rewrite_bif_call(:link, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :link, args)

  defp rewrite_bif_call(:unlink, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :unlink, args)

  defp rewrite_bif_call(:process_flag, 2, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :flag, args)

  defp rewrite_bif_call(:is_process_alive, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :alive?, args)

  defp rewrite_bif_call(:monitor, 2, [{:atom, _, :process}, pid], l),
    do: lockstep_call(l, :"Elixir.Lockstep", :monitor, [pid])

  defp rewrite_bif_call(:demonitor, 1, [ref], l),
    do: lockstep_call(l, :"Elixir.Lockstep", :demonitor, [ref, {nil, l}])

  defp rewrite_bif_call(:demonitor, 2, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep", :demonitor, args)

  defp rewrite_bif_call(:register, 2, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep.Process", :register, swap(args))

  defp rewrite_bif_call(:unregister, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep.Process", :unregister, args)

  defp rewrite_bif_call(:whereis, 1, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep.Process", :whereis, args)

  defp rewrite_bif_call(:registered, 0, args, l),
    do: lockstep_call(l, :"Elixir.Lockstep.Process", :registered, args)

  defp rewrite_bif_call(_, _, _, _), do: :nochange

  # ============================================================
  # Receive rewrite
  # ============================================================
  #
  # receive
  #   Cls
  # end
  #
  # =>
  #
  # case 'Elixir.Lockstep':recv_first(fun(__LSMsg) ->
  #   case __LSMsg of
  #     <pat from Cl1, vars stripped> -> true;
  #     ...
  #     _ -> false
  #   end
  # end) of
  #   Cls
  # end

  defp rewrite_receive(l, clauses, nil, nil) do
    matcher = build_matcher_fun(l, clauses)

    {:case, l,
     {:call, l, {:remote, l, {:atom, l, :"Elixir.Lockstep"}, {:atom, l, :recv_first}}, [matcher]},
     clauses}
  end

  defp rewrite_receive(l, clauses, timeout_expr, after_body) do
    sentinel = :__lockstep_after

    timer_var = {:var, l, :_LSTimer}
    msg_var = {:var, l, :_LSMsg}

    # Schedule timer and select.
    schedule =
      {:match, l, timer_var,
       lockstep_call(l, :"Elixir.Lockstep", :send_after, [
         self_call(l),
         {:atom, l, sentinel},
         timeout_expr
       ])}

    matcher = build_matcher_fun_with_sentinel(l, clauses, sentinel)

    recv =
      {:match, l, msg_var, lockstep_call(l, :"Elixir.Lockstep", :recv_first, [matcher])}

    # Build case clauses: sentinel branch + original clauses (with cancel_timer).
    sentinel_clause =
      {:clause, l, [{:atom, l, sentinel}], [], after_body}

    annotated_clauses =
      Enum.map(clauses, fn {:clause, cl, pat, gs, body} ->
        cancel =
          {:call, cl, {:remote, cl, {:atom, cl, :"Elixir.Lockstep"}, {:atom, cl, :cancel_timer}},
           [timer_var]}

        {:clause, cl, pat, gs, [cancel | body]}
      end)

    {:block, l,
     [
       schedule,
       recv,
       {:case, l, msg_var, [sentinel_clause | annotated_clauses]}
     ]}
  end

  defp build_matcher_fun(l, clauses) do
    arg = {:var, l, :_LSArg}
    matcher_clauses = build_matcher_clauses(l, clauses)

    {:fun, l, {:clauses, [{:clause, l, [arg], [], [{:case, l, arg, matcher_clauses}]}]}}
  end

  defp build_matcher_fun_with_sentinel(l, clauses, sentinel) do
    arg = {:var, l, :_LSArg}

    user_clauses = build_matcher_clauses_no_default(l, clauses)
    sentinel_match = {:clause, l, [{:atom, l, sentinel}], [], [{:atom, l, true}]}
    default = {:clause, l, [{:var, l, :_}], [], [{:atom, l, false}]}

    matcher_clauses = user_clauses ++ [sentinel_match, default]

    {:fun, l, {:clauses, [{:clause, l, [arg], [], [{:case, l, arg, matcher_clauses}]}]}}
  end

  defp build_matcher_clauses(l, clauses) do
    user_clauses = build_matcher_clauses_no_default(l, clauses)
    default = {:clause, l, [{:var, l, :_}], [], [{:atom, l, false}]}
    user_clauses ++ [default]
  end

  defp build_matcher_clauses_no_default(_l, clauses) do
    Enum.map(clauses, fn {:clause, cl, pats, _guards, _body} ->
      anon_pats = Enum.map(pats, &anonymize_pattern/1)
      {:clause, cl, anon_pats, [], [{:atom, cl, true}]}
    end)
  end

  # Walk a pattern, replacing every variable with an underscored
  # variant. We do NOT preserve "pre-bound variable pinning"; the
  # matcher is a structural test only (documented limitation).
  defp anonymize_pattern({:var, l, name}) when is_atom(name) do
    str = Atom.to_string(name)

    cond do
      String.starts_with?(str, "_") -> {:var, l, name}
      str == "_" -> {:var, l, :_}
      true -> {:var, l, String.to_atom("_LS" <> str)}
    end
  end

  defp anonymize_pattern({:tuple, l, elems}),
    do: {:tuple, l, Enum.map(elems, &anonymize_pattern/1)}

  defp anonymize_pattern({:cons, l, h, t}),
    do: {:cons, l, anonymize_pattern(h), anonymize_pattern(t)}

  defp anonymize_pattern({:match, l, lhs, rhs}),
    do: {:match, l, anonymize_pattern(lhs), anonymize_pattern(rhs)}

  defp anonymize_pattern({:map, l, fields}),
    do: {:map, l, Enum.map(fields, &anonymize_pattern/1)}

  defp anonymize_pattern({:map_field_exact, l, k, v}),
    do: {:map_field_exact, l, k, anonymize_pattern(v)}

  defp anonymize_pattern({:map_field_assoc, l, k, v}),
    do: {:map_field_assoc, l, k, anonymize_pattern(v)}

  defp anonymize_pattern({:bin, l, segments}),
    do:
      {:bin, l,
       Enum.map(segments, fn
         {:bin_element, ll, v, size, types} ->
           {:bin_element, ll, anonymize_pattern(v), size, types}
       end)}

  defp anonymize_pattern(other), do: other

  # ============================================================
  # AST construction helpers
  # ============================================================

  defp lockstep_call(l, mod, fun, args) do
    {:call, l, {:remote, l, {:atom, l, mod}, {:atom, l, fun}}, args}
  end

  defp self_call(l) do
    {:call, l, {:atom, l, :self}, []}
  end
end