lib/bumblebee/vision/dino_v2.ex

defmodule Bumblebee.Vision.DinoV2 do
  alias Bumblebee.Shared

  options =
    [
      image_size: [
        default: 518,
        doc: """
        the size of the input spatial dimensions. The model is trained for this size, however
        the model supports any other input size by interpolating position embeddings
        """
      ],
      num_channels: [
        default: 3,
        doc: "the number of channels in the input"
      ],
      patch_size: [
        default: 14,
        doc: "the size of the patch spatial dimensions"
      ],
      hidden_size: [
        default: 384,
        doc: "the dimensionality of hidden layers"
      ],
      num_blocks: [
        default: 12,
        doc: "the number of Transformer blocks in the encoder"
      ],
      num_attention_heads: [
        default: 12,
        doc: "the number of attention heads for each attention layer in the encoder"
      ],
      intermediate_size_ratio: [
        default: 4,
        doc: """
        the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder,
        expressed as a multiplier of `:hidden_size`
        """
      ],
      use_qkv_bias: [
        default: true,
        doc: "whether to use bias in query, key, and value projections"
      ],
      activation: [
        default: :gelu,
        doc: "the activation function"
      ],
      ffn_swiglu_activation: [
        default: false,
        doc:
          "whether to use the gated SwiGLU activation function in the feed-forward network (FFN)"
      ],
      scale_initial_value: [
        default: 1.0,
        doc: "the initial value for scaling layers"
      ],
      dropout_rate: [
        default: 0.0,
        doc: "the dropout rate for encoder and decoder"
      ],
      attention_dropout_rate: [
        default: 0.0,
        doc: "the dropout rate for attention weights"
      ],
      layer_norm_epsilon: [
        default: 1.0e-6,
        doc: "the epsilon used by the layer normalization layers"
      ],
      initializer_scale: [
        default: 0.02,
        doc:
          "the standard deviation of the normal initializer used for initializing kernel parameters"
      ],
      backbone_output_indices: [
        default: nil,
        doc: """
        list of indices indicating which feature maps to include in the output. If not specified, only
        the last feature map is included
        """
      ],
      backbone_use_norm: [
        default: true,
        doc:
          "whether to add layer normalization layer to each of the feature maps returned by the backbone"
      ]
    ] ++
      Shared.common_options([
        :output_hidden_states,
        :output_attentions,
        :num_labels,
        :id_to_label
      ])

  @moduledoc """
  DINOv2 model family.

  ## Architectures

    * `:base` - plain DINOv2 without any head on top

    * `:for_image_classification` - DINOv2 with head for image classification

    * `:backbone` - DINOv2 with feature maps output

  ## Inputs

    * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`

      Featurized image pixel values.

    * `"patch_mask"` - `{batch_size, num_patches}`

      Mask to nullify selected embedded patches.

  ## Configuration

  #{Shared.options_doc(options)}

  ## References

    * [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)

  """

  defstruct [architecture: :base] ++ Shared.option_defaults(options)

  @behaviour Bumblebee.ModelSpec
  @behaviour Bumblebee.Configurable

  import Bumblebee.Utils.Model, only: [join: 2]

  alias Bumblebee.Layers

  @impl true
  def architectures(), do: [:base, :backbone, :for_image_classification]

  @impl true
  def config(spec, opts) do
    spec
    |> Shared.put_config_attrs(opts)
    |> Shared.validate_label_options()
  end

  @impl true
  def input_template(spec) do
    %{
      "pixel_values" =>
        Nx.template({1, spec.image_size, spec.image_size, spec.num_channels}, :f32)
    }
  end

  @impl true
  def model(%__MODULE__{architecture: :base} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    hidden_state =
      Axon.layer_norm(outputs.hidden_state, epsilon: spec.layer_norm_epsilon, name: "norm")

    pooled_state = Layers.take_token(hidden_state, index: 0, axis: 1)

    Layers.output(%{
      hidden_state: hidden_state,
      pooled_state: pooled_state,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  def model(%__MODULE__{architecture: :for_image_classification} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    hidden_state =
      Axon.layer_norm(outputs.hidden_state, epsilon: spec.layer_norm_epsilon, name: "norm")

    class_token = Layers.take_token(hidden_state, index: 0, axis: 1)

    patch_embeddings_mean =
      Axon.nx(hidden_state, fn hidden_state ->
        patch_embeddings = hidden_state[[.., 1..-1//1, ..]]
        Nx.mean(patch_embeddings, axes: [1])
      end)

    logits =
      Axon.concatenate(class_token, patch_embeddings_mean)
      |> Axon.dense(spec.num_labels,
        kernel_initializer: kernel_initializer(spec),
        name: "image_classification_head.output"
      )

    Layers.output(%{
      logits: logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  def model(%__MODULE__{architecture: :backbone} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, %{spec | output_hidden_states: true})
    feature_maps = feature_maps(outputs.hidden_states, inputs["pixel_values"], spec)

    Layers.output(%{
      feature_maps: feature_maps,
      hidden_states:
        if(spec.output_hidden_states, do: outputs.hidden_states, else: Layers.none()),
      attentions: outputs.attentions
    })
  end

  defp inputs(spec) do
    shape = {nil, nil, nil, spec.num_channels}

    Bumblebee.Utils.Model.inputs_to_map([
      Axon.input("pixel_values", shape: shape),
      Axon.input("patch_mask", shape: {nil, nil}, optional: true)
    ])
  end

  defp core(inputs, spec, opts \\ []) do
    name = opts[:name]

    embeddings =
      embedder(inputs["pixel_values"], inputs["patch_mask"], spec, name: join(name, "embedder"))

    encoder_outputs = encoder(embeddings, spec, name: join(name, "encoder"))

    %{
      hidden_state: encoder_outputs.hidden_state,
      hidden_states: encoder_outputs.hidden_states,
      attentions: encoder_outputs.attentions
    }
  end

  defp feature_maps(hidden_states, pixel_values, spec, opts \\ []) do
    name = opts[:name]

    num_feature_maps = spec.num_blocks + 1
    output_indices = spec.backbone_output_indices || [num_feature_maps - 1]

    for index <- output_indices do
      hidden_state = Axon.nx(hidden_states, &elem(&1, index))

      hidden_state =
        if spec.backbone_use_norm do
          Axon.layer_norm(hidden_state,
            epsilon: spec.layer_norm_epsilon,
            name: join(name, "norm")
          )
        else
          hidden_state
        end

      Axon.layer(
        fn hidden_state, pixel_values, _opts ->
          {batch_size, height, width, _channels} = Nx.shape(pixel_values)

          hidden_state = hidden_state[[.., 1..-1//1, ..]]

          Nx.reshape(
            hidden_state,
            {batch_size, div(height, spec.patch_size), div(width, spec.patch_size), :auto}
          )
        end,
        [hidden_state, pixel_values]
      )
    end
    |> List.to_tuple()
    |> Axon.container()
  end

  defp embedder(pixel_values, patch_mask, spec, opts) do
    name = opts[:name]

    patch_embeddings =
      pixel_values
      |> patch_embedding(spec, name: join(name, "patch_embedding"))
      |> Layers.apply_vision_patch_mask(patch_mask, name: join(name, "mask_tokens"))

    class_embedding =
      Layers.learned_embeddings(1, spec.hidden_size, name: join(name, "class_embedding"))

    input_embeddings = Layers.concatenate_embeddings([class_embedding, patch_embeddings])

    num_patches = div(spec.image_size, spec.patch_size) ** 2

    position_embeddings =
      Layers.learned_embeddings(num_patches + 1, spec.hidden_size,
        initializer: :zeros,
        name: join(name, "position_embedding")
      )
      |> interpolate_position_embeddings(pixel_values, spec)

    Axon.add(input_embeddings, position_embeddings)
    |> Axon.dropout(rate: spec.dropout_rate, name: join(name, "dropout"))
  end

  defp patch_embedding(pixel_values, spec, opts) do
    name = opts[:name]

    pixel_values
    |> Axon.conv(spec.hidden_size,
      kernel_size: spec.patch_size,
      strides: spec.patch_size,
      padding: :valid,
      kernel_initializer: kernel_initializer(spec),
      name: join(name, "projection")
    )
    |> Axon.reshape({:batch, :auto, spec.hidden_size}, name: join(name, "reshape"))
  end

  defp interpolate_position_embeddings(position_embeddings, pixel_values, spec) do
    Axon.layer(
      fn position_embeddings, pixel_values, _opts ->
        original_positions = div(spec.image_size, spec.patch_size)
        {batch_size, height, width, _channels} = Nx.shape(pixel_values)
        resized_height = div(height, spec.patch_size)
        resized_width = div(width, spec.patch_size)

        class_position_embedding = position_embeddings[[.., 0..0//1, ..]]
        input_position_embeddings = position_embeddings[[.., 1..-1//1, ..]]

        interpolated_position_embeddings =
          input_position_embeddings
          |> Nx.reshape({batch_size, original_positions, original_positions, spec.hidden_size})
          |> Axon.Layers.resize(
            size: {resized_height, resized_width},
            method: :bicubic,
            antialias: false
          )
          |> Nx.reshape({batch_size, :auto, spec.hidden_size})

        Nx.concatenate([class_position_embedding, interpolated_position_embeddings], axis: 1)
      end,
      [position_embeddings, pixel_values],
      op_name: :interpolate_position_embeddings
    )
  end

  defp encoder(hidden_state, spec, opts) do
    name = opts[:name]

    ffn =
      if spec.ffn_swiglu_activation do
        intermediate_size =
          div(floor(floor(spec.hidden_size * spec.intermediate_size_ratio) * 2 / 3 + 7), 8) * 8

        &ffn_swiglu(&1, intermediate_size, spec.hidden_size, name: &2)
      else
        intermediate_size = floor(spec.hidden_size * spec.intermediate_size_ratio)

        [
          intermediate_size: intermediate_size,
          activation: spec.activation
        ]
      end

    Layers.Transformer.blocks(hidden_state,
      num_blocks: spec.num_blocks,
      num_attention_heads: spec.num_attention_heads,
      hidden_size: spec.hidden_size,
      kernel_initializer: kernel_initializer(spec),
      dropout_rate: spec.dropout_rate,
      attention_dropout_rate: spec.attention_dropout_rate,
      query_use_bias: spec.use_qkv_bias,
      key_use_bias: spec.use_qkv_bias,
      value_use_bias: spec.use_qkv_bias,
      layer_norm: [
        epsilon: spec.layer_norm_epsilon
      ],
      ffn: ffn,
      block_type: &block_impl(&1, &2, &3, spec),
      output_hidden_states: spec.output_hidden_states,
      output_attentions: spec.output_attentions,
      name: join(name, "blocks")
    )
  end

  # A feed-forward network with SwiGLU nonlinearity as in https://arxiv.org/abs/2002.05202
  defp ffn_swiglu(x, intermediate_size, output_size, opts) do
    name = opts[:name]
    dropout = opts[:dropout] || 0.0

    {gate, x} =
      x
      |> Axon.dense(intermediate_size * 2, name: join(name, "intermediate"))
      |> Axon.split(2, axis: -1)

    x = Axon.multiply(x, Axon.silu(gate))

    x
    |> Axon.dropout(rate: dropout, name: join(name, "dropout"))
    |> Axon.dense(output_size, name: join(name, "output"))
  end

  # :norm_first block with additional scaling layers
  defp block_impl(hidden_state, steps, name, spec) do
    shortcut = hidden_state

    {hidden_state, attention_info} =
      hidden_state
      |> steps.self_attention_norm.()
      |> steps.self_attention.()

    hidden_state =
      hidden_state
      |> Bumblebee.Layers.scale(
        scale_initializer: Axon.Initializers.full(spec.scale_initial_value),
        name: join(name, "self_attention_scale")
      )
      |> Axon.add(shortcut)

    {_hidden_state, cross_attention_info} =
      steps.cross_attention_maybe.(hidden_state, fn _hidden_state ->
        raise "cross attention not supported"
      end)

    shortcut = hidden_state

    hidden_state =
      hidden_state
      |> steps.output_norm.()
      |> steps.ffn.()
      |> Bumblebee.Layers.scale(
        scale_initializer: Axon.Initializers.full(spec.scale_initial_value),
        name: join(name, "output_scale")
      )
      |> Axon.add(shortcut)

    {hidden_state, attention_info, cross_attention_info}
  end

  defp kernel_initializer(spec) do
    Axon.Initializers.normal(scale: spec.initializer_scale)
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(spec, data) do
      import Shared.Converters

      opts =
        convert!(data,
          image_size: {"image_size", number()},
          num_channels: {"num_channels", number()},
          patch_size: {"patch_size", number()},
          hidden_size: {"hidden_size", number()},
          num_blocks: {"num_hidden_layers", number()},
          num_attention_heads: {"num_attention_heads", number()},
          intermediate_size_ratio: {"mlp_ratio", number()},
          activation: {"hidden_act", activation()},
          use_qkv_bias: {"qkv_bias", boolean()},
          dropout_rate: {"hidden_dropout_prob", number()},
          attention_dropout_rate: {"attention_probs_dropout_prob", number()},
          layer_norm_epsilon: {"layer_norm_eps", number()},
          initializer_scale: {"initializer_range", number()},
          scale_initial_value: {"layerscale_value", number()},
          ffn_swiglu_activation: {"use_swiglu_ffn", boolean()},
          backbone_output_indices: {"out_indices", list(number())},
          backbone_use_norm: {"use_backbone_norm", boolean()}
        ) ++ Shared.common_options_from_transformers(data, spec)

      @for.config(spec, opts)
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Model do
    def params_mapping(spec) do
      %{
        "embedder.patch_embedding.projection" => "dinov2.embeddings.patch_embeddings.projection",
        "embedder.class_embedding" => %{
          "embeddings" => {
            [{"dinov2.embeddings", "cls_token"}],
            fn [value] -> Nx.squeeze(value, axes: [0]) end
          }
        },
        "embedder.position_embedding" => %{
          "embeddings" => {
            [{"dinov2.embeddings", "position_embeddings"}],
            fn [value] -> Nx.squeeze(value, axes: [0]) end
          }
        },
        "encoder.blocks.{n}.self_attention_norm" => "dinov2.encoder.layer.{n}.norm1",
        "encoder.blocks.{n}.self_attention.key" =>
          "dinov2.encoder.layer.{n}.attention.attention.key",
        "encoder.blocks.{n}.self_attention.query" =>
          "dinov2.encoder.layer.{n}.attention.attention.query",
        "encoder.blocks.{n}.self_attention.value" =>
          "dinov2.encoder.layer.{n}.attention.attention.value",
        "encoder.blocks.{n}.self_attention.output" =>
          "dinov2.encoder.layer.{n}.attention.output.dense",
        "encoder.blocks.{n}.self_attention_scale" => %{
          "scale" => {
            [{"dinov2.encoder.layer.{n}.layer_scale1", "lambda1"}],
            fn [lambda1] -> lambda1 end
          }
        },
        "encoder.blocks.{n}.ffn.intermediate" =>
          if(spec.ffn_swiglu_activation,
            do: "dinov2.encoder.layer.{n}.mlp.weights_in",
            else: "dinov2.encoder.layer.{n}.mlp.fc1"
          ),
        "encoder.blocks.{n}.ffn.output" =>
          if(spec.ffn_swiglu_activation,
            do: "dinov2.encoder.layer.{n}.mlp.weights_out",
            else: "dinov2.encoder.layer.{n}.mlp.fc2"
          ),
        "encoder.blocks.{n}.output_norm" => "dinov2.encoder.layer.{n}.norm2",
        "encoder.blocks.{n}.output_scale" => %{
          "scale" => {
            [{"dinov2.encoder.layer.{n}.layer_scale2", "lambda1"}],
            fn [lambda1] -> lambda1 end
          }
        },
        "norm" => "dinov2.layernorm",
        "image_classification_head.output" => "classifier"
      }
    end
  end
end