defmodule GptAgent do
@moduledoc """
Provides a GPT conversation agent
"""
use GenServer
use TypedStruct
use Knigge, otp_app: :gpt_agent, default: __MODULE__.Impl
require Logger
alias GptAgent.Events.{
AssistantMessageAdded,
RunCompleted,
RunStarted,
ToolCallOutputRecorded,
ToolCallRequested,
UserMessageAdded
}
alias GptAgent.Values.NonblankString
typedstruct do
field :default_assistant_id, binary()
field :thread_id, binary() | nil
field :running?, boolean(), default: false
field :run_id, binary() | nil
field :tool_calls, [ToolCallRequested.t()], default: []
field :tool_outputs, [ToolCallOutputRecorded.t()], default: []
field :last_message_id, binary() | nil
end
defp ok(state), do: {:ok, state}
defp noreply(state), do: {:noreply, state}
defp noreply(state, next), do: {:noreply, state, next}
defp reply(state, reply), do: {:reply, reply, state}
defp reply(state, reply, next), do: {:reply, reply, state, next}
defp log(message, level \\ :debug),
do: Logger.log(level, "[GptAgent (#{inspect(self())})] " <> message)
defp publish_event(state, callback) do
channel = "gpt_agent:#{state.thread_id}"
log("Publishing event on channel #{channel}: #{inspect(callback)}")
:ok = Phoenix.PubSub.broadcast(GptAgent.PubSub, channel, {self(), callback})
state
end
@impl true
def init(init_arg) do
log("Initializing with #{inspect(init_arg)}")
init_arg
|> then(&struct!(__MODULE__, &1))
|> register()
|> ok()
end
defp register(state) do
case state.thread_id do
nil ->
state
thread_id ->
{:ok, _pid} = Registry.register(GptAgent.Registry, thread_id, :gpt_agent)
log("Registered in GptAgent.Registry as #{inspect(thread_id)}")
state
end
end
@impl true
def handle_continue(:run, state) do
log("Starting run")
{:ok, %{body: %{"id" => id}}} =
OpenAiClient.post("/v1/threads/#{state.thread_id}/runs",
json: %{
"assistant_id" => state.default_assistant_id
}
)
Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms())
log("Will check run status in #{heartbeat_interval_ms()} ms")
state
|> Map.put(:running?, true)
|> Map.put(:run_id, id)
|> publish_event(%RunStarted{
id: id,
thread_id: state.thread_id,
assistant_id: state.default_assistant_id
})
|> noreply()
end
@impl true
def handle_continue(:read_messages, state) do
url =
"/v1/threads/#{state.thread_id}/messages?order=asc" <>
if state.last_message_id do
"&after=#{state.last_message_id}"
else
""
end
log("Reading messages with request to #{url}")
{:ok, %{body: %{"object" => "list", "data" => messages}}} = OpenAiClient.get(url)
state
|> process_messages(messages)
|> noreply()
end
defp process_messages(state, messages) do
log("Processing messages: #{inspect(messages)}")
Enum.reduce(messages, state, fn message, state ->
[%{"text" => %{"value" => content}} | _rest] = message["content"]
if message["role"] == "assistant" do
publish_event(state, %AssistantMessageAdded{
message_id: message["id"],
thread_id: message["thread_id"],
run_id: message["run_id"],
assistant_id: message["assistant_id"],
content: content
})
end
log("Updating last message ID to #{message["id"]}")
%{state | last_message_id: message["id"]}
end)
end
defp heartbeat_interval_ms, do: Application.get_env(:gpt_agent, :heartbeat_interval_ms, 1000)
@impl true
def handle_cast({:set_default_assistant_id, assistant_id}, state) do
log("Setting default assistant ID to #{assistant_id}")
{:noreply, %{state | default_assistant_id: assistant_id}}
end
@impl true
def handle_call(:shutdown, _caller, state) do
log("Shutting down")
Registry.unregister(GptAgent.Registry, state.thread_id)
ok(state)
end
@impl true
def handle_call(:thread_id, _caller, %__MODULE__{} = state) do
log("Returning thread ID #{inspect(state.thread_id)}")
reply(state, {:ok, state.thread_id})
end
@impl true
def handle_call(:default_assistant_id, _caller, %__MODULE__{} = state) do
log("Returning default assistant ID #{inspect(state.default_assistant_id)}")
reply(state, {:ok, state.default_assistant_id})
end
@impl true
def handle_call({:add_user_message, message}, _caller, %__MODULE__{running?: true} = state) do
log(
"Attempting to add user message, but run in progress, cannot add user message: #{inspect(message)}"
)
reply(state, {:error, :run_in_progress})
end
@impl true
def handle_call({:add_user_message, message}, _caller, state) do
log("Adding user message #{inspect(message)}")
{:ok, message} = NonblankString.new(message)
{:ok, %{body: %{"id" => id}}} =
OpenAiClient.post("/v1/threads/#{state.thread_id}/messages", json: message)
state
|> publish_event(%UserMessageAdded{
id: id,
thread_id: state.thread_id,
content: message
})
|> reply(:ok, {:continue, :run})
end
@impl true
def handle_call(
{:submit_tool_output, tool_call_id, tool_output},
_caller,
%__MODULE__{running?: false} = state
) do
log(
"Attempting to submit tool output, but no run in progress, cannot submit tool output for call #{inspect(tool_call_id)}: #{inspect(tool_output)}"
)
reply(state, {:error, :run_not_in_progress})
end
@impl true
def handle_call({:submit_tool_output, tool_call_id, tool_output}, _caller, state) do
log("Submitting tool output #{inspect(tool_output)}")
case Enum.find_index(state.tool_calls, fn %ToolCallRequested{id: id} -> id == tool_call_id end) do
nil ->
log("Tool call ID #{inspect(tool_call_id)} not found")
reply(state, {:error, :invalid_tool_call_id})
index ->
log("Tool call ID #{inspect(tool_call_id)} found at index #{inspect(index)}")
{tool_call, tool_calls} = List.pop_at(state.tool_calls, index)
tool_output = %ToolCallOutputRecorded{
id: tool_call_id,
thread_id: state.thread_id,
run_id: tool_call.run_id,
name: tool_call.name,
output: Jason.encode!(tool_output)
}
tool_outputs = [tool_output | state.tool_outputs]
state
|> publish_event(tool_output)
|> Map.put(:tool_calls, tool_calls)
|> Map.put(:tool_outputs, tool_outputs)
|> possibly_send_outputs_to_openai()
|> reply(:ok)
end
end
defp possibly_send_outputs_to_openai(
%{running?: true, tool_calls: [], tool_outputs: [_ | _]} = state
) do
log("Sending tool outputs to OpenAI")
{:ok, %{body: %{"object" => "thread.run", "cancelled_at" => nil, "failed_at" => nil}}} =
OpenAiClient.post("/v1/threads/#{state.thread_id}/runs/#{state.run_id}/submit_tool_outputs",
json: %{tool_outputs: state.tool_outputs}
)
Process.send_after(self(), {:check_run_status, state.run_id}, heartbeat_interval_ms())
%{state | tool_outputs: []}
end
defp possibly_send_outputs_to_openai(state), do: state
@impl true
def handle_info({:check_run_status, id}, state) do
log("Checking run status for run ID #{inspect(id)}")
{:ok, %{body: %{"status" => status} = response}} =
OpenAiClient.get("/v1/threads/#{state.thread_id}/runs/#{id}", [])
handle_run_status(status, id, response, state)
end
defp handle_run_status("completed", id, _response, state) do
log("Run ID #{inspect(id)} completed")
state
|> Map.put(:running?, false)
|> publish_event(%RunCompleted{
id: id,
thread_id: state.thread_id,
assistant_id: state.default_assistant_id
})
|> noreply({:continue, :read_messages})
end
defp handle_run_status("requires_action", id, response, state) do
log("Run ID #{inspect(id)} requires action")
%{"required_action" => %{"submit_tool_outputs" => %{"tool_calls" => tool_calls}}} = response
log("Tool calls: #{inspect(tool_calls)}")
tool_calls
|> Enum.reduce(state, fn tool_call, state ->
tool_call = %ToolCallRequested{
id: tool_call["id"],
thread_id: state.thread_id,
run_id: id,
name: tool_call["function"]["name"],
arguments: Jason.decode!(tool_call["function"]["arguments"])
}
state
|> Map.put(:tool_calls, [tool_call | state.tool_calls])
|> publish_event(tool_call)
end)
|> noreply()
end
defp handle_run_status(_status, id, _response, state) do
log("Run ID #{inspect(id)} not completed")
Process.send_after(self(), {:check_run_status, id}, heartbeat_interval_ms())
log("Will check run status in #{heartbeat_interval_ms()} ms")
noreply(state)
end
@callback create_thread() :: {:ok, binary()}
@callback start_link(keyword()) :: {:ok, pid()} | {:error, reason :: term()}
@callback connect(binary()) :: {:ok, pid()} | {:error, :invalid_thread_id}
@callback connect(binary(), binary()) :: {:ok, pid()} | {:error, :invalid_thread_id}
@callback shutdown(pid()) :: :ok
@callback thread_id(pid()) :: binary()
@callback default_assistant(pid()) :: binary()
@callback set_default_assistant(pid(), binary()) :: :ok
@callback add_user_message(pid(), binary()) :: {:ok, binary()} | {:error, :run_in_progress}
@callback submit_tool_output(pid(), binary(), map()) ::
{:ok, binary()} | {:error, :invalid_tool_call_id}
defmodule Impl do
@moduledoc """
Provides the implementation of the GptAgent public API
"""
defp log(message, level \\ :debug),
do: Logger.log(level, "[GptAgent (#{inspect(self())})] " <> message)
@doc """
Creates a new thread
"""
@spec create_thread() :: {:ok, binary()}
def create_thread do
log("Creating thread")
{:ok, %{body: %{"id" => thread_id, "object" => "thread"}}} =
OpenAiClient.post("/v1/threads", json: "")
log("Created thread with ID #{inspect(thread_id)}")
{:ok, thread_id}
end
@doc """
Starts the GPT Agent
"""
@spec start_link(keyword()) :: {:ok, pid()} | {:error, reason :: term()}
def start_link(init_arg) do
GenServer.start_link(GptAgent, init_arg)
end
@doc """
Connects to the GPT Agent
"""
@spec connect(binary()) :: {:ok, pid()} | {:error, :invalid_thread_id}
def connect(thread_id) do
log("Connecting to thread ID #{inspect(thread_id)}")
case Registry.lookup(GptAgent.Registry, thread_id) do
[{pid, :gpt_agent}] ->
log("Found existing GPT Agent with PID #{inspect(pid)}")
Phoenix.PubSub.subscribe(GptAgent.PubSub, "gpt_agent:#{thread_id}")
{:ok, pid}
[] ->
log("No existing GPT Agent found, starting new one")
case OpenAiClient.get("/v1/threads/#{thread_id}") do
{:ok, %{status: 404}} ->
log("Thread ID #{inspect(thread_id)} not found")
{:error, :invalid_thread_id}
{:ok, _} ->
log("Thread ID #{inspect(thread_id)} found")
Phoenix.PubSub.subscribe(GptAgent.PubSub, "gpt_agent:#{thread_id}")
DynamicSupervisor.start_child(
GptAgent.Supervisor,
{GptAgent, [thread_id: thread_id]}
)
|> tap(&log("Started GPT Agent with result #{inspect(&1)}"))
end
end
end
@doc """
Connects to the GPT Agent and sets the default assistant
"""
@spec connect(binary(), binary()) :: {:ok, pid()} | {:error, :invalid_thread_id}
def connect(thread_id, assistant_id) do
log(
"Connecting to thread ID #{inspect(thread_id)} and setting default assistant ID to #{inspect(assistant_id)}"
)
case connect(thread_id) do
{:ok, pid} ->
:ok = set_default_assistant(pid, assistant_id)
{:ok, pid}
{:error, reason} ->
{:error, reason}
end
end
@spec shutdown(pid()) :: :ok
def shutdown(pid) do
log("Shutting down GPT Agent with PID #{inspect(pid)}")
if Process.alive?(pid) do
log("GPT Agent with PID #{inspect(pid)} is alive, terminating")
:ok = DynamicSupervisor.terminate_child(GptAgent.Supervisor, pid)
else
log("GPT Agent with PID #{inspect(pid)} is not alive")
end
:ok
end
@doc """
Returns the thread ID
"""
@spec thread_id(pid()) :: binary()
def thread_id(pid) do
GenServer.call(pid, :thread_id)
end
@doc """
Returns the default assistant
"""
@spec default_assistant(pid()) :: binary()
def default_assistant(pid) do
GenServer.call(pid, :default_assistant_id)
end
@doc """
Sets the default assistant
"""
@spec set_default_assistant(pid(), binary()) :: :ok
def set_default_assistant(pid, assistant_id) do
GenServer.cast(pid, {:set_default_assistant_id, assistant_id})
end
@doc """
Adds a user message
"""
@spec add_user_message(pid(), binary()) :: :ok | {:error, :run_in_progress}
def add_user_message(pid, message) do
GenServer.call(pid, {:add_user_message, message})
end
@doc """
Submits tool output
"""
@spec submit_tool_output(pid(), binary(), map()) :: :ok | {:error, :invalid_tool_call_id}
def submit_tool_output(pid, tool_call_id, tool_output) do
GenServer.call(pid, {:submit_tool_output, tool_call_id, tool_output})
end
end
end