Additional modules

Contents

Additional modules#

Metaclass Module#

pyrenew helper classes

class Model(**kwargs)[source]#

Bases: object

Abstract base class for models

kernel = None#
mcmc = None#
model(**kwargs)[source]#

Alias for the sample method.

Parameters:

**kwargs (dict, optional) – Additional keyword arguments passed through to internal sample() calls, should there be any.

Return type:

tuple

posterior_predictive(rng_key=None, numpyro_predictive_args={}, **kwargs)[source]#

A wrapper of numpyro.infer.util.Predictive to generate posterior predictive samples.

Parameters:
Return type:

dict

print_summary(prob=0.9, exclude_deterministic=True)[source]#

A wrapper of MCMC.print_summary().

Parameters:
  • prob (float, optional) – The width of the credible interval to show. Default 0.9

  • exclude_deterministic (bool, optional) – Whether to print deterministic sites in the summary. Defaults to True.

Return type:

None

prior_predictive(rng_key=None, numpyro_predictive_args={}, **kwargs)[source]#

A wrapper for numpyro.infer.util.Predictive to generate prior predictive samples.

Parameters:
Return type:

dict

run(num_warmup, num_samples, rng_key=None, nuts_args=None, mcmc_args=None, **kwargs)[source]#

Runs the model

Parameters:
Return type:

None

abstract sample(**kwargs)[source]#

Sample method of the model.

The method design in the class should have at least kwargs.

Parameters:

**kwargs (dict, optional) – Additional keyword arguments passed through to internal sample() calls, should there be any.

Return type:

tuple

abstract static validate()[source]#
Return type:

None

class RandomVariable(**kwargs)[source]#

Bases: object

Abstract base class for latent and observed random variables.

abstract sample(**kwargs)[source]#

Sample method of the process

The method design in the class should have at least kwargs.

Parameters:

**kwargs (dict, optional) – Additional keyword arguments passed through to internal sample() calls, should there be any.

Return type:

tuple

abstract static validate(**kwargs)[source]#

Validation of kwargs to be implemented in subclasses.

Return type:

None

Convolution Utility Module#

convolve

Factory functions for calculating convolutions of timeseries with discrete distributions of times-to-event using jax.lax.scan(). Factories generate functions that can be passed to jax.lax.scan() or numpyro.contrib.control_flow.scan() with an appropriate array to scan along.

compute_delay_ascertained_incidence(latent_incidence, delay_incidence_to_observation_pmf, p_observed_given_incident=1)[source]#

Computes incidences observed according to a given observation rate and based on a delay interval.

Parameters:
  • p_observed_given_incident (ArrayLike) – The rate at which latent incident counts translate into observed counts. For example, setting p_observed_given_incident=0.001 when the incident counts are infections and the observed counts are reported hospital admissions could be used to model disease and population for which the probability of a latent infection leading to a reported hospital admission is 0.001.

  • latent_incidence (ArrayLike) – Incidence values based on the true underlying process.

  • delay_incidence_to_observation_pmf (ArrayLike) – Probability mass function of delay interval from incidence to observation, where the \(i\) h entry represents a delay of \(i\) time units, i.e. delay_incidence_to_observation_pmf[0] represents the fraction of observations that are delayed 0 time unit, delay_incidence_to_observation_pmf[1] represents the fraction that are delayed 1 time units, et cetera.

Returns:

The predicted timeseries of delayed observations.

Return type:

ArrayLike

new_convolve_scanner(array_to_convolve, transform)[source]#

Factory function to create a “scanner” function that can be used with jax.lax.scan() or numpyro.contrib.control_flow.scan() to construct an array via backward-looking iterative convolution.

Parameters:
  • array_to_convolve (ArrayLike) – A 1D jax array to convolve with subsets of the iteratively constructed history array.

  • transform (Callable) – A transformation to apply to the result of the dot product and multiplication.

Returns:

A scanner function that can be used with jax.lax.scan() or numpyro.contrib.control_flow.scan() for convolution. This function takes a history subset array and a scalar, computes the dot product of the supplied convolution array with the history subset array, multiplies by the scalar, and returns the resulting value and a new history subset array formed by the 2nd-through-last entries of the old history subset array followed by that same resulting value.

Return type:

Callable

Notes

The following iterative operation is found often in renewal processes:

\[X(t) = f\left(m(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d} \right) \]

Where \(\mathbf{d}\) is a vector of length \(n\), \(m(t)\) is a scalar for each value of time \(t\), and \(f\) is a scalar-valued function.

Given \(\mathbf{d}\), and optionally \(f\), this factory function returns a new function that peforms one step of this process while scanning along an array of multipliers (i.e. an array giving the values of \(m(t)\)) using jax.lax.scan().

new_double_convolve_scanner(arrays_to_convolve, transforms)[source]#

