Skip to content

Model

MultiSignalModel

MultiSignalModel(
    latent_process: BaseLatentInfectionProcess,
    observations: dict[str, BaseObservationProcess],
)

Bases: Model

Multi-signal renewal model.

Combines a latent infection process (e.g., HierarchicalInfections, PartitionedInfections) with multiple observation processes (e.g., CountObservation, WastewaterObservation).

Built via PyrenewBuilder to ensure n_initialization_points is computed correctly from all components. Can also be constructed manually for advanced use cases.

Parameters:

Name Type Description Default
latent_process BaseLatentInfectionProcess

Latent infection process generating infections at jurisdiction and/or subpopulation levels

required
observations Dict[str, BaseObservationProcess]

Dictionary mapping names to observation process instances. Names are used when passing observation data to sample().

required
Notes

The model automatically routes latent infections to observations based on each observation's infection_resolution() method: - "aggregate" → receives aggregate infections from latent process - "subpop" → receives all subpopulation infections; observation selects via indices

Initialize multi-signal model.

Parameters:

Name Type Description Default
latent_process BaseLatentInfectionProcess

Configured latent infection process

required
observations dict[str, BaseObservationProcess]

Dictionary mapping observation names to observation process instances

required

Raises:

Type Description
ValueError

If validation fails (e.g., observation requires subpopulations but latent process doesn't support them)

Source code in pyrenew/model/multisignal_model.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    latent_process: BaseLatentInfectionProcess,
    observations: dict[str, BaseObservationProcess],
):
    """
    Initialize multi-signal model.

    Parameters
    ----------
    latent_process
        Configured latent infection process
    observations
        Dictionary mapping observation names to observation process instances

    Raises
    ------
    ValueError
        If validation fails (e.g., observation requires subpopulations
        but latent process doesn't support them)
    """
    self.latent = latent_process
    self.observations = observations
    self.validate()

pad_observations

pad_observations(obs: ndarray, axis: int = 0) -> ndarray

Pad observations with NaN for the initialization period.

Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. This method prepends n_init NaN values to align user data (starting at day 0 of observations) with the shared axis.

Parameters:

Name Type Description Default
obs ArrayLike

Observations in natural coordinates (index 0 = first observation day). Integer arrays are converted to float (required for NaN).

required
axis int

Axis along which to pad (typically 0 for time axis).

0

Returns:

Type Description
ndarray

Padded observations. First n_init values are NaN.

Source code in pyrenew/model/multisignal_model.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def pad_observations(
    self,
    obs: jnp.ndarray,
    axis: int = 0,
) -> jnp.ndarray:
    """
    Pad observations with NaN for the initialization period.

    Observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days. This method prepends n_init NaN values
    to align user data (starting at day 0 of observations) with the
    shared axis.

    Parameters
    ----------
    obs : ArrayLike
        Observations in natural coordinates (index 0 = first observation day).
        Integer arrays are converted to float (required for NaN).
    axis : int, default 0
        Axis along which to pad (typically 0 for time axis).

    Returns
    -------
    jnp.ndarray
        Padded observations. First n_init values are NaN.
    """
    n_init = self.latent.n_initialization_points
    obs = jnp.asarray(obs, dtype=float)
    pad_shape = list(obs.shape)
    pad_shape[axis] = n_init
    padding = jnp.full(pad_shape, jnp.nan)
    return jnp.concatenate([padding, obs], axis=axis)

sample

sample(
    n_days_post_init: int,
    population_size: float,
    *,
    subpop_fractions=None,
    **observation_data,
)

Sample from the joint generative model.

This is the model function called by NumPyro during inference.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period

required
population_size float

Total population size. Used to convert infection proportions (from latent process) to infection counts (for observation processes).

required
subpop_fractions ArrayLike

Population fractions for all subpopulations. Shape: (n_subpops,).

None
**observation_data

