←back to thread

Defining Statistical Models in Jax?

(statmodeling.stat.columbia.edu)
106 points hackandthink | 1 comments | | HN request time: 0s | source
Show context
JHonaker ◴[] No.41869481[source]
I'm very excited by the work being put in to make Bayesian inference more manageable. It's in a spot that feels very similar to deep learning circa mid-2010s when Caffe, Torch, and hand-written gradients were the options. We can do it, but doing anything more complicated than common model structures like hierarchical Gaussian linear models requires dropping out of the nice places and into the guts.

I've had a lot of success with Numpyro (a JAX library), and used quite a lot of tools that are simpler interfaces to Stan. I've also had to write quite a few model-specific things from scratch by hand (more for sequential Monte Carlo than MCMC). I'm very excited for a world where PPLs become scalable and easier to use /customize.

> I think there is a good chance that normalizing flow-based variational inference will displace MCMC as the go-to method for Bayesian posterior inference as soon as everyone gets access to good GPUs.

Wow. This is incredibly surprising. I'm only tangentially aware of normalizing flows, but apparently I need to look at the intersection of them and Bayesian statistics now! Any sources from anyone would be most appreciated!

replies(3): >>41869632 #>>41869638 #>>41874672 #
1. szvsw ◴[] No.41874672[source]
> make Bayesian inference more manageable

Discovering PyMC and the excellent accompanying textbook was game changing for me! Being able to write full hierarchical models in a handful of lines of code hooked up to pandas data frames already is so wonderful.

The more tools for this the better!