defmodule Bumblebee.Diffusion.PndmScheduler do
options = [
num_train_steps: [
default: 1000,
doc: "the number of diffusion steps used to train the model"
],
beta_schedule: [
default: :linear,
doc: """
the beta schedule type, a mapping from a beta range to a sequence of betas for stepping the model.
Either of `:linear`, `:quadratic`, or `:squared_cosine`
"""
],
beta_start: [
default: 0.0001,
doc: "the start value for the beta schedule"
],
beta_end: [
default: 0.02,
doc: "the end value for the beta schedule"
],
alpha_clip_strategy: [
default: :one,
doc: ~S"""
each step $t$ uses the values of $\bar{\alpha}\_t$ and $\bar{\alpha}\_{t-1}$,
however for $t = 0$ there is no previous alpha. The strategy can be either
`:one` ($\bar{\alpha}\_{t-1} = 1$) or `:alpha_zero` ($\bar{\alpha}\_{t-1} = \bar{\alpha}\_0$)
"""
],
timesteps_offset: [
default: 0,
doc: ~S"""
an offset added to the inference steps. You can use a combination of `timesteps_offset: 1` and
`alpha_clip_strategy: :alpha_zero`, so that the last step $t = 1$ uses $\bar{\alpha}\_1$
and $\bar{\alpha}\_0$, as done in stable diffusion
"""
],
reduce_warmup: [
default: false,
doc: """
when `true`, the first few samples are computed using lower-order linear multi-step,
rather than the Runge-Kutta method. This results in less forward passes of the model
"""
]
]
@moduledoc """
Pseudo numerical methods for diffusion models (PNDMs).
The sampling is based on two numerical methods for solving ODE: the
Runge-Kutta method (RK) and the linear multi-step method (LMS). The
gradient at each step is computed according to either of these methods,
however the transfer part (approximating the next sample based on
current sample and gradient) is non-linear. Because of this property,
the authors of the paper refer to them as pseudo numerical methods,
denoted as PRK and PLMS respectively.
## Configuration
#{Bumblebee.Shared.options_doc(options)}
## References
* [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778)
"""
defstruct Bumblebee.Shared.option_defaults(options)
@behaviour Bumblebee.Scheduler
@behaviour Bumblebee.Configurable
import Nx.Defn
alias Bumblebee.Diffusion.SchedulerUtils
@impl true
def config(scheduler, opts) do
Bumblebee.Shared.put_config_attrs(scheduler, opts)
end
@impl true
def init(scheduler, num_steps, sample_shape) do
timesteps =
timesteps(
scheduler.num_train_steps,
num_steps,
scheduler.timesteps_offset,
scheduler.reduce_warmup
)
alpha_bars = init_parameters(scheduler: scheduler)
empty = Nx.broadcast(0.0, sample_shape)
state = %{
timesteps: timesteps,
timestep_gap: div(scheduler.num_train_steps, num_steps),
alpha_bars: alpha_bars,
iteration: 0,
recent_noise: empty |> List.duplicate(4) |> List.to_tuple(),
current_sample: empty,
noise_prime: empty
}
{state, timesteps}
end
defnp init_parameters(opts \\ []) do
%{
beta_start: beta_start,
beta_end: beta_end,
beta_schedule: beta_schedule,
num_train_steps: num_train_steps
} = opts[:scheduler]
betas =
SchedulerUtils.beta_schedule(beta_schedule, num_train_steps,
start: beta_start,
end: beta_end
)
alphas = 1 - betas
Nx.cumulative_product(alphas)
end
deftransformp timesteps(num_train_steps, num_steps, offset, reduce_warmup) do
# Note that there are more timesteps than `num_steps`.
# That's because each timestep corresponds to a single forward pass
# of the denoising model and the first few steps require multiple such
# passes. In other words, the timesteps list is used for subsequent
# model calls, while the actual sampling happens at the same timesteps
# as with DDIM.
timestep_gap = div(num_train_steps, num_steps)
ddim_timesteps = SchedulerUtils.ddim_timesteps(num_train_steps, num_steps, offset)
if reduce_warmup do
if num_steps < 2 do
raise ArgumentError,
"expected at least 2 steps when using :reduce_warmup, got: #{inspect(num_steps)}"
end
just_plms_timesteps(ddim_timesteps, timestep_gap)
else
if num_steps < 4 do
raise ArgumentError, "expected at least 4 steps, got: #{inspect(num_steps)}"
end
prk_plms_timesteps(ddim_timesteps, timestep_gap)
end
end
defnp prk_plms_timesteps(ddim_timesteps, timestep_gap) do
if Nx.size(ddim_timesteps) < 4 do
prk_timesteps(ddim_timesteps, timestep_gap)
else
prk_timesteps = prk_timesteps(ddim_timesteps[0..2//1], timestep_gap)
plms_timesteps = ddim_timesteps[3..-1//1]
Nx.concatenate([prk_timesteps, plms_timesteps])
end
end
defnp prk_timesteps(timesteps, timestep_gap) do
deltas = Nx.stack([0, div(timestep_gap, 2), div(timestep_gap, 2), timestep_gap])
timesteps
|> Nx.reshape({:auto, 1})
|> Nx.subtract(Nx.reshape(deltas, {1, :auto}))
|> Nx.flatten()
end
defnp just_plms_timesteps(ddim_timesteps, timestep_gap) do
leading_timesteps = Nx.stack([ddim_timesteps[0], ddim_timesteps[0] - timestep_gap])
if Nx.size(ddim_timesteps) < 2 do
leading_timesteps
else
Nx.concatenate([leading_timesteps, ddim_timesteps[1..-1//1]])
end
end
@impl true
def step(scheduler, state, sample, prediction) do
do_step(state, sample, prediction, scheduler: scheduler)
end
defnp do_step(state, sample, noise, opts) do
scheduler = opts[:scheduler]
{state, prev} =
if scheduler.reduce_warmup do
step_just_plms(scheduler, state, sample, noise)
else
step_prk_plms(scheduler, state, sample, noise)
end
state = %{state | iteration: state.iteration + 1}
{state, prev}
end
defnp step_prk_plms(scheduler, state, sample, noise) do
# This is the version from the original paper [1], specifically F-PNDM.
# It uses the Runge-Kutta method to compute the first 3 results (each
# requiring 4 iterations).
#
# [1]: https://arxiv.org/abs/2202.09778
if state.iteration < 12 do
step_prk(scheduler, state, sample, noise)
else
step_plms(scheduler, state, sample, noise)
end
end
defnp step_just_plms(scheduler, state, sample, noise) do
# This alternative version is based on the paper, however instead of the
# Runge-Kutta method, it uses lower-order linear multi-step for computing
# the first 3 results (2, 1, 1 iterations respectively). For the original
# implementation see [1].
#
# [1]: https://github.com/CompVis/latent-diffusion/pull/51
if state.iteration < 4 do
step_warmup_plms(scheduler, state, sample, noise)
else
step_plms(scheduler, state, sample, noise)
end
end
# # Note on notation
#
# The paper denotes sample as x_t, noise as e_t, model as eps, prev_sample
# function as phi. The superscript in case of x_t and e_t translates to
# consecutive iterations, since we have one iteration per model forward
# pass (the eps function). We keep track of x_t as current_sample, and
# noise_prime corresponds to e_t prime.
defnp step_prk(scheduler, state, sample, noise) do
# See Equation (13)
%{noise_prime: noise_prime, current_sample: current_sample} = state
rk_step_number = rem(state.iteration, 4)
state =
if rk_step_number == 0 do
store_noise(state, noise)
else
state
end
{noise_prime, current_sample, noise} =
cond do
rk_step_number == 0 ->
noise_prime = noise_prime + noise / 6
{noise_prime, sample, noise}
rk_step_number == 1 ->
noise_prime = noise_prime + noise / 3
{noise_prime, current_sample, noise}
rk_step_number == 2 ->
noise_prime = noise_prime + noise / 3
{noise_prime, current_sample, noise}
true ->
noise_prime = noise_prime + noise / 6
{Nx.broadcast(0.0, noise_prime), current_sample, noise_prime}
end
state = %{state | current_sample: current_sample, noise_prime: noise_prime}
timestep = state.timesteps[state.iteration - rk_step_number]
diff = if(rk_step_number < 2, do: div(state.timestep_gap, 2), else: state.timestep_gap)
prev_timestep = timestep - diff
prev_sample = prev_sample(scheduler, state, current_sample, noise, timestep, prev_timestep)
{state, prev_sample}
end
defnp step_warmup_plms(scheduler, state, sample, noise) do
# The first two iterations use Equation (22), third iteration uses
# Equation (23), and fourth iteration uses third-order LMS in the
# same spirit.
%{current_sample: current_sample} = state
state =
if state.iteration != 1 do
store_noise(state, noise)
else
state
end
{current_sample, noise} =
cond do
state.iteration == 0 ->
{sample, noise}
state.iteration == 1 ->
noise_prime = (noise + elem(state.recent_noise, 0)) / 2
{current_sample, noise_prime}
state.iteration == 2 ->
noise_prime = (3 * elem(state.recent_noise, 0) - elem(state.recent_noise, 1)) / 2
{sample, noise_prime}
true ->
noise_prime =
(23 * elem(state.recent_noise, 0) - 16 * elem(state.recent_noise, 1) +
5 * elem(state.recent_noise, 2)) / 12
{sample, noise_prime}
end
state = %{state | current_sample: current_sample}
timestep =
if state.iteration == 1 do
state.timesteps[state.iteration - 1]
else
state.timesteps[state.iteration]
end
prev_timestep = timestep - state.timestep_gap
prev_sample = prev_sample(scheduler, state, current_sample, noise, timestep, prev_timestep)
{state, prev_sample}
end
defnp step_plms(scheduler, state, sample, noise) do
# See Equation (12)
state = store_noise(state, noise)
noise =
(55 * elem(state.recent_noise, 0) - 59 * elem(state.recent_noise, 1) +
37 * elem(state.recent_noise, 2) - 9 * elem(state.recent_noise, 3)) / 24
timestep = state.timesteps[state.iteration]
prev_timestep = timestep - state.timestep_gap
prev_sample = prev_sample(scheduler, state, sample, noise, timestep, prev_timestep)
{state, prev_sample}
end
defnp prev_sample(scheduler, state, sample, noise, timestep, prev_timestep) do
# See Equation (11)
alpha_bar_t = state.alpha_bars[timestep]
alpha_bar_t_prev =
if prev_timestep >= 0 do
state.alpha_bars[prev_timestep]
else
case scheduler.alpha_clip_strategy do
:one -> 1.0
:alpha_zero -> state.alpha_bars[0]
end
end
sample_coeff = (alpha_bar_t_prev / alpha_bar_t) ** 0.5
noise_coeff = alpha_bar_t_prev - alpha_bar_t
noise_denom_coeff =
alpha_bar_t * (1 - alpha_bar_t_prev) ** 0.5 +
(alpha_bar_t * (1 - alpha_bar_t) * alpha_bar_t_prev) ** 0.5
sample_coeff * sample - noise_coeff * noise / noise_denom_coeff
end
deftransformp store_noise(state, noise) do
recent_noise =
state.recent_noise
|> Tuple.delete_at(tuple_size(state.recent_noise) - 1)
|> Tuple.insert_at(0, noise)
%{state | recent_noise: recent_noise}
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(scheduler, data) do
import Bumblebee.Shared.Converters
opts =
convert!(data,
num_train_steps: {"num_train_timesteps", number()},
beta_schedule: {
"beta_schedule",
mapping(%{
"linear" => :linear,
"scaled_linear" => :quadratic,
"squaredcos_cap_v2" => :squared_cosine
})
},
beta_start: {"beta_start", number()},
beta_end: {"beta_end", number()},
alpha_clip_strategy: {
"set_alpha_to_one",
mapping(%{true => :one, false => :alpha_zero})
},
timesteps_offset: {"steps_offset", number()},
reduce_warmup: {"skip_prk_steps", boolean()}
)
@for.config(scheduler, opts)
end
end
end