lib/image/classification.ex

if Image.bumblebee_configured?() do
  defmodule Image.Classification do
    @moduledoc """
    Implements image classification functions using [Axon](https://hex.pm/packages/axon)
    machine learning models managed by [Bumblebee](https://hex.pm/packages/bumblebee).

    Image classification refers to the task of extracting information from an image.
    In this module, the information extracted is one or more labels desribing the
    image. Typically something like "sports car" or "Blenheim spaniel". The labels
    returns depend on the machine learning model used.

    ### Configuration

    The machine learning model to be used is configurable.
    The `:model` and `:featurizer` may be any model supported by Bumblebee. The
    `:name` is the name given to the classification service process.

    The default configuration is:

        # runtime.exs
        config :image, :classifier,
          model: {:hf, "microsoft/resnet-50"},
          featurizer:  {:hf, "microsoft/resnet-50"},
          name: Image.Classification.Server,
          autostart: true

    ### Autostart

    If `autostart: true` is configured (the default) then a process
    is started under a supervisor to execute the classification
    requests.  If running the process under an application
    supervision tree is desired, set `autostart: false`. In that
    case the function `classifer/1` can be
    used to return a `t:Supervisor.child_spec/0`.

    ### Adding a classification server to an application supervision tree

    To add image classification to an application supervision tree,
    use `Image.Classification.classifier/1` to return a child spec:
    For example:

        # Application.ex
        def start(_type, _args) do
          children = [
            # default classifier configuration
            Image.Classification.classifier()

            # custom classifier configuration
            Image.Classification.classifier(model: {:hf, "google/vit-base-patch16-224"},
              featurizer: {:hf, "google/vit-base-patch16-224"})
          ]

          Supervisor.start_link(
            children,
            strategy: :one_for_one
          )
        end

    ### Note

    This module is only available if the optional dependency
    [Bumblebee](https://hex.pm/packages/bumblebee) is configured in
    `mix.exs`.

    """

    alias Vix.Vips.Image, as: Vimage

    @min_score 0.5

    @default_classifier [
      model: {:hf, "microsoft/resnet-50"},
      featurizer: {:hf, "microsoft/resnet-50"},
      name: Image.Classification.Server,
      autostart: true
    ]

    @default_classifier_name @default_classifier[:name]

    @doc """
    Returns a child spec suitable for starting an image classification
    process as part of a supervision tree.

    ### Arguments

    * `configuration` is a keyword list.The default is
      `Application.get_env(:image, :classifier, [])`.

    ### Configuration keys

    * `:model` is any supported machine learning model for image
      classification supported by Bumblebee.

    * `:featurizer` is any supported machine learning model for image
      featurization supported by Bumblebee.

    * `:name` is the name given to the classification process when
      it is started.

    ### Default configuration

    The default configuration is:
    ```elixir
    [
      model: {:hf, "microsoft/resnet-50"},
      featurizer: {:hf, "microsoft/resnet-50"},
      name: Image.Classification.Server
    ]
    ```

    """
    @spec classifier(configuration :: Keyword.t()) :: {Nx.Serving, Keyword.t()}
    def classifier(classifier \\ Application.get_env(:image, :classifier, [])) do
      Application.ensure_all_started(:exla)
      classifier = Keyword.merge(@default_classifier, classifier)

      case Image.Classification.serving(classifier[:model], classifier[:featurizer]) do
        {:error, error} ->
          {:error, error}

        serving ->
          {Nx.Serving, serving: serving, name: classifier[:name], batch_timeout: 100}
      end
    end

    @doc false
    def serving(model, featurizer) do
      with {:ok, model_info} <- Bumblebee.load_model(model),
           {:ok, featurizer} = Bumblebee.load_featurizer(featurizer) do
        Bumblebee.Vision.image_classification(model_info, featurizer,
          compile: [batch_size: 10],
          defn_options: [compiler: EXLA]
        )
      end
    end

    @doc """
    Classify an image using a machine learning
    model.

    ### Arguments

    * `image` is any `t:Vix.Vips.Image.t/0`.

    * `options` is a keyword list of options

    ### Options

    * `:backend` is any valid `Nx` backend. The default is
      `Nx.default_backend/0`.

    * `:server` is the name of the process performing the
      classification service. The default is `#{inspect @default_classifier_name}`.

    ### Returns

    * A map containing the predictions of the image
      classification.

    ### Example

        iex> puppy = Image.open!("./test/support/images/puppy.webp")
        iex> %{predictions: [%{label: "Blenheim spaniel", score: _} | _rest]} =
        ...>   Image.Classification.classify(puppy)

    """

    @dialyzer {:nowarn_function, {:classify, 1}}
    @dialyzer {:nowarn_function, {:classify, 2}}

    @doc since: "0.18.0"

    @spec classify(image :: Vimage.t(), Keyword.t()) ::
            %{predictions: [%{label: String.t(), score: float()}]} | {:error, Image.error_message()}

    def classify(%Vimage{} = image, options \\ []) do
      backend = Keyword.get(options, :backend, Nx.default_backend())
      server = Keyword.get(options, :server, @default_classifier_name)

      with {:ok, flattened} <- Image.flatten(image),
           {:ok, srgb} <- Image.to_colorspace(flattened, :srgb),
           {:ok, tensor} <- Image.to_nx(srgb, shape: :hwc, backend: backend) do
        Nx.Serving.batched_run(server, tensor)
      end
    end

    @doc """
    Classify an image using a machine learning
    model and return the labels that meet a minimum
    score.

    ### Arguments

    * `image` is any `t:Vix.Vips.Image.t/0`.

    * `options` is a keyword list of options.

    ### Options

    * `:backend` is any valid `Nx` backend. The default is
      `Nx.default_backend/0`.

    * `:min_score` is the minimum score, a float between `0`
      and `1`, which a label must match in order to be
      returned.

    ### Returns

    * A list of labels. The list may be empty if there
      are no predictions that exceed the `:min_score`.

    * `{:error, reason}`

    ### Example

        iex> {:ok, image} = Image.open ("./test/support/images/lamborghini-forsennato-concept.jpg")
        iex> Image.Classification.labels(image)
        ["sports car", "sport car"]

    """
    @dialyzer {:nowarn_function, {:labels, 1}}
    @dialyzer {:nowarn_function, {:labels, 2}}

    @doc since: "0.18.0"

    @spec labels(image :: Vimage.t(), options :: Keyword.t()) ::
            [String.t()] | {:error, Image.error_message()}

    def labels(%Vimage{} = image, options \\ []) do
      {min_score, options} = Keyword.pop(options, :min_score, @min_score)

      with %{predictions: predictions} <- classify(image, options) do
        predictions
        |> Enum.filter(fn %{score: score} -> score >= min_score end)
        |> Enum.flat_map(fn %{label: label} -> String.split(label, ", ") end)
      end
    end
  end
end