defmodule Bumblebee.Diffusion.StableDiffusion.SafetyChecker do
alias Bumblebee.Shared
options = [
clip_spec: [
default: nil,
doc: "the specification of the CLIP model. See `Bumblebee.Multimodal.Clip` for details"
]
]
@moduledoc """
A CLIP-based model for detecting unsafe image content.
This model is designed primarily to check images generated using
Stable Diffusion.
## Architectures
* `:base` - the base safety detection model
## Inputs
* `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`
Featurized image pixel values.
## Configuration
#{Shared.options_doc(options)}
## References
* [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4#safety-module)
"""
defstruct [architecture: :base] ++ Shared.option_defaults(options)
@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable
import Nx.Defn
alias Bumblebee.Layers
@impl true
def architectures(), do: [:base]
@impl true
def config(spec, opts) do
Shared.put_config_attrs(spec, opts)
end
@impl true
def input_template(%{clip_spec: %{vision_spec: vision_spec}}) do
vision_shape = {1, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}
%{"pixel_values" => Nx.template(vision_shape, :f32)}
end
@impl true
def model(%__MODULE__{architecture: :base} = spec) do
%{clip_spec: %{vision_spec: vision_spec}} = spec
vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}
inputs =
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("pixel_values", shape: vision_shape)
])
vision_model =
vision_spec
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
"pixel_values" => inputs["pixel_values"]
})
image_embeddings =
vision_model
|> Axon.nx(& &1.pooled_state)
|> Axon.dense(spec.clip_spec.projection_size, use_bias: false, name: "visual_projection")
is_unsafe = unsafe_detection(image_embeddings, spec, name: "unsafe_detection")
Layers.output(%{
is_unsafe: is_unsafe
})
end
defp unsafe_detection(image_embeddings, spec, opts) do
name = opts[:name]
# The embeddings are precomputed using the CLIP text model and
# represent sensitive/unsafe concepts in the latent space. We then
# check whether an image is far enough from those concepts in the
# latent space (using a hand-engineered threshold for each concept).
num_sensitive_concepts = 3
num_unsafe_concepts = 17
sensitive_concept_embeddings =
Axon.param("sensitive_concept_embeddings", fn _ ->
{num_sensitive_concepts, spec.clip_spec.projection_size}
end)
unsafe_concept_embeddings =
Axon.param("unsafe_concept_embeddings", fn _ ->
{num_unsafe_concepts, spec.clip_spec.projection_size}
end)
sensitive_concept_thresholds =
Axon.param("sensitive_concept_thresholds", fn _ -> {num_sensitive_concepts} end)
unsafe_concept_thresholds =
Axon.param("unsafe_concept_thresholds", fn _ -> {num_unsafe_concepts} end)
Axon.layer(
&unsafe_detection_impl/6,
[
image_embeddings,
sensitive_concept_embeddings,
unsafe_concept_embeddings,
sensitive_concept_thresholds,
unsafe_concept_thresholds
],
name: name
)
end
defnp unsafe_detection_impl(
image_embeddings,
sensitive_concept_embeddings,
unsafe_concept_embeddings,
sensitive_concept_thresholds,
unsafe_concept_thresholds,
_opts \\ []
) do
sensitive_concept_distances =
Bumblebee.Utils.Nx.cosine_similarity(image_embeddings, sensitive_concept_embeddings)
unsafe_concept_distances =
Bumblebee.Utils.Nx.cosine_similarity(image_embeddings, unsafe_concept_embeddings)
sensitive_concept_thresholds = Nx.new_axis(sensitive_concept_thresholds, 0)
unsafe_concept_thresholds = Nx.new_axis(unsafe_concept_thresholds, 0)
sensitive_concept_scores = sensitive_concept_distances - sensitive_concept_thresholds
sensitive? = Nx.any(sensitive_concept_scores > 0, axes: [1], keep_axes: true)
# Use a lower threshold if an image has any sensitive concept
unsafe_threshold_adjustment = Nx.select(sensitive?, 0.01, 0.0)
unsafe_concept_scores =
unsafe_concept_distances - unsafe_concept_thresholds + unsafe_threshold_adjustment
Nx.any(unsafe_concept_scores > 0, axes: [1])
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(spec, data) do
clip_spec =
Bumblebee.Multimodal.Clip
|> Bumblebee.configure()
|> Bumblebee.HuggingFace.Transformers.Config.load(data)
@for.config(spec, clip_spec: clip_spec)
end
end
defimpl Bumblebee.HuggingFace.Transformers.Model do
alias Bumblebee.HuggingFace.Transformers
def params_mapping(spec) do
vision_mapping =
spec.clip_spec.vision_spec
|> Transformers.Model.params_mapping()
|> Transformers.Utils.prefix_params_mapping("vision_model", "vision_model")
%{
"visual_projection" => "visual_projection",
"unsafe_detection" => %{
"sensitive_concept_embeddings" => {
[{"unsafe_detection", "special_care_embeds"}],
fn [value] -> value end
},
"unsafe_concept_embeddings" => {
[{"unsafe_detection", "concept_embeds"}],
fn [value] -> value end
},
"sensitive_concept_thresholds" => {
[{"unsafe_detection", "special_care_embeds_weights"}],
fn [value] -> value end
},
"unsafe_concept_thresholds" => {
[{"unsafe_detection", "concept_embeds_weights"}],
fn [value] -> value end
}
}
}
|> Map.merge(vision_mapping)
end
end
end