Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.
APACHE-2.0 License
This repository provides TensorFlow / Keras implementations of different CaiT
[1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been
populated with the original CaiT pre-trained params available from [2]. These
models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model
objects and one can call all the utility functions on them (example: .summary()
).
As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.
Refer to the "Using the models" section to get started.
Updates Oct 8 2022: This project received the Kaggle ML Research Spotlight Prize (September 2022)
Updates Sept 25 2022: Blog post on CaiT
TensorFlow / Keras implementations are available in cait/models.py
. Conversion
utilities are in convert.py
.
Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:
import tensorflow as tf
model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)
dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))
Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).
model_name | top1_acc(%) | top5_acc(%) |
---|---|---|
cait_s24_224 | 83.368 | 96.576 |
cait_xxs24_224 | 78.524 | 94.212 |
cait_xxs36_224 | 79.76 | 94.876 |
cait_xxs36_384 | 81.976 | 96.064 |
cait_xxs24_384 | 80.648 | 95.516 |
cait_xs24_384 | 83.738 | 96.756 |
cait_s24_384 | 84.944 | 97.212 |
cait_s36_384 | 85.192 | 97.372 |
cait_m36_384 | 85.924 | 97.598 |
cait_m48_448 | 86.066 | 97.590 |
Results can be verified with the code in i1k_eval
. Results are in line with [1].
Slight differences in the results stemmed
from the fact that I used a different set of augmentation transformations. Original
transformations suggested by the authors can be found here.
Pre-trained models:
These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).
Original Image | Class Attention Maps | Class Saliency Map |
---|---|---|
For the best quality, refer to the assets
directory. You can also generate these plots using the following interactive demos on
Hugging Face Spaces:
Randomly initialized models:
from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
config = base_config.get_config(
model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)
dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))
To initialize a network with say, 5 classes, do:
config = base_config.get_config(
model_name="cait_xxs24_224"
)
with config.unlocked():
config.num_classes = 5
cait_xxs24_224 = CaiT(config)
To view different model configurations, refer to convert_all_models.py
.
[1] CaiT paper: https://arxiv.org/abs/2103.17239
[2] Official CaiT code: https://github.com/facebookresearch/deit
timm
library source code