defmodule Replicate.Predictions do
@moduledoc """
Documentation for `Predictions`.
"""
@behaviour Replicate.Predictions.Behaviour
@replicate_client Application.compile_env(:replicate, :replicate_client, Replicate.Client)
alias Replicate.Predictions.Prediction
alias Replicate.Models.Version
@doc """
Gets a prediction by id.
## Examples
```
iex> {:ok, prediction} = Replicate.Predictions.get("1234")
iex> prediction.status
"succeeded"
iex> Replicate.Predictions.get("not_a_real_id")
{:error, "Not found"}
```
"""
def get(id) do
@replicate_client.request(:get, "/v1/predictions/#{id}")
|> parse_response()
end
@doc """
Gets a prediction by id and fails if it doesn't exist.
```
## Examples
iex> prediction = Replicate.Predictions.get!("1234")
iex> prediction.id
"1234"
iex> Replicate.Predictions.get!("not_a_real_id")
** (RuntimeError) Not found
```
"""
def get!(id) do
case get(id) do
{:ok, prediction} -> prediction
{:error, message} -> raise message
end
end
@doc """
Cancels a prediction given an id or `%Prediction{}`.
## Examples
```
iex> {:ok, prediction} = Replicate.Predictions.cancel("1234")
iex> prediction.status
"canceled"
iex> model = Replicate.Models.get!("stability-ai/stable-diffusion")
iex> version = Replicate.Models.get_version!(model, "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
iex> {:ok, prediction} = Replicate.Predictions.create(version, %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> {:ok, prediction} = Replicate.Predictions.cancel(prediction)
iex> prediction.status
"canceled"
```
If a prediction is completed, it cannot be canceled.
```
iex> model = Replicate.Models.get!("stability-ai/stable-diffusion")
iex> version = Replicate.Models.get_version!(model, "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
iex> {:ok, prediction} = Replicate.Predictions.create(version, %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> prediction.status
"starting"
iex> {:ok, prediction} = Replicate.Predictions.wait(prediction)
iex> prediction.status
"succeeded"
# iex> {:ok, prediction} = Replicate.Predictions.cancel(prediction.id)
# iex> prediction.status
# "succeeded"
```
"""
def cancel(%Prediction{id: id}) do
cancel(id)
end
def cancel(id) when is_binary(id) do
@replicate_client.request(:post, "/v1/predictions/#{id}/cancel")
|> parse_response()
end
@doc """
Creates a prediction. You can optionally provide a webhook to be notified when the prediction is completed.
The input parameter should be a map of the model inputs.
## Examples
If you're calling an Official Model, you can provide the model name and version:
```
iex> {:ok, prediction} = Replicate.Predictions.create("stability-ai/stable-diffusion-3", %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> prediction.status
"starting"
Otherwise, provide a `%Replicate.Models.Version{}` struct:
```
iex> model = Replicate.Models.get!("stability-ai/stable-diffusion")
iex> version = Replicate.Models.get_version!(model, "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
iex> {:ok, prediction} = Replicate.Predictions.create(version, %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> prediction.status
"starting"
iex> {:ok, prediction} = Replicate.Predictions.create(version, %{prompt: "a 19th century portrait of a wombat gentleman"}, "https://example.com/webhook")
iex> prediction.status
"starting"
```
"""
def create(
model,
input,
webhook \\ nil,
webhook_completed \\ nil,
webhook_event_filter \\ nil,
stream \\ nil
) do
webhook_parameters =
%{
"webhook" => webhook,
"webhook_completed" => webhook_completed,
"webhook_event_filter" => webhook_event_filter,
"stream" => stream
}
|> Enum.filter(fn {_key, value} -> !is_nil(value) end)
|> Enum.into(%{})
send_to_replicate(model, input, webhook_parameters)
end
defp send_to_replicate(%Version{id: id}, input, webhook_parameters) do
body =
%{
"version" => id,
"input" => input |> Enum.into(%{})
}
|> Map.merge(webhook_parameters)
|> Jason.encode!()
@replicate_client.request(:post, "/v1/predictions", body)
|> parse_response()
end
defp send_to_replicate(model, input, webhook_parameters) do
[model_owner, model_name] = String.split(model, "/")
body =
%{
"input" => input |> Enum.into(%{})
}
|> Map.merge(webhook_parameters)
|> Jason.encode!()
@replicate_client.request(:post, "/v1/models/#{model_owner}/#{model_name}/predictions", body)
|> parse_response()
end
@doc """
Waits for a prediction to complete.
## Examples
```
iex> model = Replicate.Models.get!("stability-ai/stable-diffusion")
iex> version = Replicate.Models.get_version!(model, "db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
iex> {:ok, prediction} = Replicate.Predictions.create(version, %{prompt: "a 19th century portrait of a wombat gentleman"})
iex> prediction.status
"starting"
iex> {:ok, prediction} = Replicate.Predictions.wait(prediction)
iex> prediction.status
"succeeded"
```
"""
def wait(%Prediction{} = prediction), do: @replicate_client.wait({:ok, prediction})
@doc """
Lists all the predictions you've run.
## Examples
```
iex> Replicate.Predictions.list()
[%Prediction{
id: "1234",
status: "starting",
input: %{"prompt" => "a 19th century portrait of a wombat gentleman"},
version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
output: ["https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png"],
urls: %{
"get" => "https://api.replicate.com/v1/predictions/1234",
"cancel" => "https://api.replicate.com/v1/predictions/1234/cancel",
}
},
%Prediction{
id: "1235",
status: "starting",
input: %{"prompt" => "a 19th century portrait of a wombat gentleman"},
version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
output: ["https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png"],
urls: %{
"get" => "https://api.replicate.com/v1/predictions/1235",
"cancel" => "https://api.replicate.com/v1/predictions/1235/cancel"
}
}]
```
"""
def list() do
case @replicate_client.request(:get, "/v1/predictions") do
{:ok, results} ->
%{"results" => versions} = Jason.decode!(results)
versions
|> Enum.map(fn v ->
atom_map = string_to_atom(v)
struct(Replicate.Predictions.Prediction, atom_map)
end)
{:error, message} ->
raise message
end
end
defp parse_response({:ok, json_body}) do
body =
json_body
|> Jason.decode!()
|> string_to_atom()
{:ok, struct(Prediction, body)}
end
defp parse_response({:error, message}), do: {:error, message}
defp string_to_atom(body) do
for {k, v} <- body, into: %{}, do: {String.to_atom(k), v}
end
end