lib/vnext_genai/thread/session.ex

defmodule GenAI.Thread.Session do
  @moduledoc """
  More advanced VNext Thread Implementation
  """
  @vsn 1.0
  require GenAI.Records.Link
  require GenAI.Records.Node
  #  require GenAI.Records.Session
  #  alias GenAI.Records.Session, as: Node
  alias GenAI.Session.Runtime
  # alias GenAI.Session.State

  defstruct state: nil,
            root: nil,
            runtime: nil,
            vsn: @vsn

  def new(options \\ nil)

  def new(options) do
    %__MODULE__{
      state: GenAI.Session.State.new(options[:state]),
      root: %GenAI.Graph.Root{graph: GenAI.VNext.Graph.new(options[:graph])},
      runtime: GenAI.Session.Runtime.new(options[:runtime]),
      vsn: @vsn
    }
  end

  # ------------------------
  # append_directive/4
  # ------------------------
  def append_directive(session, directive, context, options \\ nil)

  def append_directive(session, directive, _, _) do
    update = update_in(session, [Access.key(:state), Access.key(:directives)], &[directive | &1])
    {:ok, update}
  end

  # ------------------------
  # pending_directives?/1
  # ------------------------
  def pending_directives?(session)

  def pending_directives?(session),
    do: session.state.directive_position < length(session.state.directives)

  # ------------------------
  # apply_directives/3
  # ------------------------
  def apply_directives(session, context, options)

  def apply_directives(session, context, options) do
    offset = -(session.state.directive_position + 1)

    session =
      session.state.directives
      |> Enum.slice(0..offset//1)
      |> Enum.reverse()
      |> Enum.reduce(
        session,
        fn
          directive, session ->
            {:ok, session} =
              GenAI.Session.State.Directive.apply_directive(directive, session, context, options)

            session
        end
      )
      |> put_in(
        [Access.key(:state), Access.key(:directive_position)],
        length(session.state.directives)
      )

    {:ok, session}
  end

  # -------------------------------------------
  # GenAI.ThreadProtocol
  # ---------------------------------------------
  defimpl GenAI.ThreadProtocol do
    require Logger

    import GenAI.Helpers, only: [apply_label: 2]

    defp append_node(thread_context, node, options \\ nil)

    defp append_node(thread_context, node, options) do
      update_in(
        thread_context,
        [Access.key(:root), Access.key(:graph)],
        &GenAI.VNext.Graph.attach_node(&1, node, options)
      )
    end

    # ------------------------
    # append_directive/4
    # ------------------------
    def append_directive(session, directive, context, options \\ nil)

    def append_directive(session, directive, _, _) do
      update =
        update_in(session, [Access.key(:state), Access.key(:directives)], &[directive | &1])

      {:ok, update}
    end

    # -------------------------------------
    # with_model/2
    # -------------------------------------
    @doc """
    Specify a specific model or model picker.

      This function allows you to define the model to be used for inference.
    You can either provide a specific model, like `Model.smartest()`, or a model picker function that dynamically selects
    the best model based on the context and available providers.

      Examples:
    * `Model.smartest()` - This will select the "smartest" available model at inference time, based on factors
      like performance and capabilities.
    * `Model.cheapest(params: :best_effort)` - This will select the cheapest available model that can handle the
      given parameters and context size.
    * `CustomProvider.custom_model` - This allows you to use a custom model from a user-defined provider.
    """
    def with_model(thread_context, model) do
      if GenAI.Graph.Node.is_node?(model, GenAI.Model) do
        {:ok, n} = GenAI.Graph.NodeProtocol.with_id(model)
        append_node(thread_context, n)
      else
        raise ArgumentError,
          message: "Model must be a GenAI.Model node, got #{inspect(model)}"
      end
    end

    # -------------------------------------
    #
    # -------------------------------------

    def with_tool(thread_context, tool) do
      if GenAI.Graph.Node.is_node?(tool, GenAI.Tool) do
        {:ok, n} = GenAI.Graph.NodeProtocol.with_id(tool)
        append_node(thread_context, n)
      else
        raise ArgumentError,
          message: "Tool must be a GenAI.Tool node, got #{inspect(tool)}"
      end
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_tools(context, tools) do
      Enum.reduce(tools || [], context, fn tool, context ->
        with_tool(context, tool)
      end)
    end

    # -------------------------------------
    #
    # -------------------------------------

    @doc """
    Specify an API key for a provider.
    """
    def with_api_key(thread_context, provider, api_key) do
      n =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: :api_key,
          value: api_key
        )

      append_node(thread_context, n)
    end

    # -------------------------------------
    #
    # -------------------------------------

    @doc """
    Specify an API org for a provider.
    """
    def with_api_org(thread_context, provider, api_org) do
      n =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: :api_org,
          value: api_org
        )

      append_node(thread_context, n)
    end

    # -------------------------------------
    #
    # -------------------------------------

    @doc """
    Set a hyperparameter option.

      Some options are model-specific. The value can be a literal or a picker function that dynamically determines
    the best value based on the context and model.

      Examples:
    * `Parameter.required(name, value)` - This sets a required parameter with the specified name and value.
    * `Gemini.best_temperature_for(:chain_of_thought)` - This uses a picker function to determine the best temperature
       for the Gemini provider when using the "chain of thought" prompting technique.
    """
    def with_setting(thread_context, setting, value) do
      n =
        GenAI.Setting.new(
          setting: setting,
          value: value
        )

      append_node(thread_context, n)
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_setting(thread_context, setting) do
      if GenAI.Graph.Node.is_node?(setting, GenAI.Setting) do
        {:ok, n} = GenAI.Graph.NodeProtocol.with_id(setting)
        append_node(thread_context, n)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting)}"
      end
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_settings(thread_context, settings) do
      Enum.reduce(settings || [], thread_context, fn
        {setting, value}, thread_context -> with_setting(thread_context, setting, value)
        value, thread_context when is_struct(value) -> with_setting(thread_context, value)
      end)
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_safety_setting(thread_context, category, threshold) do
      n =
        GenAI.Setting.SafetySetting.new(
          category: category,
          threshold: threshold
        )

      append_node(thread_context, n)
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_safety_setting(thread_context, safety_setting) do
      if GenAI.Graph.Node.is_node?(safety_setting, GenAI.Setting) do
        {:ok, n} = GenAI.Graph.NodeProtocol.with_id(safety_setting)
        append_node(thread_context, n)
      else
        raise ArgumentError,
          message: "SafetySetting must be a GenAI.Setting node, got #{inspect(safety_setting)}"
      end
    end

    # -------------------------------------
    #
    # -------------------------------------
    def with_safety_settings(context, safety_settings) do
      Enum.reduce(safety_settings || [], context, fn
        {category, threshold}, context ->
          with_safety_setting(context, category, threshold)

        safety_setting, context when is_struct(safety_setting) ->
          with_safety_setting(context, safety_setting)
      end)
    end

    # -------------------------------------
    #
    # -------------------------------------

    def with_provider_setting(thread_context, provider, setting, value) do
      node =
        GenAI.Setting.ProviderSetting.new(
          provider: provider,
          setting: setting,
          value: value
        )

      append_node(thread_context, node)
    end

    def with_provider_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_provider_settings(thread_context, provider, settings) do
      Enum.reduce(settings || [], thread_context, fn
        {setting, value}, thread_context ->
          with_provider_setting(thread_context, provider, setting, value)

        setting_object, thread_context ->
          with_provider_setting(thread_context, setting_object)
      end)
    end

    def with_provider_settings(thread_context, settings) do
      Enum.reduce(settings || [], thread_context, fn
        setting_object, thread_context -> with_provider_setting(thread_context, setting_object)
      end)
    end

    def with_model_setting(thread_context, model, setting, value) do
      node =
        GenAI.Setting.ModelSetting.new(
          model: model,
          setting: setting,
          value: value
        )

      append_node(thread_context, node)
    end

    def with_model_setting(thread_context, setting_node) do
      if GenAI.Graph.Node.is_node?(setting_node, GenAI.Setting) do
        append_node(thread_context, setting_node)
      else
        raise ArgumentError,
          message: "Setting must be a GenAI.Setting node, got #{inspect(setting_node)}"
      end
    end

    def with_model_settings(thread_context, model, settings) do
      Enum.reduce(settings || [], thread_context, fn
        {setting, value}, thread_context ->
          with_model_setting(thread_context, model, setting, value)

        setting_object, thread_context ->
          with_model_setting(thread_context, setting_object)
      end)
    end

    def with_model_settings(thread_context, settings) do
      Enum.reduce(settings, thread_context, fn
        setting_object, thread_context -> with_model_setting(thread_context, setting_object)
      end)
    end

    @doc """
    Add a message to the conversation.
    """
    def with_message(thread_context, message, options \\ nil)

    def with_message(thread_context, message, _) do
      if GenAI.Graph.Node.is_node?(message, GenAI.Message) do
        {:ok, n} = GenAI.Graph.NodeProtocol.with_id(message)
        append_node(thread_context, n)
      else
        raise ArgumentError,
          message: "Message must be a GenAI.Message node, got #{inspect(message)}"
      end
    end

    # -------------------------------------
    #
    # -------------------------------------
    @doc """
    Add a list of messages to the conversation.
    """
    def with_messages(context, messages, options)

    def with_messages(context, messages, options) do
      Enum.reduce(messages || [], context, fn message, context ->
        with_message(context, message, options)
      end)
    end

    # -------------------------------------
    #
    # -------------------------------------
    @doc """
    specify/override default stream handler
    """
    def with_stream_handler(thread_context, handler, options \\ nil)

    def with_stream_handler(thread_context, handler, _) do
      node =
        GenAI.Setting.new(
          setting: :stream_handler,
          value: handler
        )

      append_node(thread_context, node)
    end

    # -------------------------------------
    #
    # -------------------------------------
    defp initialize_session(command, session, context, options) do
      with %{state: state, runtime: runtime} <-
             session,
           {:ok, runtime} <-
             Runtime.command(runtime, command, context, options),
           {:ok, state} <-
             GenAI.Session.State.initialize(state, runtime, context, options),
           {:ok, {state, runtime}} <-
             GenAI.Session.State.monitor(state, runtime, context, options) do
        {:ok, %GenAI.Thread.Session{session | state: state, runtime: runtime}}
      end
    end

    # -------------------------------------
    #
    # -------------------------------------
    @doc """
    Run inference.

      This function performs the following steps:
    * Picks the appropriate model and hyperparameters based on the provided context and settings.
    * Performs any necessary pre-processing, such as RAG (Retrieval-Augmented Generation) or message consolidation.
    * Runs inference on the selected model with the prepared input.
    * Returns the inference result.
    """
    def execute(thread_context, command, context, options \\ nil)

    def execute(session, :run = command, context, options) do
      context = context || Noizu.Context.system()

      with {:ok, session} <-
             initialize_session(command, session, context, options),
           GenAI.Records.Node.process_end(session: session) <-
             GenAI.Graph.Root.process_node(session.root, nil, nil, session, context, options),
           {:ok, session} <- GenAI.Thread.Session.apply_directives(session, context, options),
           {:effective_model, {model, session}} <-
             session
             |> GenAI.ThreadProtocol.effective_model(context, options)
             |> apply_label(:effective_model),
           {:effective_provider, provider} <-
             model
             |> GenAI.ModelProtocol.provider()
             |> apply_label(:effective_provider) do
        # We now have a session populated with directives.
        # We need to expand directives then run inferences.
        # response = [:get_model, :get_provider, :call_provider_execute]
        # {:ok, {response, session}}

        provider.run(session, context, options)
      end
    end

    def execute(session, :stream, context, options) do
      context = context || Noizu.Context.system()

      with {:ok, session} <-
             initialize_session(:stream, session, context, options),
           GenAI.Records.Node.process_end(session: session) <-
             GenAI.Graph.Root.process_node(session.root, nil, nil, session, context, options),
           {:ok, session} <- GenAI.Thread.Session.apply_directives(session, context, options),
           {:effective_model, {model, session}} <-
             session
             |> GenAI.ThreadProtocol.effective_model(context, options)
             |> apply_label(:effective_model),
           {:effective_provider, provider} <-
             model
             |> GenAI.ModelProtocol.provider()
             |> apply_label(:effective_provider) do
        # We now have a session populated with directives.
        # We need to expand directives then run inferences.
        # response = [:get_model, :get_provider, :call_provider_execute]
        # {:ok, {response, session}}
        # TODO - handle existing stream options
        session = GenAI.with_setting(session, :stream, true)
        provider.stream(session, context, options)
      end
    end
    
    
    def effective_model(thread_context, context, options)

    def effective_model(this, _, _),
      do: {:ok, {this.state.model.value, this}}

    def effective_settings(thread_context, context, options)

    def effective_settings(this, _, _) do
      x =
        Enum.map(
          this.state.settings,
          fn
            {setting, value} ->
              {setting, value.value}
          end
        )

      {:ok, {x, this}}
    end

    def effective_safety_settings(thread_context, context, options)

    def effective_safety_settings(this, _, _) do
      x =
        Enum.map(
          this.state.safety_settings,
          fn
            {setting, value} ->
              {setting, value.value}
          end
        )

      {:ok, {x, this}}
    end

    def effective_model_settings(thread_context, model, context, options)

    def effective_model_settings(this, model, _, _) do
      with {:ok, name} <- GenAI.ModelProtocol.name(model),
           {:ok, provider} <- GenAI.ModelProtocol.provider(model) do
        key = {name, provider}

        x =
          Enum.map(
            this.state.model_settings[key] || [],
            fn
              {setting, value} ->
                {setting, value.value}
            end
          )

        {:ok, {x, this}}
      end
    end

    def effective_provider_settings(thread_context, model, context, options)

    def effective_provider_settings(this, model, _, _) do
      with {:ok, provider} <- GenAI.ModelProtocol.provider(model) do
        x =
          Enum.map(
            this.state.provider_settings[provider] || [],
            fn
              {setting, value} ->
                {setting, value.value}
            end
          )

        {:ok, {x, this}}
      end
    end

    def effective_messages(thread_context, model, context, options)

    def effective_messages(thread_context, model, context, options) do
      with {:ok, encoder} <- GenAI.ModelProtocol.encoder(model) do
        {messages, thread_context} =
          thread_context.state.thread
          |> Enum.reverse()
          |> Enum.map_reduce(
            thread_context,
            &encode_message(
              encoder,
              thread_context.state.thread_messages[&1],
              model,
              &2,
              context,
              options
            )
          )

        errors = errors(messages)

        if errors != [] do
          {:error, {:format_messages, errors}}
        else
          encoder.normalize_messages(messages, model, thread_context, context, options)
        end
      end
    end

    defp encode_message(encoder, message, model, session, context, options) do
      case encoder.encode_message(message.value, model, session, context, options) do
        {:ok, {encoded_message, session}} -> {encoded_message, session}
        error = {:error, _} -> {error, session}
      end
    end

    def effective_tools(thread_context, model, context, options)

    def effective_tools(thread_context, model, context, options) do
      with {:ok, encoder} <- GenAI.ModelProtocol.encoder(model) do
        {tools, thread_context} =
          thread_context.state.tools
          |> Enum.map_reduce(
            thread_context,
            &encode_tool(encoder, &1, model, &2, context, options)
          )

        errors = errors(tools)

        cond do
          errors != [] ->
            {:error, {:format_tools, errors}}

          tools == [] ->
            {:ok, {nil, thread_context}}

          :else ->
            {:ok, {tools, thread_context}}
        end
      end
    end

    defp encode_tool(encoder, {_, tool}, model, thread_context, context, options) do
      case encoder.encode_tool(tool.value, model, thread_context, context, options) do
        {:ok, {encoded_tool, thread_context}} -> {encoded_tool, thread_context}
        error = {:error, _} -> {error, thread_context}
      end
    end

    def set_artifact(thread_context, artifact, value)

    def set_artifact(thread_context, artifact, value) do
      thread_context
      |> put_in([Access.key(:state), Access.key(:artifacts), artifact], value)
      |> ok()
    end

    def get_artifact(thread_context, artifact)

    def get_artifact(thread_context, artifact) do
      thread_context
      |> get_in([Access.key(:state), Access.key(:artifacts), artifact])
      |> ok()
    end

    defp errors(messages) do
      messages
      |> Enum.filter(&Kernel.match?({:error, _}, &1))
    end

    defp ok(response), do: {:ok, response}
  end
end