SVIProcess#

class dynode.infer.inference.SVIProcess(*, numpyro_model: ~typing.Callable[[~dynode.config.simulation_config.SimulationConfig, ~typing.Tuple[~jax.Array, ...] | ~jax.Array | None], ~diffrax._solution.Solution], inference_prngkey: ~jax.Array = Array([      0, 8675314], dtype=uint32), num_iterations: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], num_samples: ~typing.Annotated[int, ~annotated_types.Gt(gt=0)], guide_class: ~typing.Type[~numpyro.infer.autoguide.AutoContinuous] = <class 'numpyro.infer.autoguide.AutoMultivariateNormal'>, guide_init_strategy: ~typing.Callable = <function init_to_median>, optimizer: ~numpyro.optim._NumPyroOptim = <factory>, progress_bar: bool = True, guide_kwargs: dict = <factory>)#

Bases: InferenceProcess

Inference process for fitting a numpyro_model to data using SVI.

__init__(**data: Any) None#

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Methods

get_samples([_, exclude_deterministic])

Get the posterior samples from the inference process.

infer(**kwargs)

Fit the numpyro_model to data using SVI.

model_post_init(context, /)

This function is meant to behave like a BaseModel method to initialise private attributes.

to_arviz()

Return the results of a fit as an arviz InferenceData object.

Attributes

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_iterations

num_samples

guide_class

guide_init_strategy

optimizer

progress_bar

guide_kwargs

numpyro_model

inference_prngkey

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_iterations: Annotated[int, Gt(gt=0)]#
num_samples: Annotated[int, Gt(gt=0)]#
guide_class: Type[AutoContinuous]#
guide_init_strategy: Callable#
optimizer: _NumPyroOptim#
progress_bar: bool#
guide_kwargs: dict#
infer(**kwargs) SVI#

Fit the numpyro_model to data using SVI.

Additional keyword arguments are passed to the numpyro_model.

Returns#

SVI

The SVI object used for inference.

get_samples(_: bool = False, exclude_deterministic: bool = True) dict[str, Array]#

Get the posterior samples from the inference process.

Parameters#

_bool

Unused parameter, whether or not to group posterior samples by chain or not. SVI does not have chains so this is unnecessary.

exclude_deterministicbool

whether or not to exclude parameters generated from numpyro.deterministic as keys in the returned dictionary, by default True.

Returns#

dict[str, Array]

A dictionary of posterior samples, where keys are parameter sites and values are the corresponding samples.

Notes#

Keep in mind that posterior samples are generated after the fitting process for SVI, and the samples are not arranged by chain/sample like in MCMC.

to_arviz() InferenceData#

Return the results of a fit as an arviz InferenceData object.

Returns#

arviz.InferenceData

arviz InferenceData object containing both priors and posterior_predictive.

Raises#

AssertionError

if fitting has not yet been run via infer()

_abc_impl = <_abc._abc_data object>#
model_post_init(context: Any, /) None#

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Args:

self: The BaseModel instance. context: The context.