Skip to main content

lib/onnx_runtime/model.ex

defmodule OnnxRuntime.Model do
  @moduledoc """
  A loaded ONNX Runtime inference session.
  """

  @enforce_keys [:reference]
  defstruct [:reference]

  @doc false
  def load(path, execution_providers \\ [:cpu], optimization_level \\ 3) do
    path
    |> OnnxRuntime.Native.init(execution_providers, optimization_level)
    |> then(&%__MODULE__{reference: &1})
  end

  @doc false
  def run(%__MODULE__{} = model, tensor) when not is_tuple(tensor) do
    run(model, {tensor})
  end

  def run(%__MODULE__{reference: model}, tensors) when is_tuple(tensors) do
    inputs =
      tensors
      |> Tuple.to_list()
      |> Enum.map(&Nx.backend_transfer(&1, OnnxRuntime.Backend))
      |> Enum.map(fn %Nx.Tensor{data: %OnnxRuntime.Backend{ref: ref}} -> ref end)

    model
    |> OnnxRuntime.Native.run(inputs)
    |> Enum.map(fn {ref, shape, kind, bits} ->
      shape = List.to_tuple(shape)

      %Nx.Tensor{
        data: %OnnxRuntime.Backend{ref: ref},
        shape: shape,
        type: {kind, bits},
        names: List.duplicate(nil, tuple_size(shape))
      }
    end)
    |> List.to_tuple()
  end
end

defimpl Inspect, for: OnnxRuntime.Model do
  import Inspect.Algebra

  def inspect(%OnnxRuntime.Model{reference: model}, inspect_opts) do
    {inputs, outputs} = OnnxRuntime.Native.show_session(model)

    force_unfit(
      concat([
        color("#OnnxRuntime.Model<", :map, inspect_opts),
        line(),
        nest(concat(["  inputs: ", Inspect.Algebra.to_doc(inputs, inspect_opts)]), 2),
        line(),
        nest(concat(["  outputs: ", Inspect.Algebra.to_doc(outputs, inspect_opts)]), 2),
        color(">", :map, inspect_opts)
      ])
    )
  end
end