5x times slower than pytorch (if you see any easy optimizations, please let me know.)
Optimization with tf.function
tf.function looks so hard to use with models having different input shapes per iteration since TF retraces the graph for each input. As far as I see from the official examples, workarounds take the advantage of using eager mode. (Correct me if I am wrong.)