lib/vnext_genai/thread/standard.ex

defmodule GenAI.Thread.Standard do
  @moduledoc """
  This module defines the chat struct used to manage conversations with generative AI models.
  """

  @vsn 1.0

  defstruct state: %GenAI.Thread.State{},
            graph: %GenAI.Graph{},
            vsn: @vsn

  def new(options \\ nil)

  def new(_) do
    %__MODULE__{}
  end

  defimpl GenAI.ThreadProtocol do
    @moduledoc """
    This allows chat contexts to be used for configuring and running GenAI interactions.
    """

    import GenAI.Helpers, only: [apply_label: 2]

    defp append_node(thread_context, node) do
      update_in(thread_context, [Access.key(:graph)], &GenAI.Graph.append_node(&1, node))
    end

    # ------------------------
    # append_directive/4
    # ------------------------
    def append_directive(session, directive, context, options \\ nil)

    def append_directive(session, _, _, _) do
      {:ok, session}
    end

    def with_model(thread_context, model) do
      if GenAI.Graph.Node.is_node?(model, GenAI.Model) do
        append_node(thread_context, model)
      else
        raise ArgumentError,
          message: "Model must be a GenAI.Model node, got #{inspect(model)}"
      end
    end

    def with_tool(thread_context, tool) do
      if GenAI.Graph.Node.is_node?(tool, GenAI.Tool) do
        append_node(thread_context, tool)
      else
        raise ArgumentError,
          message: "Tool must be a GenAI.Tool node, got #{inspect(tool)}"
      end
    end

    def with_tools(thread_context, tools)

    def with_tools(thread_context, nil),
      do: thread_context

    def with_tools(thread_context, tools) do
      Enum.reduce(tools, thread_context, fn tool, thread_context ->
        with_tool(thread_context, tool)
      end)
    end

    def with_api_key(thread_context, provider, api_key) do
      node =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: :api_key,
          value: api_key
        )

      append_node(thread_context, node)
    end

    def with_api_org(thread_context, provider, api_org) do
      node =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: :api_org,
          value: api_org
        )

      append_node(thread_context, node)
    end

    def with_setting(thread_context, setting, value) do
      node =
        GenAI.Setting.new(
          setting: setting,
          value: value
        )

      append_node(thread_context, node)
    end

    def with_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_settings(thread_context, nil), do: thread_context

    def with_settings(thread_context, settings) do
      Enum.reduce(settings, thread_context, fn
        {setting, value}, thread_context -> with_setting(thread_context, setting, value)
        setting_object, thread_context -> with_setting(thread_context, setting_object)
      end)
    end

    def with_safety_setting(thread_context, category, threshold) do
      node =
        GenAI.Setting.SafetySetting.new(
          category: category,
          threshold: threshold
        )

      append_node(thread_context, node)
    end

    def with_safety_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_safety_settings(thread_context, nil), do: thread_context

    def with_safety_settings(thread_context, entries) do
      Enum.reduce(entries, thread_context, fn
        {category, threshold}, thread_context ->
          with_safety_setting(thread_context, category, threshold)

        setting_object, thread_context ->
          with_safety_setting(thread_context, setting_object)
      end)
    end

    def with_provider_setting(thread_context, provider, setting, value) do
      node =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: setting,
          value: value
        )

      append_node(thread_context, node)
    end

    def with_provider_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_provider_settings(thread_context, provider, settings)
    def with_provider_settings(thread_context, _, nil), do: thread_context

    def with_provider_settings(thread_context, provider, settings) do
      Enum.reduce(settings, thread_context, fn
        {setting, value}, thread_context ->
          with_provider_setting(thread_context, provider, setting, value)

        setting_object, thread_context ->
          with_provider_setting(thread_context, setting_object)
      end)
    end

    def with_provider_settings(thread_context, settings)
    def with_provider_settings(thread_context, nil), do: thread_context

    def with_provider_settings(thread_context, settings) do
      Enum.reduce(settings, thread_context, fn
        setting_object, thread_context -> with_provider_setting(thread_context, setting_object)
      end)
    end

    def with_model_setting(thread_context, model, setting, value) do
      node =
        GenAI.Setting.ModelSetting.new(
          model: model,
          setting: setting,
          value: value
        )

      append_node(thread_context, node)
    end

    def with_model_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_model_settings(thread_context, model, settings) do
      Enum.reduce(settings, thread_context, fn
        {setting, value}, thread_context ->
          with_model_setting(thread_context, model, setting, value)

        setting_object, thread_context ->
          with_model_setting(thread_context, setting_object)
      end)
    end

    def with_model_settings(thread_context, settings) do
      Enum.reduce(settings, thread_context, fn
        setting_object, thread_context -> with_model_setting(thread_context, setting_object)
      end)
    end

    def with_message(thread_context, message, options)

    def with_message(thread_context, message_node, _) do
      if GenAI.Graph.Node.is_node?(message_node, GenAI.Message) do
        append_node(thread_context, message_node)
      else
        raise ArgumentError,
          message: "Message must be a GenAI.Message node, got #{inspect(message_node)}"
      end
    end

    def with_messages(thread_context, messages, options) do
      Enum.reduce(messages, thread_context, fn message, thread_context ->
        with_message(thread_context, message, options)
      end)
    end

    def with_stream_handler(thread_context, handler, options \\ nil)

    def with_stream_handler(thread_context, handler, _) do
      node =
        GenAI.Setting.new(
          setting: :stream_handler,
          value: handler
        )

      append_node(thread_context, node)
    end

    @doc """
    Runs inference on the chat context.

    This function determines the final settings and model, prepares the messages, and then delegates the actual inference execution to the selected provider's `chat/3` function.
    """
    def execute(thread_context, command, context, options \\ nil)

    def execute(thread_context, :run, context, options) do
      with {:prep, thread_context} <-
             thread_context.graph
             |> GenAI.Legacy.NodeProtocol.apply(thread_context)
             |> apply_label(:prep),
           {:effective_model, {model, thread_context}} <-
             thread_context
             |> GenAI.Thread.LegacyStateProtocol.effective_model(context, options)
             |> apply_label(:effective_model),
           {:effective_provider, provider} <-
             model
             |> GenAI.ModelProtocol.provider()
             |> apply_label(:effective_provider) do
        provider.run(thread_context, context, options)
      end
    end

    def execute(thread_context, :stream, context, options) do
      with {:prep, thread_context} <-
             thread_context.graph
             |> GenAI.Legacy.NodeProtocol.apply(thread_context)
             |> apply_label(:prep),
           {:effective_model, {model, thread_context}} <-
             thread_context
             |> GenAI.Thread.LegacyStateProtocol.effective_model(context, options)
             |> apply_label(:effective_model),
           {:effective_provider, provider} <-
             model
             |> GenAI.ModelProtocol.provider()
             |> apply_label(:effective_provider) do
        # TODO - handle existing stream options
        thread_context = GenAI.with_setting(thread_context, :stream, true)
        provider.stream(thread_context, context, options)
      end
    end

    defdelegate effective_model(thread_context, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_settings(thread_context, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_safety_settings(thread_context, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_model_settings(thread_context, model, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_provider_settings(thread_context, model, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_messages(thread_context, model, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate effective_tools(thread_context, model, context, options),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate set_artifact(thread_context, artifact, value),
      to: GenAI.Thread.LegacyStateProtocol

    defdelegate get_artifact(thread_context, artifact), to: GenAI.Thread.LegacyStateProtocol
  end
end