lib/beaver.ex

defmodule Beaver do
  alias Beaver.MLIR
  require Beaver.MLIR.CAPI

  @moduledoc """
  This module contains top level functions and macros for Beaver DSL for MLIR.

  Here are some of the examples of most common forms of Beaver DSL:
  - single result
  ```
  res = TOSA.add(a, b) >>> res_t
  ```
  - multiple results
  ```
  [res1, res2] = TOSA.add(a, b) >>> [res_t, res_t]
  [res1, res2] = TOSA.add(a, b) >>> res_t_list
  res_list = TOSA.add(a, b) >>> res_t_list
  ```
  - infer results
  ```
  TOSA.add(a, b) >>> :infer
  ```
  - with op
  ```
  {op, res} = TOSA.add(a, b) >>> {:op, res_t}
  ```
  - with no result
  ```
  TOSA.add(a, b) >>> []
  ```
  """

  defmacro __using__(_) do
    quote do
      require Beaver.MLIR.CAPI
      require Beaver.Env
      import Beaver
      alias Beaver.MLIR
      import MLIR.Sigils
    end
  end

  @doc """
  This is a macro where Beaver's MLIR DSL expressions get transformed to MLIR API calls.
  This transformation will works on any expression of this form, so it is also possible to call any other function/macro rather than an Op creation function. There is one operator `>>>` for typing the result of the SSA or an argument of a block. It kind of works like the `::` in specs and types of Elixir.

  ## How it works under the hood
  ```
  mlir do
    [res0, res1] = TestDialect.some_op(operand0, operand1, attr0: ~a{11 : i32}) >>> ~t{f32}
  end
  # will be transform to:
  [res0, res1] =
    %DSL.SSA{}
    |> DSL.SSA.arguments([operand0, operand1, attr0: ~a{11 : i32}, attr1: ~a{11 : i32}])
    |> DSL.SSA.results([~t{f32}])
    |> TestDialect.some_op()
  ```
  The SSA form will return:
  - For op with multiple result: the results of this op in a list.
  - For op with one single result: the result
  - For op with no result: the op itself (for instance, module, func, and terminators)

  If there is no returns, add a `[]` to make the transformation effective:
  ```
  TestDialect.some_op(operand0) >>> []
  ```
  To defer the creation of a terminator in case its successor block has not been created. You can pass an atom of the name in the block's call form.
  ```
  CF.cond_br(cond0, :bb1, {:bb2, [v0]})  >>> []
  ```
  To create region, call the op with a do block. The block macro works like the function definition in Elixir, and in the do block of `block` macro you can reference an argument by name. One caveat is that if it is a Op with region, it requires all arguments to be passed in one list to make it to call the macro version of the Op creation function.
  ```
  TestDialect.op_with_region [operand0, attr0: ~a{1}i32] do
    region do
      block(arg >>> ~t{f32}) do
        TestDialect.some_op(arg) >>> ~t{f32}
      end
    end
  end >>> ~t{f32}
  ```
  """
  defmacro mlir(do: dsl_block) do
    quote do
      Beaver.mlir [] do
        unquote(dsl_block)
      end
    end
  end

  defmacro mlir(opts, do: dsl_block) do
    dsl_block_ast = dsl_block |> Beaver.SSA.prewalk(&MLIR.Operation.eval_ssa/2)

    ctx_ast =
      if Keyword.has_key?(opts, :ctx) do
        quote do
          ctx = Keyword.fetch!(unquote(opts), :ctx)
          Kernel.var!(beaver_internal_env_ctx) = ctx
          %MLIR.Context{} = Kernel.var!(beaver_internal_env_ctx)
        end
      end

    block_ast =
      if Keyword.has_key?(opts, :block) do
        quote do
          block = Keyword.fetch!(unquote(opts), :block)
          Kernel.var!(beaver_internal_env_block) = block
          %MLIR.Block{} = Kernel.var!(beaver_internal_env_block)
        end
      end

    quote do
      require Beaver.Env
      alias Beaver.MLIR
      require Beaver.MLIR
      alias Beaver.MLIR.Type
      alias Beaver.MLIR.Attribute
      alias Beaver.MLIR.ODS
      import Beaver.MLIR.Sigils
      import Beaver.MLIR.Dialect.Builtin

      unquote(ctx_ast)
      unquote(block_ast)
      unquote(dsl_block_ast)
    end
  end

  defmacro block(call, do: block) do
    {
      _block_args,
      _block_opts,
      args_type_ast,
      args_var_ast,
      locations_var_ast,
      block_arg_var_ast
    } = Beaver.BlockDSL.transform_call(call)

    {block_id, _} = Macro.decompose_call(call)
    if not is_atom(block_id), do: raise("block name must be an atom")

    region_insert_ast =
      quote do
        if region = Beaver.Env.region() do
          # insert the block to region
          Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(region, Beaver.Env.block())
        end
      end

    {bb_name, _} = call |> Macro.decompose_call()

    block_creation_ast =
      if Macro.Env.has_var?(__CALLER__, {bb_name, nil}) do
        quote do
          _args =
            Kernel.var!(unquote({bb_name, [], nil}))
            |> Beaver.MLIR.Block.add_arg!(
              Beaver.Env.context(),
              Enum.zip(block_arg_types, block_arg_locs)
            )

          Kernel.var!(unquote({bb_name, [], nil}))
        end
      else
        quote do
          Beaver.MLIR.Block.create(
            block_arg_types |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context())),
            block_arg_locs |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context()))
          )
        end
      end

    block_ast =
      quote do
        require Beaver.Env
        unquote_splicing(args_type_ast)
        block_arg_types = [unquote_splicing(args_var_ast)]
        block_arg_locs = [unquote_splicing(locations_var_ast)]

        # can't put code here inside a function like Region.under, because we need to support uses across blocks

        Kernel.var!(beaver_internal_env_block) = unquote(block_creation_ast)

        unquote(region_insert_ast)

        %MLIR.Block{} = Kernel.var!(beaver_internal_env_block)
        unquote_splicing(block_arg_var_ast)
        unquote(block)

        Kernel.var!(beaver_internal_env_block)
      end

    block_ast
  end

  defmacro region(do: block) do
    regions =
      if Macro.Env.has_var?(__CALLER__, {:beaver_internal_env_regions, nil}) do
        quote do
          Kernel.var!(beaver_internal_env_regions)
        end
      else
        quote do
          Kernel.var!(beaver_internal_env_regions) = []
        end
      end

    quote do
      require Beaver.Env
      region = Beaver.MLIR.CAPI.mlirRegionCreate()
      unquote(regions)

      Beaver.MLIR.Region.under(region, fn ->
        Kernel.var!(beaver_env_region) = region
        %Beaver.MLIR.Region{} = Kernel.var!(beaver_env_region)
        unquote(block)
      end)

      Kernel.var!(beaver_internal_env_regions) =
        Kernel.var!(beaver_internal_env_regions) ++
          [region]

      Kernel.var!(beaver_internal_env_regions)
    end
  end

  def _call >>> _results do
    raise(
      "`>>>` operator is expected to be transformed away. Maybe you forget to put the expression inside the Beaver.mlir/1 macro's do block?"
    )
  end
end