lib/nx/lin_alg.ex

defmodule Nx.LinAlg do
  @moduledoc """
  Nx conveniences for linear algebra.

  This module can be used in `defn`.
  """

  import Nx.Shared
  import Nx.Defn, only: [defn: 2, defnp: 2, deftransformp: 2]
  import Nx.Defn.Kernel, only: [keyword!: 2]

  alias Nx.Tensor, as: T

  @doc """
  Returns the adjoint of a given tensor.

  If the input tensor is real it transposes it's two inner-most axes.
  If the input tensor is complex, it additionally applies `Nx.conjugate/1` to it.

  ## Examples

      iex> Nx.LinAlg.adjoint(Nx.tensor([[1, 2], [3, 4]]))
      #Nx.Tensor<
        s64[2][2]
        [
          [1, 3],
          [2, 4]
        ]
      >

      iex> Nx.LinAlg.adjoint(Nx.tensor([[1, Complex.new(0, 2)], [3, Complex.new(0, -4)]]))
      #Nx.Tensor<
        c64[2][2]
        [
          [1.0-0.0i, 3.0-0.0i],
          [0.0-2.0i, 0.0+4.0i]
        ]
      >
  """
  defn adjoint(t) do
    tensor = Nx.to_tensor(t)
    opts = adjoint_opts(tensor.shape)

    case Nx.type(tensor) do
      {:c, _} ->
        tensor |> Nx.transpose(opts) |> Nx.conjugate()

      _ ->
        Nx.transpose(tensor, opts)
    end
  end

  deftransformp adjoint_opts(shape) do
    rank = tuple_size(shape)

    if rank > 2 do
      axes = Enum.concat(0..(rank - 3), [rank - 1, rank - 2])
      [axes: axes]
    else
      []
    end
  end

  @doc """
  Performs a Cholesky decomposition of a batch of square matrices.

  The matrices must be positive-definite and either Hermitian
  if complex or symmetric if real. An error is raised by the
  default backend if those conditions are not met. Other
  backends may emit undefined behaviour.

  ## Examples
      iex> Nx.LinAlg.cholesky(Nx.tensor([[20.0, 17.6], [17.6, 16.0]]))
      #Nx.Tensor<
        f32[2][2]
        [
          [4.4721360206604, 0.0],
          [3.9354796409606934, 0.7155418395996094]
        ]
      >

      iex> Nx.LinAlg.cholesky(Nx.tensor([[[2.0, 3.0], [3.0, 5.0]], [[1.0, 0.0], [0.0, 1.0]]]))
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [1.4142135381698608, 0.0],
            [2.1213204860687256, 0.7071064710617065]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >

      iex> t = Nx.tensor([
      ...>   [6.0, 3.0, 4.0, 8.0],
      ...>   [3.0, 6.0, 5.0, 1.0],
      ...>   [4.0, 5.0, 10.0, 7.0],
      ...>   [8.0, 1.0, 7.0, 25.0]
      ...> ])
      iex> Nx.LinAlg.cholesky(t)
      #Nx.Tensor<
        f32[4][4]
        [
          [2.4494898319244385, 0.0, 0.0, 0.0],
          [1.2247447967529297, 2.1213202476501465, 0.0, 0.0],
          [1.6329931020736694, 1.41421377658844, 2.309401035308838, 0.0],
          [3.265986204147339, -1.4142134189605713, 1.5877134799957275, 3.132491111755371]
        ]
      >

      iex> Nx.LinAlg.cholesky(Nx.tensor([[1.0, Complex.new(0, -2)], [Complex.new(0, 2), 5.0]]))
      #Nx.Tensor<
        c64[2][2]
        [
          [1.0+0.0i, 0.0+0.0i],
          [0.0+2.0i, 1.0+0.0i]
        ]
      >

      iex> t = Nx.tensor([[[2.0, 3.0], [3.0, 5.0]], [[1.0, 0.0], [0.0, 1.0]]]) |> Nx.vectorize(x: 2)
      iex> Nx.LinAlg.cholesky(t)
      #Nx.Tensor<
        vectorized[x: 2]
        f32[2][2]
        [
          [
            [1.4142135381698608, 0.0],
            [2.1213204860687256, 0.7071064710617065]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >
  """
  def cholesky(tensor) do
    %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)

    %T{type: type, shape: shape, names: names} =
      tensor = Nx.devectorize(tensor, keep_names: false)

    output_type = Nx.Type.to_floating(type)

    {output_shape, output_names} = Nx.Shape.cholesky(shape, names)

    out = %{tensor | type: output_type, shape: output_shape, names: output_names}

    :cholesky
    |> Nx.Shared.optional([tensor], out, &Nx.LinAlg.Cholesky.cholesky/1)
    |> Nx.vectorize(vectorized_axes)
  end

  @doc """
  Calculates the p-norm of a tensor.

  For the 0-norm, the norm is the number of non-zero elements in the tensor.

  ## Options

    * `:axes` - defines the axes upon which the norm will be calculated.
      Applies only on 2-norm for 2-D tensors. Default: `nil`.
    * `:keep_axes` - whether the calculation axes should be kept with
      length 1. Defaults to `false`
    * `:ord` - defines which norm will be calculated according to the table below. Default: `nil`.

  | ord          | 2-D                            | 1-D                               |
  | ------------ | -------------------------------| --------------------------------- |
  | `nil`        | Frobenius norm                 | 2-norm                            |
  | `:nuclear`   | Nuclear norm                   | -                                 |
  | `:frobenius` | Frobenius norm                 | -                                 |
  | `:inf`       | `max(sum(abs(x), axes: [1]))`  | `max(abs(x))`                     |
  | `:neg_inf`   | `min(sum(abs(x), axes: [1]))`  | `min(abs(x))`                     |
  | 0            | -                              | Number of non-zero elements       |
  | 1            | `max(sum(abs(x), axes: [0]))`  | as below                          |
  | -1           | `min(sum(abs(x), axes: [0]))`  | as below                          |
  | 2            | 2-norm                         | as below                          |
  | -2           | smallest singular value        | as below                          |
  | other        | -                              | pow(sum(pow(abs(x), p)), 1/p) |

  ## Examples

  ### Vector norms

      iex> Nx.LinAlg.norm(Nx.tensor([3, 4]))
      #Nx.Tensor<
        f32
        5.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([3, 4]), ord: 1)
      #Nx.Tensor<
        f32
        7.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([3, -4]), ord: :inf)
      #Nx.Tensor<
        f32
        4.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([3, -4]), ord: :neg_inf)
      #Nx.Tensor<
        f32
        3.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([3, -4, 0, 0]), ord: 0)
      #Nx.Tensor<
        f32
        2.0
      >

  ### Matrix norms

      iex> Nx.LinAlg.norm(Nx.tensor([[3, -1], [2, -4]]), ord: -1)
      #Nx.Tensor<
        f32
        5.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, -2], [2, -4]]), ord: 1)
      #Nx.Tensor<
        f32
        6.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, -2], [2, -4]]), ord: :neg_inf)
      #Nx.Tensor<
        f32
        5.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, -2], [2, -4]]), ord: :inf)
      #Nx.Tensor<
        f32
        6.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, 0], [0, -4]]), ord: :frobenius)
      #Nx.Tensor<
        f32
        5.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[1, 0, 0], [0, -4, 0], [0, 0, 9]]), ord: :nuclear)
      #Nx.Tensor<
        f32
        14.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[1, 0, 0], [0, -4, 0], [0, 0, 9]]), ord: -2)
      #Nx.Tensor<
        f32
        1.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, 0], [0, -4]]))
      #Nx.Tensor<
        f32
        5.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[3, 4], [0, -4]]), axes: [1])
      #Nx.Tensor<
        f32[2]
        [5.0, 4.0]
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[Complex.new(0, 3), 4], [4, 0]]), axes: [0])
      #Nx.Tensor<
        f32[2]
        [5.0, 4.0]
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[Complex.new(0, 3), 0], [4, 0]]), ord: :neg_inf)
      #Nx.Tensor<
        f32
        3.0
      >

      iex> Nx.LinAlg.norm(Nx.tensor([[0, 0], [0, 0]]))
      #Nx.Tensor<
        f32
        0.0
      >

  ## Error cases

      iex> Nx.LinAlg.norm(Nx.tensor([3, 4]), ord: :frobenius)
      ** (ArgumentError) expected a 2-D tensor for ord: :frobenius, got a 1-D tensor
  """
  @doc from_backend: false
  defn norm(tensor, opts \\ []) do
    opts = keyword!(opts, [:ord, :axes, :keep_axes])
    norm_transform(tensor, opts)
  end

  deftransformp norm_transform(t, opts) do
    rank = Nx.rank(t)

    unless rank == 1 or rank == 2 do
      raise ArgumentError, "expected 1-D or 2-D tensor, got tensor with shape #{inspect(t.shape)}"
    end

    axes_opts = Keyword.take(opts, [:axes, :keep_axes])

    case opts[:ord] do
      nil when rank == 1 -> norm_integer(t, 2, axes_opts)
      nil when rank == 2 -> norm_integer(t, 2, axes_opts)
      :frobenius -> norm_frobenius(t, axes_opts)
      :nuclear when rank == 2 -> norm_nuclear(t)
      :nuclear -> raise ArgumentError, "nuclear norm not supported for rank != 2"
      ord when ord in [:inf, :neg_inf] -> norm_inf(t, ord, axes_opts)
      ord when is_integer(ord) -> norm_integer(t, ord, axes_opts)
      ord -> raise ArgumentError, "unknown ord #{inspect(ord)}"
    end
  end

  defp norm_frobenius(%{shape: {_}}, _opts),
    do: raise(ArgumentError, "expected a 2-D tensor for ord: :frobenius, got a 1-D tensor")

  defp norm_frobenius(%{shape: {_, _}} = t, opts), do: norm_integer(t, 2, opts)

  defp norm_nuclear(%{shape: {_, _}} = t) do
    {_u, s, _v} = svd(t)
    Nx.sum(s)
  end

  defp norm_inf(%{shape: shape, type: type} = t, ord, opts) when ord in [:inf, :neg_inf] do
    output_type = Nx.Type.to_real(type)
    aggregate_axes = if tuple_size(shape) == 2, do: &Nx.sum(&1, axes: [1]), else: & &1

    reduce =
      if ord == :inf,
        do: &Nx.reduce_max(&1, opts),
        else: &Nx.reduce_min(&1, opts)

    t
    |> Nx.abs()
    |> aggregate_axes.()
    |> reduce.()
    |> Nx.as_type(output_type)
  end

  defp norm_integer(%{shape: {_}, type: type} = t, 0, opts) do
    output_type = Nx.Type.to_real(type)

    t
    |> Nx.not_equal(0)
    |> Nx.sum(opts)
    |> Nx.as_type(output_type)
  end

  defp norm_integer(%{shape: {_, _}, type: type} = t, ord, opts) when ord in [1, -1] do
    output_type = Nx.Type.to_real(type)
    function = if ord == 1, do: &Nx.reduce_max(&1, opts), else: &Nx.reduce_min(&1, opts)

    t
    |> Nx.abs()
    |> Nx.sum(axes: [0])
    |> function.()
    |> Nx.as_type(output_type)
  end

  defp norm_integer(%{shape: {_, _}}, ord, _opts) when ord not in [-2, -1, 1, 2] do
    raise ArgumentError, "invalid :ord for 2-D tensor, got: #{inspect(ord)}"
  end

  defp norm_integer(%{shape: {_, _}} = t, -2, opts) do
    {_u, s, _v} = Nx.LinAlg.svd(t)
    Nx.reduce_min(s, opts)
  end

  defp norm_integer(%{type: type} = t, ord, opts) when is_integer(ord) do
    output_type = Nx.Type.to_real(type)
    inv_ord = Nx.tensor(1 / ord, type: output_type)

    # We extract this result to a variable because it's used both for
    # getting the normalization coefficient and for the main pipe chain
    abs_t = Nx.abs(t)

    # This coefficient is introduced for better numerical stability
    # The idea is that by dividing the tensor by it, large values of
    # tensor entries and large values of p are reduced, which in turn
    # avoids numerical overflow.
    numerical_stability_coefficient = Nx.reduce_max(abs_t, opts)

    # This code prevents from division by zero.
    numerical_stability_coefficient =
      Nx.select(
        Nx.greater(numerical_stability_coefficient, 0),
        numerical_stability_coefficient,
        1
      )

    abs_t
    |> Nx.divide(numerical_stability_coefficient)
    |> Nx.pow(ord)
    |> Nx.sum(opts)
    |> Nx.pow(inv_ord)
    |> Nx.multiply(numerical_stability_coefficient)
  end

  @doc """
  Solve the equation `a x = b` for x, assuming `a` is a batch of triangular matrices.
  Can also solve `x a = b` for x. See the `:left_side` option below.

  `b` must either be a batch of square matrices with the same dimensions as `a` or a batch of 1-D tensors
  with as many rows as `a`. Batch dimensions of `a` and `b` must be the same.

  ## Options

  The following options are defined in order of precedence

  * `:transform_a` - Defines `op(a)`, depending on its value. Can be one of:
    * `:none` -> `op(a) = a`
    * `:transpose` -> `op(a) = transpose(a)`
    Defaults to `:none`
  * `:lower` - When `true`, defines the `a` matrix as lower triangular. If false, a is upper triangular.
    Defaults to `true`
  * `:left_side` - When `true`, solves the system as `op(A).X = B`. Otherwise, solves `X.op(A) = B`. Defaults to `true`.

  ## Examples

      iex> a = Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([4, 2, 4, 2]))
      #Nx.Tensor<
        f32[4]
        [1.3333333730697632, -0.6666666865348816, 2.6666667461395264, -1.3333333730697632]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], type: :f64)
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([1, 2, 1]))
      #Nx.Tensor<
        f64[3]
        [1.0, 1.0, -1.0]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [0, 1, 1]])
      iex> b = Nx.tensor([[1, 2, 3], [2, 2, 4], [2, 0, 1]])
      iex> Nx.LinAlg.triangular_solve(a, b)
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 2.0, 3.0],
          [1.0, 0.0, 1.0],
          [1.0, 0.0, 0.0]
        ]
      >

      iex> a = Nx.tensor([[1, 1, 1, 1], [0, 1, 0, 1], [0, 0, 1, 2], [0, 0, 0, 3]])
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([2, 4, 2, 4]), lower: false)
      #Nx.Tensor<
        f32[4]
        [-1.3333333730697632, 2.6666667461395264, -0.6666666865348816, 1.3333333730697632]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 2, 1]])
      iex> b = Nx.tensor([[0, 2, 1], [1, 1, 0], [3, 3, 1]])
      iex> Nx.LinAlg.triangular_solve(a, b, left_side: false)
      #Nx.Tensor<
        f32[3][3]
        [
          [-1.0, 0.0, 1.0],
          [0.0, 1.0, 0.0],
          [1.0, 1.0, 1.0]
        ]
      >

      iex> a = Nx.tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]], type: :f64)
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([1, 2, 1]), transform_a: :transpose, lower: false)
      #Nx.Tensor<
        f64[3]
        [1.0, 1.0, -1.0]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], type: :f64)
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([1, 2, 1]), transform_a: :none)
      #Nx.Tensor<
        f64[3]
        [1.0, 1.0, -1.0]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 2, 1]])
      iex> b = Nx.tensor([[0, 1, 3], [2, 1, 3]])
      iex> Nx.LinAlg.triangular_solve(a, b, left_side: false)
      #Nx.Tensor<
        f32[2][3]
        [
          [2.0, -5.0, 3.0],
          [4.0, -5.0, 3.0]
        ]
      >

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 2, 1]])
      iex> b = Nx.tensor([[0, 2], [3, 0], [0, 0]])
      iex> Nx.LinAlg.triangular_solve(a, b, left_side: true)
      #Nx.Tensor<
        f32[3][2]
        [
          [0.0, 2.0],
          [3.0, -2.0],
          [-6.0, 2.0]
        ]
      >

      iex> a = Nx.tensor([
      ...> [1, 0, 0],
      ...> [1, Complex.new(0, 1), 0],
      ...> [Complex.new(0, 1), 1, 1]
      ...>])
      iex> b = Nx.tensor([1, -1, Complex.new(3, 3)])
      iex> Nx.LinAlg.triangular_solve(a, b)
      #Nx.Tensor<
        c64[3]
        [1.0+0.0i, 0.0+2.0i, 3.0+0.0i]
      >

      iex> a = Nx.tensor([[[1, 0], [2, 3]], [[4, 0], [5, 6]]])
      iex> b = Nx.tensor([[2, -1], [3, 7]])
      iex> Nx.LinAlg.triangular_solve(a, b)
      #Nx.Tensor<
        f32[2][2]
        [
          [2.0, -1.6666666269302368],
          [0.75, 0.5416666865348816]
        ]
      >

      iex> a = Nx.tensor([[[1, 1], [0, 1]], [[2, 0], [0, 2]]]) |> Nx.vectorize(x: 2)
      iex> b = Nx.tensor([[[2, 1], [5, -1]]]) |> Nx.vectorize(x: 1, y: 2)
      iex> Nx.LinAlg.triangular_solve(a, b, lower: false)
      #Nx.Tensor<
        vectorized[x: 2][y: 2]
        f32[2]
        [
          [
            [1.0, 1.0],
            [6.0, -1.0]
          ],
          [
            [1.0, 0.5],
            [2.5, -0.5]
          ]
        ]
      >

  ## Error cases

      iex> Nx.LinAlg.triangular_solve(Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0]]), Nx.tensor([4, 2, 4, 2]))
      ** (ArgumentError) triangular_solve/3 expected a square matrix or a batch of square matrices, got tensor with shape: {2, 4}

      iex> Nx.LinAlg.triangular_solve(Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]), Nx.tensor([4]))
      ** (ArgumentError) incompatible dimensions for a and b on triangular solve

      iex> Nx.LinAlg.triangular_solve(Nx.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1]]), Nx.tensor([4, 2, 4, 2]))
      ** (ArgumentError) can't solve for singular matrix

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], type: :f64)
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([1, 2, 1]), transform_a: :conjugate)
      ** (ArgumentError) complex numbers not supported yet

      iex> a = Nx.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]], type: :f64)
      iex> Nx.LinAlg.triangular_solve(a, Nx.tensor([1, 2, 1]), transform_a: :other)
      ** (ArgumentError) invalid value for :transform_a option, expected :none, :transpose, or :conjugate, got: :other

  """
  def triangular_solve(a, b, opts \\ []) do
    opts = keyword!(opts, lower: true, left_side: true, transform_a: :none)

    case opts[:transform_a] do
      t when t in [:none, :transpose] ->
        nil

      :conjugate ->
        raise ArgumentError, "complex numbers not supported yet"

      t ->
        raise ArgumentError,
              "invalid value for :transform_a option, expected :none, :transpose, or :conjugate, " <>
                "got: #{inspect(t)}"
    end

    [%T{vectorized_axes: vectorized_axes, shape: a_shape} = a, %T{shape: b_shape} = b] =
      Nx.broadcast_vectors([a, b])

    :ok = Nx.Shape.triangular_solve(a_shape, b_shape, opts[:left_side])
    output_type = binary_type(a, b) |> Nx.Type.to_floating()

    a = Nx.devectorize(a)
    b = Nx.devectorize(b)

    result = impl!(a, b).triangular_solve(%{b | type: output_type}, a, b, opts)

    Nx.vectorize(result, vectorized_axes)
  end

  @doc """
  Solves the system `AX = B`.

  `A` must have shape `{..., n, n}` and `B` must have shape `{..., n, m}` or `{..., n}`.
  `X` has the same shape as `B`.

  ## Examples

      iex> a = Nx.tensor([[1, 3, 2, 1], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
      iex> Nx.LinAlg.solve(a, Nx.tensor([-3, 0, 4, -2])) |> Nx.round()
      #Nx.Tensor<
        f32[4]
        [1.0, -2.0, 3.0, -4.0]
      >

      iex> a = Nx.tensor([[1, 0, 1], [1, 1, 0], [1, 1, 1]], type: :f64)
      iex> Nx.LinAlg.solve(a, Nx.tensor([0, 2, 1])) |> Nx.round()
      #Nx.Tensor<
        f64[3]
        [1.0, 1.0, -1.0]
      >

      iex> a = Nx.tensor([[1, 0, 1], [1, 1, 0], [0, 1, 1]])
      iex> b = Nx.tensor([[2, 2, 3], [2, 2, 4], [2, 0, 1]])
      iex> Nx.LinAlg.solve(a, b) |> Nx.round()
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 2.0, 3.0],
          [1.0, 0.0, 1.0],
          [1.0, 0.0, 0.0]
        ]
      >

      iex> a = Nx.tensor([[[14, 10], [9, 9]], [[4, 11], [2, 3]]])
      iex> b = Nx.tensor([[[2, 4], [3, 2]], [[1, 5], [-3, -1]]])
      iex> Nx.LinAlg.solve(a, b) |> Nx.round()
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [0.0, 0.0],
            [1.0, 0.0]
          ],
          [
            [-4.0, -3.0],
            [1.0, 1.0]
          ]
        ]
      >

      iex> a = Nx.tensor([[[1, 1], [0, 1]], [[2, 0], [0, 2]]]) |> Nx.vectorize(x: 2)
      iex> b = Nx.tensor([[[2, 1], [5, -1]]]) |> Nx.vectorize(x: 1, y: 2)
      iex> Nx.LinAlg.solve(a, b)
      #Nx.Tensor<
        vectorized[x: 2][y: 2]
        f32[2]
        [
          [
            [1.0, 1.0],
            [6.0, -1.0]
          ],
          [
            [1.0, 0.5],
            [2.5, -0.5]
          ]
        ]
      >

  If the axes are named, their names are not preserved in the output:

      iex> a = Nx.tensor([[1, 0, 1], [1, 1, 0], [1, 1, 1]], names: [:x, :y])
      iex> Nx.LinAlg.solve(a, Nx.tensor([0, 2, 1], names: [:z])) |> Nx.round()
      #Nx.Tensor<
        f32[3]
        [1.0, 1.0, -1.0]
      >

  ## Error cases

      iex> Nx.LinAlg.solve(Nx.tensor([[1, 0], [0, 1]]), Nx.tensor([4, 2, 4, 2]))
      ** (ArgumentError) `b` tensor has incompatible dimensions, expected {2, 2} or {2}, got: {4}

      iex> Nx.LinAlg.solve(Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 1, 1, 1]]), Nx.tensor([4]))
      ** (ArgumentError) `a` tensor has incompatible dimensions, expected a square matrix or a batch of square matrices, got: {3, 4}
  """
  # IMPORTANT: This function cannot be a defn because
  # optional needs to work on the actual backend.
  @doc from_backend: false
  def solve(a, b) do
    [%T{vectorized_axes: vectorized_axes} = a, b] = Nx.broadcast_vectors([a, b])

    a = Nx.devectorize(a)
    b = Nx.devectorize(b)

    %T{shape: a_shape, type: a_type} = a
    %T{shape: b_shape, type: b_type} = b

    output_shape = Nx.Shape.solve(a_shape, b_shape)
    output_type = a_type |> Nx.Type.merge(b_type) |> Nx.Type.to_floating()
    output = Nx.template(output_shape, output_type)

    result =
      Nx.Shared.optional(:solve, [a, b], output, fn a, b ->
        # Since we have triangular solve, which accepts upper
        # triangular matrices with the `lower: false` option,
        # we can solve a system as follows:

        # A.X = B -> QR.X = B -> R.X = adjoint(Q).B

        {q, r} = Nx.LinAlg.qr(a)
        q_rank = Nx.rank(q)
        batches = Enum.to_list(0..(q_rank - 3)//1)
        qb = Nx.dot(adjoint(q), [q_rank - 1], batches, b, [q_rank - 2], batches)
        triangular_solve(r, qb, lower: false)
      end)

    Nx.vectorize(result, vectorized_axes)
  end

  @doc """
  Inverts a batch of square matrices.

  For non-square matrices, use `pinv/2` for pseudo-inverse calculations.

  ## Examples

      iex> a = Nx.tensor([[1, 2, 1, 1], [0, 1, 0, 1], [0, 0, 1, 1], [0 , 0, 0, 1]])
      iex> a_inv = Nx.LinAlg.invert(a)
      #Nx.Tensor<
        f32[4][4]
        [
          [1.0, -2.0, -1.0, 2.0],
          [0.0, 1.0, 0.0, -1.0],
          [0.0, 0.0, 1.0, -1.0],
          [0.0, 0.0, 0.0, 1.0]
        ]
      >
      iex> Nx.dot(a, a_inv)
      #Nx.Tensor<
        f32[4][4]
        [
          [1.0, 0.0, 0.0, 0.0],
          [0.0, 1.0, 0.0, 0.0],
          [0.0, 0.0, 1.0, 0.0],
          [0.0, 0.0, 0.0, 1.0]
        ]
      >
      iex> Nx.dot(a_inv, a)
      #Nx.Tensor<
        f32[4][4]
        [
          [1.0, 0.0, 0.0, 0.0],
          [0.0, 1.0, 0.0, 0.0],
          [0.0, 0.0, 1.0, 0.0],
          [0.0, 0.0, 0.0, 1.0]
        ]
      >

      iex> a = Nx.tensor([[[1, 2], [0, 1]], [[1, 1], [0, 1]]])
      iex> a_inv = Nx.LinAlg.invert(a)
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [1.0, -2.0],
            [0.0, 1.0]
          ],
          [
            [1.0, -1.0],
            [0.0, 1.0]
          ]
        ]
      >
      iex> Nx.dot(a, [2], [0], a_inv, [1], [0])
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >
      iex> Nx.dot(a_inv, [2], [0], a, [1], [0])
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >

  If a singular matrix is passed, the result will silently fail.

      iex> Nx.LinAlg.invert(Nx.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1]]))
      #Nx.Tensor<
        f32[4][4]
        [
          [NaN, NaN, NaN, NaN],
          [NaN, NaN, NaN, NaN],
          [NaN, NaN, NaN, NaN],
          [NaN, NaN, NaN, NaN]
        ]
      >

  ## Error cases

      iex> Nx.LinAlg.invert(Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0]]))
      ** (ArgumentError) invert/1 expects a square matrix or a batch of square matrices, got tensor with shape: {2, 4}

  """
  @doc from_backend: false
  defn invert(tensor) do
    ans =
      tensor
      |> invert_shape()
      |> invert_tensor()

    custom_grad(ans, [tensor], fn g ->
      # As defined in https://juliadiff.org/ChainRulesCore.jl/stable/maths/arrays.html#Matrix-inversion-2
      ans_h = adjoint(ans)
      [ans_h |> Nx.negate() |> Nx.dot(g) |> Nx.dot(ans_h)]
    end)
  end

  defnp invert_tensor(tensor) do
    m = Nx.axis_size(tensor, -2)
    n = Nx.axis_size(tensor, -1)
    vectorized_axes = tensor.vectorized_axes
    input_shape = Nx.shape(tensor)

    # proof of equivalence:
    # norm_t = t / det
    # norm_t ** -1 = (t / det) ** -1 = t ** -1 * det
    # t ** -1 = norm_t ** -1 / det

    tensor =
      case input_shape do
        {_, _} ->
          # this avoids the need for creating a new
          # vectorized axis to collapse batched axes,
          # because we have no batch axes
          tensor

        _ ->
          Nx.revectorize(tensor, [collapsed_batch: :auto], target_shape: {m, n})
      end

    det = Nx.LinAlg.determinant(tensor)

    type = Nx.Type.to_real(Nx.type(tensor))
    eps = Nx.Constants.smallest_positive_normal(type) * 1.0e3

    inverse =
      if Nx.abs(det) <= eps do
        Nx.tensor(:nan)
      else
        # matrix is possibly invertible but ill-conditioned
        # we normalize it by the determinant

        scaling_matrix = Nx.reduce_max(Nx.abs(tensor), axes: [1], keep_axes: true)
        # don't rescale for 0-norm rows
        scaling_matrix = 1 / Nx.select(scaling_matrix == 0, 1, scaling_matrix)

        # We can think of the implementation as a system of equations.
        # Since we're scaling the left side by scaling_matrix[i] for each row i,
        # we need to also scale the right side.
        # This is achieved by scaling each row of an identity matrix, which is,
        # in fact, the same as putting the scaling values in the diagonal of the
        # right-side matrix.

        normalized_tensor = scaling_matrix * tensor

        Nx.LinAlg.solve(
          normalized_tensor,
          Nx.make_diagonal(Nx.squeeze(scaling_matrix, axes: [1]))
        )
      end

    Nx.revectorize(inverse, vectorized_axes, target_shape: input_shape)
  end

  deftransformp invert_shape(tensor) do
    shape = Nx.shape(tensor)

    shape
    |> Tuple.to_list()
    |> Enum.split(-2)
    |> case do
      {_, [n, n]} ->
        tensor

      _ ->
        raise ArgumentError,
              "invert/1 expects a square matrix or a batch of square matrices, got tensor with shape: #{inspect(shape)}"
    end
  end

  @doc """
  Calculates the QR decomposition of a tensor with shape `{..., M, N}`.

  ## Options

    * `:mode` - Can be one of `:reduced`, `:complete`. Defaults to `:reduced`
      For the following, `K = min(M, N)`

      * `:reduced` - returns `q` and `r` with shapes `{..., M, K}` and `{..., K, N}`
      * `:complete` - returns `q` and `r` with shapes `{..., M, M}` and `{..., M, N}`

    * `:eps` - Rounding error threshold that can be applied during the triangularization. Defaults to `1.0e-10`

  ## Examples

      iex> {q, r} = Nx.LinAlg.qr(Nx.tensor([[-3, 2, 1], [0, 1, 1], [0, 0, -1]]))
      iex> q
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, 1.0]
        ]
      >
      iex> r
      #Nx.Tensor<
        f32[3][3]
        [
          [-3.0, 2.0, 1.0],
          [0.0, 1.0, 1.0],
          [0.0, 0.0, -1.0]
        ]
      >

      iex> t = Nx.tensor([[3, 2, 1], [0, 1, 1], [0, 0, 1]])
      iex> {q, r} = Nx.LinAlg.qr(t)
      iex> q
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, 1.0]
        ]
      >
      iex> r
      #Nx.Tensor<
        f32[3][3]
        [
          [3.0, 2.0, 1.0],
          [0.0, 1.0, 1.0],
          [0.0, 0.0, 1.0]
        ]
      >

      iex> {qs, rs} = Nx.LinAlg.qr(Nx.tensor([[[-3, 2, 1], [0, 1, 1], [0, 0, -1]],[[3, 2, 1], [0, 1, 1], [0, 0, 1]]]))
      iex> qs
      #Nx.Tensor<
        f32[2][3][3]
        [
          [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0]
          ],
          [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0]
          ]
        ]
      >
      iex> rs
      #Nx.Tensor<
        f32[2][3][3]
        [
          [
            [-3.0, 2.0, 1.0],
            [0.0, 1.0, 1.0],
            [0.0, 0.0, -1.0]
          ],
          [
            [3.0, 2.0, 1.0],
            [0.0, 1.0, 1.0],
            [0.0, 0.0, 1.0]
          ]
        ]
      >

      iex> t = Nx.tensor([[3, 2, 1], [0, 1, 1], [0, 0, 1], [0, 0, 1]], type: :f32)
      iex> {q, r} = Nx.LinAlg.qr(t, mode: :reduced)
      iex> q
      #Nx.Tensor<
        f32[4][3]
        [
          [1.0, 0.0, 0.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, 0.7071068286895752],
          [0.0, 0.0, 0.7071067690849304]
        ]
      >
      iex> r
      #Nx.Tensor<
        f32[3][3]
        [
          [3.0, 2.0, 1.0],
          [0.0, 1.0, 1.0],
          [0.0, 0.0, 1.4142136573791504]
        ]
      >

      iex> t = Nx.tensor([[3, 2, 1], [0, 1, 1], [0, 0, 1], [0, 0, 0]], type: :f32)
      iex> {q, r} = Nx.LinAlg.qr(t, mode: :complete)
      iex> q
      #Nx.Tensor<
        f32[4][4]
        [
          [1.0, 0.0, 0.0, 0.0],
          [0.0, 1.0, 0.0, 0.0],
          [0.0, 0.0, 1.0, 0.0],
          [0.0, 0.0, 0.0, 1.0]
        ]
      >
      iex> r
      #Nx.Tensor<
        f32[4][3]
        [
          [3.0, 2.0, 1.0],
          [0.0, 1.0, 1.0],
          [0.0, 0.0, 1.0],
          [0.0, 0.0, 0.0]
        ]
      >

      iex> t = Nx.tensor([[[-3, 2, 1], [0, 1, 1], [0, 0, -1]],[[3, 2, 1], [0, 1, 1], [0, 0, 1]]]) |> Nx.vectorize(x: 2)
      iex> {qs, rs} = Nx.LinAlg.qr(t)
      iex> qs
      #Nx.Tensor<
        vectorized[x: 2]
        f32[3][3]
        [
          [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0]
          ],
          [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0]
          ]
        ]
      >
      iex> rs
      #Nx.Tensor<
        vectorized[x: 2]
        f32[3][3]
        [
          [
            [-3.0, 2.0, 1.0],
            [0.0, 1.0, 1.0],
            [0.0, 0.0, -1.0]
          ],
          [
            [3.0, 2.0, 1.0],
            [0.0, 1.0, 1.0],
            [0.0, 0.0, 1.0]
          ]
        ]
      >

  ## Error cases

      iex> Nx.LinAlg.qr(Nx.tensor([1, 2, 3, 4, 5]))
      ** (ArgumentError) tensor must have at least rank 2, got rank 1 with shape {5}

      iex> t = Nx.tensor([[-3, 2, 1], [0, 1, 1], [0, 0, -1]])
      iex> Nx.LinAlg.qr(t, mode: :error_test)
      ** (ArgumentError) invalid :mode received. Expected one of [:reduced, :complete], received: :error_test
  """
  def qr(tensor, opts \\ []) do
    opts = keyword!(opts, mode: :reduced, eps: 1.0e-10)
    %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)

    %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor)

    mode = opts[:mode]
    valid_modes = [:reduced, :complete]

    unless mode in valid_modes do
      raise ArgumentError,
            "invalid :mode received. Expected one of #{inspect(valid_modes)}, received: #{inspect(mode)}"
    end

    output_type = Nx.Type.to_floating(type)
    {q_shape, r_shape} = Nx.Shape.qr(shape, opts)

    output =
      {%{
         tensor
         | type: output_type,
           shape: q_shape,
           names: List.duplicate(nil, tuple_size(q_shape))
       },
       %{
         tensor
         | type: output_type,
           shape: r_shape,
           names: List.duplicate(nil, tuple_size(r_shape))
       }}

    :qr
    |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.QR.qr/2)
    |> Nx.vectorize(vectorized_axes)
  end

  @doc """
  Calculates the Moore-Penrose inverse, or the pseudoinverse, of a matrix.

  ## Options
    * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-10`

  ## Examples

  Scalar case:

      iex> Nx.LinAlg.pinv(2)
      #Nx.Tensor<
        f32
        0.5
      >

      iex> Nx.LinAlg.pinv(0)
      #Nx.Tensor<
        f32
        0.0
      >

  Vector case:

      iex> Nx.LinAlg.pinv(Nx.tensor([0, 1, 2]))
      #Nx.Tensor<
        f32[3]
        [0.0, 0.20000000298023224, 0.4000000059604645]
      >

      iex> Nx.LinAlg.pinv(Nx.tensor([0, 0, 0]))
      #Nx.Tensor<
        f32[3]
        [0.0, 0.0, 0.0]
      >

  Matrix case:

      iex> Nx.LinAlg.pinv(Nx.tensor([[1, 1], [3, 4]]))
      #Nx.Tensor<
        f32[2][2]
        [
          [3.9924824237823486, -1.0052783489227295],
          [-3.0051186084747314, 1.0071179866790771]
        ]
      >

      iex> Nx.LinAlg.pinv(Nx.tensor([[0.5, 0], [0, 1], [0.5, 0]]))
      #Nx.Tensor<
        f32[2][3]
        [
          [0.9999999403953552, 0.0, 0.9999998807907104],
          [0.0, 1.0, 0.0]
        ]
      >
  """
  defn pinv(tensor, opts \\ []) do
    opts = keyword!(opts, eps: 1.0e-10)

    if Nx.all(Nx.abs(tensor) <= opts[:eps]) do
      pinv_zero(tensor)
    else
      pinv_non_zero(tensor, opts)
    end
  end

  defnp pinv_zero(tensor) do
    # the tensor is already zero and the pseudo-inverse
    # is defined to be zero in this case
    0
    |> Nx.tensor(type: Nx.type(tensor))
    |> Nx.broadcast(pinv_zero_shape(tensor))
  end

  deftransformp pinv_zero_shape(tensor) do
    shape = Nx.shape(tensor)
    rank = tuple_size(shape)

    if rank < 2 do
      shape
    else
      [n, m | tl] =
        shape
        |> Tuple.to_list()
        |> Enum.reverse()

      tl
      |> List.to_tuple()
      |> Tuple.insert_at(rank - 2, n)
      |> Tuple.insert_at(rank - 1, m)
    end
  end

  defnp pinv_non_zero(tensor, opts \\ []) do
    case Nx.rank(tensor) do
      0 ->
        1 / tensor

      1 ->
        adjoint(tensor) / norm(tensor) ** 2

      _ ->
        {u, s, vt} = Nx.LinAlg.svd(tensor, full_matrices?: false)
        v = adjoint(vt)
        ut = adjoint(u)

        s_idx = Nx.abs(s) < opts[:eps]
        adjusted_s = Nx.select(s_idx, 1, s)

        s_inv_matrix = Nx.select(s_idx, 0, 1 / adjusted_s)

        sut = Nx.new_axis(s_inv_matrix, -1) * ut
        Nx.dot(v, sut)
    end
  end

  @doc """
  Calculates the Eigenvalues and Eigenvectors of batched Hermitian 2-D matrices.

  It returns `{eigenvals, eigenvecs}`.

  ## Options

    * `:max_iter` - `integer`. Defaults to `1_000`
      Number of maximum iterations before stopping the decomposition

    * `:eps` - `float`. Defaults to 1.0e-4
      Tolerance applied during the decomposition

  Note not all options apply to all backends, as backends may have
  specific optimizations that render these mechanisms unnecessary.

  ## Examples

      iex> {eigenvals, eigenvecs} = Nx.LinAlg.eigh(Nx.tensor([[1, 0], [0, 2]]))
      iex> Nx.round(eigenvals)
      #Nx.Tensor<
        f32[2]
        [1.0, 2.0]
      >
      iex> eigenvecs
      #Nx.Tensor<
        f32[2][2]
        [
          [1.0, 0.0],
          [0.0, 1.0]
        ]
      >

      iex> {eigenvals, eigenvecs} = Nx.LinAlg.eigh(Nx.tensor([[0, 1, 2], [1, 0, 2], [2, 2, 3]]))
      iex> Nx.round(eigenvals)
      #Nx.Tensor<
        f32[3]
        [5.0, -1.0, -1.0]
      >
      iex> eigenvecs
      #Nx.Tensor<
        f32[3][3]
        [
          [0.4075949788093567, 0.9131628274917603, 0.0],
          [0.40837883949279785, -0.18228201568126678, 0.8944271802902222],
          [0.8167576789855957, -0.36456403136253357, -0.4472135901451111]
        ]
      >

      iex> {eigenvals, eigenvecs} = Nx.LinAlg.eigh(Nx.tensor([[[2, 5],[5, 6]], [[1, 0], [0, 4]]]))
      iex> Nx.round(eigenvals)
      #Nx.Tensor<
        f32[2][2]
        [
          [9.0, -1.0],
          [1.0, 4.0]
        ]
      >
      iex> eigenvecs
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [0.5612090229988098, -0.8276740908622742],
            [0.8276740908622742, 0.5612090229988098]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >

      iex> t = Nx.tensor([[[2, 5],[5, 6]], [[1, 0], [0, 4]]]) |> Nx.vectorize(x: 2)
      iex> {eigenvals, eigenvecs} = Nx.LinAlg.eigh(t)
      iex> Nx.round(eigenvals)
      #Nx.Tensor<
        vectorized[x: 2]
        f32[2]
        [
          [9.0, -1.0],
          [1.0, 4.0]
        ]
      >
      iex> eigenvecs
      #Nx.Tensor<
        vectorized[x: 2]
        f32[2][2]
        [
          [
            [0.5612090229988098, -0.8276740908622742],
            [0.8276740908622742, 0.5612090229988098]
          ],
          [
            [1.0, 0.0],
            [0.0, 1.0]
          ]
        ]
      >

  ## Error cases

      iex> Nx.LinAlg.eigh(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
      ** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {2, 3}
  """
  def eigh(tensor, opts \\ []) do
    opts = keyword!(opts, max_iter: 1_000, eps: 1.0e-4)
    %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)
    %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor)

    output_type = Nx.Type.to_floating(type)

    {eigenvals_shape, eigenvecs_shape} = Nx.Shape.eigh(shape)
    rank = tuple_size(shape)

    eigenvecs_name = List.duplicate(nil, rank)
    eigenvals_name = tl(eigenvecs_name)

    output =
      {%{tensor | names: eigenvals_name, type: output_type, shape: eigenvals_shape},
       %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}}

    :eigh
    |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.Eigh.eigh/2)
    |> Nx.vectorize(vectorized_axes)
  end

  @doc """
  Calculates the Singular Value Decomposition of batched 2-D matrices.

  It returns `{u, s, vt}` where the elements of `s` are sorted
  from highest to lowest.

  ## Options

    * `:max_iter` - `integer`. Defaults to `100`
      Number of maximum iterations before stopping the decomposition

    * `:full_matrices?` - `boolean`. Defaults to `true`
      If `true`, `u` and `vt` are of shape (M, M), (N, N). Otherwise,
      the shapes are (M, K) and (K, N), where K = min(M, N).

  Note not all options apply to all backends, as backends may have
  specific optimizations that render these mechanisms unnecessary.

  ## Examples

      iex> {u, s, vt} = Nx.LinAlg.svd(Nx.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]]))
      iex> u
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, -1.0]
        ]
      >
      iex> s
      #Nx.Tensor<
        f32[3]
        [1.0, 1.0, 1.0]
      >
      iex> vt
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, 1.0]
        ]
      >

      iex> {u, s, vt} = Nx.LinAlg.svd(Nx.tensor([[2, 0, 0], [0, 3, 0], [0, 0, -1], [0, 0, 0]]))
      iex> u
      #Nx.Tensor<
        f32[4][4]
        [
          [0.0, 0.9999999403953552, 0.0, 0.0],
          [1.0, 0.0, 0.0, 0.0],
          [0.0, 0.0, -1.0, 0.0],
          [0.0, 0.0, 0.0, 1.0]
        ]
      >
      iex> s
      #Nx.Tensor<
        f32[3]
        [3.0, 1.9999998807907104, 1.0]
      >
      iex> vt
      #Nx.Tensor<
        f32[3][3]
        [
          [0.0, 1.0, 0.0],
          [1.0, 0.0, 0.0],
          [0.0, 0.0, 1.0]
        ]
      >

      iex> {u, s, vt} = Nx.LinAlg.svd(Nx.tensor([[2, 0, 0], [0, 3, 0], [0, 0, -1], [0, 0, 0]]), full_matrices?: false)
      iex> u
      #Nx.Tensor<
        f32[4][3]
        [
          [0.0, 0.9999999403953552, 0.0],
          [1.0, 0.0, 0.0],
          [0.0, 0.0, -1.0],
          [0.0, 0.0, 0.0]
        ]
      >
      iex> s
      #Nx.Tensor<
        f32[3]
        [3.0, 1.9999998807907104, 1.0]
      >
      iex> vt
      #Nx.Tensor<
        f32[3][3]
        [
          [0.0, 1.0, 0.0],
          [1.0, 0.0, 0.0],
          [0.0, 0.0, 1.0]
        ]
      >
  """
  def svd(tensor, opts \\ []) do
    opts = keyword!(opts, max_iter: 100, full_matrices?: true)
    %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)

    %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor)

    Nx.Shared.raise_complex_not_implemented_yet(type, "LinAlg.svd", 2)
    output_type = Nx.Type.to_floating(type)
    {u_shape, s_shape, v_shape} = Nx.Shape.svd(shape, opts)
    rank = tuple_size(shape)

    output =
      {%{tensor | names: List.duplicate(nil, rank), type: output_type, shape: u_shape},
       %{tensor | names: List.duplicate(nil, rank - 1), type: output_type, shape: s_shape},
       %{tensor | names: List.duplicate(nil, rank), type: output_type, shape: v_shape}}

    :svd
    |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.SVD.svd/2)
    |> Nx.vectorize(vectorized_axes)
  end

  @doc """
  Calculates the A = PLU decomposition of batched square 2-D matrices A.

  ## Options

    * `:eps` - Rounding error threshold that can be applied during the factorization

  ## Examples

      iex> {p, l, u} = Nx.LinAlg.lu(Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
      iex> p
      #Nx.Tensor<
        s64[3][3]
        [
          [0, 0, 1],
          [0, 1, 0],
          [1, 0, 0]
        ]
      >
      iex> l
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [0.5714285969734192, 1.0, 0.0],
          [0.1428571492433548, 2.0, 1.0]
        ]
      >
      iex> u
      #Nx.Tensor<
        f32[3][3]
        [
          [7.0, 8.0, 9.0],
          [0.0, 0.4285714328289032, 0.8571428656578064],
          [0.0, 0.0, 0.0]
        ]
      >
      iex> p |> Nx.dot(l) |> Nx.dot(u)
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 2.0, 3.0],
          [4.0, 5.0, 6.0],
          [7.0, 8.0, 9.0]
        ]
      >

      iex> {p, l, u} = Nx.LinAlg.lu(Nx.tensor([[1, 0, 1], [-1, 0, -1], [1, 1, 1]]))
      iex> p
      #Nx.Tensor<
        s64[3][3]
        [
          [1, 0, 0],
          [0, 0, 1],
          [0, 1, 0]
        ]
      >
      iex> l
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 0.0],
          [1.0, 1.0, 0.0],
          [-1.0, 0.0, 1.0]
        ]
      >
      iex> u
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 1.0],
          [0.0, 1.0, 0.0],
          [0.0, 0.0, 0.0]
        ]
      >
      iex> p |> Nx.dot(l) |> Nx.dot(u)
      #Nx.Tensor<
        f32[3][3]
        [
          [1.0, 0.0, 1.0],
          [-1.0, 0.0, -1.0],
          [1.0, 1.0, 1.0]
        ]
      >

      iex> {p, l, u} = Nx.LinAlg.lu(Nx.tensor([[[9, 8, 7], [6, 5, 4], [3, 2, 1]], [[-1, 0, -1], [1, 0, 1], [1, 1, 1]]]))
      iex> p
      #Nx.Tensor<
        s64[2][3][3]
        [
          [
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
          ],
          [
            [1, 0, 0],
            [0, 0, 1],
            [0, 1, 0]
          ]
        ]
      >
      iex> l
      #Nx.Tensor<
        f32[2][3][3]
        [
          [
            [1.0, 0.0, 0.0],
            [0.6666666865348816, 1.0, 0.0],
            [0.3333333432674408, 2.0, 1.0]
          ],
          [
            [1.0, 0.0, 0.0],
            [-1.0, 1.0, 0.0],
            [-1.0, 0.0, 1.0]
          ]
        ]
      >
      iex> u
      #Nx.Tensor<
        f32[2][3][3]
        [
          [
            [9.0, 8.0, 7.0],
            [0.0, -0.3333333432674408, -0.6666666865348816],
            [0.0, 0.0, 0.0]
          ],
          [
            [-1.0, 0.0, -1.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 0.0]
          ]
        ]
      >
      iex> p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0])
      #Nx.Tensor<
        f32[2][3][3]
        [
          [
            [9.0, 8.0, 7.0],
            [6.0, 5.0, 4.0],
            [3.0, 2.0, 1.0]
          ],
          [
            [-1.0, 0.0, -1.0],
            [1.0, 0.0, 1.0],
            [1.0, 1.0, 1.0]
          ]
        ]
      >

      iex> t = Nx.tensor([[[9, 8, 7], [6, 5, 4], [3, 2, 1]], [[-1, 0, -1], [1, 0, 1], [1, 1, 1]]]) |> Nx.vectorize(x: 2)
      iex> {p, l, u} = Nx.LinAlg.lu(t)
      iex> p
      #Nx.Tensor<
        vectorized[x: 2]
        s64[3][3]
        [
          [
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
          ],
          [
            [1, 0, 0],
            [0, 0, 1],
            [0, 1, 0]
          ]
        ]
      >
      iex> l
      #Nx.Tensor<
        vectorized[x: 2]
        f32[3][3]
        [
          [
            [1.0, 0.0, 0.0],
            [0.6666666865348816, 1.0, 0.0],
            [0.3333333432674408, 2.0, 1.0]
          ],
          [
            [1.0, 0.0, 0.0],
            [-1.0, 1.0, 0.0],
            [-1.0, 0.0, 1.0]
          ]
        ]
      >
      iex> u
      #Nx.Tensor<
        vectorized[x: 2]
        f32[3][3]
        [
          [
            [9.0, 8.0, 7.0],
            [0.0, -0.3333333432674408, -0.6666666865348816],
            [0.0, 0.0, 0.0]
          ],
          [
            [-1.0, 0.0, -1.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 0.0]
          ]
        ]
      >

  ## Error cases

      iex> Nx.LinAlg.lu(Nx.tensor([[1, 1, 1, 1], [-1, 4, 4, -1], [4, -2, 2, 0]]))
      ** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {3, 4}
  """
  def lu(tensor, opts \\ []) do
    apply_vectorized(tensor, fn tensor ->
      opts = keyword!(opts, eps: 1.0e-10)
      %T{type: type, shape: shape} = tensor

      output_type = Nx.Type.to_floating(type)
      {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape)
      names = List.duplicate(nil, tuple_size(shape))

      impl!(tensor).lu(
        {%{tensor | type: type, shape: p_shape, names: names},
         %{tensor | type: output_type, shape: l_shape, names: names},
         %{tensor | type: output_type, shape: u_shape, names: names}},
        tensor,
        opts
      )
    end)
  end

  @doc """
  Produces the tensor taken to the given power by matrix dot-product.

  The input is always a tensor of batched square matrices and an integer,
  and the output is a tensor of the same dimensions as the input tensor.

  The dot-products are unrolled inside `defn`.

  ## Examples

      iex> Nx.LinAlg.matrix_power(Nx.tensor([[1, 2], [3, 4]]), 0)
      #Nx.Tensor<
        s64[2][2]
        [
          [1, 0],
          [0, 1]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.tensor([[1, 2], [3, 4]]), 6)
      #Nx.Tensor<
        s64[2][2]
        [
          [5743, 8370],
          [12555, 18298]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.eye(3), 65535)
      #Nx.Tensor<
        s64[3][3]
        [
          [1, 0, 0],
          [0, 1, 0],
          [0, 0, 1]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.tensor([[1, 2], [3, 4]]), -1)
      #Nx.Tensor<
        f32[2][2]
        [
          [-2.000000476837158, 1.0000003576278687],
          [1.5000004768371582, -0.5000002384185791]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.iota({2, 2, 2}), 3)
      #Nx.Tensor<
        s64[2][2][2]
        [
          [
            [6, 11],
            [22, 39]
          ],
          [
            [514, 615],
            [738, 883]
          ]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.iota({2, 2, 2}), -3)
      #Nx.Tensor<
        f32[2][2][2]
        [
          [
            [-4.875, 1.375],
            [2.75, -0.75]
          ],
          [
            [-110.37397766113281, 76.8742904663086],
            [92.24915313720703, -64.2494125366211]
          ]
        ]
      >

      iex> Nx.LinAlg.matrix_power(Nx.tensor([[1, 2], [3, 4], [5, 6]]), 1)
      ** (ArgumentError) matrix_power/2 expects a square matrix or a batch of square matrices, got tensor with shape: {3, 2}
  """
  @doc from_backend: false
  def matrix_power(tensor, power) when is_integer(power) and power < 0 do
    matrix_power(invert(tensor), abs(power))
  end

  # We need a special-case for 0 since the code below
  # is optimized to not compute an initial eye.
  def matrix_power(tensor, 0) do
    shape = Nx.shape(tensor)
    :ok = Nx.Shape.matrix_power(shape)

    Nx.eye(shape)
  end

  def matrix_power(tensor, power) when is_integer(power) do
    shape = Nx.shape(tensor)
    :ok = Nx.Shape.matrix_power(shape)

    rank = Nx.rank(tensor)
    batches = Enum.to_list(0..(rank - 3)//1)
    dot_product = &Nx.dot(&1, [rank - 1], batches, &2, [rank - 2], batches)

    power
    |> Integer.digits(2)
    |> tl()
    |> Enum.reverse()
    |> Enum.reduce({nil, tensor}, fn
      1, {nil, exp_tensor} ->
        {exp_tensor, dot_product.(exp_tensor, exp_tensor)}

      1, {result_tensor, exp_tensor} ->
        {dot_product.(result_tensor, exp_tensor), dot_product.(exp_tensor, exp_tensor)}

      0, {result_tensor, exp_tensor} ->
        {result_tensor, dot_product.(exp_tensor, exp_tensor)}
    end)
    |> then(fn
      {nil, exp_tensor} -> exp_tensor
      {result, exp_tensor} -> dot_product.(result, exp_tensor)
    end)
  end

  @doc """
  Calculates the determinant of batched square matrices.

  ## Examples

  For 2x2 and 3x3, the results are given by the closed formulas:

      iex> Nx.LinAlg.determinant(Nx.tensor([[1, 2], [3, 4]]))
      #Nx.Tensor<
        f32
        -2.0
      >

      iex> Nx.LinAlg.determinant(Nx.tensor([[1.0, 2.0, 3.0], [1.0, -2.0, 3.0], [7.0, 8.0, 9.0]]))
      #Nx.Tensor<
        f32
        48.0
      >

  When there are linearly dependent rows or columns, the determinant is 0:

      iex> Nx.LinAlg.determinant(Nx.tensor([[1.0, 0.0], [3.0, 0.0]]))
      #Nx.Tensor<
        f32
        0.0
      >

      iex> Nx.LinAlg.determinant(Nx.tensor([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0], [4.0, 5.0, 6.0]]))
      #Nx.Tensor<
        f32
        0.0
      >

  The determinant can also be calculated when the axes are bigger than 3:

      iex> Nx.LinAlg.determinant(Nx.tensor([
      ...> [1, 0, 0, 0],
      ...> [0, 1, 2, 3],
      ...> [0, 1, -2, 3],
      ...> [0, 7, 8, 9.0]
      ...> ]))
      #Nx.Tensor<
        f32
        48.0
      >

      iex> Nx.LinAlg.determinant(Nx.tensor([
      ...> [0, 0, 0, 0, -1],
      ...> [0, 1, 2, 3, 0],
      ...> [0, 1, -2, 3, 0],
      ...> [0, 7, 8, 9, 0],
      ...> [1, 0, 0, 0, 0]
      ...> ]))
      #Nx.Tensor<
        f32
        48.0
      >

      iex> Nx.LinAlg.determinant(Nx.tensor([
      ...> [[2, 4, 6, 7], [5, 1, 8, 8], [1, 7, 3, 1], [3, 9, 2, 4]],
      ...> [[2, 5, 1, 3], [4, 1, 7, 9], [6, 8, 3, 2], [7, 8, 1, 4]]
      ...> ]))
      #Nx.Tensor<
        f32[2]
        [630.0, 630.0]
      >

      iex> t = Nx.tensor([[[1, 0], [0, 2]], [[3, 0], [0, 4]]]) |> Nx.vectorize(x: 2)
      iex> Nx.LinAlg.determinant(t)
      #Nx.Tensor<
        vectorized[x: 2]
        f32
        [2.0, 12.0]
      >

  If the axes are named, their names are not preserved in the output:

      iex> two_by_two = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])
      iex> Nx.LinAlg.determinant(two_by_two)
      #Nx.Tensor<
        f32
        -2.0
      >

      iex> three_by_three = Nx.tensor([[1.0, 2.0, 3.0], [1.0, -2.0, 3.0], [7.0, 8.0, 9.0]], names: [:x, :y])
      iex> Nx.LinAlg.determinant(three_by_three)
      #Nx.Tensor<
        f32
        48.0
      >

  Also supports complex inputs:

      iex> t = Nx.tensor([[1, 0, 0], [0, Complex.new(0, 2), 0], [0, 0, 3]])
      iex> Nx.LinAlg.determinant(t)
      #Nx.Tensor<
        c64
        0.0+6.0i
      >

      iex> t = Nx.tensor([[0, 0, 0, 1], [0, Complex.new(0, 2), 0, 0], [0, 0, 3, 0], [1, 0, 0, 0]])
      iex> Nx.LinAlg.determinant(t)
      #Nx.Tensor<
        c64
        -0.0-6.0i
      >

  """
  # IMPORTANT: This function cannot be a defn because
  # optional needs to work on the actual backend.
  def determinant(tensor) do
    apply_vectorized(tensor, fn tensor ->
      shape = Nx.shape(tensor)
      {batch_shape, matrix_shape} = shape |> Tuple.to_list() |> Enum.split(-2)
      output = Nx.template(List.to_tuple(batch_shape), Nx.Type.to_floating(tensor.type))

      case matrix_shape do
        [n, n] ->
          :ok

        shape ->
          raise ArgumentError,
                "determinant/1 expects a square tensor, got tensor with shape: #{inspect(shape)}"
      end

      Nx.Shared.optional(:determinant, [tensor], output, fn tensor ->
        case matrix_shape do
          [2, 2] ->
            determinant_2by2(tensor)

          [3, 3] ->
            determinant_3by3(tensor)

          [n, n] ->
            determinant_NbyN(tensor, batch_shape_n: List.to_tuple(batch_shape ++ [n]))
        end
      end)
    end)
  end

  defnp determinant_2by2(t) do
    t = Nx.tile(t, [1, 2])

    result = diagonal_product(t, 0) - diagonal_product(t, 1)

    # Ensure floating point result
    result * 1.0
  end

  defnp determinant_3by3(t) do
    rank = Nx.rank(t)
    pos_t = Nx.tile(t, [1, 2])

    neg_t = Nx.reverse(pos_t, axes: [rank - 1])

    result =
      diagonal_product(pos_t, 0) +
        diagonal_product(pos_t, 1) +
        diagonal_product(pos_t, 2) -
        diagonal_product(neg_t, 0) -
        diagonal_product(neg_t, 1) -
        diagonal_product(neg_t, 2)

    # Ensure floating point result
    result * 1.0
  end

  defnp determinant_NbyN(t, opts \\ []) do
    batch_shape_n = assert_keys(opts, [:batch_shape_n])[:batch_shape_n]
    rank = Nx.rank(t)
    shape = Nx.shape(t)

    # Taken from slogdet at https://github.com/google/jax/blob/a3a6afcd5b8bf3d60aba94054bb0001c0fcc50d7/jax/_src/numpy/linalg.py#L134
    {p, l, u} = Nx.LinAlg.lu(t)

    diag = Nx.take_diagonal(l) * Nx.take_diagonal(u)
    is_zero = Nx.any(diag == 0, axes: [-1])

    {batch_axes, transition_bcast_axes_1, transition_bcast_axes_2} = determinant_axes(rank)

    transitions =
      Nx.dot(
        Nx.real(p),
        [rank - 1],
        batch_axes,
        Nx.iota(batch_shape_n, axis: -1),
        [rank - 2],
        batch_axes
      )

    upper_tri_mask = Nx.iota(shape, axis: -2) < Nx.iota(shape, axis: -1)

    transitions_gt =
      Nx.broadcast(transitions, shape, axes: transition_bcast_axes_1) >
        Nx.broadcast(transitions, shape, axes: transition_bcast_axes_2)

    parity = Nx.sum(transitions_gt * upper_tri_mask, axes: [-2, -1])

    sign = -2 * Nx.remainder(parity, 2) + 1

    Nx.select(is_zero, 0, sign * Nx.product(diag, axes: [-1]))
  end

  deftransformp determinant_axes(rank) do
    batch_axes = Enum.to_list(0..(rank - 3)//1)
    transition_bcast_axes_1 = Enum.to_list(0..(rank - 2))
    transition_bcast_axes_2 = batch_axes ++ [rank - 1]
    {batch_axes, transition_bcast_axes_1, transition_bcast_axes_2}
  end

  defnp diagonal_product(t, offset) do
    rank = Nx.rank(t)

    t
    |> Nx.take_diagonal(offset: offset)
    |> Nx.product(axes: [rank - 2])
  end

  @doc """
  Return matrix rank of input M × N matrix using Singular Value Decomposition method.

  Approximate the number of linearly independent rows by calculating the number
  of singular values greater than `eps * max(singular values) * max(M, N)`.

  This also appears in Numerical recipes in the discussion of SVD solutions for
  linear least squares [1].

  [1] W. H. Press, S. A. Teukolsky, W. T. Vetterling and B. P. Flannery,
  “Numerical Recipes (3rd edition)”, Cambridge University Press, 2007, page 795.

  ## Options

    * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-7`

  ## Examples

      iex> Nx.LinAlg.matrix_rank(Nx.tensor([[1, 2], [3, 4]]))
      #Nx.Tensor<
        u64
        2
      >

      iex> Nx.LinAlg.matrix_rank(Nx.tensor([[1, 1, 1, 1], [1, 1, 1, 1], [1, 2, 3, 4]]))
      #Nx.Tensor<
        u64
        2
      >

      iex> Nx.LinAlg.matrix_rank(Nx.tensor([[1, 1, 1], [2, 2, 2], [8, 9, 10], [-2, 1, 5]]))
      #Nx.Tensor<
        u64
        3
      >

  ## Error cases

      iex> Nx.LinAlg.matrix_rank(Nx.tensor([1, 2, 3]))
      ** (ArgumentError) tensor must have rank 2, got rank 1 with shape {3}

      iex> Nx.LinAlg.matrix_rank(Nx.tensor([[1, Complex.new(0, 2)], [3, Complex.new(0, -4)]]))
      ** (ArgumentError) Nx.LinAlg.matrix_rank/2 is not yet implemented for complex inputs
  """
  @doc from_backend: false
  defn matrix_rank(a, opts \\ []) do
    # TODO: support batching when SVD supports it too
    opts = keyword!(opts, eps: 1.0e-7)
    %T{type: type, shape: shape} = Nx.to_tensor(a)
    size = Nx.rank(shape)

    case type do
      {:c, _} ->
        raise ArgumentError, "Nx.LinAlg.matrix_rank/2 is not yet implemented for complex inputs"

      _ ->
        nil
    end

    if size != 2 do
      raise(
        ArgumentError,
        "tensor must have rank 2, got rank #{inspect(size)} with shape #{inspect(shape)}"
      )
    end

    # Calculate max dimension
    {row_dim, col_dim} = shape
    max_dim = if row_dim > col_dim, do: row_dim, else: col_dim

    # Calculate max singular value
    {_u, s, _v} = Nx.LinAlg.svd(a)

    s_max = Nx.reduce_max(s)

    # Set tolerance values
    tol = opts[:eps] * max_dim * s_max

    # Calculate matrix rank
    Nx.sum(s > tol)
  end

  @doc """
  Return the least-squares solution to a linear matrix equation Ax = b.

  ## Examples

      iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2]))
      #Nx.Tensor<
        f32[2]
        [1.0000004768371582, -2.665601925855299e-7]
      >

      iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1]))
      #Nx.Tensor<
        f32[2]
        [0.9966151118278503, -0.947966456413269]
      >

      iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2, 3], [4, 5, 6]]), Nx.tensor([1, 2]))
      #Nx.Tensor<
        f32[3]
        [-0.05534052848815918, 0.1111316829919815, 0.27760395407676697]
      >

  ## Error cases

      iex> Nx.LinAlg.least_squares(Nx.tensor([1, 2, 3]), Nx.tensor([1, 2]))
      ** (ArgumentError) tensor of 1st argument must have rank 2, got rank 1 with shape {3}

      iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([[1, 2], [3, 4]]))
      ** (ArgumentError) tensor of 2nd argument must have rank 1, got rank 2 with shape {2, 2}

      iex> Nx.LinAlg.least_squares(Nx.tensor([[1, Complex.new(0, 2)], [3, Complex.new(0, -4)]]),  Nx.tensor([1, 2]))
      ** (ArgumentError) Nx.LinAlg.least_squares/2 is not yet implemented for complex inputs

      iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2, 3]))
      ** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3}
  """
  @doc from_backend: false
  defn least_squares(a, b) do
    %T{type: a_type, shape: a_shape} = Nx.to_tensor(a)
    a_size = Nx.rank(a_shape)
    %T{type: b_type, shape: b_shape} = Nx.to_tensor(b)
    b_size = Nx.rank(b_shape)

    case a_type do
      {:c, _} ->
        raise ArgumentError, "Nx.LinAlg.least_squares/2 is not yet implemented for complex inputs"

      _ ->
        nil
    end

    case b_type do
      {:c, _} ->
        raise ArgumentError, "Nx.LinAlg.least_squares/2 is not yet implemented for complex inputs"

      _ ->
        nil
    end

    if a_size != 2 do
      raise(
        ArgumentError,
        "tensor of 1st argument must have rank 2, got rank #{inspect(a_size)} with shape #{inspect(a_shape)}"
      )
    end

    if b_size != 1 do
      raise(
        ArgumentError,
        "tensor of 2nd argument must have rank 1, got rank #{inspect(b_size)} with shape #{inspect(b_shape)}"
      )
    end

    {a1, _a2} = a_shape
    {b1} = b_shape

    if a1 != b1 do
      raise(
        ArgumentError,
        "the number of rows of the matrix as the 1st argument and " <>
          "the number of columns of the vector as the 2nd argument must be the same, " <>
          "got 1st argument shape #{inspect(a_shape)} and 2nd argument shape #{inspect(b_shape)}"
      )
    end

    case a_shape do
      {m, n} when m == n ->
        Nx.LinAlg.solve(a, b)

      {m, n} when m != n ->
        Nx.LinAlg.pinv(a, eps: 1.0e-15)
        |> Nx.dot(b)

      _ ->
        nil
    end
  end

  defp apply_vectorized(tensor, fun) when is_function(fun, 1) do
    # same as Nx's apply_vectorized defp, but written in a "public-api" way!
    %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)

    tensor
    |> Nx.devectorize()
    |> then(fun)
    |> case do
      %T{} = t ->
        Nx.vectorize(t, vectorized_axes)

      {a, b} ->
        {Nx.vectorize(a, vectorized_axes), Nx.vectorize(b, vectorized_axes)}

      {a, b, c} ->
        {
          Nx.vectorize(a, vectorized_axes),
          Nx.vectorize(b, vectorized_axes),
          Nx.vectorize(c, vectorized_axes)
        }
    end
  end
end