lib/exla.ex

defmodule EXLA do
  @moduledoc """
  [Google's XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) compiler/backend for Nx.

  It supports just-in-time (JIT) compilation to GPU (both CUDA and ROCm) and TPUs.

  ## Configuration

  ### In projects

  EXLA works both as a backend for `Nx` tensors and an optimized `Nx.Defn` compiler.
  To enable both globally, add a `config/config.exs` (or `config/ENV.exs`) with the following:

      import Config
      config :nx, :default_backend, EXLA.Backend
      config :nx, :default_defn_options, [compiler: EXLA]

  Now you can use `Nx` as usual and it will use `EXLA` by default.

  You can also use cuda/rocm/tpu as the target by setting `:client` option
  in both configuration:

      import Config
      config :nx, :default_backend, {EXLA.Backend, client: :cuda}
      config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]

  To use GPUs/TPUs, you must also set the appropriate value for the
  [`XLA_TARGET`](https://github.com/elixir-nx/xla#xla_target) environment
  variable. For CUDA, setting `ELIXIR_ERL_OPTIONS="+sssdio 128"` is also
  required on more complex operations to increase CUDA's compiler stack size.

  ### In scripts/notebooks

  The simplest way to configure EXLA in notebooks is by calling:

  ```elixir
  EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
  ```

  Then EXLA will pick the first platform available for the current
  notebook user. From then on, you can use `Nx` as usual and it will
  use `EXLA` by default.

  As in the project configuration above, you must also set the appropriate
  value for the [`XLA_TARGET`](https://github.com/elixir-nx/xla#xla_target)
  environment variable if you intend to use GPU/TPUs.

  ### Options

  The options accepted by EXLA configuration are:

    * `:client` - an atom representing the client to use. Defaults
      to `:host`. See "Clients" section

    * `:device_id` - the default device id to run the computation
        on. Defaults to the `:default_device_id` on the client

  ## Clients

  The `EXLA` library uses a client for compiling and executing code.
  Those clients are typically bound to a platform, such as CPU or
  GPU.

  Those clients are singleton resources on Google's XLA library,
  therefore they are treated as a singleton resource on this library
  too. EXLA ships with the client configuration for each supported
  platform, which would be the equivalent to this:

      config :exla, :clients,
        host: [platform: :host],
        cuda: [platform: :cuda],
        rocm: [platform: :rocm],
        tpu: [platform: :tpu]

  > **Important!** you should avoid using multiple clients for the
  > same platform. If you have multiple clients per platform, they
  > can race each other and fight for resources, such as memory.
  > Therefore, we recommend developers to stick with the default
  > clients above.

  ### Client options

  Each client configuration accepts the following options:

    * `:platform` - the platform the client runs on. It can be
      `:host` (CPU), `:cuda`, `:rocm`, or `:tpu`.

    * `:default_device_id` - the default device ID to run on.
      For example, if you have two GPUs, you can choose a different
      one as the default. Defaults to device 0 (the first device).

    * `:preallocate`- if the memory should be preallocated on
      GPU devices. Defaults to `true`.

    * `:memory_fraction` - how much memory of a GPU device to
      allocate. Defaults to `0.9`.

  ### GPU Runtime Issues

  GPU Executions run in dirty IO threads, which have a considerable smaller
  stack size than regular scheduler threads. This may lead to problems with
  certain CUDA or cuDNN versions, leading to segmentation fails. In a development
  environment, it is suggested to set:

      ELIXIR_ERL_OPTIONS="+sssdio 128"

  To increase the stack size of dirty IO threads from 40 kilowords to
  128 kilowords. In a release, you can set this flag in your `vm.args`.

  ## Device allocation

  EXLA also ships with a `EXLA.Backend` that allows data to be explicitly
  allocated on the EXLA device. You can create tensors with `EXLA.Backend`
  directly:

      Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)

  or you can configure `EXLA.Backend` as the default backend, so that
  all tensors are allocated on the EXLA device by default.

  In some cases you may want to explicitly move an existing tensor to
  the device:

      tensor = Nx.tensor([1, 2, 3, 4], backend: Nx.BinaryBackend)
      Nx.backend_transfer(tensor, EXLA.Backend)

  Note that you can use regular `Nx` operations, so the following works:

      tensor = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
      Nx.sum(tensor)

  Under the hood, EXLA will create a computation for the sum operation
  and invoke it on the device. This is essentially an "eager mode"
  that provides acceleration during prototyping. However, eventually
  you should wrap your computations in a `defn` to utilize the full
  performance of JIT.

  ## Docker considerations

  EXLA should run fine on Docker with one important consideration:
  you must not start the Erlang VM as the root process in Docker.
  That's because when the Erlang VM runs as root, it has to manage
  all child programs.

  At the same time, Google XLA's shells out to child program during
  compilation and it must retain control over how child programs
  terminate.

  To address this, simply make sure you wrap the Erlang VM in
  another process, such as the shell one. In other words, if you
  are using releases, instead of this:

      RUN path/to/release start

  do this:

      RUN sh -c "path/to/release start"

  If you are using Mix inside your Docker containers, instead of this:

      RUN mix run

  do this:

      RUN sh -c "mix run"

  Alternatively, you can pass the `--init` flag to `docker run`, so
  it runs an `init` inside the container that forwards signals and
  reaps processes.
  """

  @behaviour Nx.Defn.Compiler

  @doc """
  Sets the global defn options to the EXLA compiler with the preferred
  client based on their availability.

  This function is typically invoked at the top of scripts and code
  notebooks which might be potentially executed from multiple platforms.
  Do not invoke this function during runtime, as it changes `Nx.Defn`
  options globally. If you have a specific client that you want to use
  throughout your project, use configuration files instead:

      import Config
      config :nx, :default_backend, {EXLA.Backend, client: :cuda}
      config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]

  ## Examples

      EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])

  The above will try to find the first client available and set
  the `EXLA` compiler with the client as the compilers for `Nx.Defn`.
  If no client is found, `EXLA` is not set as compiler at all,
  therefore it is common to add `:host` as the last option.

  If additional options are given, they are given as compiler options:

      EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])

  To use the GPU or TPUs, don't forget to also set the appropriate value
  for the [`XLA_TARGET`](https://github.com/elixir-nx/xla#xla_target)
  environment variable.
  """
  def set_as_nx_default(clients, opts \\ []) do
    supported_platforms = EXLA.Client.get_supported_platforms()
    all_clients = Application.fetch_env!(:exla, :clients)

    chosen =
      Enum.find(clients, fn client ->
        client_config = all_clients[client]
        client_platform = client_config[:platform] || :host
        client_config && Map.has_key?(supported_platforms, client_platform)
      end)

    if chosen do
      opts = Keyword.put(opts, :client, chosen)
      Nx.default_backend({EXLA.Backend, opts})
      Nx.Defn.global_default_options([compiler: EXLA] ++ opts)
      chosen
    end
  end

  @doc false
  @deprecated "Use set_as_nx_default/2 instead"
  def set_preferred_defn_options(clients, opts \\ []) do
    set_as_nx_default(clients, opts)
  end

  @doc """
  A shortcut for `Nx.Defn.jit/3` with the EXLA compiler.

      iex> EXLA.jit(&Nx.add(&1, &1), [Nx.tensor([1, 2, 3])])
      #Nx.Tensor<
        s64[3]
        [2, 4, 6]
      >

  See the moduledoc for options.
  """
  def jit(function, args, options \\ []) do
    Nx.Defn.jit(function, args, Keyword.put(options, :compiler, EXLA))
  end

  @doc """
  Starts streaming the given anonymous function with just-in-time
  compilation.

  At least two arguments are expected:

    1. The first argument is a tensor template of the data to
       be streamed in

    2. The second argument is a tensor with the stream initial state

  The streaming function must return a two element tuple, the
  first element is the data to be sent and the second is the
  accumulator.

  For each streamed chunk, you must call `Nx.Stream.send/2` and
  `Nx.Stream.recv/1`. You don't need to call `recv` immediately
  after `send`, but doing so can be a useful mechanism to provide
  backpressure. Once all chunks are sent, you must use `Nx.Stream.done/1`
  to receive the accumulated result. Let's see an example:

      defmodule Streamed do
        import Nx.Defn

        defn sum(tensor, acc) do
          {acc, tensor + acc}
        end
      end

  Now let's invoke it:

      stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 64}), 0])

      for i <- 1..5 do
        Nx.Stream.send(stream, i)
        IO.inspect {:chunk, Nx.Stream.recv(stream)}
      end

      IO.inspect {:result, Nx.Stream.done(stream)}

  It will print:

      {:chunk, 0}
      {:chunk, 1}
      {:chunk, 2}
      {:chunk, 3}
      {:chunk, 4}
      {:result, 5}

  **Note:** While any process can call `Nx.Stream.send/2`, EXLA
  expects the process that starts the streaming to be the one
  calling `Nx.Stream.recv/1` and `Nx.Stream.done/1`.
  """
  def stream(function, args, options \\ []) do
    Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA))
  end

  @doc """
  Checks if the JIT compilation of function with
  args is cached.

  Note that hooks are part of the cache, and
  therefore they must be included in the options.

  ## Examples

      iex> fun = fn a, b -> Nx.add(a, b) end
      iex> left = Nx.tensor(1, type: {:u, 8})
      iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
      iex> EXLA.jit(fun, [left, right])
      iex> EXLA.jit_cached?(fun, [left, right])
      true
      iex> EXLA.jit_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
      false

  """
  def jit_cached?(function, args, options \\ []) do
    jit(function, args, [{EXLA, cached_check()} | options])
  catch
    {:cached?, bool} -> bool
  end

  @doc """
  Checks if the JIT compilation of stream with
  args is cached.

  Note that hooks are part of the cache, and
  therefore they must be included in the options.

  ## Examples

      iex> left = Nx.tensor(1, type: {:u, 8})
      iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
      iex> fun = fn x, acc -> {acc, Nx.add(x, acc)} end
      iex> stream = EXLA.stream(fun, [left, right])
      iex> Nx.Stream.done(stream)
      iex> EXLA.stream_cached?(fun, [left, right])
      true
      iex> EXLA.stream_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
      false
  """
  def stream_cached?(function, args, options \\ []) do
    stream(function, args, [{EXLA, cached_check()} | options])
  catch
    {:cached?, bool} -> bool
  end

  defp cached_check do
    expr_cache_fun = fn key, _callback ->
      if res = EXLA.Defn.LockedCache.get(key) do
        {nil, res}
      else
        throw({:cached?, false})
      end
    end

    comp_cache_fun = fn key, _callback ->
      throw({:cached?, EXLA.Defn.LockedCache.get(key) != nil})
    end

    {expr_cache_fun, comp_cache_fun}
  end

  @impl true
  defdelegate __jit__(key, vars, fun, args, opts), to: EXLA.Defn

  @impl true
  defdelegate __stream__(key, input, acc, vars, fun, args, opts), to: EXLA.Defn
end