Model
MultiSignalModel
MultiSignalModel(
latent_process: BaseLatentInfectionProcess,
observations: dict[str, BaseObservationProcess],
)
Bases: Model
Multi-signal renewal model.
Combines a latent infection process (e.g., HierarchicalInfections, PartitionedInfections) with multiple observation processes (e.g., CountObservation, WastewaterObservation).
Built via PyrenewBuilder to ensure n_initialization_points is computed correctly from all components. Can also be constructed manually for advanced use cases.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_process
|
BaseLatentInfectionProcess
|
Latent infection process generating infections at jurisdiction and/or subpopulation levels |
required |
observations
|
Dict[str, BaseObservationProcess]
|
Dictionary mapping names to observation process instances. Names are used when passing observation data to sample(). |
required |
Notes
The model automatically routes latent infections to observations based on each observation's infection_resolution() method: - "aggregate" → receives aggregate infections from latent process - "subpop" → receives all subpopulation infections; observation selects via indices
Initialize multi-signal model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_process
|
BaseLatentInfectionProcess
|
Configured latent infection process |
required |
observations
|
dict[str, BaseObservationProcess]
|
Dictionary mapping observation names to observation process instances |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If validation fails (e.g., observation requires subpopulations but latent process doesn't support them) |
Source code in pyrenew/model/multisignal_model.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | |
pad_observations
Pad observations with NaN for the initialization period.
Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. This method prepends n_init NaN values to align user data (starting at day 0 of observations) with the shared axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
ArrayLike
|
Observations in natural coordinates (index 0 = first observation day). Integer arrays are converted to float (required for NaN). |
required |
axis
|
int
|
Axis along which to pad (typically 0 for time axis). |
0
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Padded observations. First n_init values are NaN. |
Source code in pyrenew/model/multisignal_model.py
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | |
sample
sample(
n_days_post_init: int,
population_size: float,
*,
subpop_fractions=None,
**observation_data,
)
Sample from the joint generative model.
This is the model function called by NumPyro during inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_days_post_init
|
int
|
Number of days to simulate after initialization period |
required |
population_size
|
float
|
Total population size. Used to convert infection proportions (from latent process) to infection counts (for observation processes). |
required |
subpop_fractions
|
ArrayLike
|
Population fractions for all subpopulations. Shape: (n_subpops,). |
None
|
**observation_data
|
Data for each observation process, keyed by observation name
(the |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
All quantities are recorded as NumPyro deterministic sites
( |
Source code in pyrenew/model/multisignal_model.py
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 | |
shift_times
Shift time indices from natural coordinates to shared time axis.
Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. User-provided times in natural coordinates (0 = first observation day) must be shifted by n_init to align with the shared axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
times
|
ArrayLike
|
Time indices in natural coordinates (0 = first observation day). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Time indices on the shared axis [n_init, n_total). |
Source code in pyrenew/model/multisignal_model.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
validate
validate() -> None
Validate that observation processes are compatible with latent process.
Checks that each observation implements infection_resolution() and returns a supported resolution.
Raises:
| Type | Description |
|---|---|
ValueError
|
If an observation doesn't implement infection_resolution() or returns an unsupported resolution. |
Source code in pyrenew/model/multisignal_model.py
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | |
validate_data
validate_data(
n_days_post_init: int,
subpop_fractions=None,
**observation_data: dict[str, Any],
) -> None
Validate observation data before running MCMC.
All observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days_post_init. Dense observations must have length n_total with NaN padding for the initialization period. Sparse observations provide times indices on this shared axis.
This method must be called with concrete (non-traced) values before running inference. Validation using Python control flow (if/raise) cannot be done during JAX tracing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_days_post_init
|
int
|
Number of days to simulate after initialization period |
required |
subpop_fractions
|
ArrayLike
|
Population fractions for all subpopulations. Shape: (n_subpops,). |
None
|
**observation_data
|
dict[str, Any]
|
Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample(). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If times indices are out of bounds or negative If dense obs length doesn't match n_total If data shapes are inconsistent |
Source code in pyrenew/model/multisignal_model.py
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
PyrenewBuilder
PyrenewBuilder()
Builder for multi-signal renewal models.
Automatically computes n_initialization_points from observation processes and constructs a properly configured model.
The builder pattern ensures that: 1. n_initialization_points is computed correctly from all components 2. Validation happens at build time (fail-fast) 3. Latent infections are routed to the correct observations 4. The API is clean and easy to use
Population structure (subpop_fractions) is passed at sample/fit time, not at configure time. This allows a single model to be fit to multiple jurisdictions with different population structures.
Initialize a new model builder.
The builder starts empty and is configured by calling: 1. configure_latent() - set up the latent infection process 2. add_observation() - add one or more observation processes 3. build() - construct the final model
Source code in pyrenew/model/pyrenew_builder.py
40 41 42 43 44 45 46 47 48 49 50 51 | |
add_observation
add_observation(obs_process: BaseObservationProcess) -> PyrenewBuilder
Add an observation process to the model.
The observation process's name attribute is used as the unique
identifier for this observation. This name is used when passing
observation data to model.sample() and model.fit(), and also
prefixes all numpyro sample sites created by the process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_process
|
BaseObservationProcess
|
Configured observation process instance (e.g., Counts,
Wastewater, CountsBySubpop). Must have a |
required |
Returns:
| Type | Description |
|---|---|
PyrenewBuilder
|
Self, for method chaining |
Raises:
| Type | Description |
|---|---|
ValueError
|
If an observation with this name already exists |
Source code in pyrenew/model/pyrenew_builder.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | |
build
build()
Build the multi-signal model with computed n_initialization_points.
This method: 1. Computes n_initialization_points from all components 2. Constructs the latent process with the computed value 3. Creates a MultiSignalModel with automatic infection routing 4. Validates that observation/latent types are compatible
Can be called multiple times to create multiple model instances.
Returns:
| Type | Description |
|---|---|
MultiSignalModel
|
Configured model ready for sampling |
Raises:
| Type | Description |
|---|---|
ValueError
|
If latent process not configured |
Source code in pyrenew/model/pyrenew_builder.py
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | |
compute_n_initialization_points
compute_n_initialization_points() -> int
Compute required n_initialization_points from all components.
Formula: n_initialization_points = max(all lookbacks)
Where lookbacks include: - Generation interval length (from latent process) - All observation delay/shedding PMF lengths
Useful for inspection before building the model.
Returns:
| Type | Description |
|---|---|
int
|
Minimum n_initialization_points needed to satisfy all components |
Raises:
| Type | Description |
|---|---|
ValueError
|
If latent process not configured or gen_int_rv missing If any observation process doesn't implement get_required_lookback() |
Source code in pyrenew/model/pyrenew_builder.py
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | |
configure_latent
configure_latent(
latent_class: type[BaseLatentInfectionProcess], **params
) -> PyrenewBuilder
Configure the latent infection process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_class
|
Type[BaseLatentInfectionProcess]
|
Class to use for latent infections (e.g., HierarchicalInfections, PartitionedInfections, or a custom implementation) |
required |
**params
|
Parameters for latent class constructor (model structure). DO NOT include n_initialization_points - it will be computed automatically from observation processes. DO NOT include population structure params (subpop_fractions) - these are passed at sample/fit time to allow fitting to multiple jurisdictions. |
{}
|
Returns:
| Type | Description |
|---|---|
PyrenewBuilder
|
Self, for method chaining |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_initialization_points or population structure params are included |
RuntimeError
|
If latent has already been configured |
Source code in pyrenew/model/pyrenew_builder.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | |