lib/bumblebee/diffusion/vae_kl.ex

defmodule Bumblebee.Diffusion.VaeKl do
  alias Bumblebee.Shared

  options = [
    sample_size: [
      default: 32,
      doc: "the size of the input spatial dimensions"
    ],
    in_channels: [
      default: 3,
      doc: "the number of channels in the input"
    ],
    out_channels: [
      default: 3,
      doc: "the number of channels in the output"
    ],
    latent_channels: [
      default: 4,
      doc: "the number of channels in the latent space"
    ],
    hidden_sizes: [
      default: [64],
      doc: "the dimensionality of hidden layers in each upsample/downsample block"
    ],
    depth: [
      default: 1,
      doc: "the number of residual blocks in each upsample/downsample block"
    ],
    down_block_types: [
      default: [:down_block],
      doc: "a list of downsample block types. Currently the only supported type is `:down_block`"
    ],
    up_block_types: [
      default: [:up_block],
      doc: "a list of upsample block types. Currently the only supported type is `:up_block`"
    ],
    activation: [
      default: :silu,
      doc: "the activation function"
    ]
  ]

  @moduledoc """
  Variational autoencoder (VAE) with Kullback–Leibler divergence (KL) loss.

  ## Architectures

    * `:base` - the entire VAE model

    * `:encoder` - just the encoder part of the base model

    * `:decoder` - just the decoder part of the base model

  ## Inputs

    * `"sample"` - `{batch_size, sample_size, sample_size, in_channels}`

      Sample input with two spatial dimensions. Note that in case of
      the `:decoder` model, the input usually has lower dimensionality.

    * `"sample_posterior"` - `{}`

      When `true`, the decoder input is sampled from the encoder output
      distribution. Otherwise the distribution mode value is used instead.
      This input is only relevant for the `:base` model. Defaults to `false`.

  ## Configuration

  #{Shared.options_doc(options)}

  ## References

    * [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114)

  """

  defstruct [architecture: :base] ++ Shared.option_defaults(options)

  @behaviour Bumblebee.ModelSpec
  @behaviour Bumblebee.Configurable

  import Bumblebee.Utils.Model, only: [join: 2]

  alias Bumblebee.Layers
  alias Bumblebee.Diffusion

  @impl true
  def architectures(), do: [:base, :encoder, :decoder]

  @impl true
  def config(spec, opts \\ []) do
    Shared.put_config_attrs(spec, opts)
  end

  @impl true
  def input_template(spec) do
    sample_shape = spec |> sample_shape() |> put_elem(0, 1)
    %{"sample" => Nx.template(sample_shape, :f32)}
  end

  @impl true
  def model(%__MODULE__{architecture: :base} = spec) do
    inputs = inputs(spec)
    sample = core(inputs, spec)
    Layers.output(%{sample: sample})
  end

  def model(%__MODULE__{architecture: :encoder} = spec) do
    inputs = inputs(spec)
    posterior = encoder(inputs["sample"], spec, name: "encoder")
    Layers.output(%{latent_dist: Axon.container(posterior)})
  end

  def model(%__MODULE__{architecture: :decoder} = spec) do
    inputs = inputs(spec)
    sample = decoder(inputs["sample"], spec, name: "decoder")
    Layers.output(%{sample: sample})
  end

  defp inputs(%__MODULE__{architecture: :base} = spec) do
    Bumblebee.Utils.Model.inputs_to_map([
      Axon.input("sample", shape: sample_shape(spec)),
      Axon.input("sample_posterior", shape: {}, optional: true)
    ])
  end

  defp inputs(spec) do
    Bumblebee.Utils.Model.inputs_to_map([
      Axon.input("sample", shape: sample_shape(spec))
    ])
  end

  defp sample_shape(%__MODULE__{architecture: :decoder} = spec) do
    downsample_rate = (length(spec.hidden_sizes) - 1) * 2
    size = div(spec.sample_size, downsample_rate)
    {nil, size, size, spec.latent_channels}
  end

  defp sample_shape(spec) do
    {nil, spec.sample_size, spec.sample_size, spec.in_channels}
  end

  defp core(inputs, spec) do
    x = inputs["sample"]

    sample_posterior =
      Layers.default inputs["sample_posterior"] do
        Axon.constant(Nx.tensor(0, type: {:u, 8}))
      end

    posterior = encoder(x, spec, name: "encoder")

    z =
      Axon.cond(
        sample_posterior,
        &Nx.equal(&1, Nx.tensor(1)),
        sample(posterior, name: "sample"),
        mode(posterior)
      )

    decoder(z, spec, name: "decoder")
  end

  defp encoder(x, spec, opts) do
    name = opts[:name]

    x
    |> Axon.conv(hd(spec.hidden_sizes),
      kernel_size: 3,
      strides: 1,
      padding: [{1, 1}, {1, 1}],
      name: join(name, "input_conv")
    )
    |> down_blocks(spec, name: join(name, "down_blocks"))
    |> mid_block(spec, name: join(name, "mid_block"))
    |> Axon.group_norm(32, epsilon: 1.0e-6, name: join(name, "output_norm"))
    |> Axon.activation(:silu)
    |> Axon.conv(2 * spec.latent_channels,
      kernel_size: 3,
      padding: [{1, 1}, {1, 1}],
      name: join(name, "output_conv")
    )
    # Note: plain VAE doesn't employ vector quantization (VQ), however
    # the reference implementation uses the pre and post quantization
    # convolutions as in VQ-VAE, so we reflect this naming here too
    |> Axon.conv(2 * spec.latent_channels,
      kernel_size: 1,
      name: join(name, "pre_quantization_conv")
    )
    |> diagonal_gaussian_distribution()
  end

  defp decoder(z, spec, opts) do
    name = opts[:name]

    z
    |> Axon.conv(spec.latent_channels,
      kernel_size: 1,
      name: join(name, "post_quantization_conv")
    )
    |> Axon.conv(List.last(spec.hidden_sizes),
      kernel_size: 3,
      strides: 1,
      padding: [{1, 1}, {1, 1}],
      name: join(name, "input_conv")
    )
    |> mid_block(spec, name: join(name, "mid_block"))
    |> up_blocks(spec, name: join(name, "up_blocks"))
    |> Axon.group_norm(32, epsilon: 1.0e-6, name: join(name, "output_norm"))
    |> Axon.activation(:silu)
    |> Axon.conv(spec.out_channels,
      kernel_size: 3,
      padding: [{1, 1}, {1, 1}],
      name: join(name, "output_conv")
    )
  end

  defp down_blocks(sample, spec, opts) do
    name = opts[:name]
    blocks = Enum.zip(spec.hidden_sizes, spec.down_block_types)

    acc = {sample, hd(spec.hidden_sizes)}

    {sample, _} =
      for {{output_channel, down_block_type}, idx} <- Enum.with_index(blocks), reduce: acc do
        {sample, in_channels} ->
          last_block? = idx == length(spec.hidden_sizes) - 1

          block_opts = [
            depth: spec.depth,
            in_channels: in_channels,
            out_channels: output_channel,
            add_downsample: not last_block?,
            activation: spec.activation,
            name: join(name, idx)
          ]

          {do_down_block(down_block_type, sample, block_opts), output_channel}
      end

    sample
  end

  defp do_down_block(:down_block, sample, opts),
    do: down_block(sample, opts)

  defp down_block(hidden_state, opts) do
    in_channels = opts[:in_channels]
    out_channels = opts[:out_channels]
    depth = opts[:depth]
    activation = opts[:activation]
    add_downsample = opts[:add_downsample]
    name = opts[:name]

    hidden_state =
      for idx <- 0..(depth - 1), reduce: hidden_state do
        hidden_state ->
          in_channels = if(idx == 0, do: in_channels, else: out_channels)

          Diffusion.Layers.residual_block(hidden_state, in_channels, out_channels,
            activation: activation,
            name: name |> join("residual_blocks") |> join(idx)
          )
      end

    if add_downsample do
      Diffusion.Layers.downsample_2d(hidden_state, out_channels,
        padding: 0,
        name: join(name, "downsamples.0")
      )
    else
      hidden_state
    end
  end

  defp up_blocks(sample, spec, opts) do
    name = opts[:name]
    reversed_hidden_sizes = Enum.reverse(spec.hidden_sizes)
    blocks = Enum.zip(reversed_hidden_sizes, spec.up_block_types)

    acc = {sample, hd(reversed_hidden_sizes)}

    {sample, _} =
      for {{output_channel, up_block_type}, idx} <- Enum.with_index(blocks), reduce: acc do
        {sample, in_channels} ->
          last_block? = idx == length(spec.hidden_sizes) - 1

          block_opts = [
            depth: spec.depth + 1,
            in_channels: in_channels,
            out_channels: output_channel,
            add_upsample: not last_block?,
            activation: spec.activation,
            name: join(name, idx)
          ]

          {do_up_block(up_block_type, sample, block_opts), output_channel}
      end

    sample
  end

  defp do_up_block(:up_block, sample, opts),
    do: up_block(sample, opts)

  defp up_block(hidden_state, opts) do
    in_channels = opts[:in_channels]
    out_channels = opts[:out_channels]
    depth = opts[:depth]
    activation = opts[:activation]
    add_upsample = opts[:add_upsample]
    name = opts[:name]

    hidden_state =
      for idx <- 0..(depth - 1), reduce: hidden_state do
        hidden_state ->
          in_channels = if(idx == 0, do: in_channels, else: out_channels)

          Diffusion.Layers.residual_block(hidden_state, in_channels, out_channels,
            activation: activation,
            name: name |> join("residual_blocks") |> join(idx)
          )
      end

    if add_upsample do
      Diffusion.Layers.upsample_2d(hidden_state, out_channels, name: join(name, "upsamples.0"))
    else
      hidden_state
    end
  end

  defp mid_block(hidden_state, spec, opts) do
    name = opts[:name]

    in_channels = List.last(spec.hidden_sizes)

    hidden_state
    |> Diffusion.Layers.residual_block(in_channels, in_channels,
      activation: spec.activation,
      name: join(name, "residual_blocks.0")
    )
    |> attention_block(in_channels, num_heads: 1, name: join(name, "attentions.0"))
    |> Diffusion.Layers.residual_block(in_channels, in_channels,
      activation: spec.activation,
      name: join(name, "residual_blocks.1")
    )
  end

  defp attention_block(hidden_state, channels, opts) do
    num_heads = opts[:num_heads]
    name = opts[:name]

    shortcut = hidden_state

    hidden_state =
      hidden_state
      |> Axon.group_norm(32, epsilon: 1.0e-6, name: join(name, "norm"))
      |> Axon.reshape({:batch, :auto, channels})

    {hidden_state, _attention, _self_attention_cache, _position_bias} =
      Layers.Transformer.multi_head_attention(hidden_state, hidden_state, hidden_state,
        num_heads: num_heads,
        hidden_size: channels,
        name: name
      )

    hidden_state =
      Axon.layer(
        fn hidden_state, shortcut, _opts ->
          Nx.reshape(hidden_state, Nx.shape(shortcut))
        end,
        [hidden_state, shortcut]
      )

    Axon.add(hidden_state, shortcut)
  end

  defp diagonal_gaussian_distribution(x) do
    {mean, logvar} = Axon.split(x, 2, axis: -1)
    logvar = Axon.nx(logvar, &Nx.clip(&1, -30.0, 20.0))
    std = Axon.nx(logvar, &Nx.exp(Nx.multiply(0.5, &1)))
    var = Axon.nx(logvar, &Nx.exp/1)
    %{mean: mean, logvar: logvar, std: std, var: var}
  end

  defp sample(posterior, opts) do
    name = opts[:name]

    key_state =
      Axon.param("key", fn _ -> {2} end,
        type: {:u, 32},
        initializer: fn _, _ -> Nx.Random.key(:erlang.system_time()) end
      )

    z =
      Axon.layer(
        fn mean, prng_key, opts ->
          mode = opts[:mode]

          case mode do
            :train ->
              {rand, next_key} = Nx.Random.normal(prng_key, shape: Nx.shape(mean))
              %Axon.StatefulOutput{output: rand, state: %{"key" => next_key}}

            :inference ->
              {rand, _next_key} = Nx.Random.normal(prng_key, shape: Nx.shape(mean))
              rand
          end
        end,
        [posterior.mean, key_state],
        name: name
      )

    z
    |> Axon.multiply(posterior.std)
    |> Axon.add(posterior.mean)
  end

  defp mode(%{mean: mean}) do
    mean
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(spec, data) do
      import Shared.Converters

      opts =
        convert!(data,
          sample_size: {"sample_size", number()},
          in_channels: {"in_channels", number()},
          out_channels: {"out_channels", number()},
          latent_channels: {"latent_channels", number()},
          hidden_sizes: {"block_out_channels", list(number())},
          depth: {"layers_per_block", number()},
          down_block_types:
            {"down_block_types", list(mapping(%{"DownEncoderBlock2D" => :down_block}))},
          up_block_types: {"up_block_types", list(mapping(%{"UpDecoderBlock2D" => :up_block}))},
          activation: {"act_fn", atom()}
        )

      @for.config(spec, opts)
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Model do
    def params_mapping(_spec) do
      block_mapping = %{
        "attentions.{m}.norm" => "attentions.{m}.group_norm",
        "attentions.{m}.query" => "attentions.{m}.query",
        "attentions.{m}.key" => "attentions.{m}.key",
        "attentions.{m}.value" => "attentions.{m}.value",
        "attentions.{m}.output" => "attentions.{m}.proj_attn",
        "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1",
        "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1",
        "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2",
        "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2",
        "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut",
        "downsamples.{m}.conv" => "downsamplers.{m}.conv",
        "upsamples.{m}.conv" => "upsamplers.{m}.conv"
      }

      blocks_mapping =
        for {target, source} <- block_mapping,
            prefix <- [
              "encoder.down_blocks.{n}",
              "encoder.mid_block",
              "decoder.mid_block",
              "decoder.up_blocks.{n}"
            ],
            do: {prefix <> "." <> target, prefix <> "." <> source},
            into: %{}

      %{
        "encoder.input_conv" => "encoder.conv_in",
        "encoder.output_norm" => "encoder.conv_norm_out",
        "encoder.output_conv" => "encoder.conv_out",
        "encoder.pre_quantization_conv" => "quant_conv",
        "decoder.post_quantization_conv" => "post_quant_conv",
        "decoder.input_conv" => "decoder.conv_in",
        "decoder.output_norm" => "decoder.conv_norm_out",
        "decoder.output_conv" => "decoder.conv_out",
        # TODO: this is never loaded, add support for :ignored source name
        "sample" => "sample"
      }
      |> Map.merge(blocks_mapping)
    end
  end
end