%%% @doc Tiny natural-language -> SQL agent over a read-only database.
%%%
%%% Given a question and a read-only query function, this drives a short
%%% tool-use loop with the LLM (`llm.erl'): the model writes SQL, we run it
%%% through `QueryFun' (which the caller pins read-only), feed the rows back,
%%% and repeat until the model stops calling the tool and gives a plain-English
%%% answer. The accumulated SQL is returned alongside the answer so the caller
%%% can show exactly what was run.
%%%
%%% The agent never touches the database directly — it only ever calls the
%%% caller-supplied `QueryFun', so it inherits whatever read-only guarantees the
%%% caller enforces (see `cli:read_query/1').
-module(db_agent).
-export([answer/2, answer/3, edit/2, edit/3]).
%% A query runner: takes a SQL string, returns column names + JSON-safe row
%% maps, or an error the model can read and recover from.
-type query_fun() :: fun((binary()) ->
{ok, [binary()], [map()]} | {error, term()}).
-export_type([query_fun/0]).
-define(PROVIDER, openai).
-define(SIZE, big).
%% Safety rail on the tool-use loop so a confused model can't run forever.
-define(MAX_STEPS, 8).
%% Cap rows fed back to the model per query to keep the context bounded; the
%% true row_count is always reported so it knows when results were trimmed.
-define(MAX_ROWS_TO_MODEL, 100).
%% Read-only mode: the model may only SELECT. Pair with a `QueryFun' the caller
%% has pinned read-only (see `db:run_readonly/2').
-spec answer(query_fun(), binary() | list()) ->
{ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
answer(QueryFun, Question) ->
answer(QueryFun, Question, #{}).
-spec answer(query_fun(), binary() | list(), map()) ->
{ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
answer(QueryFun, Question0, Opts) ->
run(read, QueryFun, Question0, Opts).
%% Write mode: the model may run INSERT/UPDATE/DELETE in addition to SELECT, to
%% carry out a requested change. Pair with a write-capable `QueryFun' (see
%% `db:run_write/2'). The prompt instructs the model to inspect first, scope
%% every write with a WHERE clause, and stop rather than guess on ambiguity.
-spec edit(query_fun(), binary() | list()) ->
{ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
edit(QueryFun, Instruction) ->
edit(QueryFun, Instruction, #{}).
-spec edit(query_fun(), binary() | list(), map()) ->
{ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
edit(QueryFun, Instruction0, Opts) ->
run(write, QueryFun, Instruction0, Opts).
run(Mode, QueryFun, Question0, Opts) ->
Question = to_bin(Question0),
Provider = maps:get(provider, Opts, ?PROVIDER),
Size = maps:get(size, Opts, ?SIZE),
%% Optional live progress sink. Receives:
%% {thinking, binary()} - the model's interim reasoning
%% {query, binary()} - a SQL statement about to run
%% {result, {ok, non_neg_integer()}} - row count returned
%% {result, {error, binary()}} - query error the model will see
Report = maps:get(on_event, Opts, fun(_) -> ok end),
Schema = fetch_schema(QueryFun),
Tool = sql_tool(Mode),
Messages = [
#{role => system, content => system_prompt(Schema, Mode)},
#{role => user, content => Question}
],
loop(QueryFun, Report, Provider, Size, Tool, Messages, [], ?MAX_STEPS).
%%--- Agent loop -----------------------------------------------------
loop(_QueryFun, _Report, _Provider, _Size, _Tool, _Messages, Queries, 0) ->
{ok, #{
answer =>
<<"I hit the maximum number of query steps before reaching a "
"confident answer. The SQL I ran is listed below.">>,
queries => lists:reverse(Queries)
}};
loop(QueryFun, Report, Provider, Size, Tool, Messages, Queries, Steps) ->
case llm:chat(Provider, Size, Messages, [Tool]) of
{ok, #{tool_calls := []} = Resp} ->
{ok, #{
answer => maps:get(content, Resp, <<>>),
queries => lists:reverse(Queries)
}};
{ok, #{tool_calls := Calls} = Resp} ->
Content = maps:get(content, Resp, <<>>),
maybe_report_thinking(Report, Content),
Assistant = #{role => assistant, content => Content, tool_calls => Calls},
{ResultMsgs, Queries1} = run_calls(QueryFun, Report, Calls, Queries),
loop(
QueryFun,
Report,
Provider,
Size,
Tool,
Messages ++ [Assistant | ResultMsgs],
Queries1,
Steps - 1
);
{error, _Reason} = Error ->
Error
end.
maybe_report_thinking(_Report, <<>>) -> ok;
maybe_report_thinking(Report, Content) -> Report({thinking, Content}).
run_calls(QueryFun, Report, Calls, Queries) ->
lists:foldl(
fun(Call, {MsgAcc, QAcc}) ->
{Content, QAcc1} = handle_call(QueryFun, Report, Call, QAcc),
ResultMsg = #{
role => tool_result,
tool_use_id => maps:get(id, Call),
content => Content
},
{MsgAcc ++ [ResultMsg], QAcc1}
end,
{[], Queries},
Calls
).
handle_call(QueryFun, Report, Call, Queries) ->
case call_sql(Call) of
undefined ->
{json_util:encode(#{<<"error">> => <<"missing required 'sql' argument">>}), Queries};
Sql ->
Report({query, Sql}),
{Content, Summary} = run_one(QueryFun, Sql),
Report({result, Summary}),
{Content, [Sql | Queries]}
end.
call_sql(Call) ->
case maps:get(input, Call, #{}) of
#{<<"sql">> := Sql} when is_binary(Sql) -> Sql;
_ -> undefined
end.
%% Returns {ContentForModel, SummaryForReport}.
run_one(QueryFun, Sql) ->
case QueryFun(Sql) of
{ok, Columns, Rows} ->
{Shown, Truncated} = cap_rows(Rows, ?MAX_ROWS_TO_MODEL),
Base = #{
<<"columns">> => Columns,
<<"row_count">> => length(Rows),
<<"rows">> => Shown
},
Payload =
case Truncated of
true -> Base#{<<"truncated">> => true};
false -> Base
end,
{json_util:encode(Payload), {ok, length(Rows)}};
{error, Reason} ->
Bin = to_bin(io_lib:format("~p", [Reason])),
{json_util:encode(#{<<"error">> => Bin}), {error, Bin}}
end.
cap_rows(Rows, Max) when length(Rows) > Max ->
{lists:sublist(Rows, Max), true};
cap_rows(Rows, _Max) ->
{Rows, false}.
%%--- Schema introspection -------------------------------------------
%% Pull a compact, live view of the public schema so the model starts grounded.
%% Works for tables added in the future since it reads `information_schema'. If
%% it fails, we degrade gracefully and let the model introspect on its own.
fetch_schema(QueryFun) ->
Sql =
<<"SELECT table_name, column_name, data_type "
"FROM information_schema.columns "
"WHERE table_schema = 'public' "
"ORDER BY table_name, ordinal_position">>,
case QueryFun(Sql) of
{ok, _Columns, Rows} when Rows =/= [] ->
render_schema(Rows);
_ ->
<<"(schema unavailable; use information_schema.columns to explore tables)">>
end.
render_schema(Rows) ->
Grouped = lists:foldl(
fun(Row, Acc) ->
Table = maps:get(<<"table_name">>, Row, <<"?">>),
Col = maps:get(<<"column_name">>, Row, <<"?">>),
Type = maps:get(<<"data_type">>, Row, <<"?">>),
Entry = <<Col/binary, " ", Type/binary>>,
maps:update_with(Table, fun(Cols) -> [Entry | Cols] end, [Entry], Acc)
end,
#{},
Rows
),
Lines = [
<<"- ", Table/binary, "(", (join(<<", ">>, lists:reverse(Cols)))/binary, ")">>
|| {Table, Cols} <- lists:sort(maps:to_list(Grouped))
],
join(<<"\n">>, Lines).
%%--- Prompt + tool --------------------------------------------------
system_prompt(Schema, read) ->
<<"You are a careful, read-only data analyst for a PostgreSQL database.\n"
"Answer the user's question by running SQL SELECT queries with the "
"run_sql tool, then summarising the result in plain English.\n\n"
"Rules:\n"
"- The connection is strictly read-only; only SELECT/read queries work.\n"
"- Run one focused query per tool call; iterate if you need more.\n"
"- Always use LIMIT on potentially large result sets.\n"
"- created_at, updated_at and expires_at are unix epoch SECONDS (bigint). "
"Compare against extract(epoch from now()) when you need 'currently'.\n"
"- If the schema below is not enough, query information_schema to explore.\n"
"- When confident, STOP calling tools and give a concise, direct answer.\n\n"
"Database schema (table(column type, ...)):\n", Schema/binary>>;
system_prompt(Schema, write) ->
<<"You are a careful database operator for a PostgreSQL database.\n"
"Carry out the user's requested change by running SQL with the run_sql "
"tool, then summarise exactly what you changed in plain English.\n\n"
"Rules:\n"
"- You may run SELECT, INSERT, UPDATE and DELETE statements.\n"
"- Inspect first: SELECT the rows you intend to change so you can confirm "
"you are touching exactly what was asked for, then make the change.\n"
"- Always scope an UPDATE or DELETE with a precise WHERE clause (e.g. by "
"id). Never run an unscoped UPDATE or DELETE.\n"
"- Never DROP or TRUNCATE tables and never alter the schema (no DDL).\n"
"- Make the smallest change that satisfies the request. If it is ambiguous "
"or would affect more rows than expected, STOP and explain instead of "
"guessing.\n"
"- Prefer adding RETURNING to writes so you can confirm which rows changed.\n"
"- created_at, updated_at and expires_at are unix epoch SECONDS (bigint). "
"Compare against extract(epoch from now()) when you need 'currently'.\n"
"- When done, STOP calling tools and give a concise summary of exactly what "
"changed (including row counts).\n\n"
"Database schema (table(column type, ...)):\n", Schema/binary>>.
sql_tool(read) ->
sql_tool_spec(
<<"Run a single read-only PostgreSQL SELECT statement and get the "
"rows back as JSON. Writes are rejected by the server.">>,
<<"A single SQL SELECT statement.">>
);
sql_tool(write) ->
sql_tool_spec(
<<"Run a single PostgreSQL statement (SELECT/INSERT/UPDATE/DELETE) and "
"get the resulting rows (or affected_rows count) back as JSON.">>,
<<"A single SQL statement (SELECT, or a WHERE-scoped INSERT/UPDATE/DELETE, "
"optionally with RETURNING).">>
).
sql_tool_spec(Description, SqlDescription) ->
#{
name => <<"run_sql">>,
description => Description,
parameters => #{
<<"type">> => <<"object">>,
<<"properties">> => #{
<<"sql">> => #{
<<"type">> => <<"string">>,
<<"description">> => SqlDescription
}
},
<<"required">> => [<<"sql">>],
<<"additionalProperties">> => false
}
}.
%%--- Helpers --------------------------------------------------------
join(_Sep, []) -> <<>>;
join(Sep, [H | T]) ->
lists:foldl(fun(X, Acc) -> <<Acc/binary, Sep/binary, X/binary>> end, H, T).
to_bin(B) when is_binary(B) -> B;
to_bin(L) when is_list(L) -> unicode:characters_to_binary(L).