Model
MultiSignalModel
MultiSignalModel(
latent_process: BaseLatentInfectionProcess,
observations: dict[str, BaseObservationProcess],
ascertainment_models: dict[str, AscertainmentModel] | None = None,
)
Bases: Model
Multi-signal renewal model.
Combines a latent infection process (e.g., SubpopulationInfections) with multiple observation processes (e.g., CountObservation, WastewaterObservation).
Built via PyrenewBuilder to ensure n_initialization_points is computed correctly from all components. Can also be constructed manually for advanced use cases.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_process
|
BaseLatentInfectionProcess
|
Latent infection process generating infections at jurisdiction and/or subpopulation levels |
required |
observations
|
dict[str, BaseObservationProcess]
|
Dictionary mapping names to observation process instances. Names are used when passing observation data to sample(). |
required |
ascertainment_models
|
dict[str, AscertainmentModel] | None
|
Optional dictionary mapping names to ascertainment model instances. Each ascertainment model is sampled once per model execution before observation processes run. |
None
|
Notes
The model automatically routes latent infections to observations based on each observation's infection_resolution() method: - "aggregate" → receives aggregate infections from latent process - "subpop" → receives all subpopulation infections; observation selects via indices
Initialize multi-signal model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_process
|
BaseLatentInfectionProcess
|
Configured latent infection process |
required |
observations
|
dict[str, BaseObservationProcess]
|
Dictionary mapping observation names to observation process instances |
required |
ascertainment_models
|
dict[str, AscertainmentModel] | None
|
Optional dictionary mapping names to ascertainment model instances |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If validation fails (e.g., observation requires subpopulations but latent process doesn't support them) |
Source code in pyrenew/model/multisignal_model.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
pad_observations
Pad observations with NaN for the initialization period.
Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. This method prepends n_init NaN values to align user data (starting at day 0 of observations) with the shared axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
ndarray
|
Observations in natural coordinates (index 0 = first observation day). Integer arrays are converted to float (required for NaN). |
required |
axis
|
int
|
Axis along which to pad (typically 0 for time axis). |
0
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Padded observations. First n_init values are NaN. |
Source code in pyrenew/model/multisignal_model.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | |
sample
sample(
n_days_post_init: int,
population_size: float,
*,
subpop_fractions: ArrayLike | None = None,
obs_start_date: date | datetime | datetime64 | None = None,
**observation_data: dict[str, object],
) -> None
Sample from the joint generative model.
This is the model function called by NumPyro during inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_days_post_init
|
int
|
Number of days to simulate after initialization period |
required |
population_size
|
float
|
Total population size. Used to convert infection proportions (from latent process) to infection counts (for observation processes). |
required |
subpop_fractions
|
ArrayLike | None
|
Population fractions for all subpopulations. Shape: (n_subpops,). |
None
|
obs_start_date
|
date | datetime | datetime64 | None
|
Date of the first observation day. Converted once to the
axis-origin |
None
|
**observation_data
|
dict[str, object]
|
Data for each observation process, keyed by observation name
(the |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
All quantities are recorded as NumPyro deterministic sites
( |
Source code in pyrenew/model/multisignal_model.py
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 | |
shift_times
Shift time indices from natural coordinates to shared time axis.
Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. User-provided times in natural coordinates (0 = first observation day) must be shifted by n_init to align with the shared axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
times
|
ndarray
|
Time indices in natural coordinates (0 = first observation day). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Time indices on the shared axis [n_init, n_total). |
Source code in pyrenew/model/multisignal_model.py
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | |
validate
validate() -> None
Validate that observation processes are compatible with latent process.
Checks that each observation implements infection_resolution() and returns a supported resolution.
Raises:
| Type | Description |
|---|---|
ValueError
|
If an observation doesn't implement infection_resolution() or returns an unsupported resolution. |
Source code in pyrenew/model/multisignal_model.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | |
validate_data
validate_data(
n_days_post_init: int,
*,
subpop_fractions: ArrayLike | None = None,
obs_start_date: date | datetime | datetime64 | None = None,
**observation_data: dict[str, object],
) -> None
Validate observation data before running MCMC.
All observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days_post_init. Dense observations must have length n_total with NaN padding for the initialization period. Sparse observations provide times indices on this shared axis.
This method must be called with concrete (non-traced) values before running inference. Validation using Python control flow (if/raise) cannot be done during JAX tracing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_days_post_init
|
int
|
Number of days to simulate after initialization period. |
required |
subpop_fractions
|
ArrayLike | None
|
Population fractions for all subpopulations. Shape: (n_subpops,). |
None
|
obs_start_date
|
date | datetime | datetime64 | None
|
Date of the first observation day. Required when any
observation uses |
None
|
**observation_data
|
dict[str, object]
|
Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample(). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If times indices are out of bounds or negative, if dense obs
length doesn't match n_total, if data shapes are inconsistent,
or if |
Source code in pyrenew/model/multisignal_model.py
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | |
PyrenewBuilder
PyrenewBuilder()
Builder for multi-signal renewal models.
Automatically computes n_initialization_points from observation processes and constructs a properly configured model.
The builder pattern ensures that: 1. n_initialization_points is computed correctly from all components 2. Validation happens at build time (fail-fast) 3. Latent infections are routed to the correct observations 4. The API is clean and easy to use
Population structure (subpop_fractions) is passed at sample/fit time, not at configure time. This allows a single model to be fit to multiple jurisdictions with different population structures.
Initialize a new model builder.
The builder starts empty and is configured by calling: 1. configure_latent() - set up the latent infection process 2. add_observation() - add one or more observation processes 3. build() - construct the final model
Source code in pyrenew/model/pyrenew_builder.py
41 42 43 44 45 46 47 48 49 50 51 52 53 | |
add_ascertainment
add_ascertainment(ascertainment_model: AscertainmentModel) -> PyrenewBuilder
Add shared ascertainment structure to the model.
Use this method when observation probabilities are related across
signals. Independent scalar ascertainment rates do not require an
ascertainment model; those can be passed directly to an observation
process as ordinary RandomVariable objects.
A registered ascertainment model is sampled once per model execution,
before observation processes run. Observation processes receive
signal-specific accessors from ascertainment_model.for_signal(...):
ascertainment = JointAscertainment(
name="he_ascertainment",
signals=("hospital", "ed_visits"),
baseline_rates=...,
scale_tril=...,
)
builder.add_ascertainment(ascertainment)
builder.add_observation(
PopulationCounts(
name="hospital",
ascertainment_rate_rv=ascertainment.for_signal("hospital"),
...
)
)
The ascertainment model's name attribute is used as the unique
identifier in the built MultiSignalModel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ascertainment_model
|
AscertainmentModel
|
Configured ascertainment model instance, such as
|
required |
Returns:
| Type | Description |
|---|---|
PyrenewBuilder
|
Self, for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If an ascertainment model with this name already exists. |
Source code in pyrenew/model/pyrenew_builder.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | |
add_observation
add_observation(obs_process: BaseObservationProcess) -> PyrenewBuilder
Add an observation process to the model.
The observation process's name attribute is used as the unique
identifier for this observation. This name is used when passing
observation data to model.sample() and model.fit(), and also
prefixes all numpyro sample sites created by the process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_process
|
BaseObservationProcess
|
Configured observation process instance (e.g., PopulationCounts,
Wastewater, SubpopulationCounts). Must have a |
required |
Returns:
| Type | Description |
|---|---|
PyrenewBuilder
|
Self, for method chaining |
Raises:
| Type | Description |
|---|---|
ValueError
|
If an observation with this name already exists |
Source code in pyrenew/model/pyrenew_builder.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
build
build() -> MultiSignalModel
Build the multi-signal model with computed n_initialization_points.
This method: 1. Computes n_initialization_points from all components 2. Constructs the latent process with the computed value 3. Creates a MultiSignalModel with automatic infection routing
Can be called multiple times to create multiple model instances.
Returns:
| Type | Description |
|---|---|
MultiSignalModel
|
Configured model ready for sampling |
Raises:
| Type | Description |
|---|---|
ValueError
|
If latent process not configured. |
Source code in pyrenew/model/pyrenew_builder.py
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | |
compute_n_initialization_points
compute_n_initialization_points() -> int
Compute required n_initialization_points from all components.
Formula: n_initialization_points = max(all lookbacks)
Where lookbacks include: - Generation interval length (from latent process) - All observation delay/shedding PMF lengths
Useful for inspection before building the model.
Returns:
| Type | Description |
|---|---|
int
|
Minimum n_initialization_points needed to satisfy all components |
Raises:
| Type | Description |
|---|---|
ValueError
|
If latent process not configured or gen_int_rv missing If any observation process doesn't implement get_required_lookback() |
Source code in pyrenew/model/pyrenew_builder.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | |
configure_latent
configure_latent(
latent_class: type[BaseLatentInfectionProcess], **params: dict[str, object]
) -> PyrenewBuilder
Configure the latent infection process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_class
|
type[BaseLatentInfectionProcess]
|
Class to use for latent infections (e.g., PopulationInfections, SubpopulationInfections, or a custom implementation) |
required |
**params
|
dict[str, object]
|
Parameters for latent class constructor (model structure). DO NOT include n_initialization_points - it will be computed automatically from observation processes. DO NOT include population structure params (subpop_fractions) - these are passed at sample/fit time to allow fitting to multiple jurisdictions. |
{}
|
Returns:
| Type | Description |
|---|---|
PyrenewBuilder
|
Self, for method chaining |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_initialization_points or population structure params are included |
RuntimeError
|
If latent has already been configured |
Source code in pyrenew/model/pyrenew_builder.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | |