Source code for pyrenew.process.ar

"""
This file defines a RandomVariable subclass for
autoregressive (AR) processes
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan
from numpyro.infer.reparam import LocScaleReparam

from pyrenew.metaclass import RandomVariable


[docs] class ARProcess(RandomVariable): """ RandomVariable representing an an AR(p) process. """
[docs] def sample( self, noise_name: str, n: int, autoreg: ArrayLike, init_vals: ArrayLike, noise_sd: float | ArrayLike, ) -> ArrayLike: """ Sample from the AR process Parameters ---------- noise_name: str A name for the sample site holding the Normal(`0`, `noise_sd`) noise for the AR process. Passed to :func:`numpyro.sample() <numpyro.primitives.sample>`. n: int Length of the sequence. autoreg: ArrayLike Autoregressive coefficients. The length of the array's first dimension determines the order :math:`p` of the AR process. init_vals : ArrayLike Array of initial values. Must have the same first dimension size as the order. noise_sd : ArrayLike Standard deviation of the AR process Normal noise, which by definition has mean 0. Returns ------- ArrayLike with first dimension of length `n` and additional dimensions as inferred from the shapes of `autoreg`, `init_vals`, and `noise_sd`. Notes ----- The first dimension of the return value with be of length `n` and represents time. Trailing dimensions follow standard numpy broadcasting rules and are determined from the second through `n` th dimensions, if any, of `autoreg` and `init_vals`, as well as the all dimensions of `noise_sd` (i.e. :code:`jax.numpy.shape(autoreg)[1:]`, :code:`jax.numpy.shape(init_vals)[1:]` and :code:`jax.numpy.shape(noise_sd)` Those shapes must be broadcastable together via :func:`jax.lax.broadcast_shapes`. This can be used to produce multiple AR processes of the same order but with either shared or different initial values, AR coefficient vectors, and/or and noise standard deviation values. """ autoreg = jnp.atleast_1d(autoreg) init_vals = jnp.atleast_1d(init_vals) noise_sd = jnp.array(noise_sd) # noise_sd can be a scalar, but # autoreg and init_vals must have a # a first dimension (time), # as the order of the process is # inferred from that first dimension order = autoreg.shape[0] n_inits = init_vals.shape[0] try: noise_shape = jax.lax.broadcast_shapes( init_vals.shape[1:], autoreg.shape[1:], noise_sd.shape, ) except Exception as e: raise ValueError( "Could not determine a " "valid shape for the AR process noise " "from the shapes of the init_vals, " "autoreg, and noise_sd arrays. " "See ARProcess.sample() documentation " "for details." ) from e if not n_inits == order: raise ValueError( "Initial values array must have the same " "first dimension length as the order p of " "the AR process. The order is given by " "the first dimension length of the array " "of autoregressive coefficients. Got an initial " f"value array with first dimension {n_inits} for " f"a process of order {order}" ) history_shape = (order,) + noise_shape try: inits_broadcast = jnp.broadcast_to(init_vals, history_shape) except Exception as e: raise ValueError( "Could not broadcast init_vals " f"(shape {init_vals.shape}) " "to the expected shape of the process " f"history (shape {history_shape}). " "History shape is determined by the " "shapes of the init_vals, autoreg, and " "noise_sd arrays. See ARProcess " "documentation for details" ) from e inits_flipped = jnp.flip(inits_broadcast, axis=0) def transition(recent_vals, _): # numpydoc ignore=GL08 with numpyro.handlers.reparam( config={noise_name: LocScaleReparam(0)} ): next_noise = numpyro.sample( noise_name, numpyro.distributions.Normal( loc=jnp.zeros(noise_shape), scale=noise_sd ), ) dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals) new_term = dot_prod + next_noise new_recent_vals = jnp.concatenate( [ new_term[jnp.newaxis, ...], # concatenate as (1 time unit,) + noise_shape # array recent_vals, ], axis=0, )[:order] return new_recent_vals, new_term if n > order: _, ts = scan( f=transition, init=inits_flipped, xs=None, length=(n - order), ) ts_with_inits = jnp.concatenate( [inits_broadcast, ts], axis=0, ) else: ts_with_inits = inits_broadcast return ts_with_inits[:n]
[docs] @staticmethod def validate(): # numpydoc ignore=RT01 """ Validates input parameters, implementation pending. """ return None