lib/axon_onnx.ex

defmodule AxonOnnx do
  @moduledoc """
  Library for converting to and from Axon/ONNX.

  [ONNX](https://github.com/onnx/onnx) is a Neural Network specification
  supported by most popular deep learning frameworks such as PyTorch
  and TensorFlow. AxonOnnx allows you to convert to and from ONNX
  models via a simple import/export API.

  You can import supported ONNX models using `AxonOnnx.import/2`:

      {model, params} = AxonOnnx.import("model.onnx")

  `model` will be an Axon struct and `params` will be a compatible
  model state.

  You can export supported models using `AxonOnnx.export/4`:

      AxonOnnx.export(model, templates, params)
  """

  @doc """
  Imports an ONNX model from the given path.

  Some models support ONNX `dim_params` which you may specify
  by providing dimension names as a keyword list:

      AxonOnnx.import("model.onnx", batch: 1)

  The imported model will be in the form:

      {model, params} = AxonOnnx.import("model.onnx")
  """
  def import(path, dimensions \\ []) do
    path
    |> File.read!()
    |> AxonOnnx.Deserialize.__load__(dimensions)
  end

  @doc """
  Loads an ONNX model into an Axon model from the given binary.

  Some models support ONNX `dim_params` which you may specify
  by providing dimension names as a keyword list:

      onnx = File.read!("model.onnx")
      AxonOnnx.load(onnx, batch: 1)

  The imported model will be in the form:

      {model, params} = AxonOnnx.import(onnx)
  """
  def load(onnx, dimensions \\ []), do: AxonOnnx.Deserialize.__load__(onnx, dimensions)

  @doc """
  Exports an Axon model and parameters to an ONNX model
  with the given input templates.

  You may optionally specify a `path` to export a model to
  a specific file path:

      AxonOnnx.export(model, templates, params, path: "resnet.onnx")
  """
  def export(%Axon{} = model, templates, params, opts \\ []) do
    {encoded, output_name} = AxonOnnx.Serialize.__dump__(model, templates, params, opts)

    fname = opts[:path] || output_name <> ".onnx"

    {:ok, file} = File.open(fname, [:write])
    IO.binwrite(file, encoded)
    File.close(file)
  end

  @doc """
  Dumps an Axon model and parameters into a binary representing
  and ONNX model.
  """
  def dump(%Axon{} = model, templates, params, opts \\ []) do
    {encoded, _} = AxonOnnx.Serialize.__dump__(model, templates, params, opts)
    encoded
  end
end