Skip to main content

lib/rustler/match_spec/compiler.ex

defmodule Rustler.MatchSpec.Compiler do
  @moduledoc false

  @reserved_vars MapSet.new([:_, :__MODULE__, :__ENV__, :__CALLER__])
  @variables Map.new(1..255, &{&1, :"$#{&1}"})

  def compile_block({:__block__, _meta, clauses}, caller),
    do: Enum.map(clauses, &compile_clause(&1, caller))

  def compile_block(clauses, caller) when is_list(clauses),
    do: Enum.map(clauses, &compile_clause(&1, caller))

  def compile_block(clause, caller), do: [compile_clause(clause, caller)]

  defp compile_clause({:->, _meta, [[{:when, _when_meta, [head, guard]}], body]}, caller) do
    {compiled_head, bindings} = compile_pattern(head, %{})
    compiled_guard = compile_expr(guard, bindings, caller)
    compiled_body = compile_expr(body_expr(body), bindings, caller)
    {compiled_head, [compiled_guard], [compiled_body]}
  end

  defp compile_clause({:->, _meta, [[head], body]}, caller) do
    {compiled_head, bindings} = compile_pattern(head, %{})
    compiled_body = compile_expr(body_expr(body), bindings, caller)
    {compiled_head, [], [compiled_body]}
  end

  defp compile_clause(other, _caller) do
    raise ArgumentError,
          "expected match_spec clauses in the form `pattern [when guard] -> body`, got: #{Macro.to_string(other)}"
  end

  defp body_expr([expr]), do: expr
  defp body_expr(exprs) when is_list(exprs), do: {:__block__, [], exprs}
  defp body_expr(expr), do: expr

  defp compile_pattern({:_, _meta, context}, bindings) when is_atom(context), do: {:_, bindings}

  defp compile_pattern({name, _meta, context}, bindings)
       when is_atom(name) and is_atom(context) do
    if reserved_var?(name), do: {{name, [], context}, bindings}, else: bind_var(name, bindings)
  end

  defp compile_pattern({left, right}, bindings) do
    {left, bindings} = compile_pattern(left, bindings)
    {right, bindings} = compile_pattern(right, bindings)
    {{left, right}, bindings}
  end

  defp compile_pattern({:{}, _meta, elements}, bindings) do
    compile_tuple(elements, bindings)
  end

  defp compile_pattern({:%{}, _meta, pairs}, bindings) do
    Enum.map_reduce(pairs, bindings, fn {key, value}, bindings ->
      {value, bindings} = compile_pattern(value, bindings)
      {{key, value}, bindings}
    end)
    |> then(fn {pairs, bindings} -> {Map.new(pairs), bindings} end)
  end

  defp compile_pattern({name, _meta, args}, bindings) when is_atom(name) and is_list(args) do
    {args, bindings} = Enum.map_reduce(args, bindings, &compile_pattern/2)
    {List.to_tuple([name | args]), bindings}
  end

  defp compile_pattern(list, bindings) when is_list(list) do
    Enum.map_reduce(list, bindings, &compile_pattern/2)
  end

  defp compile_pattern(literal, bindings), do: {literal, bindings}

  defp compile_tuple(elements, bindings) do
    {elements, bindings} = Enum.map_reduce(elements, bindings, &compile_pattern/2)
    {List.to_tuple(elements), bindings}
  end

  defp bind_var(name, bindings) do
    case Map.fetch(bindings, name) do
      {:ok, variable} ->
        {variable, bindings}

      :error ->
        variable = variable_for(map_size(bindings) + 1)
        {variable, Map.put(bindings, name, variable)}
    end
  end

  defp compile_expr({name, _meta, context}, bindings, _caller)
       when is_atom(name) and is_atom(context) do
    if reserved_var?(name), do: {name, [], context}, else: Map.fetch!(bindings, name)
  end

  defp compile_expr({:in, _meta, [left, right]}, bindings, caller) do
    {:member, compile_expr(left, bindings, caller), compile_expr(right, bindings, caller)}
  end

  defp compile_expr({op, _meta, args}, bindings, caller)
       when op in [
              :is_atom,
              :is_binary,
              :is_boolean,
              :is_float,
              :is_integer,
              :is_list,
              :is_map,
              :is_number,
              :is_tuple,
              :not,
              :and,
              :or,
              :andalso,
              :orelse,
              :xor,
              :==,
              :===,
              :!=,
              :!==,
              :>,
              :>=,
              :<,
              :<=,
              :+,
              :-,
              :*,
              :div,
              :rem
            ] and is_list(args) do
    {match_spec_op(op), Enum.map(args, &compile_expr(&1, bindings, caller))}
    |> tuple_from_op()
  end

  defp compile_expr({left, right}, bindings, caller) do
    {compile_expr(left, bindings, caller), compile_expr(right, bindings, caller)}
  end

  defp compile_expr({:{}, _meta, elements}, bindings, caller) do
    elements
    |> Enum.map(&compile_expr(&1, bindings, caller))
    |> List.to_tuple()
  end

  defp compile_expr({:%{}, _meta, pairs}, bindings, caller) do
    Map.new(pairs, fn {key, value} -> {key, compile_expr(value, bindings, caller)} end)
  end

  defp compile_expr(list, bindings, caller) when is_list(list) do
    Enum.map(list, &compile_expr(&1, bindings, caller))
  end

  defp compile_expr(literal, _bindings, _caller), do: literal

  defp tuple_from_op({op, args}), do: List.to_tuple([op | args])

  defp reserved_var?(name), do: MapSet.member?(@reserved_vars, name)

  defp variable_for(index) when index in 1..255, do: @variables[index]

  defp variable_for(index) do
    raise ArgumentError, "match_spec supports at most 255 variables, got variable #{index}"
  end

  defp match_spec_op(:===), do: :"=:="
  defp match_spec_op(:!==), do: :"=/="
  defp match_spec_op(:==), do: :==
  defp match_spec_op(:!=), do: :"/="
  defp match_spec_op(:<=), do: :"=<"
  defp match_spec_op(op), do: op
end