## lib/axon/metrics.ex

defmodule Axon.Metrics do
@moduledoc """
Metric functions.

Metrics are used to measure the performance and compare
performance of models in easy-to-understand terms. Often
times, neural networks use surrogate loss functions such
as negative log-likelihood to indirectly optimize a certain
performance metric. Metrics such as accuracy, also called
the 0-1 loss, do not have useful derivatives (e.g. they
are information sparse), and are often intractable even
with low input dimensions.

Despite not being able to train specifically for certain
metrics, it's still useful to track these metrics to
monitor the performance of a neural network during training.
Metrics such as accuracy provide useful feedback during
training, whereas loss can sometimes be difficult to interpret.

You can attach any of these functions as metrics within the
Axon.Loop API using Axon.Loop.metric/3.

All of the functions in this module are implemented as
numerical functions and can be JIT or AOT compiled with
any supported Nx compiler.
"""

import Nx.Defn

# Standard Metrics

@doc ~S"""
Computes the accuracy of the given predictions.

If the size of the last axis is 1, it performs a binary
accuracy computation with a threshold of 0.5. Otherwise,
computes categorical accuracy.

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Examples

iex> Axon.Metrics.accuracy(Nx.tensor([[1], [0], [0]]), Nx.tensor([[1], [1], [1]]))
#Nx.Tensor<
f32
0.3333333432674408
>

iex> Axon.Metrics.accuracy(Nx.tensor([[0, 1], [1, 0], [1, 0]]), Nx.tensor([[0, 1], [1, 0], [0, 1]]))
#Nx.Tensor<
f32
0.6666666865348816
>

iex> Axon.Metrics.accuracy(Nx.tensor([[0, 1, 0], [1, 0, 0]]), Nx.tensor([[0, 1, 0], [0, 1, 0]]))
#Nx.Tensor<
f32
0.5
>

"""
defn accuracy(y_true, y_pred) do
if elem(Nx.shape(y_pred), Nx.rank(y_pred) - 1) == 1 do
y_pred
|> Nx.greater(0.5)
|> Nx.equal(y_true)
|> Nx.mean()
else
y_true
|> Nx.argmax(axis: -1)
|> Nx.equal(Nx.argmax(y_pred, axis: -1))
|> Nx.mean()
end
end

@doc ~S"""
Computes the precision of the given predictions with
respect to the given targets.

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Options

* :threshold - threshold for truth value of the predictions.
Defaults to 0.5

## Examples

iex> Axon.Metrics.precision(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>

"""
defn precision(y_true, y_pred, opts \\ []) do
true_positives = true_positives(y_true, y_pred, opts)
false_positives = false_positives(y_true, y_pred, opts)

true_positives
|> Nx.divide(true_positives + false_positives + 1.0e-16)
end

@doc ~S"""
Computes the recall of the given predictions with
respect to the given targets.

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Options

* :threshold - threshold for truth value of the predictions.
Defaults to 0.5

## Examples

iex> Axon.Metrics.recall(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>

"""
defn recall(y_true, y_pred, opts \\ []) do
true_positives = true_positives(y_true, y_pred, opts)
false_negatives = false_negatives(y_true, y_pred, opts)

Nx.divide(true_positives, false_negatives + true_positives + 1.0e-16)
end

@doc """
Computes the number of true positive predictions with respect
to given targets.

## Options

* :threshold - threshold for truth value of predictions.
Defaults to 0.5.

## Examples

iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_positives(y_true, y_pred)
#Nx.Tensor<
u64
1
>
"""
defn true_positives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])

thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()
end

@doc """
Computes the number of false negative predictions with respect
to given targets.

## Options

* :threshold - threshold for truth value of predictions.
Defaults to 0.5.

## Examples

iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_negatives(y_true, y_pred)
#Nx.Tensor<
u64
3
>
"""
defn false_negatives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])

thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()
end

@doc """
Computes the number of true negative predictions with respect
to given targets.

## Options

* :threshold - threshold for truth value of predictions.
Defaults to 0.5.

## Examples

iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_negatives(y_true, y_pred)
#Nx.Tensor<
u64
1
>
"""
defn true_negatives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])

thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()
end

@doc """
Computes the number of false positive predictions with respect
to given targets.

## Options

* :threshold - threshold for truth value of predictions.
Defaults to 0.5.

## Examples

iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_positives(y_true, y_pred)
#Nx.Tensor<
u64
2
>
"""
defn false_positives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])

thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()
end

@doc ~S"""
Computes the sensitivity of the given predictions
with respect to the given targets.

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Options

* :threshold - threshold for truth value of the predictions.
Defaults to 0.5

## Examples

iex> Axon.Metrics.sensitivity(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>

"""
defn sensitivity(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

recall(y_true, y_pred, opts)
end

@doc ~S"""
Computes the specificity of the given predictions
with respect to the given targets.

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Options

* :threshold - threshold for truth value of the predictions.
Defaults to 0.5

## Examples

iex> Axon.Metrics.specificity(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.0
>

"""
defn specificity(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)

thresholded_preds = Nx.greater(y_pred, opts[:threshold])

true_negatives =
thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()

false_positives =
thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()

Nx.divide(true_negatives, false_positives + true_negatives + 1.0e-16)
end

@doc ~S"""
Calculates the mean absolute error of predictions
with respect to targets.

$$l_i = \sum_i |\hat{y_i} - y_i|$$

## Argument Shapes

* y_true - $$$d_0, d_1, ..., d_n$$$
* y_pred - $$$d_0, d_1, ..., d_n$$$

## Examples

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Metrics.mean_absolute_error(y_true, y_pred)
#Nx.Tensor<
f32
0.5
>
"""
defn mean_absolute_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.abs()
|> Nx.mean()
end

# Combinators

@doc """
Returns a function which computes a running average given current average,
new observation, and current iteration.

## Examples

iex> cur_avg = 0.5
iex> iteration = 1
iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> y_pred = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> avg_acc = Axon.Metrics.running_average(&Axon.Metrics.accuracy/2)
iex> avg_acc.(cur_avg, [y_true, y_pred], iteration)
#Nx.Tensor<
f32
0.75
>
"""
def running_average(metric) do
&running_average_impl(&1, apply(metric, &2), &3)
end

defnp running_average_impl(avg, obs, i) do
avg
|> Nx.multiply(i)
end

@doc """
Returns a function which computes a running sum given current sum,
new observation, and current iteration.

## Examples

iex> cur_sum = 12
iex> iteration = 2
iex> y_true = Nx.tensor([0, 1, 0, 1])
iex> y_pred = Nx.tensor([1, 1, 0, 1])
iex> fps = Axon.Metrics.running_sum(&Axon.Metrics.false_positives/2)
iex> fps.(cur_sum, [y_true, y_pred], iteration)
#Nx.Tensor<
s64
13
>
"""
def running_sum(metric) do
&running_sum_impl(&1, apply(metric, &2), &3)
end

defnp running_sum_impl(sum, obs, _) do