Data for each observation process, keyed by observation name (the name attribute of each observation process). Each value should be a dict of kwargs for that observation's sample().

{}

Returns:

Type Description
None

All quantities are recorded as NumPyro deterministic sites (latent_infections, latent_infections_by_subpop) and observation sites. Use numpyro.infer.Predictive for forward sampling.

Source code in pyrenew/model/multisignal_model.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def sample(
    self,
    n_days_post_init: int,
    population_size: float,
    *,
    subpop_fractions=None,
    **observation_data,
):
    """
    Sample from the joint generative model.

    This is the model function called by NumPyro during inference.

    Parameters
    ----------
    n_days_post_init : int
        Number of days to simulate after initialization period
    population_size : float
        Total population size. Used to convert infection proportions
        (from latent process) to infection counts (for observation processes).
    subpop_fractions : ArrayLike
        Population fractions for all subpopulations. Shape: (n_subpops,).
    **observation_data
        Data for each observation process, keyed by observation name
        (the ``name`` attribute of each observation process).
        Each value should be a dict of kwargs for that observation's sample().

    Returns
    -------
    None
        All quantities are recorded as NumPyro deterministic sites
        (``latent_infections``, ``latent_infections_by_subpop``) and
        observation sites. Use ``numpyro.infer.Predictive`` for forward
        sampling.
    """
    # Generate latent infections (proportions)
    latent_sample = self.latent.sample(
        n_days_post_init=n_days_post_init,
        subpop_fractions=subpop_fractions,
    )

    # Scale from proportions to counts
    inf_aggregate = latent_sample.aggregate * population_size
    inf_all = latent_sample.all_subpops * population_size

    # Record scaled infections for posterior analysis
    numpyro.deterministic("latent_infections", inf_aggregate)
    numpyro.deterministic("latent_infections_by_subpop", inf_all)

    # Map infection resolution to infection arrays
    latent_map = {
        "aggregate": inf_aggregate,
        "subpop": inf_all,
    }

    # Apply each observation process
    for name, obs_process in self.observations.items():
        # Get the appropriate latent infections based on observation type
        resolution = obs_process.infection_resolution()
        if resolution not in latent_map:
            raise ValueError(
                f"Observation '{name}' returned invalid infection_resolution "
                f"'{resolution}'. Expected one of {self._SUPPORTED_RESOLUTIONS}."
            )
        latent_infections = latent_map[resolution]

        # Get observation-specific data
        obs_data = observation_data.get(name, {})

        # Sample from observation process
        obs_process.sample(
            infections=latent_infections,
            **obs_data,
        )

    return None

shift_times

shift_times(times: ndarray) -> ndarray

Shift time indices from natural coordinates to shared time axis.

Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. User-provided times in natural coordinates (0 = first observation day) must be shifted by n_init to align with the shared axis.

Parameters:

Name Type Description Default
times ArrayLike

Time indices in natural coordinates (0 = first observation day).

required

Returns:

Type Description
ndarray

Time indices on the shared axis [n_init, n_total).

Source code in pyrenew/model/multisignal_model.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def shift_times(self, times: jnp.ndarray) -> jnp.ndarray:
    """
    Shift time indices from natural coordinates to shared time axis.

    Observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days. User-provided times in natural coordinates
    (0 = first observation day) must be shifted by n_init to align with
    the shared axis.

    Parameters
    ----------
    times : ArrayLike
        Time indices in natural coordinates (0 = first observation day).

    Returns
    -------
    jnp.ndarray
        Time indices on the shared axis [n_init, n_total).
    """
    n_init = self.latent.n_initialization_points
    return jnp.asarray(times) + n_init

validate

validate() -> None

Validate that observation processes are compatible with latent process.

Checks that each observation implements infection_resolution() and returns a supported resolution.

Raises:

Type Description
ValueError

If an observation doesn't implement infection_resolution() or returns an unsupported resolution.

