←back to thread

Defining Statistical Models in Jax?

(statmodeling.stat.columbia.edu)
106 points hackandthink | 2 comments | | HN request time: 0.399s | source
1. gnulinux ◴[] No.41871680[source]
Reading this post, and reviewing the documentation of NumPyro/Pyro, I think I'm not following the crucial difference between NumPyro/Pyro. I understand that Pyro uses PyTorch as backend, and NumPyro uses JAX as backend, but other than that I'm not sure about the critical differences. If their frontend is about the same (which seems to be the case here) why is JAX mentioned in this post? Could we simply not replace Pyro with Stan for statistical modelling (whether with PyTorch or JAX backend)?
replies(1): >>41871844 #
2. nextos ◴[] No.41871844[source]
> Could we simply not replace Pyro with Stan for statistical modelling (whether with PyTorch or JAX backend)?

Stan has a fantastic NUTS Monte Carlo implementation. Pyro & NumPyro are more focused on variational inference. For a third alternatively that IMHO doesn't get the attention it deserves, take a look at Infer.NET, which excels at expectation propagation and uses factor graphs underneath. These three offer very different tradeoffs.

Stan is less expressive than Pyro/NumPyro. But for the models it can deal with (generally medium-sized multi-level models), I find it extremely easy to work with. In other words, it's much easier to diagnose model and sampling issues.