defmodule Pi.Session.Worker do
@moduledoc "Server-owned Pi session process with subscribers and LLM-backed runs."
use GenServer
alias Pi.Agent.Messages
alias Pi.Protocol.LLM.Message
alias Pi.Protocol.Session.Snapshot
alias Pi.Session.Event
alias Pi.Session.State
@timeout 60_000
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts)
end
def id(pid), do: state(pid).id
def state(pid), do: GenServer.call(pid, :state)
def snapshot(pid), do: GenServer.call(pid, :snapshot)
def subscribe(pid, subscriber \\ self()), do: GenServer.call(pid, {:subscribe, subscriber})
def detach(pid, subscriber \\ self()), do: GenServer.call(pid, {:detach, subscriber})
def run(pid, prompt, opts \\ []) when is_binary(prompt) do
GenServer.call(pid, {:run, prompt, opts}, Keyword.get(opts, :timeout, @timeout) + 1_000)
end
def complete(pid, opts \\ []) do
GenServer.call(pid, {:complete, opts}, Keyword.get(opts, :timeout, @timeout) + 1_000)
end
def append(pid, message), do: GenServer.call(pid, {:append, message})
def emit_event(pid, %Event{} = event), do: GenServer.call(pid, {:emit_event, event})
def cancel(pid), do: GenServer.call(pid, :cancel)
def rerun(pid, opts \\ []),
do: GenServer.call(pid, {:rerun, opts}, Keyword.get(opts, :timeout, @timeout) + 1_000)
@impl true
def init(opts) do
{:ok,
%{
state: State.new(opts),
ask_fun: Keyword.get(opts, :ask_fun, &Pi.LLM.complete_with_usage/2),
stream_fun: Keyword.get(opts, :stream_fun, &Pi.LLM.stream/2),
subscribers: %{},
task: nil,
task_ref: nil,
caller: nil
}}
end
@impl true
def handle_call(:state, _from, data), do: {:reply, data.state, data}
def handle_call(:snapshot, _from, data), do: {:reply, to_snapshot(data.state), data}
def handle_call({:subscribe, subscriber}, _from, data) when is_pid(subscriber) do
ref = Process.monitor(subscriber)
{:reply, {:ok, data.state}, %{data | subscribers: Map.put(data.subscribers, ref, subscriber)}}
end
def handle_call({:detach, subscriber}, _from, data) when is_pid(subscriber) do
{removed, subscribers} =
Enum.split_with(data.subscribers, fn {_ref, pid} -> pid == subscriber end)
Enum.each(removed, fn {ref, _pid} -> Process.demonitor(ref, [:flush]) end)
{:reply, :ok, %{data | subscribers: Map.new(subscribers)}}
end
def handle_call({:append, message}, _from, data) do
data = update_state(data, fn state -> append_message(state, message) end)
{:reply, :ok, data}
end
def handle_call({:emit_event, %Event{} = event}, _from, data) do
data = append_event(data, event)
{:reply, :ok, data}
end
def handle_call(:cancel, _from, %{task: nil} = data) do
data = transition(data, :cancelled, Event.new(:cancelled))
{:reply, :ok, data}
end
def handle_call(:cancel, _from, %{task: task} = data) do
Task.shutdown(task, :brutal_kill)
event = Event.new(:cancelled)
data =
data
|> Map.merge(%{task: nil, task_ref: nil})
|> finish(:cancelled, nil, nil, event)
|> reply({:error, :cancelled})
{:reply, :ok, data}
end
def handle_call({:run, prompt, opts}, from, %{task: nil} = data) do
data = data |> update_state(&append_message(&1, %Message{role: :user, content: prompt}))
start_completion(data, from, opts)
end
def handle_call({:run, _prompt, _opts}, _from, data) do
{:reply, {:error, :busy}, data}
end
def handle_call({:rerun, opts}, from, %{task: nil} = data) do
case last_user_message(data.state) do
nil -> {:reply, {:error, :no_user_message}, data}
prompt -> handle_call({:run, prompt, opts}, from, data)
end
end
def handle_call({:rerun, _opts}, _from, data) do
{:reply, {:error, :busy}, data}
end
def handle_call({:complete, opts}, from, %{task: nil} = data) do
start_completion(data, from, opts)
end
def handle_call({:complete, _opts}, _from, data) do
{:reply, {:error, :busy}, data}
end
@impl true
def handle_info({:session_delta, delta}, data) when is_binary(delta) do
data =
data
|> update_state(&put_recent_output(&1, delta))
|> transition(:running, Event.new(:delta, %{delta: delta}))
{:noreply, data}
end
def handle_info({ref, {:ok, result}}, %{task_ref: ref} = data) do
Process.demonitor(ref, [:flush])
{text, usage} = completion_result(result)
data =
data
|> update_state(&append_message(&1, %Message{role: :assistant, content: text}))
|> put_usage(usage)
|> complete(:done, text, nil, Event.new(:done, %{result: text}))
|> reply({:ok, text})
{:noreply, data}
end
def handle_info({ref, {:error, reason}}, %{task_ref: ref} = data) do
Process.demonitor(ref, [:flush])
data =
data
|> complete(:failed, nil, reason, Event.new(:failed, %{error: inspect(reason)}))
|> reply({:error, reason})
{:noreply, data}
end
def handle_info({:DOWN, ref, :process, _pid, _reason}, %{task_ref: ref} = data) do
data =
data
|> complete(:failed, nil, :down, Event.new(:failed, %{error: "down"}))
|> reply({:error, :down})
{:noreply, data}
end
def handle_info({:DOWN, ref, :process, _pid, _reason}, data) do
{:noreply, %{data | subscribers: Map.delete(data.subscribers, ref)}}
end
defp start_completion(data, from, opts) do
data =
data
|> update_state(&begin_run/1)
|> transition(:running, Event.new(:started))
messages = messages(data.state)
ask_fun = data.ask_fun
stream_fun = data.stream_fun
timeout = Keyword.get(opts, :timeout, @timeout)
owner = self()
task =
Task.async(fn ->
if Keyword.get(opts, :stream, false) do
safe_stream(stream_fun, messages, Keyword.put(opts, :timeout, timeout), owner)
else
safe_ask(ask_fun, messages, Keyword.put(opts, :timeout, timeout))
end
end)
data =
transition(
%{data | task: task, task_ref: task.ref, caller: from},
:running,
Event.new(:llm)
)
{:noreply, data}
end
defp last_user_message(%State{messages: messages}) do
messages
|> Enum.reverse()
|> Enum.find_value(fn
%Message{role: :user, content: content} -> content
_message -> nil
end)
end
defp safe_ask(ask_fun, messages, opts) do
ask_fun.(messages, opts)
rescue
exception in [
RuntimeError,
ArgumentError,
FunctionClauseError,
MatchError,
UndefinedFunctionError,
ErlangError
] ->
{:error, Exception.message(exception)}
catch
kind, reason -> {:error, {kind, reason}}
end
defp completion_result(%{text: text, usage: usage}) when is_binary(text), do: {text, usage}
defp completion_result(%{"text" => text, "usage" => usage}) when is_binary(text),
do: {text, usage}
defp completion_result(text) when is_binary(text), do: {text, nil}
defp completion_result(result), do: {inspect(result), nil}
defp put_usage(data, nil), do: data
defp put_usage(data, usage) when is_map(usage) do
update_state(data, fn %State{metadata: metadata} = state ->
%{state | metadata: Map.put(metadata, :usage, usage)}
end)
end
defp safe_stream(stream_fun, messages, opts, owner) do
stream = stream_fun.(messages, opts)
text =
Enum.map_join(stream.stream, fn delta ->
text = to_string(delta)
send(owner, {:session_delta, text})
text
end)
{:ok, text}
rescue
exception in [
RuntimeError,
ArgumentError,
FunctionClauseError,
MatchError,
UndefinedFunctionError,
ErlangError
] ->
{:error, Exception.message(exception)}
catch
kind, reason -> {:error, {kind, reason}}
end
defp messages(%State{system: nil, messages: messages}), do: messages
defp messages(%State{system: system, messages: messages}) do
[%Message{role: :system, content: system} | messages]
end
defp append_message(%State{messages: messages} = state, message) do
%{
state
| messages: messages ++ [Messages.normalize(message)],
updated_at: DateTime.utc_now()
}
end
defp begin_run(%State{metadata: metadata} = state) do
metadata =
metadata
|> Map.update(:run_count, 1, &(&1 + 1))
|> Map.put(:recent_output, [])
|> put_current("llm")
|> Map.delete(:completed_at)
%{state | metadata: metadata}
end
defp put_recent_output(%State{metadata: metadata} = state, delta) do
recent_output =
metadata
|> Map.get(:recent_output, [])
|> Kernel.++([delta])
|> Enum.take(-5)
%{
state
| metadata: metadata |> Map.put(:recent_output, recent_output) |> put_current("streaming")
}
end
defp put_current(metadata, current) do
metadata
|> Map.put(:current, current)
|> Map.put(:current_started_at, current_started_at(metadata, current))
end
defp current_started_at(
%{current: current, current_started_at: %DateTime{} = started_at},
current
),
do: started_at
defp current_started_at(_metadata, _current), do: DateTime.utc_now()
defp append_event(data, event) do
update_state(data, fn state ->
%{state | events: state.events ++ [event], updated_at: event.at}
end)
end
defp transition(data, status, event) do
update_state(data, fn state ->
%{state | status: status, events: state.events ++ [event], updated_at: event.at}
end)
end
defp complete(data, status, result, error, event) do
data
|> Map.merge(%{task: nil, task_ref: nil})
|> finish(status, result, error, event)
end
defp finish(data, status, result, error, event) do
update_state(data, fn state ->
metadata =
state.metadata
|> Map.put(:completed_at, event.at)
|> Map.delete(:current)
|> Map.delete(:current_started_at)
%{
state
| status: status,
result: result,
error: error,
events: state.events ++ [event],
updated_at: event.at,
metadata: metadata
}
end)
end
defp reply(%{caller: nil} = data, _result), do: data
defp reply(%{caller: caller} = data, result) do
GenServer.reply(caller, result)
%{data | caller: nil}
end
defp update_state(data, fun) do
state = fun.(data.state)
broadcast(data.subscribers, state)
%{data | state: state}
end
defp broadcast(subscribers, state) do
Enum.each(subscribers, fn {_ref, pid} -> send(pid, {:pi_session, state.id, state}) end)
Pi.Plugin.Event.emit("pi_session", %{session: to_snapshot(state)})
end
defp to_snapshot(%State{} = state) do
Code.ensure_loaded?(Snapshot)
%Snapshot{
id: state.id,
parent_id: state.parent_id,
name: name(state.name),
status: Atom.to_string(state.status),
result: state.result,
error: error(state.error),
started_at: datetime(state.started_at),
updated_at: datetime(state.updated_at),
last_activity_at: datetime(state.updated_at),
completed_at: datetime(Map.get(state.metadata, :completed_at)),
current_started_at: datetime(Map.get(state.metadata, :current_started_at)),
duration_ms: duration_ms(state),
prompt: prompt_text(state),
response: response_text(state),
message_count: length(state.messages),
latest: latest_text(state),
current: Map.get(state.metadata, :current),
usage: Map.get(state.metadata, :usage),
run_count: Map.get(state.metadata, :run_count, 0),
turn_count: turn_count(state.messages),
recent_output: Map.get(state.metadata, :recent_output, []),
events: Enum.map(state.events, &event/1)
}
end
defp name(nil), do: nil
defp name(value), do: to_string(value)
defp error(nil), do: nil
defp error(value) when is_binary(value), do: value
defp error(value), do: inspect(value)
defp datetime(nil), do: nil
defp datetime(%DateTime{} = datetime), do: DateTime.to_iso8601(datetime)
defp duration_ms(%State{
started_at: %DateTime{} = started_at,
updated_at: %DateTime{} = updated_at
}) do
DateTime.diff(updated_at, started_at, :millisecond)
end
defp duration_ms(_state), do: nil
defp prompt_text(%State{messages: messages}), do: last_text_for_role(messages, :user)
defp response_text(%State{messages: messages}), do: last_text_for_role(messages, :assistant)
defp turn_count(messages), do: Enum.count(messages, &(&1.role == :assistant))
defp latest_text(%State{result: result}) when is_binary(result) and result != "", do: result
defp latest_text(%State{messages: messages}) do
messages
|> Enum.reverse()
|> Enum.find_value(fn
%Message{content: content} when is_binary(content) and content != "" -> content
_message -> nil
end)
end
defp last_text_for_role(messages, role) do
messages
|> Enum.reverse()
|> Enum.find_value(fn
%Message{role: ^role, content: content} when is_binary(content) and content != "" -> content
_message -> nil
end)
end
defp event(%Event{} = event) do
%Pi.Protocol.Session.Event{
type: Atom.to_string(event.type),
at: datetime(event.at),
data: event.data
}
end
end