Source code in pyrenew/model/multisignal_model.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def validate(self) -> None:
    """
    Validate that observation processes are compatible with latent process.

    Checks that each observation implements infection_resolution()
    and returns a supported resolution.

    Raises
    ------
    ValueError
        If an observation doesn't implement infection_resolution()
        or returns an unsupported resolution.
    """
    for name, obs in self.observations.items():
        if not hasattr(obs, "infection_resolution"):
            raise ValueError(
                f"Observation '{name}' must implement infection_resolution()"
            )
        resolution = obs.infection_resolution()
        if resolution not in self._SUPPORTED_RESOLUTIONS:
            raise ValueError(
                f"Observation '{name}' returned invalid infection_resolution "
                f"'{resolution}'. Expected one of {self._SUPPORTED_RESOLUTIONS}."
            )

validate_data

validate_data(
    n_days_post_init: int,
    subpop_fractions=None,
    **observation_data: dict[str, Any],
) -> None

Validate observation data before running MCMC.

All observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days_post_init. Dense observations must have length n_total with NaN padding for the initialization period. Sparse observations provide times indices on this shared axis.

This method must be called with concrete (non-traced) values before running inference. Validation using Python control flow (if/raise) cannot be done during JAX tracing.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period

required
subpop_fractions ArrayLike

Population fractions for all subpopulations. Shape: (n_subpops,).

None
**observation_data dict[str, Any]

Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample().

{}

Raises:

Type Description
ValueError

If times indices are out of bounds or negative If dense obs length doesn't match n_total If data shapes are inconsistent

Source code in pyrenew/model/multisignal_model.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def validate_data(
    self,
    n_days_post_init: int,
    subpop_fractions=None,
    **observation_data: dict[str, Any],
) -> None:
    """
    Validate observation data before running MCMC.

    All observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days_post_init. Dense observations must have
    length n_total with NaN padding for the initialization period.
    Sparse observations provide times indices on this shared axis.

    This method must be called with concrete (non-traced) values
    before running inference. Validation using Python control flow
    (if/raise) cannot be done during JAX tracing.

    Parameters
    ----------
    n_days_post_init : int
        Number of days to simulate after initialization period
    subpop_fractions : ArrayLike
        Population fractions for all subpopulations. Shape: (n_subpops,).
    **observation_data
        Data for each observation process, keyed by observation name.
        Each value should be a dict of kwargs for that observation's sample().

    Raises
    ------
    ValueError
        If times indices are out of bounds or negative
        If dense obs length doesn't match n_total
        If data shapes are inconsistent
    """
    pop = BaseLatentInfectionProcess._parse_and_validate_fractions(
        subpop_fractions=subpop_fractions,
    )

    n_init = self.latent.n_initialization_points
    n_total = n_init + n_days_post_init

    for name, obs_data in observation_data.items():
        if name not in self.observations:
            raise ValueError(
                f"Unknown observation '{name}'. "
                f"Available: {list(self.observations.keys())}"
            )

        self.observations[name].validate_data(
            n_total=n_total,
            n_subpops=pop.n_subpops,
            **obs_data,
        )

PyrenewBuilder

PyrenewBuilder()

Builder for multi-signal renewal models.

Automatically computes n_initialization_points from observation processes and constructs a properly configured model.

The builder pattern ensures that: 1. n_initialization_points is computed correctly from all components 2. Validation happens at build time (fail-fast) 3. Latent infections are routed to the correct observations 4. The API is clean and easy to use

Population structure (subpop_fractions) is passed at sample/fit time, not at configure time. This allows a single model to be fit to multiple jurisdictions with different population structures.

Initialize a new model builder.

The builder starts empty and is configured by calling: 1. configure_latent() - set up the latent infection process 2. add_observation() - add one or more observation processes 3. build() - construct the final model

