JAX port of Persistent Independent Particles
JAX port of the PIPs model for tracking point trajectories.
@inproceedings{harley2022particle,
title={Particle Video Revisited: Tracking Through Occlusions Using Point Trajectories},
author={Adam W Harley and Zhaoyuan Fang and Katerina Fragkiadaki},
booktitle={ECCV},
year={2022}
}
We currently include:
Clone and install (you may want to install JAX with GPU support first):
git clone https://github.com/brentyi/pips-jax.git
cd pips-jax
pip install -e .
Un-split reference checkpoint:
# Full checkpoints surpass GitHub's maximum file size, so we split the reference
# checkpoint into several parts.
cat checkpoints/reference_model/checkpoint_200000.* > checkpoints/reference_model/checkpoint_200000
Runnable scripts:
python convert_checkpoint.py --help
: Conversion script for converting the
PIPs reference PyTorch checkpoint for use in Flax.
python demo.py --help
: Loose reproduction of the original PIPs model's demo
script. Loads images and writes GIFs:
python benchmark.py --help
: Benchmarking script for the JAX model's forward
pass. Runtimes in seconds for a single forward pass[^1] compared to PyTorch:
JAX 0.4.1 | PyTorch 1.13 | PyTorch 2.0 | PyTorch 2.0 + torch.compile() |
|
---|---|---|---|---|
RTX 4090 | 0.031110.00 | 0.098920.020.076520.02 | 0.099220.020.086530.03 | (probably fast but ran into CUDA errors!) |
RTX 2080 TI | 0.106100.00 | 0.177700.010.156590.02 | 0.191430.020.156340.02 | 0.129790.000.119680.00 |
For generating PyTorch timings, see
this script. Note
that each PyTorch cell has two timings: the first is the PIPs code as
released, and the second is the PIPs code with logic corresponding to fcp
commented out. This is only used for training and visualization.
[^1]: 8 image subsequence, 640x360, 256 points, stride=4, iters=6.