MCMCProcess#
- class dynode.infer.inference.MCMCProcess(*, numpyro_model: ~typing.Callable[[~dynode.config.simulation_config.SimulationConfig, ~typing.Tuple[~jax.Array, ...] | ~jax.Array | None], ~diffrax._solution.Solution], inference_prngkey: ~jax.Array = Array([ 0, 8675314], dtype=uint32), num_samples: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], num_warmup: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], num_chains: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], nuts_max_tree_depth: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], nuts_init_strategy: ~typing.Callable = <function init_to_median>, mcmc_kwargs: dict = <factory>, nuts_kwargs: dict = <factory>, progress_bar: bool = True)#
Bases:
InferenceProcess
Inference process for fitting a numpyro_model to data using MCMC.
- __init__(**data: Any) None #
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
Methods
get_samples
([group_by_chain, ...])Get the posterior samples from the inference process.
infer
(**kwargs)Fit the numpyro_model to data using MCMC.
model_post_init
(context, /)This function is meant to behave like a BaseModel method to initialise private attributes.
to_arviz
()Return the results of a fit as an arviz InferenceData object.
Attributes
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
numpyro_model
inference_prngkey
- num_samples: Annotated[int, Gt(gt=0)]#
- num_warmup: Annotated[int, Gt(gt=0)]#
- num_chains: Annotated[int, Gt(gt=0)]#
- nuts_max_tree_depth: Annotated[int, Gt(gt=0)]#
- nuts_init_strategy: Callable#
- mcmc_kwargs: dict#
- nuts_kwargs: dict#
- progress_bar: bool#
- infer(**kwargs) MCMC #
Fit the numpyro_model to data using MCMC.
Additional keyword arguments are passed to the numpyro_model.
Returns#
- MCMC
The MCMC object used for inference.
- get_samples(group_by_chain=False, exclude_deterministic=True) dict[str, Array] #
Get the posterior samples from the inference process.
Parameters#
- group_by_chainbool
whether or not to group posterior samples by chain or not. Adds a leading dimension to return dict’s values if True.
- exclude_deterministicbool
whether or not to exclude parameters generated from numpyro.deterministic as keys in the returned dictionary, by default True.
Returns#
dict[str, Array] A dictionary of posterior samples, where keys are parameter sites and values are the corresponding samples, arranged with shape (num_chains * num_samples,) if group_by_chain=False, otherwise arranged by (num_chains, num_samples).
- to_arviz() InferenceData #
Return the results of a fit as an arviz InferenceData object.
Returns#
- arviz.InferenceData
arviz InferenceData object containing both priors and posterior_predictive.
Raises#
- AssertionError
if fitting has not yet been run via infer()
- _abc_impl = <_abc._abc_data object>#
- model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_post_init(context: Any, /) None #
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
- Args:
self: The BaseModel instance. context: The context.