defmodule Scholar.ModelSelection do
@moduledoc """
Module containing cross validation, splitting function, and other model selection methods.
@doc """
Perform K-Fold split on the given data.
## Examples
iex> x = Nx.iota({7, 2})
iex> Scholar.ModelSelection.k_fold_split(x, 2) |> Enum.to_list()
[6, 7],
[8, 9],
[10, 11]
[0, 1],
[2, 3],
[4, 5]
[0, 1],
[2, 3],
[4, 5]
[6, 7],
[8, 9],
[10, 11]
def k_fold_split(x, k) do
fn ->
x = Nx.tensor(x)
fold_size = floor(Nx.axis_size(x, 0) / k)
slices = for i <- 0..(k - 1), do: (i * fold_size)..((i + 1) * fold_size - 1)
{slices, 0, k}
{list, k, k} ->
{:halt, {list, k, k}}
{list, current, k} ->
{left, [test | right]} = Enum.split(list, current)
tensors =
case {left, right} do
{[], _} ->
{_, []} ->
{[_ | _], [_ | _]} ->
Nx.concatenate([x[concat_ranges(left)], x[concat_ranges(right)]])
{[{tensors, x[test]}], {list, current + 1, k}}
fn _ -> :ok end
# Receive a list of contiguous ranges and returns a range with first first and last last.
defp concat_ranges([first.._ | _] = list), do: first..last_last(list)
defp last_last([_..last]), do: last
defp last_last([_ | tail]), do: last_last(tail)
@doc """
General interface of cross validation.
## Examples
iex> folding_fun = fn x -> Scholar.ModelSelection.k_fold_split(x, 3) end
iex> scoring_fun = fn x, y ->
...> {x_train, x_test} = x
...> {y_train, y_test} = y
...> model =, y_train, fit_intercept?: true)
...> y_pred = Scholar.Linear.LinearRegression.predict(model, x_test)
...> mse = Scholar.Metrics.Regression.mean_square_error(y_test, y_pred)
...> mae = Scholar.Metrics.Regression.mean_absolute_error(y_test, y_pred)
...> [mse, mae]
...> end
iex> x = Nx.iota({7, 2})
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
iex> Scholar.ModelSelection.cross_validate(x, y, folding_fun, scoring_fun)
[1.5700000524520874, 1.2149654626846313, 0.004999990575015545],
[1.100000023841858, 1.0735294818878174, 0.04999995231628418]
def cross_validate(x, y, folding_fun, scoring_fun)
when is_function(folding_fun, 1) and is_function(scoring_fun, 2) do[folding_fun.(x), folding_fun.(y)])
|> {x, y} -> scoring_fun.(x, y) |> Nx.stack() end)
|> Nx.stack(axis: 1)
@doc """
General interface of weighted cross validation.
## Examples
iex> folding_fun = fn x -> Scholar.ModelSelection.k_fold_split(x, 3) end
iex> scoring_fun = fn x, y, weights ->
...> {x_train, x_test} = x
...> {y_train, y_test} = y
...> {weights_train, _weights_test} = weights
...> model =, y_train, fit_intercept?: true, sample_weights: weights_train)
...> y_pred = Scholar.Linear.LinearRegression.predict(model, x_test)
...> mse = Scholar.Metrics.Regression.mean_square_error(y_test, y_pred)
...> mae = Scholar.Metrics.Regression.mean_absolute_error(y_test, y_pred)
...> [mse, mae]
...> end
iex> x = Nx.iota({7, 2})
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
iex> weights = Nx.tensor([1, 2, 1, 2, 1, 2, 1])
iex> Scholar.ModelSelection.weighted_cross_validate(x, y, weights, folding_fun, scoring_fun)
[0.5010331869125366, 1.1419668197631836, 0.35123956203460693],
[0.5227273106575012, 1.0526316165924072, 0.5909090042114258]
def weighted_cross_validate(x, y, weights, folding_fun, scoring_fun)
when is_function(folding_fun, 1) and is_function(scoring_fun, 3) do[folding_fun.(x), folding_fun.(y), folding_fun.(weights)])
|> {x, y, weights} -> scoring_fun.(x, y, weights) |> Nx.stack() end)
|> Nx.stack(axis: 1)
defp combinations([]), do: [[]]
defp combinations([{name, values} | opts]) do
for subcombination <- combinations(opts), value <- values do
[{name, value} | subcombination]
@doc """
General interface of grid search.
The `opts` must be a keyword list of list values, which will become different
combinations to perform the grid search on.
## Examples
iex> folding_fun = fn x -> Scholar.ModelSelection.k_fold_split(x, 3) end
iex> scoring_fun = fn x, y, opts ->
...> {x_train, x_test} = x
...> {y_train, y_test} = y
...> model =, y_train, opts)
...> y_pred = Scholar.Linear.LogisticRegression.predict(model, x_test)
...> mse = Scholar.Metrics.Regression.mean_square_error(y_test, y_pred)
...> mae = Scholar.Metrics.Regression.mean_absolute_error(y_test, y_pred)
...> [mse, mae]
...> end
iex> x = Nx.iota({7, 2})
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
iex> opts = [
...> num_classes: [3],
...> iterations: [10, 20, 50],
...> optimizer: [Polaris.Optimizers.adam(learning_rate: 0.005), Polaris.Optimizers.adam(learning_rate: 0.01)],
...> ]
iex> Scholar.ModelSelection.grid_search(x, y, folding_fun, scoring_fun, opts)
def grid_search(x, y, folding_fun, scoring_fun, opts)
when is_list(opts) and is_function(folding_fun, 1) and is_function(scoring_fun, 3) do
params = combinations(opts)
for param <- params do
scoring_fun = &scoring_fun.(&1, &2, param)
hyperparameters: param,
score: Nx.mean(cross_validate(x, y, folding_fun, scoring_fun), axes: [1])
@doc """
General interface of weighted grid search.
If you want to use `opts` in some functions inside `scoring_fun`, you need to pass it as a parameter
like in the example below.
## Examples
iex> folding_fun = fn x -> Scholar.ModelSelection.k_fold_split(x, 3) end
iex> scoring_fun = fn x, y, weights, opts ->
...> {x_train, x_test} = x
...> {y_train, y_test} = y
...> {weights_train, _weights_test} = weights
...> opts = Keyword.put(opts, :sample_weights, weights_train)
...> model =, y_train, opts)
...> y_pred = Scholar.Linear.RidgeRegression.predict(model, x_test)
...> mse = Scholar.Metrics.Regression.mean_square_error(y_test, y_pred)
...> mae = Scholar.Metrics.Regression.mean_absolute_error(y_test, y_pred)
...> [mse, mae]
...> end
iex> x = Nx.iota({7, 2})
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
iex> weights = [Nx.tensor([1, 2, 1, 2, 1, 2, 1]), Nx.tensor([2, 1, 2, 1, 2, 1, 2])]
iex> opts = [
...> alpha: [0, 1, 5],
...> fit_intercept?: [true, false],
...> ]
iex> Scholar.ModelSelection.weighted_grid_search(x, y, weights, folding_fun, scoring_fun, opts)
def weighted_grid_search(x, y, weights, folding_fun, scoring_fun, opts)
when is_list(weights) and is_list(opts) and is_function(folding_fun, 1) and
is_function(scoring_fun, 4) do
params = combinations(opts)
for weight <- weights,
param <- params do
scoring_fun = &scoring_fun.(&1, &2, &3, param)
weights: weight,
hyperparameters: param,
Nx.mean(weighted_cross_validate(x, y, weight, folding_fun, scoring_fun),
axes: [1]