Learning scan
with jax
and numpyro
Learning scan
- it’s all over the numpyro
time series examples so if I want to implement time series in numpyro
from scratch, I’ll need to better understand it.
In the context of time series models, scan
is helpful because it allows us to iterate over sequences (without a cumbersome for
loop) while maintaining state across each iteration.
Elements of a scan
First, let’s break down an example scan
function into its’ components:
Trivial Example - Cumulative Sum
We’ll start with the cumsum example from From The Differential Programming Tutorial - Loopy Carry, by Eric Ma. We can implement a cumsum with a for
loop like:
from jax import random
import jax.numpy as jnp
# xs
= jnp.array([1, 2, 3, 5, 7, 11, 13, 17])
xs
= []
result = 0
cumsum for el in xs:
+= el
cumsum
result.append(cumsum) jnp.array(result)
Array([ 1, 3, 6, 11, 18, 29, 42, 59], dtype=int32)
Now, instead of a loop, we rewrite it with jax.lax.scan
:
from jax import lax
def cumsum(cumsum, el):
"""
Parameters
----------
cumsum
The result from the previous loop, aka the 'carried over' result
el
The current element in the sequence in `xs`
"""
+= el
cumsum return cumsum, cumsum # ('carryover', 'accumulated')
= 0
init_sum = lax.scan(cumsum, init_sum, xs)
final, result print(final)
print(result)
59
[ 1 3 6 11 18 29 42 59]
We see that lax.scan
returned two things:
- The first item is the latest value of the carried value
cumsum
- The second item being returned is an array where each iteration appended the latest value of
cumsum
to the end.
To make it a little more concrete, here’s a GIF that steps through each iteration:
See that the scanned function returns two items: the first goes back into the loop while the second is appended to an array that acts as a ‘running tally’ of values from each iteration. You can also pass None
here if you don’t need it.
Trivial Example 2: A little more of the scan
function with compounding interest
The first two parameters in the input function can have arbitrary names but their order matters. The first item is the ‘carried over value’, aka the carry
. The second item is the current element in the xs
input array.
Similarly the input function must have two outputs. The first output becomes the new ‘carried over value’ that will feed back into the input function, while the second output gets appended to an output jax
array.
Let’s have a concrete example in which we calculate compounded interest for a $100 investment with a 4% interest rate:
# Define the input function to apply at each time step
def interest_growth(value, x):
= value * 1.04
value return value, value # ('carried over value', 'accumulated array')
# Define the initial state and input sequence
= 100.0
init_state = jnp.array([1, 2, 3, 4, 5])
T = jnp.arange(1, 30)
T
# Apply the scan function
= lax.scan(f=interest_growth, init=init_state, xs=T)
final_state, results
print(final_state)
print(results)
311.8648
[104. 108.159996 112.48639 116.98584 121.66527 126.531876
131.59314 136.85686 142.33113 148.02437 153.94534 160.10315
166.50726 173.16754 180.09424 187.298 194.78992 202.58151
210.68477 219.11215 227.87663 236.99168 246.47134 256.3302
266.5834 277.24673 288.33658 299.87003 311.8648 ]
And we see the effect of the interest compounding over time.
Trivial Example 3: scan
in numpyro
to sample from a distribution
And what I’m really interested in is learning scan
in the context of doing time series in numpyro
.
numpyro
has it’s own (experimental) implementation of scan
which allows us to iteratively sample from numpyro
primitives like numpyro.distributions
.
Here’s a simple example of a random walk:
\[ y_{t+1} = y_t + \epsilon, \quad \epsilon \sim \mathcal{N}(0,1) \]
Or, equivalently:
\[ \begin{aligned} y_t &= \mu_t\\ \mu_{t+1} &= \mu_t + \epsilon_\mu, \quad \epsilon_\mu \sim \mathcal{N}(0, 1)\\ \end{aligned} \]
We need to use scan
because each subsequent value of \(\mu_{t+1}\) depends on its previous value \(\mu_{t}\).
To do this within a scan
we’ll also need to introduce an effect handler: numpyro.handlers.seed
and run the scan
function within that context.
# Import libraries
import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
import matplotlib.pyplot as plt
# Define our random walk function as an input into `scan`
def random_walk(value, x):
= numpyro.sample('x', dist.Normal(value, 1.))
value # Which is the same as if we were very explicit:
# value = value + numpyro.sample('x', dist.Normal(0, 1.))
return value, value
= 0
init_0 = jnp.arange(100)
T
# Need the `seed` effect handler in a context manager because of
# the need for a PRNG key within `numpyro.sample`
with numpyro.handlers.seed(rng_seed=42):
= scan(random_walk, init=init_0, xs=T)
final_location, locations
plt.plot(T, locations)'Simulated Gaussian Random Walk'); plt.title(
Aside: What if I didn’t use numpyro.handlers.seed
there?
Note the addition of numpyro.handlers.seed
. This is required due to how numpyro
requires an explicit PRNG key when you sample
from a distribution, so this seed
effect handler will automatically do the split
s on the key within each loop of the scan
function.
Otherwise, if you set up an explicit key with scan
this is what you’d get:
# What happens if you don't use `numpyro.handlers.seed`?
= random.PRNGKey(42)
prng_key
def random_walk_nokey(value, x):
# Being really explicit here
= value + numpyro.sample('x', dist.Normal(0, 1.), rng_key=prng_key)
value return value, value
= 0
init_0 = 100
n_steps
= scan(random_walk_nokey, init=init_0, xs=jnp.arange(n_steps))
state, final_result
; plt.plot(jnp.arange(n_steps), final_result)
So that’s why you need to use the handler!