Source code in pyrenew/model/pyrenew_builder.py
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(self):
    """
    Initialize a new model builder.

    The builder starts empty and is configured by calling:
    1. configure_latent() - set up the latent infection process
    2. add_observation() - add one or more observation processes
    3. build() - construct the final model
    """
    self.latent_class: type[BaseLatentInfectionProcess] | None = None
    self.latent_params: dict[str, Any] = {}
    self.observations: dict[str, BaseObservationProcess] = {}

add_observation

add_observation(obs_process: BaseObservationProcess) -> PyrenewBuilder

Add an observation process to the model.

The observation process's name attribute is used as the unique identifier for this observation. This name is used when passing observation data to model.sample() and model.fit(), and also prefixes all numpyro sample sites created by the process.

Parameters:

Name Type Description Default
obs_process BaseObservationProcess

Configured observation process instance (e.g., Counts, Wastewater, CountsBySubpop). Must have a name attribute.

required

Returns:

Type Description
PyrenewBuilder

Self, for method chaining

Raises:

Type Description
ValueError

If an observation with this name already exists

Source code in pyrenew/model/pyrenew_builder.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def add_observation(
    self,
    obs_process: BaseObservationProcess,
) -> PyrenewBuilder:
    """
    Add an observation process to the model.

    The observation process's ``name`` attribute is used as the unique
    identifier for this observation. This name is used when passing
    observation data to ``model.sample()`` and ``model.fit()``, and also
    prefixes all numpyro sample sites created by the process.

    Parameters
    ----------
    obs_process : BaseObservationProcess
        Configured observation process instance (e.g., Counts,
        Wastewater, CountsBySubpop). Must have a ``name`` attribute.

    Returns
    -------
    PyrenewBuilder
        Self, for method chaining

    Raises
    ------
    ValueError
        If an observation with this name already exists
    """
    name = obs_process.name
    if name in self.observations:
        raise ValueError(
            f"Observation '{name}' already added. "
            f"Each observation must have a unique name."
        )

    self.observations[name] = obs_process
    return self

build

build()

Build the multi-signal model with computed n_initialization_points.

This method: 1. Computes n_initialization_points from all components 2. Constructs the latent process with the computed value 3. Creates a MultiSignalModel with automatic infection routing 4. Validates that observation/latent types are compatible

Can be called multiple times to create multiple model instances.

Returns:

Type Description
MultiSignalModel

Configured model ready for sampling

Raises:

Type Description
ValueError

If latent process not configured

Source code in pyrenew/model/pyrenew_builder.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def build(self):
    """
    Build the multi-signal model with computed n_initialization_points.

    This method:
    1. Computes n_initialization_points from all components
    2. Constructs the latent process with the computed value
    3. Creates a MultiSignalModel with automatic infection routing
    4. Validates that observation/latent types are compatible

    Can be called multiple times to create multiple model instances.

    Returns
    -------
    MultiSignalModel
        Configured model ready for sampling

    Raises
    ------
    ValueError
        If latent process not configured
    """
    if self.latent_class is None:
        raise ValueError("Must call configure_latent() before build()")

    # Compute n_initialization_points
    n_init = self.compute_n_initialization_points()

    # Construct latent process with computed n_initialization_points
    latent_params = {**self.latent_params, "n_initialization_points": n_init}

    latent_process = self.latent_class(**latent_params)

    # Build model
    model = MultiSignalModel(
        latent_process=latent_process,
        observations=self.observations,
    )

    return model

compute_n_initialization_points

compute_n_initialization_points() -> int

Compute required n_initialization_points from all components.

Formula: n_initialization_points = max(all lookbacks)

Where lookbacks include: - Generation interval length (from latent process) - All observation delay/shedding PMF lengths

Useful for inspection before building the model.

Returns:

Type Description
int

Minimum n_initialization_points needed to satisfy all components

Raises:

Type Description
ValueError

If latent process not configured or gen_int_rv missing If any observation process doesn't implement get_required_lookback()

