A light-weight implementation of ICCV2023 paper "Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement."
OTHER License
A light-weight implementation of Dataset Reinforcement, pretrained checkpoints, and reinforced datasets.
Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement. , Faghri, F., Pouransari, H., Mehta, S., Farajtabar, M., Farhadi, A., Rastegari, M., & Tuzel, O., Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
Update 2023/09/22: Table 7-Average column corrected in ArXiv V3. Correct numbers: 30.4, 37.1, 37.9, 43.7, 39.6, 51.1.
Reinforced ImageNet, ImageNet+, improves accuracy at similar iterations/wall-clock
ImageNet validation accuracy of ResNet-50 is shown as a function of training duration with (1) ImageNet dataset, (2) knowledge distillation (KD), and (3) ImageNet+ dataset (ours). Each point is a full training with epochs varying from 50-1000. An epoch has the same number of iterations for ImageNet/ImageNet+.
Illustration of Dataset Reinforcement
Data augmentation and knowledge distillation are common approaches to improving accuracy. Dataset reinforcement combines the benefits of both by bringing the advantages of large models trained on large datasets to other datasets and models. Training of new models with a reinforced dataset is as fast as training on the original dataset for the same total iterations. Creating a reinforced dataset is a one-time process (e.g., ImageNet to ImageNet+) the cost of which is amortized over repeated uses.
Install the requirements using:
pip install -r requirements.txt
We support loading models from Timm library and CVNets library.
To install CVNets library follow their installation instructions.
The following is a list of reinforcements for ImageNet/CIFAR-100/Food-101/Flowers-102. We recommend ImageNet+-RA/RE based on the analysis in the paper.
Reinforce Data | Task ID | Size (GBs) | Comments |
---|---|---|---|
ImageNet+-RRC | rdata | 33.4 | [NS=400] |
ImageNet+-+M* | rdata | 46.3 | [NS=400] |
ImageNet+-+RA/RE | rdata | 37.5 | [NS=400] |
ImageNet+-+M*+R* | rdata | 53.3 | [NS=400] |
ImageNet+-RRC-Small | rdata | 4.7 | [NS=100, K=5] |
ImageNet+-+M*-Small | rdata | 7.8 | [NS=100, K=5] |
ImageNet+-+RA/RE-Small | rdata | 5.6 | [NS=100, K=5] |
ImageNet+-+M*+R*-Small | rdata | 9.4 | [NS=100, K=5] |
ImageNet+-RRC-Mini | rdata | 4.4 | [NS=50] |
ImageNet+-+M*-Mini | rdata | 6.1 | [NS=50] |
ImageNet+-+RA/RE-Mini | rdata | 4.9 | [NS=50] |
ImageNet+-+M*+R*-Mini | rdata | 7.0 | [NS=50] |
CIFAR-100 | rdata | 2.5 | [NS=800] |
Food-101 | rdata | 4.2 | [NS=800] |
Flowers-102 | rdata | 0.5 | [NS=8000] |
We provide pretrained checkpoints for various models in CVNets. The accuracies can be verified using the CVNets library.
Selected results trained for 1000 epochs:
Name | Mode | Params | ImageNet | ImageNet+ | ImageNet (EMA) | ImageNet+ (EMA) | Links |
---|---|---|---|---|---|---|---|
MobileNetV3 | large | 5.5M | 74.8 | 77.9 (+3.1) | 75.8 | 77.9 (+2.1) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ResNet | 50 | 25.6M | 80.0 | 82.0 (+2.0) | 80.1 | 82.0 (+1.9) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ViT | base | 86.7M | 76.8 | 85.1 (+8.3) | 80.8 | 85.1 (+4.3) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
ViT-384 | base | 86.7M | 79.4 | 85.4 (+6.0) | 83.1 | 85.5 (+2.4) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | tiny | 28.3M | 81.3 | 84.0 (+2.7) | 80.5 | 83.5 (+3.0) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | small | 49.7M | 81.3 | 85.0 (+3.7) | 81.9 | 84.5 (+2.6) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin | base | 87.8M | 81.5 | 85.4 (+3.9) | 81.8 | 85.2 (+3.4) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
Swin-384 | base | 87.8M | 83.6 | 85.8 (+2.2) | 83.8 | 85.5 (+1.7) | [best.pt] [ema_best.pt] [config.yaml] [metrics.jb] |
We provide pretrained checkpoints for ResNet50d from Timm library trained for 150 epochs using various reinforced datasets:
Model | Reinforce Data | Accuracy | Links |
---|---|---|---|
ResNet50d [ERM] | N/A | 78.9 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC | 80.0 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M* | 80.5 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE | 80.4 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R* | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC-Small | 80.0 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*-Small | 80.6 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE-Small | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R*-Small | 80.1 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-RRC-Mini | 80.1 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*-Mini | 80.5 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+RA/RE-Mini | 80.4 | [best.pt] [config.yaml] [metrics.jb] |
ResNet50d | ImageNet+-+M*+R*-Mini | 80.2 | [best.pt] [config.yaml] [metrics.jb] |
We provide YAML configurations for training ResNet-50 in
CFG_FILE=configs/${DATASET}/${TRAINER}.yaml
, with the following options:
DATASET
: imagenet
, cifar100
, flowers102
, and food101
.TRAINER
: standard training (erm
), knowledge distillation (kd
), and with reinforced data (plus
).Follow the steps:
data_path
in $CFG_FILE
.reinforce.data_path
in $CFG_FILE
.python train.py --config configs/imagenet/erm.yaml # ImageNet training without Reinforcements (ERM)
python train.py --config configs/imagenet/kd.yaml # Knowledge Distillation
python train.py --config configs/imagenet/plus.yaml # ImageNet+ training with reinforcements
Hyperparameters such as batch size for ImageNet training are optimized for running on a single node with 8xA100 40GB GPUs. For CIFAR-100/Flowers-102/Food-101, the configurations are optimized for training on a single GPU.
Follow the steps:
data_path
in $CFG_FILE
.$CFG_FILE
to a smaller architecture.python reinforce.py --config configs/imagenet/reinforce/randaug.yaml
If you found this code useful, please cite the following paper:
@InProceedings{faghri2023reinforce,
author = {Faghri, Fartash and Pouransari, Hadi and Mehta, Sachin and Farajtabar, Mehrdad and Farhadi, Ali and Rastegari, Mohammad and Tuzel, Oncel},
title = {Reinforce Data, Multiply Impact: Improved Model Accuracy and Robustness with Dataset Reinforcement},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2023},
}
This sample code is released under the LICENSE terms.