dynode.utils.utils.identify_distribution_indexes

dynode.utils.utils.identify_distribution_indexes#

dynode.utils.utils.identify_distribution_indexes(parameters: dict[str, Any]) dict[str, dict[str, str | tuple | None]]#

Identify the locations and site names of numpyro samples.

The inverse of sample_if_distribution(), identifies which parameters are numpyro distributions and returns a mapping between the sample site names and its actual parameter name and index.

Parameters#

parametersdict[str, Any]

A dictionary containing keys of different parameter names and values of any type.

Returns#

dict[str, dict[str, str | tuple[int] | None]]

A dictionary mapping the sample name to the dict key within parameters. If the sampled parameter is within a larger list, returns a tuple of indexes as well, otherwise None.

  • key: str

    Sampled parameter name as produced by sample_if_distribution().

  • value: dict[str, str | tuple | None]

    “sample_name” maps to key within parameters and “sample_idx” provides the indexes of the distribution if it is found in a list, otherwise None.

Examples#

>>> import numpyro.distributions as dist
>>> parameters = {"test": [0, dist.Normal(), 2], "example": dist.Normal()}
>>> identify_distribution_indexes(parameters)
{'test_1': {'sample_name': 'test', 'sample_idx': (1,)},
'example': {'sample_name': 'example', 'sample_idx': None}}