lib/mock_client.ex

defmodule Replicate.MockClient do
  @moduledoc """
  A mock version of `Replicate.Client`. Only used for testing.
  """
  alias Replicate.Predictions.Prediction

  @stub_prediction %{
    id: "1234",
    status: "starting",
    input: %{"prompt" => "a 19th century portrait of a wombat gentleman"},
    version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
    output: [
      "https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png"
    ],
    urls: %{
      "get" => "https://api.replicate.com/v1/predictions/1234",
      "cancel" => "https://api.replicate.com/v1/predictions/1234/cancel"
    }
  }
  @stub_prediction2 %{
    id: "1235",
    status: "starting",
    input: %{"prompt" => "a 19th century portrait of a wombat gentleman"},
    version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
    output: [
      "https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png"
    ],
    urls: %{
      "get" => "https://api.replicate.com/v1/predictions/1235",
      "cancel" => "https://api.replicate.com/v1/predictions/1235/cancel"
    }
  }

  @stub_version1 %{
    id: "v1",
    created_at: "2022-04-26T19:29:04.418669Z",
    cog_version: "0.3.0",
    openapi_schema: %{}
  }

  @stub_version2 %{
    id: "v2",
    created_at: "2022-03-21T13:01:04.418669Z",
    cog_version: "0.3.0",
    openapi_schema: %{}
  }

  @stub_model %{
    "url" => "https://replicate.com/replicate/hello-world",
    "owner" => "replicate",
    "name" => "hello-world",
    "description" => "A tiny model that says hello",
    "visibility" => "public",
    "github_url" => "https://github.com/replicate/cog-examples",
    "paper_url" => nil,
    "license_url" => nil,
    "run_count" => 12345,
    "cover_image_url" => nil,
    "latest_version" => @stub_version2
  }

  @stub_hardware [
    %{"name" => "CPU", "sku" => "cpu"},
    %{"name" => "Nvidia T4 GPU", "sku" => "gpu-t4"},
    %{"name" => "Nvidia A40 GPU", "sku" => "gpu-a40-small"},
    %{"name" => "Nvidia A40 (Large) GPU", "sku" => "gpu-a40-large"}
  ]

  def request(:get, "/v1/predictions") do
    {:ok, %{"results" => [@stub_prediction, @stub_prediction2]} |> Jason.encode!()}
  end

  def request(:get, "/v1/models/replicate/hello-world/versions") do
    {:ok, %{"results" => [@stub_version1, @stub_version2]} |> Jason.encode!()}
  end

  def request(:get, "/v1/models/replicate/hello-world") do
    {:ok, @stub_model |> Jason.encode!()}
  end

  def request(:get, "/v1/models/stability-ai/stable-diffusion/versions") do
    {:ok, %{"results" => [@stub_version1, @stub_version2]} |> Jason.encode!()}
  end

  def request(:get, "/v1/models/replicate/hello-world/versions/v2") do
    {:ok, @stub_version2 |> Jason.encode!()}
  end

  def request(:get, "/v1/models/stability-ai/stable-diffusion") do
    {:ok, @stub_model |> Jason.encode!()}
  end

  def request(
        :get,
        "/v1/models/replicate/hello-world/versions/27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"
      ) do
    {:ok, @stub_version1 |> Jason.encode!()}
  end

  def request(
        :get,
        "/v1/models/replicate/hello-world/versions/db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
      ) do
    {:ok, @stub_version2 |> Jason.encode!()}
  end

  def request(:get, "/v1/predictions/1234") do
    {:ok, %{@stub_prediction | status: "succeeded"} |> Jason.encode!()}
  end

  def request(:get, "/v1/predictions/not_a_real_id"), do: {:error, "Not found"}
  def request(:get, "/v1/models/cbh123/babadook-diffusion"), do: {:error, "Not found"}

  def request(:get, "/v1/models") do
    models = Enum.map(1..25, fn _ -> @stub_model end)

    {:ok,
     %{
       "results" => models,
       "next" =>
         "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
       "previous" => nil
     }
     |> Jason.encode!()}
  end

  def request(:get, "/v1/hardware") do
    {:ok, @stub_hardware |> Jason.encode!()}
  end

  def request(:get, path), do: {:error, "Unexpected path in the mock client: #{path}"}

  def request(:post, path), do: request(:post, path, [])

  def request(:post, "/v1/models", body) do
    {:ok, body}
  end

  # official models endpoint
  def request(:post, "/v1/models/replicate/hello-world/predictions", _body) do
    {:ok, @stub_prediction |> Jason.encode!()}
  end

  def request(:post, "/v1/predictions", body) do
    if "version" not in (body |> Jason.decode!() |> Map.keys()) do
      {:error, "Version is required"}
    else
      {:ok, @stub_prediction |> Jason.encode!()}
    end
  end

  def request(:post, path, _body) do
    if Path.basename(path) == "cancel" do
      {:ok, %{@stub_prediction | status: "canceled"} |> Jason.encode!()}
    else
      {:ok, @stub_prediction |> Jason.encode!()}
    end
  end

  def request(:fail, _path, _body), do: {:error, "Failed"}

  def wait({:ok, _prediction}) do
    {:ok, struct(Prediction, %{@stub_prediction | status: "succeeded"})}
  end

  def wait({:error, message}), do: {:error, message}
end