defmodule HuggingfaceClient.Inference.TaskRunner do
@moduledoc """
Shared run/stream/resolve_provider logic used by all inference task modules.
"""
alias HuggingfaceClient.Inference.ResponseHandler
alias HuggingfaceClient.ProviderMapping
alias HuggingfaceClient.ProviderRegistry
alias HuggingfaceClient.Request
@doc "Executes a non-streaming inference task."
@spec run(HuggingfaceClient.Client.t(), map(), String.t()) ::
{:ok, map()} | {:error, Exception.t()}
def run(client, args, task) do
opts = HuggingfaceClient.Client.merge_opts(client, args)
with {:ok, provider} <- resolve_provider(args, opts),
{:ok, provider_mod} <- ProviderRegistry.get(provider, task),
{:ok, {raw, _ctx}} <- Request.execute(args, provider_mod, Map.put(opts, :task, task)),
result <- provider_mod.get_response(raw, %{task: task}) do
ResponseHandler.resolve(result, opts)
end
end
@doc "Executes a streaming inference task."
@spec stream(HuggingfaceClient.Client.t(), map(), String.t()) ::
{:ok, Enumerable.t()} | {:error, Exception.t()}
def stream(client, args, task) do
opts = HuggingfaceClient.Client.merge_opts(client, args)
with {:ok, provider} <- resolve_provider(args, opts),
{:ok, provider_mod} <- ProviderRegistry.get(provider, task) do
Request.stream(args, provider_mod, Map.put(opts, :task, task))
end
end
@doc false
def resolve_provider(args, opts) do
provider = Map.get(args, :provider) || opts[:provider]
model = Map.get(args, :model)
endpoint = Map.get(args, :endpoint_url) || opts[:endpoint_url]
ProviderMapping.resolve_provider(provider, model, endpoint)
end
end