A Jax wrapper for cudaKDTree
MIT License
A Jax wrapper for cudaKDTree, a Library for Building and Querying Left-Balanced (point-)kd-Trees in CUDA. See Wald 2022 for details of this algorithm.
This is still very experimental code, contributions and improvements most welcome.
import jax.numpy as jnp
import jaxkdtree
# Define an array of positions of shape [N, 3]
pos = jax.random.normal(jax.random.PRNGKey(0), (64**3, 3))
# Compute the indices of the nearest neighbors
nn_inds = jaxkdtree.kNN(pos, k=8, max_radius=1.0)
On A100, as an indication, it gives you k=8 nearest neighboors for:
Checkout the demo notebook:
This assumes that you have jax, CUDA, and cmake installed on your system:
$ pip install git+https://github.com/EiffL/JaxKDTree.git