lib/beaver/dsl/pattern.ex

defmodule Beaver.Pattern.Env do
  @moduledoc false
  defstruct ctx: nil, block: nil, loc: nil
end

defmodule Beaver.Pattern do
  @moduledoc """
  Beaver pattern DSL for MLIR, a PDL frontend in Elixir.
  """
  alias Beaver.MLIR
  alias Beaver.MLIR.Dialect.PDL
  alias Beaver.MLIR.{Attribute, Type}
  import MLIR.Sigils
  import Beaver
  require Beaver.MLIR
  require Beaver.Env
  require Beaver.MLIR.CAPI
  alias Beaver.Pattern.Env

  defmacro defpat(call, do: block) do
    {name, _args} = Macro.decompose_call(call)
    block_ast = block |> Beaver.SSA.prewalk(&__MODULE__.eval_rewrite/2)

    pdl_pattern_module_op =
      quote do
        mlir ctx: ctx do
          module do
            benefit = Keyword.get(opts, :benefit, 1)

            PDL.pattern benefit: Attribute.integer(Type.i16(), benefit),
                        sym_name: "\"#{unquote(name)}\"" do
              region do
                block _pattern_block() do
                  unquote(block_ast)
                end
              end
            end >>> []
          end
        end
      end

    pdl_pattern_module_op_ast_dump = pdl_pattern_module_op |> Macro.to_string()

    quote do
      def unquote(name)(opts \\ [benefit: 1]) do
        Beaver.Deferred.from_opts(
          opts,
          fn ctx ->
            alias Beaver.Pattern
            pdl_pattern_module_op = unquote(pdl_pattern_module_op)

            case Beaver.MLIR.Operation.verify(pdl_pattern_module_op, debug: true) do
              {:ok, op} ->
                op

              :fail ->
                require Logger

                Logger.info("""
                Here is the code generated by the #{unquote(name)}:
                #{unquote(pdl_pattern_module_op_ast_dump)}
                """)

                raise "fail to verify generated pattern"
            end
          end
        )
      end
    end
  end

  defmacro attribute(a) do
    quote do
      Beaver.MLIR.Dialect.PDL.attribute(value: unquote(a)) >>>
        ~t{!pdl.attribute}
    end
  end

  defmacro type() do
    quote do
      mlir do
        Beaver.MLIR.Dialect.PDL.type() >>> ~t{!pdl.type}
      end
    end
  end

  defmacro type(t) do
    quote do
      mlir do
        Beaver.MLIR.Dialect.PDL.type(constantType: unquote(t)) >>> ~t{!pdl.type}
      end
    end
  end

  defmacro value() do
    quote do
      mlir do
        Beaver.MLIR.Dialect.PDL.operand() >>> ~t{!pdl.value}
      end
    end
  end

  defmacro rewrite(root, do: block) do
    rewrite_block_ast = block |> Beaver.SSA.prewalk(&__MODULE__.eval_rewrite/2)

    quote do
      mlir do
        Beaver.MLIR.Dialect.PDL.rewrite [
          unquote(root),
          operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0])
        ] do
          region do
            block _rewrite_block() do
              unquote(rewrite_block_ast)
            end
          end
        end >>> []
      end
    end
  end

  defmacro replace(root, opts) do
    quote do
      mlir do
        opts = unquote(opts)
        repl = opts |> Keyword.fetch!(:with)

        pdl_handler =
          case MLIR.Value.owner(repl) do
            {:ok, owner} ->
              owner |> Beaver.MLIR.Operation.name()

            _ ->
              raise "not a pdl handler"
          end

        case pdl_handler do
          "pdl.result" ->
            Beaver.MLIR.Dialect.PDL.replace([
              unquote(root),
              repl,
              operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 0, 1])
            ]) >>> []

          "pdl.operation" ->
            Beaver.MLIR.Dialect.PDL.replace([
              unquote(root),
              repl,
              operand_segment_sizes: Beaver.MLIR.ODS.operand_segment_sizes([1, 1, 0])
            ]) >>> []
        end
      end
    end
  end

  @doc false
  def gen_pdl(%Env{} = env, %MLIR.Type{} = type) do
    mlir block: env.block, ctx: env.ctx do
      Beaver.MLIR.Dialect.PDL.type(constantType: type) >>> ~t{!pdl.type}
    end
  end

  def gen_pdl(%Env{} = env, %MLIR.Attribute{} = attribute) do
    mlir block: env.block, ctx: env.ctx do
      Beaver.MLIR.Dialect.PDL.attribute(value: attribute) >>>
        ~t{!pdl.attribute}
    end
  end

  def gen_pdl(_env, %MLIR.Value{} = value) do
    value
  end

  def gen_pdl(%Env{ctx: ctx} = env, f) when is_function(f, 1) do
    gen_pdl(env, f.(ctx))
  end

  def gen_pdl(_env, element) do
    raise "fail to generate pdl handle for element: #{inspect(element)}"
  end

  @doc """
  The difference between a pdl.operation creation in a match body and a rewrite body:
  - in a match body, `pdl.attribute`/`pdl.operand`/`pdl.result` will be generated for unbound variables
  - in a rewrite body, all variables are considered bound before creation pdl ops
  """
  def create_operation(
        %Env{ctx: ctx, block: block, loc: loc} = env,
        op_name,
        operands,
        attributes,
        results
      )
      when is_list(attributes) do
    mlir block: block, ctx: ctx do
      results = results |> Enum.map(&gen_pdl(env, &1))

      attribute_names =
        for {k, _} <- attributes do
          k |> Atom.to_string() |> MLIR.Attribute.string()
        end

      attributes =
        for {_, a} <- attributes do
          a
        end
        |> Enum.map(&gen_pdl(env, &1))

      Beaver.MLIR.Dialect.PDL.operation(
        loc,
        operands,
        attributes,
        results,
        opName: Beaver.MLIR.Attribute.string(op_name),
        attributeValueNames: Beaver.MLIR.Attribute.array(attribute_names),
        operand_segment_sizes:
          Beaver.MLIR.ODS.operand_segment_sizes([
            length(operands),
            length(attributes),
            length(results)
          ])
      ) >>> ~t{!pdl.operation}
    end
  end

  @doc """
  Evaluate SSA as ops in a rewrite block. Note that function is only public so that it could be used in a AST.
  """
  def eval_rewrite(
        op_name,
        %Beaver.SSA{
          arguments: arguments,
          results: result_types,
          ctx: ctx,
          block: block,
          loc: loc
        }
      ) do
    attributes = for {_k, _a} = a <- arguments, do: a
    operands = for %MLIR.Value{} = o <- arguments, do: o
    env = %Env{ctx: ctx, block: block, loc: loc}

    result_types_unwrap =
      case result_types do
        [:infer] -> []
        [{:op, types}] -> types |> List.wrap()
        _ -> result_types
      end
      |> Enum.map(&gen_pdl(env, &1))

    op =
      create_operation(
        env,
        op_name,
        operands,
        attributes,
        result_types_unwrap
      )

    results =
      result_types_unwrap |> Enum.with_index() |> Enum.map(fn {_, i} -> result(env, op, i) end)

    results = if length(results) == 1, do: List.first(results), else: results

    case result_types do
      [{:op, _types}] ->
        {op, results}

      _ ->
        results
    end
  end

  defp result(%Env{block: block, ctx: ctx}, %Beaver.MLIR.Value{} = v, i)
       when is_integer(i) do
    mlir block: block, ctx: ctx do
      PDL.result(v, index: Beaver.MLIR.Attribute.integer(Beaver.MLIR.Type.i32(), i)) >>>
        ~t{!pdl.value}
    end
  end
end