
defmodule Scholar.Cluster.AffinityPropagation do
  @moduledoc """
  Model representing affinity propagation clustering. The first dimension
  of `:clusters_centers` is set to the number of samples in the dataset.
  The artificial centers are filled with `:infinity` values. To fillter
  them out use `prune` function.

  import Nx.Defn
  import Scholar.Shared

  @derive {Nx.Container,
           containers: [
  defstruct [

  @opts_schema [
    iterations: [
      type: :pos_integer,
      default: 300,
      doc: "Number of iterations of the algorithm."
    damping_factor: [
      type: :float,
      default: 0.5,
      doc: """
      Damping factor in the range [0.5, 1.0) is the extent to which the
      current value is maintained relative to incoming values (weighted 1 - damping).
    self_preference: [
      type: {:or, [:float, :boolean, :integer]},
      doc: "Self preference."
    key: [
      type: {:custom, Scholar.Options, :key, []},
      doc: """
      Determines random number generation for centroid initialization.
      If the key is not provided, it is set to `Nx.Random.key(System.system_time())`.
    learning_loop_unroll: [
      type: :boolean,
      default: false,
      doc: ~S"""
      If `true`, the learning loop is unrolled.

  @doc """
  Cluster the dataset using affinity propagation.

  ## Options


  ## Return Values

  The function returns a struct with the following parameters:

    * `:affinity_matrix` - Affinity matrix. It is a negated squared euclidean distance of each pair of points.

    * `:clusters_centers` - Cluster centers from the initial data.

    * `:cluster_centers_indices` - Indices of cluster centers.

    * `:num_clusters` - Number of clusters.

  ## Examples

      iex> key = Nx.Random.key(42)
      iex> x = Nx.tensor([[12,5,78,2], [1,-5,7,32], [-1,3,6,1], [1,-2,5,2]])
      iex>, key: key)
        labels: Nx.tensor([0, 3, 3, 3]),
        cluster_centers_indices: Nx.tensor([0, -1, -1, 3]),
        affinity_matrix: Nx.tensor(
            [-0.0, -6162.0, -5358.0, -5499.0],
            [-6162.0, -0.0, -1030.0, -913.0],
            [-5358.0, -1030.0, -0.0, -31.0],
            [-5499.0, -913.0, -31.0, -0.0]
        cluster_centers: Nx.tensor(
            [12.0, 5.0, 78.0, 2.0],
            [:infinity, :infinity, :infinity, :infinity],
            [:infinity, :infinity, :infinity, :infinity],
            [1.0, -2.0, 5.0, 2.0]
        num_clusters: Nx.tensor(2, type: :u64)
  deftransform fit(data, opts \\ []) do
    opts = NimbleOptions.validate!(opts, @opts_schema)
    opts = Keyword.update(opts, :self_preference, false, fn x -> x end)
    key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)
    fit_n(data, key, NimbleOptions.validate!(opts, @opts_schema))

  defnp fit_n(data, key, opts) do
    data = to_float(data)
    iterations = opts[:iterations]
    damping_factor = opts[:damping_factor]
    self_preference = opts[:self_preference]

    {initial_a, initial_r, s, affinity_matrix} =
      initialize_matrices(data, self_preference: self_preference)

    {n, _} = Nx.shape(initial_a)
    {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {n, n}, type: Nx.type(s))

    s =
      s +
        normal *
          (Nx.Constants.smallest_positive_normal(Nx.type(s)) * 100 +
             Nx.Constants.epsilon(Nx.type(data)) / 10 * s)

    range = Nx.iota({n})

    {{a, r}, _} =
      while {{a = initial_a, r = initial_r}, {s, range, i = 0}},
            i < iterations do
        temp = a + s
        indices = Nx.argmax(temp, axis: 1)
        y = Nx.reduce_max(temp, axes: [1])

        neg_inf = Nx.Constants.neg_infinity(to_float_type(a))
        neg_infinities = Nx.broadcast(neg_inf, {n})
        max_indices = Nx.stack([range, indices], axis: 1)
        temp = Nx.indexed_put(temp, max_indices, neg_infinities)
        y2 = Nx.reduce_max(temp, axes: [1])

        temp = s - Nx.new_axis(y, -1)
        temp = Nx.indexed_put(temp, max_indices, Nx.gather(s, max_indices) - y2)
        temp = temp * (1 - damping_factor)
        r = r * damping_factor + temp

        temp = Nx.max(r, 0)
        temp = Nx.put_diagonal(temp, Nx.take_diagonal(r))
        temp = temp - Nx.sum(temp, axes: [0])
        a_change = Nx.take_diagonal(temp)

        temp = Nx.max(temp, 0)
        temp = Nx.put_diagonal(temp, a_change)
        temp = temp * (1 - damping_factor)
        a = a * damping_factor - temp

        {{a, r}, {s, range, i + 1}}

    diagonals = Nx.take_diagonal(a) + Nx.take_diagonal(r) > 0

    k = Nx.sum(diagonals, axes: [0])
    {n, _} = shape = Nx.shape(data)

    {cluster_centers, cluster_centers_indices, labels} =
      if k > 0 do
        mask = diagonals != 0

        indices =
, Nx.iota(Nx.shape(diagonals)), -1)
          |> Nx.as_type({:s, 64})

        cluster_centers =

            Nx.broadcast(Nx.new_axis(mask, -1), shape),

        labels =
          Nx.broadcast(mask, Nx.shape(s))
          |>, Nx.Constants.neg_infinity(Nx.type(s)))
          |> Nx.argmax(axis: 1)
          |> Nx.as_type({:s, 64})

        labels =, Nx.iota(Nx.shape(labels)), labels)

        {cluster_centers, indices, labels}
        {Nx.tensor(-1, type: Nx.type(data)), Nx.broadcast(Nx.tensor(-1, type: :s64), {n}),
         Nx.broadcast(Nx.tensor(-1, type: :s64), {n})}

      affinity_matrix: affinity_matrix,
      cluster_centers_indices: cluster_centers_indices,
      cluster_centers: cluster_centers,
      labels: labels,
      num_clusters: k

  @doc """
  Optionally prune clusters, indices, and labels to only valid entries.

  It returns an updated and pruned model.

  ## Examples

      iex> key = Nx.Random.key(42)
      iex> x = Nx.tensor([[12,5,78,2], [1,-5,7,32], [-1,3,6,1], [1,-2,5,2]])
      iex> model =, key: key)
      iex> Scholar.Cluster.AffinityPropagation.prune(model)
        labels: Nx.tensor([0, 1, 1, 1]),
        cluster_centers_indices: Nx.tensor([0, 3]),
        affinity_matrix: Nx.tensor(
            [-0.0, -6162.0, -5358.0, -5499.0],
            [-6162.0, -0.0, -1030.0, -913.0],
            [-5358.0, -1030.0, -0.0, -31.0],
            [-5499.0, -913.0, -31.0, -0.0]
        cluster_centers: Nx.tensor(
            [12.0, 5.0, 78.0, 2.0],
            [1.0, -2.0, 5.0, 2.0]
        num_clusters: Nx.tensor(2, type: :u64)
  def prune(
          cluster_centers_indices: cluster_centers_indices,
          cluster_centers: cluster_centers,
          labels: labels
        } = model
      ) do
    {indices, _, _, mapping} =
      |> Nx.to_flat_list()
      |> Enum.reduce({[], 0, 0, []}, fn
        index, {indices, old_pos, new_pos, mapping} when index >= 0 ->
          {[index | indices], old_pos + 1, new_pos + 1, [{old_pos, new_pos} | mapping]}

        _index, {indices, old_pos, new_pos, mapping} ->
          {indices, old_pos + 1, new_pos, mapping}

    mapping =
    cluster_centers_indices = Nx.tensor(Enum.reverse(indices))

      | cluster_centers_indices: cluster_centers_indices,
        cluster_centers: Nx.take(cluster_centers, cluster_centers_indices),
        labels: labels |> Nx.to_flat_list() |>!(mapping, &1)) |> Nx.tensor()

  @doc """
  Predict the closest cluster each sample in `x` belongs to.

  ## Examples

      iex> key = Nx.Random.key(42)
      iex> x = Nx.tensor([[12,5,78,2], [1,5,7,32], [1,3,6,1], [1,2,5,2]])
      iex> model =, key: key)
      iex> model = Scholar.Cluster.AffinityPropagation.prune(model)
      iex> Scholar.Cluster.AffinityPropagation.predict(model, Nx.tensor([[1,6,2,6], [8,3,8,2]]))
        [1, 1]
  defn predict(%__MODULE__{cluster_centers: cluster_centers} = _model, x) do
    {num_clusters, num_features} = Nx.shape(cluster_centers)
    {num_samples, _} = Nx.shape(x)
    broadcast_shape = {num_samples, num_clusters, num_features}

      Nx.new_axis(x, 1) |> Nx.broadcast(broadcast_shape),
      Nx.new_axis(cluster_centers, 0) |> Nx.broadcast(broadcast_shape),
      axes: [-1]
    |> Nx.argmin(axis: 1)

  defnp initialize_matrices(data, opts \\ []) do
    {n, _} = Nx.shape(data)
    self_preference = opts[:self_preference]

    {similarity_matrix, affinity_matrix} =
      initialize_similarities(data, self_preference: self_preference)

    zero = Nx.tensor(0, type: Nx.type(similarity_matrix))
    availability_matrix = Nx.broadcast(zero, {n, n})
    responsibility_matrix = Nx.broadcast(zero, {n, n})

    {availability_matrix, responsibility_matrix, similarity_matrix, affinity_matrix}

  defnp initialize_similarities(data, opts \\ []) do
    {n, dims} = Nx.shape(data)
    self_preference = opts[:self_preference]
    t1 = Nx.reshape(data, {1, n, dims}) |> Nx.broadcast({n, n, dims})
    t2 = Nx.reshape(data, {n, 1, dims}) |> Nx.broadcast({n, n, dims})

    dist =
      (-1 * Scholar.Metrics.Distance.squared_euclidean(t1, t2, axes: [-1]))
      |> Nx.as_type(to_float_type(data))

    fill_in =
      cond do
        self_preference == false ->
          Nx.broadcast(Nx.median(dist), {n})

        true ->
          if Nx.size(self_preference) == 1,
            do: Nx.broadcast(self_preference, {n}),
            else: self_preference

    s_modified = dist |> Nx.put_diagonal(fill_in)
    {s_modified, dist}