defmodule Beaver.MLIR.Operation do
@moduledoc """
This module defines functions working with MLIR #{__MODULE__ |> Module.split() |> List.last()}.
"""
alias Beaver.MLIR
alias Beaver.MLIR.CAPI
import Beaver.MLIR.CAPI
require Logger
use Kinda.ResourceKind,
forward_module: Beaver.Native
@doc false
def create(%MLIR.Operation.State{} = state) do
state |> MLIR.Operation.State.create() |> create
end
def create(state) do
state |> Beaver.Native.ptr() |> Beaver.Native.bag(state) |> MLIR.CAPI.mlirOperationCreate()
end
defp create(op_name, %Beaver.SSA{
block: %MLIR.Block{} = block,
arguments: arguments,
results: results,
filler: filler,
ctx: ctx,
loc: loc
}) do
filler =
if is_function(filler, 0) do
[regions: filler]
else
[]
end
create_and_append(ctx, op_name, arguments ++ [result_types: results] ++ filler, block, loc)
end
# one single value, usually a terminator
defp create(op_name, %MLIR.Value{} = op) do
create(op_name, [op])
end
@doc false
def create_and_append(
%MLIR.Context{} = ctx,
op_name,
arguments,
%MLIR.Block{} = block,
loc \\ nil
)
when is_list(arguments) do
op = do_create(ctx, op_name, arguments, loc)
Beaver.MLIR.CAPI.mlirBlockAppendOwnedOperation(block, op)
op
end
def results(%MLIR.Operation{} = op) do
case CAPI.mlirOperationGetNumResults(op) |> Beaver.Native.to_term() do
0 ->
op
1 ->
CAPI.mlirOperationGetResult(op, 0)
n when n > 1 ->
for i <- 0..(n - 1)//1 do
CAPI.mlirOperationGetResult(op, i)
end
end
end
def results({:deferred, {_func_name, _arguments}} = deferred) do
deferred
end
defp do_create(ctx, op_name, arguments, loc) when is_binary(op_name) and is_list(arguments) do
location = loc || MLIR.Location.unknown()
state = %MLIR.Operation.State{name: op_name, location: location, context: ctx}
state = Enum.reduce(arguments, state, &MLIR.Operation.State.add_argument(&2, &1))
state
|> MLIR.Operation.State.create()
|> create()
end
@default_verify_opts [debug: false]
def verify!(op, opts \\ @default_verify_opts) do
case verify(op, opts ++ [should_raise: true]) do
{:ok, op} ->
op
:null ->
raise "MLIR operation verification failed because the operation is null. Maybe it is parsed from an ill-formed text format?"
:fail ->
raise "MLIR operation verification failed"
end
end
def verify(op, opts \\ @default_verify_opts) do
debug = opts |> Keyword.get(:debug, false)
is_null = MLIR.is_null(op)
if is_null do
:null
else
is_success = from_module(op) |> MLIR.CAPI.mlirOperationVerify() |> Beaver.Native.to_term()
if is_success do
{:ok, op}
else
if debug do
Logger.info("Start printing op failed to pass the verification. This might crash.")
Logger.info(MLIR.to_string(op))
end
:fail
end
end
end
def dump(op) do
op |> from_module |> mlirOperationDump()
op
end
@doc """
Verify the op and dump it. It raises if the verification fails.
"""
def dump!(%MLIR.Operation{} = op) do
verify!(op)
mlirOperationDump(op)
op
end
def name(%MLIR.Operation{} = operation) do
MLIR.CAPI.mlirOperationGetName(operation)
|> MLIR.CAPI.mlirIdentifierStr()
|> MLIR.StringRef.to_string()
end
def from_module(%MLIR.Module{} = module) do
CAPI.mlirModuleGetOperation(module)
end
def from_module(%MLIR.Operation{} = op) do
op
end
@doc false
def eval_ssa(full_name, %Beaver.SSA{results: result_types} = ssa) do
ssa =
case result_types do
[{:op, result_types}] ->
%Beaver.SSA{ssa | results: List.wrap(result_types)}
_ ->
ssa
end
op = create(full_name, ssa)
results = op |> results()
case result_types do
[{:op, _}] ->
{op, results}
_ ->
results
end
end
end