lib/bumblebee/text/generation_config.ex

defmodule Bumblebee.Text.GenerationConfig do
  alias Bumblebee.Shared

  length_options = [
    max_new_tokens: [
      default: 20,
      doc:
        "the maximum number of tokens to be generated, ignoring the number of tokens in the prompt"
    ],
    min_new_tokens: [
      default: nil,
      doc:
        "the minimum number of tokens to be generated, ignoring the number of tokens in the prompt"
    ],
    max_length: [
      default: nil,
      doc: """
      the maximum length of the sequence to be generated. Note that this length includes the
      length of the input prompt (including padding). In general, prefer `:max_new_tokens`,
      which ignores the number of tokens in the prompt
      """
    ],
    min_length: [
      default: nil,
      doc: """
      the minimum length of the sequence to be generated. Note that this length includes the
      length of the input prompt (including padding). In general, prefer `:min_new_tokens`,
      which ignores the number of tokens in the prompt
      """
    ]
  ]

  strategy_options = [
    strategy: [
      default: %{type: :greedy_search},
      doc: """
      the method deciding how tokens are selected, it has a significant impact on the quality
      of the generated sequence. Should be a map with `:type` and strategy-specific options.

        * `:greedy_search` - the most straightforward approach, where in
          every iteration the most probable token (as given by the model)
          is taken.

          Example: `%{type: :greedy_search}`.

        * `:contrastive_search` - state-of-the-art decoding method, capable
          of producing high quality, coherent sequences. The results are
          deterministic. See [this article](https://huggingface.co/blog/introducing-csearch)
          for more details.

            * `:top_k` (required) - the number of highest probability vocabulary tokens considered
              as a continuation

            * `:alpha` (required) - the weight of degeneration penalty. It balances the model
              confidence and the penalty

          Example: `%{type: :contrastive_search, top_k: 4, alpha: 0.6}`.

        * `:multinomial_sampling` - this method samples tokens according to the probability
          distribution given by the model. The results are nondeterministic, unless a seed
          is specified.

            * `:top_k` (optional) - when specified, restricts sampling to top-k most probable
              candidates

            * `:top_p` (optional) - when specified, restricts sampling to tokens which probabilities
              add up to top-p
      """
    ]
  ]

  token_options = [
    decoder_start_token_id: [
      default: nil,
      doc:
        "the id of the initial token when generating from scratch, in case of encoder-decoder models"
    ],
    forced_bos_token_id: [
      default: nil,
      doc: "the id of the token to force as the first generated token"
    ],
    forced_eos_token_id: [
      default: nil,
      doc:
        "the id of the token to force as the last generated token when `:max_length` is reached"
    ],
    forced_token_ids: [
      default: [],
      doc:
        "a list of `{index, token_id}` pairs forcing `token_id` to appear at `index` in the generated sequence"
    ],
    suppressed_token_ids: [
      default: [],
      doc: "a list of token ids to suppress during generation"
    ],
    no_repeat_ngram_length: [
      default: nil,
      doc: "when set, n-grams of the given length can occur only once in the generated sequence"
    ]
  ]

  special_token_options = [
    bos_token_id: [
      default: nil,
      doc: "the id of the beginning-of-sequence token"
    ],
    eos_token_id: [
      default: nil,
      doc: "the id of the end-of-sequence token"
    ],
    pad_token_id: [
      default: nil,
      doc: "the id of the padding token"
    ]
  ]

  other_options = [
    extra_config: [
      default: nil,
      doc: "additional configuration specific to the given model"
    ]
  ]

  options =
    length_options ++ strategy_options ++ token_options ++ special_token_options ++ other_options

  @moduledoc """
  A set of configuration options controlling text generation.

  This struct is expected by `Bumblebee.Text.Generation.build_generate/3`.

  ## Configuration

  ### Options controlling length

  #{Shared.options_doc(length_options)}

  ### Options controlling strategy

  #{Shared.options_doc(strategy_options)}

  ### Options controlling generated tokens

  #{Shared.options_doc(token_options)}

  ### Special tokens used during generation

  #{Shared.options_doc(special_token_options)}
  """

  defstruct Shared.option_defaults(options)

  @behaviour Bumblebee.Configurable

  @type t :: %__MODULE__{}

  @impl true
  def config(config, opts \\ []) do
    opts =
      case {opts[:max_new_tokens], opts[:max_length]} do
        {nil, nil} ->
          opts

        {_, nil} ->
          put_in(opts[:max_length], nil)

        {nil, _} ->
          put_in(opts[:max_new_tokens], nil)

        _ ->
          raise ArgumentError,
                "only one of :max_new_tokens or :max_length options must be given, but got both"
      end

    opts =
      case {opts[:min_new_tokens], opts[:min_length]} do
        {nil, nil} ->
          opts

        {_, nil} ->
          put_in(opts[:min_length], nil)

        {nil, _} ->
          put_in(opts[:min_new_tokens], nil)

        _ ->
          raise ArgumentError,
                "only one of :min_new_tokens or :min_length options must be given, but got both"
      end

    with {:ok, strategy} <- Keyword.fetch(opts, :strategy) do
      validate_strategy!(strategy)
    end

    Shared.put_config_attrs(config, opts)
  end

  defp validate_strategy!(%{type: :greedy_search} = strategy) do
    validate_strategy_keys!(strategy, [:type], [])
  end

  defp validate_strategy!(%{type: :contrastive_search} = strategy) do
    validate_strategy_keys!(strategy, [:type, :top_k, :alpha], [])
  end

  defp validate_strategy!(%{type: :multinomial_sampling} = strategy) do
    validate_strategy_keys!(strategy, [:type], [:top_k, :top_p])
  end

  defp validate_strategy!(%{type: type}) do
    raise ArgumentError,
          "expected strategy type to be either :greedy_search or :contrastive_search, got: #{inspect(type)}"
  end

  defp validate_strategy!(%{} = other) do
    raise ArgumentError,
          "expected strategy to have :type, but was not present in #{inspect(other)}"
  end

  defp validate_strategy!(other) do
    raise ArgumentError, "expected strategy to be a map, but got: #{inspect(other)}"
  end

  defp validate_strategy_keys!(strategy, required_keys, optional_keys) do
    actual = strategy |> Map.keys() |> Enum.sort()

    missing_keys = Enum.sort(required_keys -- actual)

    if missing_keys != [] do
      raise ArgumentError,
            "missing keys #{inspect(missing_keys)} for strategy #{inspect(strategy.type)}"
    end

    extra_keys = Enum.sort((actual -- required_keys) -- optional_keys)

    if extra_keys != [] do
      raise ArgumentError,
            "unexpected keys #{inspect(extra_keys)} for strategy #{inspect(strategy.type)}"
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(config, data) do
      import Shared.Converters

      # Special case joint configurations
      data =
        case data do
          %{"model_type" => "blip", "text_config" => data} -> data
          data -> data
        end

      data =
        case data do
          # During generation BLIP uses SEP token as the EOS token
          %{"model_type" => "blip_text_model", "sep_token_id" => sep_token_id} ->
            put_in(data["eos_token_id"], sep_token_id)

          data ->
            data
        end

      data =
        case data do
          %{"forced_decoder_ids" => ids} ->
            ids = Enum.reject(ids, &match?([_idx, nil], &1))
            put_in(data["forced_decoder_ids"], ids)

          data ->
            data
        end

      data =
        case data do
          %{"suppress_tokens" => nil} -> Map.delete(data, "suppress_tokens")
          data -> data
        end

      opts =
        convert!(data,
          max_new_tokens: {"max_new_tokens", number()},
          min_new_tokens: {"min_new_tokens", number()},
          max_length: {"max_length", number()},
          min_length: {"min_length", number()},
          decoder_start_token_id: {"decoder_start_token_id", optional(number())},
          bos_token_id: {"bos_token_id", optional(number())},
          eos_token_id: {"eos_token_id", optional(number())},
          pad_token_id: {"pad_token_id", optional(number())},
          forced_bos_token_id: {"forced_bos_token_id", optional(number())},
          forced_eos_token_id: {"forced_eos_token_id", optional(number())},
          forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))},
          suppressed_token_ids: {"suppress_tokens", list(number())},
          no_repeat_ngram_length: {"no_repeat_ngram_size", number()}
        )

      strategy_opts =
        data
        |> convert!(
          sample: {"do_sample", boolean()},
          top_k: {"top_k", number()},
          top_p: {"top_p", number()},
          alpha: {"penalty_alpha", number()}
        )
        |> Map.new()
        |> case do
          %{sample: true} = opts ->
            options =
              Map.filter(opts, fn
                {:top_k, k} when k > 0 -> true
                {:top_p, p} when p < 1.0 -> true
                _ -> false
              end)

            [strategy: Map.merge(%{type: :multinomial_sampling}, options)]

          %{top_k: top_k, alpha: alpha} when top_k > 1 and alpha > 0 ->
            [strategy: %{type: :contrastive_search, top_k: top_k, alpha: alpha}]

          _ ->
            []
        end

      @for.config(config, opts ++ strategy_opts)
    end
  end
end