lib/helpers.ex

# SPDX-FileCopyrightText: 2024 Isaak Tsalicoglou <isaak@waseigo.com>
# SPDX-License-Identifier: Apache-2.0

defmodule Ollamex.Helpers do
  @moduledoc """
  Module containing helper functions used by the other modules of Ollamex.
  """
  @moduledoc since: "0.1.0"

  alias Ollamex.LLMResponse

  defp at_least_one(key, r) when is_list(r) and is_bitstring(key) do
    r
    |> Enum.filter(fn x -> key in Map.keys(x) end)
    |> length()
    |> Kernel.>(0)
  end

  @doc """
  Because the body of the responses of the `/generate`, `/chat` and `/embeddings` endpoints of the Ollama REST API return different data structures, this helper function detects the endpoint from which a list of responses originated.
  """
  @doc since: "0.2.0"
  def detect_endpoint(r) when is_list(r) do
    lookup = %{
      "response" => :generate,
      "message" => :chat,
      "embedding" => :embeddings
    }

    lookup
    |> Map.keys()
    |> Enum.map(&at_least_one(&1, r))
    |> Enum.find_index(&(&1 == true))
    |> (&Enum.at(Map.values(lookup), &1)).()
  end

  def handle_response(body) do
    r =
      body
      |> Enum.map(&Jason.decode!(&1))
      |> Enum.to_list()

    case detect_endpoint(r) do
      :generate ->
        r |> consolidate_responses()

      :chat ->
        r |> consolidate_responses()

      :embeddings ->
        %{%LLMResponse{} | embedding: hd(r)["embedding"]}
    end
  end

  @doc """
  Consolidates a list of `%Ollamex.LLMResponse{}` structs into a single response, dealing with `stream: false` and `stream: true`, regardless of the origin of the responses (`/chat` or `/generate`).
  """
  @doc since: "0.1.0"
  def consolidate_responses(r) when is_list(r) do
    rs = r |> Enum.map(&map_to_struct(&1, LLMResponse))

    response_done =
      rs
      |> Enum.filter(fn %LLMResponse{done: done} -> done end)

    case Enum.empty?(response_done) do
      true ->
        {:error, "Streamed response was never done."}

      false ->
        if detect_endpoint(r) == :chat do
          %{
            hd(response_done)
            | message: %{role: "assistant", content: extract_messages(rs)}
          }
        else
          %{hd(response_done) | response: extract_responses(rs)}
        end
    end
  end

  @doc """
  Extracts the chat message fragments from a list of `%Ollamex.LLMResponse{}` from the `/chat` endpoint and concatenates them into a string.
  """
  @doc since: "0.1.0"
  def extract_messages(r) when is_list(r) do
    r
    |> Enum.map(fn r ->
      if is_nil(r.message) do
        ""
      else
        r.message["content"]
      end
    end)
    |> List.to_string()
  end

  @doc """
  Extracts the response fragments from a list of `%Ollamex.LLMResponse{}` from the `/generate` endpoint and concatenates them into a string.
  """
  @doc since: "0.1.0"
  def extract_responses(r) when is_list(r) do
    r
    |> Enum.map(fn r -> r.response end)
    |> List.to_string()
  end

  @doc """
  Converts a map to a target struct, so that API responses can be converted to Ollamex structs.
  """
  @doc since: "0.1.0"
  def map_to_struct(source, target) when is_map(source) do
    source
    |> Map.keys()
    |> Enum.map(&String.to_atom(&1))
    |> Enum.zip(Map.values(source))
    |> (&struct(target, &1)).()
  end

  @doc """
  Creates an async task from the provided function `fun` with arguments `args`.
  """
  @doc since: "0.1.0"
  def create_task(fun, args) do
    Task.async(fn -> apply(fun, args) end)
  end

  @doc """
  Yields the result of a task if within the provided `timeout` (in milliseconds, default value `120_000`), or shuts the task down.

  Utilized by the `Ollamex.generate_with_timeout/2`, `Ollamex.generate_with_timeout/3`, `Ollamex.chat_with_timeout/2` and `Ollamex.chat_with_timeout/3` functions to deal with cases where the LLM keeps going on and on with repetition.
  """
  @doc since: "0.1.0"
  def yield_or_timeout_and_shutdown(task, timeout \\ 120_000) do
    case Task.yield(task, timeout) || Task.shutdown(task) do
      {:ok, result} ->
        {:ok, result}

      nil ->
        {:error, :timeout}

      {:exit, :noproc} ->
        {:error, :noproc}
    end
  end
end