Additional modules

Contents

Additional modules#

Metaclass Module#

pyrenew helper classes

class DistributionalRV(dist, name)[source]#

Bases: RandomVariable

Wrapper class for random variables that sample from a single numpyro.distributions.Distribution.

sample(obs=None, **kwargs)[source]#

Sample from the distribution.

Parameters:
  • obs (ArrayLike, optional) – Observations passed as the obs argument to numpyro.sample(). Default None.

  • **kwargs (dict, optional) – Additional keyword arguments passed through to internal sample calls, should there be any.

Return type:

DistributionalRVSample

static validate(dist)[source]#

Validation of the distribution to be implemented in subclasses.

Return type:

None

class DistributionalRVSample(value: Array | ndarray | bool | number | bool | int | float | complex | None = None)[source]#

Bases: NamedTuple

Named tuple for the sample method of DistributionalRV

value#

Sampled value from the distribution.

Type:

ArrayLike

value: Union[Array, ndarray, bool, number, bool, int, float, complex, None]#

Alias for field number 0

class Model(**kwargs)[source]#

Bases: object

Abstract base class for models

kernel = None#
mcmc = None#
plot_posterior(var, obs_signal=None, xlab=None, ylab='Signal', samples=50, figsize=[4, 5], draws_col='darkblue', obs_col='black')[source]#

A wrapper of pyrenew.mcmcutils.plot_posterior

Return type:

Figure

posterior_predictive(rng_key=None, numpyro_predictive_args={}, **kwargs)[source]#

A wrapper for numpyro.infer.Predictive to generate posterior predictive samples.

Parameters:
  • rng_key (ArrayLike, optional) – Random key for the Predictive function call. Defaults to None.

  • numpyro_predictive_args (dict, optional) – Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor.

  • **kwargs – Additional named arguments passed to the __call__() method of numpyro.inference.Predictive

Return type:

dict

print_summary(prob=0.9, exclude_deterministic=True)[source]#

A wrapper of MCMC.print_summary

Parameters:
  • prob (float, optional) – The acceptance probability of print_summary. Defaults to 0.9

  • exclude_deterministic (bool, optional) – Whether to print deterministic variables in the summary. Defaults to True.

Return type:

None

prior_predictive(rng_key=None, numpyro_predictive_args={}, **kwargs)[source]#

A wrapper for numpyro.infer.Predictive to generate prior predictive samples.

Parameters:
  • rng_key (ArrayLike, optional) – Random key for the Predictive function call. Defaults to None.

  • numpyro_predictive_args (dict, optional) – Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor.

  • **kwargs – Additional named arguments passed to the __call__() method of numpyro.inference.Predictive

Return type:

dict

run(num_warmup, num_samples, rng_key=None, nuts_args=None, mcmc_args=None, **kwargs)[source]#

Runs the model

Parameters:
  • nuts_args (dict, optional) – Dictionary of arguments passed to the NUTS. Defaults to None.

  • mcmc_args (dict, optional) – Dictionary of passed to the MCMC sampler. Defaults to None.

Return type:

None

abstract sample(**kwargs)[source]#

Sample method of the model

The method design in the class should have at least kwargs.

Parameters:

**kwargs (dict, optional) – Additional keyword arguments passed through to internal sample() calls, should there be any.

Return type:

tuple

spread_draws(variables_names)[source]#

A wrapper of mcmcutils.spread_draws

Parameters:

variables_names (list) – A list of variable names to create a table of samples.

Return type:

pl.DataFrame

abstract static validate()[source]#
Return type:

None

class RandomVariable(**kwargs)[source]#

Bases: object

Abstract base class for latent and observed random variables.

abstract sample(**kwargs)[source]#

Sample method of the process

The method design in the class should have at least kwargs.

Parameters:

**kwargs (dict, optional) – Additional keyword arguments passed through to internal sample() calls, should there be any.

Return type:

tuple

abstract static validate(**kwargs)[source]#

Validation of kwargs to be implemented in subclasses.

Return type:

None

Convolution Utility Module#

convolve

Factory functions for calculating convolutions of timeseries with discrete distributions of times-to-event using jax.lax.scan(). Factories generate functions that can be passed to jax.lax.scan() with an appropriate array to scan along.

new_convolve_scanner(array_to_convolve, transform)[source]#

