This project attempts to build neural network training and lightweighting cookbook including three kinds of lightweighting solutions, i.e., knowledge distillation, filter pruning, and quantization.
Knowledge distillation | Filter pruning
# ResNet-56 on CIFAR10
python train.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --train_path ~/test
python test.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --trained_param pretrained/res56_c10
Hardware: GTX 1080ti
Tensorflow implementation [link]
Pytorch implementation [link]
In order to check only training time except for model and data preparation, training time is calculated from the second to the last epoch.
Note that Accuracy on CIFAR dataset has a quite large variance so that you should focus on another metrics, i.e., training time.
As you can notice, JAX and TF are much faster than Pytorch because of JIT compiling.
Library | Accuracy | Time (m) |
---|---|---|
JAX | 93.98 | 54 |
TF | 93.91 | 53 |
Pytorch | 93.80 | 69 |
Basic training and test framework
Knowledge distillation framework
Filter pruning framework
Quantization framework
Tools for handy usage.