lib/manx/nx/compiler.ex

defmodule Manx.Compiler do
  use Beaver
  alias Beaver.MLIR
  import MLIR.Sigils
  import Beaver, only: :macros
  require Beaver.MLIR
  alias Beaver.MLIR.Dialect.{Func}
  require Func
  @behaviour Nx.Defn.Compiler

  defp eval_arg(f) when is_function(f), do: f.()
  defp eval_arg(a), do: a

  defp runtime_libs() do
    case LLVMConfig.lib_dir() do
      {:ok, llvm_lib_dir} ->
        [
          llvm_lib_dir |> Path.join("libmlir_c_runner_utils.dylib")
        ]

      _ ->
        []
    end
  end

  defp vulkan_runtime_libs() do
    case LLVMConfig.lib_dir() do
      {:ok, llvm_lib_dir} ->
        [
          llvm_lib_dir |> Path.join("libvulkan-runtime-wrappers.dylib")
        ]

      _ ->
        []
    end
  end

  @impl true
  def __jit__(key, vars, fun, [args], _options) do
    # call fun to generate expression tree
    tree = fun.(vars)
    info = Function.info(key)
    uniq = info |> Keyword.get(:uniq)
    module = info |> Keyword.get(:module)
    name = info |> Keyword.get(:name)
    symbol = Module.concat([module, name, "#{uniq}"]) |> Atom.to_string()
    args = args |> Enum.map(&eval_arg/1)

    # generate ir
    entry_types =
      Enum.reduce(vars, [], fn
        tuple, acc when is_tuple(tuple) ->
          acc ++ Enum.map(Tuple.to_list(tuple), &Manx.Defn.gen_type/1)

        t, acc ->
          acc ++ [Manx.Defn.gen_type(t)]
      end)

    module_attrs =
      case args |> List.first() do
        arg0 when not is_nil(arg0) ->
          with %Manx{device: :vulkan} <- arg0.data do
            [
              "spirv.target_env":
                ~a"#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>"
            ]
          else
            _ -> []
          end

        _ ->
          []
      end

    ctx = MLIR.Context.create()

    ir =
      mlir ctx: ctx do
        module(module_attrs) do
          function_type =
            Type.function(
              entry_types,
              Manx.Defn.gen_root_types(tree)
            )

          Func.func manx_main(
                      sym_name: "\"#{symbol}\"",
                      function_type: function_type
                    ) do
            region do
              locs = List.duplicate(MLIR.Location.unknown(), length(entry_types))

              entry =
                MLIR.Block.create(
                  entry_types |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context())),
                  locs |> Enum.map(&Beaver.Deferred.create(&1, Beaver.Env.context()))
                )

              root = Manx.Defn.gen_op(%Manx.Defn.Env{block: entry, ctx: ctx}, tree)

              mlir block: entry do
                case root do
                  ret = %Beaver.MLIR.Value{} ->
                    Func.return(ret) >>> []

                  tuple_ret when is_tuple(tuple_ret) ->
                    Func.return(Tuple.to_list(tuple_ret)) >>> []
                end
              end

              Beaver.Env.region()
              |> Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(entry)
            end
          end
        end
      end

    {llvm_ir, libs} =
      case args |> List.first() do
        arg0 when not is_nil(arg0) ->
          case arg0.data do
            %Nx.BinaryBackend{} ->
              {Manx.Lowering.CPU.lower(ir), runtime_libs()}

            %Manx{device: device} ->
              case device do
                :host ->
                  {Manx.Lowering.CPU.lower(ir), runtime_libs()}

                :vulkan ->
                  {Manx.Lowering.Vulkan.lower(ir), vulkan_runtime_libs()}
              end
          end

        _ ->
          {Manx.Lowering.CPU.lower(ir), []}
      end

    llvm_ir =
      case llvm_ir do
        {:ok, op} ->
          op

        {:error, msg} ->
          MLIR.Context.destroy(ctx)
          raise msg
      end

    jit =
      llvm_ir
      |> MLIR.ExecutionEngine.create!(shared_lib_paths: libs)

    # invoke jit and setting return for tree
    tree_return =
      tree
      |> Manx.tensor_of_null_memref()
      |> invoke(args, jit, symbol)

    MLIR.CAPI.mlirContextDestroy(ctx)
    [tree_return]
  end

  @doc """
  Invoke MLIR JIT with Nx tensors. If there are tuples their memrefs will be packed into a single C struct.
  """

  def invoke(return, args, jit, symbol) do
    # pack the tensor tuples into a C struct
    jit_args =
      [return_struct | _] =
      [return | args]
      |> Enum.map(&memref_from_tensor/1)

    if List.improper?(jit_args), do: raise("jit arguments is not a proper list")

    MLIR.ExecutionEngine.invoke!(
      jit,
      symbol,
      jit_args |> Enum.map(&Beaver.Native.Memory.descriptor_ptr/1)
    )

    # unpack the C struct into tensor tuples
    populate_tensor_from_memref(return, return_struct)
    |> Manx.add_allocated_memory()
  end

  @doc """
  - If it is a tensor, return a memref
  - If it is a tuple, recursively pack them into one struct.
  """
  def memref_from_tensor(f) when is_function(f), do: f.() |> memref_from_tensor
  def memref_from_tensor(%Nx.Tensor{data: %Manx{memory: memory}}), do: memory

  def memref_from_tensor(
        %Nx.Tensor{
          data: %Nx.BinaryBackend{state: binary}
        } = tensor
      ) do
    Manx.from_binary(tensor, binary, []) |> memref_from_tensor
  end

  def memref_from_tensor(%Nx.Tensor{shape: shape, data: %Nx.TemplateBackend{}}) do
    # TODO: generate a magical deadbeef pointer for this
    Beaver.Native.Memory.new(nil, sizes: shape |> Tuple.to_list(), type: Beaver.Native.F32)
  end

  def memref_from_tensor({}) do
    raise "can't extract memref from an empty tuple"
  end

  def memref_from_tensor(tuple) when is_tuple(tuple) do
    mems =
      Tuple.to_list(tuple)
      |> Enum.map(&memref_from_tensor/1)

    # TODO: support array of memref descriptor of different kinds
    first = mems |> List.first()
    kind = first.descriptor.descriptor_kind

    refs =
      mems
      |> Enum.map(fn %Beaver.Native.Memory{descriptor: %Beaver.Native.Memory.Descriptor{ref: ref}} ->
        ref
      end)

    # TODO: add a raw NIF beaver_raw_create_heterogeneous_array, using union maybe
    mut_array = Beaver.Native.forward(kind, :mut_array, [refs])

    struct!(Beaver.Native.Array,
      element_kind: kind,
      ref: mut_array
    )
  end

  @doc """
  - If it is a tensor, return a memref
  - If it is a tuple, recursively unpack each member from the nested struct.
  """
  def populate_tensor_from_memref(%Nx.Tensor{data: %Manx{}} = tensor, memory) do
    %{tensor | data: %Manx{memory: memory}}
  end

  def populate_tensor_from_memref(
        tuple,
        %Beaver.Native.Array{element_kind: element_kind} = nested_struct
      )
      when is_tuple(tuple) do
    nested_struct_ptr = nested_struct |> Beaver.Native.Memory.descriptor_ptr()

    {tensors, _offset} =
      Enum.reduce(tuple |> Tuple.to_list(), {[], 0}, fn x, {acc, offset} ->
        {ref, size} =
          Beaver.Native.OpaquePtr.to_resource(
            element_kind,
            nested_struct_ptr,
            offset
          )

        mem = %Beaver.Native.Memory{
          descriptor: %Beaver.Native.Memory.Descriptor{
            ref: ref,
            descriptor_kind: element_kind
          }
        }

        {acc ++ [populate_tensor_from_memref(x, mem)], offset + size}
      end)

    tensors |> List.to_tuple()
  end
end