Skip to content

Random Variable

DynamicDistributionalVariable

DynamicDistributionalVariable(
    name: str,
    distribution_constructor: Callable,
    reparam: Reparam = None,
    expand_by_shape: tuple = None,
)

Bases: RandomVariable

Wrapper class for random variables that sample from a single numpyro.distributions.distribution.Distribution that is parameterized / instantiated at sample() time (rather than at RandomVariable instantiation time).

Default constructor for DynamicDistributionalVariable.

Parameters:

Name Type Description Default
name str

Name of the random variable.

required
distribution_constructor Callable

Callable that returns a concrete parametrized numpyro.Distributions.distribution instance.

required
reparam Reparam

If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer

None
expand_by_shape tuple

If not None, call numpyro.distributions.distribution.Distribution.expand_by on the underlying distribution once it is instantiated with the given expand_by_shape. Default None.

None

Returns:

Type Description
None
Source code in pyrenew/randomvariable/distributionalvariable.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    name: str,
    distribution_constructor: Callable,
    reparam: Reparam = None,
    expand_by_shape: tuple = None,
) -> None:
    """
    Default constructor for DynamicDistributionalVariable.

    Parameters
    ----------
    name
        Name of the random variable.
    distribution_constructor
        Callable that returns a concrete parametrized
        numpyro.Distributions.distribution instance.
    reparam
        If not None, reparameterize sampling
        from the distribution according to the
        given numpyro reparameterizer
    expand_by_shape
        If not None, call [`numpyro.distributions.distribution.Distribution.expand_by`][] on the
        underlying distribution once it is instantiated
        with the given `expand_by_shape`.
        Default None.

    Returns
    -------
    None
    """

    self.name = name
    self.validate(distribution_constructor)
    self.distribution_constructor = distribution_constructor
    if reparam is not None:
        self.reparam_dict = {self.name: reparam}
    else:
        self.reparam_dict = {}
    if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)):
        raise ValueError(
            "expand_by_shape must be a tuple or be None ",
            f"Got {type(expand_by_shape)}",
        )
    self.expand_by_shape = expand_by_shape

    return None

expand_by

expand_by(sample_shape) -> Self

Expand the distribution by a given sample_shape, if possible. Returns a new DynamicDistributionalVariable whose underlying distribution will be expanded by the given shape at sample() time.

Parameters:

Name Type Description Default
sample_shape

Sample shape by which to expand the distribution. Passed to the expand_by() method of numpyro.distributions.distribution.Distribution after the distribution is instantiated.

required

Returns:

Type Description
DynamicDistributionalVariable

Whose underlying distribution will be expanded by the given sample shape at sampling time.

Source code in pyrenew/randomvariable/distributionalvariable.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def expand_by(self, sample_shape) -> Self:
    """
    Expand the distribution by a given
    sample_shape, if possible. Returns a
    new DynamicDistributionalVariable whose underlying
    distribution will be expanded by the given shape
    at sample() time.

    Parameters
    ----------
    sample_shape
        Sample shape by which to expand the distribution.
        Passed to the expand_by() method of
        [`numpyro.distributions.distribution.Distribution`][]
        after the distribution is instantiated.

    Returns
    -------
    DynamicDistributionalVariable
        Whose underlying distribution will be expanded by
        the given sample shape at sampling time.
    """
    return DynamicDistributionalVariable(
        name=self.name,
        distribution_constructor=self.distribution_constructor,
        reparam=self.reparam_dict.get(self.name, None),
        expand_by_shape=sample_shape,
    )

sample

sample(*args, obs: ArrayLike = None, **kwargs) -> ArrayLike

Sample from the distributional rv.

Parameters:

Name Type Description Default
*args

Positional arguments passed to self.distribution_constructor

()
obs ArrayLike

Observations passed as the obs argument to numpyro.primitives.sample. Default None.

None
**kwargs

Keyword arguments passed to self.distribution_constructor

{}

Returns:

Type Description
ArrayLike

a sample from the distribution.

