dynode.utils.utils.identify_distribution_indexes#
- dynode.utils.utils.identify_distribution_indexes(parameters: dict[str, Any]) dict[str, dict[str, str | tuple | None]] #
Identify the locations and site names of numpyro samples.
The inverse of sample_if_distribution(), identifies which parameters are numpyro distributions and returns a mapping between the sample site names and its actual parameter name and index.
Parameters#
- parametersdict[str, Any]
A dictionary containing keys of different parameter names and values of any type.
Returns#
- dict[str, dict[str, str | tuple[int] | None]]
A dictionary mapping the sample name to the dict key within parameters. If the sampled parameter is within a larger list, returns a tuple of indexes as well, otherwise None.
- key: str
Sampled parameter name as produced by sample_if_distribution().
- value: dict[str, str | tuple | None]
“sample_name” maps to key within parameters and “sample_idx” provides the indexes of the distribution if it is found in a list, otherwise None.
Examples#
>>> import numpyro.distributions as dist >>> parameters = {"test": [0, dist.Normal(), 2], "example": dist.Normal()} >>> identify_distribution_indexes(parameters) {'test_1': {'sample_name': 'test', 'sample_idx': (1,)}, 'example': {'sample_name': 'example', 'sample_idx': None}}