Skip to main content

src/adk_llm_gemini.erl

-module(adk_llm_gemini).
-behaviour(adk_llm).

-export([generate/3]).

generate(Config, Memory, Tools) ->
    %% Default to gemini-1.5-flash if model not provided
    Model = maps:get(model, Config, <<"gemini-3.5-flash">>),
    ApiKey = get_api_key(Config),

    Url =
        "https://generativelanguage.googleapis.com/v1beta/models/" ++
            binary_to_list(Model) ++ ":generateContent?key=" ++ binary_to_list(ApiKey),

    %% Build the payload
    {SystemInstruction, Contents} = build_contents(Memory, <<>>, []),

    Payload0 = #{<<"contents">> => Contents},
    Payload1 =
        case SystemInstruction of
            <<>> -> Payload0;
            Sys -> Payload0#{<<"system_instruction">> => #{<<"parts">> => #{<<"text">> => Sys}}}
        end,

    %% Add generation config (bells and whistles)
    GenConfig = build_gen_config(Config),
    Payload2 =
        case maps:size(GenConfig) of
            0 -> Payload1;
            _ -> Payload1#{<<"generationConfig">> => GenConfig}
        end,

    %% Add tools if present
    Payload3 = case Tools of
        [] -> Payload2;
        _ ->
            Declarations = [Mod:schema() || Mod <- Tools],
            Payload2#{<<"tools">> => [#{<<"function_declarations">> => Declarations}]}
    end,

    JsonBody = jsx:encode(Payload3),

    %% HTTP Request
    Headers = [{"Content-Type", "application/json"}],
    Request = {Url, Headers, "application/json", JsonBody},
    %% Simplified SSL
    HttpOptions = [{ssl, [{verify, verify_none}]}],
    Options = [{body_format, binary}],

    case httpc:request(post, Request, HttpOptions, Options) of
        {ok, {{_Version, 200, _ReasonPhrase}, _Headers, Body}} ->
            ResponseMap = jsx:decode(Body, [return_maps]),
            parse_response(ResponseMap);
        {ok, {{_Version, StatusCode, ReasonPhrase}, _Headers, Body}} ->
            {error, {StatusCode, ReasonPhrase, Body}};
        {error, Reason} ->
            {error, Reason}
    end.

get_api_key(Config) ->
    case maps:find(api_key, Config) of
        {ok, Key} ->
            Key;
        error ->
            case os:getenv("GEMINI_API_KEY") of
                false -> throw({error, missing_api_key});
                KeyStr -> list_to_binary(KeyStr)
            end
    end.

build_contents([], SysAcc, ContentsAcc) ->
    {SysAcc, lists:reverse(ContentsAcc)};
build_contents([#{role := system, content := Content} | Rest], SysAcc, ContentsAcc) ->
    ContentBin = to_binary(Content),
    NewSys =
        case SysAcc of
            <<>> -> ContentBin;
            _ -> <<SysAcc/binary, "\n", ContentBin/binary>>
        end,
    build_contents(Rest, NewSys, ContentsAcc);
build_contents([#{role := tool} | _] = History, SysAcc, ContentsAcc) ->
    %% Group consecutive tool responses into a single 'user' message
    {ToolParts, Rest} = consume_tool_responses(History, []),
    Msg = #{<<"role">> => <<"user">>, <<"parts">> => ToolParts},
    build_contents(Rest, SysAcc, [Msg | ContentsAcc]);
build_contents([#{role := agent, content := {tool_calls, Calls}} | Rest], SysAcc, ContentsAcc) ->
    Parts = [begin
        {Name, Args, Sig} = case Call of
            {N, A} -> {N, A, undefined};
            {N, A, S} -> {N, A, S}
        end,
        FuncCall = #{
            <<"name">> => to_binary(Name),
            <<"args">> => Args
        },
        Part0 = #{<<"functionCall">> => FuncCall},
        case Sig of
            undefined -> Part0;
            _ -> Part0#{<<"thought_signature">> => Sig}
        end
    end || Call <- Calls],
    Msg = #{<<"role">> => <<"model">>, <<"parts">> => Parts},
    build_contents(Rest, SysAcc, [Msg | ContentsAcc]);
build_contents([#{role := Role, content := Content} | Rest], SysAcc, ContentsAcc) ->
    GeminiRole =
        case Role of
            user -> <<"user">>;
            agent -> <<"model">>
        end,
    Part = #{<<"text">> => to_binary(Content)},
    Msg = #{<<"role">> => GeminiRole, <<"parts">> => [Part]},
    build_contents(Rest, SysAcc, [Msg | ContentsAcc]).

consume_tool_responses([#{role := tool, content := {tool_response, Name, ResponseMap, Sig}} | Rest], Acc) ->
    FuncResp = #{
        <<"name">> => to_binary(Name),
        <<"response">> => ResponseMap
    },
    Part0 = #{<<"functionResponse">> => FuncResp},
    Part1 = case Sig of
        undefined -> Part0;
        _ -> Part0#{<<"thoughtSignature">> => Sig}
    end,
    consume_tool_responses(Rest, [Part1 | Acc]);
consume_tool_responses([#{role := tool, content := {tool_response, Name, ResponseMap}} | Rest], Acc) ->
    Part = #{<<"functionResponse">> => #{
        <<"name">> => to_binary(Name),
        <<"response">> => ResponseMap
    }},
    consume_tool_responses(Rest, [Part | Acc]);
consume_tool_responses(Rest, Acc) ->
    {lists:reverse(Acc), Rest}.

build_gen_config(Config) ->
    Keys = [
        {temperature, <<"temperature">>},
        {max_tokens, <<"maxOutputTokens">>},
        {top_p, <<"topP">>},
        {top_k, <<"topK">>}
    ],
    lists:foldl(
        fun({K, GeminiK}, Acc) ->
            case maps:find(K, Config) of
                {ok, V} -> Acc#{GeminiK => V};
                error -> Acc
            end
        end,
        #{},
        Keys
    ).

parse_response(#{<<"candidates">> := [#{<<"content">> := #{<<"parts">> := Parts}} | _]}) ->
    %% Look for function calls first
    ToolCallsWithSigs = lists:filtermap(
        fun
            (Part = #{<<"functionCall">> := FuncCall}) ->
                Name = maps:get(<<"name">>, FuncCall),
                Args = maps:get(<<"args">>, FuncCall, #{}),
                Sig = maps:get(<<"thoughtSignature">>, Part, undefined),
                {true, {Name, Args, Sig}};
            (_) -> false
        end, Parts),
    
    %% Propagate the thought_signature to all tool calls in parallel array
    GlobalSig = lists:foldl(fun
        ({_, _, S}, _Acc) when S =/= undefined -> S;
        (_, Acc) -> Acc
    end, undefined, ToolCallsWithSigs),
    
    ToolCalls = [{N, A, GlobalSig} || {N, A, _} <- ToolCallsWithSigs],
    
    case ToolCalls of
        [] ->
            %% Try to extract text
            TextParts = [Text || #{<<"text">> := Text} <- Parts],
            case TextParts of
                [] -> {error, empty_response};
                [Text|_] -> {ok, Text}
            end;
        _ ->
            {tool_calls, ToolCalls}
    end;
parse_response(_) ->
    {error, invalid_response}.

to_binary(Atom) when is_atom(Atom) -> atom_to_binary(Atom, utf8);
to_binary(Str) when is_list(Str) -> unicode:characters_to_binary(Str);
to_binary(Bin) when is_binary(Bin) -> Bin.