Factory function to create a “scanner” function that can be used with jax.lax.scan() to construct an array via backward-looking iterative convolution.

Parameters:
  • array_to_convolve (ArrayLike) – A 1D jax array to convolve with subsets of the iteratively constructed history array.

  • transform (Callable) – A transformation to apply to the result of the dot product and multiplication.

Returns:

A scanner function that can be used with jax.lax.scan() for convolution. This function takes a history subset array and a scalar, computes the dot product of the supplied convolution array with the history subset array, multiplies by the scalar, and returns the resulting value and a new history subset array formed by the 2nd-through-last entries of the old history subset array followed by that same resulting value.

Return type:

Callable

Notes

The following iterative operation is found often in renewal processes:

\[\begin{split}X(t) = f\left(m(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d} \right)\end{split}\]

Where \(\mathbf{d}\) is a vector of length \(n\), \(m(t)\) is a scalar for each value of time \(t\), and \(f\) is a scalar-valued function.

Given \(\mathbf{d}\), and optionally \(f\), this factory function returns a new function that peforms one step of this process while scanning along an array of multipliers (i.e. an array giving the values of \(m(t)\)) using jax.lax.scan().

new_double_convolve_scanner(arrays_to_convolve, transforms)[source]#

Factory function to create a scanner function that iteratively constructs arrays by applying the dot-product/multiply/transform operation twice per history subset, with the first yielding operation yielding an additional scalar multiplier for the second.

Parameters:
  • arrays_to_convolve (tuple[ArrayLike, ArrayLike]) – A tuple of two 1D jax arrays, one for each of the two stages of convolution. The first entry in the arrays_to_convolve tuple will be convolved with the current history subset array first, the the second entry will be convolved with it second.

  • transforms (tuple[Callable, Callable]) – A tuple of two functions, each transforming the output of the dot product at each convolution stage. The first entry in the transforms tuple will be applied first, then the second will be applied.

Returns:

A scanner function that applies two sets of convolution, multiply, and transform operations in sequence to construct a new array by scanning along a pair of input arrays that are equal in length to each other.

Return type:

Callable

Notes

Using the same notation as in the documentation for new_convolve_scanner(), this function aids in applying the iterative operation:

\[\begin{split}\begin{aligned} Y(t) &= f_1 \left(m_1(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1) \end{bmatrix} \cdot{} \mathbf{d}_1 \right) \\ \\ X(t) &= f_2 \left( m_2(t) Y(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d}_2 \right) \end{aligned}\end{split}\]

Where \(\mathbf{d}_1\) and \(\mathbf{d}_2\) are vectors of length \(n\), \(m_1(t)\) and \(m_2(t)\) are scalars for each value of time \(t\), and \(f_1\) and \(f_2\) are scalar-valued functions.

Mathematics Utilities Module#

Helper functions for doing analytical and/or numerical calculations about a given renewal process.

get_asymptotic_growth_rate(R, generation_interval_pmf)[source]#

Get the asymptotic per timestep growth rate for a renewal process with a given value of R and a given discrete generation interval probability mass vector.

This function computes that growth rate finding the dominant eigenvalue of the renewal process’s Leslie matrix.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The asymptotic growth rate of the renewal process, as a jax float.

Return type:

float

get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[source]#

Get the asymptotic per-timestep growth rate of the renewal process (the dominant eigenvalue of its Leslie matrix) and the associated stable age distribution (a normalized eigenvector associated to that eigenvalue).

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

A tuple consisting of the asymptotic growth rate of the process, as jax float, and the stable age distribution of the process, as a jax array probability vector of the same shape as the generation interval probability vector.

Return type:

tuple[float, ArrayLike]

Raises:

ValueError – If an age distribution vector with non-zero imaginary part is produced.

get_leslie_matrix(R, generation_interval_pmf)[source]#

Create the Leslie matrix corresponding to a basic renewal process with the given R value and discrete generation interval pmf vector.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The Leslie matrix for the renewal process, as a jax array.

Return type:

ArrayLike

get_stable_age_distribution(R, generation_interval_pmf)[source]#

Get the stable age distribution for a renewal process with a given value of R and a given discrete generation interval probability mass vector.

This function computes that stable age distribution by finding and then normalizing an eigenvector associated to the dominant eigenvalue of the renewal process’s Leslie matrix.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The stable age distribution for the process, as a jax array probability vector of the same shape as the generation interval probability vector.

