lib/tasks/classify_image.ex

defmodule Mix.Tasks.ClassifyImage do
  @moduledoc """
  Image classification mix task: `mix help classify_image`

  Command line arguments:

  - `-m`, `--model`: *Required*. File path of .tflite file.
  - `-i`, `--input`: *Required*. Image to be classified.
  - `-l`, `--labels`: File path of labels file.
  - `-k`, `--top`: Default to `1`. Max number of classification results.
  - `-t`, `--threshold`: Default to `0.0`. Classification score threshold.
  - `-c`, `--count`: Default to `1`. Number of times to run inference.
  - `-a`, `--mean`: Default to `128.0`. Mean value for input normalization.
  - `-s`, `--std`: Default to `128.0`. STD value for input normalization.
  - `-j`, `--jobs`: Number of threads for the interpreter (only valid for CPU).
  - `--use-tpu`: Default to false. Add this option to use Coral device.
  - `--tpu`: Default to `""`. Coral device name.
    - `""`      -- any TPU device
    - `"usb"`   -- any TPU device on USB bus
    - `"pci"`   -- any TPU device on PCIe bus
    - `":N"`    -- N-th TPU device, e.g. `":0"`
    - `"usb:N"` -- N-th TPU device on USB bus, e.g. `"usb:0"`
    - `"pci:N"` -- N-th TPU device on PCIe bus, e.g. `"pci:0"`

  Code based on [classify_image.py](https://github.com/google-coral/pycoral/blob/master/examples/classify_image.py)
  """

  use Mix.Task

  alias TFLiteBEAM.Interpreter
  alias TFLiteBEAM.InterpreterBuilder
  alias TFLiteBEAM.TFLiteTensor
  alias TFLiteBEAM.FlatBufferModel

  @shortdoc "Image Classification"
  def run(argv) do
    {args, _, _} =
      OptionParser.parse(argv,
        strict: [
          model: :string,
          input: :string,
          labels: :string,
          top: :integer,
          threshold: :float,
          count: :integer,
          mean: :float,
          std: :float,
          use_tpu: :boolean,
          tpu: :string,
          jobs: :integer
        ],
        aliases: [
          m: :model,
          i: :input,
          l: :labels,
          k: :top,
          t: :threshold,
          c: :count,
          a: :mean,
          s: :std,
          j: :jobs
        ]
      )

    default_values = [
      top: 1,
      threshold: 0.0,
      count: 1,
      mean: 128.0,
      std: 128.0,
      jobs: System.schedulers_online(),
      use_tpu: false,
      tpu: ""
    ]

    args =
      Keyword.merge(args, default_values, fn _k, user, default ->
        if user == nil do
          default
        else
          user
        end
      end)

    model = load_model(args[:model])
    input_image = load_input(args[:input])
    labels = load_labels(args[:labels])

    tpu_context =
      if args[:use_tpu] do
        TFLiteBEAM.Coral.get_edge_tpu_context!(device: args[:tpu])
      else
        nil
      end

    interpreter = make_interpreter(model, args[:jobs], args[:use_tpu], tpu_context)
    :ok = Interpreter.allocate_tensors(interpreter)

    [input_tensor_number | _] = Interpreter.inputs!(interpreter)
    [output_tensor_number | _] = Interpreter.outputs!(interpreter)
    %TFLiteTensor{} = input_tensor = Interpreter.tensor(interpreter, input_tensor_number)

    if input_tensor.type != {:u, 8} do
      raise ArgumentError, "Only support uint8 input type."
    end

    {h, w} =
      case input_tensor.shape do
        {_n, h, w, _c} ->
          {h, w}

        {_n, h, w} ->
          {h, w}

        shape ->
          raise RuntimeError, "not sure the input shape, got #{inspect(shape)}"
      end

    input_image = StbImage.resize(input_image, h, w)

    [scale] = input_tensor.quantization_params.scale
    [zero_point] = input_tensor.quantization_params.zero_point
    mean = args[:mean]
    std = args[:std]

    if abs(scale * std - 1) < 0.00001 and abs(mean - zero_point) < 0.00001 do
      # Input data does not require preprocessing.
      %StbImage{data: input_data} = input_image
      input_data
    else
      # Input data requires preprocessing
      StbImage.to_nx(input_image)
      |> Nx.subtract(mean)
      |> Nx.divide(std * scale)
      |> Nx.add(zero_point)
      |> Nx.clip(0, 255)
      |> Nx.as_type(:u8)
      |> Nx.to_binary()
    end
    |> then(&TFLiteTensor.set_data(input_tensor, &1))

    IO.puts("----INFERENCE TIME----")

    for _ <- 1..args[:count] do
      start_time = :os.system_time(:microsecond)
      Interpreter.invoke!(interpreter)
      end_time = :os.system_time(:microsecond)
      inference_time = (end_time - start_time) / 1000.0
      IO.puts("#{Float.round(inference_time, 1)}ms")
    end

    output_data = Interpreter.output_tensor!(interpreter, 0)
    %TFLiteTensor{} = output_tensor = Interpreter.tensor(interpreter, output_tensor_number)
    scores = get_scores(output_data, output_tensor)
    sorted_indices = Nx.argsort(scores, direction: :desc)
    top_k = Nx.take(sorted_indices, Nx.iota({args[:top]}))
    scores = Nx.to_flat_list(Nx.take(scores, top_k))
    top_k = Nx.to_flat_list(top_k)

    IO.puts("-------RESULTS--------")

    if labels != nil do
      Enum.zip(top_k, scores)
      |> Enum.each(fn {class_id, score} ->
        IO.puts("#{Enum.at(labels, class_id)}: #{Float.round(score, 5)}")
      end)
    else
      Enum.zip(top_k, scores)
      |> Enum.each(fn {class_id, score} ->
        IO.puts("#{class_id}: #{Float.round(score, 5)}")
      end)
    end
  end

  defp load_model(nil) do
    raise ArgumentError, "empty value for argument '--model'"
  end

  defp load_model(model_path) do
    FlatBufferModel.build_from_buffer(File.read!(model_path))
  end

  defp load_input(nil) do
    raise ArgumentError, "empty value for argument '--input'"
  end

  defp load_input(input_path) do
    with {:ok, input_image} <- StbImage.read_file(input_path) do
      input_image
    else
      {:error, error} ->
        raise RuntimeError, error
    end
  end

  defp load_labels(nil), do: nil

  defp load_labels(label_file_path) do
    File.read!(label_file_path)
    |> String.split("\n")
  end

  defp make_interpreter(model, num_jobs, false, _tpu_context) do
    resolver = TFLiteBEAM.Ops.Builtin.BuiltinResolver.new!()
    builder = InterpreterBuilder.new!(model, resolver)
    interpreter = Interpreter.new!()
    InterpreterBuilder.set_num_threads!(builder, num_jobs)
    :ok = InterpreterBuilder.build!(builder, interpreter)
    Interpreter.set_num_threads!(interpreter, num_jobs)
    interpreter
  end

  defp make_interpreter(model, _num_jobs, true, tpu_context) do
    TFLiteBEAM.Coral.make_edge_tpu_interpreter!(model, tpu_context)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype = {:u, _}} = output_tensor) do
    scale = Nx.tensor(output_tensor.quantization_params.scale)
    zero_point = Nx.tensor(output_tensor.quantization_params.zero_point)

    Nx.from_binary(output_data, dtype)
    |> Nx.as_type({:s, 64})
    |> Nx.subtract(zero_point)
    |> Nx.multiply(scale)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype = {:s, _}} = output_tensor) do
    [scale] = output_tensor.quantization_params.scale
    [zero_point] = output_tensor.quantization_params.zero_point

    Nx.from_binary(output_data, dtype)
    |> Nx.as_type({:s, 64})
    |> Nx.subtract(zero_point)
    |> Nx.multiply(scale)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype}) do
    Nx.from_binary(output_data, dtype)
  end
end