Factory function to create a scanner function that iteratively constructs arrays by applying the dot-product/multiply/transform operation twice per history subset, with the first yielding operation yielding an additional scalar multiplier for the second.

Parameters:
  • arrays_to_convolve (tuple[ArrayLike, ArrayLike]) – A tuple of two 1D jax arrays, one for each of the two stages of convolution. The first entry in the arrays_to_convolve tuple will be convolved with the current history subset array first, the the second entry will be convolved with it second.

  • transforms (tuple[Callable, Callable]) – A tuple of two functions, each transforming the output of the dot product at each convolution stage. The first entry in the transforms tuple will be applied first, then the second will be applied.

Returns:

A scanner function that applies two sets of convolution, multiply, and transform operations in sequence to construct a new array by scanning along a pair of input arrays that are equal in length to each other.

Return type:

Callable

Notes

Using the same notation as in the documentation for new_convolve_scanner(), this function aids in applying the iterative operation:

\[\begin{aligned} Y(t) &= f_1 \left(m_1(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1) \end{bmatrix} \cdot{} \mathbf{d}_1 \right) \\ \\ X(t) &= f_2 \left( m_2(t) Y(t) \begin{bmatrix} X(t - n) \\ X(t - n + 1) \\ \vdots{} \\ X(t - 1)\end{bmatrix} \cdot{} \mathbf{d}_2 \right) \end{aligned} \]

Where \(\mathbf{d}_1\) and \(\mathbf{d}_2\) are vectors of length \(n\), \(m_1(t)\) and \(m_2(t)\) are scalars for each value of time \(t\), and \(f_1\) and \(f_2\) are scalar-valued functions.

Mathematics Utilities Module#

Helper functions for doing analytical and/or numerical calculations.

get_asymptotic_growth_rate(R, generation_interval_pmf)[source]#

Get the asymptotic per timestep growth rate for a renewal process with a given value of \(\mathcal{R}\) and a given discrete generation interval probability mass vector.

This function computes that growth rate finding the dominant eigenvalue of the renewal process’s Leslie matrix.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The asymptotic growth rate of the renewal process, as a jax float.

Return type:

float

get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[source]#

Get the asymptotic per-timestep growth rate of the renewal process (the dominant eigenvalue of its Leslie matrix) and the associated stable age distribution (a normalized eigenvector associated to that eigenvalue).

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

A tuple consisting of the asymptotic growth rate of the process, as jax float, and the stable age distribution of the process, as a jax array probability vector of the same shape as the generation interval probability vector.

Return type:

tuple[float, ArrayLike]

Raises:

ValueError – If an age distribution vector with non-zero imaginary part is produced.

get_leslie_matrix(R, generation_interval_pmf)[source]#

Create the Leslie matrix corresponding to a basic renewal process with the given \(\mathcal{R}\) value and discrete generation interval pmf vector.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The Leslie matrix for the renewal process, as a jax array.

Return type:

ArrayLike

get_stable_age_distribution(R, generation_interval_pmf)[source]#

Get the stable age distribution for a renewal process with a given value of R and a given discrete generation interval probability mass vector.

This function computes that stable age distribution by finding and then normalizing an eigenvector associated to the dominant eigenvalue of the renewal process’s Leslie matrix.

Parameters:
  • R (float) – The reproduction number of the renewal process

  • generation_interval_pmf (ArrayLike) – The discrete generation interval probability mass vector of the renewal process

Returns:

The stable age distribution for the process, as a jax array probability vector of the same shape as the generation interval probability vector.

Return type:

ArrayLike

integrate_discrete(init_diff_vals, highest_order_diff_vals)[source]#

Integrate (de-difference) the differenced process, obtaining the process values \(X(t=0), X(t=1), ... X(t)\) from the \(n^{th}\) differences and a set of initial process / difference values \(X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)\), where \(X^k(t)\) is the value of the \(n^{th}\) difference at index \(t\) of the process, obtaining a sequence of length equal to the length of the provided highest_order_diff_vals vector plus the order of the process.

Parameters:
  • init_diff_vals (ArrayLike) – Values of \(X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)\).

  • highest_order_diff_vals (ArrayLike) – Array of differences at the highest order of differencing, i.e. the order of the overall process, starting with \(X^{n}(t=n)\)

Returns:

The integrated (de-differenced) sequence of values, of length n_diffs + order, where n_diffs is the number of highest_order_diff_vals and order is the order of the process.

Return type:

ArrayLike

Transformations Module#

This module exposes numpyro’s transformations module to the user, and defines and adds additional custom transformations

class AbsTransform[source]#

Bases: ParameterFreeTransform

codomain = Positive(lower_bound=0.0)#
domain = Real()#
class AffineTransform(loc, scale, domain=Real())[source]#

Bases: Transform

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

property codomain#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
property sign#