Source code in pyrenew/randomvariable/distributionalvariable.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def sample(
    self,
    *args,
    obs: ArrayLike = None,
    **kwargs,
) -> ArrayLike:
    """
    Sample from the distributional rv.

    Parameters
    ----------
    *args
        Positional arguments passed to self.distribution_constructor
    obs
        Observations passed as the `obs` argument to
        [`numpyro.primitives.sample`][]. Default `None`.
    **kwargs
        Keyword arguments passed to self.distribution_constructor

    Returns
    -------
    ArrayLike
        a sample from the distribution.
    """
    distribution = self.distribution_constructor(*args, **kwargs)
    if self.expand_by_shape is not None:
        distribution = distribution.expand_by(self.expand_by_shape)
    with numpyro.handlers.reparam(config=self.reparam_dict):
        sample = numpyro.sample(
            name=self.name,
            fn=distribution,
            obs=obs,
        )
    return sample

validate staticmethod

validate(distribution_constructor: any) -> None

Confirm that the distribution_constructor is callable.

Parameters:

Name Type Description Default
distribution_constructor any

Putative distribution_constructor to validate.

required

Returns:

Type Description
None or raises a ValueError
Source code in pyrenew/randomvariable/distributionalvariable.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@staticmethod
def validate(distribution_constructor: any) -> None:
    """
    Confirm that the distribution_constructor is
    callable.

    Parameters
    ----------
    distribution_constructor
        Putative distribution_constructor to validate.

    Returns
    -------
    None or raises a ValueError
    """
    if not callable(distribution_constructor):
        raise ValueError(
            "To instantiate a DynamicDistributionalVariable, ",
            "one must provide a Callable that returns a "
            "numpyro.distributions.Distribution as the "
            "distribution_constructor argument. "
            f"Got {type(distribution_constructor)}, which "
            "does not appear to be callable",
        )
    return None

StaticDistributionalVariable

StaticDistributionalVariable(
    name: str, distribution: Distribution, reparam: Reparam = None
)

Bases: RandomVariable

Wrapper class for random variables that sample from a single numpyro.distributions.distribution.Distribution that is parameterized / instantiated at RandomVariable instantiation time (rather than at sample()-ing time).

Default constructor for DistributionalVariable.

Parameters:

Name Type Description Default
name str

Name of the random variable.

required
distribution Distribution

Distribution of the random variable.

required
reparam Reparam

If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer

None

Returns:

Type Description
None
Source code in pyrenew/randomvariable/distributionalvariable.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    name: str,
    distribution: numpyro.distributions.Distribution,
    reparam: Reparam = None,
) -> None:
    """
    Default constructor for DistributionalVariable.

    Parameters
    ----------
    name
        Name of the random variable.
    distribution
        Distribution of the random variable.
    reparam
        If not None, reparameterize sampling
        from the distribution according to the
        given numpyro reparameterizer

    Returns
    -------
    None
    """

    self.name = name
    self.validate(distribution)
    self.distribution = distribution
    if reparam is not None:
        self.reparam_dict = {self.name: reparam}
    else:
        self.reparam_dict = {}

    return None

expand_by

expand_by(sample_shape) -> Self

Expand the distribution by the given sample_shape, if possible. Returns a new StaticDistributionalVariable whose underlying distribution has been expanded by the given sample_shape via numpyro.distributions.distribution.Distribution.expand_by.

Parameters:

Name Type Description Default
sample_shape

Sample shape for the expansion. Passed to numpyro.distributions.distribution.Distribution.expand_by.

required

Returns:

Type Description
StaticDistributionalVariable

Whose underlying distribution has been expanded by the given sample shape.

Source code in pyrenew/randomvariable/distributionalvariable.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def expand_by(self, sample_shape) -> Self:
    """
    Expand the distribution by the given sample_shape,
    if possible. Returns a new StaticDistributionalVariable
    whose underlying distribution has been expanded by
    the given sample_shape via
    [`numpyro.distributions.distribution.Distribution.expand_by`][].

    Parameters
    ----------
    sample_shape
        Sample shape for the expansion. Passed to
        [`numpyro.distributions.distribution.Distribution.expand_by`][].

    Returns
    -------
    StaticDistributionalVariable
        Whose underlying distribution has been expanded by
        the given sample shape.
    """
    if not isinstance(sample_shape, tuple):
        raise ValueError(
            "sample_shape for expand()-ing "
            "a DistributionalVariable must be a "
            f"tuple. Got {type(sample_shape)}"
        )
    return StaticDistributionalVariable(
        name=self.name,
        distribution=self.distribution.expand_by(sample_shape),
        reparam=self.reparam_dict.get(self.name, None),
    )

