SVIProcess#
- class dynode.infer.inference.SVIProcess(*, numpyro_model: ~typing.Callable[[~dynode.config.simulation_config.SimulationConfig, ~typing.Tuple[~jax.Array, ...] | ~jax.Array | None], ~diffrax._solution.Solution], inference_prngkey: ~jax.Array = Array([ 0, 8675314], dtype=uint32), num_iterations: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], num_samples: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], guide_class: ~typing.Type[~numpyro.infer.autoguide.AutoContinuous] = <class 'numpyro.infer.autoguide.AutoMultivariateNormal'>, guide_init_strategy: ~typing.Callable = <function init_to_median>, optimizer: ~numpyro.optim._NumPyroOptim = <factory>, progress_bar: bool = True, guide_kwargs: dict = <factory>)#
Bases:
InferenceProcess
Inference process for fitting a numpyro_model to data using SVI.
- __init__(**data: Any) None #
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
Methods
get_samples
([_, exclude_deterministic])Get the posterior samples from the inference process.
infer
(**kwargs)Fit the numpyro_model to data using SVI.
model_post_init
(context, /)This function is meant to behave like a BaseModel method to initialise private attributes.
to_arviz
()Return the results of a fit as an arviz InferenceData object.
Attributes
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
numpyro_model
inference_prngkey
- model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- num_iterations: Annotated[int, Gt(gt=0)]#
- num_samples: Annotated[int, Gt(gt=0)]#
- guide_class: Type[AutoContinuous]#
- guide_init_strategy: Callable#
- optimizer: _NumPyroOptim#
- progress_bar: bool#
- guide_kwargs: dict#
- infer(**kwargs) SVI #
Fit the numpyro_model to data using SVI.
Additional keyword arguments are passed to the numpyro_model.
Returns#
- SVI
The SVI object used for inference.
- get_samples(_: bool = False, exclude_deterministic: bool = True) dict[str, Array] #
Get the posterior samples from the inference process.
Parameters#
- _bool
Unused parameter, whether or not to group posterior samples by chain or not. SVI does not have chains so this is unnecessary.
- exclude_deterministicbool
whether or not to exclude parameters generated from numpyro.deterministic as keys in the returned dictionary, by default True.
Returns#
- dict[str, Array]
A dictionary of posterior samples, where keys are parameter sites and values are the corresponding samples.
Notes#
Keep in mind that posterior samples are generated after the fitting process for SVI, and the samples are not arranged by chain/sample like in MCMC.
- to_arviz() InferenceData #
Return the results of a fit as an arviz InferenceData object.
Returns#
- arviz.InferenceData
arviz InferenceData object containing both priors and posterior_predictive.
Raises#
- AssertionError
if fitting has not yet been run via infer()
- _abc_impl = <_abc._abc_data object>#
- model_post_init(context: Any, /) None #
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
- Args:
self: The BaseModel instance. context: The context.