lib/bumblebee.ex

defmodule Bumblebee do
  @moduledoc """
  Pre-trained `Axon` models for easy inference and boosted training.

  Bumblebee provides state-of-the-art, configurable `Axon` models. On
  top of that, it streamlines the process of loading pre-trained models
  by integrating with Hugging Face Hub and [🤗 Transformers](https://github.com/huggingface/transformers).

  ## Usage

  You can load one of the supported models by specifying the model repository:

      {:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
      {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

  Then you are ready to make predictions:

      inputs = Bumblebee.apply_tokenizer(tokenizer, "Hello Bumblebee!")
      outputs = Axon.predict(bert.model, bert.params, inputs)

  For complete examples see the [Examples](examples.livemd) notebook.

  > #### Note {: .info}
  >
  > The models are generally large, so make sure to configure an efficient
  > `Nx` backend, such as `EXLA` or `Torchx`.
  """

  alias Bumblebee.HuggingFace

  @config_filename "config.json"
  @featurizer_filename "preprocessor_config.json"
  @tokenizer_filename "tokenizer.json"
  @tokenizer_special_tokens_filename "special_tokens_map.json"
  @scheduler_filename "scheduler_config.json"
  @params_filename %{pytorch: "pytorch_model.bin"}

  @transformers_class_to_model %{
    "AlbertForMaskedLM" => {Bumblebee.Text.Albert, :for_masked_language_modeling},
    "AlbertForMultipleChoice" => {Bumblebee.Text.Albert, :for_multiple_choice},
    "AlbertForPreTraining" => {Bumblebee.Text.Albert, :for_pre_training},
    "AlbertForQuestionAnswering" => {Bumblebee.Text.Albert, :for_question_answering},
    "AlbertForSequenceClassification" => {Bumblebee.Text.Albert, :for_sequence_classification},
    "AlbertForTokenClassification" => {Bumblebee.Text.Albert, :for_token_classification},
    "AlbertModel" => {Bumblebee.Text.Albert, :base},
    "BartForCausalLM" => {Bumblebee.Text.Bart, :for_causal_language_modeling},
    "BartForConditionalGeneration" => {Bumblebee.Text.Bart, :for_conditional_generation},
    "BartForQuestionAnswering" => {Bumblebee.Text.Bart, :for_question_answering},
    "BartForSequenceClassification" => {Bumblebee.Text.Bart, :for_sequence_classification},
    "BartModel" => {Bumblebee.Text.Bart, :base},
    "BertForMaskedLM" => {Bumblebee.Text.Bert, :for_masked_language_modeling},
    "BertForMultipleChoice" => {Bumblebee.Text.Bert, :for_multiple_choice},
    "BertForNextSentencePrediction" => {Bumblebee.Text.Bert, :for_next_sentence_prediction},
    "BertForPreTraining" => {Bumblebee.Text.Bert, :for_pre_training},
    "BertForQuestionAnswering" => {Bumblebee.Text.Bert, :for_question_answering},
    "BertForSequenceClassification" => {Bumblebee.Text.Bert, :for_sequence_classification},
    "BertForTokenClassification" => {Bumblebee.Text.Bert, :for_token_classification},
    "BertLMHeadModel" => {Bumblebee.Text.Bert, :for_causal_language_modeling},
    "BertModel" => {Bumblebee.Text.Bert, :base},
    # These models are just RoBERTa models, but the config will list them as CamemBERT
    "CamembertModel" => {Bumblebee.Text.Roberta, :base},
    "CamembertForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "CamembertForSequenceClassification" =>
      {Bumblebee.Text.Roberta, :for_sequence_classification},
    "CamembertForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "CamembertForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "CamembertForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "CLIPModel" => {Bumblebee.Multimodal.Clip, :base},
    "CLIPTextModel" => {Bumblebee.Text.ClipText, :base},
    "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base},
    "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification},
    "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base},
    "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification},
    "DeiTForImageClassificationWithTeacher" =>
      {Bumblebee.Vision.Deit, :for_image_classification_with_teacher},
    "DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling},
    "DeiTModel" => {Bumblebee.Vision.Deit, :base},
    "DistilBertModel" => {Bumblebee.Text.Distilbert, :base},
    "DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling},
    "DistilBertForSequenceClassification" =>
      {Bumblebee.Text.Distilbert, :for_sequence_classification},
    "DistilBertForQuestionAnswering" => {Bumblebee.Text.Distilbert, :for_question_answering},
    "DistilBertForTokenClassification" => {Bumblebee.Text.Distilbert, :for_token_classification},
    "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
    "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
    "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
    "GPT2Model" => {BumbleBee.Text.Gpt2, :base},
    "LayoutLMForMaskedLanguageModeling" =>
      {Bumblebee.Multimodal.LayoutLm, :for_masked_language_modeling},
    "LayoutLMForQuestionAnswering" => {Bumblebee.Multimodal.LayoutLm, :for_question_answering},
    "LayoutLMForSequenceClassification" =>
      {Bumblebee.Multimodal.LayoutLm, :for_sequence_classification},
    "LayoutLMForTokenClassification" =>
      {Bumblebee.Multimodal.LayoutLm, :for_token_classification},
    "LayoutLMModel" => {Bumblebee.Multimodal.LayoutLm, :base},
    "MBartForCausalLM" => {Bumblebee.Text.Mbart, :for_causal_language_modeling},
    "MBartForConditionalGeneration" => {Bumblebee.Text.Mbart, :for_conditional_generation},
    "MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering},
    "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification},
    "MBartModel" => {Bumblebee.Text.Mbart, :base},
    "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
    "ResNetModel" => {Bumblebee.Vision.ResNet, :base},
    "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "RobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "RobertaForPreTraining" => {Bumblebee.Text.Roberta, :for_pre_training},
    "RobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "RobertaForSequenceClassification" => {Bumblebee.Text.Roberta, :for_sequence_classification},
    "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
    "RobertaModel" => {Bumblebee.Text.Roberta, :base},
    "T5Model" => {Bumblebee.Text.T5, :base},
    "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
    "ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification},
    "ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling},
    "ViTModel" => {Bumblebee.Vision.Vit, :base},
    # These models are just RoBERTa models, but the config will list them as XLM-RoBERTa
    "XLMRobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
    "XLMRobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "XLMRobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "XLMRobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "XLMRobertaForSequenceClassification" =>
      {Bumblebee.Text.Roberta, :for_sequence_classification},
    "XLMRobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "XLMRobertaModel" => {Bumblebee.Text.Roberta, :base},
    "WhisperModel" => {Bumblebee.Audio.Whisper, :base},
    "WhisperForConditionalGeneration" => {Bumblebee.Audio.Whisper, :for_conditional_generation},
    # Diffusers
    "AutoencoderKL" => {Bumblebee.Diffusion.VaeKl, :base},
    "StableDiffusionSafetyChecker" => {Bumblebee.Diffusion.StableDiffusion.SafetyChecker, :base},
    "UNet2DConditionModel" => {Bumblebee.Diffusion.UNet2DConditional, :base}
  }

  @transformers_class_to_featurizer %{
    "CLIPFeatureExtractor" => Bumblebee.Vision.ClipFeaturizer,
    "ConvNextFeatureExtractor" => Bumblebee.Vision.ConvNextFeaturizer,
    "DeiTFeatureExtractor" => Bumblebee.Vision.DeitFeaturizer,
    "ViTFeatureExtractor" => Bumblebee.Vision.VitFeaturizer,
    "WhisperFeatureExtractor" => Bumblebee.Audio.WhisperFeaturizer
  }

  @model_type_to_featurizer %{
    "convnext" => Bumblebee.Vision.ConvNextFeaturizer,
    "deit" => Bumblebee.Vision.DeitFeaturizer,
    "resnet" => Bumblebee.Vision.ConvNextFeaturizer,
    "vit" => Bumblebee.Vision.VitFeaturizer,
    "whisper" => Bumblebee.Audio.WhisperFeaturizer
  }

  @model_type_to_tokenizer %{
    "albert" => Bumblebee.Text.AlbertTokenizer,
    "bart" => Bumblebee.Text.BartTokenizer,
    "bert" => Bumblebee.Text.BertTokenizer,
    "distilbert" => Bumblebee.Text.DistilbertTokenizer,
    "camembert" => Bumblebee.Text.CamembertTokenizer,
    "clip" => Bumblebee.Text.ClipTokenizer,
    "gpt2" => Bumblebee.Text.Gpt2Tokenizer,
    "layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
    "mbart" => Bumblebee.Text.MbartTokenizer,
    "roberta" => Bumblebee.Text.RobertaTokenizer,
    "t5" => Bumblebee.Text.T5Tokenizer,
    "whisper" => Bumblebee.Text.WhisperTokenizer,
    "xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
  }

  @diffusers_class_to_scheduler %{
    "DDIMScheduler" => Bumblebee.Diffusion.DdimScheduler,
    "PNDMScheduler" => Bumblebee.Diffusion.PndmScheduler
  }

  @typedoc """
  A location to fetch model files from.

  Can be either:

    * `{:hf, repository_id}` - the repository on Hugging Face. Options
      may be passed as the third element:

        * `:revision` - the specific model version to use, it can be
          any valid git identifier, such as branch name, tag name, or
          a commit hash

        * `:cache_dir` - the directory to store the downloaded files
          in. Defaults to the standard cache location for the given
          operating system. You can also configure it globally by
          setting the `BUMBLEBEE_CACHE_DIR` environment variable

        * `:auth_token` - the token to use as HTTP bearer authorization
          for remote files

        * `:subdir` - the directory within the repository where the
          files are located

    * `{:local, directory}` - the directory containing model files

  """
  @type repository :: {:hf, String.t()} | {:hf, String.t(), keyword()} | {:local, Path.t()}

  @typedoc """
  A model together with its state and metadata.
  """
  @type model_info :: %{
          model: Axon.t(),
          params: map(),
          spec: Bumblebee.ModelSpec.t()
        }

  @doc """
  Builds or updates a configuration object with the given options.

  Expects a configuration struct or a module supporting configuration.
  These are usually configurable:

    * model specification (`Bumblebee.ModelSpec`)

    * featurizer (`Bumblebee.Featurizer`)

    * scheduler (`Bumblebee.Scheduler`)

  ## Examples

  To build a new configuration, pass a module:

      featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
      spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :for_image_classification)

  Similarly, you can update an existing configuration:

      featurizer = Bumblebee.configure(featurizer, resize_method: :bilinear)
      spec = Bumblebee.configure(spec, embedding_size: 128)

  """
  @spec configure(module() | Bumblebee.Configurable.t(), keyword()) :: Bumblebee.Configurable.t()
  def configure(config, options \\ []) do
    %module{} = config = struct!(config)
    module.config(config, options)
  end

  @doc """
  Builds an `Axon` model according to the given specification.

  ## Example

      spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128)
      model = Bumblebee.build_model(spec)

  """
  @doc type: :model
  @spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t()
  def build_model(%module{} = spec) do
    module.model(spec)
  end

  @doc """
  Loads model specification from a model repository.

  ## Options

    * `:module` - the model specification module. By default it is
      inferred from the configuration file, if that is not possible,
      it must be specified explicitly

    * `:architecture` - the model architecture, must be supported by
      `:module`. By default it is inferred from the configuration file

  ## Examples

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})

  You can explicitly specify a different architecture:

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}, architecture: :base)

  """
  @doc type: :model
  @spec load_spec(repository(), keyword()) ::
          {:ok, Bumblebee.ModelSpec.t()} | {:error, String.t()}
  def load_spec(repository, opts \\ []) do
    repository = normalize_repository!(repository)

    opts = Keyword.validate!(opts, [:module, :architecture])

    module = opts[:module]
    architecture = opts[:architecture]

    with {:ok, path} <- download(repository, @config_filename),
         {:ok, spec_data} <- decode_config(path) do
      {inferred_module, inferred_architecture, inference_error} =
        case infer_model_type(spec_data) do
          {:ok, module, architecture} -> {module, architecture, nil}
          {:error, error} -> {nil, nil, error}
        end

      module = module || inferred_module
      architecture = architecture || inferred_architecture

      unless module do
        raise "#{inference_error}, please specify the :module and :architecture options"
      end

      architectures = module.architectures()

      if architecture && architecture not in architectures do
        raise ArgumentError,
              "expected architecture to be one of: #{Enum.map_join(architectures, ", ", &inspect/1)}, but got: #{inspect(architecture)}"
      end

      spec =
        if architecture do
          configure(module, architecture: architecture)
        else
          configure(module)
        end

      spec = HuggingFace.Transformers.Config.load(spec, spec_data)

      {:ok, spec}
    end
  end

  defp decode_config(path) do
    path
    |> File.read!()
    |> Jason.decode()
    |> case do
      {:ok, data} -> {:ok, data}
      _ -> {:error, "failed to parse the config file, it is not a valid JSON"}
    end
  end

  defp infer_model_type(%{"architectures" => [class_name]}) do
    case @transformers_class_to_model[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported models"}

      {module, architecture} ->
        {:ok, module, architecture}
    end
  end

  defp infer_model_type(%{"_class_name" => class_name}) do
    infer_model_type(%{"architectures" => [class_name]})
  end

  defp infer_model_type(_spec_data) do
    {:error, "could not infer model type from the configuration"}
  end

  @doc """
  Loads a pre-trained model from a model repository.

  ## Options

    * `:spec` - the model specification to use when building the model.
      By default the specification is loaded using `load_spec/2`

    * `:module` - the model specification module. By default it is
      inferred from the configuration file, if that is not possible,
      it must be specified explicitly

    * `:architecture` - the model architecture, must be supported by
      `:module`. By default it is inferred from the configuration file

    * `:params_filename` - the file with the model parameters to be loaded

    * `:log_params_diff` - whether to log missing, mismatched and unused
      parameters. Defaults to `true`

    * `:backend` - the backend to allocate the tensors on. It is either
      an atom or a tuple in the shape `{backend, options}`

  ## Examples

  By default the model type is inferred from configuration, so loading
  is as simple as:

      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
      %{model: model, params: params, spec: spec} = resnet

  You can explicitly specify a different architecture, in which case
  matching parameters are still loaded:

      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, architecture: :base)

  To further customize the model, you can also pass the specification:

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})
      spec = Bumblebee.configure(spec, num_labels: 10)
      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)

  """
  @doc type: :model
  @spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()}
  def load_model(repository, opts \\ []) do
    repository = normalize_repository!(repository)

    opts =
      Keyword.validate!(opts, [
        :spec,
        :module,
        :architecture,
        :params_filename,
        :backend,
        log_params_diff: true
      ])

    spec_response =
      if spec = opts[:spec] do
        {:ok, spec}
      else
        load_spec(repository, Keyword.take(opts, [:module, :architecture]))
      end

    with {:ok, spec} <- spec_response,
         model <- build_model(spec),
         {:ok, params} <-
           load_params(
             spec,
             model,
             repository,
             opts
             |> Keyword.take([:params_filename, :log_params_diff, :backend])
           ) do
      {:ok, %{model: model, params: params, spec: spec}}
    end
  end

  defp load_params(%module{} = spec, model, repository, opts) do
    # TODO: support format: :auto | :axon | :pytorch
    format = :pytorch
    filename = opts[:params_filename] || @params_filename[format]
    log_params_diff = opts[:log_params_diff]
    backend = opts[:backend]

    input_template = module.input_template(spec)

    params_mapping = Bumblebee.HuggingFace.Transformers.Model.params_mapping(spec)

    with {:ok, path} <- download(repository, filename) do
      params =
        Bumblebee.Conversion.PyTorch.load_params!(model, input_template, path,
          log_params_diff: log_params_diff,
          backend: backend,
          params_mapping: params_mapping
        )

      {:ok, params}
    end
  end

  @doc """
  Featurizes `input` with the given featurizer.

  ## Options

    * `:defn_options` - the options for JIT compilation. Note that
      this is only relevant for featurizers implemented with Nx.
      Defaults to `[]`

  ## Examples

      featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
      {:ok, img} = StbImage.read_file(path)
      inputs = Bumblebee.apply_featurizer(featurizer, [img])

  """
  @doc type: :featurizer
  @spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any()
  def apply_featurizer(%module{} = featurizer, input, opts \\ []) do
    opts = Keyword.validate!(opts, defn_options: [])
    module.apply(featurizer, input, opts[:defn_options])
  end

  @doc """
  Loads featurizer from a model repository.

  ## Options

    * `:module` - the featurizer module. By default it is inferred
      from the preprocessor configuration file, if that is not possible,
      it must be specified explicitly

  ## Examples

      {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})

  """
  @doc type: :featurizer
  @spec load_featurizer(repository(), keyword()) ::
          {:ok, Bumblebee.Featurizer.t()} | {:error, String.t()}
  def load_featurizer(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    with {:ok, path} <- download(repository, @featurizer_filename),
         {:ok, featurizer_data} <- decode_config(path) do
      module =
        module ||
          case infer_featurizer_type(featurizer_data, repository) do
            {:ok, module} -> module
            {:error, error} -> raise "#{error}, please specify the :module option"
          end

      featurizer = configure(module)
      featurizer = HuggingFace.Transformers.Config.load(featurizer, featurizer_data)
      {:ok, featurizer}
    end
  end

  defp infer_featurizer_type(%{"feature_extractor_type" => class_name}, _repository) do
    case @transformers_class_to_featurizer[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported featurizers"}

      module ->
        {:ok, module}
    end
  end

  defp infer_featurizer_type(_featurizer_data, repository) do
    with {:ok, path} <- download(repository, @config_filename),
         {:ok, featurizer_data} <- decode_config(path) do
      case featurizer_data do
        %{"model_type" => model_type} ->
          case @model_type_to_featurizer[model_type] do
            nil ->
              {:error,
               "could not match model type #{inspect(model_type)} to any of the supported featurizers"}

            module ->
              {:ok, module}
          end

        _ ->
          {:error, "could not infer featurizer type from the configuration"}
      end
    end
  end

  @doc """
  Tokenizes and encodes `input` with the given tokenizer.

  ## Options

    * `:add_special_tokens` - whether to add special tokens. Defaults
      to `true`

    * `:pad_direction` - the padding direction, either `:right` or
      `:left`. Defaults to `:right`

    * `:return_attention_mask` - whether to return attention mask for
      encoded sequence. Defaults to `true`

    * `:return_token_type_ids` - whether to return token type ids for
      encoded sequence. Defaults to `true`

    * `:return_special_tokens_mask` - whether to return special tokens
      mask for encoded sequence. Defaults to `false`

    * `:return_offsets` - whether to return token offsets for encoded
      sequence. Defaults to `false`

    * `:length` - applies fixed length padding or truncation to the given
      input if set


  ## Examples

      tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
      inputs = Bumblebee.apply_tokenizer(tokenizer, ["The capital of France is [MASK]."])

  """
  @doc type: :tokenizer
  @spec apply_tokenizer(
          Bumblebee.Tokenizer.t(),
          Bumblebee.Tokenizer.input() | list(Bumblebee.Tokenizer.input()),
          keyword()
        ) :: any()
  def apply_tokenizer(%module{} = tokenizer, input, opts \\ []) do
    opts =
      Keyword.validate!(opts,
        add_special_tokens: true,
        pad_direction: :right,
        truncate_direction: :right,
        length: nil,
        return_attention_mask: true,
        return_token_type_ids: true,
        return_special_tokens_mask: false,
        return_offsets: false
      )

    module.apply(tokenizer, input, opts)
  end

  @doc """
  Loads tokenizer from a model repository.

  ## Options

    * `:module` - the tokenizer module. By default it is inferred from
      the configuration files, if that is not possible, it must be
      specified explicitly

  ## Examples

      {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

  """
  @doc type: :tokenizer
  @spec load_tokenizer(repository(), keyword()) ::
          {:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()}
  def load_tokenizer(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    with {:ok, path} <- download(repository, @tokenizer_filename) do
      module =
        module ||
          case infer_tokenizer_type(repository) do
            {:ok, module} -> module
            {:error, error} -> raise "#{error}, please specify the :module option"
          end

      special_tokens_map =
        with {:ok, path} <- download(repository, @tokenizer_special_tokens_filename),
             {:ok, special_tokens_map} <- decode_config(path) do
          special_tokens_map
        else
          _ -> %{}
        end

      tokenizer = struct!(module)

      tokenizer =
        HuggingFace.Transformers.Config.load(tokenizer, %{
          "tokenizer_file" => path,
          "special_tokens_map" => special_tokens_map
        })

      {:ok, tokenizer}
    end
  end

  defp infer_tokenizer_type(repository) do
    with {:ok, path} <- download(repository, @config_filename),
         {:ok, tokenizer_data} <- decode_config(path) do
      case tokenizer_data do
        %{"model_type" => model_type} ->
          case @model_type_to_tokenizer[model_type] do
            nil ->
              {:error,
               "could not match model type #{inspect(model_type)} to any of the supported tokenizers"}

            module ->
              {:ok, module}
          end

        _ ->
          {:error, "could not infer tokenizer type from the model configuration"}
      end
    end
  end

  @doc """
  Initializes state for a new scheduler loop.

  Returns a pair of `{state, timesteps}`, where `state` is an opaque
  container expected by `scheduler_step/4` and `timesteps` is a sequence
  of subsequent timesteps for model forward pass.

  Note that the number of `timesteps` may not match `num_steps` exactly.
  `num_steps` parameterizes sampling points, however depending on the
  method, sampling certain points may require multiple forward passes
  of the model and each element in `timesteps` corresponds to a single
  forward pass.
  """
  @doc type: :scheduler
  @spec scheduler_init(
          Bumblebee.Scheduler.t(),
          non_neg_integer(),
          tuple()
        ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()}
  def scheduler_init(%module{} = scheduler, num_steps, sample_shape) do
    module.init(scheduler, num_steps, sample_shape)
  end

  @doc """
  Predicts sample at the previous timestep using the given scheduler.

  Takes the current `sample` and `prediction` (usually noise) returned
  by the model at the current timestep. Returns `{state, prev_sample}`,
  where `state` is the updated scheduler loop state and `prev_sample`
  is the predicted sample at the previous timestep.

  Note that some schedulers require several forward passes of the model
  (and a couple calls to this function) to make an actual prediction for
  the previous sample.
  """
  @doc type: :scheduler
  @spec scheduler_step(
          Bumblebee.Scheduler.t(),
          Bumblebee.Scheduler.state(),
          Nx.Tensor.t(),
          Nx.Tensor.t()
        ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()}
  def scheduler_step(%module{} = scheduler, state, sample, prediction) do
    module.step(scheduler, state, sample, prediction)
  end

  @doc """
  Loads scheduler from a model repository.

  ## Options

    * `:module` - the scheduler module. By default it is inferred
      from the scheduler configuration file, if that is not possible,
      it must be specified explicitly

  ## Examples

      {:ok, scheduler} =
        Bumblebee.load_scheduler({:hf, "CompVis/stable-diffusion-v1-4", subdir: "scheduler"})

  """
  @doc type: :scheduler
  @spec load_scheduler(repository(), keyword()) ::
          {:ok, Bumblebee.Scheduler.t()} | {:error, String.t()}
  def load_scheduler(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    with {:ok, path} <- download(repository, @scheduler_filename),
         {:ok, scheduler_data} <- decode_config(path) do
      module =
        module ||
          case infer_scheduler_type(scheduler_data) do
            {:ok, module} -> module
            {:error, error} -> raise "#{error}, please specify the :module option"
          end

      scheduler = configure(module)
      scheduler = HuggingFace.Transformers.Config.load(scheduler, scheduler_data)
      {:ok, scheduler}
    end
  end

  defp infer_scheduler_type(%{"_class_name" => class_name}) do
    case @diffusers_class_to_scheduler[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported schedulers"}

      module ->
        {:ok, module}
    end
  end

  defp infer_scheduler_type(_scheduler_data) do
    {:error, "could not infer featurizer type from the configuration"}
  end

  defp download({:local, dir}, filename) do
    path = Path.join(dir, filename)

    if File.exists?(path) do
      {:ok, path}
    else
      {:error, "local file #{inspect(path)} does not exist"}
    end
  end

  defp download({:hf, repository_id, opts}, filename) do
    revision = opts[:revision]
    cache_dir = opts[:cache_dir]
    auth_token = opts[:auth_token]
    subdir = opts[:subdir]

    filename = if subdir, do: subdir <> "/" <> filename, else: filename

    url = HuggingFace.Hub.file_url(repository_id, filename, revision)

    HuggingFace.Hub.cached_download(url, cache_dir: cache_dir, auth_token: auth_token)
  end

  defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do
    {:hf, repository_id, []}
  end

  defp normalize_repository!({:hf, repository_id, opts}) when is_binary(repository_id) do
    opts = Keyword.validate!(opts, [:revision, :cache_dir, :auth_token, :subdir])
    {:hf, repository_id, opts}
  end

  defp normalize_repository!({:local, dir}) when is_binary(dir) do
    {:local, dir}
  end

  defp normalize_repository!(other) do
    raise ArgumentError,
          "expected repository to be either {:hf, repository_id}, {:hf, repository_id, options}" <>
            " or {:local, directory}, got: #{inspect(other)}"
  end
end