Tpuddim

Denoising Diffusion Implicit Models JAX TPU implementation. Based on the network architecture in https://github.com/openai/guided-diffusion , pretrained weights compatible.

Still quite WIP, though less so now, there is code for inference and basic training, weights for the MNIST example, an initial training attempt on Danbooru2019Faces and a convenient colab inference notebook for the same.

There's also an attempt at making a mixer work as a diffusion model, that doesn't quite work yet. Patches extremely welcome, especially if they adress one of the things marked TODO in the code somewhere.

(Note that training on Colab will not work - this code was developed on, and the models trained on, TRC TPUs)

MNIST example output:

Danbooru2019Faces example output:

Acknowledgements

This work would not have been possible without a TPU access grant by the Google TPU Research Cloud.