Return type:

ArrayLike

Transformations Module#

This module exposes numpyro’s transformations module to the user, and defines and adds additional custom transformations

class AbsTransform[source]#

Bases: ParameterFreeTransform

codomain = Positive(lower_bound=0.0)#
domain = Real()#
class AffineTransform(loc, scale, domain=Real())[source]#

Bases: Transform

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

property codomain#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class CholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.

codomain = LowerCholesky()#
domain = PositiveDefinite()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class ComposeTransform(parts)[source]#

Bases: Transform

call_with_intermediates(x)[source]#
property codomain#
property domain#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class CorrCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:

\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).

  2. Transforms into an unsigned domain: \(z_i = r_i^2\).

  3. Applies \(s_i = StickBreakingTransform(z_i)\).

  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).

codomain = CorrCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class CorrMatrixCholeskyTransform[source]#

Bases: CholeskyTransform

Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.

codomain = CorrCholesky()#
domain = CorrMatrix()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class ExpTransform(domain=Real())[source]#

Bases: Transform

property codomain#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class IdentityTransform[source]#

Bases: ParameterFreeTransform

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class L1BallTransform[source]#

Bases: ParameterFreeTransform

Transforms a uncontrained real vector \(x\) into the unit L1 ball.

codomain = L1Ball()#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class LowerCholeskyAffine(loc, scale_tril)[source]#

Bases: Transform

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.

  • scale_tril – a lower triangular matrix with positive diagonal.

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
codomain = RealVector(Real(), 1)#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class LowerCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.

codomain = LowerCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class PermuteTransform(permutation)[source]#

Bases: Transform

codomain = RealVector(Real(), 1)#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class PowerTransform(exponent)[source]#

Bases: Transform

codomain = Positive(lower_bound=0.0)#
domain = Positive(lower_bound=0.0)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class RealFastFourierTransform(transform_shape=None, transform_ndims=1)[source]#

Bases: Transform

N-dimensional discrete fast Fourier transform for real input.

Parameters:
  • transform_shape – Length of each transformed axis to use from the input, defaults to the input size.

  • transform_ndims – Number of trailing dimensions to transform.

property codomain: Constraint#
property domain: Constraint#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

Return type:

tuple

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Return type:

tuple

log_abs_det_jacobian(x, y, intermediates=None)[source]#
Return type:

Array

tree_flatten()[source]#
class ReshapeTransform(forward_shape, inverse_shape)[source]#

Bases: Transform

Reshape a sample, leaving batch dimensions unchanged.

Parameters:
  • forward_shape – Shape to transform the sample to.

  • inverse_shape – Shape of the sample for the inverse transform.

codomain = Real()#
domain = Real()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
ScaledLogitTransform(x_max)[source]#

Scaled logistic transformation from the interval (0, X_max) to the interval (-infinity, +infinity).

Parameters:

x_max (float) – Maximum value of the untransformed scale (will be transformed to +infinity).

Returns:

A composition of the following transformations: - numpyro.distributions.transforms.AffineTransform(0.0, 1.0/x_max) - numpyro.distributions.transforms.SigmoidTransform().inv

Return type:

nt.ComposeTransform

class ScaledUnitLowerCholeskyTransform[source]#

Bases: LowerCholeskyTransform

Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition

\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).

where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.

codomain = ScaledUnitLowerCholesky()#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class SigmoidTransform[source]#

Bases: ParameterFreeTransform

codomain = UnitInterval(lower_bound=0.0, upper_bound=1.0)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class SimplexToOrderedTransform(anchor_point=0.0)[source]#

Bases: Transform

Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

Parameters:

anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1]

References:

  1. Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
codomain = OrderedVector()#
domain = Simplex()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class SoftplusLowerCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

codomain = SoftplusLowerCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class SoftplusTransform[source]#

Bases: ParameterFreeTransform

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

codomain = SoftplusPositive(lower_bound=0.0)#
domain = Real()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class StickBreakingTransform[source]#

Bases: ParameterFreeTransform

codomain = Simplex()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class Transform[source]#

Bases: object

call_with_intermediates(x)[source]#
codomain = Real()#
domain = Real()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

property inv#
inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
classmethod tree_unflatten(aux_data, params)[source]#
class UnpackTransform(unpack_fn)[source]#

