defmodule Bumblebee.Diffusion.StableDiffusion do
@moduledoc """
High-level tasks based on Stable Diffusion.
"""
import Nx.Defn
alias Bumblebee.Utils
alias Bumblebee.Shared
@type text_to_image_input ::
String.t()
| %{
:prompt => String.t(),
optional(:negative_prompt) => String.t() | nil,
optional(:seed) => integer() | nil
}
@type text_to_image_output :: %{results: list(text_to_image_result())}
@type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()}
@doc ~S"""
Build serving for prompt-driven image generation.
The serving accepts `t:text_to_image_input/0` and returns `t:text_to_image_output/0`.
A list of inputs is also supported.
You can specify `:safety_checker` model to automatically detect
when a generated image is offensive or harmful and filter it out.
## Options
* `:safety_checker` - the safety checker model info map. When a
safety checker is used, each output entry has an additional
`:is_safe` property and unsafe images are automatically zeroed.
Make sure to also set `:safety_checker_featurizer`
* `:safety_checker_featurizer` - the featurizer to use to preprocess
the safety checker input images
* `:num_steps` - the number of denoising steps. More denoising
steps usually lead to higher image quality at the expense of
slower inference. Defaults to `50`
* `:num_images_per_prompt` - the number of images to generate for
each prompt. Defaults to `1`
* `:guidance_scale` - the scale used for classifier-free diffusion
guidance. Higher guidance scale makes the generated images more
closely reflect the text prompt. This parameter corresponds to
$\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf).
Defaults to `7.5`
* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:
* `:batch_size` - the maximum batch size of the input. Inputs
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
time.
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. When using this option, you should first
load the parameters into the host. This can be done by passing
`backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends.
Defaults to `false`
## Examples
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
{:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
serving =
Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
num_steps: 20,
num_images_per_prompt: 2,
safety_checker: safety_checker,
safety_checker_featurizer: featurizer,
compile: [batch_size: 1, sequence_length: 60],
defn_options: [compiler: EXLA]
)
prompt = "numbat in forest, detailed, digital art"
Nx.Serving.run(serving, prompt)
#=> %{
#=> results: [
#=> %{
#=> image: #Nx.Tensor<
#=> u8[512][512][3]
#=> ...
#=> >,
#=> is_safe: true
#=> },
#=> %{
#=> image: #Nx.Tensor<
#=> u8[512][512][3]
#=> ...
#=> >,
#=> is_safe: true
#=> }
#=> ]
#=> }
"""
@spec text_to_image(
Bumblebee.model_info(),
Bumblebee.model_info(),
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
Bumblebee.Scheduler.t(),
keyword()
) :: Nx.Serving.t()
def text_to_image(encoder, unet, vae, tokenizer, scheduler, opts \\ []) do
opts =
Keyword.validate!(opts, [
:safety_checker,
:safety_checker_featurizer,
:compile,
num_steps: 50,
num_images_per_prompt: 1,
guidance_scale: 7.5,
defn_options: [],
preallocate_params: false
])
safety_checker = opts[:safety_checker]
safety_checker_featurizer = opts[:safety_checker_featurizer]
num_steps = opts[:num_steps]
num_images_per_prompt = opts[:num_images_per_prompt]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]
if safety_checker != nil and safety_checker_featurizer == nil do
raise ArgumentError, "got :safety_checker but no :safety_checker_featurizer was specified"
end
safety_checker? = safety_checker != nil
compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]
tokenizer =
Bumblebee.configure(tokenizer,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
{_, encoder_predict} = Axon.build(encoder.model)
{_, vae_predict} = Axon.build(vae.model)
{_, unet_predict} = Axon.build(unet.model)
scheduler_init = &Bumblebee.scheduler_init(scheduler, num_steps, &1, &2)
scheduler_step = &Bumblebee.scheduler_step(scheduler, &1, &2, &3)
image_fun =
&text_to_image_impl(
encoder_predict,
&1,
unet_predict,
&2,
vae_predict,
&3,
scheduler_init,
scheduler_step,
&4,
num_images_per_prompt: opts[:num_images_per_prompt],
latents_sample_size: unet.spec.sample_size,
latents_channels: unet.spec.in_channels,
guidance_scale: opts[:guidance_scale]
)
safety_checker_fun =
if safety_checker do
{_, predict_fun} = Axon.build(safety_checker.model)
fn params, input ->
input = Bumblebee.Featurizer.process_batch(safety_checker_featurizer, input)
predict_fun.(params, input)
end
end
# Note that all of these are copied when using serving as a process
init_args = [
{image_fun, safety_checker_fun},
encoder.params,
unet.params,
vae.params,
{safety_checker?, safety_checker[:params]},
safety_checker_featurizer,
{compile != nil, batch_size, sequence_length},
num_images_per_prompt,
preallocate_params
]
Nx.Serving.new(
fn defn_options -> apply(&init/10, init_args ++ [defn_options]) end,
defn_options
)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer))
|> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker))
end
defp init(
{image_fun, safety_checker_fun},
encoder_params,
unet_params,
vae_params,
{safety_checker?, safety_checker_params},
safety_checker_featurizer,
{compile?, batch_size, sequence_length},
num_images_per_prompt,
preallocate_params,
defn_options
) do
encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options)
unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options)
vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options)
image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
inputs = %{
"conditional_and_unconditional" => %{
"input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32)
},
"seed" => Nx.template({batch_size}, :s64)
}
[encoder_params, unet_params, vae_params, inputs]
end)
safety_checker_fun =
safety_checker_fun &&
Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn ->
inputs =
Bumblebee.Featurizer.batch_template(
safety_checker_featurizer,
batch_size * num_images_per_prompt
)
[safety_checker_params, inputs]
end)
safety_checker_params =
safety_checker_params &&
Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options)
fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
image = image_fun.(encoder_params, unet_params, vae_params, inputs)
output =
if safety_checker? do
inputs = Bumblebee.Featurizer.process_input(safety_checker_featurizer, image)
outputs = safety_checker_fun.(safety_checker_params, inputs)
%{image: image, is_unsafe: outputs.is_unsafe}
else
%{image: image}
end
output
|> Utils.Nx.composite_unflatten_batch(Utils.Nx.batch_size(inputs))
|> Shared.serving_post_computation()
end
end
defp client_preprocessing(input, tokenizer) do
{inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)
seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(backend: Nx.BinaryBackend)
# Note: we need to tokenize all sequences together, so that
# they are padded to the same length (if not specified)
prompts = Enum.flat_map(inputs, &[&1.prompt, &1.negative_prompt])
prompt_pairs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
inputs = Bumblebee.apply_tokenizer(tokenizer, prompts)
Utils.Nx.composite_unflatten_batch(inputs, Nx.axis_size(seed, 0))
end)
inputs = %{"conditional_and_unconditional" => prompt_pairs, "seed" => seed}
{Nx.Batch.concatenate([inputs]), multi?}
end
defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do
for outputs <- Utils.Nx.batch_to_list(outputs) do
results =
for outputs = %{image: image} <- Utils.Nx.batch_to_list(outputs) do
if safety_checker? do
if Nx.to_number(outputs.is_unsafe) == 1 do
%{image: zeroed(image), is_safe: false}
else
%{image: image, is_safe: true}
end
else
%{image: image}
end
end
%{results: results}
end
|> Shared.normalize_output(multi?)
end
defp zeroed(tensor) do
0
|> Nx.tensor(type: Nx.type(tensor), backend: Nx.BinaryBackend)
|> Nx.broadcast(Nx.shape(tensor))
end
defnp text_to_image_impl(
encoder_predict,
encoder_params,
unet_predict,
unet_params,
vae_predict,
vae_params,
scheduler_init,
scheduler_step,
inputs,
opts \\ []
) do
num_images_per_prompt = opts[:num_images_per_prompt]
latents_sample_size = opts[:latents_sample_size]
latents_in_channels = opts[:latents_channels]
guidance_scale = opts[:guidance_scale]
seed = inputs["seed"]
inputs =
inputs["conditional_and_unconditional"]
# Transpose conditional and unconditional to separate blocks
|> composite_transpose_leading()
|> Utils.Nx.composite_flatten_batch()
%{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs)
{_twice_batch_size, sequence_length, hidden_size} = Nx.shape(text_embeddings)
text_embeddings =
text_embeddings
|> Nx.new_axis(1)
|> Nx.tile([1, num_images_per_prompt, 1, 1])
|> Nx.reshape({:auto, sequence_length, hidden_size})
prng_key =
seed
|> Nx.vectorize(:batch)
|> Nx.Random.key()
|> Nx.Random.split(parts: num_images_per_prompt)
|> Nx.devectorize()
|> Nx.flatten(axes: [0, 1])
|> Nx.vectorize(:batch)
{latents, prng_key} =
Nx.Random.normal(prng_key,
shape: {latents_sample_size, latents_sample_size, latents_in_channels}
)
{scheduler_state, timesteps} = scheduler_init.(Nx.to_template(latents), prng_key)
latents = Nx.devectorize(latents)
{latents, _} =
while {latents, {scheduler_state, text_embeddings, unet_params}}, timestep <- timesteps do
unet_inputs = %{
"sample" => Nx.concatenate([latents, latents]),
"timestep" => timestep,
"encoder_hidden_state" => text_embeddings
}
%{sample: noise_pred} = unet_predict.(unet_params, unet_inputs)
{noise_pred_conditional, noise_pred_unconditional} =
split_conditional_and_unconditional(noise_pred)
noise_pred =
noise_pred_unconditional +
guidance_scale * (noise_pred_conditional - noise_pred_unconditional)
{scheduler_state, latents} =
scheduler_step.(
scheduler_state,
Nx.vectorize(latents, :batch),
Nx.vectorize(noise_pred, :batch)
)
latents = Nx.devectorize(latents)
{latents, {scheduler_state, text_embeddings, unet_params}}
end
latents = latents * (1 / 0.18215)
%{sample: image} = vae_predict.(vae_params, latents)
NxImage.from_continuous(image, -1, 1)
end
deftransformp composite_transpose_leading(container) do
Utils.Nx.map(container, fn tensor ->
[first, second | rest] = Nx.axes(tensor)
Nx.transpose(tensor, axes: [second, first | rest])
end)
end
defnp split_conditional_and_unconditional(tensor) do
batch_size = Nx.axis_size(tensor, 0)
half_size = div(batch_size, 2)
{tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]}
end
defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt})
defp validate_input(%{prompt: prompt} = input) do
{:ok,
%{
prompt: prompt,
negative_prompt: input[:negative_prompt] || "",
seed: input[:seed] || :erlang.system_time()
}}
end
defp validate_input(%{} = input) do
{:error, "expected the input map to have :prompt key, got: #{inspect(input)}"}
end
defp validate_input(input) do
{:error, "expected either a string or a map, got: #{inspect(input)}"}
end
end