llvm-jit-ptx

Rust device functions fused into kernels at run-time

Stars
2
Committers
1

llvm-jit-ptx

This demonstrates run-time fusion of a kernel that calls a device function. The device functions go in the dfunc crate, which is no_std and included both as a regular crate on the host (where it's easily tested) and compiled to LLVM-IR targeting the device.

In one mode, build.rs compiles to PTX. In the second, the linking and compilation to PTX is done at run-time. Either way, the PTX is loaded and run using the CUDA device API.

Demo

You must use nightly for -Z build-std (and feature(linkage) to link libdevice). We currently assume that /opt/cuda/nvvm/libdevice/libdevice.10.bc exists.

$ cargo run

This will run three kernels, all using a C++ launcher (src/main.cpp) based on the LLVM NVPTX example.

  1. A single file src/kernel.ll is compiled to PTX in build.rs and only launched in main.rs. This mode requires that the build-time compute capability detected by build.rs matches the target device.
  2. A kernel src/kernel_only.ll is paired with a device function defined in dfunc/src/lib.rs. The dfunc crate is built for target nvptx64-nvidia-cuda by build.rs. Linking occurs in main.rs. In a modest extension, kernel_only.ll could be generated from CUDA .cu source, either AoT or using NVRTC.
  3. This is the same execution mode as (2), but runs a derivative, either hand-coded or computed using Enyzme.

Enzyme (broken)

If you have a Rust Enyzme toolchain, you can run:

$ cargo +enzyme run --features enzyme-host,enzyme-device

At the time of this writing, this is broken when used across multiple crates. That is, dfunc unit tests work, but integration tests (and calls from a separate crate) see derivatives of 0.0. I believe this needs fixing in Rust-Enyzme.

Notes

  • nvJitLink provides a way to link together artifacts from different sources into a fused kernel. It works with LTO-IR, which must match the major version of libnvJitLink. LTO-IR can be generated by NVCC and NVRTC, but not to my knowledge by upstream LLVM.