lib/nx_image.ex

defmodule NxImage do
  @moduledoc """
  Image processing in `Nx`.

  All functions expect images to be tensors in either HWC or CHW order,
  with an arbitrary number of leading batch axes.

  All transformations preserve the input type, rounding if necessary.
  For higher precision, cast the input to floating-point beforehand.
  """

  import Nx.Defn

  @doc """
  Crops an image at the center.

  If the image is too small to be cropped to the desired size, it gets
  padded with zeros.

  ## Options

    * `:channels` - channels location, either `:first` or `:last`.
      Defaults to `:last`

  ## Examples

      iex> image = Nx.iota({4, 4, 1}, type: :u8)
      iex> NxImage.center_crop(image, {2, 2})
      #Nx.Tensor<
        u8[2][2][1]
        [
          [
            [5],
            [6]
          ],
          [
            [9],
            [10]
          ]
        ]
      >

      iex> image = Nx.iota({2, 2, 1}, type: :u8)
      iex> NxImage.center_crop(image, {1, 4})
      #Nx.Tensor<
        u8[1][4][1]
        [
          [
            [0],
            [0],
            [1],
            [0]
          ]
        ]
      >

  """
  @doc type: :transformation
  deftransform center_crop(input, size, opts \\ []) when is_tuple(size) do
    opts = Keyword.validate!(opts, channels: :last)
    validate_image!(input)

    pad_config =
      for {axis, size, out_size} <- spatial_axes_with_sizes(input, size, opts[:channels]),
          reduce: List.duplicate({0, 0, 0}, Nx.rank(input)) do
        pad_config ->
          low = div(size - out_size, 2)
          high = low + out_size
          List.replace_at(pad_config, axis, {-low, high - size, 0})
      end

    Nx.pad(input, 0, pad_config)
  end

  deftransformp spatial_axes_with_sizes(input, size, channels) do
    {height_axis, width_axis} = spatial_axes(input, channels)
    {height, width} = size(input, channels)
    {out_height, out_width} = size
    [{height_axis, height, out_height}, {width_axis, width, out_width}]
  end

  # Returns the image size as `{height, width}`.
  deftransformp size(input, channels) do
    {height_axis, width_axis} = spatial_axes(input, channels)
    {Nx.axis_size(input, height_axis), Nx.axis_size(input, width_axis)}
  end

  @doc """
  Resizes an image.

  ## Options

    * `:method` - the resizing method to use, either of `:nearest`,
      `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
      `:bilinear`

    * `:channels` - channels location, either `:first` or `:last`.
      Defaults to `:last`

  ## Examples

      iex> image = Nx.iota({2, 2, 1}, type: :u8)
      iex> NxImage.resize(image, {3, 3}, method: :nearest)
      #Nx.Tensor<
        u8[3][3][1]
        [
          [
            [0],
            [1],
            [1]
          ],
          [
            [2],
            [3],
            [3]
          ],
          [
            [2],
            [3],
            [3]
          ]
        ]
      >

      iex> image = Nx.iota({2, 2, 1}, type: :f32)
      iex> NxImage.resize(image, {3, 3}, method: :bilinear)
      #Nx.Tensor<
        f32[3][3][1]
        [
          [
            [0.0],
            [0.5],
            [1.0]
          ],
          [
            [1.0],
            [1.5],
            [2.0]
          ],
          [
            [2.0],
            [2.5],
            [3.0]
          ]
        ]
      >

  """
  @doc type: :transformation
  deftransform resize(input, size, opts \\ []) when is_tuple(size) do
    opts = Keyword.validate!(opts, channels: :last, method: :bilinear)
    validate_image!(input)

    {spatial_axes, out_shape} =
      input
      |> spatial_axes_with_sizes(size, opts[:channels])
      |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end)
      |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape ->
        {axis, put_elem(out_shape, axis, out_size)}
      end)

    resized_input =
      case opts[:method] do
        :nearest ->
          resize_nearest(input, out_shape, spatial_axes)

        :bilinear ->
          resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1)

        :bicubic ->
          resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1)

        :lanczos3 ->
          resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1))

        :lanczos5 ->
          resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1))

        method ->
          raise ArgumentError,
                "expected :method to be either of :nearest, :bilinear, :bicubic, " <>
                  ":lanczos3, :lanczos5, got: #{inspect(method)}"
      end

    cast_to(resized_input, input)
  end

  deftransformp spatial_axes(input, channels) do
    axes =
      case channels do
        :first -> [-2, -1]
        :last -> [-3, -2]
      end

    axes
    |> Enum.map(&Nx.axis_index(input, &1))
    |> List.to_tuple()
  end

  defnp cast_to(left, right) do
    left_type = Nx.type(left)
    right_type = Nx.type(right)

    left =
      if Nx.Type.float?(left_type) and Nx.Type.integer?(right_type) do
        Nx.round(left)
      else
        left
      end

    left
    |> Nx.as_type(right_type)
    |> Nx.reshape(left, names: Nx.names(right))
  end

  deftransformp resize_nearest(input, out_shape, spatial_axes) do
    singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple()

    for axis <- spatial_axes, reduce: input do
      input ->
        input_shape = Nx.shape(input)
        input_size = elem(input_shape, axis)
        output_size = elem(out_shape, axis)
        inv_scale = input_size / output_size
        offset = Nx.iota({output_size}) |> Nx.add(0.5) |> Nx.multiply(inv_scale)
        offset = offset |> Nx.floor() |> Nx.as_type({:s, 32})

        offset =
          offset
          |> Nx.reshape(put_elem(singular_shape, axis, output_size))
          |> Nx.broadcast(put_elem(input_shape, axis, output_size))

        Nx.take_along_axis(input, offset, axis: axis)
    end
  end

  @f32_eps :math.pow(2, -23)

  deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
    for axis <- spatial_axes, reduce: input do
      input ->
        resize_axis_with_kernel(input,
          axis: axis,
          output_size: elem(out_shape, axis),
          kernel_fun: kernel_fun
        )
    end
  end

  defnp resize_axis_with_kernel(input, opts) do
    axis = opts[:axis]
    output_size = opts[:output_size]
    kernel_fun = opts[:kernel_fun]

    input_size = Nx.axis_size(input, axis)

    inv_scale = input_size / output_size
    kernel_scale = max(1, inv_scale)

    sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
    x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
    weights = kernel_fun.(x)

    weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)

    weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0)

    input = Nx.dot(input, [axis], weights, [0])
    # The transformed axis is moved to the end, so we transpose back
    reorder_axis(input, -1, axis)
  end

  defnp fill_linear_kernel(x) do
    Nx.max(0, 1 - x)
  end

  defnp fill_cubic_kernel(x) do
    # See https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
    out = (1.5 * x - 2.5) * x * x + 1
    out = Nx.select(x >= 1, ((-0.5 * x + 2.5) * x - 4) * x + 2, out)
    Nx.select(x >= 2, 0, out)
  end

  @pi :math.pi()

  defnp fill_lanczos_kernel(radius, x) do
    y = radius * Nx.sin(@pi * x) * Nx.sin(@pi * x / radius)
    out = Nx.select(x > 1.0e-3, safe_divide(y, @pi ** 2 * x ** 2), 1)
    Nx.select(x > radius, 0, out)
  end

  defnp safe_divide(x, y) do
    x / Nx.select(y != 0, y, 1)
  end

  deftransformp reorder_axis(tensor, axis, target_axis) do
    axes = Nx.axes(tensor)
    {source_axis, axes} = List.pop_at(axes, axis)
    axes = List.insert_at(axes, target_axis, source_axis)
    Nx.transpose(tensor, axes: axes)
  end

  @doc """
  Scales an image such that the short edge matches the given size.

  ## Options

    * `:method` - the resizing method to use, same as `resize/2`

    * `:channels` - channels location, either `:first` or `:last`.
      Defaults to `:last`

  ## Examples

      iex> image = Nx.iota({2, 4, 1}, type: :u8)
      iex> resized_image = NxImage.resize_short(image, 3, method: :nearest)
      iex> Nx.shape(resized_image)
      {3, 6, 1}

      iex> image = Nx.iota({4, 2, 1}, type: :u8)
      iex> resized_image = NxImage.resize_short(image, 3, method: :nearest)
      iex> Nx.shape(resized_image)
      {6, 3, 1}

  """
  @doc type: :transformation
  deftransform resize_short(input, size, opts \\ []) when is_integer(size) do
    opts = Keyword.validate!(opts, channels: :last, method: :bilinear)
    validate_image!(input)
    resize_short_n(input, [size: size] ++ opts)
  end

  defnp resize_short_n(input, opts) do
    size = opts[:size]
    method = opts[:method]
    channels = opts[:channels]

    {height, width} = size(input, channels)
    {out_height, out_width} = resize_short_size(height, width, size)

    resize(input, {out_height, out_width}, method: method, channels: channels)
  end

  deftransformp resize_short_size(height, width, size) do
    {short, long} = if height < width, do: {height, width}, else: {width, height}

    out_short = size
    out_long = floor(size * long / short)

    if height < width, do: {out_short, out_long}, else: {out_long, out_short}
  end

  @doc """
  Normalizes an image according to the given per-channel mean and
  standard deviation.

    * `:channels` - channels location, either `:first` or `:last`.
      Defaults to `:last`

  ## Examples

      iex> image = Nx.iota({2, 2, 3}, type: :f32)
      iex> mean = Nx.tensor([0.485, 0.456, 0.406])
      iex> std = Nx.tensor([0.229, 0.224, 0.225])
      iex> NxImage.normalize(image, mean, std)
      #Nx.Tensor<
        f32[2][2][3]
        [
          [
            [-2.1179039478302, 2.4285714626312256, 7.084444522857666],
            [10.982532501220703, 15.821427345275879, 20.41777801513672]
          ],
          [
            [24.08296775817871, 29.214284896850586, 33.7511100769043],
            [37.183406829833984, 42.607139587402344, 47.08444595336914]
          ]
        ]
      >

  """
  @doc type: :transformation
  defn normalize(input, mean, std, opts \\ []) do
    opts = keyword!(opts, channels: :last)
    validate_image!(input)

    mean = broadcast_channel_info(mean, input, opts[:channels], "mean")
    std = broadcast_channel_info(std, input, opts[:channels], "std")

    normalized_input = (input - mean) / std

    cast_to(normalized_input, input)
  end

  deftransformp broadcast_channel_info(tensor, input, channels, name) do
    rank = Nx.rank(input)

    channels_axis =
      case channels do
        :first -> rank - 3
        :last -> rank - 1
      end

    num_channels = Nx.axis_size(input, channels_axis)

    case Nx.shape(tensor) do
      {^num_channels} ->
        :ok

      shape ->
        raise ArgumentError,
              "expected #{name} to have shape {#{num_channels}}, got: #{inspect(shape)}"
    end

    shape = 1 |> Tuple.duplicate(rank) |> put_elem(channels_axis, :auto)
    Nx.reshape(tensor, shape)
  end

  @doc """
  Converts pixel values (0-255) into a continuous range.

  ## Examples

      iex> image = Nx.tensor([[[0], [128]], [[191], [255]]])
      iex> NxImage.to_continuous(image, 0.0, 1.0)
      #Nx.Tensor<
        f32[2][2][1]
        [
          [
            [0.0],
            [0.501960813999176]
          ],
          [
            [0.7490196228027344],
            [1.0]
          ]
        ]
      >

      iex> image = Nx.tensor([[[0], [128]], [[191], [255]]])
      iex> NxImage.to_continuous(image, -1.0, 1.0)
      #Nx.Tensor<
        f32[2][2][1]
        [
          [
            [-1.0],
            [0.003921627998352051]
          ],
          [
            [0.49803924560546875],
            [1.0]
          ]
        ]
      >

  """
  @doc type: :conversion
  defn to_continuous(input, min, max) do
    validate_image!(input)

    input / 255.0 * (max - min) + min
  end

  @doc """
  Converts values from continuous range into pixel values (0-255).

  ## Examples

      iex> image = Nx.tensor([[[0.0], [0.5]], [[0.75], [1.0]]])
      iex> NxImage.from_continuous(image, 0.0, 1.0)
      #Nx.Tensor<
        u8[2][2][1]
        [
          [
            [0],
            [128]
          ],
          [
            [191],
            [255]
          ]
        ]
      >

      iex> image = Nx.tensor([[[-1.0], [0.0]], [[0.5], [1.0]]])
      iex> NxImage.from_continuous(image, -1.0, 1.0)
      #Nx.Tensor<
        u8[2][2][1]
        [
          [
            [0],
            [128]
          ],
          [
            [191],
            [255]
          ]
        ]
      >

  """
  @doc type: :conversion
  defn from_continuous(input, min, max) do
    validate_image!(input)

    input = (input - min) / (max - min) * 255.0

    input
    |> Nx.round()
    |> Nx.clip(0, 255)
    |> Nx.as_type(:u8)
  end

  deftransformp validate_image!(input) do
    rank = Nx.rank(input)

    if rank < 3 do
      raise ArgumentError,
            "expected the image input to have rank 3 or higher, got: #{inspect(rank)}"
    end
  end
end