Source code in pyrenew/model/pyrenew_builder.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def compute_n_initialization_points(self) -> int:
    """
    Compute required n_initialization_points from all components.

    Formula: n_initialization_points = max(all lookbacks)

    Where lookbacks include:
    - Generation interval length (from latent process)
    - All observation delay/shedding PMF lengths

    Useful for inspection before building the model.

    Returns
    -------
    int
        Minimum n_initialization_points needed to satisfy all components

    Raises
    ------
    ValueError
        If latent process not configured or gen_int_rv missing
        If any observation process doesn't implement get_required_lookback()
    """
    if self.latent_class is None:
        raise ValueError(
            "Must call configure_latent() before computing n_initialization_points"
        )

    # Get generation interval length from latent params
    gen_int_rv = self.latent_params.get("gen_int_rv")
    if gen_int_rv is None:
        raise ValueError("gen_int_rv is required in latent process parameters")

    # Start with generation interval lookback
    lookbacks = [len(gen_int_rv())]

    # Add lookback from each observation process
    for name, obs_process in self.observations.items():
        lookbacks.append(obs_process.lookback_days())

    # Formula: max(all lookbacks)
    # For generation interval (1-indexed): L-element PMF has max lag L days → need L init points
    # For delay distributions (0-indexed): L-element PMF has max delay L-1 days
    # We need at least max(all lookbacks) to satisfy the renewal equation extraction
    n_init = max(lookbacks)

    return n_init

configure_latent

configure_latent(
    latent_class: type[BaseLatentInfectionProcess], **params
) -> PyrenewBuilder

Configure the latent infection process.

Parameters:

Name Type Description Default
latent_class Type[BaseLatentInfectionProcess]

Class to use for latent infections (e.g., HierarchicalInfections, PartitionedInfections, or a custom implementation)

required
**params

Parameters for latent class constructor (model structure). DO NOT include n_initialization_points - it will be computed automatically from observation processes. DO NOT include population structure params (subpop_fractions) - these are passed at sample/fit time to allow fitting to multiple jurisdictions.

{}

Returns:

Type Description
PyrenewBuilder

Self, for method chaining

Raises:

Type Description
ValueError

If n_initialization_points or population structure params are included

RuntimeError

If latent has already been configured

Source code in pyrenew/model/pyrenew_builder.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def configure_latent(
    self,
    latent_class: type[BaseLatentInfectionProcess],
    **params,
) -> PyrenewBuilder:
    """
    Configure the latent infection process.

    Parameters
    ----------
    latent_class : Type[BaseLatentInfectionProcess]
        Class to use for latent infections (e.g., HierarchicalInfections,
        PartitionedInfections, or a custom implementation)
    **params
        Parameters for latent class constructor (model structure).
        DO NOT include n_initialization_points - it will be computed
        automatically from observation processes.
        DO NOT include population structure params (subpop_fractions) -
        these are passed at sample/fit time to allow fitting to multiple
        jurisdictions.

    Returns
    -------
    PyrenewBuilder
        Self, for method chaining

    Raises
    ------
    ValueError
        If n_initialization_points or population structure params are included
    RuntimeError
        If latent has already been configured
    """
    if self.latent_class is not None:
        raise RuntimeError(
            "Latent process already configured. Create a new builder for "
            "a different configuration."
        )

    if "n_initialization_points" in params:
        raise ValueError(
            "Do not specify n_initialization_points - it will be computed "
            "automatically from observation processes and generation interval. "
            "Use PyrenewBuilder.build() to create the model with the correct value."
        )

    # Check for population structure params that should be at sample time
    sample_time_found = _SAMPLE_TIME_PARAMS.intersection(params.keys())
    if sample_time_found:
        raise ValueError(
            f"Do not specify {sample_time_found} at configure time. "
            f"Population structure is passed at sample/fit time to allow "
            f"fitting the same model to multiple jurisdictions."
        )

    self.latent_class = latent_class
    self.latent_params = params
    return self