lib/pca.ex

defmodule PCA do
  import Nx.Defn

  # this corresponds to aggregate average reconstruction error
  def metric(z, u_truncated_t, x_centered) do
    elem_num = elem(Nx.shape(x_centered), 0) * elem(Nx.shape(x_centered), 1)

    x_reconstructed = Nx.dot(z, Nx.transpose(u_truncated_t))

    x_diff = Nx.subtract(x_centered, x_reconstructed)


    Nx.divide(Nx.sum(x_diff), elem_num)

  end

  # choose only the first gamma entries of U
  def fit(x, gamma) do

    k = elem(Nx.shape(x), 1)

    mu = Nx.mean(x, axes: [1])
    mu_square = Nx.outer(mu, Nx.transpose(mu))

    x_centered = Nx.subtract(Nx.transpose(x), Nx.broadcast(mu, Nx.transpose(x).shape))

    x_square = Nx.dot(x, Nx.transpose(x))


    cov_mat = Nx.subtract(Nx.divide(x_square, (k - 1)), Nx.multiply(mu_square, (k/(k - 1))))

    {u, s, v} = Nx.LinAlg.svd(cov_mat)



    u_t = Nx.transpose(u)


    u_truncated_t = Nx.slice(u_t, [0, 0], [elem(u_t.shape, 0), gamma])

    z = Nx.dot(x_centered, u_truncated_t)

    # ret value:
    {z, u_truncated_t, x_centered}
  end
end

x = Nx.tensor([[1,2],[3,4], [6,4]])
k = 2

{z, u_truncated_t, x_centered} = PCA.fit(x, k)

IO.inspect(z)


# m = PCA.metric(z, u_truncated_t, x_centered)

# IO.inspect(m) # 0 avg aggregate reconstruction error