Skip to content

Model

HospitalAdmissionsModel

HospitalAdmissionsModel(
    latent_hosp_admissions_rv: RandomVariable,
    latent_infections_rv: RandomVariable,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    Rt_process_rv: RandomVariable,
    hosp_admission_obs_process_rv: RandomVariable,
)

Bases: Model

Hospital Admissions Model (BasicRenewal + HospitalAdmissions)

This class inherits from pyrenew.models.Model. It extends the basic renewal model by adding a hospital admissions module, e.g., pyrenew.observations.HospitalAdmissions.

Default constructor

Parameters:

Name Type Description Default
latent_hosp_admissions_rv RandomVariable

Latent process for the hospital admissions.

required
latent_infections_rv RandomVariable

The infections latent process (passed to RtInfectionsRenewalModel).

required
gen_int_rv RandomVariable

Generation time (passed to RtInfectionsRenewalModel)

required
I0_rv RandomVariable

Initial infections (passed to RtInfectionsRenewalModel)

required
Rt_process_rv RandomVariable

Rt process (passed to RtInfectionsRenewalModel).

required
hosp_admission_obs_process_rv RandomVariable

Observation process for the hospital admissions.

required

Returns:

Type Description
None
Source code in pyrenew/model/admissionsmodel.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __init__(
    self,
    latent_hosp_admissions_rv: RandomVariable,
    latent_infections_rv: RandomVariable,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    Rt_process_rv: RandomVariable,
    hosp_admission_obs_process_rv: RandomVariable,
) -> None:  # numpydoc ignore=PR04
    """
    Default constructor

    Parameters
    ----------
    latent_hosp_admissions_rv
        Latent process for the hospital admissions.
    latent_infections_rv
        The infections latent process (passed to RtInfectionsRenewalModel).
    gen_int_rv
        Generation time (passed to RtInfectionsRenewalModel)
    I0_rv
        Initial infections (passed to RtInfectionsRenewalModel)
    Rt_process_rv
        Rt process  (passed to RtInfectionsRenewalModel).
    hosp_admission_obs_process_rv
        Observation process for the hospital admissions.

    Returns
    -------
    None
    """
    self.basic_renewal = RtInfectionsRenewalModel(
        gen_int_rv=gen_int_rv,
        I0_rv=I0_rv,
        latent_infections_rv=latent_infections_rv,
        infection_obs_process_rv=None,  # why is this None?
        Rt_process_rv=Rt_process_rv,
    )

    HospitalAdmissionsModel.validate(
        latent_hosp_admissions_rv, hosp_admission_obs_process_rv
    )

    self.latent_hosp_admissions_rv = latent_hosp_admissions_rv
    if hosp_admission_obs_process_rv is None:
        hosp_admission_obs_process_rv = NullObservation()

    self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv

sample

sample(
    n_datapoints: int | None = None,
    data_observed_hosp_admissions: ArrayLike | None = None,
    padding: int = 0,
    **kwargs,
) -> HospModelSample

Sample from the HospitalAdmissions model

Parameters:

Name Type Description Default
n_datapoints int | None

Number of timepoints to sample (passed to the basic renewal model).

None
data_observed_hosp_admissions ArrayLike | None

The observed hospitalization data (passed to the basic renewal model). Defaults to None (simulation, rather than fit).

None
padding int

Number of padding timepoints to add to the beginning of the simulation. Defaults to 0.

0
**kwargs

Additional keyword arguments passed through to internal sample() calls, should there be any.

{}

Returns:

Type Description
HospModelSample
See Also

basic_renewal.sample : For sampling the basic renewal model sample_observed_admissions : For sampling observed hospital admissions

