lib/bumblebee/diffusion/stable_diffusion.ex

defmodule Bumblebee.Diffusion.StableDiffusion do
  @moduledoc """
  High-level tasks based on Stable Diffusion.
  """

  import Nx.Defn

  alias Bumblebee.Shared

  @type text_to_image_input ::
          String.t() | %{:prompt => String.t(), optional(:negative_prompt) => String.t()}
  @type text_to_image_output :: %{results: list(text_to_image_result())}
  @type text_to_image_result :: %{:image => Nx.Tensor.t(), optional(:is_safe) => boolean()}

  @doc ~S"""
  Build serving for prompt-driven image generation.

  The serving accepts `t:text_to_image_input/0` and returns `t:text_to_image_output/0`.
  A list of inputs is also supported.

  You can specify `:safety_checker` model to automatically detect
  when a generated image is offensive or harmful and filter it out.

  ## Options

    * `:safety_checker` - the safety checker model info map. When a
      safety checker is used, each output entry has an additional
      `:is_safe` property and unsafe images are automatically zeroed.
      Make sure to also set `:safety_checker_featurizer`

    * `:safety_checker_featurizer` - the featurizer to use to preprocess
      the safety checker input images

    * `:num_steps` - the number of denoising steps. More denoising
      steps usually lead to higher image quality at the expense of
      slower inference. Defaults to `50`

    * `:num_images_per_prompt` - the number of images to generate for
      each prompt. Defaults to `1`

    * `:guidance_scale` - the scale used for classifier-free diffusion
      guidance. Higher guidance scale makes the generated images more
      closely reflect the text prompt. This parameter corresponds to
      $\omega$ in Equation (2) of the [Imagen paper](https://arxiv.org/pdf/2205.11487.pdf).
      Defaults to `7.5`

    * `:seed` - a seed for the random number generator. Defaults to `0`

    * `:compile` - compiles all computations for predefined input shapes
      during serving initialization. Should be a keyword list with the
      following keys:

        * `:batch_size` - the maximum batch size of the input. Inputs
          are optionally padded to always match this batch size

        * `:sequence_length` - the maximum input sequence length. Input
          sequences are always padded/truncated to match that length

      It is advised to set this option in production and also configure
      a defn compiler using `:defn_options` to maximally reduce inference
      time.

    * `:defn_options` - the options for JIT compilation. Defaults to `[]`

  ## Examples

      repository_id = "CompVis/stable-diffusion-v1-4"

      {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})

      {:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})

      {:ok, unet} =
        Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
          params_filename: "diffusion_pytorch_model.bin"
        )

      {:ok, vae} =
        Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
          architecture: :decoder,
          params_filename: "diffusion_pytorch_model.bin"
        )

      {:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
      {:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
      {:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})

      serving =
        Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
          num_steps: 20,
          num_images_per_prompt: 2,
          safety_checker: safety_checker,
          safety_checker_featurizer: featurizer,
          compile: [batch_size: 1, sequence_length: 60],
          defn_options: [compiler: EXLA]
        )

      prompt = "numbat in forest, detailed, digital art"
      Nx.Serving.run(serving, prompt)
      #=> %{
      #=>   results: [
      #=>     %{
      #=>       image: #Nx.Tensor<
      #=>         u8[512][512][3]
      #=>         ...
      #=>       >,
      #=>       is_safe: true
      #=>     },
      #=>     %{
      #=>       image: #Nx.Tensor<
      #=>         u8[512][512][3]
      #=>         ...
      #=>       >,
      #=>       is_safe: true
      #=>     }
      #=>   ]
      #=> }

  """
  @spec text_to_image(
          Bumblebee.model_info(),
          Bumblebee.model_info(),
          Bumblebee.model_info(),
          Bumblebee.Tokenizer.t(),
          Bumblebee.Scheduler.t(),
          keyword()
        ) :: Nx.Serving.t()
  def text_to_image(encoder, unet, vae, tokenizer, scheduler, opts \\ []) do
    opts =
      Keyword.validate!(opts, [
        :safety_checker,
        :safety_checker_featurizer,
        :compile,
        num_steps: 50,
        num_images_per_prompt: 1,
        guidance_scale: 7.5,
        seed: 0,
        defn_options: []
      ])

    safety_checker = opts[:safety_checker]
    safety_checker_featurizer = opts[:safety_checker_featurizer]
    num_steps = opts[:num_steps]
    num_images_per_prompt = opts[:num_images_per_prompt]
    compile = opts[:compile]
    defn_options = opts[:defn_options]

    if safety_checker != nil and safety_checker_featurizer == nil do
      raise ArgumentError, "got :safety_checker but no :safety_checker_featurizer was specified"
    end

    safety_checker? = safety_checker != nil

    batch_size = compile[:batch_size]
    sequence_length = compile[:sequence_length]

    if compile != nil and (batch_size == nil or sequence_length == nil) do
      raise ArgumentError,
            "expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}"
    end

    {_, encoder_predict} = Axon.build(encoder.model)
    {_, vae_predict} = Axon.build(vae.model)
    {_, unet_predict} = Axon.build(unet.model)

    scheduler_init = fn latents_shape ->
      Bumblebee.scheduler_init(scheduler, num_steps, latents_shape)
    end

    scheduler_step = &Bumblebee.scheduler_step(scheduler, &1, &2, &3)

    image_fun =
      &text_to_image_impl(
        encoder_predict,
        &1,
        unet_predict,
        &2,
        vae_predict,
        &3,
        scheduler_init,
        scheduler_step,
        &4,
        num_images_per_prompt: opts[:num_images_per_prompt],
        latents_sample_size: unet.spec.sample_size,
        latents_channels: unet.spec.in_channels,
        seed: opts[:seed],
        guidance_scale: opts[:guidance_scale]
      )

    safety_checker_fun =
      if safety_checker do
        {_, predict_fun} = Axon.build(safety_checker.model)
        predict_fun
      end

    # Note that all of these are copied when using serving as a process
    init_args = [
      {image_fun, safety_checker_fun},
      encoder.params,
      unet.params,
      vae.params,
      {safety_checker?, safety_checker[:spec], safety_checker[:params]},
      safety_checker_featurizer,
      {compile != nil, batch_size, sequence_length},
      num_images_per_prompt
    ]

    Nx.Serving.new(
      fn defn_options -> apply(&init/9, init_args ++ [defn_options]) end,
      defn_options
    )
    |> Nx.Serving.process_options(batch_size: batch_size)
    |> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer, sequence_length))
    |> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, &3, safety_checker))
  end

  defp init(
         {image_fun, safety_checker_fun},
         encoder_params,
         unet_params,
         vae_params,
         {safety_checker?, safety_checker_spec, safety_checker_params},
         safety_checker_featurizer,
         {compile?, batch_size, sequence_length},
         num_images_per_prompt,
         defn_options
       ) do
    image_fun =
      Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
        text_inputs = %{
          "input_ids" => Nx.template({batch_size, sequence_length}, :s64)
        }

        inputs = %{"unconditional" => text_inputs, "conditional" => text_inputs}

        [encoder_params, unet_params, vae_params, inputs]
      end)

    safety_checker_fun =
      safety_checker_fun &&
        Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn ->
          inputs = %{
            "pixel_values" =>
              Shared.input_template(safety_checker_spec, "pixel_values", [
                batch_size * num_images_per_prompt
              ])
          }

          [safety_checker_params, inputs]
        end)

    fn inputs ->
      inputs = Shared.maybe_pad(inputs, batch_size)

      image = image_fun.(encoder_params, unet_params, vae_params, inputs)

      output =
        if safety_checker? do
          inputs = Bumblebee.apply_featurizer(safety_checker_featurizer, image)
          outputs = safety_checker_fun.(safety_checker_params, inputs)
          %{image: image, is_unsafe: outputs.is_unsafe}
        else
          %{image: image}
        end

      Bumblebee.Utils.Nx.composite_unflatten_batch(output, inputs.size)
    end
  end

  defp client_preprocessing(input, tokenizer, sequence_length) do
    {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input/1)

    prompts = Enum.map(inputs, & &1.prompt)
    negative_prompts = Enum.map(inputs, & &1.negative_prompt)

    conditional =
      Bumblebee.apply_tokenizer(tokenizer, prompts,
        length: sequence_length,
        return_token_type_ids: false,
        return_attention_mask: false
      )

    unconditional =
      Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
        length: Nx.axis_size(conditional["input_ids"], 1),
        return_attention_mask: false,
        return_token_type_ids: false
      )

    inputs = %{"unconditional" => unconditional, "conditional" => conditional}

    {Nx.Batch.concatenate([inputs]), multi?}
  end

  defp client_postprocessing(outputs, _metadata, multi?, safety_checker?) do
    for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
      results =
        for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
          if safety_checker? do
            if Nx.to_number(outputs.is_unsafe) == 1 do
              %{image: zeroed(image), is_safe: false}
            else
              %{image: image, is_safe: true}
            end
          else
            %{image: image}
          end
        end

      %{results: results}
    end
    |> Shared.normalize_output(multi?)
  end

  defp zeroed(tensor) do
    0
    |> Nx.tensor(type: Nx.type(tensor))
    |> Nx.broadcast(Nx.shape(tensor))
  end

  defnp text_to_image_impl(
          encoder_predict,
          encoder_params,
          unet_predict,
          unet_params,
          vae_predict,
          vae_params,
          scheduler_init,
          scheduler_step,
          inputs,
          opts \\ []
        ) do
    num_images_per_prompt = opts[:num_images_per_prompt]
    latents_sample_size = opts[:latents_sample_size]
    latents_in_channels = opts[:latents_channels]
    seed = opts[:seed]
    guidance_scale = opts[:guidance_scale]

    inputs =
      Bumblebee.Utils.Nx.composite_concatenate(inputs["unconditional"], inputs["conditional"])

    %{hidden_state: text_embeddings} = encoder_predict.(encoder_params, inputs)

    {twice_batch_size, sequence_length, hidden_size} = Nx.shape(text_embeddings)
    batch_size = div(twice_batch_size, 2)

    text_embeddings =
      text_embeddings
      |> Nx.new_axis(1)
      |> Nx.tile([1, num_images_per_prompt, 1, 1])
      |> Nx.reshape({:auto, sequence_length, hidden_size})

    latents_shape =
      {batch_size * num_images_per_prompt, latents_sample_size, latents_sample_size,
       latents_in_channels}

    {scheduler_state, timesteps} = scheduler_init.(latents_shape)

    key = Nx.Random.key(seed)
    {latents, _key} = Nx.Random.normal(key, shape: latents_shape)

    {_, latents, _, _} =
      while {scheduler_state, latents, text_embeddings, unet_params}, timestep <- timesteps do
        unet_inputs = %{
          "sample" => Nx.concatenate([latents, latents]),
          "timestep" => timestep,
          "encoder_hidden_state" => text_embeddings
        }

        %{sample: noise_pred} = unet_predict.(unet_params, unet_inputs)

        {noise_pred_unconditional, noise_pred_text} = split_in_half(noise_pred)

        noise_pred =
          noise_pred_unconditional + guidance_scale * (noise_pred_text - noise_pred_unconditional)

        {scheduler_state, latents} = scheduler_step.(scheduler_state, latents, noise_pred)

        {scheduler_state, latents, text_embeddings, unet_params}
      end

    latents = latents * (1 / 0.18215)

    %{sample: image} = vae_predict.(vae_params, latents)

    NxImage.from_continuous(image, -1, 1)
  end

  defnp split_in_half(tensor) do
    batch_size = Nx.axis_size(tensor, 0)
    half_size = div(batch_size, 2)
    {tensor[0..(half_size - 1)//1], tensor[half_size..-1//1]}
  end

  defp validate_input(prompt) when is_binary(prompt), do: validate_input(%{prompt: prompt})

  defp validate_input(%{prompt: prompt} = input) do
    {:ok, %{prompt: prompt, negative_prompt: input[:negative_prompt] || ""}}
  end

  defp validate_input(%{} = input) do
    {:error, "expected the input map to have :prompt key, got: #{inspect(input)}"}
  end

  defp validate_input(input) do
    {:error, "expected either a string or a map, got: #{inspect(input)}"}
  end
end