I wish Jax had everything I need to experiment with DL models built in, natively - like Pytorch. Instead there are many third party libraries (flax, trax, haiku, this one, etc). I have no idea which one to use. This was the case when I first played with jax 5 years ago, and it’s still the case today (even worse it seems). This makes it a non starter for me.
replies(1):