defmodule Axon.Metrics do
@moduledoc """
Metric functions.
Metrics are used to measure the performance and compare
performance of models in easy-to-understand terms. Often
times, neural networks use surrogate loss functions such
as negative log-likelihood to indirectly optimize a certain
performance metric. Metrics such as accuracy, also called
the 0-1 loss, do not have useful derivatives (e.g. they
are information sparse), and are often intractable even
with low input dimensions.
Despite not being able to train specifically for certain
metrics, it's still useful to track these metrics to
monitor the performance of a neural network during training.
Metrics such as accuracy provide useful feedback during
training, whereas loss can sometimes be difficult to interpret.
You can attach any of these functions as metrics within the
`Axon.Loop` API using `Axon.Loop.metric/3`.
All of the functions in this module are implemented as
numerical functions and can be JIT or AOT compiled with
any supported `Nx` compiler.
"""
import Nx.Defn
# Standard Metrics
@doc ~S"""
Computes the accuracy of the given predictions.
If the size of the last axis is 1, it performs a binary
accuracy computation with a threshold of 0.5. Otherwise,
computes categorical accuracy.
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Examples
iex> Axon.Metrics.accuracy(Nx.tensor([[1], [0], [0]]), Nx.tensor([[1], [1], [1]]))
#Nx.Tensor<
f32
0.3333333432674408
>
iex> Axon.Metrics.accuracy(Nx.tensor([[0, 1], [1, 0], [1, 0]]), Nx.tensor([[0, 1], [1, 0], [0, 1]]))
#Nx.Tensor<
f32
0.6666666865348816
>
iex> Axon.Metrics.accuracy(Nx.tensor([[0, 1, 0], [1, 0, 0]]), Nx.tensor([[0, 1, 0], [0, 1, 0]]))
#Nx.Tensor<
f32
0.5
>
"""
defn accuracy(y_true, y_pred) do
if elem(Nx.shape(y_pred), Nx.rank(y_pred) - 1) == 1 do
y_pred
|> Nx.greater(0.5)
|> Nx.equal(y_true)
|> Nx.mean()
else
y_true
|> Nx.argmax(axis: -1)
|> Nx.equal(Nx.argmax(y_pred, axis: -1))
|> Nx.mean()
end
end
@doc ~S"""
Computes the precision of the given predictions with
respect to the given targets.
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Options
* `:threshold` - threshold for truth value of the predictions.
Defaults to `0.5`
## Examples
iex> Axon.Metrics.precision(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>
"""
defn precision(y_true, y_pred, opts \\ []) do
true_positives = true_positives(y_true, y_pred, opts)
false_positives = false_positives(y_true, y_pred, opts)
true_positives
|> Nx.divide(true_positives + false_positives + 1.0e-16)
end
@doc ~S"""
Computes the recall of the given predictions with
respect to the given targets.
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Options
* `:threshold` - threshold for truth value of the predictions.
Defaults to `0.5`
## Examples
iex> Axon.Metrics.recall(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>
"""
defn recall(y_true, y_pred, opts \\ []) do
true_positives = true_positives(y_true, y_pred, opts)
false_negatives = false_negatives(y_true, y_pred, opts)
Nx.divide(true_positives, false_negatives + true_positives + 1.0e-16)
end
@doc """
Computes the number of true positive predictions with respect
to given targets.
## Options
* `:threshold` - threshold for truth value of predictions.
Defaults to `0.5`.
## Examples
iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_positives(y_true, y_pred)
#Nx.Tensor<
u64
1
>
"""
defn true_positives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])
thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()
end
@doc """
Computes the number of false negative predictions with respect
to given targets.
## Options
* `:threshold` - threshold for truth value of predictions.
Defaults to `0.5`.
## Examples
iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_negatives(y_true, y_pred)
#Nx.Tensor<
u64
3
>
"""
defn false_negatives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])
thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()
end
@doc """
Computes the number of true negative predictions with respect
to given targets.
## Options
* `:threshold` - threshold for truth value of predictions.
Defaults to `0.5`.
## Examples
iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_negatives(y_true, y_pred)
#Nx.Tensor<
u64
1
>
"""
defn true_negatives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])
thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()
end
@doc """
Computes the number of false positive predictions with respect
to given targets.
## Options
* `:threshold` - threshold for truth value of predictions.
Defaults to `0.5`.
## Examples
iex> y_true = Nx.tensor([1, 0, 1, 1, 0, 1, 0])
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_positives(y_true, y_pred)
#Nx.Tensor<
u64
2
>
"""
defn false_positives(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
thresholded_preds =
y_pred
|> Nx.greater(opts[:threshold])
thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()
end
@doc ~S"""
Computes the sensitivity of the given predictions
with respect to the given targets.
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Options
* `:threshold` - threshold for truth value of the predictions.
Defaults to `0.5`
## Examples
iex> Axon.Metrics.sensitivity(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.6666666865348816
>
"""
defn sensitivity(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
recall(y_true, y_pred, opts)
end
@doc ~S"""
Computes the specificity of the given predictions
with respect to the given targets.
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Options
* `:threshold` - threshold for truth value of the predictions.
Defaults to `0.5`
## Examples
iex> Axon.Metrics.specificity(Nx.tensor([0, 1, 1, 1]), Nx.tensor([1, 0, 1, 1]))
#Nx.Tensor<
f32
0.0
>
"""
defn specificity(y_true, y_pred, opts \\ []) do
opts = keyword!(opts, threshold: 0.5)
thresholded_preds = Nx.greater(y_pred, opts[:threshold])
true_negatives =
thresholded_preds
|> Nx.equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 0))
|> Nx.sum()
false_positives =
thresholded_preds
|> Nx.not_equal(y_true)
|> Nx.logical_and(Nx.equal(thresholded_preds, 1))
|> Nx.sum()
Nx.divide(true_negatives, false_positives + true_negatives + 1.0e-16)
end
@doc ~S"""
Calculates the mean absolute error of predictions
with respect to targets.
$$l_i = \sum_i |\hat{y_i} - y_i|$$
## Argument Shapes
* `y_true` - $\(d_0, d_1, ..., d_n\)$
* `y_pred` - $\(d_0, d_1, ..., d_n\)$
## Examples
iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Metrics.mean_absolute_error(y_true, y_pred)
#Nx.Tensor<
f32
0.5
>
"""
defn mean_absolute_error(y_true, y_pred) do
y_true
|> Nx.subtract(y_pred)
|> Nx.abs()
|> Nx.mean()
end
# Combinators
@doc """
Returns a function which computes a running average given current average,
new observation, and current iteration.
## Examples
iex> cur_avg = 0.5
iex> iteration = 1
iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> y_pred = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> avg_acc = Axon.Metrics.running_average(&Axon.Metrics.accuracy/2)
iex> avg_acc.(cur_avg, [y_true, y_pred], iteration)
#Nx.Tensor<
f32
0.75
>
"""
def running_average(metric) do
&running_average_impl(&1, apply(metric, &2), &3)
end
defnp running_average_impl(avg, obs, i) do
avg
|> Nx.multiply(i)
|> Nx.add(obs)
|> Nx.divide(Nx.add(i, 1))
end
@doc """
Returns a function which computes a running sum given current sum,
new observation, and current iteration.
## Examples
iex> cur_sum = 12
iex> iteration = 2
iex> y_true = Nx.tensor([0, 1, 0, 1])
iex> y_pred = Nx.tensor([1, 1, 0, 1])
iex> fps = Axon.Metrics.running_sum(&Axon.Metrics.false_positives/2)
iex> fps.(cur_sum, [y_true, y_pred], iteration)
#Nx.Tensor<
s64
13
>
"""
def running_sum(metric) do
&running_sum_impl(&1, apply(metric, &2), &3)
end
defnp running_sum_impl(sum, obs, _) do
Nx.add(sum, obs)
end
end