sample

sample(obs: ArrayLike | None = None, **kwargs) -> ArrayLike

Sample from the distribution.

Parameters:

Name Type Description Default
obs ArrayLike | None

Observations passed as the obs argument to numpyro.primitives.sample. Default None.

None
**kwargs

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
ArrayLike

Containing a sample from the distribution.

Source code in pyrenew/randomvariable/distributionalvariable.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def sample(
    self,
    obs: ArrayLike | None = None,
    **kwargs,
) -> ArrayLike:
    """
    Sample from the distribution.

    Parameters
    ----------
    obs
        Observations passed as the `obs` argument to
        [`numpyro.primitives.sample`][]. Default `None`.
    **kwargs
        Additional keyword arguments passed through
        to internal sample calls, should there be any.

    Returns
    -------
    ArrayLike
        Containing a sample from the distribution.
    """
    with numpyro.handlers.reparam(config=self.reparam_dict):
        sample = numpyro.sample(
            name=self.name,
            fn=self.distribution,
            obs=obs,
        )
    return sample

validate staticmethod

validate(distribution: any) -> None

Validation of the distribution.

Source code in pyrenew/randomvariable/distributionalvariable.py
203
204
205
206
207
208
209
210
211
212
213
214
215
@staticmethod
def validate(distribution: any) -> None:
    """
    Validation of the distribution.
    """
    if not isinstance(distribution, numpyro.distributions.Distribution):
        raise ValueError(
            "distribution should be an instance of "
            "numpyro.distributions.Distribution, got "
            "{type(distribution)}"
        )

    return None

TransformedVariable

TransformedVariable(
    name: str, base_rv: RandomVariable, transforms: Transform | tuple[Transform]
)

Bases: RandomVariable

Class to represent RandomVariables defined by taking the output of another RV's pyrenew.metaclass.RandomVariable.sample method and transforming it by a given transformation (typically a numpyro.distributions.transforms.Transform)

Default constructor

Parameters:

Name Type Description Default
name str

A name for the random variable instance.

required
base_rv RandomVariable

The underlying (untransformed) RandomVariable.

required
transforms Transform | tuple[Transform]

Transformation or tuple of transformations to apply to the output of base_rv.sample(); single values will be coerced to a length-one tuple. If a tuple, should be the same length as the tuple returned by base_rv.sample().

required

Returns:

Type Description
None
Source code in pyrenew/randomvariable/transformedvariable.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    name: str,
    base_rv: RandomVariable,
    transforms: Transform | tuple[Transform],
):
    """
    Default constructor

    Parameters
    ----------
    name
        A name for the random variable instance.
    base_rv
        The underlying (untransformed) RandomVariable.
    transforms
        Transformation or tuple of transformations
        to apply to the output of
        `base_rv.sample()`; single values will be coerced to
        a length-one tuple. If a tuple, should be the same
        length as the tuple returned by `base_rv.sample()`.

    Returns
    -------
    None
    """
    self.name = name
    self.base_rv = base_rv

    if not isinstance(transforms, tuple):
        transforms = (transforms,)
    self.transforms = transforms
    self.validate()

sample

sample(record=False, **kwargs) -> tuple

Sample method. Call self.base_rv.sample() and then apply the transforms specified in self.transforms.

Parameters:

Name Type Description Default
record

Whether to record the value of the deterministic RandomVariable. Defaults to False.

False
**kwargs

Keyword arguments passed to self.base_rv.sample()

{}

Returns:

Type Description
tuple of the same length as the tuple returned by
sample()
Source code in pyrenew/randomvariable/transformedvariable.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def sample(self, record=False, **kwargs) -> tuple:
    """
    Sample method. Call self.base_rv.sample()
    and then apply the transforms specified
    in self.transforms.

    Parameters
    ----------
    record
        Whether to record the value of the deterministic
        RandomVariable. Defaults to False.
    **kwargs
        Keyword arguments passed to self.base_rv.sample()

    Returns
    -------
    tuple of the same length as the tuple returned by
    self.base_rv.sample()
    """

    untransformed_values = self.base_rv.sample(**kwargs)

    if not isinstance(untransformed_values, tuple):
        untransformed_values = (untransformed_values,)

    transformed_values = tuple(
        t(uv) for t, uv in zip(self.transforms, untransformed_values)
    )

    if record:
        if len(untransformed_values) == 1:
            numpyro.deterministic(self.name, transformed_values)
        else:
            suffixes = (
                untransformed_values._fields
                if hasattr(untransformed_values, "_fields")
                else range(len(transformed_values))
            )
            for suffix, tv in zip(suffixes, transformed_values):
                numpyro.deterministic(f"{self.name}_{suffix}", tv)

    if len(transformed_values) == 1:
        transformed_values = transformed_values[0]

    return transformed_values

