lib/beaver/dsl/ssa.ex

defmodule Beaver.SSA do
  @moduledoc """
  Storing MLIR IR structure with a Elixir struct. Macros like `Beaver.mlir/1` will generate SSA structs defined by this module.
  """
  alias Beaver.MLIR
  require Beaver.MLIR.CAPI

  @type t() :: %__MODULE__{
          arguments: any(),
          results: any(),
          filler: any(),
          block: nil,
          ctx: any(),
          loc: any(),
          evaluator: function()
        }
  defstruct arguments: [],
            results: [],
            filler: nil,
            block: nil,
            ctx: nil,
            loc: nil,
            evaluator: nil

  def put_arguments(%__MODULE__{arguments: arguments} = ssa, additional_arguments)
      when is_list(additional_arguments) do
    %__MODULE__{ssa | arguments: arguments ++ additional_arguments}
  end

  def put_results(%__MODULE__{results: results} = ssa, f) when is_function(f, 1) do
    %__MODULE__{ssa | results: results ++ [f]}
  end

  def put_results(%__MODULE__{results: results} = ssa, %MLIR.Type{} = single_result) do
    %__MODULE__{ssa | results: results ++ [single_result]}
  end

  def put_results(%__MODULE__{results: results} = ssa, additional_results)
      when is_list(additional_results) do
    %__MODULE__{ssa | results: results ++ additional_results}
  end

  def put_results(%__MODULE__{} = ssa, :infer = infer) do
    %__MODULE__{ssa | results: infer}
  end

  def put_location(%__MODULE__{} = ssa, %MLIR.Location{} = loc) do
    %__MODULE__{ssa | loc: loc}
  end

  def put_filler(%__MODULE__{} = ssa, filler) when is_function(filler, 0) do
    %__MODULE__{ssa | filler: filler}
  end

  def put_block(%__MODULE__{} = ssa, block) do
    %__MODULE__{ssa | block: block}
  end

  def put_ctx(%__MODULE__{} = ssa, %MLIR.Context{} = ctx) do
    %__MODULE__{ssa | ctx: ctx}
  end

  # construct ssa by injecting MLIR context and block in the environment
  defp construct_ssa(
         {:>>>, _line, [{_call, line, args}, results]},
         evaluator
       ) do
    quote do
      args = [unquote_splicing(args)] |> List.flatten()
      results = unquote(results) |> List.wrap() |> List.flatten()

      {loc, args} =
        Enum.reduce(args, {nil, []}, fn arg, {loc, args} ->
          case arg do
            %MLIR.Location{} = loc ->
              {loc, args}

            a ->
              {loc, args ++ [a]}
          end
        end)

      loc =
        loc ||
          Beaver.MLIR.Location.file(
            name: __ENV__.file,
            line: Keyword.get(unquote(line), :line),
            ctx: Beaver.Env.context()
          )

      %Beaver.SSA{evaluator: unquote(evaluator)}
      |> Beaver.SSA.put_location(loc)
      |> Beaver.SSA.put_arguments(args)
      |> Beaver.SSA.put_results(results)
      |> Beaver.SSA.put_block(Beaver.Env.block())
      |> Beaver.SSA.put_ctx(Beaver.Env.context())
    end
  end

  # block arguments
  defp do_transform(
         {:>>>, _, [var = {_var_name, _, nil}, type]},
         _evaluator
       ) do
    quote do
      {unquote(var), unquote(type)}
    end
  end

  # with do block
  defp do_transform(
         {:>>>, line0, [{call, line, [args, [do: ast_block]]}, results]},
         evaluator
       ) do
    quote do
      unquote(
        construct_ssa(
          {:>>>, line0, [{call, line, args}, results]},
          evaluator
        )
      )
      |> Beaver.SSA.put_filler(fn -> unquote(ast_block) end)
      |> unquote({call, line, []})
    end
  end

  # op creation
  defp do_transform(
         {:>>>, _line, [{call, line, _args}, _results]} = ast,
         evaluator
       ) do
    quote do
      unquote(construct_ssa(ast, evaluator))
      |> unquote({call, line, []})
    end
  end

  defp do_transform(ast, _evaluator), do: ast

  def prewalk(ast, evaluator) when is_function(evaluator, 2) do
    Macro.prewalk(ast, &do_transform(&1, evaluator))
  end

  def postwalk(ast, evaluator) when is_function(evaluator, 2) do
    Macro.postwalk(ast, &do_transform(&1, evaluator))
  end
end