Source code in pyrenew/model/admissionsmodel.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def sample(
    self,
    n_datapoints: int | None = None,
    data_observed_hosp_admissions: ArrayLike | None = None,
    padding: int = 0,
    **kwargs,
) -> HospModelSample:
    """
    Sample from the HospitalAdmissions model

    Parameters
    ----------
    n_datapoints
        Number of timepoints to sample (passed to the basic renewal model).
    data_observed_hosp_admissions
        The observed hospitalization data (passed to the basic renewal
        model). Defaults to None (simulation, rather than fit).
    padding
        Number of padding timepoints to add to the beginning of the
        simulation. Defaults to 0.
    **kwargs
        Additional keyword arguments passed through to internal sample()
        calls, should there be any.

    Returns
    -------
    HospModelSample

    See Also
    --------
    basic_renewal.sample : For sampling the basic renewal model
    sample_observed_admissions : For sampling observed hospital admissions
    """
    if n_datapoints is None and data_observed_hosp_admissions is None:
        raise ValueError(
            "Either n_datapoints or data_observed_hosp_admissions must be passed."
        )
    elif n_datapoints is not None and data_observed_hosp_admissions is not None:
        raise ValueError(
            "Cannot pass both n_datapoints and data_observed_hosp_admissions."
        )
    elif n_datapoints is None:
        n_datapoints = len(data_observed_hosp_admissions)
    else:
        n_datapoints = n_datapoints

    # Getting the initial quantities from the basic model
    basic_model = self.basic_renewal.sample(
        n_datapoints=n_datapoints,
        data_observed_infections=None,
        padding=padding,
        **kwargs,
    )
    # Sampling the latent hospital admissions
    (
        infection_hosp_rate,
        latent_hosp_admissions,
        *_,
    ) = self.latent_hosp_admissions_rv(
        latent_infections=basic_model.latent_infections,
        **kwargs,
    )
    observed_hosp_admissions = self.hosp_admission_obs_process_rv(
        mu=latent_hosp_admissions[-n_datapoints:],
        obs=data_observed_hosp_admissions,
        **kwargs,
    )

    return HospModelSample(
        Rt=basic_model.Rt,
        latent_infections=basic_model.latent_infections,
        infection_hosp_rate=infection_hosp_rate,
        latent_hosp_admissions=latent_hosp_admissions,
        observed_hosp_admissions=observed_hosp_admissions,
    )

validate staticmethod

validate(latent_hosp_admissions_rv, hosp_admission_obs_process_rv) -> None

Verifies types and status (RV) of latent and observed hospital admissions

Parameters:

Name Type Description Default
latent_hosp_admissions_rv

The latent process for the hospital admissions.

required
hosp_admission_obs_process_rv

The observed hospital admissions.

required

Returns:

Type Description
None
Source code in pyrenew/model/admissionsmodel.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@staticmethod
def validate(latent_hosp_admissions_rv, hosp_admission_obs_process_rv) -> None:
    """
    Verifies types and status (RV) of latent and observed hospital admissions

    Parameters
    ----------
    latent_hosp_admissions_rv
        The latent process for the hospital admissions.
    hosp_admission_obs_process_rv
        The observed hospital admissions.

    Returns
    -------
    None
    """
    assert isinstance(latent_hosp_admissions_rv, RandomVariable)
    if hosp_admission_obs_process_rv is not None:
        assert isinstance(hosp_admission_obs_process_rv, RandomVariable)

    return None

RtInfectionsRenewalModel

RtInfectionsRenewalModel(
    latent_infections_rv: RandomVariable,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    Rt_process_rv: RandomVariable,
    infection_obs_process_rv: RandomVariable = None,
)

Bases: Model

Basic Renewal Model (Infections + Rt)

The basic renewal model consists of a sampler of two steps: Sample from Rt and then used that to sample the infections.

Default constructor

Parameters:

Name Type Description Default
latent_infections_rv RandomVariable

Infections latent process (e.g., pyrenew.latent.Infections.).

required
gen_int_rv RandomVariable

The generation interval.

required
I0_rv RandomVariable

The initial infections.

required
Rt_process_rv RandomVariable

The sample function of the process should return a tuple where the first element is the drawn Rt.

required
infection_obs_process_rv RandomVariable

Infections observation process (e.g., pyrenew.observations.Poisson.).

None

Returns:

Type Description
None
Source code in pyrenew/model/rtinfectionsrenewalmodel.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(
    self,
    latent_infections_rv: RandomVariable,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    Rt_process_rv: RandomVariable,
    infection_obs_process_rv: RandomVariable = None,
) -> None:
    """
    Default constructor

    Parameters
    ----------
    latent_infections_rv
        Infections latent process (e.g.,
        pyrenew.latent.Infections.).
    gen_int_rv
        The generation interval.
    I0_rv
        The initial infections.
    Rt_process_rv
        The sample function of the process should return a tuple where the
        first element is the drawn Rt.
    infection_obs_process_rv
        Infections observation process (e.g.,
        pyrenew.observations.Poisson.).

    Returns
    -------
    None
    """

    if infection_obs_process_rv is None:
        infection_obs_process_rv = NullObservation()

    RtInfectionsRenewalModel.validate(
        gen_int_rv=gen_int_rv,
        I0_rv=I0_rv,
        latent_infections_rv=latent_infections_rv,
        infection_obs_process_rv=infection_obs_process_rv,
        Rt_process_rv=Rt_process_rv,
    )

    self.gen_int_rv = gen_int_rv
    self.I0_rv = I0_rv
    self.latent_infections_rv = latent_infections_rv
    self.infection_obs_process_rv = infection_obs_process_rv
    self.Rt_process_rv = Rt_process_rv

sample

sample(
    n_datapoints: int | None = None,
    data_observed_infections: ArrayLike | None = None,
    padding: int = 0,
    **kwargs,
) -> RtInfectionsRenewalSample

Sample from the Basic Renewal Model

Parameters:

Name Type Description Default
n_datapoints int | None

Number of timepoints to sample.

None
data_observed_infections ArrayLike | None

Observed infections. Defaults to None.

None
padding int

Number of padding timepoints to add to the beginning of the simulation. Defaults to 0.

0
**kwargs

Additional keyword arguments passed through to internal sample() calls, if any

{}
Notes

Either data_observed_infections or n_datapoints must be specified, not both.

Returns:

Type Description
RtInfectionsRenewalSample
Source code in pyrenew/model/rtinfectionsrenewalmodel.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def sample(
    self,
    n_datapoints: int | None = None,
    data_observed_infections: ArrayLike | None = None,
    padding: int = 0,
    **kwargs,
) -> RtInfectionsRenewalSample:
    """
    Sample from the Basic Renewal Model

    Parameters
    ----------
    n_datapoints
        Number of timepoints to sample.
    data_observed_infections
        Observed infections. Defaults to None.
    padding
        Number of padding timepoints to add to the beginning of the
        simulation. Defaults to 0.
    **kwargs
        Additional keyword arguments passed through to internal sample()
        calls, if any

    Notes
    -----
    Either `data_observed_infections` or `n_datapoints`
    must be specified, not both.

    Returns
    -------
    RtInfectionsRenewalSample
    """

    if n_datapoints is None and data_observed_infections is None:
        raise ValueError(
            "Either n_datapoints or data_observed_infections must be passed."
        )
    elif n_datapoints is not None and data_observed_infections is not None:
        raise ValueError(
            "Cannot pass both n_datapoints and data_observed_infections."
        )
    elif n_datapoints is None:
        n_timepoints = len(data_observed_infections) + padding
    else:
        n_timepoints = n_datapoints + padding
    # Sampling from Rt (possibly with a given Rt, depending on
    # the Rt_process (RandomVariable) object.)

    Rt = self.Rt_process_rv(
        n=n_timepoints,
        **kwargs,
    )

    # Getting the generation interval
    gen_int = self.gen_int_rv(**kwargs)

    # Sampling initial infections
    I0 = self.I0_rv(**kwargs)

    # Sampling from the latent process
    post_initialization_latent_infections = self.latent_infections_rv(
        Rt=Rt,
        gen_int=gen_int,
        I0=I0,
        **kwargs,
    ).post_initialization_infections
    observed_infections = self.infection_obs_process_rv(
        mu=post_initialization_latent_infections[padding:],
        obs=data_observed_infections,
        **kwargs,
    )

    all_latent_infections = jnp.hstack([I0, post_initialization_latent_infections])
    numpyro.deterministic("all_latent_infections", all_latent_infections)

    numpyro.deterministic("Rt", Rt)

    return RtInfectionsRenewalSample(
        Rt=Rt,
        latent_infections=all_latent_infections,
        observed_infections=observed_infections,
    )

validate staticmethod

validate(
    gen_int_rv: any,
    I0_rv: any,
    latent_infections_rv: any,
    infection_obs_process_rv: any,
    Rt_process_rv: any,
) -> None

Verifies types and status (RV) of the generation interval, initial infections, latent and observed infections, and the Rt process.

Parameters:

Name Type Description Default
gen_int_rv any

The generation interval. Expects RandomVariable.

required
I0_rv any

The initial infections. Expects RandomVariable.

required
latent_infections_rv any

Infections latent process. Expects RandomVariable.

required
infection_obs_process_rv any

Infections observation process. Expects RandomVariable.

required
Rt_process_rv any

The sample function of the process should return a tuple where the first element is the drawn Rt. Expects RandomVariable.

required

Returns:

Type Description
None
Source code in pyrenew/model/rtinfectionsrenewalmodel.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@staticmethod
def validate(
    gen_int_rv: any,
    I0_rv: any,
    latent_infections_rv: any,
    infection_obs_process_rv: any,
    Rt_process_rv: any,
) -> None:
    """
    Verifies types and status (RV) of the generation interval, initial
    infections, latent and observed infections, and the Rt process.

    Parameters
    ----------
    gen_int_rv
        The generation interval. Expects RandomVariable.
    I0_rv
        The initial infections. Expects RandomVariable.
    latent_infections_rv
        Infections latent process. Expects RandomVariable.
    infection_obs_process_rv
        Infections observation process. Expects RandomVariable.
    Rt_process_rv
        The sample function of the process should return a tuple where the
        first element is the drawn Rt. Expects RandomVariable.

    Returns
    -------
    None
    """
    assert isinstance(gen_int_rv, RandomVariable)
    assert isinstance(I0_rv, RandomVariable)
    assert isinstance(latent_infections_rv, RandomVariable)
    assert isinstance(infection_obs_process_rv, RandomVariable)
    assert isinstance(Rt_process_rv, RandomVariable)
    return None