Bases: Transform

Transforms a contiguous array to a pytree of subarrays.

Parameters:

unpack_fn – callable used to unpack a contiguous array.

codomain = Dependent()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class ZeroSumTransform(transform_ndims=1)[source]#

Bases: Transform

A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]

Parameters:

transform_ndims (int) – Number of trailing dimensions to transform.

References [1] pymc-devs/pymc [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/

property codomain: Constraint#
property domain: Constraint#
extend_axis(array, axis)[source]#
Return type:

Array

extend_axis_rev(array, axis)[source]#
Return type:

Array

forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

Return type:

tuple

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Return type:

tuple

log_abs_det_jacobian(x, y, intermediates=None)[source]#
Return type:

Array

tree_flatten()[source]#

Regression Module#

Helper classes for regression problems

class AbstractRegressionPrediction[source]#

Bases: object

abstract predict()[source]#

Make a regression prediction

abstract sample(obs=None)[source]#

Observe or sample from the regression problem according to the specified distributions

class GLMPrediction(name, fixed_predictor_values, intercept_prior, coefficient_priors, transform=None, intercept_suffix='_intercept', coefficient_suffix='_coefficients')[source]#

Bases: AbstractRegressionPrediction

Generalized linear model regression predictions

predict(intercept, coefficients)[source]#

Generates a transformed prediction w/ intercept, coefficients, and fixed predictor values

Parameters:
  • intercept (ArrayLike) – Sampled numpyro distribution generated from intercept priors.

  • coefficients (ArrayLike) – Sampled prediction coefficients distribution generated from coefficients priors.

Returns:

Array of transformed predictions.

Return type:

ArrayLike

sample()[source]#

Sample generalized linear model

Returns:

A dictionary containing transformed predictions, and the intercept and coefficients sample distributions.

Return type:

dict

MCMC Utilities Module#

Utilities to deal with MCMC outputs

plot_posterior(var, draws, obs_signal=None, ylab=None, xlab='Time', samples=50, figsize=[4, 5], draws_col='darkblue', obs_col='black')[source]#

Plot the posterior distribution of a variable

Parameters:
  • var (str) – Name of the variable to plot

  • model (Model) – Model object

  • obs_signal (ArrayLike, optional) – Observed signal to plot as reference

  • ylab (str, optional) – Label for the y-axis

  • xlab (str, optional) – Label for the x-axis

  • samples (int, optional) – Number of samples to plot

  • figsize (list, optional) – Size of the figure

  • draws_col (str, optional) – Color of the draws

  • obs_col (str, optional) – Color of observations column.

Return type:

plt.Figure

spread_draws(posteriors, variables_names)[source]#

Get nicely shaped draws from the posterior

Given a dictionary of posteriors, return a long-form polars dataframe indexed by draw, with variable values (equivalent of tidybayes spread_draws() function).

Parameters:
  • posteriors (dict) – A dictionary of posteriors with variable names as keys and numpy ndarrays as values (with the first axis corresponding to the posterior draw number.

  • variables_names (list[str] | list[tuple]) – list of strings or of tuples identifying which variables to retrieve.

Returns:

A dataframe of draw-indexed

Return type:

pl.DataFrame

Distributions Utility Module#

distutil

Utilities for working with commonly- encountered probability distributions found in renewal equation modeling, such as discrete time-to-event distributions

reverse_discrete_dist_vector(dist)[source]#

Reverse a discrete distribution vector (useful for discrete time-to-event distributions).

Parameters:

dist (ArrayLike) – A discrete distribution vector (likely discrete time-to-event distribution)

Returns:

A reversed (jnp.flip) discrete distribution vector

Return type:

ArrayLike

validate_discrete_dist_vector(discrete_dist, tol=1e-05)[source]#

Validate that a vector represents a discrete probability distribution to within a specified tolerance, raising a ValueError if not.

Parameters:
  • discrete_dist (ArrayLike) – An jax array containing non-negative values that represent a discrete probability distribution. The values must sum to 1 within the specified tolerance.

  • tol (float, optional) – The tolerance within which the sum of the distribution must be 1. Defaults to 1e-5.

Returns:

The normalized distribution array if the input is valid.

Return type:

ArrayLike

Raises:

ValueError – If any value in discrete_dist is negative or if the sum of the distribution does not equal 1 within the specified tolerance.