Sign of the derivative of the transform if it is bijective.

tree_flatten()[source]#
class CholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.

codomain = LowerCholesky()#
domain = PositiveDefinite()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class ComposeTransform(parts)[source]#

Bases: Transform

call_with_intermediates(x)[source]#
property codomain#
property domain#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
property sign#

Sign of the derivative of the transform if it is bijective.

tree_flatten()[source]#
class CorrCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transforms a unconstrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:

\[\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix} \]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).

  2. Transforms into an unsigned domain: \(z_i = r_i^2\).

  3. Applies \(s_i = StickBreakingTransform(z_i)\).

  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).

codomain = CorrCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class CorrMatrixCholeskyTransform[source]#

Bases: CholeskyTransform

Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.

codomain = CorrCholesky()#
domain = CorrMatrix()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class ExpTransform(domain=Real())[source]#

Bases: Transform

property codomain#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
sign = 1#
tree_flatten()[source]#
class IdentityTransform[source]#

Bases: ParameterFreeTransform

log_abs_det_jacobian(x, y, intermediates=None)[source]#
sign = 1#
class L1BallTransform[source]#

Bases: ParameterFreeTransform

Transforms a unconstrained real vector \(x\) into the unit L1 ball.

codomain = L1Ball()#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class LowerCholeskyAffine(loc, scale_tril)[source]#

Bases: Transform

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.

  • scale_tril – a lower triangular matrix with positive diagonal.

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
codomain = RealVector(Real(), 1)#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class LowerCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.

codomain = LowerCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class PermuteTransform(permutation)[source]#

Bases: Transform

codomain = RealVector(Real(), 1)#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class PowerTransform(exponent)[source]#

Bases: Transform

codomain = Positive(lower_bound=0.0)#
domain = Positive(lower_bound=0.0)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
property sign#

Sign of the derivative of the transform if it is bijective.

tree_flatten()[source]#
class RealFastFourierTransform(transform_shape=None, transform_ndims=1)[source]#

Bases: Transform

N-dimensional discrete fast Fourier transform for real input.

Parameters:
  • transform_shape – Length of each transformed axis to use from the input, defaults to the input size.

  • transform_ndims – Number of trailing dimensions to transform.

property codomain: Constraint#
property domain: Constraint#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

Return type:

tuple

Parameters:

shape (tuple)

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Return type:

tuple

Parameters:

shape (tuple)

log_abs_det_jacobian(x, y, intermediates=None)[source]#
Return type:

Array

Parameters:
tree_flatten()[source]#
class ReshapeTransform(forward_shape, inverse_shape)[source]#

Bases: Transform

Reshape a sample, leaving batch dimensions unchanged.

Parameters:
  • forward_shape – Shape to transform the sample to.

  • inverse_shape – Shape of the sample for the inverse transform.

codomain = Real()#
domain = Real()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
sign = 1#
tree_flatten()[source]#
ScaledLogitTransform(x_max)[source]#

Scaled logistic transformation from the interval (0, X_max) to the interval (-infinity, +infinity).

Parameters:

x_max (float) – Maximum value of the untransformed scale (will be transformed to +infinity).

Returns:

A composition of the following transformations: - numpyro.distributions.transforms.AffineTransform(0.0, 1.0/x_max) - numpyro.distributions.transforms.SigmoidTransform().inv

Return type:

nt.ComposeTransform

class ScaledUnitLowerCholeskyTransform[source]#

Bases: LowerCholeskyTransform

Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition

\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).

where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.

codomain = ScaledUnitLowerCholesky()#
domain = RealVector(Real(), 1)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
class SigmoidTransform[source]#

Bases: ParameterFreeTransform

codomain = UnitInterval(lower_bound=0.0, upper_bound=1.0)#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
sign = 1#
class SimplexToOrderedTransform(anchor_point=0.0)[source]#

Bases: Transform

Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

Parameters:

anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1]

References:

  1. Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
codomain = OrderedVector()#
domain = Simplex()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class SoftplusLowerCholeskyTransform[source]#

Bases: ParameterFreeTransform

Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

codomain = SoftplusLowerCholesky()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class SoftplusTransform[source]#

Bases: ParameterFreeTransform

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

codomain = SoftplusPositive(lower_bound=0.0)#
domain = Real()#
log_abs_det_jacobian(x, y, intermediates=None)[source]#
sign = 1#
class StickBreakingTransform[source]#

Bases: ParameterFreeTransform

codomain = Simplex()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
class Transform[source]#

Bases: object

call_with_intermediates(x)[source]#
codomain = Real()#
domain = Real()#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

property inv#
inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
property sign#

Sign of the derivative of the transform if it is bijective.

classmethod tree_unflatten(aux_data, params)[source]#
class UnpackTransform(unpack_fn, pack_fn=None)[source]#

Bases: Transform

