PyData NYC 2022

JAX for Bayes
11-09, 10:15–11:00 (America/New_York), Central Park East (6th floor)

See how to implement and fit a Bayesian model using a number of cutting edge Python libraries. We take a tour of the composable and ergonomic ecosystem for doing Bayesian statistics with JAX: we'll show how to write a model with any of PyMC, TensorFlow Probability (TFP), distrax, NumPyro, or pure JAX (or a combination of these!), and then how to do inference on any of those models with using any of PyMC, TFP, NumPyro, or BlackJAX.


JAX is a high performance library for doing machine learning (with gradients!) on GPU/TPU. There is a thriving ecosystem of composable libraries for doing Bayesian inference that work with JAX. Specifically, there are libraries designed for expressing a Bayesian model, and libraries designed for fitting those models, and they are easy to mix-and-match.

This talk walks through doing exactly that: implementing a regression model for classification with five different libraries, and then running inference using four libraries. This talk is primarily an API tour of the Bayesian ecosystem, but some intuition for the MCMC algorithms being used and enabled by JAX's autodiff and vmap will also be provided.

Whether you are looking for a simple way to implement and fit well understood models, experimenting with new bespoke models, keeping up with library design, or actively doing research in Bayesian computation, there is something for you here.


Prior Knowledge Expected

Previous knowledge expected

Colin Carroll is a software engineer at Google Research. In this role he focuses Bayesian computation and research, and contributes to a number of open source libraries, including TensorFlow Probability, PyMC[3], and ArviZ. He received his PhD in mathematics from Rice University, where he researched geometric measure theory.