lib/axon/activations.ex

defmodule Axon.Activations do
  @moduledoc """
  Activation functions.

  Activation functions are element-wise, (typically) non-linear
  functions called on the output of another layer, such as
  a dense layer:

      x
      |> dense(weight, bias)
      |> relu()

  Activation functions output the "activation" or how active
  a given layer's neurons are in learning a representation
  of the data-generating distribution.

  Some activations are commonly used as output activations. For
  example `softmax` is often used as the output in multiclass
  classification problems because it returns a categorical
  probability distribution:

      iex> Axon.Activations.softmax(Nx.tensor([[1, 2, 3]], type: {:f, 32}))
      #Nx.Tensor<
        f32[1][3]
        [
          [0.09003057330846786, 0.2447284758090973, 0.6652409434318542]
        ]
      >

  Other activations such as `tanh` or `sigmoid` are used because
  they have desirable properties, such as keeping the output
  tensor constrained within a certain range.

  Generally, the choice of activation function is arbitrary;
  although some activations work better than others in certain
  problem domains. For example ReLU (rectified linear unit)
  activation is a widely-accepted default. You can see
  a list of activation functions and implementations
  [here](https://paperswithcode.com/methods/category/activation-functions).

  All of the functions in this module are implemented as
  numerical functions and can be JIT or AOT compiled with
  any supported `Nx` compiler.
  """

  import Nx.Defn
  import Axon.Shared

  @doc ~S"""
  Continuously-differentiable exponential linear unit activation.

  $$f(x_i) = \max(0, x_i) + \min(0, \alpha * e^{\frac{x_i}{\alpha}} - 1)$$

  ## Options

    * `alpha` - $\alpha$ in CELU formulation. Must be non-zero.
      Defaults to `1.0`

  ## Examples

      iex> Axon.Activations.celu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]))
      #Nx.Tensor<
        f32[7]
        [-0.9502129554748535, -0.8646647334098816, -0.6321205496788025, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.celu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}))
      #Nx.Tensor<
        bf16[2][3]
        [
          [-0.62890625, -0.86328125, -0.94921875],
          [1.0, 2.0, 3.0]
        ]
      >

  ### Error cases

      iex> Axon.Activations.celu(Nx.tensor([0.0, 1.0, 2.0], type: {:f, 32}), alpha: 0.0)
      ** (ArgumentError) :alpha must be non-zero in CELU activation

  ## References

    * [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf)

  """
  defn celu(x, opts \\ []) do
    opts = keyword!(opts, alpha: 1.0)

    transform(
      opts[:alpha],
      fn x ->
        if Elixir.Kernel.==(x, 0),
          do: raise(ArgumentError, ":alpha must be non-zero in CELU activation")
      end
    )

    Nx.select(Nx.greater(x, 0), x, opts[:alpha] * Nx.expm1(x / opts[:alpha]))
  end

  @doc ~S"""
  Exponential linear unit activation.

  Equivalent to `celu` for $\alpha = 1$

  $$f(x_i) = \begin{cases}x_i & x _i > 0 \newline \alpha * (e^{x_i} - 1) & x_i \leq 0 \\ \end{cases}$$

  ## Options

    * `alpha` - $\alpha$ in ELU formulation. Defaults to `1.0`

  ## Examples

      iex> Axon.Activations.elu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]))
      #Nx.Tensor<
        f32[7]
        [-0.9502129554748535, -0.8646647334098816, -0.6321205496788025, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.elu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}))
      #Nx.Tensor<
        bf16[2][3]
        [
          [-0.62890625, -0.86328125, -0.94921875],
          [1.0, 2.0, 3.0]
        ]
      >

  ## References

    * [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)

  """
  defn elu(x, opts \\ []) do
    opts = keyword!(opts, alpha: 1.0)
    x_hat = Nx.select(Nx.greater(x, 0), 0, x)
    Nx.select(Nx.greater(x, 0), x, opts[:alpha] * Nx.expm1(x_hat))
  end

  @doc ~S"""
  Exponential activation.

  $$f(x_i) = e^{x_i}$$

  ## Examples

      iex> Axon.Activations.exp(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [0.049787066876888275, 0.1353352814912796, 0.3678794503211975, 1.0, 2.7182817459106445, 7.389056205749512, 20.08553695678711]
      >

      iex> Axon.Activations.exp(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.3671875, 0.134765625, 0.049560546875],
          [2.703125, 7.375, 20.0]
        ]
      >

  """
  defn exp(x) do
    Nx.exp(x)
  end

  @doc ~S"""
  Gaussian error linear unit activation.

  $$f(x_i) = \frac{x_i}{2}(1 + {erf}(\frac{x_i}{\sqrt{2}}))$$

  ## Examples

      iex> Axon.Activations.gelu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.0040496885776519775, -0.04550027847290039, -0.15865525603294373, 0.0, 0.8413447141647339, 1.9544997215270996, 2.995950222015381]
      >

      iex> Axon.Activations.gelu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.16015625, -0.046875, -0.005859375],
          [0.83984375, 1.953125, 2.984375]
        ]
      >

  ## References

    * [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)

  """
  defn gelu(x) do
    sqrt2 = Nx.sqrt(Nx.tensor(2, type: Nx.type(x)))

    x
    |> Nx.divide(sqrt2)
    |> Nx.erf()
    |> Nx.add(1)
    |> Nx.multiply(x)
    |> Nx.divide(2)
  end

  @doc ~S"""
  Hard sigmoid activation.

  ## Examples

      iex> Axon.Activations.hard_sigmoid(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [0.0, 0.0, 0.0, 0.20000000298023224, 0.4000000059604645, 0.6000000238418579, 0.800000011920929]
      >

      iex> Axon.Activations.hard_sigmoid(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [7.781982421875e-4, 0.0, 0.0],
          [0.3984375, 0.59765625, 0.796875]
        ]
      >

  """
  defn hard_sigmoid(x, opts \\ []) do
    opts = keyword!(opts, alpha: 0.2, beta: 0.2)

    x
    |> Nx.multiply(opts[:alpha])
    |> Nx.add(opts[:beta])
    |> Nx.max(0)
    |> Nx.min(1)
  end

  @doc ~S"""
  Hard sigmoid weighted linear unit activation.

  $$f(x_i) = \begin{cases} 0 & x_i \leq -3 \newline
  x & x_i \geq 3 \newline
  \frac{x_i^2}{6} + \frac{x_i}{2} & otherwise \end{cases}$$

  ## Examples

      iex> Axon.Activations.hard_silu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.0, -0.0, -0.0, 0.0, 0.4000000059604645, 1.2000000476837158, 2.4000000953674316]
      >

      iex> Axon.Activations.hard_silu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-7.781982421875e-4, -0.0, -0.0],
          [0.3984375, 1.1953125, 2.390625]
        ]
      >

  """
  defn hard_silu(x, opts \\ []) do
    x
    |> hard_sigmoid(opts)
    |> Nx.multiply(x)
  end

  @doc ~S"""
  Hard hyperbolic tangent activation.

  $$f(x_i) = \begin{cases} 1 & x > 1 \newline -1 & x < -1 \newline x & otherwise \end{cases}$$

  ## Examples

      iex> Axon.Activations.hard_tanh(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-1.0, -1.0, -1.0, 0.0, 1.0, 1.0, 1.0]
      >

      iex> Axon.Activations.hard_tanh(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-1.0, -1.0, -1.0],
          [1.0, 1.0, 1.0]
        ]
      >

  """
  defn hard_tanh(x) do
    Nx.select(
      Nx.greater(x, 1),
      1,
      Nx.select(Nx.less(x, -1), -1, x)
    )
  end

  @doc ~S"""
  Leaky rectified linear unit activation.

  $$f(x_i) = \begin{cases} x & x \geq 0 \newline \alpha * x & otherwise \end{cases}$$

  ## Options

    * `:alpha` - $\alpha$ in Leaky ReLU formulation. Defaults to `1.0e-2`

  ## Examples

      iex> Axon.Activations.leaky_relu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]), alpha: 0.5)
      #Nx.Tensor<
        f32[data: 7]
        [-1.5, -1.0, -0.5, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.leaky_relu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], names: [:batch, :data]), alpha: 0.5)
      #Nx.Tensor<
        f32[batch: 2][data: 3]
        [
          [-0.5, -1.0, -1.5],
          [1.0, 2.0, 3.0]
        ]
      >

  """
  defn leaky_relu(x, opts \\ []) do
    opts = keyword!(opts, alpha: 1.0e-2)
    Nx.select(Nx.greater(x, 0), x, x * opts[:alpha])
  end

  @doc ~S"""
  Linear activation.

  $$f(x_i) = x_i$$

  ## Examples

      iex> Axon.Activations.linear(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.linear(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-1.0, -2.0, -3.0],
          [1.0, 2.0, 3.0]
        ]
      >

  """
  defn linear(x), do: x

  @doc ~S"""
  Logsumexp activation.

  $$\log(sum e^x_i)$$

  ## Examples

      iex> Axon.Activations.log_sumexp(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 1]
        [0.45776283740997314]
      >

      iex> Axon.Activations.log_sumexp(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 1]
        [
          [0.404296875],
          [0.404296875]
        ]
      >

  """
  defn log_sumexp(x, opts \\ []) do
    opts = keyword!(opts, axis: -1)
    axes = transform(opts[:axis], &List.wrap/1)

    # This is a scaling term designed to prevent over/under flow when x is very
    # large. Consider cases where the intermediate value e^x with large positive
    # x, e^x tends towards infinity or 0. This poisons the rest of the
    # calculation which would otherwise be normalized with the division by sum(e^x).
    # Thus we can scale by the max value in the tensor which guarantees all values
    # are smaller than 0.
    #
    # Given the expression is essentially:
    #
    # e^(x - C) / sum(e^(x - C))
    #
    # We are essentially treating the max value as a constant term, C. Thus there
    # is no need to differentiate through the max. See also: https://github.com/google/jax/pull/2260
    # for a note on performance.
    max_val = stop_grad(Nx.reduce_max(x, axes: axes, keep_axes: true))

    stable_exp =
      x
      |> Nx.subtract(max_val)
      |> Nx.exp()

    res =
      stable_exp
      |> Nx.sum(axes: axes, keep_axes: true)
      |> Nx.log()

    res
  end

  @doc ~S"""
  Log-sigmoid activation.

  $$f(x_i) = \log(\sigmoid(x))$$

  ## Examples

      iex> Axon.Activations.log_sigmoid(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-3.0485873222351074, -2.1269280910491943, -1.3132617473602295, -0.6931471824645996, -0.3132616877555847, -0.12692801654338837, -0.04858734831213951]
      >

      iex> Axon.Activations.log_sigmoid(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-1.3125, -2.125, -3.046875],
          [-0.3125, -0.1259765625, -0.04833984375]
        ]
      >

  """
  defn log_sigmoid(x), do: -softplus(-x)

  @doc """
  Log-softmax activation.

  $$f(x_i) = -\log(\sum{e^x_i})$$

  ## Examples

      iex> Axon.Activations.log_softmax(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-6.457762718200684, -5.457762718200684, -4.457762718200684, -3.4577627182006836, -2.4577627182006836, -1.4577628374099731, -0.45776283740997314]
      >

      iex> Axon.Activations.log_softmax(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.404296875, -1.3984375, -2.390625],
          [-2.390625, -1.3984375, -0.404296875]
        ]
      >
  """
  defn log_softmax(x, opts \\ []) do
    opts = keyword!(opts, axis: -1)

    transform({x, opts}, fn {x, opts} ->
      if Elixir.Kernel.<=(Nx.rank(x), opts[:axis]) do
        raise ArgumentError, "log_softmax axis must be within rank of tensor"
      end
    end)

    shifted = x - stop_grad(Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true))

    shifted
    |> Nx.exp()
    |> Nx.sum(axes: [opts[:axis]], keep_axes: true)
    |> Nx.log()
    |> Nx.negate()
    |> Nx.add(shifted)
  end

  @doc ~S"""
  Mish activation.

  $$f(x_i) = x_i* \tanh(\log(1 + e^x_i))$$

  ## Examples

      iex> Axon.Activations.mish(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], type: {:f, 32}, names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.14564745128154755, -0.2525014877319336, -0.30340147018432617, 0.0, 0.8650984168052673, 1.9439589977264404, 2.98653507232666]
      >

      iex> Axon.Activations.mish(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.30078125, -0.25, -0.1435546875],
          [0.86328125, 1.9375, 2.96875]
        ]
      >
  """
  defn mish(x) do
    x * tanh(softplus(x))
  end

  @doc ~S"""
  Rectified linear unit activation.

  $$f(x_i) = \max_i(x, 0)$$

  ## Examples

      iex> Axon.Activations.relu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.relu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.0, 0.0, 0.0],
          [1.0, 2.0, 3.0]
        ]
      >

  """
  defn relu(x) do
    custom_grad(
      Nx.max(x, 0),
      fn _ans, g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end
    )
  end

  @doc ~S"""
  Rectified linear unit 6 activation.

  $$f(x_i) = \min_i(\max_i(x, 0), 6)$$

  ## Examples

      iex> Axon.Activations.relu6(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]))
      #Nx.Tensor<
        f32[7]
        [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0]
      >

      iex> Axon.Activations.relu6(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.0, 0.0, 0.0],
          [1.0, 2.0, 3.0]
        ]
      >

  ## References

    * [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861v1)

  """
  defn relu6(x) do
    x
    |> Nx.max(0)
    |> Nx.min(6)
  end

  @doc ~S"""
  Sigmoid activation.

  $$f(x_i) = \frac{1}{1 + e^{-x_i}}$$

  **Implementation Note: Sigmoid logits are cached as metadata
  in the expression and can be used in calculations later on.
  For example, they are used in cross-entropy calculations for
  better stability.**

  ## Examples

      iex> Axon.Activations.sigmoid(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [0.04742587357759476, 0.11920291930437088, 0.2689414322376251, 0.5, 0.7310585975646973, 0.8807970881462097, 0.9525741338729858]
      >

      iex> Axon.Activations.sigmoid(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.267578125, 0.119140625, 0.04736328125],
          [0.73046875, 0.87890625, 0.94921875]
        ]
      >

  """
  defn sigmoid(x) do
    # Cache logits so they are available in certain calculations,
    # e.g. binary_cross_entropy and categorical_cross_entropy
    transform(Nx.sigmoid(x), &Nx.Defn.Expr.metadata(&1, %{logits: x}))
  end

  @doc ~S"""
  Sigmoid weighted linear unit activation.

  $$f(x_i) = x\sigmoid(x)$$

  ## Examples

      iex> Axon.Activations.silu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.14227762818336487, -0.23840583860874176, -0.2689414322376251, 0.0, 0.7310585975646973, 1.7615941762924194, 2.857722282409668]
      >

      iex> Axon.Activations.silu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.267578125, -0.23828125, -0.1416015625],
          [0.73046875, 1.7578125, 2.84375]
        ]
      >

  ## References

    * [Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning](https://arxiv.org/abs/1702.03118v3)

  """
  defn silu(x) do
    x
    |> Nx.sigmoid()
    |> Nx.multiply(x)
  end

  @doc ~S"""
  Scaled exponential linear unit activation.

  $$f(x_i) = \begin{cases} \lambda x & x \geq 0 \newline
  \lambda \alpha(e^{x} - 1) & x < 0 \end{cases}$$

  $$\alpha \approx 1.6733$$
  $$\lambda \approx 1.0507$$

  ## Examples

      iex> Axon.Activations.selu(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-1.670568823814392, -1.5201665163040161, -1.1113307476043701, 0.0, 1.0507010221481323, 2.1014020442962646, 3.1521029472351074]
      >

      iex> Axon.Activations.selu(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-1.09375, -1.5078125, -1.6640625],
          [1.046875, 2.09375, 3.140625]
        ]
      >

  ## References

    * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515v5)

  """
  defn selu(x, opts \\ []) do
    opts =
      keyword!(opts,
        alpha: 1.6732632423543772848170429916717,
        gamma: 1.0507009873554804934193349852946
      )

    opts[:gamma] * elu(x, alpha: opts[:alpha])
  end

  @doc ~S"""
  Softmax activation along an axis.

  $$\frac{e^{x_i}}{\sum_i e^{x_i}}$$

  **Implementation Note: Softmax logits are cached as metadata
  in the expression and can be used in calculations later on.
  For example, they are used in cross-entropy calculations for
  better stability.**

  ## Options

    * `:axis` - softmax axis along which to calculate distribution.
      Defaults to 1.

  ## Examples

      iex> Axon.Activations.softmax(Nx.tensor([[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]], names: [:batch, :data]))
      #Nx.Tensor<
        f32[batch: 1][data: 7]
        [
          [0.0015683004166930914, 0.004263082519173622, 0.011588259600102901, 0.03150015324354172, 0.08562629669904709, 0.23275642096996307, 0.6326975226402283]
        ]
      >

      iex> Axon.Activations.softmax(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.6640625, 0.2431640625, 0.08935546875],
          [0.08935546875, 0.2431640625, 0.6640625]
        ]
      >

  """
  defn softmax(x, opts \\ []) do
    opts = keyword!(opts, axis: -1)
    axes = transform(opts[:axis], &List.wrap/1)

    transform({x, axes}, fn {x, axes} ->
      Enum.each(axes, fn axis ->
        Nx.Shape.normalize_axis(Nx.shape(x), axis, Nx.names(x))
      end)
    end)

    # This is a scaling term designed to prevent over/under flow when x is very
    # large. Consider cases where the intermediate value e^x with large positive
    # x, e^x tends towards infinity or 0. This poisons the rest of the
    # calculation which would otherwise be normalized with the division by sum(e^x).
    # Thus we can scale by the max value in the tensor which guarantees all values
    # are smaller than 0.
    #
    # Given the expression is essentially:
    #
    # e^(x - C) / sum(e^(x - C))
    #
    # We are essentially treating the max value as a constant term, C. Thus there
    # is no need to differentiate through the max. See also: https://github.com/google/jax/pull/2260
    # for a note on performance.
    max_val = stop_grad(Nx.reduce_max(x, axes: axes, keep_axes: true))

    stable_exp =
      x
      |> Nx.subtract(max_val)
      |> Nx.exp()

    res =
      stable_exp
      |> Nx.sum(axes: axes, keep_axes: true)
      |> reciprocal()
      |> Nx.multiply(stable_exp)

    # Cache logits so they are available in certain calculations,
    # e.g. binary_cross_entropy and categorical_cross_entropy
    transform(res, &Nx.Defn.Expr.metadata(&1, %{logits: x}))
  end

  @doc ~S"""
  Softplus activation.

  $$\log(1 + e^x_i)$$

  ## Examples

      iex> Axon.Activations.softplus(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [0.04858734831213951, 0.12692801654338837, 0.3132616877555847, 0.6931471824645996, 1.3132617473602295, 2.1269280910491943, 3.0485873222351074]
      >

      iex> Axon.Activations.softplus(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [0.3125, 0.1259765625, 0.04833984375],
          [1.3125, 2.125, 3.046875]
        ]
      >

  """
  defn softplus(x) do
    stable = Nx.max(0.0, x)

    x
    |> Nx.abs()
    |> Nx.negate()
    |> Nx.exp()
    |> Nx.log1p()
    |> Nx.add(stable)
  end

  @doc ~S"""
  Softsign activation.

  $$f(x_i) = \frac{x_i}{|x_i| + 1}$$

  ## Examples

      iex> Axon.Activations.softsign(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.75, -0.6666666865348816, -0.5, 0.0, 0.5, 0.6666666865348816, 0.75]
      >

      iex> Axon.Activations.softsign(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.5, -0.6640625, -0.75],
          [0.5, 0.6640625, 0.75]
        ]
      >

  """
  defn softsign(x) do
    x
    |> Nx.abs()
    |> Nx.add(1)
    |> reciprocal()
    |> Nx.multiply(x)
  end

  @doc ~S"""
  Hyperbolic tangent activation.

  $$f(x_i) = \tanh(x_i)$$

  ## Examples

      iex> Axon.Activations.tanh(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data]))
      #Nx.Tensor<
        f32[data: 7]
        [-0.9950547814369202, -0.9640275835990906, -0.7615941762924194, 0.0, 0.7615941762924194, 0.9640275835990906, 0.9950547814369202]
      >

      iex> Axon.Activations.tanh(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data]))
      #Nx.Tensor<
        bf16[batch: 2][data: 3]
        [
          [-0.7578125, -0.9609375, -0.9921875],
          [0.7578125, 0.9609375, 0.9921875]
        ]
      >

  """
  defn tanh(x), do: Nx.tanh(x)
end