defmodule Bumblebee.Diffusion.DdimScheduler do
options = [
num_train_steps: [
default: 1000,
doc: "the number of diffusion steps used to train the model"
],
beta_schedule: [
default: :linear,
doc: """
the beta schedule type, a mapping from a beta range to a sequence of betas for stepping the model.
Either of `:linear`, `:quadratic`, or `:squared_cosine`
"""
],
beta_start: [
default: 0.0001,
doc: "the start value for the beta schedule"
],
beta_end: [
default: 0.02,
doc: "the end value for the beta schedule"
],
prediction_type: [
default: :noise,
doc: """
prediction type of the denoising model. Either of:
* `:noise` (default) - the model predicts the noise of the diffusion process
* `:angular_velocity` - the model predicts velocity in angular parameterization.
See Section 2.4 in [Imagen Video: High Definition Video Generation with Diffusion Models](https://imagen.research.google/video/paper.pdf),
then Section 4 in [Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/pdf/2202.00512.pdf)
and Appendix D
"""
],
alpha_clip_strategy: [
default: :one,
doc: ~S"""
each step $t$ uses the values of $\bar{\alpha}\_t$ and $\bar{\alpha}\_{t-1}$,
however for $t = 0$ there is no previous alpha. The strategy can be either
`:one` ($\bar{\alpha}\_{t-1} = 1$) or `:alpha_zero` ($\bar{\alpha}\_{t-1} = \bar{\alpha}\_0$)
"""
],
timesteps_offset: [
default: 0,
doc: ~S"""
an offset added to the inference steps. You can use a combination of `timesteps_offset: 1` and
`alpha_clip_strategy: :alpha_zero`, so that the last step $t = 1$ uses $\bar{\alpha}\_1$
and $\bar{\alpha}\_0$, as done in stable diffusion
"""
],
clip_denoised_sample: [
default: true,
doc: """
whether to clip the predicted denoised sample ($x_0$ in Equation (12)) into $[-1, 1]$
for numerical stability.
"""
],
rederive_noise: [
default: false,
doc: """
whether the noise (output of the denoising model) should be re-derived at each step based on the
predicted denoised sample ($x_0$) and the current sample. This technique is used in OpenAI GLIDE
"""
],
eta: [
default: 0.0,
doc: """
a weight for the noise added in a denoising diffusion step. This scales the value of $\\sigma_t$
in Equation (12) in the original paper, as per Equation (16)
"""
]
]
@moduledoc """
Denoising diffusion implicit models (DDIMs).
This sampling method was proposed as a follow up to the original
denoising diffusion probabilistic models (DDPMs) in order to heavily
reduce the number of steps during inference. DDPMs model the diffusion
process as a Markov chain; DDIMs generalize this considering
non-Markovian diffusion processes that lead to the same objective.
This enables a reverse process with many less samples, as compared
to DDPMs, while using the same denoising model.
DDIMs were shown to be a simple variant of pseudo numerical methods
for diffusion models (PNDMs), see `Bumblebee.Diffusion.PndmScheduler`
and the corresponding paper for more details.
## Configuration
#{Bumblebee.Shared.options_doc(options)}
## References
* [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
"""
defstruct Bumblebee.Shared.option_defaults(options)
@behaviour Bumblebee.Scheduler
@behaviour Bumblebee.Configurable
import Nx.Defn
alias Bumblebee.Diffusion.SchedulerUtils
@impl true
def config(scheduler, opts) do
Bumblebee.Shared.put_config_attrs(scheduler, opts)
end
@impl true
def init(scheduler, num_steps, _sample_shape, opts \\ []) do
opts = Keyword.validate!(opts, [:seed])
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
timesteps =
SchedulerUtils.ddim_timesteps(
scheduler.num_train_steps,
num_steps,
scheduler.timesteps_offset
)
{alpha_bars, prng_key} = init_parameters(scheduler: scheduler, seed: seed)
state = %{
timesteps: timesteps,
timestep_gap: div(scheduler.num_train_steps, num_steps),
alpha_bars: alpha_bars,
iteration: 0,
prng_key: prng_key
}
{state, timesteps}
end
defnp init_parameters(opts \\ []) do
%{
beta_start: beta_start,
beta_end: beta_end,
beta_schedule: beta_schedule,
num_train_steps: num_train_steps
} = opts[:scheduler]
seed = opts[:seed]
prng_key = Nx.Random.key(seed)
betas =
SchedulerUtils.beta_schedule(beta_schedule, num_train_steps,
start: beta_start,
end: beta_end
)
alphas = 1 - betas
{Nx.cumulative_product(alphas), prng_key}
end
@impl true
def step(scheduler, state, sample, prediction) do
do_step(state, sample, prediction, scheduler: scheduler)
end
defnp do_step(state, sample, prediction, opts) do
scheduler = opts[:scheduler]
# See Equation (12)
# Note that in the paper alpha_t represents a cumulative product,
# often denoted as alpha_t with a bar on top. We use an explicit
# alpha_bart_t name for consistency
timestep = state.timesteps[state.iteration]
prev_timestep = timestep - state.timestep_gap
alpha_bar_t = state.alpha_bars[timestep]
alpha_bar_t_prev =
if prev_timestep >= 0 do
state.alpha_bars[prev_timestep]
else
case scheduler.alpha_clip_strategy do
:one -> 1.0
:alpha_zero -> state.alpha_bars[0]
end
end
{pred_denoised_sample, noise} =
case scheduler.prediction_type do
:noise ->
pred_denoised_sample =
(sample - Nx.sqrt(1 - alpha_bar_t) * prediction) / Nx.sqrt(alpha_bar_t)
{pred_denoised_sample, prediction}
:angular_velocity ->
pred_denoised_sample =
Nx.sqrt(alpha_bar_t) * sample - Nx.sqrt(1 - alpha_bar_t) * prediction
noise = Nx.sqrt(alpha_bar_t) * prediction + Nx.sqrt(1 - alpha_bar_t) * sample
{pred_denoised_sample, noise}
end
pred_denoised_sample =
if scheduler.clip_denoised_sample do
Nx.clip(pred_denoised_sample, -1, 1)
else
pred_denoised_sample
end
# See Equation (16)
sigma_t =
scheduler.eta *
Nx.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))
noise =
if scheduler.rederive_noise do
# Re-derive the noise as in GLIDE
(sample - Nx.sqrt(alpha_bar_t) * pred_denoised_sample) / Nx.sqrt(1 - alpha_bar_t)
else
noise
end
pred_sample_direction = Nx.sqrt(1 - alpha_bar_t_prev - Nx.pow(sigma_t, 2)) * noise
prev_sample = Nx.sqrt(alpha_bar_t_prev) * pred_denoised_sample + pred_sample_direction
{prev_sample, next_key} =
if scheduler.eta > 0 do
{rand, next_key} = Nx.Random.normal(state.prng_key, prev_sample)
out = prev_sample + sigma_t * rand
{out, next_key}
else
{prev_sample, state.prng_key}
end
state = %{state | iteration: state.iteration + 1, prng_key: next_key}
{state, prev_sample}
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(scheduler, data) do
import Bumblebee.Shared.Converters
opts =
convert!(data,
num_train_steps: {"num_train_timesteps", number()},
beta_schedule: {
"beta_schedule",
mapping(%{
"linear" => :linear,
"scaled_linear" => :quadratic,
"squaredcos_cap_v2" => :squared_cosine
})
},
beta_start: {"beta_start", number()},
beta_end: {"beta_end", number()},
prediction_type:
{"prediction_type",
mapping(%{"epsilon" => :noise, "v_prediction" => :angular_velocity})},
alpha_clip_strategy: {
"set_alpha_to_one",
mapping(%{true => :one, false => :alpha_zero})
},
timesteps_offset: {"steps_offset", number()},
clip_denoised_sample: {"clip_sample", boolean()},
rederive_noise: {"use_clipped_model_output", boolean()}
)
@for.config(scheduler, opts)
end
end
end