BiT-jax2tf

This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.

APACHE-2.0 License

Stars
13

BiT-jax2tf

This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models are results of [2]. The original model weights come from [3].

Huge thanks to Willi Gierke (of Google) for helping with the porting.

The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:

Model Name InputResolution Classifier FeatureExtractor
BiT-ResNet152x2 384 Link Link
BiT-ResNet152x2 224 Link Link
BiT-ResNet50x1 224 Link Link
BiT-ResNet50x1 160 Link Link

You could use the convert_jax_weights_tf.ipynb notebook to understand how model porting works between JAX and TensorFlow. There is also an experimental tool called jax2tf from the JAX team that you can find here.

References

[1] Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.

[2] Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

[3] BiT GitHub

Related Projects