Source code for pyrenew.process.ar

# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from __future__ import annotations

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import lax
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable


[docs] class ARProcess(RandomVariable): """ Object to represent an AR(p) process in Numpyro """ def __init__( self, mean: float, autoreg: ArrayLike, noise_sd: float, ) -> None: """ Default constructor Parameters ---------- mean: float Mean parameter. autoreg : ArrayLike Model parameters. The shape determines the order. noise_sd : float Standard error for the noise component. Returns ------- None """ self.mean = mean self.autoreg = autoreg self.noise_sd = noise_sd
[docs] def sample( self, duration: int, inits: ArrayLike = None, name: str = "arprocess", **kwargs, ) -> tuple: """ Sample from the AR process Parameters ---------- duration: int Length of the sequence. inits : ArrayLike, optional Initial points, if None, then these are sampled. Defaults to None. name : str, optional Name of the parameter passed to numpyro.sample. Defaults to "arprocess". **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. Returns ------- tuple With a single array of shape (duration,). """ order = self.autoreg.shape[0] if inits is None: inits = numpyro.sample( name + "_sampled_inits", dist.Normal(0, self.noise_sd).expand((order,)), ) def _ar_scanner(carry, next): # numpydoc ignore=GL08 new_term = (jnp.dot(self.autoreg, carry) + next).flatten() new_carry = jnp.hstack([new_term, carry[: (order - 1)]]) return new_carry, new_term noise = numpyro.sample( name + "_noise", dist.Normal(0, self.noise_sd).expand((duration - inits.size,)), ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) return (jnp.hstack([inits, self.mean + ts.flatten()]),)
[docs] @staticmethod def validate(): # numpydoc ignore=RT01 """ Validates inputted parameters, implementation pending. """ return None