Skip to main content

src/db_agent.erl

%%% @doc Tiny natural-language -> SQL agent over a read-only database.
%%%
%%% Given a question and a read-only query function, this drives a short
%%% tool-use loop with the LLM (`llm.erl'): the model writes SQL, we run it
%%% through `QueryFun' (which the caller pins read-only), feed the rows back,
%%% and repeat until the model stops calling the tool and gives a plain-English
%%% answer. The accumulated SQL is returned alongside the answer so the caller
%%% can show exactly what was run.
%%%
%%% The agent never touches the database directly — it only ever calls the
%%% caller-supplied `QueryFun', so it inherits whatever read-only guarantees the
%%% caller enforces (see `cli:read_query/1').
-module(db_agent).

-export([answer/2, answer/3, edit/2, edit/3]).

%% A query runner: takes a SQL string, returns column names + JSON-safe row
%% maps, or an error the model can read and recover from.
-type query_fun() :: fun((binary()) ->
    {ok, [binary()], [map()]} | {error, term()}).

-export_type([query_fun/0]).

-define(PROVIDER, openai).
-define(SIZE, big).
%% Safety rail on the tool-use loop so a confused model can't run forever.
-define(MAX_STEPS, 8).
%% Cap rows fed back to the model per query to keep the context bounded; the
%% true row_count is always reported so it knows when results were trimmed.
-define(MAX_ROWS_TO_MODEL, 100).

%% Read-only mode: the model may only SELECT. Pair with a `QueryFun' the caller
%% has pinned read-only (see `db:run_readonly/2').
-spec answer(query_fun(), binary() | list()) ->
    {ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
answer(QueryFun, Question) ->
    answer(QueryFun, Question, #{}).

-spec answer(query_fun(), binary() | list(), map()) ->
    {ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
answer(QueryFun, Question0, Opts) ->
    run(read, QueryFun, Question0, Opts).

%% Write mode: the model may run INSERT/UPDATE/DELETE in addition to SELECT, to
%% carry out a requested change. Pair with a write-capable `QueryFun' (see
%% `db:run_write/2'). The prompt instructs the model to inspect first, scope
%% every write with a WHERE clause, and stop rather than guess on ambiguity.
-spec edit(query_fun(), binary() | list()) ->
    {ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
edit(QueryFun, Instruction) ->
    edit(QueryFun, Instruction, #{}).

-spec edit(query_fun(), binary() | list(), map()) ->
    {ok, #{answer := binary(), queries := [binary()]}} | {error, term()}.
edit(QueryFun, Instruction0, Opts) ->
    run(write, QueryFun, Instruction0, Opts).

run(Mode, QueryFun, Question0, Opts) ->
    Question = to_bin(Question0),
    Provider = maps:get(provider, Opts, ?PROVIDER),
    Size = maps:get(size, Opts, ?SIZE),
    %% Optional live progress sink. Receives:
    %%   {thinking, binary()}                 - the model's interim reasoning
    %%   {query, binary()}                    - a SQL statement about to run
    %%   {result, {ok, non_neg_integer()}}    - row count returned
    %%   {result, {error, binary()}}          - query error the model will see
    Report = maps:get(on_event, Opts, fun(_) -> ok end),
    Schema = fetch_schema(QueryFun),
    Tool = sql_tool(Mode),
    Messages = [
        #{role => system, content => system_prompt(Schema, Mode)},
        #{role => user, content => Question}
    ],
    loop(QueryFun, Report, Provider, Size, Tool, Messages, [], ?MAX_STEPS).

%%--- Agent loop -----------------------------------------------------

loop(_QueryFun, _Report, _Provider, _Size, _Tool, _Messages, Queries, 0) ->
    {ok, #{
        answer =>
            <<"I hit the maximum number of query steps before reaching a "
                "confident answer. The SQL I ran is listed below.">>,
        queries => lists:reverse(Queries)
    }};
loop(QueryFun, Report, Provider, Size, Tool, Messages, Queries, Steps) ->
    case llm:chat(Provider, Size, Messages, [Tool]) of
        {ok, #{tool_calls := []} = Resp} ->
            {ok, #{
                answer => maps:get(content, Resp, <<>>),
                queries => lists:reverse(Queries)
            }};
        {ok, #{tool_calls := Calls} = Resp} ->
            Content = maps:get(content, Resp, <<>>),
            maybe_report_thinking(Report, Content),
            Assistant = #{role => assistant, content => Content, tool_calls => Calls},
            {ResultMsgs, Queries1} = run_calls(QueryFun, Report, Calls, Queries),
            loop(
                QueryFun,
                Report,
                Provider,
                Size,
                Tool,
                Messages ++ [Assistant | ResultMsgs],
                Queries1,
                Steps - 1
            );
        {error, _Reason} = Error ->
            Error
    end.

maybe_report_thinking(_Report, <<>>) -> ok;
maybe_report_thinking(Report, Content) -> Report({thinking, Content}).

run_calls(QueryFun, Report, Calls, Queries) ->
    lists:foldl(
        fun(Call, {MsgAcc, QAcc}) ->
            {Content, QAcc1} = handle_call(QueryFun, Report, Call, QAcc),
            ResultMsg = #{
                role => tool_result,
                tool_use_id => maps:get(id, Call),
                content => Content
            },
            {MsgAcc ++ [ResultMsg], QAcc1}
        end,
        {[], Queries},
        Calls
    ).

handle_call(QueryFun, Report, Call, Queries) ->
    case call_sql(Call) of
        undefined ->
            {json_util:encode(#{<<"error">> => <<"missing required 'sql' argument">>}), Queries};
        Sql ->
            Report({query, Sql}),
            {Content, Summary} = run_one(QueryFun, Sql),
            Report({result, Summary}),
            {Content, [Sql | Queries]}
    end.

call_sql(Call) ->
    case maps:get(input, Call, #{}) of
        #{<<"sql">> := Sql} when is_binary(Sql) -> Sql;
        _ -> undefined
    end.

%% Returns {ContentForModel, SummaryForReport}.
run_one(QueryFun, Sql) ->
    case QueryFun(Sql) of
        {ok, Columns, Rows} ->
            {Shown, Truncated} = cap_rows(Rows, ?MAX_ROWS_TO_MODEL),
            Base = #{
                <<"columns">> => Columns,
                <<"row_count">> => length(Rows),
                <<"rows">> => Shown
            },
            Payload =
                case Truncated of
                    true -> Base#{<<"truncated">> => true};
                    false -> Base
                end,
            {json_util:encode(Payload), {ok, length(Rows)}};
        {error, Reason} ->
            Bin = to_bin(io_lib:format("~p", [Reason])),
            {json_util:encode(#{<<"error">> => Bin}), {error, Bin}}
    end.

cap_rows(Rows, Max) when length(Rows) > Max ->
    {lists:sublist(Rows, Max), true};
cap_rows(Rows, _Max) ->
    {Rows, false}.

%%--- Schema introspection -------------------------------------------

%% Pull a compact, live view of the public schema so the model starts grounded.
%% Works for tables added in the future since it reads `information_schema'. If
%% it fails, we degrade gracefully and let the model introspect on its own.
fetch_schema(QueryFun) ->
    Sql =
        <<"SELECT table_name, column_name, data_type "
            "FROM information_schema.columns "
            "WHERE table_schema = 'public' "
            "ORDER BY table_name, ordinal_position">>,
    case QueryFun(Sql) of
        {ok, _Columns, Rows} when Rows =/= [] ->
            render_schema(Rows);
        _ ->
            <<"(schema unavailable; use information_schema.columns to explore tables)">>
    end.

render_schema(Rows) ->
    Grouped = lists:foldl(
        fun(Row, Acc) ->
            Table = maps:get(<<"table_name">>, Row, <<"?">>),
            Col = maps:get(<<"column_name">>, Row, <<"?">>),
            Type = maps:get(<<"data_type">>, Row, <<"?">>),
            Entry = <<Col/binary, " ", Type/binary>>,
            maps:update_with(Table, fun(Cols) -> [Entry | Cols] end, [Entry], Acc)
        end,
        #{},
        Rows
    ),
    Lines = [
        <<"- ", Table/binary, "(", (join(<<", ">>, lists:reverse(Cols)))/binary, ")">>
     || {Table, Cols} <- lists:sort(maps:to_list(Grouped))
    ],
    join(<<"\n">>, Lines).

%%--- Prompt + tool --------------------------------------------------

system_prompt(Schema, read) ->
    <<"You are a careful, read-only data analyst for a PostgreSQL database.\n"
        "Answer the user's question by running SQL SELECT queries with the "
        "run_sql tool, then summarising the result in plain English.\n\n"
        "Rules:\n"
        "- The connection is strictly read-only; only SELECT/read queries work.\n"
        "- Run one focused query per tool call; iterate if you need more.\n"
        "- Always use LIMIT on potentially large result sets.\n"
        "- created_at, updated_at and expires_at are unix epoch SECONDS (bigint). "
        "Compare against extract(epoch from now()) when you need 'currently'.\n"
        "- If the schema below is not enough, query information_schema to explore.\n"
        "- When confident, STOP calling tools and give a concise, direct answer.\n\n"
        "Database schema (table(column type, ...)):\n", Schema/binary>>;
system_prompt(Schema, write) ->
    <<"You are a careful database operator for a PostgreSQL database.\n"
        "Carry out the user's requested change by running SQL with the run_sql "
        "tool, then summarise exactly what you changed in plain English.\n\n"
        "Rules:\n"
        "- You may run SELECT, INSERT, UPDATE and DELETE statements.\n"
        "- Inspect first: SELECT the rows you intend to change so you can confirm "
        "you are touching exactly what was asked for, then make the change.\n"
        "- Always scope an UPDATE or DELETE with a precise WHERE clause (e.g. by "
        "id). Never run an unscoped UPDATE or DELETE.\n"
        "- Never DROP or TRUNCATE tables and never alter the schema (no DDL).\n"
        "- Make the smallest change that satisfies the request. If it is ambiguous "
        "or would affect more rows than expected, STOP and explain instead of "
        "guessing.\n"
        "- Prefer adding RETURNING to writes so you can confirm which rows changed.\n"
        "- created_at, updated_at and expires_at are unix epoch SECONDS (bigint). "
        "Compare against extract(epoch from now()) when you need 'currently'.\n"
        "- When done, STOP calling tools and give a concise summary of exactly what "
        "changed (including row counts).\n\n"
        "Database schema (table(column type, ...)):\n", Schema/binary>>.

sql_tool(read) ->
    sql_tool_spec(
        <<"Run a single read-only PostgreSQL SELECT statement and get the "
            "rows back as JSON. Writes are rejected by the server.">>,
        <<"A single SQL SELECT statement.">>
    );
sql_tool(write) ->
    sql_tool_spec(
        <<"Run a single PostgreSQL statement (SELECT/INSERT/UPDATE/DELETE) and "
            "get the resulting rows (or affected_rows count) back as JSON.">>,
        <<"A single SQL statement (SELECT, or a WHERE-scoped INSERT/UPDATE/DELETE, "
            "optionally with RETURNING).">>
    ).

sql_tool_spec(Description, SqlDescription) ->
    #{
        name => <<"run_sql">>,
        description => Description,
        parameters => #{
            <<"type">> => <<"object">>,
            <<"properties">> => #{
                <<"sql">> => #{
                    <<"type">> => <<"string">>,
                    <<"description">> => SqlDescription
                }
            },
            <<"required">> => [<<"sql">>],
            <<"additionalProperties">> => false
        }
    }.

%%--- Helpers --------------------------------------------------------

join(_Sep, []) -> <<>>;
join(Sep, [H | T]) ->
    lists:foldl(fun(X, Acc) -> <<Acc/binary, Sep/binary, X/binary>> end, H, T).

to_bin(B) when is_binary(B) -> B;
to_bin(L) when is_list(L) -> unicode:characters_to_binary(L).