lib/nx/random.ex

defmodule Nx.Random do
  @moduledoc """
  Pseudo-random number generators.

  Unlike the stateful pseudo-random number generators (PRNGs)
  that users of most programming languages and numerical libraries
  may be accustomed to, Nx random functions require an explicit
  PRNG key to be passed as a first argument. That key is defined by
  an `Nx.Tensor` composed of 2 unsigned 32-bit integers, usually
  generated by the `Nx.Random.key/1` function:

      iex> Nx.Random.key(12)
      #Nx.Tensor<
        u32[2]
        [0, 12]
      >

  This key can then be used in any of Nx’s random number generation
  routines:

      iex> key = Nx.Random.key(12)
      iex> {uniform, _new_key} = Nx.Random.uniform(key)
      iex> uniform
      #Nx.Tensor<
        f32
        0.7691127061843872
      >

  Now, when generating a new random number, you pass the `new_key`
  to get a different number.

  ## Design and Context

  In short, Nx's PRNGs are based on a Threefry counter PRNG
  associated to a functional array-oriented splitting model.
  To summarize, among other requirements, Nx's PRNG aims to:

  1. Ensure reproducibility

  2. Parallelize well, both in terms of vectorization
     (generating array values) and multi-replica, multi-core
     computation. In particular it should not use sequencing
     constraints between random function calls.
  """

  import Nx.Defn, only: [deftransformp: 2, defn: 2, defnp: 2]

  @nbits 32

  @doc """
  Create a pseudo-random number generator (PRNG) key given an integer seed.

  ## Examples

      iex> Nx.Random.key(12)
      #Nx.Tensor<
        u32[2]
        [0, 12]
      >

      iex> Nx.Random.key(999999999999)
      #Nx.Tensor<
        u32[2]
        [232, 3567587327]
      >
  """
  defn key(seed) do
    k1 = Nx.right_shift(seed, 32)
    k2 = Nx.bitwise_and(seed, 0xFFFFFFFF)

    Nx.stack([k1, k2])
    |> Nx.as_type(:u32)
  end

  @doc """
  Splits a PRNG key into `num` new keys by adding a leading axis.

  ## Examples

      iex> key = Nx.Random.key(1701)
      iex> Nx.Random.split(key)
      #Nx.Tensor<
        u32[2][2]
        [
          [56197195, 1801093307],
          [961309823, 1704866707]
        ]
      >

      iex> key = Nx.Random.key(999999999999)
      iex> Nx.Random.split(key, parts: 4)
      #Nx.Tensor<
        u32[4][2]
        [
          [3959978897, 4079927650],
          [3769699049, 3585271160],
          [3182829676, 333122445],
          [3185556048, 1258545461]
        ]
      >
  """
  defn split(key, opts \\ []) do
    assert_key!(key)
    opts = keyword!(opts, parts: 2)
    threefry2x32(key, {opts[:parts], 2})
  end

  @doc """
  Folds in new data to a PRNG key.

  ## Examples

      iex> key = Nx.Random.key(42)
      iex> Nx.Random.fold_in(key, 99)
      #Nx.Tensor<
        u32[2]
        [2015327502, 1351855566]
      >

      iex> key = Nx.Random.key(42)
      iex> Nx.Random.fold_in(key, 1234)
      #Nx.Tensor<
        u32[2]
        [1356445167, 2917756949]
      >

      iex> key = Nx.Random.key(42)
      iex> Nx.Random.fold_in(key, Nx.tensor([[1, 99], [1234, 13]]))
      #Nx.Tensor<
        u32[2][2][2]
        [
          [
            [64467757, 2916123636],
            [2015327502, 1351855566]
          ],
          [
            [1356445167, 2917756949],
            [3514951389, 229662949]
          ]
        ]
      >
  """

  defn fold_in(key, data) do
    assert_key!(key)

    k1 = Nx.right_shift(data, 32)
    k2 = Nx.bitwise_and(data, 0xFFFFFFFF)

    {x1, x2} =
      Nx.stack([k1, k2])
      |> Nx.reshape({2, :auto})
      |> Nx.as_type(:u32)
      |> threefry2x32_20_pair(key)

    [x1, x2]
    |> Nx.stack(axis: -1)
    |> Nx.reshape(fold_shape(Nx.shape(data)))
  end

  deftransformp fold_shape(shape) do
    Tuple.insert_at(shape, tuple_size(shape), 2)
  end

  defnp threefry2x32(key, shape) do
    case shape |> Nx.size() |> rem(2) do
      0 ->
        Nx.iota({2, div(Nx.size(shape), 2)}, type: :u32)
        |> threefry2x32_20_concat(key)
        |> Nx.reshape(shape)

      1 ->
        Nx.concatenate([Nx.iota({Nx.size(shape)}, type: :u32), Nx.tensor([0], type: :u32)])
        |> Nx.reshape({2, :auto})
        |> threefry2x32_20_concat(key)
        |> Access.get(0..-2//1)
        |> Nx.reshape(shape)
    end
  end

  defn threefry2x32_20_concat(xs, ks) do
    {nx1, nx2} = threefry2x32_20_pair(xs, ks)
    Nx.concatenate([nx1, nx2], axis: 0)
  end

  defnp threefry2x32_20_pair(xs, ks) do
    rotations = {Nx.tensor([13, 15, 26, 6], type: :u8), Nx.tensor([17, 29, 16, 24], type: :u8)}

    key1 = ks[0]
    key2 = ks[1]
    xs = {xs[0] + key1, xs[1] + key2}

    ks = {
      key2,
      Nx.bitwise_xor(key1, key2)
      |> Nx.bitwise_xor(0x1BD11BDA),
      key1
    }

    state = {xs, ks, rotations}

    {_, {{nx1, nx2}, _, _}} =
      while {x = Nx.tensor(1, type: :u32), state}, x < 6 do
        {x + Nx.tensor(1, type: :u32), rolled_loop_step(x, state)}
      end

    {nx1, nx2}
  end

  defnp apply_round({xs1, xs2}, rot) do
    y1 = xs1 + xs2

    y2 =
      rotate_left(xs2, rot)
      |> Nx.bitwise_xor(y1)

    {y1, y2}
  end

  defnp rolled_loop_step(i, {{_xs1, _xs2} = xs, {k1, k2, k3}, {r1, r2}}) do
    {xs1, xs2} =
      while xs, r <- r1 do
        apply_round(xs, r)
      end

    xs1 = k1 + xs1
    xs2 = k2 + i + xs2

    new_xs = {xs1, xs2}
    new_ks = {k2, k3, k1}
    new_rs = {r2, r1}

    {new_xs, new_ks, new_rs}
  end

  defnp rotate_left(x, rot) do
    x <<< rot ||| x >>> (Nx.tensor(@nbits, type: :u32) - rot)
  end

  defnp random_bits(key, opts \\ []) do
    assert_key!(key)
    opts = keyword!(opts, shape: {}, bit_width: 32)
    shape = opts[:shape]
    bit_width = opts[:bit_width]

    case bit_width do
      64 ->
        bits =
          threefry2x32(key, {2, Nx.size(shape)})
          |> Nx.as_type({:u, 64})

        bits = bits[0] <<< 32 ||| bits[1]
        Nx.reshape(bits, shape)

      32 ->
        threefry2x32(key, shape)

      _ ->
        threefry2x32(key, shape)
        |> Nx.as_type({:u, bit_width})
    end
  end

  defnp mantissa(type) do
    case type do
      {:bf, 16} -> 7
      {:f, 16} -> 10
      {:f, 32} -> 23
      {:f, 64} -> 52
    end
  end

  @doc """
  Sample uniform random integer values in `[min_value, max_value)`.

  ## Options

    * `:type` - the integer type for the returned tensor
    * `:shape` - shape of the returned tensor
    * `:names` - the names of the returned tensor

  ## Examples

      iex> key = Nx.Random.key(1701)
      iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100)
      iex> randint
      #Nx.Tensor<
        s64
        66
      >

      iex> key = Nx.Random.key(1701)
      iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100, shape: {3, 2}, type: :u32)
      iex> randint
      #Nx.Tensor<
        u32[3][2]
        [
          [9, 20],
          [19, 6],
          [71, 15]
        ]
      >

  """
  defn randint(key, min_val, max_val, opts \\ []) do
    keys = split(key)
    {randint_split(keys[1], min_val, max_val, opts), keys[0]}
  end

  @doc """
  Same as `randint/4` but assumes the key has already been split.
  """
  defn randint_split(key, min_val, max_val, opts \\ []) do
    opts = keyword!(opts, [:names, :type, shape: {}])
    assert_key!(key)

    shape = opts[:shape]
    type = {_, nbits} = infer_type(min_val, max_val, opts)

    case type do
      {:u, _} -> :ok
      {:s, _} -> :ok
      _ -> raise ArgumentError, "expected integer type, got type #{inspect(type)}"
    end

    min_val = Nx.broadcast(min_val, shape)
    max_val = Nx.broadcast(max_val, shape)

    random_bits = random_bits(key, shape: randint_random_bits_shape(shape), bit_width: nbits)

    higher_bits = random_bits[0]
    lower_bits = random_bits[1]

    span = max_val - min_val

    multiplier =
      Nx.power(2, Nx.quotient(nbits, 2))
      |> Nx.remainder(span)
      |> Nx.power(2)
      |> Nx.remainder(span)

    offset =
      higher_bits
      |> Nx.remainder(span)
      |> Nx.multiply(multiplier)
      |> Nx.add(Nx.remainder(lower_bits, span))
      |> Nx.remainder(span)

    (min_val + offset)
    |> Nx.as_type(type)
    |> Nx.reshape(shape, take_names(opts))
  end

  deftransformp randint_random_bits_shape(shape), do: Tuple.insert_at(shape, 0, 2)

  @doc """
  Shortcut for `uniform(key, 0.0, 1.0, opts)`.
  """
  defn uniform(key, opts \\ []) do
    uniform(key, 0.0, 1.0, opts)
  end

  @doc """
  Sample uniform float values in `[min_val, max_val)`.

  ## Options

    * `:type` - a float type for the returned tensor

    * `:shape` - shape of the returned tensor

    * `:names` - the names of the returned tensor

  ## Examples

      iex> key = Nx.Random.key(1701)
      iex> {uniform, _new_key} = Nx.Random.uniform(key)
      iex> uniform
      #Nx.Tensor<
        f32
        0.9728643894195557
      >

      iex> key = Nx.Random.key(1701)
      iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {3, 2}, type: :f16)
      iex> uniform
      #Nx.Tensor<
        f16[3][2]
        [
          [0.75390625, 0.6484375],
          [0.7294921875, 0.21484375],
          [0.09765625, 0.0693359375]
        ]
      >

      iex> key = Nx.Random.key(1701)
      iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {2, 2}, type: :c64)
      iex> uniform
      #Nx.Tensor<
        c64[2][2]
        [
          [0.18404805660247803+0.6546461582183838i, 0.5525915622711182+0.11568140983581543i],
          [0.6074584722518921+0.8104375600814819i, 0.247686505317688+0.21975469589233398i]
        ]
      >
  """
  defn uniform(key, min_val, max_val, opts \\ []) do
    keys = split(key)
    {uniform_split(keys[1], min_val, max_val, opts), keys[0]}
  end

  @doc """
  Same as `uniform/4` but assumes the key has already been split.
  """
  defn uniform_split(key, min_value, max_value, opts \\ []) do
    assert_key!(key)
    opts = keyword!(opts, [:names, :type, shape: {}])
    type = infer_float_type(min_value, max_value, opts)

    float_or_complex(key, type, opts[:shape], fn key, {_type, nbits} = type, shape ->
      u_one = Nx.tensor(1.0, type: type) |> Nx.bitcast({:u, nbits})

      min_value = Nx.as_type(min_value, type)
      max_value = Nx.as_type(max_value, type)

      random_bits(key, shape: shape, bit_width: nbits)
      |> Nx.right_shift(Nx.tensor(nbits - mantissa(type), type: {:u, nbits}))
      |> Nx.bitwise_or(u_one)
      |> Nx.bitcast(type)
      |> Nx.subtract(Nx.tensor(1.0, type: type))
      |> Nx.multiply(max_value - min_value)
      |> Nx.add(min_value)
      |> Nx.max(min_value)
      |> Nx.reshape(shape, take_names(opts))
    end)
  end

  @doc """
  Shortcut for `normal(key, 0.0, 1.0, opts)`.
  """
  defn normal(key, opts \\ []) do
    normal(key, 0.0, 1.0, opts)
  end

  @doc """
  Returns a normal distribution with the given `mean` and `standard_deviation`.

  ## Options

    * `:type` - a float or complex type for the returned tensor

    * `:shape` - shape of the returned tensor

    * `:names` - the names of the returned tensor

  ## Examples

      iex> key = Nx.Random.key(42)
      iex> {normal, _new_key} = Nx.Random.normal(key)
      iex> normal
      #Nx.Tensor<
        f32
        1.3694692850112915
      >

      iex> key = Nx.Random.key(42)
      iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {3, 2}, type: :f16)
      iex> normal
      #Nx.Tensor<
        f16[3][2]
        [
          [-0.326416015625, -0.7734375],
          [0.3916015625, 0.533203125],
          [0.27001953125, -2.083984375]
        ]
      >

      iex> key = Nx.Random.key(42)
      iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {2, 2}, type: :c64)
      iex> normal
      #Nx.Tensor<
        c64[2][2]
        [
          [-0.763276219367981+0.8661126494407654i, -0.14282897114753723-0.7384797930717468i],
          [0.6784614324569702+0.4118310213088989i, -2.2695391178131104-0.3689095377922058i]
        ]
      >

      iex> key = Nx.Random.key(1337)
      iex> {normal, _new_key} = Nx.Random.normal(key, 10, 5, shape: {1_000})
      iex> Nx.mean(normal)
      #Nx.Tensor<
        f32
        9.70021915435791
      >
      iex> Nx.standard_deviation(normal)
      #Nx.Tensor<
        f32
        5.051421642303467
      >
  """
  defn normal(key, mean, standard_deviation, opts \\ []) do
    keys = split(key)
    {normal_split(keys[1], mean, standard_deviation, opts), keys[0]}
  end

  @doc """
  Same as `normal/4` but assumes the key has already been split.
  """
  defn normal_split(key, mean, standard_deviation, opts \\ []) do
    assert_key!(key)
    opts = keyword!(opts, [:names, :type, shape: {}])
    type = infer_float_type(mean, standard_deviation, opts)

    float_or_complex(key, type, opts[:shape], fn key, type, shape ->
      min_value = -1 + Nx.Constants.smallest_positive_normal(type)
      u = uniform_split(key, min_value, 1, opts |> put_type(type) |> put_shape(shape))

      normal = Nx.sqrt(Nx.tensor(2, type: type)) * Nx.erf_inv(u)
      Nx.as_type(standard_deviation, type) * normal + Nx.as_type(mean, type)
    end)
  end

  deftransformp float_or_complex(key, type, shape, fun) do
    case type do
      {:c, _} ->
        type = Nx.Type.to_real(type)
        data = fun.(key, type, Tuple.append(shape, 2))
        to_complex = Nx.stack([1, Nx.Constants.i()])
        Nx.dot(data, to_complex)

      {t, _} when t == :f or t == :bf ->
        fun.(key, type, shape)

      _ ->
        raise ArgumentError, "expected float or complex type, got type #{inspect(type)}"
    end
  end

  deftransformp take_names(opts), do: Keyword.take(opts, [:names])

  deftransformp infer_type(left, right, opts) do
    if type = opts[:type] do
      Nx.Type.normalize!(type)
    else
      Nx.Type.merge(Nx.type(left), Nx.type(right))
    end
  end

  deftransformp infer_float_type(left, right, opts) do
    if type = opts[:type] do
      Nx.Type.normalize!(type)
    else
      Nx.Type.to_floating(Nx.Type.merge(Nx.type(left), Nx.type(right)))
    end
  end

  deftransformp put_type(opts, type), do: Keyword.put(opts, :type, type)
  deftransformp put_shape(opts, shape), do: Keyword.put(opts, :shape, shape)

  defnp assert_key!(tensor) do
    %{shape: shape, type: type} = tensor

    case shape do
      {2} ->
        :ok

      _ ->
        raise ArgumentError,
              "expected key to have shape {2}, got tensor with shape #{inspect(shape)}"
    end

    case type do
      {:u, 32} ->
        :ok

      _ ->
        raise ArgumentError,
              "expected key with 32-bit unsigned integer type, got key with type #{inspect(type)}"
    end
  end
end