defmodule Bumblebee.Diffusion.LcmScheduler do
options = [
num_train_steps: [
default: 1000,
doc: "the number of diffusion steps used to train the model"
],
beta_schedule: [
default: :quadratic,
doc: """
the beta schedule type, a mapping from a beta range to a sequence of betas for stepping the model.
Either of `:linear`, `:quadratic`, or `:squared_cosine`
"""
],
beta_start: [
default: 0.00085,
doc: "the start value for the beta schedule"
],
beta_end: [
default: 0.012,
doc: "the end value for the beta schedule"
],
prediction_type: [
default: :noise,
doc: """
prediction type of the denoising model. Either of:
* `:noise` (default) - the model predicts the noise of the diffusion process
* `:angular_velocity` - the model predicts velocity in angular parameterization.
See Section 2.4 in [Imagen Video: High Definition Video Generation with Diffusion Models](https://imagen.research.google/video/paper.pdf),
then Section 4 in [Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/pdf/2202.00512.pdf)
and Appendix D
"""
],
alpha_clip_strategy: [
default: :one,
doc: ~S"""
each step $t$ uses the values of $\bar{\alpha}\_t$ and $\bar{\alpha}\_{t-1}$,
however for $t = 0$ there is no previous alpha. The strategy can be either
`:one` ($\bar{\alpha}\_{t-1} = 1$) or `:alpha_zero` ($\bar{\alpha}\_{t-1} = \bar{\alpha}\_0$)
"""
],
clip_denoised_sample: [
default: false,
doc: """
whether to clip the predicted denoised sample ($x_0$ in Equation (12)) into $[-1, 1]$
for numerical stability
"""
],
num_original_steps: [
default: 50,
doc: ~S"""
the number of denoising steps used during Latent Consistency Distillation (LCD).
The LCD procedure distills a base diffusion model, but instead of sampling all
`:num_train_steps` it skips steps and uses another scheduler accordingly. See
Section 4.3
"""
],
boundary_condition_timestep_scale: [
default: 10.0,
doc: ~S"""
the scaling factor used in the consistency function coefficients. In the original
LCM implementation the authors use the formulation
$$
c_{skip}(t) = \frac{\sigma_{data}^2}{(st)^2 + \sigma_{data}^2}, \quad
c_{out}(t) = \frac{st}{\sqrt{(st)^2 + \sigma_{data}^2}}
$$
where $\sigma_{data} = 0.5$ and $s$ is the scaling factor. Increasing the scaling
factor will decrease approximation error, although the approximation error at the
default of `10.0` is already pretty small
"""
]
]
@moduledoc """
Latent Consistency Model (LCM) sampling.
This sampling method should be used in combination with LCM. LCM is
a model distilled from a regular diffusion model to predict the
final denoised sample in a single step. The sample quality can be
improved by alternating a couple denoising and noise injection
steps (multi-step sampling), as per Appendix B.
## Configuration
#{Bumblebee.Shared.options_doc(options)}
## References
* [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378)
* [Consistency Models](https://arxiv.org/pdf/2303.01469.pdf)
"""
defstruct Bumblebee.Shared.option_defaults(options)
@behaviour Bumblebee.Scheduler
@behaviour Bumblebee.Configurable
import Nx.Defn
alias Bumblebee.Diffusion.SchedulerUtils
@impl true
def config(scheduler, opts) do
Bumblebee.Shared.put_config_attrs(scheduler, opts)
end
@impl true
def init(scheduler, num_steps, _sample_template, prng_key, opts \\ []) do
opts = Keyword.validate!(opts, [:strength])
strength = Keyword.get(opts, :strength, 1.0)
timesteps =
timesteps(num_steps, scheduler.num_original_steps, scheduler.num_train_steps, strength)
{alpha_bars} = init_parameters(scheduler: scheduler)
state = %{
timesteps: timesteps,
alpha_bars: alpha_bars,
iteration: 0,
prng_key: prng_key
}
{state, timesteps}
end
deftransformp timesteps(num_steps, num_original_steps, num_train_steps, strength) do
skipping_step_k = div(num_train_steps, num_original_steps)
# Original steps used during Latent Consistency Distillation
original_timesteps =
{floor(num_original_steps * strength)}
|> Nx.iota()
|> Nx.add(1)
|> Nx.multiply(skipping_step_k)
|> Nx.subtract(1)
if num_steps > num_train_steps do
raise ArgumentError,
"expected the number of steps to be less or equal to the number of" <>
" training steps (#{num_train_steps}), got: #{num_steps}"
end
if num_steps > num_original_steps do
raise ArgumentError,
"expected the number of steps to be less or equal to the number of" <>
" original steps (#{num_original_steps}), got: #{num_steps}"
end
if num_steps > Nx.size(original_timesteps) do
raise ArgumentError,
"expected the number of steps to be less or equal to num_original_steps * strength" <>
" (#{num_original_steps} * #{strength}). Either reduce the number of steps or" <>
"increase the strength"
end
# We select evenly spaced indices from the original timesteps.
# See the discussion in https://github.com/huggingface/diffusers/pull/5836
indices =
Nx.linspace(0, Nx.size(original_timesteps), n: num_steps, endpoint: false)
|> Nx.floor()
|> Nx.as_type(:s64)
original_timesteps
|> Nx.reverse()
|> Nx.take(indices)
end
defnp init_parameters(opts \\ []) do
%{
beta_start: beta_start,
beta_end: beta_end,
beta_schedule: beta_schedule,
num_train_steps: num_train_steps
} = opts[:scheduler]
betas =
SchedulerUtils.beta_schedule(beta_schedule, num_train_steps,
start: beta_start,
end: beta_end
)
alphas = 1 - betas
{Nx.cumulative_product(alphas)}
end
@impl true
def step(scheduler, state, sample, prediction) do
do_step(state, sample, prediction, scheduler: scheduler)
end
defnp do_step(state, sample, prediction, opts) do
scheduler = opts[:scheduler]
step_index = state.iteration
prev_step_index = state.iteration + 1
timestep = state.timesteps[step_index]
prev_timestep =
if prev_step_index < Nx.size(state.timesteps) do
state.timesteps[prev_step_index]
else
timestep
end
# Note that in the paper alpha_bar_t is denoted as a(t) and
# beta_bar_t is denoted as sigma(t)^2
alpha_bar_t = state.alpha_bars[timestep]
alpha_bar_t_prev =
if prev_timestep >= 0 do
state.alpha_bars[prev_timestep]
else
case scheduler.alpha_clip_strategy do
:one -> 1.0
:alpha_zero -> state.alpha_bars[0]
end
end
beta_bar_t = 1 - alpha_bar_t
beta_bar_t_prev = 1 - alpha_bar_t_prev
# See Appendix D
pred_denoised_sample =
case scheduler.prediction_type do
:noise ->
(sample - Nx.sqrt(beta_bar_t) * prediction) / Nx.sqrt(alpha_bar_t)
:angular_velocity ->
Nx.sqrt(alpha_bar_t) * sample - Nx.sqrt(beta_bar_t) * prediction
end
pred_denoised_sample =
if scheduler.clip_denoised_sample do
Nx.clip(pred_denoised_sample, -1, 1)
else
pred_denoised_sample
end
{c_skip, c_out} =
consistency_model_coefficients(timestep,
boundary_condition_timestep_scale: scheduler.boundary_condition_timestep_scale
)
# See Equation (9)
denoised_sample = c_skip * sample + c_out * pred_denoised_sample
# See Appendix B
#
# We insert additional noise after each but last step. This also
# means no noise is used for one-step sampling
{prev_sample, next_key} =
if state.iteration < Nx.size(state.timesteps) - 1 do
{rand, next_key} = Nx.Random.normal(state.prng_key, shape: Nx.shape(denoised_sample))
out = Nx.sqrt(alpha_bar_t_prev) * denoised_sample + Nx.sqrt(beta_bar_t_prev) * rand
{out, next_key}
else
{denoised_sample, state.prng_key}
end
state = %{state | iteration: state.iteration + 1, prng_key: next_key}
{state, prev_sample}
end
defnp consistency_model_coefficients(timestep, opts) do
# See Appendix C in https://arxiv.org/pdf/2303.01469.pdf
#
# Note that LCM authors use different coefficients for the
# consistency function than the original CM paper. In their
# formulation the timestep is scaled by a constant factor.
boundary_condition_timestep_scale = opts[:boundary_condition_timestep_scale]
sigma_data = 0.5
scaled_timestep = timestep * boundary_condition_timestep_scale
c_skip = sigma_data ** 2 / (scaled_timestep ** 2 + sigma_data ** 2)
c_out = scaled_timestep / Nx.sqrt(scaled_timestep ** 2 + sigma_data ** 2)
{c_skip, c_out}
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(scheduler, data) do
import Bumblebee.Shared.Converters
opts =
convert!(data,
num_train_steps: {"num_train_timesteps", number()},
beta_schedule: {
"beta_schedule",
mapping(%{
"linear" => :linear,
"scaled_linear" => :quadratic,
"squaredcos_cap_v2" => :squared_cosine
})
},
beta_start: {"beta_start", number()},
beta_end: {"beta_end", number()},
prediction_type:
{"prediction_type",
mapping(%{"epsilon" => :noise, "v_prediction" => :angular_velocity})},
alpha_clip_strategy: {
"set_alpha_to_one",
mapping(%{true => :one, false => :alpha_zero})
},
clip_denoised_sample: {"clip_sample", boolean()},
num_original_steps: {"original_inference_steps", number()},
boundary_condition_timestep_scale: {"timestep_scaling", number()}
)
@for.config(scheduler, opts)
end
end
end