lib/bumblebee/multimodal/layout_lm.ex

defmodule Bumblebee.Multimodal.LayoutLm do
  alias Bumblebee.Shared

  options =
    [
      vocab_size: [
        default: 30522,
        doc: """
        the vocabulary size of the token embedding. This corresponds to the number of distinct
        tokens that can be represented in model input and output
        """
      ],
      max_positions: [
        default: 1024,
        doc: """
        the vocabulary size of the position embedding. This corresponds to the maximum sequence
        length that this model can process. Typically this is set to a large value just in case,
        such as 512, 1024 or 2048
        """
      ],
      max_spatial_positions: [
        default: 1024,
        doc: """
        the maximum value of the spatial position embedding. Typically this is set to a large value
        just in case, such as 512, 1024, or 2048
        """
      ],
      max_token_types: [
        default: 2,
        doc: """
        the vocabulary size of the token type embedding (also referred to as segment embedding).
        This corresponds to how many different token groups can be distinguished in the input
        """
      ],
      hidden_size: [
        default: 768,
        doc: "the dimensionality of hidden layers"
      ],
      num_blocks: [
        default: 12,
        doc: "the number of Transformer blocks in the encoder"
      ],
      num_attention_heads: [
        default: 12,
        doc: "the number of attention heads for each attention layer in the decoder"
      ],
      intermediate_size: [
        default: 3072,
        doc: """
        the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the decoder.
        If not specified, defaults to 4 times `:hidden_size`
        """
      ],
      activation: [
        default: :gelu,
        doc: "the activation function"
      ],
      dropout_rate: [
        default: 0.1,
        doc: "the dropout rate for embedding and encoder"
      ],
      attention_dropout_rate: [
        default: 0.1,
        doc: "the dropout rate for attention weights"
      ],
      classifier_dropout_rate: [
        default: nil,
        doc:
          "the dropout rate for the classification head. If not specified, the value of `:dropout_rate` is used instead"
      ],
      initializer_scale: [
        default: 0.02,
        doc:
          "the standard deviation of the normal initializer used for initializing kernel parameters"
      ],
      layer_norm_epsilon: [
        default: 1.0e-12,
        doc: "the epsilon used by the layer normalization layers"
      ]
    ] ++
      Shared.common_options([
        :output_hidden_states,
        :output_attentions,
        :num_labels,
        :id_to_label
      ]) ++
      Shared.token_options(pad_token_id: 0)

  @moduledoc """
  LayoutLM Model family.

  ## Architectures

     * `:base` - plain LayoutLM without any head on top

    * `:for_masked_language_modeling` - LayoutLM with a language modeling
      head. The head returns logits for each token in the original
      sequence

    * `:for_sequence_classification` - LayoutLM with a sequence
      classification head. The head returns logits corresponding to
      possible classes

    * `:for_token_classification` - LayoutLM with a token classification
      head. The head returns logits for each token in the original
      sequence

    * `:for_question_answering` - LayoutLM with a span classification head.
      The head returns logits for the span start and end positions

  ## Inputs

    * `"input_ids"` - `{batch_size, sequence_length}`

      Indices of input sequence tokens in the vocabulary.

    * `"attention_mask"` - `{batch_size, sequence_length}`

      Mask indicating which tokens to attend to. This is used to ignore
      padding tokens, which are added when processing a batch of sequences
      with different length.

    * `"token_type_ids"` - `{batch_size, sequence_length}`

      Mask distinguishing groups in the input sequence. This is used
      in when the input sequence is a semantically a pair of sequences.

    * `"position_ids"` - `{batch_size, sequence_length}`

      Indices of positions of each input sequence tokens in the position
      embeddings.

    * `"attention_head_mask"` - `{num_blocks, num_attention_heads}`

      Mask to nullify selected heads of the self-attention blocks in
      the encoder.

    * `"bounding_box"` - `{batch_size, sequence_length, 4}`

    Bounding boxes of each input sequence token. Each bounding box is
    `{x0, y0, x1, y1}` where `{x0, y0}` is the upper left corner and
    `{x1, y1}` is the lower right corner.

  ## Configuration

  #{Shared.options_doc(options)}

  ## References

    * [LayoutLM: LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318)
  """

  defstruct [architecture: :base] ++ Shared.option_defaults(options)

  @behaviour Bumblebee.ModelSpec
  @behaviour Bumblebee.Configurable

  import Bumblebee.Utils.Model, only: [join: 2]

  alias Bumblebee.Layers

  @impl true
  def architectures(),
    do: [
      :base,
      :for_masked_language_modeling,
      :for_sequence_classification,
      :for_token_classification,
      :for_question_answering
    ]

  @impl true
  def config(spec, opts \\ []) do
    spec
    |> Shared.put_config_attrs(opts)
    |> Shared.validate_label_options()
  end

  @impl true
  def input_template(%{architecture: :for_multiple_choice}) do
    %{"input_ids" => Nx.template({1, 1, 1}, :s64)}
  end

  def input_template(_spec) do
    %{"input_ids" => Nx.template({1, 1}, :s64)}
  end

  @impl true
  def model(%__MODULE__{architecture: :base} = spec) do
    inputs = inputs(spec)

    inputs
    |> core(spec)
    |> Layers.output()
  end

  def model(%__MODULE__{architecture: :for_masked_language_modeling} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")

    Layers.output(%{
      logits: logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    logits =
      outputs.pooled_state
      |> Axon.dropout(
        rate: classifier_dropout_rate(spec),
        name: "sequence_classification_head.dropout"
      )
      |> Axon.dense(spec.num_labels,
        kernel_initializer: kernel_initializer(spec),
        name: "sequence_classification_head.output"
      )

    Layers.output(%{
      logits: logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  def model(%__MODULE__{architecture: :for_token_classification} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    logits =
      outputs.hidden_state
      |> Axon.dropout(
        rate: classifier_dropout_rate(spec),
        name: "token_classification_head.dropout"
      )
      |> Axon.dense(spec.num_labels,
        kernel_initializer: kernel_initializer(spec),
        name: "token_classification_head.output"
      )

    Layers.output(%{
      logits: logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  def model(%__MODULE__{architecture: :for_question_answering} = spec) do
    inputs = inputs(spec)
    outputs = core(inputs, spec)

    logits =
      outputs.hidden_state
      |> Axon.dropout(
        rate: classifier_dropout_rate(spec),
        name: "question_answering_head.dropout"
      )
      |> Axon.dense(2,
        kernel_initializer: kernel_initializer(spec),
        name: "question_answering_head.output"
      )

    {start_logits, end_logits} = Layers.split_pair(logits)

    Layers.output(%{
      start_logits: start_logits,
      end_logits: end_logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions
    })
  end

  defp inputs(spec, opts \\ []) do
    shape = Keyword.get(opts, :shape, {nil, nil})

    attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads}
    bounding_box_shape = Tuple.append(shape, 4)

    Bumblebee.Utils.Model.inputs_to_map([
      Axon.input("input_ids", shape: shape),
      Axon.input("bounding_box", optional: true, shape: bounding_box_shape),
      Axon.input("attention_mask", optional: true, shape: shape),
      Axon.input("token_type_ids", optional: true, shape: shape),
      Axon.input("position_ids", optional: true, shape: shape),
      Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape)
    ])
  end

  defp core(inputs, spec) do
    embeddings =
      embedder(
        inputs["input_ids"],
        inputs["bounding_box"],
        inputs["position_ids"],
        inputs["token_type_ids"],
        spec,
        name: "embedder"
      )

    encoder_outputs =
      encoder(embeddings, inputs["attention_mask"], inputs["attention_head_mask"], spec,
        name: "encoder"
      )

    pooled_state = pooler(encoder_outputs.hidden_state, spec, name: "pooler")

    %{
      hidden_state: encoder_outputs.hidden_state,
      pooled_state: pooled_state,
      hidden_states: encoder_outputs.hidden_states,
      attentions: encoder_outputs.attentions
    }
  end

  defp embedder(input_ids, bounding_box, position_ids, token_type_ids, spec, opts) do
    name = opts[:name]

    bounding_box =
      Layers.default bounding_box do
        Layers.default_bounding_box(input_ids)
      end

    position_ids =
      Layers.default position_ids do
        Layers.default_position_ids(input_ids)
      end

    token_type_ids =
      Layers.default token_type_ids do
        Layers.default_token_type_ids(input_ids)
      end

    inputs_embeddings =
      Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size,
        name: join(name, "token_embedding")
      )

    position_embeddings =
      Axon.embedding(position_ids, spec.max_positions, spec.hidden_size,
        name: join(name, "position_embedding")
      )

    token_type_embeddings =
      Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size,
        name: join(name, "token_type_embedding")
      )

    # TODO: Explicitly tie these weights

    left_position_embeddings =
      bounding_box
      |> Axon.nx(& &1[[0..-1//1, 0..-1//1, 0]])
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "x_position_embedding")
      )

    right_position_embeddings =
      bounding_box
      |> Axon.nx(& &1[[0..-1//1, 0..-1//1, 2]])
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "x_position_embedding")
      )

    upper_position_embeddings =
      bounding_box
      |> Axon.nx(& &1[[0..-1//1, 0..-1//1, 1]])
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "y_position_embedding")
      )

    lower_position_embeddings =
      bounding_box
      |> Axon.nx(& &1[[0..-1//1, 0..-1//1, 3]])
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "y_position_embedding")
      )

    h_position_embeddings =
      bounding_box
      |> Axon.nx(fn x -> Nx.subtract(x[[0..-1//1, 0..-1//1, 3]], x[[0..-1//1, 0..-1//1, 1]]) end)
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "h_position_embedding")
      )

    w_position_embeddings =
      bounding_box
      |> Axon.nx(fn x -> Nx.subtract(x[[0..-1//1, 0..-1//1, 2]], x[[0..-1//1, 0..-1//1, 0]]) end)
      |> Axon.embedding(spec.max_spatial_positions, spec.hidden_size,
        name: join(name, "w_position_embedding")
      )

    embeddings =
      Axon.add([
        inputs_embeddings,
        position_embeddings,
        token_type_embeddings,
        left_position_embeddings,
        right_position_embeddings,
        upper_position_embeddings,
        lower_position_embeddings,
        h_position_embeddings,
        w_position_embeddings
      ])

    embeddings
    |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm"))
    |> Axon.dropout(rate: spec.dropout_rate)
  end

  defp encoder(hidden_state, attention_mask, attention_head_mask, spec, opts) do
    name = opts[:name]

    Layers.Transformer.blocks(hidden_state,
      attention_mask: attention_mask,
      attention_head_mask: attention_head_mask,
      num_blocks: spec.num_blocks,
      num_attention_heads: spec.num_attention_heads,
      hidden_size: spec.hidden_size,
      kernel_initializer: kernel_initializer(spec),
      dropout_rate: spec.dropout_rate,
      attention_dropout_rate: spec.attention_dropout_rate,
      layer_norm: [
        epsilon: spec.layer_norm_epsilon
      ],
      ffn: [
        intermediate_size: spec.intermediate_size,
        activation: spec.activation
      ],
      output_hidden_states: spec.output_hidden_states,
      output_attentions: spec.output_attentions,
      name: join(name, "blocks")
    )
  end

  defp pooler(hidden_state, spec, opts) do
    name = opts[:name]

    hidden_state
    |> Layers.take_token(index: 0, axis: 1, name: join(name, "head"))
    |> Axon.dense(spec.hidden_size,
      kernel_initializer: kernel_initializer(spec),
      name: join(name, "output")
    )
    |> Axon.tanh()
  end

  defp language_modeling_head(hidden_state, spec, opts) do
    name = opts[:name]

    # TODO: use a shared parameter with embeddings.word_embeddings.kernel
    # if spec.tie_word_embeddings is true (relevant for training)

    hidden_state
    |> Axon.dense(spec.hidden_size,
      kernel_initializer: kernel_initializer(spec),
      name: join(name, "dense")
    )
    |> Layers.activation(spec.activation, name: join(name, "activation"))
    |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm"))
    # We reuse the kernel of input embeddings and add bias for each token
    |> Layers.dense_transposed(spec.vocab_size,
      kernel_initializer: kernel_initializer(spec),
      name: join(name, "output")
    )
    |> Axon.bias(name: join(name, "bias"))
  end

  defp classifier_dropout_rate(spec) do
    spec.classifier_dropout_rate || spec.dropout_rate
  end

  defp kernel_initializer(spec) do
    Axon.Initializers.normal(scale: spec.initializer_scale)
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(spec, data) do
      import Shared.Converters

      opts =
        convert!(data,
          vocab_size: {"vocab_size", number()},
          max_positions: {"max_position_embeddings", number()},
          max_token_types: {"type_vocab_size", number()},
          hidden_size: {"hidden_size", number()},
          num_blocks: {"num_hidden_layers", number()},
          num_attention_heads: {"num_attention_heads", number()},
          intermediate_size: {"intermediate_size", number()},
          activation: {"hidden_act", atom()},
          dropout_rate: {"hidden_dropout_prob", number()},
          attention_dropout_rate: {"attention_probs_dropout_prob", number()},
          classifier_dropout_rate: {"classifier_dropout", optional(number())},
          layer_norm_epsilon: {"layer_norm_eps", number()},
          initializer_scale: {"initializer_range", number()}
        ) ++ Shared.common_options_from_transformers(data, spec)

      @for.config(spec, opts)
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Model do
    def params_mapping(_spec) do
      %{
        "embedder.token_embedding" => "layoutlm.embeddings.word_embeddings",
        "embedder.position_embedding" => "layoutlm.embeddings.position_embeddings",
        "embedder.token_type_embedding" => "layoutlm.embeddings.token_type_embeddings",
        "embedder.x_position_embedding" => "layoutlm.embeddings.x_position_embeddings",
        "embedder.y_position_embedding" => "layoutlm.embeddings.y_position_embeddings",
        "embedder.h_position_embedding" => "layoutlm.embeddings.h_position_embeddings",
        "embedder.w_position_embedding" => "layoutlm.embeddings.w_position_embeddings",
        "embedder.norm" => "layoutlm.embeddings.LayerNorm",
        "encoder.blocks.{n}.self_attention.query" =>
          "layoutlm.encoder.layer.{n}.attention.self.query",
        "encoder.blocks.{n}.self_attention.key" =>
          "layoutlm.encoder.layer.{n}.attention.self.key",
        "encoder.blocks.{n}.self_attention.value" =>
          "layoutlm.encoder.layer.{n}.attention.self.value",
        "encoder.blocks.{n}.self_attention.output" =>
          "layoutlm.encoder.layer.{n}.attention.output.dense",
        "encoder.blocks.{n}.self_attention_norm" =>
          "layoutlm.encoder.layer.{n}.attention.output.LayerNorm",
        "encoder.blocks.{n}.cross_attention.query" =>
          "layoutlm.encoder.layer.{n}.attention.self.query",
        "encoder.blocks.{n}.cross_attention.key" =>
          "layoutlm.encoder.layer.{n}.attention.self.key",
        "encoder.blocks.{n}.cross_attention.value" =>
          "layoutlm.encoder.layer.{n}.attention.self.value",
        "encoder.blocks.{n}.cross_attention.output" =>
          "layoutlm.encoder.layer.{n}.attention.output.dense",
        "encoder.blocks.{n}.cross_attention_norm" =>
          "layoutlm.encoder.layer.{n}.attention.output.LayerNorm",
        "encoder.blocks.{n}.ffn.intermediate" => "layoutlm.encoder.layer.{n}.intermediate.dense",
        "encoder.blocks.{n}.ffn.output" => "layoutlm.encoder.layer.{n}.output.dense",
        "encoder.blocks.{n}.output_norm" => "layoutlm.encoder.layer.{n}.output.LayerNorm",
        "pooler.output" => "layoutlm.pooler.dense",
        "language_modeling_head.dense" => "cls.predictions.transform.dense",
        "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm",
        "language_modeling_head.output" => "cls.predictions.decoder",
        "language_modeling_head.bias" => "cls.predictions",
        "sequence_classification_head.output" => "cls.seq_relationship",
        "token_classification_head.output" => "classifier",
        "multiple_choice_head.output" => "classifier",
        "question_answering_head.output" => "qa_outputs"
      }
    end
  end
end