defmodule Bumblebee.Text.GenerationConfig do
alias Bumblebee.Shared
length_options = [
max_new_tokens: [
default: 20,
doc:
"the maximum number of tokens to be generated, ignoring the number of tokens in the prompt"
],
min_new_tokens: [
default: nil,
doc:
"the minimum number of tokens to be generated, ignoring the number of tokens in the prompt"
],
max_length: [
default: nil,
doc: """
the maximum length of the sequence to be generated. Note that this length includes the
length of the input prompt (including padding). In general, prefer `:max_new_tokens`,
which ignores the number of tokens in the prompt
"""
],
min_length: [
default: nil,
doc: """
the minimum length of the sequence to be generated. Note that this length includes the
length of the input prompt (including padding). In general, prefer `:min_new_tokens`,
which ignores the number of tokens in the prompt
"""
]
]
strategy_options = [
strategy: [
default: %{type: :greedy_search},
doc: """
the method deciding how tokens are selected, it has a significant impact on the quality
of the generated sequence. Should be a map with `:type` and strategy-specific options.
* `:greedy_search` - the most straightforward approach, where in
every iteration the most probable token (as given by the model)
is taken.
Example: `%{type: :greedy_search}`.
* `:contrastive_search` - state-of-the-art decoding method, capable
of producing high quality, coherent sequences. The results are
deterministic. See [this article](https://huggingface.co/blog/introducing-csearch)
for more details.
* `:top_k` (required) - the number of highest probability vocabulary tokens considered
as a continuation
* `:alpha` (required) - the weight of degeneration penalty. It balances the model
confidence and the penalty
Example: `%{type: :contrastive_search, top_k: 4, alpha: 0.6}`.
* `:multinomial_sampling` - this method samples tokens according to the probability
distribution given by the model. The results are nondeterministic, unless a seed
is specified.
* `:top_k` (optional) - when specified, restricts sampling to top-k most probable
candidates
* `:top_p` (optional) - when specified, restricts sampling to tokens which probabilities
add up to top-p
"""
]
]
token_options = [
decoder_start_token_id: [
default: nil,
doc:
"the id of the initial token when generating from scratch, in case of encoder-decoder models"
],
forced_bos_token_id: [
default: nil,
doc: "the id of the token to force as the first generated token"
],
forced_eos_token_id: [
default: nil,
doc:
"the id of the token to force as the last generated token when `:max_length` is reached"
],
forced_token_ids: [
default: [],
doc:
"a list of `{index, token_id}` pairs forcing `token_id` to appear at `index` in the generated sequence"
],
suppressed_token_ids: [
default: [],
doc: "a list of token ids to suppress during generation"
],
no_repeat_ngram_length: [
default: nil,
doc: "when set, n-grams of the given length can occur only once in the generated sequence"
]
]
special_token_options = [
bos_token_id: [
default: nil,
doc: "the id of the beginning-of-sequence token"
],
eos_token_id: [
default: nil,
doc: "the id of the end-of-sequence token"
],
pad_token_id: [
default: nil,
doc: "the id of the padding token"
]
]
other_options = [
extra_config: [
default: nil,
doc: "additional configuration specific to the given model"
]
]
options =
length_options ++ strategy_options ++ token_options ++ special_token_options ++ other_options
@moduledoc """
A set of configuration options controlling text generation.
This struct is expected by `Bumblebee.Text.Generation.build_generate/3`.
## Configuration
### Options controlling length
#{Shared.options_doc(length_options)}
### Options controlling strategy
#{Shared.options_doc(strategy_options)}
### Options controlling generated tokens
#{Shared.options_doc(token_options)}
### Special tokens used during generation
#{Shared.options_doc(special_token_options)}
"""
defstruct Shared.option_defaults(options)
@behaviour Bumblebee.Configurable
@type t :: %__MODULE__{}
@impl true
def config(config, opts \\ []) do
opts =
case {opts[:max_new_tokens], opts[:max_length]} do
{nil, nil} ->
opts
{_, nil} ->
put_in(opts[:max_length], nil)
{nil, _} ->
put_in(opts[:max_new_tokens], nil)
_ ->
raise ArgumentError,
"only one of :max_new_tokens or :max_length options must be given, but got both"
end
opts =
case {opts[:min_new_tokens], opts[:min_length]} do
{nil, nil} ->
opts
{_, nil} ->
put_in(opts[:min_length], nil)
{nil, _} ->
put_in(opts[:min_new_tokens], nil)
_ ->
raise ArgumentError,
"only one of :min_new_tokens or :min_length options must be given, but got both"
end
with {:ok, strategy} <- Keyword.fetch(opts, :strategy) do
validate_strategy!(strategy)
end
Shared.put_config_attrs(config, opts)
end
defp validate_strategy!(%{type: :greedy_search} = strategy) do
validate_strategy_keys!(strategy, [:type], [])
end
defp validate_strategy!(%{type: :contrastive_search} = strategy) do
validate_strategy_keys!(strategy, [:type, :top_k, :alpha], [])
end
defp validate_strategy!(%{type: :multinomial_sampling} = strategy) do
validate_strategy_keys!(strategy, [:type], [:top_k, :top_p])
end
defp validate_strategy!(%{type: type}) do
raise ArgumentError,
"expected strategy type to be either :greedy_search or :contrastive_search, got: #{inspect(type)}"
end
defp validate_strategy!(%{} = other) do
raise ArgumentError,
"expected strategy to have :type, but was not present in #{inspect(other)}"
end
defp validate_strategy!(other) do
raise ArgumentError, "expected strategy to be a map, but got: #{inspect(other)}"
end
defp validate_strategy_keys!(strategy, required_keys, optional_keys) do
actual = strategy |> Map.keys() |> Enum.sort()
missing_keys = Enum.sort(required_keys -- actual)
if missing_keys != [] do
raise ArgumentError,
"missing keys #{inspect(missing_keys)} for strategy #{inspect(strategy.type)}"
end
extra_keys = Enum.sort((actual -- required_keys) -- optional_keys)
if extra_keys != [] do
raise ArgumentError,
"unexpected keys #{inspect(extra_keys)} for strategy #{inspect(strategy.type)}"
end
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(config, data) do
import Shared.Converters
# Special case joint configurations
data =
case data do
%{"model_type" => "blip", "text_config" => data} -> data
data -> data
end
data =
case data do
# During generation BLIP uses SEP token as the EOS token
%{"model_type" => "blip_text_model", "sep_token_id" => sep_token_id} ->
put_in(data["eos_token_id"], sep_token_id)
data ->
data
end
data =
case data do
%{"forced_decoder_ids" => ids} ->
ids = Enum.reject(ids, &match?([_idx, nil], &1))
put_in(data["forced_decoder_ids"], ids)
data ->
data
end
data =
case data do
%{"suppress_tokens" => nil} -> Map.delete(data, "suppress_tokens")
data -> data
end
opts =
convert!(data,
max_new_tokens: {"max_new_tokens", number()},
min_new_tokens: {"min_new_tokens", number()},
max_length: {"max_length", number()},
min_length: {"min_length", number()},
decoder_start_token_id: {"decoder_start_token_id", optional(number())},
bos_token_id: {"bos_token_id", optional(number())},
eos_token_id: {"eos_token_id", optional(number())},
pad_token_id: {"pad_token_id", optional(number())},
forced_bos_token_id: {"forced_bos_token_id", optional(number())},
forced_eos_token_id: {"forced_eos_token_id", optional(number())},
forced_token_ids: {"forced_decoder_ids", list(tuple([number(), number()]))},
suppressed_token_ids: {"suppress_tokens", list(number())},
no_repeat_ngram_length: {"no_repeat_ngram_size", number()}
)
strategy_opts =
data
|> convert!(
sample: {"do_sample", boolean()},
top_k: {"top_k", number()},
top_p: {"top_p", number()},
alpha: {"penalty_alpha", number()}
)
|> Map.new()
|> case do
%{sample: true} = opts ->
options =
Map.filter(opts, fn
{:top_k, k} when k > 0 -> true
{:top_p, p} when p < 1.0 -> true
_ -> false
end)
[strategy: Map.merge(%{type: :multinomial_sampling}, options)]
%{top_k: top_k, alpha: alpha} when top_k > 1 and alpha > 0 ->
[strategy: %{type: :contrastive_search, top_k: top_k, alpha: alpha}]
_ ->
[]
end
@for.config(config, opts ++ strategy_opts)
end
end
end