This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.
APACHE-2.0 License
This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models are results of [2]. The original model weights come from [3].
Huge thanks to Willi Gierke (of Google) for helping with the porting.
The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:
Model Name | InputResolution | Classifier | FeatureExtractor |
---|---|---|---|
BiT-ResNet152x2 | 384 | Link | Link |
BiT-ResNet152x2 | 224 | Link | Link |
BiT-ResNet50x1 | 224 | Link | Link |
BiT-ResNet50x1 | 160 | Link | Link |
You could use the convert_jax_weights_tf.ipynb
notebook to understand how model porting works between JAX and TensorFlow. There
is also an experimental tool called jax2tf
from the JAX team that you can find here.
[1] Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.
[2] Knowledge distillation: A good teacher is patient and consistent by Beyer et al.
[3] BiT GitHub