defmodule Beaver.MLIR.Dialect do
@moduledoc """
This module defines macro to generate code for an MLIR dialect.
"""
alias Beaver.MLIR.Dialect
require Logger
@callback eval_ssa(String.t(), Beaver.SSA.t()) :: any()
defmacro __using__(opts) do
dialect = Keyword.fetch!(opts, :dialect)
ops = Keyword.fetch!(opts, :ops)
quote(bind_quoted: [dialect: dialect, ops: ops]) do
@behaviour Beaver.MLIR.Dialect
def eval_ssa(full_name, %Beaver.SSA{evaluator: evaluator} = ssa)
when is_function(evaluator, 2) do
evaluator.(full_name, ssa)
end
defoverridable eval_ssa: 2
require Logger
dialect_module_name = dialect |> Beaver.MLIR.Dialect.Registry.normalize_dialect_name()
Logger.debug(
"[Beaver] building Elixir module for dialect #{dialect} => #{dialect_module_name} (#{length(ops)})"
)
func_names =
for op <- ops do
func_name = Beaver.MLIR.Dialect.Registry.normalize_op_name(op)
full_name = Enum.join([dialect, op], ".")
def unquote(func_name)(ssa) do
eval_ssa(unquote(full_name), ssa)
end
defoverridable [{func_name, 1}]
func_name
end
if length(func_names) != MapSet.size(MapSet.new(func_names)) do
raise "duplicate op name found in dialect: #{dialect}"
end
end
end
def dialects() do
for d <- Dialect.Registry.dialects() do
module_name = d |> Dialect.Registry.normalize_dialect_name()
Module.concat([__MODULE__, module_name])
end
end
defmacro define_modules(name) do
quote bind_quoted: [d: name] do
alias Beaver.MLIR.Dialect
module_name = d |> Dialect.Registry.normalize_dialect_name()
module_name = Module.concat([Beaver.MLIR.Dialect, module_name])
ops = Dialect.Registry.ops(d)
defmodule module_name do
use Beaver.MLIR.Dialect,
dialect: d,
ops: ops
end
end
end
end