lib/lorax/params.ex

defmodule Lorax.Params do
  @moduledoc """
  Helper module for loading, downloading, filtering, and calculating the size of Axon parameters
  """

  @doc """
  Returns LoRA only params from a merged param map
  """
  def filter(lora_merged_params, original_params) do
    original_keys = Map.keys(original_params)
    Map.drop(lora_merged_params, original_keys)
  end

  @doc """
  Loads parameters from file path
  """
  def file_load!(params_path) do
    File.read!(params_path)
    |> Nx.deserialize()
  end

  @doc """
  Creates a Kino widget for uploading serialized params file.
  Must be placed in the last line of a Livebook cell.
  """
  def kino_file_load!(kino_input) do
    value = Kino.Input.read(kino_input)

    case value do
      nil ->
        raise "No param file uploaded"

      value ->
        path = Kino.Input.file_path(value.file_ref)

        try do
          file_load!(path)
        rescue
          ArgumentError -> raise "Invalid param file"
        end
    end
  end

  @doc """
  Creates a Kino widget for downloading params map.
  Must be placed in the last line of a Livebook cell.
  """
  def kino_download(
        params,
        filename \\ "params.lorax",
        label \\ "Download Params"
      ) do
    iodata = Nx.serialize(params)
    binary = IO.iodata_to_binary(iodata)

    Kino.Download.new(
      fn -> binary end,
      filename: filename,
      label: label
    )
  end

  @doc """
  Calculates total bytes of all the tensors inside a parameter map
  """
  def size(%{} = params) do
    Enum.reduce(params, 0, fn {_k, v}, param_size ->
      layer_param_size =
        Enum.reduce(v, 0, fn {_layer_name, tensor}, acc -> acc + Nx.size(tensor) end)

      param_size + layer_param_size
    end)
  end
end