sample_length

sample_length()

Sample length for a transformed random variable must be equal to the length of self.transforms or validation will fail.

Returns:

Type Description
int

Equal to the length of self.transforms

Source code in pyrenew/randomvariable/transformedvariable.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def sample_length(self):
    """
    Sample length for a transformed
    random variable must be equal to the
    length of self.transforms or
    validation will fail.

    Returns
    -------
    int
        Equal to the length of `self.transforms`
    """
    return len(self.transforms)

validate

validate()

Perform validation checks on a TransformedVariable instance, confirming that all transformations are callable and that the number of transformations is equal to the sample length of the base random variable.

Returns:

Type Description
None

on successful validation, or raise a ValueError

Source code in pyrenew/randomvariable/transformedvariable.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def validate(self):
    """
    Perform validation checks on a
    TransformedVariable instance,
    confirming that all transformations
    are callable and that the number of
    transformations is equal to the sample
    length of the base random variable.

    Returns
    -------
    None
        on successful validation, or raise a [`ValueError`][]
    """
    for t in self.transforms:
        if not callable(t):
            raise ValueError("All entries in self.transforms must be callable")
    if hasattr(self.base_rv, "sample_length"):
        n_transforms = len(self.transforms)
        n_entries = self.base_rv.sample_length()
        if not n_transforms == n_entries:
            raise ValueError(
                "There must be exactly as many transformations "
                "specified as entries self.transforms as there are "
                "entries in the tuple returned by "
                "self.base_rv.sample()."
                f"Got {n_transforms} transforms and {n_entries} "
                "entries"
            )

DistributionalVariable

DistributionalVariable(
    name: str, distribution: Distribution | Callable, reparam: Reparam = None
) -> RandomVariable

Factory function to generate Distributional RandomVariables, either static or dynamic.

Parameters:

Name Type Description Default
name str

Name of the random variable.

required
distribution Distribution | Callable

Either numpyro.distributions.Distribution instance given the static distribution of the random variable or a callable that returns a parameterized numpyro.distributions.Distribution when called, which allows for dynamically-parameterized DistributionalVariables, e.g. a Normal distribution with an inferred location and scale.

required
reparam Reparam

If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer

None

Returns:

Type Description
DynamicDistributionalVariable | StaticDistributionalVariable or
raises a ValueError if a distribution cannot be constructed.
Source code in pyrenew/randomvariable/distributionalvariable.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def DistributionalVariable(
    name: str,
    distribution: numpyro.distributions.Distribution | Callable,
    reparam: Reparam = None,
) -> RandomVariable:
    """
    Factory function to generate Distributional RandomVariables,
    either static or dynamic.

    Parameters
    ----------
    name
        Name of the random variable.

    distribution
        Either numpyro.distributions.Distribution instance
        given the static distribution of the random variable or
        a callable that returns a parameterized
        numpyro.distributions.Distribution when called, which
        allows for dynamically-parameterized DistributionalVariables,
        e.g. a Normal distribution with an inferred location and
        scale.

    reparam
        If not None, reparameterize sampling
        from the distribution according to the
        given numpyro reparameterizer

    Returns
    -------
    DynamicDistributionalVariable | StaticDistributionalVariable or
    raises a ValueError if a distribution cannot be constructed.
    """
    if isinstance(distribution, dist.Distribution):
        return StaticDistributionalVariable(
            name=name, distribution=distribution, reparam=reparam
        )
    elif callable(distribution):
        return DynamicDistributionalVariable(
            name=name, distribution_constructor=distribution, reparam=reparam
        )
    else:
        raise ValueError(
            "distribution argument to DistributionalVariable "
            "must be either a numpyro.distributions.Distribution "
            "(for instantiating a static DistributionalVariable) "
            "or a callable that returns a "
            "numpyro.distributions.Distribution (for "
            "a dynamic DistributionalVariable)."
        )