Skip to main content

lib/ex_burn/tensor.ex

defmodule ExBurn.Tensor do
  @moduledoc """
  Tensor conversion utilities between Nx and Burn formats.

  Handles marshaling of tensor data between Elixir's Nx tensor
  representation and the Rust/Burn tensor references used by the NIF.

  ## Type Mapping

  | Nx Type    | Burn Type |
  |------------|-----------|
  | `{:f, 32}` | `:f32`    |
  | `{:f, 64}` | `:f64`    |
  | `{:f, 16}` | `:f32`    |
  | `{:bf, 16}`| `:f32`    |
  | `{:s, 32}` | `:i32`    |
  | `{:s, 64}` | `:i64`    |
  | `{:s, 16}` | `:i32`    |
  | `{:s, 8}`  | `:i32`    |
  | `{:u, 8}`  | `:f32`    |
  """

  @typedoc "Burn element type tag"
  @type burn_type :: :f32 | :f64 | :i32 | :i64

  @typedoc "NIF tensor reference"
  @type t :: %__MODULE__{
          ref: reference(),
          shape: [non_neg_integer()],
          type: burn_type()
        }

  defstruct [:ref, :shape, :type]

  # ── Type Conversions ─────────────────────────────────────────────

  @doc "Converts an Nx type tuple to a Burn type tag."
  @spec nx_type_to_burn(Nx.Type.t()) :: burn_type()
  def nx_type_to_burn(type)
  def nx_type_to_burn({:f, 32}), do: :f32
  def nx_type_to_burn({:f, 64}), do: :f64
  def nx_type_to_burn({:f, 16}), do: :f32
  def nx_type_to_burn({:bf, 16}), do: :f32
  def nx_type_to_burn({:s, 32}), do: :i32
  def nx_type_to_burn({:s, 64}), do: :i64
  def nx_type_to_burn({:s, 16}), do: :i32
  def nx_type_to_burn({:s, 8}), do: :i32
  def nx_type_to_burn({:u, 8}), do: :f32
  def nx_type_to_burn(_), do: :f32

  @doc "Converts a Burn type tag back to an Nx type tuple."
  @spec burn_type_to_nx(burn_type()) :: Nx.Type.t()
  def burn_type_to_nx(type)
  def burn_type_to_nx(:f32), do: {:f, 32}
  def burn_type_to_nx(:f64), do: {:f, 64}
  def burn_type_to_nx(:i32), do: {:s, 32}
  def burn_type_to_nx(:i64), do: {:s, 64}
  def burn_type_to_nx(_), do: {:f, 32}

  # ── Nx ↔ Burn Conversion ─────────────────────────────────────────

  @doc """
  Converts an `Nx.Tensor.t()` into an `ExBurn.Tensor.t()`.

  The tensor data is sent to the Rust NIF layer as a flat binary.
  Returns `{:ok, tensor}` or `{:error, reason}`.
  """
  @spec from_nx(Nx.Tensor.t()) :: {:ok, t()} | {:error, String.t()}
  def from_nx(%Nx.Tensor{} = tensor) do
    data = Nx.to_binary(tensor)
    shape = Nx.shape(tensor) |> Tuple.to_list()
    type = nx_type_to_burn(Nx.type(tensor))

    try do
      ref = ExBurn.Nif.new_tensor(data, shape, Atom.to_string(type))
      {:ok, %__MODULE__{ref: ref, shape: shape, type: type}}
    rescue
      e -> {:error, Exception.message(e)}
    end
  end

  @doc """
  Converts an `ExBurn.Tensor.t()` back into an `Nx.Tensor.t()`.

  Reads the raw data from the Rust NIF layer and reshapes it.
  Returns `{:ok, tensor}` or `{:error, reason}`.
  """
  @spec to_nx(t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
  def to_nx(%__MODULE__{ref: ref, shape: shape, type: type}) do
    nx_type = burn_type_to_nx(type)

    try do
      binary = ExBurn.Nif.tensor_to_binary(ref)

      tensor =
        binary
        |> Nx.from_binary(nx_type)
        |> Nx.reshape(List.to_tuple(shape))

      {:ok, tensor}
    rescue
      e -> {:error, Exception.message(e)}
    end
  end

  @doc """
  Batch converts a list of Nx tensors to Burn tensors.

  More efficient than calling `from_nx/1` individually when you need
  to convert many tensors.
  """
  @spec from_nx_batch([Nx.Tensor.t()]) :: {:ok, [t()]} | {:error, String.t()}
  def from_nx_batch(tensors) do
    Enum.reduce_while(tensors, {:ok, []}, fn tensor, {:ok, acc} ->
      case from_nx(tensor) do
        {:ok, bt} -> {:cont, {:ok, [bt | acc]}}
        {:error, reason} -> {:halt, {:error, reason}}
      end
    end)
    |> case do
      {:ok, list} -> {:ok, Enum.reverse(list)}
      error -> error
    end
  end

  @doc """
  Batch converts a list of Burn tensors to Nx tensors.
  """
  @spec to_nx_batch([t()]) :: {:ok, [Nx.Tensor.t()]} | {:error, String.t()}
  def to_nx_batch(tensors) do
    Enum.reduce_while(tensors, {:ok, []}, fn bt, {:ok, acc} ->
      case to_nx(bt) do
        {:ok, nx} -> {:cont, {:ok, [nx | acc]}}
        {:error, reason} -> {:halt, {:error, reason}}
      end
    end)
    |> case do
      {:ok, list} -> {:ok, Enum.reverse(list)}
      error -> error
    end
  end

  # ── Direct Construction ──────────────────────────────────────────

  @doc "Creates a Burn tensor from raw binary data, shape, and type."
  @spec from_binary(binary(), [non_neg_integer()], burn_type()) ::
          {:ok, t()} | {:error, String.t()}
  def from_binary(data, shape, type) do
    try do
      ref = ExBurn.Nif.new_tensor(data, shape, Atom.to_string(type))
      {:ok, %__MODULE__{ref: ref, shape: shape, type: type}}
    rescue
      e -> {:error, Exception.message(e)}
    end
  end

  # ── Inspection ───────────────────────────────────────────────────

  @doc "Returns the shape of a Burn tensor."
  @spec shape(t()) :: [non_neg_integer()]
  def shape(%__MODULE__{shape: shape}), do: shape

  @doc "Returns the Burn element type of a tensor."
  @spec type(t()) :: burn_type()
  def type(%__MODULE__{type: type}), do: type

  @doc "Returns the NIF reference for a tensor."
  @spec ref(t()) :: reference()
  def ref(%__MODULE__{ref: ref}), do: ref

  @doc "Returns the total number of elements."
  @spec numel(t()) :: non_neg_integer()
  def numel(%__MODULE__{shape: shape}), do: Enum.product(shape)

  @doc "Returns the rank (number of dimensions)."
  @spec rank(t()) :: non_neg_integer()
  def rank(%__MODULE__{shape: shape}), do: length(shape)

  # ── Memory ───────────────────────────────────────────────────────

  @doc "Frees the underlying Rust tensor."
  @spec free(t()) :: :ok
  def free(%__MODULE__{ref: ref}), do: ExBurn.Nif.free_tensor(ref)
end