lib/beaver/native/memory.ex

defmodule Beaver.Native.Memory do
  alias Beaver.Native

  @moduledoc """
  A piece of memory managed by BEAM and can by addressed by a generated native function as MLIR MemRef descriptor
  """

  @enforce_keys [:descriptor]
  defstruct storage: nil, descriptor: nil

  defp shape_to_descriptor_kind(type, []) do
    Module.concat([type, MemRef.DescriptorUnranked])
  end

  defp shape_to_descriptor_kind(type, list) when is_list(list) do
    rank = length(list)
    Module.concat([type, MemRef, "Descriptor#{rank}D"])
  end

  defp dim_product(dims) when is_list(dims) and length(dims) > 0 do
    dims |> Enum.reduce(&*/2)
  end

  defp infer_dense_strides([_], strides) when is_list(strides) do
    strides ++ [1]
  end

  defp infer_dense_strides([_ | tail], strides) when is_list(strides) do
    infer_dense_strides(tail, strides ++ [dim_product(tail)])
  end

  def dense_strides([]) do
    []
  end

  def dense_strides(shape) when is_list(shape) do
    infer_dense_strides(shape, [])
  end

  defp pair_to_mod({:u, 8}), do: Native.U8
  defp pair_to_mod({:u, 16}), do: Native.U16
  defp pair_to_mod({:u, 32}), do: Native.U32
  defp pair_to_mod({:u, 64}), do: Native.U64
  defp pair_to_mod({:s, 8}), do: Native.I8
  defp pair_to_mod({:s, 16}), do: Native.I16
  defp pair_to_mod({:s, 32}), do: Native.I32
  defp pair_to_mod({:s, 64}), do: Native.I64
  defp pair_to_mod({:f, 32}), do: Native.F32
  defp pair_to_mod({:c, 64}), do: Native.Complex.F32
  defp pair_to_mod(mod) when is_atom(mod), do: mod

  defp extract_mod_from_opts(opts) do
    Keyword.fetch!(opts, :type)
    |> pair_to_mod()
  end

  def new(data, opts \\ [offset: 0])

  def new(data, opts) when is_list(data) or is_binary(data) do
    mod = extract_mod_from_opts(opts)

    array =
      if data do
        Native.array(data, mod, mut: true)
      end

    new(array, opts)
  end

  def new(%Native.Array{ref: ref, element_kind: mod} = array, opts) do
    offset = Keyword.get(opts, :offset, 0)
    sizes = Keyword.fetch!(opts, :sizes)

    strides =
      Keyword.get(
        opts,
        :strides,
        dense_strides(sizes)
      )

    %__MODULE__{
      storage: array,
      descriptor:
        __MODULE__.Descriptor.make(shape_to_descriptor_kind(mod, sizes), {
          ref,
          ref,
          offset,
          sizes,
          strides
        })
    }
  end

  def new(nil = storage, opts) do
    offset = Keyword.get(opts, :offset, 0)
    sizes = Keyword.fetch!(opts, :sizes)
    mod = extract_mod_from_opts(opts)

    strides =
      Keyword.get(
        opts,
        :strides,
        dense_strides(sizes)
      )

    %__MODULE__{
      storage: storage,
      descriptor:
        __MODULE__.Descriptor.make(shape_to_descriptor_kind(mod, sizes), {
          nil,
          nil,
          offset,
          sizes,
          strides
        })
    }
  end

  @doc """
  return a opaque pointer to the memory
  """
  def aligned(%__MODULE__{
        descriptor: d,
        storage: storage
      }) do
    ptr = __MODULE__.Descriptor.aligned(d)

    if storage do
      ptr |> Native.bag(storage)
    else
      ptr
    end
  end

  def allocated(%__MODULE__{
        descriptor: d,
        storage: storage
      }) do
    ptr = __MODULE__.Descriptor.allocated(d)

    if storage do
      ptr |> Native.bag(storage)
    else
      ptr
    end
  end

  @doc """
  return a opaque pointer to the memory descriptor. Usually used in the invoking of a generated function.
  If it is a array, will return the pointer of the array to mimic a struct of packed memory descriptors
  """
  def descriptor_ptr(%__MODULE__{
        descriptor: d,
        storage: storage
      }) do
    __MODULE__.Descriptor.opaque_ptr(d)
    |> Native.bag(storage)
  end

  # if this is an array, this should be packed memory descriptors for tuple
  def descriptor_ptr(%Native.Array{ref: ref, element_kind: element_kind} = array) do
    ref = Native.forward(element_kind, :ptr_to_opaque, [ref])
    struct!(Native.OpaquePtr, ref: ref) |> Native.bag(array)
  end

  @doc """
  take ownership of the memory the descriptor's `allocated` field points to
  """
  def own_allocated(
        %__MODULE__{
          storage: nil
        } = m
      ) do
    owner =
      m
      |> allocated
      |> Native.PtrOwner.new()

    %{m | storage: owner}
  end
end