lib/bumblebee.ex

defmodule Bumblebee do
  @external_resource "README.md"

  [_, readme_docs, _] =
    "README.md"
    |> File.read!()
    |> String.split("<!-- Docs -->")

  @moduledoc """
  Pre-trained `Axon` models for easy inference and boosted training.

  Bumblebee provides state-of-the-art, configurable `Axon` models. On
  top of that, it streamlines the process of loading pre-trained models
  by integrating with Hugging Face Hub and [🤗 Transformers](https://github.com/huggingface/transformers).

  ## Usage

  You can load one of the supported models by specifying the model repository:

      {:ok, model_info} = Bumblebee.load_model({:hf, "bert-base-uncased"})
      {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

  Then you are ready to make predictions:

      inputs = Bumblebee.apply_tokenizer(tokenizer, "Hello Bumblebee!")
      outputs = Axon.predict(model_info.model, model_info.params, inputs)

  ### Tasks

  On top of bare models, Bumblebee provides a number of **"servings"**
  that act as end-to-end pipelines for specific tasks.

      serving = Bumblebee.Text.fill_mask(model_info, tokenizer)
      Nx.Serving.run(serving, "The capital of [MASK] is Paris.")
      #=> %{
      #=>   predictions: [
      #=>     %{score: 0.9279842972755432, token: "france"},
      #=>     %{score: 0.008412551134824753, token: "brittany"},
      #=>     %{score: 0.007433671969920397, token: "algeria"},
      #=>     %{score: 0.004957548808306456, token: "department"},
      #=>     %{score: 0.004369721747934818, token: "reunion"}
      #=>   ]
      #=> }

  As you can see the **serving** takes care of pre-processing the
  text input, runs the model and also post-processes its output into
  more structured data. In the above example we `run` serving on the
  fly, however for production usage you can start serving as a process
  and it will automatically batch requests from multiple clients.
  Processing inputs in batches is usually much more efficient, since
  it can take advantage of parallel capabilities of the target device,
  which is particularly relevant in case of GPU. For more details read
  the `Nx.Serving` docs.

  For more examples see the [Examples](examples.livemd) notebook.

  > #### Note {: .info}
  >
  > The models are generally large, so make sure to configure an efficient
  > `Nx` backend, such as `EXLA` or `Torchx`.

  #{readme_docs}
  """

  alias Bumblebee.HuggingFace

  @config_filename "config.json"
  @featurizer_filename "preprocessor_config.json"
  @tokenizer_filename "tokenizer.json"
  @tokenizer_config_filename "tokenizer_config.json"
  @tokenizer_special_tokens_filename "special_tokens_map.json"
  @generation_filename "generation_config.json"
  @scheduler_filename "scheduler_config.json"
  @pytorch_params_filename "pytorch_model.bin"
  @safetensors_params_filename "model.safetensors"

  @transformers_class_to_model %{
    "AlbertForMaskedLM" => {Bumblebee.Text.Albert, :for_masked_language_modeling},
    "AlbertForMultipleChoice" => {Bumblebee.Text.Albert, :for_multiple_choice},
    "AlbertForPreTraining" => {Bumblebee.Text.Albert, :for_pre_training},
    "AlbertForQuestionAnswering" => {Bumblebee.Text.Albert, :for_question_answering},
    "AlbertForSequenceClassification" => {Bumblebee.Text.Albert, :for_sequence_classification},
    "AlbertForTokenClassification" => {Bumblebee.Text.Albert, :for_token_classification},
    "AlbertModel" => {Bumblebee.Text.Albert, :base},
    "BartForCausalLM" => {Bumblebee.Text.Bart, :for_causal_language_modeling},
    "BartForConditionalGeneration" => {Bumblebee.Text.Bart, :for_conditional_generation},
    "BartForQuestionAnswering" => {Bumblebee.Text.Bart, :for_question_answering},
    "BartForSequenceClassification" => {Bumblebee.Text.Bart, :for_sequence_classification},
    "BartModel" => {Bumblebee.Text.Bart, :base},
    "BertForMaskedLM" => {Bumblebee.Text.Bert, :for_masked_language_modeling},
    "BertForMultipleChoice" => {Bumblebee.Text.Bert, :for_multiple_choice},
    "BertForNextSentencePrediction" => {Bumblebee.Text.Bert, :for_next_sentence_prediction},
    "BertForPreTraining" => {Bumblebee.Text.Bert, :for_pre_training},
    "BertForQuestionAnswering" => {Bumblebee.Text.Bert, :for_question_answering},
    "BertForSequenceClassification" => {Bumblebee.Text.Bert, :for_sequence_classification},
    "BertForTokenClassification" => {Bumblebee.Text.Bert, :for_token_classification},
    "BertLMHeadModel" => {Bumblebee.Text.Bert, :for_causal_language_modeling},
    "BertModel" => {Bumblebee.Text.Bert, :base},
    "BlenderbotForConditionalGeneration" =>
      {Bumblebee.Text.Blenderbot, :for_conditional_generation},
    "BlenderbotModel" => {Bumblebee.Text.Blenderbot, :base},
    "BlipForConditionalGeneration" => {Bumblebee.Multimodal.Blip, :for_conditional_generation},
    # These models are just RoBERTa models, but the config will list them as CamemBERT
    "CamembertModel" => {Bumblebee.Text.Roberta, :base},
    "CamembertForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "CamembertForSequenceClassification" =>
      {Bumblebee.Text.Roberta, :for_sequence_classification},
    "CamembertForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "CamembertForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "CamembertForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "CLIPModel" => {Bumblebee.Multimodal.Clip, :base},
    "CLIPTextModel" => {Bumblebee.Text.ClipText, :base},
    "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base},
    "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification},
    "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base},
    "DeiTForImageClassification" => {Bumblebee.Vision.Deit, :for_image_classification},
    "DeiTForImageClassificationWithTeacher" =>
      {Bumblebee.Vision.Deit, :for_image_classification_with_teacher},
    "DeiTForMaskedImageModeling" => {Bumblebee.Vision.Deit, :for_masked_image_modeling},
    "DeiTModel" => {Bumblebee.Vision.Deit, :base},
    "DistilBertModel" => {Bumblebee.Text.Distilbert, :base},
    "DistilBertForMaskedLM" => {Bumblebee.Text.Distilbert, :for_masked_language_modeling},
    "DistilBertForSequenceClassification" =>
      {Bumblebee.Text.Distilbert, :for_sequence_classification},
    "DistilBertForQuestionAnswering" => {Bumblebee.Text.Distilbert, :for_question_answering},
    "DistilBertForTokenClassification" => {Bumblebee.Text.Distilbert, :for_token_classification},
    "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
    "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
    "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
    "GPT2Model" => {BumbleBee.Text.Gpt2, :base},
    "GPTNeoXModel" => {Bumblebee.Text.GptNeoX, :base},
    "GPTNeoXForCausalLM" => {Bumblebee.Text.GptNeoX, :for_causal_language_modeling},
    "GPTNeoXForSequenceClassification" => {Bumblebee.Text.GptNeoX, :for_sequence_classification},
    "GPTNeoXForTokenClassification" => {Bumblebee.Text.GptNeoX, :for_token_classification},
    "LayoutLMForMaskedLanguageModeling" =>
      {Bumblebee.Multimodal.LayoutLm, :for_masked_language_modeling},
    "LayoutLMForQuestionAnswering" => {Bumblebee.Multimodal.LayoutLm, :for_question_answering},
    "LayoutLMForSequenceClassification" =>
      {Bumblebee.Multimodal.LayoutLm, :for_sequence_classification},
    "LayoutLMForTokenClassification" =>
      {Bumblebee.Multimodal.LayoutLm, :for_token_classification},
    "LayoutLMModel" => {Bumblebee.Multimodal.LayoutLm, :base},
    "LlamaModel" => {Bumblebee.Text.Llama, :base},
    "LlamaForCausalLM" => {Bumblebee.Text.Llama, :for_causal_language_modeling},
    "LlamaForSequenceClassification" => {Bumblebee.Text.Llama, :for_sequence_classification},
    "MBartForCausalLM" => {Bumblebee.Text.Mbart, :for_causal_language_modeling},
    "MBartForConditionalGeneration" => {Bumblebee.Text.Mbart, :for_conditional_generation},
    "MBartForQuestionAnswering" => {Bumblebee.Text.Mbart, :for_question_answering},
    "MBartForSequenceClassification" => {Bumblebee.Text.Mbart, :for_sequence_classification},
    "MBartModel" => {Bumblebee.Text.Mbart, :base},
    "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
    "ResNetModel" => {Bumblebee.Vision.ResNet, :base},
    "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "RobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "RobertaForPreTraining" => {Bumblebee.Text.Roberta, :for_pre_training},
    "RobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "RobertaForSequenceClassification" => {Bumblebee.Text.Roberta, :for_sequence_classification},
    "RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
    "RobertaModel" => {Bumblebee.Text.Roberta, :base},
    "T5Model" => {Bumblebee.Text.T5, :base},
    "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
    "T5EncoderModel" => {Bumblebee.Text.T5, :encoder},
    "ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification},
    "ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling},
    "ViTModel" => {Bumblebee.Vision.Vit, :base},
    "WhisperModel" => {Bumblebee.Audio.Whisper, :base},
    "WhisperForConditionalGeneration" => {Bumblebee.Audio.Whisper, :for_conditional_generation},
    # These models are just RoBERTa models, but the config will list them as XLM-RoBERTa
    "XLMRobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
    "XLMRobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
    "XLMRobertaForMultipleChoice" => {Bumblebee.Text.Roberta, :for_multiple_choice},
    "XLMRobertaForQuestionAnswering" => {Bumblebee.Text.Roberta, :for_question_answering},
    "XLMRobertaForSequenceClassification" =>
      {Bumblebee.Text.Roberta, :for_sequence_classification},
    "XLMRobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
    "XLMRobertaModel" => {Bumblebee.Text.Roberta, :base},
    # Diffusers
    "AutoencoderKL" => {Bumblebee.Diffusion.VaeKl, :base},
    "StableDiffusionSafetyChecker" => {Bumblebee.Diffusion.StableDiffusion.SafetyChecker, :base},
    "UNet2DConditionModel" => {Bumblebee.Diffusion.UNet2DConditional, :base}
  }

  @transformers_class_to_featurizer %{
    "CLIPFeatureExtractor" => Bumblebee.Vision.ClipFeaturizer,
    "ConvNextFeatureExtractor" => Bumblebee.Vision.ConvNextFeaturizer,
    "DeiTFeatureExtractor" => Bumblebee.Vision.DeitFeaturizer,
    "ViTFeatureExtractor" => Bumblebee.Vision.VitFeaturizer,
    "WhisperFeatureExtractor" => Bumblebee.Audio.WhisperFeaturizer
  }

  @transformers_image_processor_type_to_featurizer %{
    "BlipImageProcessor" => Bumblebee.Vision.BlipFeaturizer
  }

  @model_type_to_featurizer %{
    "convnext" => Bumblebee.Vision.ConvNextFeaturizer,
    "deit" => Bumblebee.Vision.DeitFeaturizer,
    "resnet" => Bumblebee.Vision.ConvNextFeaturizer,
    "vit" => Bumblebee.Vision.VitFeaturizer,
    "whisper" => Bumblebee.Audio.WhisperFeaturizer
  }

  @model_type_to_tokenizer %{
    "albert" => Bumblebee.Text.AlbertTokenizer,
    "bart" => Bumblebee.Text.BartTokenizer,
    "bert" => Bumblebee.Text.BertTokenizer,
    "blenderbot" => Bumblebee.Text.BlenderbotTokenizer,
    "blip" => Bumblebee.Text.BertTokenizer,
    "distilbert" => Bumblebee.Text.DistilbertTokenizer,
    "camembert" => Bumblebee.Text.CamembertTokenizer,
    "clip" => Bumblebee.Text.ClipTokenizer,
    "gpt_neox" => Bumblebee.Text.GptNeoXTokenizer,
    "gpt2" => Bumblebee.Text.Gpt2Tokenizer,
    "layoutlm" => Bumblebee.Text.LayoutLmTokenizer,
    "llama" => Bumblebee.Text.LlamaTokenizer,
    "mbart" => Bumblebee.Text.MbartTokenizer,
    "roberta" => Bumblebee.Text.RobertaTokenizer,
    "t5" => Bumblebee.Text.T5Tokenizer,
    "whisper" => Bumblebee.Text.WhisperTokenizer,
    "xlm-roberta" => Bumblebee.Text.XlmRobertaTokenizer
  }

  @diffusers_class_to_scheduler %{
    "DDIMScheduler" => Bumblebee.Diffusion.DdimScheduler,
    "PNDMScheduler" => Bumblebee.Diffusion.PndmScheduler
  }

  @typedoc """
  A location to fetch model files from.

  Can be either:

    * `{:hf, repository_id}` - the repository on Hugging Face. Options
      may be passed as the third element:

        * `:revision` - the specific model version to use, it can be
          any valid git identifier, such as branch name, tag name, or
          a commit hash

        * `:cache_dir` - the directory to store the downloaded files
          in. Defaults to the standard cache location for the given
          operating system. You can also configure it globally by
          setting the `BUMBLEBEE_CACHE_DIR` environment variable

        * `:offline` - if `true`, only cached files are accessed and
          missing files result in an error. You can also configure it
          globally by setting the `BUMBLEBEE_OFFLINE` environment
          variable to `true`

        * `:auth_token` - the token to use as HTTP bearer authorization
          for remote files

        * `:subdir` - the directory within the repository where the
          files are located

    * `{:local, directory}` - the directory containing model files

  """
  @type repository :: {:hf, String.t()} | {:hf, String.t(), keyword()} | {:local, Path.t()}

  @typedoc """
  A model together with its state and metadata.
  """
  @type model_info :: %{
          model: Axon.t(),
          params: map(),
          spec: Bumblebee.ModelSpec.t()
        }

  @doc """
  Builds or updates a configuration object with the given options.

  Expects a configuration struct or a module supporting configuration.
  These are usually configurable:

    * model specification (`Bumblebee.ModelSpec`)

    * featurizer (`Bumblebee.Featurizer`)

    * scheduler (`Bumblebee.Scheduler`)

  ## Examples

  To build a new configuration, pass a module:

      featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
      spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :for_image_classification)

  Similarly, you can update an existing configuration:

      featurizer = Bumblebee.configure(featurizer, resize_method: :bilinear)
      spec = Bumblebee.configure(spec, embedding_size: 128)

  """
  @spec configure(module() | Bumblebee.Configurable.t(), keyword()) :: Bumblebee.Configurable.t()
  def configure(config, options \\ []) do
    %module{} = config = struct!(config)
    module.config(config, options)
  end

  @doc """
  Builds an `Axon` model according to the given specification.

  ## Example

      spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128)
      model = Bumblebee.build_model(spec)

  """
  @doc type: :model
  @spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t()
  def build_model(%module{} = spec) do
    module.model(spec)
  end

  @doc """
  Loads model specification from a model repository.

  ## Options

    * `:module` - the model specification module. By default it is
      inferred from the configuration file, if that is not possible,
      it must be specified explicitly

    * `:architecture` - the model architecture, must be supported by
      `:module`. By default it is inferred from the configuration file

  ## Examples

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})

  You can explicitly specify a different architecture:

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}, architecture: :base)

  """
  @doc type: :model
  @spec load_spec(repository(), keyword()) ::
          {:ok, Bumblebee.ModelSpec.t()} | {:error, String.t()}
  def load_spec(repository, opts \\ []) do
    repository = normalize_repository!(repository)

    opts = Keyword.validate!(opts, [:module, :architecture])

    module = opts[:module]
    architecture = opts[:architecture]

    with {:ok, repo_files} <- get_repo_files(repository) do
      do_load_spec(repository, repo_files, module, architecture)
    end
  end

  defp do_load_spec(repository, repo_files, module, architecture) do
    case repo_files do
      %{@config_filename => etag} ->
        with {:ok, path} <- download(repository, @config_filename, etag),
             {:ok, spec_data} <- decode_config(path) do
          {inferred_module, inferred_architecture, inference_error} =
            case infer_model_type(spec_data) do
              {:ok, module, architecture} -> {module, architecture, nil}
              {:error, error} -> {nil, nil, error}
            end

          module = module || inferred_module
          architecture = architecture || inferred_architecture

          unless module do
            raise ArgumentError,
                  "#{inference_error}, please specify the :module and :architecture options"
          end

          architectures = module.architectures()

          if architecture && architecture not in architectures do
            raise ArgumentError,
                  "expected architecture to be one of: #{Enum.map_join(architectures, ", ", &inspect/1)}, but got: #{inspect(architecture)}"
          end

          spec =
            if architecture do
              configure(module, architecture: architecture)
            else
              configure(module)
            end

          spec = HuggingFace.Transformers.Config.load(spec, spec_data)

          {:ok, spec}
        end

      %{} ->
        raise ArgumentError,
              "no config file found in the given repository. Please refer to Bumblebee" <>
                " README to learn about repositories and supported models"
    end
  end

  defp decode_config(path) do
    path
    |> File.read!()
    |> Jason.decode()
    |> case do
      {:ok, data} -> {:ok, data}
      _ -> {:error, "failed to parse the config file, it is not a valid JSON"}
    end
  end

  defp infer_model_type(%{"architectures" => [class_name]}) do
    case @transformers_class_to_model[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported models"}

      {module, architecture} ->
        {:ok, module, architecture}
    end
  end

  defp infer_model_type(%{"_class_name" => class_name}) do
    infer_model_type(%{"architectures" => [class_name]})
  end

  defp infer_model_type(_spec_data) do
    {:error, "could not infer model type from the configuration"}
  end

  @doc """
  Loads a pre-trained model from a model repository.

  ## Options

    * `:spec` - the model specification to use when building the model.
      By default the specification is loaded using `load_spec/2`

    * `:module` - the model specification module. By default it is
      inferred from the configuration file, if that is not possible,
      it must be specified explicitly

    * `:architecture` - the model architecture, must be supported by
      `:module`. By default it is inferred from the configuration file

    * `:params_filename` - the file with the model parameters to be loaded

    * `:log_params_diff` - whether to log missing, mismatched and unused
      parameters. By default diff is logged only if some parameters
      cannot be loaded

    * `:backend` - the backend to allocate the tensors on. It is either
      an atom or a tuple in the shape `{backend, options}`

  ## Examples

  By default the model type is inferred from configuration, so loading
  is as simple as:

      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
      %{model: model, params: params, spec: spec} = resnet

  You can explicitly specify a different architecture, in which case
  matching parameters are still loaded:

      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, architecture: :base)

  To further customize the model, you can also pass the specification:

      {:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})
      spec = Bumblebee.configure(spec, num_labels: 10)
      {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)

  """
  @doc type: :model
  @spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()}
  def load_model(repository, opts \\ []) do
    repository = normalize_repository!(repository)

    opts =
      Keyword.validate!(opts, [
        :spec,
        :module,
        :architecture,
        :params_filename,
        :backend,
        :log_params_diff
      ])

    with {:ok, repo_files} <- get_repo_files(repository),
         {:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files),
         model <- build_model(spec),
         {:ok, params} <- load_params(spec, model, repository, repo_files, opts) do
      {:ok, %{model: model, params: params, spec: spec}}
    end
  end

  defp maybe_load_model_spec(opts, repository, repo_files) do
    if spec = opts[:spec] do
      {:ok, spec}
    else
      do_load_spec(repository, repo_files, opts[:module], opts[:architecture])
    end
  end

  defp load_params(%module{} = spec, model, repository, repo_files, opts) do
    input_template = module.input_template(spec)

    params_mapping = Bumblebee.HuggingFace.Transformers.Model.params_mapping(spec)

    {filename, sharded?} = infer_params_filename(repo_files, opts[:params_filename])
    loader_fun = filename |> Path.extname() |> params_file_loader_fun()

    with {:ok, paths} <- download_params_files(repository, repo_files, filename, sharded?) do
      opts =
        [
          params_mapping: params_mapping,
          loader_fun: loader_fun
        ] ++ Keyword.take(opts, [:backend, :log_params_diff])

      params = Bumblebee.Conversion.PyTorch.load_params!(model, input_template, paths, opts)
      {:ok, params}
    end
  end

  defp infer_params_filename(repo_files, nil = _filename) do
    cond do
      Map.has_key?(repo_files, @pytorch_params_filename) ->
        {@pytorch_params_filename, false}

      Map.has_key?(repo_files, @pytorch_params_filename <> ".index.json") ->
        {@pytorch_params_filename, true}

      Map.has_key?(repo_files, @safetensors_params_filename) ->
        {@safetensors_params_filename, false}

      Map.has_key?(repo_files, @safetensors_params_filename <> ".index.json") ->
        {@safetensors_params_filename, true}

      true ->
        raise ArgumentError,
              "none of the expected parameters files found in the repository." <>
                " If the file exists under an unusual name, try specifying :params_filename"
    end
  end

  defp infer_params_filename(repo_files, filename) do
    cond do
      Map.has_key?(repo_files, filename) ->
        {filename, false}

      Map.has_key?(repo_files, filename <> ".index.json") ->
        {filename, true}

      true ->
        raise ArgumentError, "could not find file #{inspect(filename)} in the repository"
    end
  end

  defp download_params_files(repository, repo_files, filename, false = _sharded?) do
    with {:ok, path} <- download(repository, filename, repo_files[filename]) do
      {:ok, [path]}
    end
  end

  defp download_params_files(repository, repo_files, filename, true = _sharded?) do
    index_filename = filename <> ".index.json"

    with {:ok, path} <- download(repository, index_filename, repo_files[index_filename]),
         {:ok, sharded_metadata} <- decode_config(path) do
      filenames =
        for {_layer, filename} <- sharded_metadata["weight_map"], uniq: true, do: filename

      Enum.reduce_while(filenames, {:ok, []}, fn filename, {:ok, paths} ->
        case download(repository, filename, repo_files[filename]) do
          {:ok, path} -> {:cont, {:ok, [path | paths]}}
          error -> {:halt, error}
        end
      end)
    end
  end

  defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!/1
  defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorch.Loader.load!/1

  @doc """
  Featurizes `input` with the given featurizer.

  ## Options

    * `:defn_options` - the options for JIT compilation. Note that
      this is only relevant for featurizers implemented with Nx.
      Defaults to `[]`

  ## Examples

      featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
      {:ok, img} = StbImage.read_file(path)
      inputs = Bumblebee.apply_featurizer(featurizer, [img])

  """
  @doc type: :featurizer
  @spec apply_featurizer(Bumblebee.Featurizer.t(), any(), keyword()) :: any()
  def apply_featurizer(%module{} = featurizer, input, opts \\ []) do
    opts = Keyword.validate!(opts, defn_options: [])

    batch = module.process_input(featurizer, input)

    if Code.ensure_loaded?(module) and function_exported?(module, :process_batch, 2) do
      Nx.Defn.jit_apply(&module.process_batch(featurizer, &1), [batch], opts[:defn_options])
    else
      batch
    end
  end

  @doc """
  Loads featurizer from a model repository.

  ## Options

    * `:module` - the featurizer module. By default it is inferred
      from the preprocessor configuration file, if that is not possible,
      it must be specified explicitly

  ## Examples

      {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})

  """
  @doc type: :featurizer
  @spec load_featurizer(repository(), keyword()) ::
          {:ok, Bumblebee.Featurizer.t()} | {:error, String.t()}
  def load_featurizer(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    case get_repo_files(repository) do
      {:ok, %{@featurizer_filename => etag} = repo_files} ->
        with {:ok, path} <- download(repository, @featurizer_filename, etag),
             {:ok, featurizer_data} <- decode_config(path) do
          module =
            module ||
              case infer_featurizer_type(featurizer_data, repository, repo_files) do
                {:ok, module} ->
                  module

                {:error, error} ->
                  raise ArgumentError, "#{error}, please specify the :module option"
              end

          featurizer = configure(module)
          featurizer = HuggingFace.Transformers.Config.load(featurizer, featurizer_data)
          {:ok, featurizer}
        end

      {:ok, %{}} ->
        raise ArgumentError, "no featurizer found in the given repository"

      {:error, message} ->
        {:error, message}
    end
  end

  defp infer_featurizer_type(%{"feature_extractor_type" => class_name}, _repository, _repo_files) do
    case @transformers_class_to_featurizer[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported featurizers"}

      module ->
        {:ok, module}
    end
  end

  defp infer_featurizer_type(%{"image_processor_type" => class_name}, _repository, _repo_files) do
    case @transformers_image_processor_type_to_featurizer[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported featurizers"}

      module ->
        {:ok, module}
    end
  end

  defp infer_featurizer_type(_featurizer_data, repository, repo_files) do
    with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]),
         {:ok, featurizer_data} <- decode_config(path) do
      case featurizer_data do
        %{"model_type" => model_type} ->
          case @model_type_to_featurizer[model_type] do
            nil ->
              {:error,
               "could not match model type #{inspect(model_type)} to any of the supported featurizers"}

            module ->
              {:ok, module}
          end

        _ ->
          {:error, "could not infer featurizer type from the configuration"}
      end
    end
  end

  @doc """
  Tokenizes and encodes `input` with the given tokenizer.

  ## Options

    * `:add_special_tokens` - whether to add special tokens. Defaults
      to `true`

    * `:pad_direction` - the padding direction, either `:right` or
      `:left`. Defaults to `:right`

    * `:return_attention_mask` - whether to return attention mask for
      encoded sequence. Defaults to `true`

    * `:return_token_type_ids` - whether to return token type ids for
      encoded sequence. Defaults to `true`

    * `:return_special_tokens_mask` - whether to return special tokens
      mask for encoded sequence. Defaults to `false`

    * `:return_offsets` - whether to return token offsets for encoded
      sequence. Defaults to `false`

    * `:length` - applies fixed length padding or truncation to the
      given input if set. Can be either a specific number or a list
      of numbers. When a list is given, the smallest number that
      exceeds all input lengths is used as the padding length

  ## Examples

      tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
      inputs = Bumblebee.apply_tokenizer(tokenizer, ["The capital of France is [MASK]."])

  """
  @doc type: :tokenizer
  @spec apply_tokenizer(
          Bumblebee.Tokenizer.t(),
          Bumblebee.Tokenizer.input() | list(Bumblebee.Tokenizer.input()),
          keyword()
        ) :: any()
  def apply_tokenizer(%module{} = tokenizer, input, opts \\ []) do
    opts =
      Keyword.validate!(opts,
        add_special_tokens: true,
        pad_direction: :right,
        truncate_direction: :right,
        length: nil,
        return_attention_mask: true,
        return_token_type_ids: true,
        return_special_tokens_mask: false,
        return_offsets: false
      )

    module.apply(tokenizer, input, opts)
  end

  @doc """
  Loads tokenizer from a model repository.

  ## Options

    * `:module` - the tokenizer module. By default it is inferred from
      the configuration files, if that is not possible, it must be
      specified explicitly

  ## Examples

      {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

  """
  @doc type: :tokenizer
  @spec load_tokenizer(repository(), keyword()) ::
          {:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()}
  def load_tokenizer(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    case get_repo_files(repository) do
      {:ok, %{@tokenizer_filename => etag} = repo_files} ->
        with {:ok, path} <- download(repository, @tokenizer_filename, etag) do
          module =
            module ||
              case infer_tokenizer_type(repository, repo_files) do
                {:ok, module} ->
                  module

                {:error, error} ->
                  raise ArgumentError, "#{error}, please specify the :module option"
              end

          special_tokens_map_result =
            if Map.has_key?(repo_files, @tokenizer_special_tokens_filename) do
              etag = repo_files[@tokenizer_special_tokens_filename]

              with {:ok, path} <- download(repository, @tokenizer_special_tokens_filename, etag) do
                decode_config(path)
              end
            else
              {:ok, %{}}
            end

          with {:ok, special_tokens_map} <- special_tokens_map_result do
            tokenizer = struct!(module)

            tokenizer =
              HuggingFace.Transformers.Config.load(tokenizer, %{
                "tokenizer_file" => path,
                "special_tokens_map" => special_tokens_map
              })

            {:ok, tokenizer}
          end
        end

      {:ok, %{@tokenizer_config_filename => _}} ->
        raise ArgumentError,
              "expected a Rust-compatible tokenizer.json file, however the repository" <>
                " includes tokenizer in a different format. Please refer to Bumblebee" <>
                " README to see the possible steps you can take"

      {:ok, %{}} ->
        raise ArgumentError, "no tokenizer found in the given repository"

      {:error, message} ->
        {:error, message}
    end
  end

  defp infer_tokenizer_type(repository, repo_files) do
    with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]),
         {:ok, tokenizer_data} <- decode_config(path) do
      case tokenizer_data do
        %{"model_type" => model_type} ->
          case @model_type_to_tokenizer[model_type] do
            nil ->
              {:error,
               "could not match model type #{inspect(model_type)} to any of the supported tokenizers"}

            module ->
              {:ok, module}
          end

        _ ->
          {:error, "could not infer tokenizer type from the model configuration"}
      end
    end
  end

  @doc """
  Loads generation config from a model repository.

  Generation config includes a number of model-specific properties,
  so it is usually best to load the config and further configure,
  rather than building from scratch.

  See `Bumblebee.Text.GenerationConfig` for all the available options.

  ## Options

    * `:spec_module` - the model specification module. By default it
      is inferred from the configuration file, if that is not possible,
      it must be specified explicitly. Some models have extra options
      related to generations and those are loaded into a separate
      struct, stored under the `:extra_config` attribute

  ## Examples

      {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})

      generation_config = Bumblebee.configure(generation_config, max_new_tokens: 10)

  """
  @spec load_generation_config(repository()) ::
          {:ok, Bumblebee.Text.GenerationConfig.t()} | {:error, String.t()}
  def load_generation_config(repository, opts \\ []) do
    opts = Keyword.validate!(opts, [:spec_module])

    repository = normalize_repository!(repository)

    case get_repo_files(repository) do
      {:ok, %{@config_filename => etag} = repo_files} ->
        with {:ok, path} <- download(repository, @config_filename, etag),
             {:ok, spec_data} <- decode_config(path) do
          spec_module = opts[:spec_module]

          {inferred_module, inference_error} =
            case infer_model_type(spec_data) do
              {:ok, module, _architecture} -> {module, nil}
              {:error, error} -> {nil, error}
            end

          spec_module = spec_module || inferred_module

          unless spec_module do
            raise ArgumentError, "#{inference_error}, please specify the :spec_module option"
          end

          generation_data_result =
            if Map.has_key?(repo_files, @generation_filename) do
              etag = repo_files[@generation_filename]

              with {:ok, path} <- download(repository, @generation_filename, etag) do
                decode_config(path)
              end
            else
              # Fallback to the spec data, since it used to include
              # generation attributes
              {:ok, spec_data}
            end

          with {:ok, generation_data} <- generation_data_result do
            config = struct!(Bumblebee.Text.GenerationConfig)
            config = HuggingFace.Transformers.Config.load(config, generation_data)

            extra_config_module =
              Bumblebee.Text.Generation.extra_config_module(struct!(spec_module))

            extra_config =
              if extra_config_module do
                extra_config = struct!(extra_config_module)
                HuggingFace.Transformers.Config.load(extra_config, generation_data)
              end

            config = %{config | extra_config: extra_config}

            {:ok, config}
          end
        end

      {:error, message} ->
        {:error, message}
    end
  end

  @doc """
  Initializes state for a new scheduler loop.

  Returns a pair of `{state, timesteps}`, where `state` is an opaque
  container expected by `scheduler_step/4` and `timesteps` is a sequence
  of subsequent timesteps for model forward pass.

  Note that the number of `timesteps` may not match `num_steps` exactly.
  `num_steps` parameterizes sampling points, however depending on the
  method, sampling certain points may require multiple forward passes
  of the model and each element in `timesteps` corresponds to a single
  forward pass.
  """
  @doc type: :scheduler
  @spec scheduler_init(
          Bumblebee.Scheduler.t(),
          non_neg_integer(),
          tuple()
        ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()}
  def scheduler_init(%module{} = scheduler, num_steps, sample_shape) do
    module.init(scheduler, num_steps, sample_shape)
  end

  @doc """
  Predicts sample at the previous timestep using the given scheduler.

  Takes the current `sample` and `prediction` (usually noise) returned
  by the model at the current timestep. Returns `{state, prev_sample}`,
  where `state` is the updated scheduler loop state and `prev_sample`
  is the predicted sample at the previous timestep.

  Note that some schedulers require several forward passes of the model
  (and a couple calls to this function) to make an actual prediction for
  the previous sample.
  """
  @doc type: :scheduler
  @spec scheduler_step(
          Bumblebee.Scheduler.t(),
          Bumblebee.Scheduler.state(),
          Nx.Tensor.t(),
          Nx.Tensor.t()
        ) :: {Bumblebee.Scheduler.state(), Nx.Tensor.t()}
  def scheduler_step(%module{} = scheduler, state, sample, prediction) do
    module.step(scheduler, state, sample, prediction)
  end

  @doc """
  Loads scheduler from a model repository.

  ## Options

    * `:module` - the scheduler module. By default it is inferred
      from the scheduler configuration file, if that is not possible,
      it must be specified explicitly

  ## Examples

      {:ok, scheduler} =
        Bumblebee.load_scheduler({:hf, "CompVis/stable-diffusion-v1-4", subdir: "scheduler"})

  """
  @doc type: :scheduler
  @spec load_scheduler(repository(), keyword()) ::
          {:ok, Bumblebee.Scheduler.t()} | {:error, String.t()}
  def load_scheduler(repository, opts \\ []) do
    repository = normalize_repository!(repository)
    opts = Keyword.validate!(opts, [:module])
    module = opts[:module]

    case get_repo_files(repository) do
      {:ok, %{@scheduler_filename => etag}} ->
        with {:ok, path} <- download(repository, @scheduler_filename, etag),
             {:ok, scheduler_data} <- decode_config(path) do
          module =
            module ||
              case infer_scheduler_type(scheduler_data) do
                {:ok, module} ->
                  module

                {:error, error} ->
                  raise ArgumentError, "#{error}, please specify the :module option"
              end

          scheduler = configure(module)
          scheduler = HuggingFace.Transformers.Config.load(scheduler, scheduler_data)
          {:ok, scheduler}
        end

      {:ok, %{}} ->
        raise ArgumentError, "no scheduler found in the given repository"

      {:error, message} ->
        {:error, message}
    end
  end

  defp infer_scheduler_type(%{"_class_name" => class_name}) do
    case @diffusers_class_to_scheduler[class_name] do
      nil ->
        {:error,
         "could not match the class name #{inspect(class_name)} to any of the supported schedulers"}

      module ->
        {:ok, module}
    end
  end

  defp infer_scheduler_type(_scheduler_data) do
    {:error, "could not infer featurizer type from the configuration"}
  end

  defp get_repo_files({:local, dir}) do
    case File.ls(dir) do
      {:ok, filenames} ->
        repo_files =
          for filename <- filenames,
              path = Path.join(dir, filename),
              File.regular?(path),
              into: %{},
              do: {filename, nil}

        {:ok, repo_files}

      {:error, reason} ->
        {:error, "could not read #{dir}, reason: #{:file.format_error(reason)}"}
    end
  end

  defp get_repo_files({:hf, repository_id, opts}) do
    subdir = opts[:subdir]
    url = HuggingFace.Hub.file_listing_url(repository_id, subdir, opts[:revision])

    result =
      HuggingFace.Hub.cached_download(
        url,
        Keyword.take(opts, [:cache_dir, :offline, :auth_token])
      )

    with {:ok, path} <- result,
         {:ok, data} <- decode_config(path) do
      repo_files =
        for entry <- data, entry["type"] == "file", into: %{} do
          path = entry["path"]

          name =
            if subdir do
              String.replace_leading(path, subdir <> "/", "")
            else
              path
            end

          etag_content = entry["lfs"]["oid"] || entry["oid"]
          etag = <<?", etag_content::binary, ?">>
          {name, etag}
        end

      {:ok, repo_files}
    end
  end

  defp download({:local, dir}, filename, _etag) do
    path = Path.join(dir, filename)

    if File.exists?(path) do
      {:ok, path}
    else
      {:error, "local file #{inspect(path)} does not exist"}
    end
  end

  defp download({:hf, repository_id, opts}, filename, etag) do
    filename =
      if subdir = opts[:subdir] do
        subdir <> "/" <> filename
      else
        filename
      end

    url = HuggingFace.Hub.file_url(repository_id, filename, opts[:revision])

    HuggingFace.Hub.cached_download(
      url,
      [etag: etag] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token])
    )
  end

  defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do
    {:hf, repository_id, []}
  end

  defp normalize_repository!({:hf, repository_id, opts}) when is_binary(repository_id) do
    opts = Keyword.validate!(opts, [:revision, :cache_dir, :offline, :auth_token, :subdir])
    {:hf, repository_id, opts}
  end

  defp normalize_repository!({:local, dir}) when is_binary(dir) do
    {:local, dir}
  end

  defp normalize_repository!(other) do
    raise ArgumentError,
          "expected repository to be either {:hf, repository_id}, {:hf, repository_id, options}" <>
            " or {:local, directory}, got: #{inspect(other)}"
  end

  @doc """
  Returns the directory where downloaded files are stored.
  """
  @spec cache_dir() :: String.t()
  def cache_dir() do
    if dir = System.get_env("BUMBLEBEE_CACHE_DIR") do
      Path.expand(dir)
    else
      :filename.basedir(:user_cache, "bumblebee")
    end
  end
end