lib/beaver/mlir/execution_engine.ex

defmodule Beaver.MLIR.ExecutionEngine do
  @moduledoc """
  This module defines functions working with MLIR #{__MODULE__ |> Module.split() |> List.last()}.
  """
  alias Beaver.MLIR
  alias Beaver.MLIR.Pass.Composer
  import Beaver.MLIR.CAPI

  def is_null(jit) do
    jit
    |> beaverMlirExecutionEngineIsNull()
    |> Beaver.Native.to_term()
  end

  @doc """
  Create a MLIR JIT engine for a module and check if successful. Usually this module should be of LLVM dialect.
  """
  def create!(%Composer{} = composer_or_op) do
    Composer.run!(composer_or_op) |> create!()
  end

  def create!(module, opts \\ []) do
    shared_lib_paths = Keyword.get(opts, :shared_lib_paths, [])

    shared_lib_paths_ptr =
      shared_lib_paths
      |> Enum.map(&MLIR.StringRef.create/1)
      |> Beaver.Native.array(MLIR.StringRef)

    require MLIR.Context

    jit =
      mlirExecutionEngineCreate(
        module,
        2,
        length(shared_lib_paths),
        shared_lib_paths_ptr,
        false
      )

    is_null = is_null(jit)

    if is_null do
      raise "Execution engine creation failed"
    end

    jit
  end

  defp do_invoke!(jit, symbol, arg_ptr_list) do
    mlirExecutionEngineInvokePacked(
      jit,
      MLIR.StringRef.create(symbol),
      Beaver.Native.array(arg_ptr_list, Beaver.Native.OpaquePtr, mut: true)
    )
  end

  @doc """
  invoke a function by symbol name.
  """
  def invoke!(jit, symbol, args, return) when is_list(args) do
    arg_ptr_list = args |> Enum.map(&Beaver.Native.opaque_ptr/1)
    return_ptr = return |> Beaver.Native.opaque_ptr()
    result = do_invoke!(jit, symbol, arg_ptr_list ++ [return_ptr])

    if MLIR.LogicalResult.success?(result) do
      return
    else
      raise "Execution engine invoke failed"
    end
  end

  @doc """
  invoke a void function by symbol name.
  """
  def invoke!(jit, symbol, args) when is_list(args) do
    arg_ptr_list = args |> Enum.map(&Beaver.Native.opaque_ptr/1)
    result = do_invoke!(jit, symbol, arg_ptr_list)

    if MLIR.LogicalResult.success?(result) do
      :ok
    else
      raise "Execution engine invoke failed"
    end
  end

  def destroy(jit) do
    mlirExecutionEngineDestroy(jit)
  end
end