lib/huggingface_client/inference/tasks/task_runner.ex

defmodule HuggingfaceClient.Inference.TaskRunner do
  @moduledoc """
  Shared run/stream/resolve_provider logic used by all inference task modules.
  """

  alias HuggingfaceClient.Inference.ResponseHandler
  alias HuggingfaceClient.ProviderMapping
  alias HuggingfaceClient.ProviderRegistry
  alias HuggingfaceClient.Request

  @doc "Executes a non-streaming inference task."
  @spec run(HuggingfaceClient.Client.t(), map(), String.t()) ::
          {:ok, map()} | {:error, Exception.t()}
  def run(client, args, task) do
    opts = HuggingfaceClient.Client.merge_opts(client, args)

    with {:ok, provider} <- resolve_provider(args, opts),
         {:ok, provider_mod} <- ProviderRegistry.get(provider, task),
         {:ok, {raw, _ctx}} <- Request.execute(args, provider_mod, Map.put(opts, :task, task)),
         result <- provider_mod.get_response(raw, %{task: task}) do
      ResponseHandler.resolve(result, opts)
    end
  end

  @doc "Executes a streaming inference task."
  @spec stream(HuggingfaceClient.Client.t(), map(), String.t()) ::
          {:ok, Enumerable.t()} | {:error, Exception.t()}
  def stream(client, args, task) do
    opts = HuggingfaceClient.Client.merge_opts(client, args)

    with {:ok, provider} <- resolve_provider(args, opts),
         {:ok, provider_mod} <- ProviderRegistry.get(provider, task) do
      Request.stream(args, provider_mod, Map.put(opts, :task, task))
    end
  end

  @doc false
  def resolve_provider(args, opts) do
    provider = Map.get(args, :provider) || opts[:provider]
    model = Map.get(args, :model)
    endpoint = Map.get(args, :endpoint_url) || opts[:endpoint_url]
    ProviderMapping.resolve_provider(provider, model, endpoint)
  end
end