#pragma once
#ifndef SPIRIT_CORE_ENGINE_BACKEND_PAR_VULKAN_HPP
#define SPIRIT_CORE_ENGINE_BACKEND_PAR_VULKAN_HPP
/* Backend_par_vulkan.hpp — Vulkan compute backend for Spirit.
*
* Third backend alongside CUDA (Backend_par.hpp) and sequential
* (Backend_seq.hpp). Uses Vulkan compute shaders + VkFFT for
* GPU-accelerated spin simulations on any Vulkan-capable GPU
* (NVIDIA, AMD, Intel).
*
* Pattern follows Backend_par.hpp: apply(N, f), reduce(N, ...),
* set(vf1, vf2, f). Instead of CUDA kernels or OpenMP pragmas,
* dispatches pre-compiled SPIR-V compute shaders via vkCmdDispatch.
*
* Requires: Vulkan 1.1+, VkFFT (header-only, in thirdparty/).
*/
#include <vulkan/vulkan.h>
#include <cstdint>
/* Use Spirit's scalar type if available, otherwise default to float. */
#ifndef SPIRIT_SCALAR_TYPE
#define SPIRIT_SCALAR_TYPE float
#endif
using scalar = SPIRIT_SCALAR_TYPE;
#include <cstring>
#include <string>
#include <vector>
#include <unordered_map>
namespace Engine
{
namespace Backend
{
namespace vulkan
{
/* ----------------------------------------------------------------
* Vulkan compute context — one per Spirit simulation instance.
* Initialized once via vk_init(), destroyed via vk_destroy().
* ---------------------------------------------------------------- */
struct VkContext
{
VkInstance instance = VK_NULL_HANDLE;
VkPhysicalDevice physical_device = VK_NULL_HANDLE;
VkDevice device = VK_NULL_HANDLE;
VkQueue compute_queue = VK_NULL_HANDLE;
uint32_t queue_family_index = 0;
VkCommandPool command_pool = VK_NULL_HANDLE;
VkPhysicalDeviceProperties device_props{};
VkPhysicalDeviceMemoryProperties mem_props{};
bool has_float64 = false;
/* Shader cache: spv_path → VkShaderModule */
std::unordered_map<std::string, VkShaderModule> shader_cache;
/* Reusable fence for synchronous dispatch */
VkFence sync_fence = VK_NULL_HANDLE;
/* Reusable command buffers — avoids per-dispatch alloc/free.
* dispatch_cmd: compute dispatches (bind, push, dispatch).
* xfer_cmd: upload/download copy commands. */
VkCommandBuffer dispatch_cmd = VK_NULL_HANDLE;
VkCommandBuffer xfer_cmd = VK_NULL_HANDLE;
/* Phase 2 W5 — persistent pipeline cache.
* Holds compiled SPIR-V → device ISA. Created at init with
* pInitialData restored from disk (if header matches the device
* UUID). Passed to every vkCreateComputePipelines call. Persisted
* back to disk via pipeline_cache_persist(). */
VkPipelineCache pipeline_cache = VK_NULL_HANDLE;
};
/* Global context — matches Spirit's pattern of global state in
* Backend::par (CUDA uses global streams/handles). */
extern VkContext g_vk_ctx;
/* ----------------------------------------------------------------
* GPU buffer — wraps VkBuffer + VkDeviceMemory.
* Used as the backing store for field<T> on Vulkan.
*
* USAGE PATTERN: persistent device-resident buffers
* --------------------------------------------------
* Allocate once, dispatch many, download once. Per-op alloc + xfer
* is ~50 ms for a 1M-element f32 buffer on an RTX 3060 Ti; the
* dispatch itself is ~70 us. The 99.9% gap is the alloc + xfer
* overhead — eliminated by holding the VkBuf across operations.
*
* Anti-pattern (don't do this in a hot loop):
*
* for (int i = 0; i < iters; i++) {
* buf_alloc(&a); buf_alloc(&b); buf_alloc(&c); // ~10 ms
* upload(&a, ha, sz); upload(&b, hb, sz); // ~20 ms
* dispatch(...); // ~0.07 ms
* download(&c, hc, sz); // ~10 ms
* buf_free(&a); buf_free(&b); buf_free(&c); // ~10 ms
* }
*
* Persistent pattern (the right shape):
*
* // once: alloc + initial upload
* VkBuf a, b, c;
* buf_alloc(&a, sz, ...); buf_alloc(&b, sz, ...); buf_alloc(&c, sz, ...);
* upload(&a, ha, sz);
* upload(&b, hb, sz);
*
* // many: dispatch only
* for (int i = 0; i < iters; i++)
* dispatch(pipe, bufs, 3, groups, sizeof(uint32_t), &n);
*
* // once: download + free
* download(&c, hc, sz);
* buf_free(&a); buf_free(&b); buf_free(&c);
*
* For Spirit's simulation loop and Nx-style tensor lifecycles, the
* persistent pattern is the only viable shape; benchmarks show
* ~700x speedup over the anti-pattern at 1M elements.
* See PERSISTENT_BUFFERS_PLAN.md for the optimization roadmap.
* ---------------------------------------------------------------- */
struct VkBuf
{
VkBuffer buffer = VK_NULL_HANDLE;
VkDeviceMemory memory = VK_NULL_HANDLE;
VkDeviceSize size = 0;
};
/* ----------------------------------------------------------------
* Pipeline — shader + descriptor set + pipeline layout.
* Cached per (shader_path, specialization_constant) pair.
* ---------------------------------------------------------------- */
struct VkPipe
{
VkDescriptorPool descriptor_pool = VK_NULL_HANDLE;
VkDescriptorSetLayout descriptor_layout = VK_NULL_HANDLE;
VkDescriptorSet descriptor_set = VK_NULL_HANDLE;
VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
VkPipeline pipeline = VK_NULL_HANDLE;
};
/* ----------------------------------------------------------------
* Lifecycle
* ---------------------------------------------------------------- */
/* Initialize the global Vulkan context. Call once at startup.
* device_id: index into vkEnumeratePhysicalDevices result. */
int vk_init(int device_id = 0);
void vk_destroy();
/* ----------------------------------------------------------------
* Memory helpers
* ---------------------------------------------------------------- */
uint32_t find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags props);
int buf_alloc(VkBuf* b, VkDeviceSize size, VkBufferUsageFlags usage,
VkMemoryPropertyFlags mem_flags);
void buf_free(VkBuf* b);
/* Host ↔ Device transfer via staging buffer. */
int upload(VkBuf* dst, const void* data, VkDeviceSize size);
int download(VkBuf* src, void* data, VkDeviceSize size);
/* ----------------------------------------------------------------
* Shader / pipeline management
* ---------------------------------------------------------------- */
VkShaderModule load_shader(const std::string& spv_path);
int create_pipeline(VkPipe* p, VkShaderModule shader,
uint32_t n_buffers, uint32_t push_constant_size,
int32_t spec_constant = 0);
void destroy_pipeline(VkPipe* p);
/* ----------------------------------------------------------------
* Dispatch — record + submit + wait
* ---------------------------------------------------------------- */
int dispatch(VkPipe* p, VkBuffer* buffers, uint32_t n_buffers,
uint32_t group_count_x,
uint32_t push_size = 0, const void* push_data = nullptr);
/* H3 dispatch timing accumulators. All times in nanoseconds; count is the
* number of dispatch() calls since the last reset. */
void timing_reset();
void timing_get(uint64_t* count, uint64_t* dispatch_ns, uint64_t* submit_ns,
uint64_t* wait_ns, uint64_t* record_ns);
/* Batched host→device upload: packs N host source pointers into one
* staging buffer + issues N copies in one submit_and_wait. Saves N-1
* fence waits. Returns 0 on success. */
int upload_batch(VkBuf** dsts, const void** data,
const VkDeviceSize* sizes, uint32_t n_buffers);
/* Phase 2 W5 — pipeline cache lifecycle.
*
* pipeline_cache_create() is called once during vk_init after the
* VkDevice is up. If `init_data` is non-null and the embedded header
* matches the current device's pipelineCacheUUID, the cache is
* restored from it. Otherwise a fresh empty cache is created.
* Returns 0 on success.
*
* pipeline_cache_get_data() serializes the current cache into a
* caller-allocated buffer. Pass `out_buf=nullptr` to query required
* size; then allocate and call again with the buffer.
*
* pipeline_cache_destroy() tears down the cache (called from vk_destroy). */
int pipeline_cache_create(const void* init_data, size_t init_size);
int pipeline_cache_get_data(void* out_buf, size_t* size_inout);
void pipeline_cache_destroy();
/* Batched device→host download: copies N device buffers into one staging
* region with a single command buffer + single submit_and_wait. Sources
* may differ in size; out_data[i] receives sizes[i] bytes from srcs[i].
* Returns 0 on success. */
int download_batch(VkBuf** srcs, void** out_data, const VkDeviceSize* sizes,
uint32_t n_buffers);
/* ----------------------------------------------------------------
* Parallel primitives — matching Backend::par interface
* ---------------------------------------------------------------- */
/* Apply a compute shader to N elements.
* shader_id: index into the registered shader table.
* The shader reads/writes through storage buffer bindings. */
void apply(int N, VkPipe* pipe, VkBuffer* buffers, uint32_t n_buffers);
/* Reduction ops — matches specialization constants in reduce.comp */
enum ReduceOp { REDUCE_SUM = 0, REDUCE_MIN = 1, REDUCE_MAX = 2 };
/* GPU reduction via two-pass tree reduce shader.
* Returns the scalar result on the host.
* reduce_spv_path: path to compiled reduce.spv shader. */
scalar reduce(VkBuf* input, int N, ReduceOp op, const std::string& reduce_spv_path);
scalar reduce_sum(VkBuf* input, int N, const std::string& reduce_spv_path);
/* Scale all elements: buf[i] *= alpha */
void scale(VkBuf* buf, int N, scalar alpha);
/* Element-wise add: out[i] = a[i] + b[i] */
void add(VkBuf* out, VkBuf* a, VkBuf* b, int N);
/* Dot product: sum(a[i] * b[i]) */
scalar dot(VkBuf* a, VkBuf* b, int N);
} // namespace vulkan
} // namespace Backend
} // namespace Engine
#endif /* SPIRIT_CORE_ENGINE_BACKEND_PAR_VULKAN_HPP */