Skip to main content

README.md

# db_agent

A tiny natural-language → SQL agent for PostgreSQL. Give it a question and a
function that runs SQL, and it drives a short LLM tool-use loop: the model
writes SQL, your function runs it, the rows are fed back, and it repeats until
the model has a plain-English answer. The exact SQL it ran is returned
alongside the answer.

## The idea

`db_agent` owns the *logic* ("what SQL should I run to answer this?"). It does
**not** own the *plumbing* ("how do I connect, enforce read-only, format
rows?"). Those two responsibilities meet at a single function argument — the
`QueryFun`.

```
your CLI  ───names & calls──▶  db_agent:answer/3
                                    │ calls back the fun value you passed
your CLI's QueryFun  ◀──────────  QueryFun(Sql)
  (runs SQL via epgsql, returns rows)
```

The dependency arrow only points one way: **you depend on `db_agent`;
`db_agent` never names your code.** It only ever calls the function value you
hand it (exactly like `lists:map/2` calls the fun you pass it). That is what
makes it reusable across any project and any database client — you supply the
`QueryFun`, the agent stays untouched.

This is plain inversion of control, not an Erlang `behaviour`: there is no
callback module to implement, just one function value to provide. The caller is
in charge of wiring; the agent is in charge of the loop.

Read-only vs. write is decided *entirely* by which `QueryFun` you pass:

- **`db_agent:answer/3`** — read-only analysis. Pair it with a `QueryFun` you
  have pinned read-only at the DB level; the model can only `SELECT`.
- **`db_agent:edit/3`** — write-capable changes. Pair it with a write-capable
  `QueryFun`; the model may run `WHERE`-scoped `INSERT`/`UPDATE`/`DELETE`, and
  reports exactly what it changed.

## The contract

```erlang
-type query_fun() :: fun((binary()) ->
    {ok, [binary()], [map()]} | {error, term()}).
```

A `QueryFun` takes one SQL string and returns either:

- `{ok, ColumnNames, RowMaps}``ColumnNames` is a list of binaries; each row
  is a `#{ColumnName => JsonSafeValue}` map, or
- `{error, Reason}` — the agent shows the error to the model so it can recover.

## Modules

| Module       | Role                                                           |
|--------------|----------------------------------------------------------------|
| `db_agent`   | the tool-use loop and prompts (`answer/2,3`, `edit/2,3`)       |

The LLM client (`llm`) and JSON codec (`json_util`) live in the separate
[`erlangchain`](https://github.com/abhavk/erlangchain) library, which `db_agent`
depends on. `erlangchain` has no third-party dependencies — only OTP `inets` +
`ssl` for HTTP.

## Install

```erlang
%% rebar.config — db_agent pulls in erlangchain transitively
{deps, [
    {db_agent, {git, "https://github.com/abhavk/db_agent.git", {tag, "0.1.0"}}}
]}.
```

Set one of `OPENAI_API_KEY` (default provider) or `ANTHROPIC_API_KEY` in the
environment. A `.env` file in the working directory is also loaded
automatically if present.

## Quick start

```erlang
QueryFun = fun(Sql) -> my_run_readonly(Conn, Sql) end,

{ok, #{answer := Answer, queries := Queries}} =
    db_agent:answer(QueryFun, <<"how many users signed up today?">>, #{}).
```

`edit/3` is identical except you pass a write-capable `QueryFun`:

```erlang
WriteFun = fun(Sql) -> my_run_write(Conn, Sql) end,

{ok, #{answer := Summary, queries := Queries}} =
    db_agent:edit(WriteFun, <<"delete the frame_listener with id 8fd97600...">>, #{}).
```

### Options (`Opts`)

| Key        | Meaning                                                      |
|------------|-------------------------------------------------------------|
| `provider` | `openai` (default) or `anthropic`                           |
| `size`     | `big` (default) or `small`                                  |
| `on_event` | `fun((Event) -> ok)` live progress sink (see below)         |

`on_event` receives `{thinking, Bin}`, `{query, Sql}`,
`{result, {ok, RowCount}}`, and `{result, {error, Bin}}`.

## Reference implementation (an `ask` / `edit` CLI over Postgres)

This is the glue layer — it lives in *your* project, not the library. It builds
the read-only and write-capable `QueryFun`s around `epgsql` and wires up a
`./db ask "..."` / `./db edit "..."` command. It works as-is against any
PostgreSQL database; connection settings come from environment variables.

`db.erl` (an escript):

```erlang
-module(db).
-export([main/1, ask/1, edit/1]).

main(Args) ->
    add_escript_code_paths(),
    halt(run([normalize_arg(A) || A <- Args])).

run(["ask"  | Q]) when Q =/= [] -> ask(join_text(Q));
run(["edit" | I]) when I =/= [] -> edit(join_text(I));
run(_)                          -> usage(), 2.

%% Read-only: QueryFun is pinned read-only, so the agent can only SELECT.
ask(Question) ->
    with_conn(fun(Conn) ->
        QueryFun = fun(Sql) -> run_readonly(Conn, Sql) end,
        report(db_agent:answer(QueryFun, Question, #{on_event => fun report_event/1}))
    end).

%% Write-capable: QueryFun lets writes through, so edits take effect.
edit(Instruction) ->
    with_conn(fun(Conn) ->
        QueryFun = fun(Sql) -> run_write(Conn, Sql) end,
        report(db_agent:edit(QueryFun, Instruction, #{on_event => fun report_event/1}))
    end).

%% --- the QueryFun implementations -------------------------------------

%% Pin the session read-only at the server so any write is rejected.
run_readonly(Conn, Sql) ->
    _ = epgsql:squery(Conn, "SET default_transaction_read_only = on"),
    _ = epgsql:squery(Conn, "SET statement_timeout = '15000'"),
    interpret_result(epgsql:squery(Conn, Sql)).

%% Write-capable counterpart: timeout still guards, session stays read-write.
run_write(Conn, Sql) ->
    _ = epgsql:squery(Conn, "SET statement_timeout = '15000'"),
    interpret_result(epgsql:squery(Conn, Sql)).

%% Normalize epgsql results into the {ok, Cols, RowMaps} the agent expects.
%% A non-row command (INSERT/UPDATE/DELETE without RETURNING) reports its
%% affected-row count as a single `affected_rows` row.
interpret_result({ok, Columns, Rows}) ->
    Names = [column_name(C) || C <- Columns],
    {ok, Names, [row_object(Names, R) || R <- Rows]};
interpret_result({ok, Count}) when is_integer(Count) ->
    {ok, [<<"affected_rows">>], [#{<<"affected_rows">> => Count}]};
interpret_result({ok, _Count, Columns, Rows}) ->
    Names = [column_name(C) || C <- Columns],
    {ok, Names, [row_object(Names, R) || R <- Rows]};
interpret_result({error, _} = Error) ->
    Error;
interpret_result(Results) when is_list(Results) ->
    interpret_result(pick_result(Results)).

pick_result(Results) ->
    case [R || R <- Results, is_row_result(R)] of
        []   -> lists:last(Results);
        Rows -> lists:last(Rows)
    end.

is_row_result({ok, _, _})    -> true;
is_row_result({ok, _, _, _}) -> true;
is_row_result(_)             -> false.

%% epgsql columns are #column{} records; read the name positionally to avoid
%% pulling in the epgsql header.
column_name(C) when is_tuple(C) -> element(2, C);
column_name(C)                  -> C.

row_object(Names, Row) when is_tuple(Row) ->
    maps:from_list(lists:zip(Names, [json_value(V) || V <- tuple_to_list(Row)])).

%% Coerce raw epgsql values into JSON-encodable terms.
json_value(null)                 -> null;
json_value(V) when is_binary(V)  -> V;
json_value(V) when is_integer(V) -> V;
json_value(V) when is_float(V)   -> V;
json_value(true)                 -> true;
json_value(false)                -> false;
json_value(V) -> iolist_to_binary(io_lib:format("~p", [V])).

%% --- connection + output ----------------------------------------------

with_conn(Fun) ->
    {ok, Conn} = epgsql:connect(
        os:getenv("DB_HOST", "localhost"),
        os:getenv("DB_USER", "postgres"),
        os:getenv("DB_PASSWORD", ""),
        [{database, os:getenv("DB_NAME", "postgres")},
         {port, list_to_integer(os:getenv("DB_PORT", "5432"))}]),
    try Fun(Conn) after epgsql:close(Conn) end.

report({ok, #{answer := Answer, queries := Queries}}) ->
    io:format("~ts~n", [Answer]),
    case Queries of
        [] -> ok;
        _  ->
            io:format("~n--- SQL run (~b) ---~n", [length(Queries)]),
            lists:foreach(fun(Q) -> io:format("~ts~n", [Q]) end, Queries)
    end,
    0;
report({error, Reason}) ->
    io:format(standard_error, "error: ~p~n", [Reason]),
    1.

%% Live progress to stderr, keeping stdout clean for the answer + SQL.
report_event({thinking, T})        -> io:format(standard_error, "~n[thinking] ~ts~n", [T]);
report_event({query, Sql})         -> io:format(standard_error, "[query] ~ts~n", [Sql]);
report_event({result, {ok, N}})    -> io:format(standard_error, "[result] ~b row(s)~n", [N]);
report_event({result, {error, R}}) -> io:format(standard_error, "[result] error: ~ts~n", [R]).

usage() ->
    io:format(standard_error,
        "Usage:~n  ./db ask  <question...>~n  ./db edit <instruction...>~n", []).

join_text(Parts)        -> unicode:characters_to_binary(lists:join(" ", Parts)).
normalize_arg(A) when is_atom(A) -> atom_to_list(A);
normalize_arg(A)                 -> A.

add_escript_code_paths() ->
    code:add_pathsa(filelib:wildcard("_build/default/lib/*/ebin")),
    ok.
```

The `./db` wrapper:

```bash
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
exec escript "$SCRIPT_DIR/db.erl" "$@"
```

Usage:

```bash
./db ask  "what frame listeners are currently active?"
./db edit "delete the frame_listener with id 8fd976005666a1975a6b6f3f26239677"
```

## License

MIT — see [LICENSE](LICENSE).