Random Variables#

DistributionalVariable(name, distribution, reparam=None)[source]#

Factory function to generate Distributional RandomVariables, either static or dynamic.

Parameters:
  • name (str) – Name of the random variable.

  • distribution (numpyro.distributions.Distribution | Callable) – Either numpyro.distributions.Distribution instance given the static distribution of the random variable or a callable that returns a parameterized numpyro.distributions.Distribution when called, which allows for dynamically-parameterized DistributionalVariables, e.g. a Normal distribution with an inferred location and scale.

  • reparam (numpyro.infer.reparam.Reparam) – If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer

Return type:

RandomVariable

Returns:

  • DynamicDistributionalVariable | StaticDistributionalVariable or

  • raises a ValueError if a distribution cannot be constructed.

class DynamicDistributionalVariable(name, distribution_constructor, reparam=None, expand_by_shape=None)[source]#

Bases: RandomVariable

Wrapper class for random variables that sample from a single numpyro.distributions.Distribution that is parameterized / instantiated at sample() time (rather than at RandomVariable instantiation time).

expand_by(sample_shape)[source]#

Expand the distribution by a given sample_shape, if possible. Returns a new DynamicDistributionalVariable whose underlying distribution will be expanded by the given shape at sample() time.

Parameters:

sample_shape (tuple) – Sample shape by which to expand the distribution. Passed to the expand_by() method of numpyro.distributions.Distribution after the distribution is instantiated.

Returns:

Whose underlying distribution will be expanded by the given sample shape at sampling time.

Return type:

DynamicDistributionalVariable

sample(*args, obs=None, **kwargs)[source]#

Sample from the distributional rv.

Parameters:
  • *args – Positional arguments passed to self.distribution_constructor

  • obs (ArrayLike, optional) – Observations passed as the obs argument to numpyro.sample(). Default None.

  • **kwargs (dict, optional) – Keyword arguments passed to self.distribution_constructor

Returns:

a sample from the distribution.

Return type:

ArrayLike

static validate(distribution_constructor)[source]#

Confirm that the distribution_constructor is callable.

Parameters:

distribution_constructor (any) – Putative distribution_constructor to validate.

Return type:

None or raises a ValueError

class StaticDistributionalVariable(name, distribution, reparam=None)[source]#

Bases: RandomVariable

Wrapper class for random variables that sample from a single numpyro.distributions.Distribution that is parameterized / instantiated at RandomVariable instantiation time (rather than at sample()-ing time).

expand_by(sample_shape)[source]#

Expand the distribution by the given sample_shape, if possible. Returns a new StaticDistributionalVariable whose underlying distribution has been expanded by the given sample_shape via expand_by()

Parameters:

sample_shape (tuple) – Sample shape for the expansion. Passed to the expand_by() method of numpyro.distributions.Distribution.

Returns:

Whose underlying distribution has been expanded by the given sample shape.

Return type:

StaticDistributionalVariable

sample(obs=None, **kwargs)[source]#

Sample from the distribution.

Parameters:
  • obs (ArrayLike, optional) – Observations passed as the obs argument to numpyro.sample(). Default None.

  • **kwargs (dict, optional) – Additional keyword arguments passed through to internal sample calls, should there be any.

Returns:

Containing a sample from the distribution.

Return type:

ArrayLike

static validate(distribution)[source]#

Validation of the distribution.

Return type:

None

class TransformedVariable(name, base_rv, transforms)[source]#

Bases: RandomVariable

Class to represent RandomVariables defined by taking the output of another RV’s RandomVariable.sample() method and transforming it by a given transformation (typically a Transform)

sample(record=False, **kwargs)[source]#

Sample method. Call self.base_rv.sample() and then apply the transforms specified in self.transforms.

Parameters:
  • record (bool, optional) – Whether to record the value of the deterministic RandomVariable. Defaults to False.

  • **kwargs – Keyword arguments passed to self.base_rv.sample()

Return type:

tuple

Returns:

  • tuple of the same length as the tuple returned by

  • self.base_rv.sample()

sample_length()[source]#

Sample length for a transformed random variable must be equal to the length of self.transforms or validation will fail.

Returns:

Equal to the length self.transforms

Return type:

int

validate()[source]#

Perform validation checks on a TransformedVariable instance, confirming that all transformations are callable and that the number of transformations is equal to the sample length of the base random variable.

Returns:

on successful validation, or raise a ValueError

Return type:

None