======================================== Fitting a hospital admissions-only model ======================================== This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data. We begin by loading ``numpyro`` and configuring the device count to 2 to enable running MCMC chains in parallel. By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s `set_host_device_count() `__ function to set the number of devices available for parallel computing. .. container:: cell :name: numpyro-setup .. code:: python import numpyro numpyro.set_host_device_count(2) Model definition ================ In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: .. math:: h(t) \sim \text{HospDist}\left(H(t)\right) Where :math:`h(t)` is the observed number of hospital admissions at time :math:`t`, and :math:`H(t)` is the number of latent hospital admissions at time :math:`t`. The distribution :math:`\text{HospDist}` is discrete. For this example, we will use a negative binomial distribution with an inferred concentration. .. math:: \begin{align*} h(t) & \sim \mathrm{NegativeBinomial}\left(\mathrm{mean} = H(t), \mathrm{concentration} = k\right) \\ H(t) & = p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) \\ \log[p_\mathrm{hosp}(t)] & \sim \mathrm{Normal}(\mu=\log(0.05), \sigma=\log(1.1) \\ \log(k) & \sim \mathrm{Normal}(\mu=\log(1), \sigma=\log(10)) \end{align*} Were :math:`d(\tau)` is the infection to hospital admission interval, :math:`I(t)` is the number of latent infections at time :math:`t`, :math:`p_\mathrm{hosp}(t)` is the infection to admission rate. The number of latent hospital admissions at time :math:`t` is a function of the number of latent infections at time :math:`t` and the infection to admission rate. The latent infections are modeled as a renewal process: .. math:: \begin{align*} I(t) &= \mathcal{R}(t) \times \sum_{\tau < t} I(\tau) g(t - \tau) \\ \log[I(0)] &\sim \text{Normal}(\mu=\log(100), \sigma=\log(1.75)) \end{align*} The reproduction number :math:`\mathcal{R}(t)` is modeled as a random walk in logarithmic space, i.e.: .. math:: \begin{align*} \log[\mathcal{R}(t)] & = \log[\mathcal{R}(t-1)] + \epsilon\\ \epsilon & \sim \text{Normal}(\mu=0, \sigma=0.025) \\ \mathcal{R}(0) &\sim \text{TruncatedNormal}(\text{loc}=1.2, \text{scale}=0.2, \text{min}=0) \end{align*} Data processing =============== We start by loading the data and inspecting the first five rows. .. container:: cell .. code:: python import polars as pl from pyrenew import datasets dat = datasets.load_wastewater() dat.head(5) .. container:: cell-output cell-output-display :name: data-inspect .. raw:: html
shape: (5, 14) +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | t | lab | log | date | lod_s | belo | da | d | pop | for | h | site | w | inf_ | | | _wwtp | _conc | | ewage | w_lod | ily_h | aily_ | | ecast | osp_c | | w_pop | per_c | | | _uniq | | | | | osp_a | hosp_ | | _date | alibr | | | apita | | | ue_id | | | | | dmits | admit | | | ation | | | | | | | | | | | | s_for | | | _time | | | | | | | | | | | | _eval | | | | | | | +=====+=======+=======+=======+=======+=======+=======+=======+=====+=======+=======+======+=======+=======+ | i64 | i64 | f64 | date | f64 | i64 | i64 | i64 | f64 | date | i64 | i64 | f64 | f64 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 1 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 2 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 3 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 2 | 200 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 4 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 3 | 100 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 5 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 4 | 50 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ .. raw:: html
The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day. .. container:: cell .. code:: python # Keeping the first observation of each date dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"]) # Now, sorting by date dat = dat.sort("date") # Keeping the first 90 days dat = dat.head(90) dat.head(5) .. container:: cell-output cell-output-display :name: aggregation .. raw:: html
shape: (5, 2) ========== ================= date daily_hosp_admits ========== ================= date i64 2023-10-30 6 2023-10-31 8 2023-11-01 4 2023-11-02 8 2023-11-03 4 ========== ================= .. raw:: html
Let’s take a look at the daily prevalence of hospital admissions. .. container:: cell .. code:: python import matplotlib.pyplot as plt import matplotlib.dates as mdates daily_hosp_admits = dat["daily_hosp_admits"].to_numpy() dates = dat["date"].to_numpy() ax = plt.gca() ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) ax.xaxis.set_major_locator(mdates.DayLocator(interval=7)) ax.set_xlim(dates[0], dates[-1]) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") plt.plot(dates, daily_hosp_admits, "-o") plt.xlabel("Date") plt.ylabel("Admissions") plt.show() .. container:: cell-output cell-output-display |image1| Building the model ================== First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval. .. container:: cell .. code:: python gen_int = datasets.load_generation_interval() inf_hosp_int = datasets.load_infection_admission_interval() # We only need the probability_mass column of each dataset gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy() # Taking a peek at the first 5 elements of each gen_int[:5], inf_hosp_int_array[:5] # Visualizing both quantities side by side fig, axs = plt.subplots(1, 2) axs[0].plot(gen_int) axs[0].set_title("Generation interval") axs[1].plot(inf_hosp_int_array) axs[1].set_title("Infection to hospital admission interval") plt.show() .. container:: cell-output cell-output-display |image2| With these two in hand, we can start building the model. First, we will define the latent hospital admissions: .. container:: cell :name: latent-hosp .. code:: python from pyrenew import latent, deterministic, metaclass, randomvariable import jax.numpy as jnp import numpyro.distributions as dist inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) hosp_rate = randomvariable.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) latent_hosp = latent.HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp_int, infection_hospitalization_ratio_rv=hosp_rate, ) The ``inf_hosp_int`` is a ``DeterministicPMF`` object that takes the infection to hospital admission interval as input. The ``hosp_rate`` is a ``DistributionalVariable`` object that takes a numpyro distribution to represent the infection to hospital admission rate. The ``HospitalAdmissions`` class is a ``RandomVariable`` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: .. container:: cell :name: initializing-rest-of-model .. code:: python from pyrenew import model, process, observation, metaclass, transformation from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, ) # Infection process latent_inf = latent.Infections() n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1 I0 = InfectionInitializationProcess( "I0_initialization", randomvariable.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( n_initialization_points, deterministic.DeterministicVariable(name="rate", value=0.05), ), ) # Generation interval and Rt gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int) class MyRt(metaclass.RandomVariable): def validate(self): pass def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) rt_rv = randomvariable.TransformedVariable( name="log_rt_random_walk", base_rv=process.RandomWalk( name="log_rt", step_rv=randomvariable.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=transformation.ExpTransform(), ) rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt = rt_init_rv.sample() return rt_rv.sample(n=n, init_vals=init_rt, **kwargs) rtproc = MyRt() # The observation model # we place a log-Normal prior on the concentration # parameter of the negative binomial. nb_conc_rv = randomvariable.TransformedVariable( "concentration", randomvariable.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), transformation.PowerTransform(-2), ) # now we define the observation process obs = observation.NegativeBinomialObservation( "negbinom_rv", concentration_rv=nb_conc_rv, ) Notice all the components are ``RandomVariable`` instances. We can now build the model: .. container:: cell :name: init-model .. code:: python hosp_model = model.HospitalAdmissionsModel( latent_infections_rv=latent_inf, latent_hosp_admissions_rv=latent_hosp, I0_rv=I0, gen_int_rv=gen_int, Rt_process_rv=rtproc, hosp_admission_obs_process_rv=obs, ) Let’s simulate from the prior predictive distribution to check that the model is working: .. container:: cell :name: simulation .. code:: python import numpy as np timeframe = 120 with numpyro.handlers.seed(rng_seed=223): simulated_data = hosp_model.sample(n_datapoints=timeframe) .. container:: cell .. code:: python import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot axs[0].plot(simulated_data.Rt) axs[0].set_ylabel("Simulated Rt") # Admissions plot axs[1].plot(simulated_data.observed_hosp_admissions, "-o") axs[1].set_ylabel("Simulated Admissions") fig.suptitle("Basic renewal model") fig.supxlabel("Time") plt.tight_layout() plt.show() .. container:: cell-output cell-output-display |image3| Fitting the model ================= We now fit the model, not to these simulated data, but rather to the dataset we retrieved above. We use the ``run`` method of the ``Model`` object: .. container:: cell :name: model-fit .. code:: python import jax hosp_model.run( num_samples=1000, num_warmup=1000, data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.key(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) We can use the ``Model`` object’s ``plot_posterior`` method to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process) [1]_: .. container:: cell .. code:: python out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=daily_hosp_admits.astype(float), ) .. container:: cell-output cell-output-display |image4| Results exploration and MCMC diagnostics ======================================== To explore further, We can use `ArviZ `__ to visualize the results. Let’s start by loading the module and converting the fitted model to ArviZ ``InferenceData`` object: .. container:: cell :name: convert-inferencedata .. code:: python import arviz as az idata = az.from_numpyro(hosp_model.mcmc) We obtain the summary of model diagnostics and print the diagnostics for ``latent_hospital_admissions[1]`` .. container:: cell :name: diagnostics .. code:: python diagnostic_stats_summary = az.summary( idata.posterior, kind="diagnostics", ) print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"]) .. container:: cell-output cell-output-stdout :: mcse_mean 0.013 mcse_sd 0.009 ess_bulk 2118.000 ess_tail 1652.000 r_hat 1.000 Name: latent_hospital_admissions[1], dtype: float64 Below we plot 90% and 50% highest density intervals for latent hospital admissions using `plot_hdi `__: .. container:: cell .. code:: python x_data = idata.posterior["latent_hospital_admissions_dim_0"] y_data = idata.posterior["latent_hospital_admissions"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, y_data, hdi_prob=0.9, color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( x_data, y_data, hdi_prob=0.5, color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) # Add the posterior median to the figure median_ts = y_data.median(dim=["chain", "draw"]) axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() axes.set_title("Posterior Hospital Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Hospital Admissions", fontsize=10) plt.show() .. container:: cell-output cell-output-display |image5| We can also look at credible intervals for the posterior distribution of latent infections: .. container:: cell .. code:: python x_data = ( idata.posterior["all_latent_infections_dim_0"] - n_initialization_points ) y_data = idata.posterior["all_latent_infections"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, y_data, hdi_prob=0.9, color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( x_data, y_data, hdi_prob=0.5, color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) # Add the posterior median to the figure median_ts = y_data.median(dim=["chain", "draw"]) axes.plot(x_data, median_ts, color="C0", label="Median") axes.legend() .. container:: cell-output cell-output-display |image6| Predictive checks and forecasting ================================= We can use the ``Model``\ ’s ``posterior_predictive`` and ``prior_predictive`` methods to generate posterior and prior predictive samples for observed admissions. .. container:: cell :name: demonstrate-use-of-predictive-methods .. code:: python import arviz as az idata = az.from_numpyro( hosp_model.mcmc, posterior_predictive=hosp_model.posterior_predictive( n_datapoints=len(daily_hosp_admits) ), prior=hosp_model.prior_predictive( n_datapoints=len(daily_hosp_admits), numpyro_predictive_args={"num_samples": 1000}, ), ) We will use ``plot_lm`` method from ArviZ to plot the posterior predictive distribution against the actual observed data below: .. container:: cell .. code:: python fig, ax = plt.subplots() az.plot_lm( "negbinom_rv", idata=idata, kind_pp="hdi", y_kwargs={"color": "black"}, y_hat_fill_kwargs={"color": "C0"}, axes=ax, ) ax.set_title("Posterior Predictive Plot") ax.set_ylabel("Hospital Admissions") ax.set_xlabel("Days") plt.show() .. container:: cell-output cell-output-display |image7| By increasing ``n_datapoints``, we can perform forecasting using the posterior predictive distribution. .. container:: cell :name: posterior-predictive-distribution .. code:: python n_forecast_points = 28 idata = az.from_numpyro( hosp_model.mcmc, posterior_predictive=hosp_model.posterior_predictive( n_datapoints=len(daily_hosp_admits) + n_forecast_points, ), prior=hosp_model.prior_predictive( n_datapoints=len(daily_hosp_admits), numpyro_predictive_args={"num_samples": 1000}, ), ) Below we plot the prior predictive distributions using equal tailed Bayesian credible intervals: .. container:: cell .. code:: python def compute_eti(dataset, eti_prob): eti_bdry = dataset.quantile( ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") ) return eti_bdry.T fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( idata.prior_predictive["negbinom_rv_dim_0"], hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( idata.prior_predictive["negbinom_rv_dim_0"], hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) plt.scatter( idata.observed_data["negbinom_rv_dim_0"], idata.observed_data["negbinom_rv"], color="black", ) axes.set_title("Prior Predictive Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Observed Admissions", fontsize=10) plt.yscale("log") plt.show() .. container:: cell-output cell-output-display |image8| And now we plot the posterior predictive distributions with a 28-day-ahead forecast: .. container:: cell .. code:: python x_data = idata.posterior_predictive["negbinom_rv_dim_0"] y_data = idata.posterior_predictive["negbinom_rv"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, hdi_data=compute_eti(y_data, 0.9), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( x_data, hdi_data=compute_eti(y_data, 0.5), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) # Add median of the posterior to the figure median_ts = y_data.median(dim=["chain", "draw"]) plt.plot( x_data, median_ts, color="C0", label="Median", ) plt.scatter( idata.observed_data["negbinom_rv_dim_0"], idata.observed_data["negbinom_rv"], color="black", ) axes.legend() axes.set_title( "Posterior Predictive Admissions, including a forecast", fontsize=10 ) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Hospital Admissions", fontsize=10) plt.show() .. container:: cell-output cell-output-display |image9| .. [1] The output is captured to avoid ``quarto`` from displaying the output twice. .. |image1| image:: hospital_admissions_model_files/figure-rst/fig-plot-hospital-admissions-output-1.png .. |image2| image:: hospital_admissions_model_files/figure-rst/fig-data-extract-output-1.png .. |image3| image:: hospital_admissions_model_files/figure-rst/fig-basic-output-1.png .. |image4| image:: hospital_admissions_model_files/figure-rst/fig-output-hospital-admissions-output-1.png .. |image5| image:: hospital_admissions_model_files/figure-rst/fig-output-admission-distribution-output-1.png .. |image6| image:: hospital_admissions_model_files/figure-rst/fig-output-infections-distribution-output-1.png .. |image7| image:: hospital_admissions_model_files/figure-rst/fig-posterior-predictive-output-1.png .. |image8| image:: hospital_admissions_model_files/figure-rst/fig-output-prior-predictive-output-1.png .. |image9| image:: hospital_admissions_model_files/figure-rst/fig-output-posterior-predictive-forecast-output-1.png