lib/gpt_agent.ex

defmodule GptAgent do
  @moduledoc """
  Provides a GPT conversation agent
  """

  use GenServer
  use GptAgent.Types
  use Knigge, otp_app: :gpt_agent, default: __MODULE__.Impl

  require Logger

  alias GptAgent.Types

  alias Types.UserMessage

  alias GptAgent.Events.{
    AssistantMessageAdded,
    OrganizationQuotaExceeded,
    RateLimited,
    RateLimitRetriesExhuasted,
    RunCompleted,
    RunFailed,
    RunStarted,
    ToolCallOutputRecorded,
    ToolCallOutputSubmissionFailed,
    ToolCallRequested,
    UserMessageAdded
  }

  # two minutes
  @timeout_ms 120_000

  @rate_limit_max_retries Application.compile_env(:gpt_agent, :rate_limit_max_retries, 10)
  @rate_limit_retry_delay Application.compile_env(:gpt_agent, :rate_limit_retry_delay, 30_000)

  @tool_output_retry_delay Application.compile_env(:gpt_agent, :tool_output_retry_delay, 1000)

  typedstruct do
    field :assistant_id, Types.assistant_id(), enforce: true
    field :thread_id, Types.thread_id(), enforce: true
    field :last_message_id, Types.message_id() | nil, enforce: true
    field :running?, boolean(), default: false
    field :run_id, Types.run_id() | nil
    field :tool_calls, [ToolCallRequested.t()], default: []
    field :tool_outputs, [ToolCallOutputRecorded.t()], default: []
    field :timeout_ms, non_neg_integer(), default: @timeout_ms
    field :rate_limit_retry_attempt, non_neg_integer(), default: 0
  end

  @type connect_opt() ::
          {:subscribe, boolean()}
          | {:thread_id, Types.thread_id()}
          | {:assistant_id, Types.assistant_id()}
  @type connect_opts() :: list(connect_opt())

  @callback create_thread() :: {:ok, Types.thread_id()}
  @callback start_link(t()) :: Types.result(pid(), term())
  @callback connect(connect_opts()) :: Types.result(pid(), :invalid_thread_id)
  @callback shutdown(pid()) :: Types.result({:process_not_alive, pid()})
  @callback add_user_message(pid(), Types.nonblank_string()) ::
              Types.result(:run_in_progress | {:process_not_alive, pid()})
  @callback submit_tool_output(pid(), Types.tool_name(), Types.tool_output()) ::
              Types.result(:invalid_tool_call_id | {:process_not_alive, pid()})
  @callback run_in_progress?(pid()) :: boolean() | Types.error({:process_not_alive, pid()})
  @callback set_assistant_id(pid(), Types.assistant_id()) ::
              Types.result({:process_not_alive, pid()})

  defp noreply(%__MODULE__{} = state), do: {:noreply, state, state.timeout_ms}
  defp noreply(%__MODULE__{} = state, next), do: {:noreply, state, next}
  defp reply(%__MODULE__{} = state, reply), do: {:reply, reply, state, state.timeout_ms}
  defp stop(%__MODULE__{} = state), do: {:stop, :normal, state}

  defp log(message, level \\ :debug) when is_binary(message),
    do: Logger.log(level, "[GptAgent (#{inspect(self())})] " <> message)

  defp publish_event(%__MODULE__{} = state, callback) do
    channel = "gpt_agent:#{state.thread_id}"
    log("Publishing event on channel #{channel}: #{inspect(callback)}")

    :ok = Phoenix.PubSub.broadcast(GptAgent.PubSub, channel, {self(), callback})

    state
  end

  @impl true
  def init(%__MODULE__{} = state) do
    ensure_type!(state)

    log("Initializing with #{inspect(state)}")

    state
    |> register()
    |> retrieve_current_run_status()
    |> then(&{:ok, &1, {:continue, {:check_run_status, &1.run_id}}})
  end

  defp register(%__MODULE__{} = state) do
    case state.thread_id do
      nil ->
        state

      thread_id ->
        {:ok, _pid} = Registry.register(GptAgent.Registry, thread_id, :gpt_agent)

        log("Registered in GptAgent.Registry as #{inspect(thread_id)}")

        state
    end
  end

  defp receive_timeout_ms(%__MODULE__{} = state) do
    default_receive_timeout_ms = Application.get_env(:gpt_agent, :receive_timeout_ms)
    Enum.min([default_receive_timeout_ms, state.timeout_ms])
  end

  defp retrieve_current_run_status(%__MODULE__{} = state) do
    {:ok, %{body: %{"object" => "list", "data" => runs}}} =
      OpenAiClient.get("/v1/threads/#{state.thread_id}/runs?limit=1&order=desc",
        receive_timeout: receive_timeout_ms(state)
      )

    case runs do
      [%{"id" => run_id, "status" => status} | _rest]
      when status in ~w(queued in_progress requires_action) ->
        state
        |> Map.put(:running?, true)
        |> Map.put(:run_id, run_id)

      _ ->
        state
    end
  end

  @impl true
  def handle_continue(:run, %__MODULE__{} = state) do
    log("Starting run")

    {:ok, %{body: %{"id" => id}}} =
      OpenAiClient.post("/v1/threads/#{state.thread_id}/runs",
        json: %{
          "assistant_id" => state.assistant_id
        },
        receive_timeout: receive_timeout_ms(state)
      )

    Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms())
    log("Will check run status in #{heartbeat_interval_ms()} ms")

    state
    |> Map.put(:running?, true)
    |> Map.put(:run_id, id)
    |> publish_event(
      RunStarted.new!(
        id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id
      )
    )
    |> noreply()
  end

  @impl true
  def handle_continue({:check_run_status, nil}, state) do
    log("No run in progress, not checking run status")
    noreply(state)
  end

  @impl true
  def handle_continue({:check_run_status, run_id}, state) do
    handle_info({:check_run_status, run_id}, state)
  end

  @impl true
  def handle_continue(:read_messages, %__MODULE__{} = state) do
    url =
      "/v1/threads/#{state.thread_id}/messages?order=asc" <>
        if state.last_message_id do
          "&after=#{state.last_message_id}"
        else
          ""
        end

    log("Reading messages with request to #{url}")

    {:ok, %{body: %{"object" => "list", "data" => messages}}} =
      OpenAiClient.get(url, receive_timeout: receive_timeout_ms(state))

    state
    |> process_messages(messages)
    |> noreply()
  end

  defp process_message(message, %__MODULE__{} = state) do
    case message["content"] do
      [%{"text" => %{"value" => content}} | _rest] ->
        if message["role"] == "assistant" do
          publish_event(
            state,
            AssistantMessageAdded.new!(
              message_id: message["id"],
              thread_id: message["thread_id"],
              run_id: message["run_id"],
              assistant_id: message["assistant_id"],
              content: content
            )
          )

          log("Updating last message ID to #{message["id"]}")
          %{state | last_message_id: message["id"]}
        else
          state
        end

      _ ->
        log("Skipping message with no content: #{inspect(message)}")
        state
    end
  end

  defp process_messages(%__MODULE__{} = state, messages) do
    log("Processing messages: #{inspect(messages)}")
    Enum.reduce(messages, state, &process_message/2)
  end

  defp heartbeat_interval_ms, do: Application.get_env(:gpt_agent, :heartbeat_interval_ms, 1000)

  @impl true
  def handle_cast({:set_assistant_id, assistant_id}, %__MODULE__{} = state) do
    log("Setting default assistant ID to #{assistant_id}")
    {:noreply, %{state | assistant_id: assistant_id}}
  end

  def handle_cast({:set_last_message_id, last_message_id}, %__MODULE__{} = state) do
    log("Setting last message ID to #{last_message_id}")
    {:noreply, %{state | last_message_id: last_message_id}}
  end

  def handle_cast({:add_user_message, message}, %__MODULE__{running?: true} = state) do
    log(
      "Attempting to add user message, but run in progress, cannot add user message: #{inspect(message)}"
    )

    GenServer.cast(self(), {:add_user_message, message})
    noreply(state)
  end

  def handle_cast({:add_user_message, %UserMessage{} = message}, %__MODULE__{} = state) do
    log("Adding user message #{inspect(message)}")

    {:ok, %{body: %{"id" => id}}} =
      OpenAiClient.post("/v1/threads/#{state.thread_id}/messages",
        json: message,
        receive_timeout: receive_timeout_ms(state)
      )

    state
    |> Map.put(:rate_limit_retry_attempt, 0)
    |> publish_event(
      UserMessageAdded.new!(
        id: id,
        thread_id: state.thread_id,
        content: message
      )
    )
    |> noreply({:continue, :run})
  end

  def handle_cast(
        {:submit_tool_output, tool_call_id, tool_output},
        %__MODULE__{running?: false} = state
      ) do
    log(
      "Attempting to submit tool output, but no run in progress, cannot submit tool output for call #{inspect(tool_call_id)}: #{inspect(tool_output)}"
    )

    noreply(state)
  end

  def handle_cast(
        {:submit_tool_output, tool_call_id, tool_output},
        %__MODULE__{} = state
      ) do
    log("Submitting tool output #{inspect(tool_output)}")

    case Enum.find_index(state.tool_calls, fn %ToolCallRequested{id: id} -> id == tool_call_id end) do
      nil ->
        log("Tool call ID #{inspect(tool_call_id)} not found")
        noreply(state)

      index ->
        log("Tool call ID #{inspect(tool_call_id)} found at index #{inspect(index)}")
        {tool_call, tool_calls} = List.pop_at(state.tool_calls, index)

        tool_output =
          ToolCallOutputRecorded.new!(
            id: tool_call_id,
            thread_id: state.thread_id,
            run_id: tool_call.run_id,
            name: tool_call.name,
            output: Jason.encode!(tool_output)
          )

        tool_outputs = [tool_output | state.tool_outputs]

        state
        |> publish_event(tool_output)
        |> Map.put(:tool_calls, tool_calls)
        |> Map.put(:tool_outputs, tool_outputs)
        |> possibly_send_outputs_to_openai()
        |> noreply()
    end
  end

  defp possibly_send_outputs_to_openai(state, failure_count \\ 0)

  defp possibly_send_outputs_to_openai(state, failure_count) when failure_count >= 3 do
    log("Failed to send tool outputs to OpenAI after 3 attempts, giving up", :warning)

    state
    |> publish_event(
      ToolCallOutputSubmissionFailed.new!(
        thread_id: state.thread_id,
        run_id: state.run_id
      )
    )
  end

  defp possibly_send_outputs_to_openai(
         %__MODULE__{running?: true, tool_calls: [], tool_outputs: [_ | _]} = state,
         failure_count
       ) do
    log("Sending tool outputs to OpenAI")

    try do
      {:ok, %{body: %{"object" => "thread.run", "cancelled_at" => nil, "failed_at" => nil}}} =
        OpenAiClient.post(
          "/v1/threads/#{state.thread_id}/runs/#{state.run_id}/submit_tool_outputs",
          json: %{tool_outputs: state.tool_outputs},
          receive_timeout: receive_timeout_ms(state)
        )
    rescue
      exception ->
        log("Failed to send tool outputs to OpenAI: #{inspect(exception)}", :warning)
        :timer.sleep(@tool_output_retry_delay)
        possibly_send_outputs_to_openai(state, failure_count + 1)
    end

    Process.send_after(self(), {:check_run_status, state.run_id}, heartbeat_interval_ms())

    %{state | tool_outputs: []}
  end

  defp possibly_send_outputs_to_openai(%__MODULE__{} = state, _failure_count), do: state

  @impl true
  def handle_call(:run_in_progress?, _caller, %__MODULE__{} = state) do
    reply(state, state.running?)
  end

  def handle_call(:shutdown, _caller, %__MODULE__{} = state) do
    log("Shutting down")
    Registry.unregister(GptAgent.Registry, state.thread_id)
    stop(state)
  end

  def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
    log("Returning thread ID #{inspect(state.thread_id)}")
    reply(state, {:ok, state.thread_id})
  end

  def handle_call(:assistant_id, _caller, %__MODULE__{} = state) do
    log("Returning default assistant ID #{inspect(state.assistant_id)}")
    reply(state, {:ok, state.assistant_id})
  end

  @impl true
  def handle_info(:timeout, %__MODULE__{} = state) do
    log("Timeout Received")

    if state.running? do
      log("Run in progress, checking run status")
      noreply(state, {:continue, {:check_run_status, state.run_id}})
    else
      log("Shutting down.")
      stop(state)
    end
  end

  def handle_info(:run, %__MODULE__{} = state) do
    noreply(state, {:continue, :run})
  end

  def handle_info({:check_run_status, id}, %__MODULE__{} = state) do
    log("Checking run status for run ID #{inspect(id)}")

    {:ok, %{body: %{"status" => status} = response}} =
      OpenAiClient.get("/v1/threads/#{state.thread_id}/runs/#{id}",
        receive_timeout: receive_timeout_ms(state)
      )

    handle_run_status(status, id, response, state)
  end

  defp handle_run_status("completed", id, response, %__MODULE__{} = state) do
    log("Run ID #{inspect(id)} completed")

    state
    |> Map.put(:running?, false)
    |> publish_event(
      RunCompleted.new!(
        id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id,
        prompt_tokens: response |> Map.get("usage", %{}) |> Map.get("prompt_tokens", 0),
        completion_tokens: response |> Map.get("usage", %{}) |> Map.get("completion_tokens", 0),
        total_tokens: response |> Map.get("usage", %{}) |> Map.get("total_tokens", 0)
      )
    )
    |> noreply({:continue, :read_messages})
  end

  defp handle_run_status("requires_action", id, response, %__MODULE__{} = state) do
    log("Run ID #{inspect(id)} requires action")
    %{"required_action" => %{"submit_tool_outputs" => %{"tool_calls" => tool_calls}}} = response
    log("Tool calls: #{inspect(tool_calls)}")

    tool_calls
    |> Enum.reduce(state, fn tool_call, state ->
      case Jason.decode(tool_call["function"]["arguments"]) do
        {:ok, arguments} ->
          tool_call =
            ToolCallRequested.new!(
              id: tool_call["id"],
              thread_id: state.thread_id,
              run_id: id,
              name: tool_call["function"]["name"],
              arguments: arguments
            )

          state
          |> Map.put(:tool_calls, [tool_call | state.tool_calls])
          |> publish_event(tool_call)

        {:error, %Jason.DecodeError{}} ->
          log("Failed to decode tool call arguments: #{inspect(tool_call)}", :warning)

          tool_output =
            ToolCallOutputRecorded.new!(
              id: tool_call["id"],
              thread_id: state.thread_id,
              run_id: id,
              name: tool_call["function"]["name"],
              output:
                Jason.encode!(%{error: "Failed to decode arguments, invalid JSON in tool call."})
            )

          state
          |> publish_event(tool_output)
          |> Map.put(:tool_outputs, [tool_output | state.tool_outputs])
      end
    end)
    |> possibly_send_outputs_to_openai()
    |> noreply()
  end

  defp handle_run_status(status, id, _response, %__MODULE__{} = state)
       when status in ~w(queued in_progress) do
    log("Run ID #{inspect(id)} not completed")
    Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms())
    log("Will check run status in #{heartbeat_interval_ms()} ms")
    noreply(%{state | running?: true})
  end

  defp handle_run_status(
         "failed",
         id,
         %{
           "last_error" => %{
             "code" => "rate_limit_exceeded",
             "message" => "Rate limit reached" <> _
           }
         },
         %__MODULE__{rate_limit_retry_attempt: attempts} = state
       )
       when attempts < @rate_limit_max_retries do
    log(
      "Run ID #{inspect(id)} failed due to rate limiting. Will retry run in #{@rate_limit_retry_delay}ms."
    )

    Process.send_after(self(), :run, @rate_limit_retry_delay)

    state
    |> Map.update!(:rate_limit_retry_attempt, &(&1 + 1))
    |> publish_event(
      RateLimited.new!(
        run_id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id,
        retries_remaining: @rate_limit_max_retries - attempts
      )
    )
    |> noreply()
  end

  defp handle_run_status(
         "failed",
         id,
         %{
           "last_error" => %{
             "code" => "rate_limit_exceeded",
             "message" => "Rate limit reached" <> _
           }
         },
         %__MODULE__{rate_limit_retry_attempt: attempts} = state
       )
       when attempts >= @rate_limit_max_retries do
    log("Run ID #{inspect(id)} failed due to rate limiting. Retries expired")

    state
    |> Map.update!(:rate_limit_retry_attempt, &(&1 + 1))
    |> Map.put(:running?, false)
    |> publish_event(
      RateLimited.new!(
        run_id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id,
        retries_remaining: @rate_limit_max_retries - attempts
      )
    )
    |> publish_event(
      RateLimitRetriesExhuasted.new!(
        run_id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id
      )
    )
    |> noreply()
  end

  defp handle_run_status(
         "failed",
         id,
         %{
           "last_error" => %{
             "code" => "rate_limit_exceeded",
             "message" =>
               "You exceeded your current quota, please check your plan and billing details." <> _
           }
         } = response,
         %__MODULE__{} = state
       ) do
    log("Run ID #{inspect(id)} failed due to OpenAI account quota reached.")

    state
    |> Map.put(:running?, false)
    |> publish_event(
      OrganizationQuotaExceeded.new!(
        run_id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id
      )
    )
    # DEPRECATED: remove this second publish_event call on major version bump to 10.0.0
    |> publish_event(
      RunFailed.new!(
        id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id,
        code: "rate_limit_exceeded",
        message: response |> Map.get("last_error", %{}) |> Map.get("message")
      )
    )
    |> noreply()
  end

  defp handle_run_status(status, id, response, %__MODULE__{} = state) do
    log("Run ID #{inspect(id)} failed with status #{inspect(status)}", :warning)
    log("Response: #{inspect(response)}")
    log("State: #{inspect(state)}")

    state
    |> Map.put(:running?, false)
    |> publish_event(
      RunFailed.new!(
        id: id,
        thread_id: state.thread_id,
        assistant_id: state.assistant_id,
        code: response |> Map.get("last_error", %{}) |> Map.get("code") || "unknown",
        message: response |> Map.get("last_error", %{}) |> Map.get("message") || "unknown"
      )
    )
    |> noreply()
  end

  defmodule Impl do
    @moduledoc """
    Provides the implementation of the GptAgent public API
    """

    @behaviour GptAgent

    defp log(message, level \\ :debug) when is_binary(message),
      do: Logger.log(level, "[GptAgent (#{inspect(self())})] " <> message)

    defp ok(data), do: {:ok, data}

    @impl true
    def create_thread do
      log("Creating thread")

      {:ok, %{body: %{"id" => thread_id, "object" => "thread"}}} =
        OpenAiClient.post("/v1/threads",
          json: "",
          receive_timeout: Application.get_env(:gpt_agent, :receive_timeout_ms)
        )

      log("Created thread with ID #{inspect(thread_id)}")

      {:ok, thread_id}
    end

    @impl true
    def start_link(%GptAgent{} = state) do
      GenServer.start_link(GptAgent, state)
    end

    @impl true
    def connect(opts) when is_list(opts) do
      {:ok, opts} = validate_and_convert_opts(opts)

      opts
      |> connect_to_new_or_existing_agent()
      |> maybe_subscribe(opts)
    end

    defp connect_to_new_or_existing_agent(opts) do
      log("Connecting to thread ID #{inspect(opts.thread_id)}")

      case Registry.lookup(GptAgent.Registry, opts.thread_id) do
        [{pid, :gpt_agent}] ->
          handle_existing_agent(pid, opts.last_message_id, opts.assistant_id)

        [] ->
          handle_no_existing_agent(
            opts.thread_id,
            opts.last_message_id,
            opts.assistant_id,
            opts.timeout_ms
          )
      end
    end

    defp validate_and_convert_opts(opts) do
      Keyword.validate!(opts, [
        :thread_id,
        :last_message_id,
        :assistant_id,
        subscribe: true,
        timeout_ms: nil
      ])
      |> Enum.into(%{})
      |> ok()
      |> validate_thread_id()
      |> validate_last_message_id()
      |> validate_assistant_id()
    end

    defp validate_thread_id({:ok, %{thread_id: _thread_id} = opts}) do
      ok(opts)
    end

    defp validate_thread_id({:ok, _opts}) do
      {:error, :missing_thread_id}
    end

    defp validate_last_message_id({:ok, %{last_message_id: _last_message_id} = opts}) do
      ok(opts)
    end

    defp validate_last_message_id({:ok, _opts}) do
      {:error, :missing_last_message_id}
    end

    defp validate_last_message_id({:error, _} = error), do: error

    defp validate_assistant_id({:ok, %{assistant_id: _assistant_id} = opts}) do
      ok(opts)
    end

    defp validate_assistant_id({:ok, _opts}) do
      {:error, :missing_assistant_id}
    end

    defp validate_assistant_id({:error, _} = error), do: error

    defp maybe_subscribe({:ok, _pid} = result, opts) do
      if opts.subscribe do
        Phoenix.PubSub.subscribe(GptAgent.PubSub, "gpt_agent:#{opts.thread_id}")
      end

      result
    end

    defp maybe_subscribe(result, _opts), do: result

    defp receive_timeout_ms(%GptAgent{} = state) do
      default_receive_timeout_ms = Application.get_env(:gpt_agent, :receive_timeout_ms)
      Enum.min([default_receive_timeout_ms, state.timeout_ms])
    end

    defp handle_existing_agent(pid, last_message_id, assistant_id) do
      log("Found existing GPT Agent with PID #{inspect(pid)}")
      log("Updating last message ID to #{inspect(last_message_id)}")
      GenServer.cast(pid, {:set_last_message_id, last_message_id})
      GenServer.cast(pid, {:set_assistant_id, assistant_id})
      {:ok, pid}
    end

    defp handle_no_existing_agent(thread_id, last_message_id, assistant_id, timeout_ms) do
      log("No existing GPT Agent found, starting new one")

      state =
        GptAgent.new!(
          thread_id: thread_id,
          last_message_id: last_message_id,
          assistant_id: assistant_id,
          timeout_ms: timeout_ms || default_timeout_ms()
        )

      case OpenAiClient.get("/v1/threads/#{thread_id}",
             receive_timeout: receive_timeout_ms(state)
           ) do
        {:ok, %{status: 404}} ->
          log("Thread ID #{inspect(thread_id)} not found")
          {:error, :invalid_thread_id}

        {:ok, _} ->
          log("Thread ID #{inspect(thread_id)} found")

          child_spec = %{
            id: thread_id,
            start: {__MODULE__, :start_link, [state]},
            restart: :temporary
          }

          DynamicSupervisor.start_child(GptAgent.Supervisor, child_spec)
          |> tap(&log("Started GPT Agent with result #{inspect(&1)}"))
      end
    end

    defp default_timeout_ms, do: Application.get_env(:gpt_agent, :timeout_ms, 120_000)

    defp handle_dead_process(pid) do
      log("GPT Agent with PID #{inspect(pid)} is not alive", :warning)
      {:error, {:process_not_alive, pid}}
    end

    @impl true
    def shutdown(pid) do
      log("Shutting down GPT Agent with PID #{inspect(pid)}")

      if Process.alive?(pid) do
        log("GPT Agent with PID #{inspect(pid)} is alive, terminating")
        DynamicSupervisor.terminate_child(GptAgent.Supervisor, pid)
      else
        handle_dead_process(pid)
      end
    end

    @impl true
    def add_user_message(pid, message) do
      if Process.alive?(pid) do
        GenServer.cast(pid, {:add_user_message, %UserMessage{content: message}})
      else
        handle_dead_process(pid)
      end
    end

    @impl true
    def submit_tool_output(pid, tool_call_id, tool_output) do
      if Process.alive?(pid) do
        GenServer.cast(pid, {:submit_tool_output, tool_call_id, tool_output})
      else
        handle_dead_process(pid)
      end
    end

    @impl true
    def run_in_progress?(pid) do
      if Process.alive?(pid) do
        GenServer.call(pid, :run_in_progress?)
      else
        handle_dead_process(pid)
      end
    end

    @impl true
    def set_assistant_id(pid, assistant_id) do
      if Process.alive?(pid) do
        GenServer.cast(pid, {:set_assistant_id, assistant_id})
      else
        handle_dead_process(pid)
      end
    end
  end
end