tensorf-jax

Unofficial implementation of Tensorial Radiance Fields (Chen & Xu ‘22)

Stars
39

tensorf-jax

JAX implementation of Tensorial Radiance Fields, written as an exercise.

@misc{TensoRF,
      title={TensoRF: Tensorial Radiance Fields},
      author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
      year={2022},
      eprint={2203.09517},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

We don't attempt to reproduce the original paper exactly, but can achieve decent results after 5~10 minutes of training:

As proposed, TensoRF only supports scenes that fit in a fixed-size bounding box. We've also added basic support for unbounded "real" scenes via mip-NeRF 360-inspired scene contraction[^1]. From nerfstudio's "dozer" dataset:

[^1]: Same as the original, but with an $L-\infty$ norm instead of $L-2$ norm.

Instructions

  1. Download nerf_synthetic dataset: Google Drive. With the default training script arguments, we expect this to be extracted to ./data, eg ./data/nerf_synthetic/lego.

  2. Install dependencies. Probably you want the GPU version of JAX; see the official instructions. Then:

    pip install -r requirements.txt
    
  3. To print training options:

    python ./train_lego.py --help
    
  4. To monitor training, we use Tensorboard:

    tensorboard --logdir=./runs/
    
  5. To render:

    python ./render_360.py --help
    

Differences from the PyTorch implementation

Things aren't totally matched to the official implementation:

  • The official implementation relies heavily on masking operations to improve
    runtime (for example, by using a weight threshold for sampled points). These
    require dynamic shapes and are currently difficult to implement in JAX, so we
    replace them with workarounds like weighted sampling.
  • Several training details that would likely improve performance are not yet
    implemented: bounding box refinement, ray filtering, regularization, etc.
  • We include mixed-precision training, which can speed training throughput up by
    a significant factor. (is this actually faster in terms of wall-clock time?
    unclear)

References

Implementation details are based loosely on the original PyTorch implementation apchsenstu/TensoRF.

unixpickle/learn-nerf and google-research/jaxnerf were also really helpful for understanding core NeRF concepts + connecting them to JAX!

To-do

  • Main implementation
    • Point sampling
    • Feature MLP
    • Rendering
    • VM decomposition
      • Basic implementation
      • Vectorized
    • Dataloading
      • Blender
      • nerfstudio
        • Basics
        • Fisheye support
        • Compute samples without undistorting images (throws away a lot of
          pixels)
    • Tricks for real data
      • Scene contraction (~mip-NeRF 360)
      • Camera embeddings
  • Training
    • Learning rate scheduler
      • ADAM + grouped LR
      • Exponential decay
      • Reset decay after upsampling
    • Running
    • Checkpointing
    • Logging
      • Loss
      • PSNR
      • Test metrics
      • Test images
      • Render previews
    • Ray filtering
    • Bounding box refinement
    • Incremental upsampling
    • Regularization terms
  • Performance
    • Weight thresholding for computing appearance features
      • per ray top-k
      • global top-k (bad & deleted)
    • Mixed-precision
      • implemented
      • stable
    • Multi-GPU (should be quick)
  • Rendering
    • RGB
    • Depth (median)
    • Depth (mean)
    • Batching
    • Generate some GIFs
  • Misc engineering
    • Actions
    • Understand vmap performance differences
      (details)