defmodule Bumblebee.Text.Generation do
@moduledoc """
An interface for language models supporting sequence generation.
"""
@type cache :: Nx.Tensor.t() | Nx.Container.t()
@doc """
Initializes an opaque cache input for iterative inference.
"""
@callback init_cache(
spec :: Bumblebee.ModelSpec.t(),
batch_size :: pos_integer(),
max_length :: pos_integer(),
inputs :: map()
) :: cache()
@doc """
Traverses all batched tensors in the cache.
This function is used when the cache needs to be inflated or
deflated for a different batch size.
"""
@callback traverse_cache(
spec :: Bumblebee.ModelSpec.t(),
cache(),
(Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()
import Nx.Defn
alias Bumblebee.Shared
alias Bumblebee.Utils
alias Bumblebee.Text
@doc """
Initializes an opaque cache input for iterative inference.
"""
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
def init_cache(%module{} = spec, batch_size, max_length, inputs) do
module.init_cache(spec, batch_size, max_length, inputs)
end
@doc """
Calls `fun` for every batched tensor in the cache.
"""
@spec traverse_cache(Bumblebee.ModelSpec.t(), cache, (Nx.Tensor.t() -> Nx.Tensor.t())) ::
cache()
def traverse_cache(%module{} = spec, cache, fun) do
module.traverse_cache(spec, cache, fun)
end
@doc """
Builds a numerical definition that generates sequences of tokens using
the given language model.
The model should be either a decoder or an encoder-decoder. The tokens
are generated by iterative inference using the decoder (autoregression),
until the termination criteria are met.
In case of encoder-decoder models, the corresponding encoder is run
only once and the intermediate state is reused during all iterations.
The generation is controlled by a number of options given as
`%Bumblebee.Text.GenerationConfig{}`, see the corresponding docs
for more details.
## Options
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
"""
@spec build_generate(
Axon.t(),
Bumblebee.ModelSpec.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) :: (params :: map(), inputs :: map() -> Nx.t())
def build_generate(model, spec, config, opts \\ []) do
opts = Keyword.validate!(opts, [:seed])
seed = Keyword.get_lazy(opts, :seed, &:erlang.system_time/0)
decoder_start_token_id = config.decoder_start_token_id || config.bos_token_id
eos_token_id = config.eos_token_id
pad_token_id = config.pad_token_id || config.eos_token_id
{max_length_fun, min_length_fun} = lazy_lengths_from_opts(config)
{prepare_inputs_fun, update_inputs_fun} =
input_callbacks(model, spec, max_length_fun, decoder_start_token_id)
traverse_cache_fun = &traverse_cache(spec, &1, &2)
model =
if not spec.output_hidden_states and config.strategy.type == :contrastive_search do
spec
|> Bumblebee.configure(output_hidden_states: true)
|> Bumblebee.build_model()
else
model
end
{_init_fun, predict_fun} = Axon.build(model)
logits_processor_fun = get_logits_processor(min_length_fun, eos_token_id, config)
&generate_impl(
&2,
predict_fun,
&1,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id,
seed: seed,
strategy: config.strategy
)
end
defp lazy_lengths_from_opts(opts) do
max_length_fun =
case {opts.max_new_tokens, opts.max_length} do
{nil, nil} ->
raise ArgumentError,
"expected either :max_new_tokens or :max_length option, but neither was given"
{max_new_tokens, nil} ->
fn input_length -> input_length + max_new_tokens end
{nil, max_length} ->
fn _ -> max_length end
end
min_length_fun =
case {opts.min_new_tokens, opts.min_length} do
{nil, nil} ->
nil
{min_new_tokens, nil} ->
fn input_length -> input_length + min_new_tokens end
{nil, min_length} ->
fn _ -> min_length end
end
{max_length_fun, min_length_fun}
end
defp encoder_from_encoder_decoder(model) do
# We cherry-pick encoder outputs from the encoder-decoder outputs.
# The expanded expression will have no decoder bits, so it will
# effectively be the same as an encoder built from scratch
Axon.nx(model, fn outputs ->
case outputs do
%{
encoder_hidden_state: hidden_state,
encoder_hidden_states: hidden_states,
encoder_attentions: attentions
} ->
%{
hidden_state: hidden_state,
hidden_states: hidden_states,
attentions: attentions
}
_ ->
raise ArgumentError,
"expected an encoder-decoder model, but it does not have the expected outputs"
end
end)
end
defp input_callbacks(model, spec, max_length_fun, decoder_start_token_id) do
if encoder_decoder?(model) do
encoder = encoder_from_encoder_decoder(model)
{_encoder_init_fun, encoder_predict_fun} = Axon.build(encoder)
prepare_inputs_fun = fn inputs, params ->
encoder_outputs = encoder_predict_fun.(params, inputs)
batch_size = Nx.axis_size(encoder_input(inputs), 0)
decoder_input_ids = Nx.broadcast(decoder_start_token_id, {batch_size, 1})
inputs =
Map.merge(inputs, %{
"encoder_hidden_state" => encoder_outputs.hidden_state,
"decoder_input_ids" => decoder_input_ids
})
max_length = max_length_fun.(1)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, max_length)
{inputs, inputs["decoder_input_ids"], max_length}
end
update_inputs_fun = &update_decoder_inputs("decoder_", &1, &2, &3)
{prepare_inputs_fun, update_inputs_fun}
else
prepare_inputs_fun = fn inputs, _params ->
sequence_length = Nx.axis_size(inputs["input_ids"], 1)
max_length = max_length_fun.(sequence_length)
inputs = prepare_decoder_inputs(inputs, "", spec, max_length)
{inputs, inputs["input_ids"], max_length}
end
update_inputs_fun = &update_decoder_inputs("", &1, &2, &3)
{prepare_inputs_fun, update_inputs_fun}
end
end
defp encoder_decoder?(model) do
inputs = Axon.get_inputs(model)
encoder_input(inputs) != nil and Map.has_key?(inputs, "decoder_input_ids")
end
defp encoder_input(inputs) do
inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
end
defp prepare_decoder_inputs(inputs, prefix, spec, max_length) do
input_ids = inputs[prefix <> "input_ids"]
attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1.0, input_ids)
position_ids =
attention_mask
|> Nx.cumulative_sum(axis: 1)
|> Nx.subtract(1)
inputs =
inputs
|> Map.put(prefix <> "attention_mask", attention_mask)
|> Map.put(prefix <> "position_ids", position_ids)
batch_size = Nx.axis_size(input_ids, 0)
cache = init_cache(spec, batch_size, max_length, inputs)
Map.put(inputs, "cache", cache)
end
defp update_decoder_inputs(prefix, inputs, cache, token_ids) do
inputs
|> Map.replace!(prefix <> "input_ids", token_ids)
|> Map.replace!(prefix <> "attention_mask", Nx.broadcast(1.0, token_ids))
|> Map.update!(prefix <> "position_ids", fn position_ids ->
position_ids
|> Nx.slice_along_axis(Nx.axis_size(position_ids, -1) - 1, 1, axis: -1)
|> Nx.add(1)
end)
|> Map.replace!("cache", cache)
end
defp get_logits_processor(min_length_fun, eos_token_id, config) do
import Bumblebee.Text.Generation.LogitsProcessing
processors =
[
if config.no_repeat_ngram_length && config.no_repeat_ngram_length > 0 do
&no_repeat_ngram_processor(&1, &2, ngram_length: config.no_repeat_ngram_length)
end,
if min_length_fun && eos_token_id do
&min_length_processor(&1, &2,
min_length_fun: min_length_fun,
eos_token_id: eos_token_id
)
end,
if config.forced_bos_token_id do
&bos_token_processor(&1, &2, bos_token_id: config.forced_bos_token_id)
end,
if config.forced_eos_token_id do
&eos_token_processor(&1, &2, eos_token_id: config.forced_eos_token_id)
end,
if config.forced_token_ids do
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end
] ++
if config.strategy.type == :multinomial_sampling do
[
if top_k = config.strategy[:top_k] do
&top_k_processor(&1, &2, top_k: top_k)
end,
if top_p = config.strategy[:top_p] do
&top_p_processor(&1, &2, top_p: top_p)
end
]
else
[]
end
fn logits, context ->
for processor <- processors, processor, reduce: logits do
logits -> processor.(logits, context)
end
end
end
deftransformp generate_impl(
inputs,
predict_fun,
params,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
{decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)
length = Nx.axis_size(decoder_input_ids, 1)
if length >= max_length do
raise ArgumentError,
"the input sequence has #{length} tokens, but max_length is set to #{max_length}." <>
" Consider increasing :max_new_tokens (or :max_length), or padding the input to a shorter length"
end
strategy = opts[:strategy]
seed = opts[:seed]
case strategy.type do
:greedy_search ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length] ++ opts
)
:contrastive_search ->
contrastive(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
[max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha] ++ opts
)
:multinomial_sampling ->
prng_key = Nx.Random.key(seed)
sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length, prng_key: prng_key] ++ opts
)
end
end
# Greedy search
defnp greedy(
inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
opts \\ []
) do
max_length = opts[:max_length]
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
{sequences, length = input_length, finished?} =
init_sequences(decoder_input_ids, max_length, pad_token_id)
# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs} =
if length > 1 do
greedy_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs}
end
{sequences, _length, _finished?, _inputs, _params} =
while {sequences, length, finished?, inputs, params},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs} =
greedy_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)
{sequences, length, finished?, inputs, params}
end
sequences
end
defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
{batch_size, length} = Nx.shape(decoder_input_ids)
if length > max_length do
raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}"
end
sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)
finished? = Nx.broadcast(Nx.tensor(0, type: :u8), {batch_size})
{sequences, length, finished?}
end
defnp continue?(finished?, length, max_length) do
not Nx.all(finished?) and length < max_length
end
defnp greedy_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
opts
) do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
outputs = predict_fun.(params, inputs)
logits = outputs.logits[[.., -1]]
logits =
logits_processor_fun.(logits, %{
sequences: sequences,
length: length,
input_length: input_length
})
token_id = Nx.argmax(logits, axis: -1)
{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))
{sequences, length, finished?, inputs}
end
defnp update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id) do
token_id = Nx.select(finished?, pad_token_id, token_id)
finished? =
case eos_token_id do
nil -> finished?
eos_token_id -> finished? or token_id == eos_token_id
end
{token_id, finished?} = hook({token_id, finished?}, :token)
token_ids = Nx.new_axis(token_id, -1)
sequences = Nx.put_slice(sequences, [0, length], token_ids)
{sequences, length + 1, finished?}
end
# Contrastive search
defnp contrastive(
inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
max_length = opts[:max_length]
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
top_k = opts[:top_k]
penalty_alpha = opts[:penalty_alpha]
{sequences, length = input_length, finished?} =
init_sequences(decoder_input_ids, max_length, pad_token_id)
# Step (1)
# Initial pass to obtain hidden state and expand inputs to top-k
outputs = predict_fun.(params, inputs)
# Later, we feed model a single token at a time and reuse previous
# results using cache. Here we need the final hidden state, so we
# need to keep track of it in a similar way
initial_hidden_state = decoder_hidden_state(outputs)
batch_size = Nx.axis_size(initial_hidden_state, 0)
hidden_size = Nx.axis_size(initial_hidden_state, -1)
joint_hidden_state = Nx.broadcast(0.0, {batch_size, max_length, hidden_size})
joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state)
logits = outputs.logits[[.., -1]]
logits =
logits_processor_fun.(logits, %{
sequences: sequences,
length: length,
input_length: input_length
})
scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)
# For subsequent model passes we consider several (top-k) paths
# for each batch item, so we duplicate inputs and cache accordingly
inputs = expand_inputs(inputs, top_k)
cache = expand_cache(outputs.cache, top_k, traverse_cache_fun)
inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))
# Step (2)
# In the loop we make prediction for top-k continuation tokens and
# pick the best one using the contrastive rank. From the same model
# pass we also get the next top-k continuation tokens
{sequences, _length, _finished?, _inputs, _params, _joint_hidden_state, _top_k_values} =
while {sequences, length, finished?, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}},
continue?(finished?, length, max_length) do
outputs = predict_fun.(params, inputs)
hidden_state = decoder_hidden_state(outputs)
context_hidden_state = Utils.Nx.repeat_interleave(joint_hidden_state, top_k)
selected_idx =
contrastive_rank(
context_hidden_state,
hidden_state,
length,
top_k_scores,
penalty_alpha,
top_k
)
hidden_state = Utils.Nx.chunked_take(hidden_state, top_k, selected_idx)
joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, length, 0], hidden_state)
token_id = top_k_token_ids |> Nx.flatten() |> Utils.Nx.chunked_take(top_k, selected_idx)
{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
logits =
logits_processor_fun.(logits, %{
sequences: sequences,
length: length,
input_length: input_length
})
scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)
# Mirror the selected idx to other entries within each chunk
cache = reflect_cache(outputs.cache, top_k, selected_idx, traverse_cache_fun)
inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))
{sequences, length, finished?, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}}
end
sequences
end
deftransformp decoder_hidden_state(outputs) do
hidden_states =
case outputs do
%{decoder_hidden_states: hidden_states} -> hidden_states
%{hidden_states: hidden_states} -> hidden_states
end
elem(hidden_states, tuple_size(hidden_states) - 1)
end
deftransformp expand_inputs(inputs, times) do
Map.new(inputs, fn
{key, value} when key in ["cache"] ->
{key, value}
{key, %Nx.Tensor{} = value} ->
{key, Utils.Nx.repeat_interleave(value, times)}
end)
end
deftransformp expand_cache(cache, times, traverse_cache_fun) do
traverse_cache_fun.(cache, &Utils.Nx.repeat_interleave(&1, times))
end
deftransformp reflect_cache(cache, times, idx, traverse_cache_fun) do
traverse_cache_fun.(
cache,
&(&1
|> Utils.Nx.chunked_take(times, idx)
|> Utils.Nx.repeat_interleave(times))
)
end
defnp contrastive_rank(
context_hidden_state,
hidden_state,
length,
top_k_scores,
penalty_alpha,
top_k
) do
similarity_matrix =
context_hidden_state
|> Bumblebee.Utils.Nx.cosine_similarity(hidden_state, batched?: true)
# hidden_state has sequence length of 1, so the batch of similarity
# matrices has shape {batch_size * top_k, max_length, 1} and we
# flatten out the last dimension
|> Nx.squeeze(axes: [-1])
# context_hidden_state includes placeholder values for tokens up
# to max_length, so we need to ignore these
current_sequence? = Nx.iota(Nx.shape(similarity_matrix), axis: -1) < length
degeneration_penalty =
current_sequence?
|> Nx.select(similarity_matrix, Nx.Constants.neg_infinity())
|> Nx.reduce_max(axes: [-1])
contrastive_score =
(1.0 - penalty_alpha) * Nx.flatten(top_k_scores) - penalty_alpha * degeneration_penalty
contrastive_score
|> Nx.reshape({:auto, top_k})
|> Nx.argmax(axis: -1)
end
# Multinomial sampling
defnp sampling(
inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
opts \\ []
) do
max_length = opts[:max_length]
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
prng_key = opts[:prng_key]
{sequences, length = input_length, finished?} =
init_sequences(decoder_input_ids, max_length, pad_token_id)
# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs, prng_key} =
if length > 1 do
sampling_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs, prng_key}
end
{sequences, _length, _finished?, _inputs, _params, _key} =
while {sequences, length, finished?, inputs, params, prng_key},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs, prng_key} =
sampling_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
pad_token_id: pad_token_id,
eos_token_id: eos_token_id
)
{sequences, length, finished?, inputs, params, prng_key}
end
sequences
end
defnp sampling_step(
sequences,
length,
finished?,
inputs,
input_length,
predict_fun,
params,
prng_key,
logits_processor_fun,
update_inputs_fun,
opts \\ []
) do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]
key = Nx.Random.split(prng_key)
{key, prng_key} = {key[1], key[0]}
outputs = predict_fun.(params, inputs)
logits = outputs.logits[[.., -1]]
logits =
logits_processor_fun.(logits, %{
sequences: sequences,
length: length,
input_length: input_length
})
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)
{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))
{sequences, length, finished?, inputs, prng_key}
end
deftransformp batched_choice(key, scores) do
{batch_size, vocab_size} = Nx.shape(scores)
vocab = Nx.iota({vocab_size})
keys = Nx.Random.split(key, parts: batch_size)
key = Nx.vectorize(keys, :batch)
probabilities = Nx.vectorize(scores, :batch)
{tokens, _} = Nx.Random.choice(key, vocab, probabilities, samples: 1)
tokens
|> Nx.squeeze()
|> Nx.devectorize()
end
# Serving
@doc false
def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do
opts =
Keyword.validate!(opts, [
:seed,
:compile,
defn_options: [],
preallocate_params: false,
stream: false
])
%{model: model, params: params, spec: spec} = model_info
Shared.validate_architecture!(spec, [
:for_conditional_generation,
:for_causal_language_modeling
])
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]
compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]
generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))
batch_keys = Shared.sequence_batch_keys(sequence_length)
Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
}
[params, inputs]
end)
fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
end
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
if opts[:stream] do
Shared.validate_input_for_stream!(input)
end
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)
inputs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)
end)
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
decoded
|> Enum.map(&%{results: [%{text: &1}]})
|> Shared.normalize_output(multi?)
end)
|> maybe_stream(opts[:stream], tokenizer)
end
defp maybe_stream(serving, false, _tokenizer), do: serving
defp maybe_stream(serving, true, tokenizer) do
serving
|> Nx.Serving.streaming(hooks: [:token])
|> Nx.Serving.client_postprocessing(fn stream, false = _multi? ->
Stream.transform(stream, %{tokens: [], consumed_size: 0, finished?: false}, fn
_event, %{finished?: true} = state ->
{:halt, state}
{:token, {token_id, finished?}}, state ->
token_id = Nx.to_number(token_id[0])
finished? = Nx.to_number(finished?[0]) == 1
state = %{state | tokens: state.tokens ++ [token_id], finished?: finished?}
chunk = pending_chunk(tokenizer, state)
cond do
# When the sequence is finished early or we reach a newline,
# we flush the cache
finished? or String.ends_with?(chunk, "\n") ->
{[chunk], %{state | tokens: [], consumed_size: 0}}
# CJK characters are tokenized atomically, so we can emit
# the chunk
chunk != "" and cjk_codepoint?(last_codepoint(chunk)) ->
state = update_in(state.consumed_size, &(&1 + byte_size(chunk)))
{[chunk], state}
# Emit chunk until the space. We need to keep tokens,
# because certain tokenizers do not encode whitespace in
# tokens and they add a space based on previous tokens
space_idx = find_last_occurrence(chunk, " ") ->
if space_idx > 0 do
chunk = binary_slice(chunk, 0, space_idx)
state = update_in(state.consumed_size, &(&1 + space_idx))
{[chunk], state}
else
{[], state}
end
true ->
{[], state}
end
{:done, _, _}, state ->
chunk = pending_chunk(tokenizer, state)
if chunk == "" do
{:halt, state}
else
{[chunk], %{state | tokens: [], consumed_size: 0}}
end
end)
end)
end
defp pending_chunk(tokenizer, state) do
text = Bumblebee.Tokenizer.decode(tokenizer, state.tokens)
binary_slice(text, state.consumed_size..-1//1)
end
defp find_last_occurrence(string, pattern) do
case :binary.matches(string, pattern) do
[] -> nil
matches -> matches |> List.last() |> elem(0)
end
end
defp last_codepoint(<<codepoint::utf8>>), do: codepoint
defp last_codepoint(<<_::utf8, rest::binary>>), do: last_codepoint(rest)
defp cjk_codepoint?(codepoint) do
# The specific ranges originated in [1] and are generally mirrored
# in other tokenizers using WordPiece. Also see [2].
#
# [1]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/tokenization.py#L264-L284
# [2]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/multilingual.md#tokenization
codepoint in 0x4E00..0x9FFF or
codepoint in 0x3400..0x4DBF or
codepoint in 0x20000..0x2A6DF or
codepoint in 0x2A700..0x2B73F or
codepoint in 0x2B740..0x2B81F or
codepoint in 0x2B820..0x2CEAF or
codepoint in 0xF900..0xFAFF or
codepoint in 0x2F800..0x2FA1F
end
end