lib/bumblebee/diffusion/unet_2d_conditional.ex

defmodule Bumblebee.Diffusion.UNet2DConditional do
  alias Bumblebee.Shared

  options = [
    sample_size: [
      default: 32,
      doc: "the size of the input spatial dimensions"
    ],
    in_channels: [
      default: 4,
      doc: "the number of channels in the input"
    ],
    out_channels: [
      default: 4,
      doc: "the number of channels in the output"
    ],
    center_input_sample: [
      default: false,
      doc: "whether to center the input sample"
    ],
    embedding_flip_sin_to_cos: [
      default: true,
      doc: "whether to flip the sin to cos in the sinusoidal timestep embedding"
    ],
    embedding_frequency_correction_term: [
      default: 0,
      doc: ~S"""
      controls the frequency formula in the timestep sinusoidal embedding. The frequency is computed
      as $\\omega_i = \\frac{1}{10000^{\\frac{i}{n - s}}}$, for $i \\in \\{0, ..., n-1\\}$, where $n$
      is half of the embedding size and $s$ is the shift. Historically, certain implementations of
      sinusoidal embedding used $s=0$, while others used $s=1$
      """
    ],
    hidden_sizes: [
      default: [320, 640, 1280, 1280],
      doc: "the dimensionality of hidden layers in each upsample/downsample block"
    ],
    depth: [
      default: 2,
      doc: "the number of residual blocks in each upsample/downsample block"
    ],
    down_block_types: [
      default: [
        :cross_attention_down_block,
        :cross_attention_down_block,
        :cross_attention_down_block,
        :down_block
      ],
      doc:
        "a list of downsample block types. The supported blocks are: `:down_block`, `:cross_attention_down_block`"
    ],
    up_block_types: [
      default: [
        :up_block,
        :cross_attention_up_block,
        :cross_attention_up_block,
        :cross_attention_up_block
      ],
      doc:
        "a list of upsample block types. The supported blocks are: `:up_block`, `:cross_attention_up_block`"
    ],
    downsample_padding: [
      default: [{1, 1}, {1, 1}],
      doc: "the padding to use in the downsample convolution"
    ],
    mid_block_scale_factor: [
      default: 1,
      doc: "the scale factor to use for the mid block"
    ],
    num_attention_heads: [
      default: 8,
      doc:
        "the number of attention heads for each attention layer. Optionally can be a list with one number per block"
    ],
    cross_attention_size: [
      default: 1280,
      doc: "the dimensionality of the cross attention features"
    ],
    use_linear_projection: [
      default: false,
      doc:
        "whether the input/output projection of the transformer block should be linear or convolutional"
    ],
    activation: [
      default: :silu,
      doc: "the activation function"
    ],
    group_norm_num_groups: [
      default: 32,
      doc: "the number of groups used by the group normalization layers"
    ],
    group_norm_epsilon: [
      default: 1.0e-5,
      doc: "the epsilon used by the group normalization layers"
    ]
  ]

  @moduledoc """
  U-Net model with two spatial dimensions and conditional state.

  ## Architectures

    * `:base` - the U-Net model

  ## Inputs

    * `"sample"` - `{batch_size, sample_size, sample_size, in_channels}`

      Sample input with two spatial dimensions.

    * `"timestep"` - `{}`

      The timestep used to parameterize model behaviour in a multi-step
      process, such as diffusion.

    * `"encoder_hidden_state"` - `{batch_size, sequence_length, hidden_size}`

      The conditional state (context) to use with cross-attention.

  ## Configuration

  #{Shared.options_doc(options)}
  """

  defstruct [architecture: :base] ++ Shared.option_defaults(options)

  @behaviour Bumblebee.ModelSpec
  @behaviour Bumblebee.Configurable

  import Bumblebee.Utils.Model, only: [join: 2]

  alias Bumblebee.Layers
  alias Bumblebee.Diffusion

  @impl true
  def architectures(), do: [:base]

  @impl true
  def config(spec, opts \\ []) do
    Shared.put_config_attrs(spec, opts)
  end

  @impl true
  def input_template(spec) do
    sample_shape = {1, spec.sample_size, spec.sample_size, spec.in_channels}
    timestep_shape = {}
    encoder_hidden_state_shape = {1, 1, spec.cross_attention_size}

    %{
      "sample" => Nx.template(sample_shape, :f32),
      "timestep" => Nx.template(timestep_shape, :s64),
      "encoder_hidden_state" => Nx.template(encoder_hidden_state_shape, :f32)
    }
  end

  @impl true
  def model(%__MODULE__{architecture: :base} = spec) do
    inputs = inputs(spec)
    sample = core(inputs, spec)
    Layers.output(%{sample: sample})
  end

  defp inputs(spec) do
    sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels}

    Bumblebee.Utils.Model.inputs_to_map([
      Axon.input("sample", shape: sample_shape),
      Axon.input("timestep", shape: {}),
      Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size})
    ])
  end

  defp core(inputs, spec) do
    sample = inputs["sample"]
    timestep = inputs["timestep"]
    encoder_hidden_state = inputs["encoder_hidden_state"]

    sample =
      if spec.center_input_sample do
        Axon.nx(sample, fn sample -> 2 * sample - 1.0 end, op_name: :center)
      else
        sample
      end

    timestep =
      Axon.layer(
        fn sample, timestep, _opts ->
          Nx.broadcast(timestep, {Nx.axis_size(sample, 0)})
        end,
        [sample, timestep],
        op_name: :broadcast
      )

    timestep_embedding =
      timestep
      |> Diffusion.Layers.timestep_sinusoidal_embedding(hd(spec.hidden_sizes),
        flip_sin_to_cos: spec.embedding_flip_sin_to_cos,
        frequency_correction_term: spec.embedding_frequency_correction_term
      )
      |> Diffusion.Layers.UNet.timestep_embedding_mlp(hd(spec.hidden_sizes) * 4,
        name: "time_embedding"
      )

    sample =
      Axon.conv(sample, hd(spec.hidden_sizes),
        kernel_size: 3,
        padding: [{1, 1}, {1, 1}],
        name: "input_conv"
      )

    {sample, down_block_residuals} =
      down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, name: "down_blocks")

    sample
    |> mid_block(timestep_embedding, encoder_hidden_state, spec, name: "mid_block")
    |> up_blocks(timestep_embedding, down_block_residuals, encoder_hidden_state, spec,
      name: "up_blocks"
    )
    |> Axon.group_norm(spec.group_norm_num_groups,
      epsilon: spec.group_norm_epsilon,
      name: "output_norm"
    )
    |> Axon.activation(:silu)
    |> Axon.conv(spec.out_channels,
      kernel_size: 3,
      padding: [{1, 1}, {1, 1}],
      name: "output_conv"
    )
  end

  defp down_blocks(sample, timestep_embedding, encoder_hidden_state, spec, opts) do
    name = opts[:name]

    blocks =
      Enum.zip([spec.hidden_sizes, spec.down_block_types, num_attention_heads_per_block(spec)])

    in_channels = hd(spec.hidden_sizes)
    down_block_residuals = [{sample, in_channels}]

    state = {sample, down_block_residuals, in_channels}

    {sample, down_block_residuals, _} =
      for {{out_channels, block_type, num_attention_heads}, idx} <- Enum.with_index(blocks),
          reduce: state do
        {sample, down_block_residuals, in_channels} ->
          last_block? = idx == length(spec.hidden_sizes) - 1

          {sample, residuals} =
            Diffusion.Layers.UNet.down_block_2d(
              block_type,
              sample,
              timestep_embedding,
              encoder_hidden_state,
              depth: spec.depth,
              in_channels: in_channels,
              out_channels: out_channels,
              add_downsample: not last_block?,
              downsample_padding: spec.downsample_padding,
              activation: spec.activation,
              norm_epsilon: spec.group_norm_epsilon,
              norm_num_groups: spec.group_norm_num_groups,
              num_attention_heads: num_attention_heads,
              use_linear_projection: spec.use_linear_projection,
              name: join(name, idx)
            )

          {sample, down_block_residuals ++ Tuple.to_list(residuals), out_channels}
      end

    {sample, List.to_tuple(down_block_residuals)}
  end

  defp mid_block(hidden_state, timesteps_embedding, encoder_hidden_state, spec, opts) do
    Diffusion.Layers.UNet.mid_cross_attention_block_2d(
      hidden_state,
      timesteps_embedding,
      encoder_hidden_state,
      channels: List.last(spec.hidden_sizes),
      activation: spec.activation,
      norm_epsilon: spec.group_norm_epsilon,
      norm_num_groups: spec.group_norm_num_groups,
      output_scale_factor: spec.mid_block_scale_factor,
      num_attention_heads: spec |> num_attention_heads_per_block() |> List.last(),
      use_linear_projection: spec.use_linear_projection,
      name: opts[:name]
    )
  end

  defp up_blocks(
         sample,
         timestep_embedding,
         down_block_residuals,
         encoder_hidden_state,
         spec,
         opts
       ) do
    name = opts[:name]

    down_block_residuals =
      down_block_residuals
      |> Tuple.to_list()
      |> Enum.reverse()
      |> Enum.chunk_every(spec.depth + 1)

    reversed_hidden_sizes = Enum.reverse(spec.hidden_sizes)
    in_channels = hd(reversed_hidden_sizes)

    num_attention_heads_per_block =
      spec
      |> num_attention_heads_per_block()
      |> Enum.reverse()

    blocks_and_chunks =
      [
        reversed_hidden_sizes,
        spec.up_block_types,
        num_attention_heads_per_block,
        down_block_residuals
      ]
      |> Enum.zip()
      |> Enum.with_index()

    {sample, _} =
      for {{out_channels, block_type, num_attention_heads, residuals}, idx} <- blocks_and_chunks,
          reduce: {sample, in_channels} do
        {sample, in_channels} ->
          last_block? = idx == length(spec.hidden_sizes) - 1

          sample =
            Diffusion.Layers.UNet.up_block_2d(
              block_type,
              sample,
              timestep_embedding,
              residuals,
              encoder_hidden_state,
              depth: spec.depth + 1,
              in_channels: in_channels,
              out_channels: out_channels,
              add_upsample: not last_block?,
              norm_epsilon: spec.group_norm_epsilon,
              norm_num_groups: spec.group_norm_num_groups,
              activation: spec.activation,
              num_attention_heads: num_attention_heads,
              use_linear_projection: spec.use_linear_projection,
              name: join(name, idx)
            )

          {sample, out_channels}
      end

    sample
  end

  defp num_attention_heads_per_block(spec) when is_list(spec.num_attention_heads) do
    spec.num_attention_heads
  end

  defp num_attention_heads_per_block(spec) when is_integer(spec.num_attention_heads) do
    num_blocks = length(spec.down_block_types)
    List.duplicate(spec.num_attention_heads, num_blocks)
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(spec, data) do
      import Shared.Converters

      opts =
        convert!(data,
          in_channels: {"in_channels", number()},
          out_channels: {"out_channels", number()},
          sample_size: {"sample_size", number()},
          center_input_sample: {"center_input_sample", boolean()},
          embedding_flip_sin_to_cos: {"flip_sin_to_cos", boolean()},
          embedding_frequency_correction_term: {"freq_shift", number()},
          hidden_sizes: {"block_out_channels", list(number())},
          depth: {"layers_per_block", number()},
          down_block_types: {
            "down_block_types",
            list(
              mapping(%{
                "DownBlock2D" => :down_block,
                "CrossAttnDownBlock2D" => :cross_attention_down_block
              })
            )
          },
          up_block_types: {
            "up_block_types",
            list(
              mapping(%{
                "UpBlock2D" => :up_block,
                "CrossAttnUpBlock2D" => :cross_attention_up_block
              })
            )
          },
          downsample_padding: {"downsample_padding", padding(2)},
          mid_block_scale_factor: {"mid_block_scale_factor", number()},
          num_attention_heads: {"attention_head_dim", one_of([number(), list(number())])},
          cross_attention_size: {"cross_attention_dim", number()},
          use_linear_projection: {"use_linear_projection", boolean()},
          activation: {"act_fn", atom()},
          group_norm_num_groups: {"norm_num_groups", number()},
          group_norm_epsilon: {"norm_eps", number()}
        )

      @for.config(spec, opts)
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Model do
    def params_mapping(_spec) do
      block_mapping = %{
        "transformers.{m}.norm" => "attentions.{m}.norm",
        "transformers.{m}.input_projection" => "attentions.{m}.proj_in",
        "transformers.{m}.output_projection" => "attentions.{m}.proj_out",
        "transformers.{m}.blocks.{l}.self_attention.query" =>
          "attentions.{m}.transformer_blocks.{l}.attn1.to_q",
        "transformers.{m}.blocks.{l}.self_attention.key" =>
          "attentions.{m}.transformer_blocks.{l}.attn1.to_k",
        "transformers.{m}.blocks.{l}.self_attention.value" =>
          "attentions.{m}.transformer_blocks.{l}.attn1.to_v",
        "transformers.{m}.blocks.{l}.self_attention.output" =>
          "attentions.{m}.transformer_blocks.{l}.attn1.to_out.0",
        "transformers.{m}.blocks.{l}.cross_attention.query" =>
          "attentions.{m}.transformer_blocks.{l}.attn2.to_q",
        "transformers.{m}.blocks.{l}.cross_attention.key" =>
          "attentions.{m}.transformer_blocks.{l}.attn2.to_k",
        "transformers.{m}.blocks.{l}.cross_attention.value" =>
          "attentions.{m}.transformer_blocks.{l}.attn2.to_v",
        "transformers.{m}.blocks.{l}.cross_attention.output" =>
          "attentions.{m}.transformer_blocks.{l}.attn2.to_out.0",
        "transformers.{m}.blocks.{l}.ffn.intermediate" =>
          "attentions.{m}.transformer_blocks.{l}.ff.net.0.proj",
        "transformers.{m}.blocks.{l}.ffn.output" =>
          "attentions.{m}.transformer_blocks.{l}.ff.net.2",
        "transformers.{m}.blocks.{l}.self_attention_norm" =>
          "attentions.{m}.transformer_blocks.{l}.norm1",
        "transformers.{m}.blocks.{l}.cross_attention_norm" =>
          "attentions.{m}.transformer_blocks.{l}.norm2",
        "transformers.{m}.blocks.{l}.output_norm" =>
          "attentions.{m}.transformer_blocks.{l}.norm3",
        "residual_blocks.{m}.timestep_projection" => "resnets.{m}.time_emb_proj",
        "residual_blocks.{m}.norm_1" => "resnets.{m}.norm1",
        "residual_blocks.{m}.conv_1" => "resnets.{m}.conv1",
        "residual_blocks.{m}.norm_2" => "resnets.{m}.norm2",
        "residual_blocks.{m}.conv_2" => "resnets.{m}.conv2",
        "residual_blocks.{m}.shortcut.projection" => "resnets.{m}.conv_shortcut",
        "downsamples.{m}.conv" => "downsamplers.{m}.conv",
        "upsamples.{m}.conv" => "upsamplers.{m}.conv"
      }

      blocks_mapping =
        for {target, source} <- block_mapping,
            prefix <- ["down_blocks.{n}", "mid_block", "up_blocks.{n}"],
            do: {prefix <> "." <> target, prefix <> "." <> source},
            into: %{}

      %{
        "time_embedding.intermediate" => "time_embedding.linear_1",
        "time_embedding.output" => "time_embedding.linear_2",
        "input_conv" => "conv_in",
        "output_norm" => "conv_norm_out",
        "output_conv" => "conv_out"
      }
      |> Map.merge(blocks_mapping)
    end
  end
end