lib/exla/lib.ex

defmodule EXLA.Lib do
  @moduledoc """
  High-level operations built on top of `EXLA.Op`.
  """

  alias EXLA.{Builder, Op, Shape}

  @doc """
  Element-wise tangent function.
  """
  def tan(%Op{} = op) do
    Op.divide(Op.sin(op), Op.cos(op))
  end

  @doc """
  Builds iota along axis.
  """
  def iota(builder, shape, nil) do
    total_elems = Nx.size(shape.dims)

    Op.reshape(
      Op.iota(builder, EXLA.Shape.make_shape(shape.dtype, {total_elems}), 0),
      shape.dims
    )
  end

  def iota(builder, shape, axis) do
    Op.iota(builder, shape, axis)
  end

  @doc """
  Computes the argmax of the given operation.

  ## Options

    * `:axis` - the axis to reduce on
    * `:keep_axis` - whether or not to keep reduced axis
    * `:tie_break` - how to break ties
  """
  def argmax(%Builder{} = builder, %Op{} = op, opts \\ []) do
    argmin_or_max(builder, op, false, opts)
  end

  @doc """
  Computes the argmin of the given operation.

  ## Options

    * `:axis` - the axis to reduce on
    * `:keep_axis` - whether or not to keep reduced axis
    * `:tie_break` - how to break ties
  """
  def argmin(%Builder{} = builder, %Op{} = op, opts \\ []) do
    argmin_or_max(builder, op, true, opts)
  end

  defp argmin_or_max(builder, op, is_min?, opts) do
    tie_break = opts[:tie_break] || :low
    keep_axis = opts[:keep_axis] || false
    op_shape = Op.get_shape(op)
    type = opts[:type] || op_shape.dtype

    init_value =
      if is_min?,
        do: max_number(builder, op_shape.dtype),
        else: min_number(builder, op_shape.dtype)

    axis = opts[:axis]
    index_init_value = Op.constant_r0(builder, 0, type)
    iota = iota(builder, Shape.make_shape(type, op_shape.dims), axis)
    reduction = create_min_max_computation(builder, op_shape.dtype, is_min?, tie_break)

    result =
      builder
      |> Op.variadic_reduce(
        [op, iota],
        [init_value, index_init_value],
        reduction,
        if(axis, do: {axis}, else: List.to_tuple(Nx.axes(op_shape.dims)))
      )
      |> Op.get_tuple_element(1)

    if keep_axis do
      Op.reshape(result, put_elem(op_shape.dims, axis, 1))
    else
      result
    end
  end

  defp create_min_max_computation(builder, type, is_min?, tie_break) do
    sub_builder = subbuilder(builder, "min-max")

    lhs_value = Op.parameter(sub_builder, 0, Shape.make_shape(type, {}), "lhs_value")
    lhs_index = Op.parameter(sub_builder, 1, Shape.make_shape({:s, 64}, {}), "lhs_index")
    rhs_value = Op.parameter(sub_builder, 2, Shape.make_shape(type, {}), "rhs_value")
    rhs_index = Op.parameter(sub_builder, 3, Shape.make_shape({:s, 64}, {}), "rhs_index")

    cmp =
      if is_min?,
        do: Op.less_equal(lhs_value, rhs_value),
        else: Op.greater_equal(lhs_value, rhs_value)

    max = Op.select(cmp, lhs_value, rhs_value)
    arg_max = Op.select(cmp, lhs_index, rhs_index)

    arg_max =
      case tie_break do
        :low ->
          eq? = Op.equal(lhs_value, rhs_value)
          id = Op.min(lhs_index, rhs_index)
          Op.select(eq?, id, arg_max)

        :high ->
          eq? = Op.equal(lhs_value, rhs_value)
          id = Op.max(lhs_index, rhs_index)
          Op.select(eq?, id, arg_max)
      end

    ast = Op.tuple(sub_builder, [max, arg_max])
    Builder.build(ast)
  end

  @doc """
  Returns a minimum value scalar operator for the given type.

  It will be negative infinity for floating point types.
  """
  def min_number(%Builder{} = builder, type) do
    Op.constant_from_binary(builder, min_binary(type), Shape.make_shape(type, {}))
  end

  @doc """
  Returns a maximum value scalar operator for the given type.

  Maximum values are defined in `Nx.Type.max_finite_binary/1`.
  """
  def max_number(builder, type) do
    Op.constant_from_binary(builder, max_binary(type), Shape.make_shape(type, {}))
  end

  defp subbuilder(%Builder{name: name} = builder, desc) do
    suffix = System.unique_integer([:positive])
    Builder.new(builder, name <> "-" <> desc <> "-" <> Integer.to_string(suffix))
  end

  defp min_binary({:pred, 8}), do: <<0>>
  defp min_binary(type), do: Nx.Type.min_binary(type)

  defp max_binary({:pred, 8}), do: <<1>>
  defp max_binary(type), do: Nx.Type.max_binary(type)

  @doc """
  Sorts a tensor and returns the corresponding indices in the new positions.
  """
  def argsort(builder, operand, dimension, comparator, iota_type) do
    shape = EXLA.Op.get_shape(operand)
    iota = iota(builder, Shape.make_shape(iota_type, shape.dims), dimension)

    builder
    |> Op.variadic_sort(
      [operand, iota],
      comparator,
      dimension
    )
    |> Op.get_tuple_element(1)
  end
end