Skip to main content

examples/gpu_node_demo.exs

# Nx.Vulkan.Node end-to-end demo.
#
# Boots the long-lived GPU node, synthesizes a Beta(2, 5) chain
# shader from a runtime spec, dispatches it under the node's
# serialized GenServer, and reports timings + the produced
# trajectory's first logp value (verified against the analytic
# expected value).
#
# Demonstrates the full Phase 2 contract:
#   1. Nx.Vulkan.PipelineCache.load    — restore vkPipelineCache from disk
#   2. Nx.Vulkan.Node.start_link       — boot the GenServer
#   3. Nx.Vulkan.ChainShaderSpecs.beta — fetch a family spec
#   4. Nx.Vulkan.Synthesis.compile     — render GLSL + glslangValidator
#   5. Nx.Vulkan.Node.with_node        — dispatch inside the node
#   6. Nx.Vulkan.PipelineCache.persist — write the cache back to disk
#
# Run from the nx_vulkan repo root:
#   mix run examples/gpu_node_demo.exs

defmodule GPUNodeDemo do
  @alpha 2.0
  @beta 5.0
  @n 1
  @k 32
  @eps 0.05

  def run do
    IO.puts("\n=== Nx.Vulkan.Node demo ===\n")

    # 1. Boot.
    Nx.Vulkan.Native.init()
    {:ok, _pid} = Nx.Vulkan.Node.start_link()

    IO.puts("device:  #{Nx.Vulkan.device_name()}")
    IO.puts("f64?     #{Nx.Vulkan.has_f64?()}")
    IO.puts("uuid:    #{Nx.Vulkan.PipelineCache.device_uuid_hex()}")

    # 2. Synthesize a Beta(α, β) chain shader at runtime.
    spec = Nx.Vulkan.ChainShaderSpecs.beta()
    t_synth = System.monotonic_time(:millisecond)
    {:ok, spv_path} = Nx.Vulkan.Synthesis.compile(spec)
    synth_ms = System.monotonic_time(:millisecond) - t_synth

    IO.puts("\nsynthesized Beta SPV in #{synth_ms} ms")
    IO.puts("  → #{spv_path}")
    IO.puts("  → #{File.stat!(spv_path).size} bytes")

    # 3. Build push-constants blob. The caller computes lgamma-derived
    # constants (nx_vulkan stays out of the math library business).
    logp_const = log_beta_neg(@alpha, @beta)
    push = Nx.Vulkan.ChainShaderSpecs.beta_push(@n, @k, @eps, @alpha, @beta, logp_const)

    # 4. Upload the initial state to GPU buffers.
    {:ok, q_ref} = Nx.Vulkan.upload_binary(<<0.0::little-float-32>>)
    {:ok, p_ref} = Nx.Vulkan.upload_binary(<<0.5::little-float-32>>)
    {:ok, m_ref} = Nx.Vulkan.upload_binary(<<1.0::little-float-32>>)

    # 5. Dispatch INSIDE Nx.Vulkan.Node. This is the API that any
    # client (MCMC sampler, distributed worker, etc.) uses to share
    # the pipeline cache and watchdog protection.
    t_dispatch = System.monotonic_time(:microsecond)

    {:ok, {_q_chain, _p_chain, _grad_chain, logp_chain}} =
      Nx.Vulkan.Node.with_node(fn ->
        Nx.Vulkan.Native.leapfrog_chain_synth(q_ref, p_ref, m_ref, push, @k, spv_path)
      end)

    dispatch_us = System.monotonic_time(:microsecond) - t_dispatch
    IO.puts("\nfirst dispatch via with_node: #{dispatch_us} µs")

    # 6. Read the trajectory.
    {:ok, logp_bin} = Nx.Vulkan.Native.download_binary(logp_chain, @k * 4)
    [first_logp | _] = for <<v::little-float-32 <- logp_bin>>, do: v

    expected = analytic_logp_at_q_uc_zero(@alpha, @beta, @n)
    delta = abs(first_logp - expected)

    # logp[0] is recorded AFTER the first leapfrog step, so the chain
    # has moved slightly from q_uc=0. A delta of ~0.05 is expected for
    # a single half-step + position-step + half-step at this ε.
    IO.puts("\nlogp[0]:   #{Float.round(first_logp, 4)}")
    IO.puts("logp at q_uc=0 (analytic): #{Float.round(expected, 4)}")
    IO.puts("delta after 1 leapfrog:    #{Float.round(delta, 5)} #{if delta < 0.1, do: "✓", else: "✗"}")

    # 7. Status check.
    status = Nx.Vulkan.Node.status()
    IO.puts("\nnode uptime:       #{status.uptime_ms} ms")
    IO.puts("with_node calls:   #{status.exec_count}")

    # 8. Persist the pipeline cache for the next run.
    :ok = Nx.Vulkan.PipelineCache.persist()
    IO.puts("\npipeline cache persisted to disk:")
    IO.puts("  → #{Nx.Vulkan.PipelineCache.default_path()}")
    IO.puts("  → #{File.stat!(Nx.Vulkan.PipelineCache.default_path()).size} bytes")

    IO.puts("\ndone.\n")
  end

  # log p(q_uc=0 | Beta(α, β)) on logit-uc:
  #   q = sigmoid(0) = 0.5
  #   = α·log(0.5) + β·log(0.5) - log B(α, β)
  defp analytic_logp_at_q_uc_zero(alpha, beta, n) do
    log_half = :math.log(0.5)
    n * (alpha * log_half + beta * log_half - log_beta_2_5())
  end

  defp log_beta_neg(_alpha, _beta), do: -log_beta_2_5()

  # Hardcoded log_beta(2, 5) so the demo doesn't depend on lgamma
  # (Erlang/OTP's :math lacks it; nx_vulkan stays out of the math
  # library business — see ChainShaderSpecs.beta_push/6 docs).
  #
  # log_beta(2, 5) = lgamma(2) + lgamma(5) - lgamma(7)
  #               = log(1!) + log(4!) - log(6!)
  #               = 0 + log(24) - log(720) = -log(30) ≈ -3.4012
  defp log_beta_2_5, do: -:math.log(30.0)
end

GPUNodeDemo.run()