dynode.infer.sample.resolve_deterministic

dynode.infer.sample.resolve_deterministic#

dynode.infer.sample.resolve_deterministic(obj: Any, root_params: dict | BaseModel, _prefix: str = '')#

Find and resolve all DeterministicParameter types.

Parameters#

obj: Any

Python data structure that may or may not contain DeterministicParameter objects. If they exist will be resolved.

root_params: dict | Basemodel

dict or pydantic model used to resolve DeterministicParameter, all parameters pointed to by DeterministicParameters must be in the top level of this object.

_prefix:

optional prefix to add before site names, impacts downstream inference functionality so best left alone.

Returns#

obj

The obj, with any DeterministicParameter object resolved within. if isinstance(obj, DeterministicParameter) then the resolved value is returned instead.

Examples#

>>> import numpyro.distributions as dist
... from dynode.model_configuration.types import DeterministicParameter
... from dynode.utils import sample_if_distribution, resolve_if_dependent
... import numpyro.handlers as handlers
>>> parameters = {"x": dist.Normal(),
...               "y": DeterministicParameter("x"),
...               "x_lst": [0, dist.Normal(), 2],
...               "y_lst": [0, DeterministicParameter("x_lst", index=1), 2]}
>>> with handlers.seed(rng_seed=1):
...     samples = sample_distributions(parameters)
...     resolved = resolve_deterministic(samples, root_params=samples)
>>> resolved
    {'x': Array(-0.80760655, dtype=float64),
    'y': Array(-0.80760655, dtype=float64),
    'x_lst': Array([0.        , 0.57522288, 2.        ], dtype=float64),
    'y_lst': Array([0.        , 0.57522288, 2.        ], dtype=float64)}