lib/axon_interp.ex

defmodule AxonInterp do
  @moduledoc """
  The thin wrapper for Axon inference.
  It is a Deep Learning inference framework that can be used in the same way as my *Interp.
  """

  @basic_usage """
  ## Basic Usage
  You get the trained axon model or onnx model and save it in a directory that your application can read.
  "your-app/priv" may be good choice.

  ```
  $ cp your-trained-model.axon ./priv
  ```

  Next, you will create a module that interfaces with the deep learning model.
  The module will need pre-processing and post-processing in addition to inference
  processing, as in the example following. AxonInterp provides inference processing
  only.

  You put `use AxonInterp` at the beginning of your module, specify the model path
  as an optional argument and describe model's inputs/outputs specification. In the
  inference, you will put data input to the model (`AxonInterp.set_input_tensor/3`),
  inference execution (`AxonInterp.invoke/1`), and inference result retrieval
  (`AxonInterp.get_output_tensor/2`).

  ```elixr:your_model.ex
  defmodule YourApp.YourModel do
    use AxonInterp,
      model: "priv/your-trained-model.onnx",
      inputs: [f32: {1, 3, 224, 224}],
      outputs: [f32: {1, 1000}]
      

    def predict(data) do
      # preprocess
      #  to convert the data to be inferred to the input format of the model.
      input_bin = convert-float32-binaries(data)

      # inference
      #  typical I/O data for Onnx models is a serialized 32-bit float tensor.
      output_bin = session()
        |> AxonInterp.set_input_tensor(0, input_bin)
        |> AxonInterp.invoke()
        |> AxonInterp.get_output_tensor(0)

      # postprocess
      #  add your post-processing here.
      #  you may need to reshape output_bin to tensor at first.
      tensor = output_bin
        |> Nx.from_binary({:f, 32})
        |> Nx.reshape({size-x, size-y, :auto})

      * your-postprocessing *
      ...
    end
  end
  ```
  """

  @timeout 300000

  @framework "axon"

  # the suffix expected for the model
  suffix = %{
    "axon" => [".axon", ".onnx"]
  }
  @model_suffix suffix[String.downcase(@framework)]
  
  # session record
  defstruct module: nil, inputs: %{}, outputs: %{}

  defmacro __using__(opts) do
    quote generated: true, location: :keep do
      use GenServer

      def start_link(opts) do
        GenServer.start_link(__MODULE__, opts, name: __MODULE__)
      end

      def init(opts) do
        opts = Keyword.merge(unquote(opts), opts)
        nn_model   = AxonInterp.validate_model(Keyword.get(opts, :model), Keyword.get(opts, :url))
        nn_inputs  = Keyword.get(opts, :inputs, [])
        nn_outputs = Keyword.get(opts, :outputs, [])

        {model, params} = case Path.extname(nn_model) do
          ".axon" -> File.read!(nn_model) |> Axon.deserialize()
          ".onnx" -> AxonOnnx.import(nn_model)
        end
        {_, predict_fn} = Axon.build(model, mode: :inference)

        {:ok, %{model: predict_fn, params: params, path: nn_model, itempl: nn_inputs, otempl: nn_outputs}}
      end

      def session() do
        %AxonInterp{module: __MODULE__}
      end

      def handle_call({:info}, _from, state) do
        info = %{
          "model"   => state.path,
          "inputs"  => state.itempl,
          "outputs" => state.otempl,
        }
        {:reply, {:ok, info}, state}
      end

      require Nx

      def handle_call({:invoke, inputs}, _from, %{model: model, params: params, itempl: template}=state) do
        n = Enum.count(template)

        inputs =
          if n == 1 do
            inputs[0]
          else
            Enum.map(0..n-1, &inputs[&1]) |> List.to_tuple()
          end

        results = case model.(params, inputs) do
          single when Nx.is_tensor(single) ->
            %{0 => single}
          multiple when is_tuple(multiple) ->
            Tuple.to_list(multiple)
            |> Enum.with_index()
            |> Enum.into(%{}, fn {result, index} -> {index, result} end)
        end
        
        {:reply, {:ok, results}, state}
      end

      def handle_call({:itempl, index}, _from, %{itempl: template}=state) do
        {:reply, {:ok, Enum.at(template, index)}, state}
      end

      def handle_call({:otempl, index}, _from, %{otempl: template}=state) do
        {:reply, {:ok, Enum.at(template, index)}, state}
      end

      def terminate(_reason, state) do
        :ok
      end
    end
  end


  @doc """
  Get name of backend NN framework.
  """
  def framework() do
    @framework
  end

  @doc """
  Ensure that the back-end framework is as expected.
  """
  def framework?(name) do
    unless String.downcase(name) == String.downcase(@framework),
      do: raise "Error: backend NN framework is \"#{@framework}\", not \"#{name}\"."
  end

  @doc """
  Ensure that the model matches the back-end framework.
  
  ## Parameters
    * model - path of model file
    * url - download site
  """
  def validate_model(nil, _), do: raise ArgumentError, "need a model file #{inspect(@model_suffix)}."
  def validate_model(model, url) do
    validate_extname!(model)

    abs_path = Path.expand(model)
    unless File.exists?(abs_path) do
      IO.puts("#{model}:")
      {:ok, _} = AxonInterp.URL.download(url, Path.dirname(abs_path), Path.basename(abs_path))
    end
    model
  end

  defp validate_extname!(model) do
    actual_ext = Path.extname(model)
    unless actual_ext in @model_suffix,
      do: raise ArgumentError, "expects the model file #{inspect(@model_suffix)} not \"#{actual_ext}\"."

    actual_ext
  end

  @doc """
  Get the propaty of the model.

  ## Parameters

    * mod - modules' names
  """
  def info(mod) do
    case GenServer.call(mod, {:info}, @timeout) do
      {:ok, result} ->  {:ok, Map.put(result, "framework", @framework)}
      any -> any
    end
  end

  @doc """
  Stop the axon interpreter.

  ## Parameters

    * mod - modules' names
  """
  def stop(mod) do
    GenServer.stop(mod)
  end

  @doc """
  Put a flat binary to the input tensor on the interpreter.

  ## Parameters

    * sessin - session record
    * index - index of input tensor in the model
    * bin   - input data - flat binary, cf. serialized tensor
    * opts
      * [:binary] - data is binary
  """
  def set_input_tensor(%AxonInterp{module: mod, inputs: inputs}=session, index, data, opts \\ []) do
    data = cond do
      :binary in opts ->
        with {:ok, {dtype, shape}} <- GenServer.call(mod, {:itempl, index}, @timeout) do
          Nx.from_binary(data, dtype) |> Nx.reshape(shape)
        end

      true -> data
      end

    %AxonInterp{session | inputs: Map.put(inputs, index, data)}
  end

  @doc """
  Get the flat binary from the output tensor on the interpreter.

  ## Parameters

    * session - session record
    * index - index of output tensor in the model
    * opts:
      * [:binary] - convert the output to binary
  """
  def get_output_tensor(%AxonInterp{outputs: outputs}, index, opts \\ []) do
    output = outputs[index]

    cond do
      :binary in opts -> Nx.to_binary(output)
      true -> output
    end
  end

  @doc """
  Execute the inference session. In session mode, data input/execution of
  inference/output of results to the DL model is done all at once.

  ## Parameters

    * session - session

  ## Examples.

    ```elixir
      output_bin0 = session()
        |> AxonInterp.set_input_tensor(0, input_bin0)
        |> AxonInterp.invoke()
        |> AxonInterp.get_output_tensor(0)
    ```
  """
  def invoke(%AxonInterp{module: mod, inputs: inputs}=session) do
    case GenServer.call(mod, {:invoke, inputs}, @timeout) do
      {:ok, results} -> %AxonInterp{session | outputs: results}
      any -> any
    end
  end
end