Transforms a contiguous array to a pytree of subarrays.

Parameters:
  • unpack_fn – callable used to unpack a contiguous array.

  • pack_fn – callable used to pack a pytree into a contiguous array.

codomain = Dependent()#
domain = RealVector(Real(), 1)#
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x, y, intermediates=None)[source]#
tree_flatten()[source]#
class ZeroSumTransform(transform_ndims=1)[source]#

Bases: Transform

A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]

Parameters:

transform_ndims (int) – Number of trailing dimensions to transform.

References [1] pymc-devs/pymc [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/

property codomain: Constraint#
property domain: Constraint#
extend_axis(array, axis)[source]#
Return type:

Array

Parameters:
extend_axis_rev(array, axis)[source]#
Return type:

Array

Parameters:
forward_shape(shape)[source]#

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

Return type:

tuple

Parameters:

shape (tuple)

inverse_shape(shape)[source]#

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Return type:

tuple

Parameters:

shape (tuple)

log_abs_det_jacobian(x, y, intermediates=None)[source]#
Return type:

Array

Parameters:
tree_flatten()[source]#

Regression Module#

Helper classes for regression problems

class AbstractRegressionPrediction[source]#

Bases: object

abstract predict()[source]#

Make a regression prediction

abstract sample(obs=None)[source]#

Observe or sample from the regression problem according to the specified distributions

class GLMPrediction(name, intercept_prior, coefficient_priors, transform=None, intercept_suffix='_intercept', coefficient_suffix='_coefficients')[source]#

Bases: AbstractRegressionPrediction

Generalized linear model regression predictions

Parameters:
  • name (str)

  • intercept_prior (dist.Distribution)

  • coefficient_priors (dist.Distribution)

  • transform (t.Transform)

predict(intercept, coefficients, predictor_values)[source]#

Generates a transformed prediction w/ intercept, coefficients, and predictor values

Parameters:
  • intercept (ArrayLike) – Sampled numpyro distribution generated from intercept priors.

  • coefficients (ArrayLike) – Sampled prediction coefficients distribution generated from coefficients priors.

  • predictor_values (ArrayLike(n_predictors, n_observations)) – Matrix of predictor variables (covariates) for the regression problem. Each row should represent the predictor values corresponding to an observation; each column should represent a predictor variable. You do not include values of 1 for the intercept; these will be added automatically.

Returns:

Array of transformed predictions.

Return type:

ArrayLike

sample(predictor_values)[source]#

Sample generalized linear model

Parameters:

predictor_values (ArrayLike(n_predictors, n_observations)) – Matrix of predictor variables (covariates) for the regression problem. Each row should represent the predictor values corresponding to an observation; each column should represent a predictor variable. Do not include values of 1 for the intercept; these will be added automatically. Passed as the predictor_values argument to GLMPrediction.predict()

Return type:

GLMPredictionSample

class GLMPredictionSample(prediction=None, intercept=None, coefficients=None)[source]#

Bases: NamedTuple

A container for holding the output from GLMPrediction.sample().

Parameters:
  • prediction (ArrayLike | None)

  • intercept (ArrayLike | None)

  • coefficients (ArrayLike | None)

prediction#

Transformed predictions. Defaults to None.

Type:

ArrayLike | None, optional

intercept#

Sampled intercept from intercept priors. Defaults to None.

Type:

ArrayLike | None, optional

coefficients#

Prediction coefficients generated from coefficients priors. Defaults to None.

Type:

ArrayLike | None, optional

coefficients: ArrayLike | None#

Alias for field number 2

intercept: ArrayLike | None#

Alias for field number 1

prediction: ArrayLike | None#

Alias for field number 0

MCMC Utilities Module#

Distributions Utility Module#

distutil

Utilities for working with commonly- encountered probability distributions found in renewal equation modeling, such as discrete time-to-event distributions

reverse_discrete_dist_vector(dist)[source]#

Reverse a discrete distribution vector (useful for discrete time-to-event distributions).

Parameters:

dist (ArrayLike) – A discrete distribution vector (likely discrete time-to-event distribution)

Returns:

A reversed (jnp.flip) discrete distribution vector

Return type:

ArrayLike

validate_discrete_dist_vector(discrete_dist, tol=1e-05)[source]#

Validate that a vector represents a discrete probability distribution to within a specified tolerance, raising a ValueError if not.

Parameters:
  • discrete_dist (ArrayLike) – An jax array containing non-negative values that represent a discrete probability distribution. The values must sum to 1 within the specified tolerance.

  • tol (float, optional) – The tolerance within which the sum of the distribution must be 1. Defaults to 1e-5.

Returns:

The normalized distribution array if the input is valid.

Return type:

ArrayLike

Raises:

ValueError – If any value in discrete_dist is negative or if the sum of the distribution does not equal 1 within the specified tolerance.