Skip to contents

Quick start

In this quick start, we demonstrate using wwinference to specify and fit a minimal model using daily COVID-19 hospital admissions from a “global” population and viral concentrations in wastewater from a few “local” wastewater treatment plants, which come from subsets of the larger population. In this context, when we say “global”, we are referring to a larger population e.g. a state, and when we say “local” we are referring to a smaller subset of that population, e.g. a municipality within that state. This is intended to be used as a reference for those interested in fitting the wwinference model to their own data.

Packages

In this quick start, we also use dplyr tidybayes and ggplot2 packages. These are installed as dependencies when wwinference is installed.

Code

Data

The model expects two types of data: daily counts of hospital admissions data from the larger “global” population, and wastewater concentration data from wastewater treatment plants whose catchment areas are contained within the larger “global” population. For this quick start, we will use simulated data, modeled after a hypothetical US state with 4 wastewater treatment plants (also referred to as sites) reporting data on log scale viral concentrations of SARS-COV-2, processed in 3 different labs, covering about 25% of the state’s population. This simulated data contains daily counts of the total hospital admissions in a hypothetical US state from September 1, 2023 to November 29, 2023. It contains wastewater log genome concentration data from September 1, 2023 to December 1, 2023, with varying sampling frequencies. We will be using this data to produce a forecast of COVID-19 hospital admissions as of December 6, 2023. These data are provided as part of the package data.

These data are already in a format that can be used for wwinference. For the hospital admissions data, it contains:

  • a date (column date): the date of the observation, in this case, the date the hospital admissions occurred
  • a count (column daily_hosp_admits): the number of hospital admissions observed on that day
  • a population size (column state_pop): the population size covered by the hospital admissions data, in this case, the size of the theoretical state.

Additionally, we provide the hosp_data_eval dataset which contains the simulated hospital admissions 28 days ahead of the forecast date, which can be used to evaluate the model.

For the wastewater data, the expcted format is a table of observations with the

following columns. The wastewater data should not contain NA values for days with missing observations, instead these should be excluded: - a date (column date): the date the sample was collected - a site indicator (column site): the unique identifier for the wastewater treatment plant that the sample was collected from - a lab indicator (column lab): the unique identifier for the lab where the sample was processed - a concentration (column log_genome_copies_ml): the measured log genome copies per mL for the given sample. This column should not contain NA values, even if the observation for that sample is below the limit of detection. - a limit of detection (column log_lod): the natural log of the limit of detection of the assay used to process the sample. Units should be the same units as the concentration column. - a site population size (column site_pop): the population size covered by the wastewater catchment area of that site

Code
hosp_data <- wwinference::hosp_data
hosp_data_eval <- wwinference::hosp_data_eval
ww_data <- wwinference::ww_data

head(ww_data)
## # A tibble: 6 × 7
##   date        site   lab log_genome_copies_per_ml log_lod site_pop location     
##   <date>     <dbl> <dbl>                    <dbl>   <dbl>    <dbl> <chr>        
## 1 2023-09-05     5     3                     8.14    5.11    50000 example state
## 2 2023-09-06     5     3                     8.00    5.11    50000 example state
## 3 2023-09-13     5     3                     9.04    5.11    50000 example state
## 4 2023-09-14     5     3                     8.36    5.11    50000 example state
## 5 2023-09-15     5     3                     8.29    5.11    50000 example state
## 6 2023-09-18     5     3                     8.07    5.11    50000 example state
Code
head(hosp_data)
## # A tibble: 6 × 4
##   date       daily_hosp_admits state_pop location     
##   <date>                 <dbl>     <dbl> <chr>        
## 1 2023-09-01                11   3000000 example state
## 2 2023-09-02                13   3000000 example state
## 3 2023-09-03                20   3000000 example state
## 4 2023-09-04                 5   3000000 example state
## 5 2023-09-05                16   3000000 example state
## 6 2023-09-06                10   3000000 example state

Pre-processing

The user will need to provide data that is in a similar format to the package data, as described above. This represents the bare minimum required data for a single location and a single forecast date. We will need to do some pre-processing to add some additional variables that the model will need to be able apply features such as outlier exclusion and censoring of values below the limit of detection.

Parameters

Get the example parameters from the package, which we will use here. Note that some of these are COVID specific, others are more general to the model, as indicated in the .toml file.

Code
params <- get_params(
  system.file("extdata", "example_params.toml",
    package = "wwinference"
  )
)

Wastewater data pre-processing

The preprocess_ww_data function adds the following variables to the original dataset. First, it assigns a unique identifier the unique combinations of labs and sites, since this is the unit we will use for estimating the observation error in the reported measurements. Second it adds a column below_lod which is an indicator of whether the reported concentration is above or below the limit of detection (LOD). If the observation is below the LOD, the model will treat this observation as censored. Third, it adds a column flag_as_ww_outlier that indicates whether the measurement is identified as an outlier by our algorithm and the default thresholds. While the default choice will be to exclude the measurements flagged as outliers, the user can still choose to include these if they’d like later on. The user must specify the name of the column containing the concentration measurements (presumed to be in genome copies per mL) and the name of the column containing the limit of detection for each measurement. The function assumes that the original data contains the columns date, site, and lab, and will return a dataframe with the column names needed to pass to the downstream model fitting functions.

Code
ww_data_preprocessed <- wwinference::preprocess_ww_data(
  ww_data,
  conc_col_name = "log_genome_copies_per_ml",
  lod_col_name = "log_lod"
)

Note that this function assumes that there are no missing values in the concentration column. The package expects observations below the LOD will be replaced with a numeric value below the LOD. If there are NAs in your dataset when observations are below the LOD, we suggest replacing them with a value below the LOD in upstream pre-processing.

Hospital admissions data pre-processing

The preprocess_hosp_data function standardizes the column names of the resulting datafame. The user must specify the name of the column containing the daily hospital admissions counts and the population size that the hospital admissions are coming from (from in this case, a hypothetical US state). The function assumes that the original data contains the column date, and will return a dataframe with the column names needed to pass to the downstream model fitting functions.

Code
hosp_data_preprocessed <- wwinference::preprocess_count_data(
  hosp_data,
  count_col_name = "daily_hosp_admits",
  pop_size_col_name = "state_pop"
)

We’ll make some plots of the data just to make sure it looks like what we’d expect:

Code
ggplot(ww_data_preprocessed) +
  geom_point(
    aes(
      x = date, y = log_genome_copies_per_ml,
      color = as.factor(lab_site_name)
    ),
    show.legend = FALSE,
    size = 0.5
  ) +
  geom_point(
    data = ww_data_preprocessed |> filter(
      log_genome_copies_per_ml <= log_lod
    ),
    aes(x = date, y = log_genome_copies_per_ml, color = "red"),
    show.legend = FALSE, size = 0.5
  ) +
  scale_x_date(
    date_breaks = "2 weeks",
    labels = scales::date_format("%Y-%m-%d")
  ) +
  geom_hline(aes(yintercept = log_lod), linetype = "dashed") +
  facet_wrap(~lab_site_name, scales = "free") +
  xlab("") +
  ylab("Genome copies/mL") +
  ggtitle("Lab-site level wastewater concentration") +
  theme_bw() +
  theme(
    axis.text.x = element_text(
      size = 5, vjust = 1,
      hjust = 1, angle = 45
    ),
    axis.title.x = element_text(size = 12),
    axis.text.y = element_text(size = 5),
    strip.text = element_text(size = 5),
    axis.title.y = element_text(size = 12),
    plot.title = element_text(
      size = 10,
      vjust = 0.5, hjust = 0.5
    )
  )

Code
ggplot(hosp_data_preprocessed) +
  # Plot the hospital admissions data that we will evaluate against in white
  geom_point(
    data = hosp_data_eval, aes(
      x = date,
      y = daily_hosp_admits_for_eval
    ),
    shape = 21, color = "black", fill = "white"
  ) +
  # Plot the data we will calibrate to
  geom_point(aes(x = date, y = count)) +
  scale_x_date(
    date_breaks = "2 weeks",
    labels = scales::date_format("%Y-%m-%d")
  ) +
  xlab("") +
  ylab("Daily hospital admissions") +
  ggtitle("State level hospital admissions") +
  theme_bw() +
  theme(
    axis.text.x = element_text(
      size = 8, vjust = 1,
      hjust = 1, angle = 45
    ),
    axis.title.x = element_text(size = 12),
    axis.title.y = element_text(size = 12),
    plot.title = element_text(
      size = 10,
      vjust = 0.5, hjust = 0.5
    )
  )

The closed circles indicate the data the model will be calibrated to, while the open circles indicate data we later observe after the forecast date.

Data exclusion

As an optional additional pre-processing step, the user can decide to exclude certain data points in the model fit procedure. For example, we recommend excluding the flagged wastewater concentration outliers. To do so we will use the indicate_ww_exclusions() function, which will add the flagged outliers to the exclude column where indicated.

Code
ww_data_to_fit <- wwinference::indicate_ww_exclusions(
  ww_data_preprocessed,
  outlier_col_name = "flag_as_ww_outlier",
  remove_outliers = TRUE
)

Model specification:

We will need to set some metadata to facilitate model specification. *This includes: + forecast date (the date we are making a forecast) + number of days to calibrate the model for + number of days to forecast beyond the forecast date + specification of the generation interval, in this case for COVID-19 + specification of the delay from infection to the count data, in this case from infection to COVID-19 hospital admission

Calibration time and forecast time

The calibration time represents the number of days to calibrate the count data to. This must be less than or equal to the number of rows in hosp_data. The forecast horizon represents the number of days from the forecast date to generate forecasted hospital admissions for. Typically, the hospital admissions data will not be complete up until the forecast date, and we will refer to the time between the last hospital admissions data point and the forecast date as the nowcast time. The model will “forecast” this period, in addition to the specified forecast horizon.

Code
forecast_date <- "2023-12-06"
calibration_time <- 90
forecast_horizon <- 28

Delay distributions

We will pass in probability mass functions (PMFs) that are specific to COVID, and to the delay from infections to hospital admissions, the count data we are using to fit the model. If using a different pathogen or a different count dataset, these PMFs need to be replaced. We provide them as package data here. The model expects that these are discrete daily PMFs.

Additionally, the model requires specifying a delay distribution for the infection feedback term, which essentially describes the delay at which high incident infections results in negative feedback on future infections (due to susceptibility, behavior changes, policies to reduce transmission, etc.). We by default set this as the generation interval, but this can be modified with any discrete daily PMF.

Code
generation_interval <- wwinference::default_covid_gi
inf_to_hosp <- wwinference::default_covid_inf_to_hosp

# Assign infection feedback equal to the generation interval
infection_feedback_pmf <- generation_interval

We will pass these to the model_spec() function of the wwinference() model, along with the other specified parameters above.

Precompiling the model

As wwinference uses cmdstan to fit its models, it is necessary to first compile the model. This can be done using the compile_model() function.

Code
model <- wwinference::compile_model()
## Using model source file:
## /home/runner/work/_temp/Library/wwinference/stan/wwinference.stan
## Using include paths: /home/runner/work/_temp/Library/wwinference/stan
## Model compiled or loaded successfully; model executable binary located at:
## /tmp/RtmpeMpoW9/wwinference

Fitting the model

We’re now ready to fit the model using the “No-U-Turn Sampler Markov chain Monte Carlo” method. This is a type of Hamiltonian Monte Carlo (HMC) algorithm and is the core fitting method used by cmdstan. The user can adjust the MCMC settings (see the documentation for get_mcmc_options()), however this vignette will use the default parameter settings which includes running 4 parallel chains with 750 warm up iterations, 500 sampling iterations for each chain, a target average acceptance probability of 0.95 and a maximum tree depth of 12. The default is not to set a the seed for the random number generator for the MCMC model runs (which would produce stochastic results each time the model is run), but for reproducibility we will set the seed of the Stan PRNG to 123 in this vignette.

When applying the model to real data, experimenting with these MCMC settings may make it possible to achieve improved model convergence and/or faster model fitting times. See the Stan User’s Guide for an introduction to No-U-Turn sampler convergence diagnostics and configuration parameters.

We also pass our preprocessed datasets (ww_data_to_fit and hosp_data_preprocessed), specify our model using get_model_spec(), set the MCMC settings using get_mcmc_options(), and pass in our pre-compiled model(model) to wwinference() where they are combined and used to fit the model.

Code
ww_fit <- wwinference::wwinference(
  ww_data = ww_data_to_fit,
  count_data = hosp_data_preprocessed,
  forecast_date = forecast_date,
  calibration_time = calibration_time,
  forecast_horizon = forecast_horizon,
  model_spec = get_model_spec(
    generation_interval = generation_interval,
    inf_to_count_delay = inf_to_hosp,
    infection_feedback_pmf = infection_feedback_pmf,
    params = params
  ),
  fit_opts = get_mcmc_options(seed = 123),
  compiled_model = model
)
## Running MCMC with 4 parallel chains...
## 
## Chain 1 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 2 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 3 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 4 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 2 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 3 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 1 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 4 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 2 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 1 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 3 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 2 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 1 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 4 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 3 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 2 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 1 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 4 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 2 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 3 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 1 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 4 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 2 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 3 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 4 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 1 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 2 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 3 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 4 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 1 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 2 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 3 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 1 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 4 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 3 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 1 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 4 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 2 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 3 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 1 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 3 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 4 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 1 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 2 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 3 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 1 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 3 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 4 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 1 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 1 finished in 286.1 seconds.
## Chain 2 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 3 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 3 finished in 289.4 seconds.
## Chain 4 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 2 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 4 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 2 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 2 finished in 320.8 seconds.
## Chain 4 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 4 finished in 335.6 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 308.0 seconds.
## Total execution time: 335.8 seconds.

The wwinference_fit object

The wwinference() function returns a wwinference_fit object which includes the underlying and the underlying CmdStanModel object (fit), a list of the two sources of input data (raw_input_data), the list of the arguments passed to stan (stan_data_list), and the list of the MCMC options (fit_opts) passed to stan. We show how to generate downstream elements from a wwinference_fit object.

wwinference_fit objects currently have the following methods available:

Code
methods(class = "wwinference_fit")
## [1] get_draws                  get_model_diagnostic_flags
## [3] print                      summary                   
## see '?methods' for accessing help and source code

The print and summary methods can provide some information about the model. In particular, the summary method is a wrapper for cmdstanr::summary():

Code
print(ww_fit)
## wwinference_fit object
## N of WW sites              : 4 
## N of unique lab-site pairs : 5 
## State population           : 3000000 
## N of weeks                 : 18 
## --------------------
## For more details, you can access the following:
##  - `$fit` for the CmdStan object
##  - `$raw_input_data` for the input data
##  - `$stan_data_list` for the stan data arguments
##  - `$fit_opts` for the fitting options
Code
summary(ww_fit)
## # A tibble: 4,413 × 10
##    variable      mean    median     sd    mad      q5     q95  rhat ess_bulk
##    <chr>        <dbl>     <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>
##  1 lp__     -365.     -365.     11.4   11.0   -384.   -347.    1.00     546.
##  2 w[1]       -0.0727   -0.0243  1.01   1.00    -1.81    1.57  1.01    2469.
##  3 w[2]        0.0139    0.0332  0.989  0.934   -1.62    1.69  1.01    3138.
##  4 w[3]       -0.168    -0.149   0.954  0.932   -1.75    1.44  1.00    2359.
##  5 w[4]       -0.299    -0.283   0.974  0.970   -1.86    1.29  1.00    2227.
##  6 w[5]       -0.512    -0.525   1.04   1.06    -2.21    1.20  1.00    1603.
##  7 w[6]       -0.733    -0.777   1.09   1.10    -2.50    1.07  1.00    1021.
##  8 w[7]       -0.853    -0.864   1.12   1.16    -2.60    1.02  1.00    1083.
##  9 w[8]       -0.434    -0.452   1.02   1.03    -2.14    1.23  1.00    1550.
## 10 w[9]       -0.326    -0.318   0.940  0.914   -1.90    1.26  1.00    2150.
## # ℹ 4,403 more rows
## # ℹ 1 more variable: ess_tail <dbl>

Extracting the posterior predictions

Working with the posterior predictions alongside the input data can be useful to check that your model is fitting the data well and that the nowcasted/forecast quantities look reasonable.

We will generate a dataframe that we’ll call draws_df, that contains the posterior draws of the estimated, nowcasted, and forecasted expected observed hospital admissions and wastewater concentrations, as well as the latent variables of interest including the site-level \(\mathcal{R}(t)\) estimates and the state-level \(\mathcal{R}(t)\) estimate.

We can generate this directly on the output of [wwinference::wwinference()] using:

Code
draws_df <- get_draws(ww_fit)

print(draws_df)
## Draws from the model featuring 2000 draws across 124 days  in the following datasets:
##  - `$predicted_counts` with 250000 rows
##  - `$predicted_ww` with 1250000 rows across 5 sites.
##  - `$global_rt` with 250000 rows
##  - `$subpop_rt` with 1250000 rows across 5 sub-populations
## You can use $ to access the datasets.

Note that by default the get_draws() function will return a list of class wwinference_fit_draws all of the posterior draws for predicted hospitalizations, wastewater concentration, global, and site \(\mathcal{R}(t)\) estimates. To examine a particular variable (e.g. "predicted counts" for posterior predicted hospital admissions), access the corresponding tibble using the $ operator.

Using explicit passed arguments rather than S3 methods

Rather than using S3 methods supplied for wwinference(), the elements in the wwinference_fit object can also be used directly to create this dataframe. This is demonstrated below:

Code
draws_df_explicit <- get_draws(
  x = ww_fit$raw_input_data$input_ww_data,
  count_data = ww_fit$raw_input_data$input_count_data,
  stan_data_list = ww_fit$stan_data_list,
  fit_obj = ww_fit$fit
)

Plotting the outputs

We can create plots of the outputs using draws_df and the fitting wrapper functions. Note that by default, these plots will not visualize data that was below the LOD (even though the fit incorporated them via the censored observation process.)

Code
plot_hosp <- get_plot_forecasted_counts(
  draws = draws_df$predicted_counts,
  count_data_eval = hosp_data_eval,
  count_data_eval_col_name = "daily_hosp_admits_for_eval",
  forecast_date = forecast_date
)
plot_hosp
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).

Code
plot_ww <- get_plot_ww_conc(draws_df$predicted_ww, forecast_date)
plot_ww
## Warning: Removed 52700 rows containing missing values or values outside the scale range
## (`geom_point()`).

Code
plot_state_rt <- get_plot_global_rt(draws_df$global_rt, forecast_date)
plot_state_rt

Code
plot_subpop_rt <- get_plot_subpop_rt(draws_df$subpop_rt, forecast_date)
plot_subpop_rt

The previous three are equivalent to calling the plot method of wwinference_fit_draws using the what argument:

Code
plot(
  x = draws_df,
  what = "predicted_counts",
  count_data_eval = hosp_data_eval,
  count_data_eval_col_name = "daily_hosp_admits_for_eval",
  forecast_date = forecast_date
)
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).

Code
plot(draws_df, what = "predicted_ww", forecast_date = forecast_date)
## Warning: Removed 52700 rows containing missing values or values outside the scale range
## (`geom_point()`).

Code
plot(draws_df, what = "global_rt", forecast_date = forecast_date)

Code
plot(draws_df, what = "subpop_rt", forecast_date = forecast_date)

Diagnostics

We strongly recommend running diagnostics as a post-processing step on the model outputs.

This can be done by passing the output of wwinference() into the get_model_diagnostic_flags(), parameter_diagnostics(), and parameter_diagnostics() functions.

get_model_diagnostic_flags() will print out a table of any flags, if any of these are TRUE, it will print out a warning. We have set default thresholds on the model diagnostics for production-level runs, we recommend adjusting as needed (see below)

To further troubleshoot, you can look at the diagnostic summary and the diagnostics of the individual parameters using the parameter_diagnostics() function.

Code
convergence_flag_df <- get_model_diagnostic_flags(ww_fit)
print(convergence_flag_df)
##   flag_high_max_treedepth flag_too_many_divergences flag_high_rhat
## 1                   FALSE                     FALSE          FALSE
##   flag_low_embfi
## 1          FALSE
Code
## $num_divergent
## [1] 13  3  7  3
## 
## $num_max_treedepth
## [1] 0 0 0 0
## 
## $ebfmi
## [1] 0.8063454 0.7588019 0.7639581 0.9149501

This can also be done explicitly by parsing the elements of the wwinference_fit object into the custom functions we built / directly calling CmdStan’s built in functions.

Start by passing the stan fit object(ww_fit$fit$result) into the get_model_diagnostic_flags() and adjusting the thresholds if desired.

Then, we recommend looking at the diagnostics summary provided by CmdStan, which we had wrapped into the parameter_diagnostics() call above. Lastly, we recommend looking at the individual model parameters provided by CmdStan to identify which components of the model might be driving the convergence issues.

For further information on troubleshooting the model diagnostics, we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html].

Code
convergence_flag_df <- get_model_diagnostic_flags(
  x = ww_fit$fit$result,
  ebmfi_tolerance = 0.2,
  divergences_tolerance = 0.01,
  frac_high_rhat_tolerance = 0.05,
  rhat_tolerance = 1.05,
  max_tree_depth_tol = 0.01
)
# Get the tables using the CmdStan functions via wrappers
summary(ww_fit)
## # A tibble: 4,413 × 10
##    variable      mean    median     sd    mad      q5     q95  rhat ess_bulk
##    <chr>        <dbl>     <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>
##  1 lp__     -365.     -365.     11.4   11.0   -384.   -347.    1.00     546.
##  2 w[1]       -0.0727   -0.0243  1.01   1.00    -1.81    1.57  1.01    2469.
##  3 w[2]        0.0139    0.0332  0.989  0.934   -1.62    1.69  1.01    3138.
##  4 w[3]       -0.168    -0.149   0.954  0.932   -1.75    1.44  1.00    2359.
##  5 w[4]       -0.299    -0.283   0.974  0.970   -1.86    1.29  1.00    2227.
##  6 w[5]       -0.512    -0.525   1.04   1.06    -2.21    1.20  1.00    1603.
##  7 w[6]       -0.733    -0.777   1.09   1.10    -2.50    1.07  1.00    1021.
##  8 w[7]       -0.853    -0.864   1.12   1.16    -2.60    1.02  1.00    1083.
##  9 w[8]       -0.434    -0.452   1.02   1.03    -2.14    1.23  1.00    1550.
## 10 w[9]       -0.326    -0.318   0.940  0.914   -1.90    1.26  1.00    2150.
## # ℹ 4,403 more rows
## # ℹ 1 more variable: ess_tail <dbl>
Code
parameter_diagnostics(ww_fit, quiet = TRUE)
## $num_divergent
## [1] 13  3  7  3
## 
## $num_max_treedepth
## [1] 0 0 0 0
## 
## $ebfmi
## [1] 0.8063454 0.7588019 0.7639581 0.9149501
Code
head(convergence_flag_df)
##   flag_high_max_treedepth flag_too_many_divergences flag_high_rhat
## 1                   FALSE                     FALSE          FALSE
##   flag_low_embfi
## 1          FALSE

Fit to only hospital admissions data

The package also has functionality to fit the model without wastewater data. This can be useful when doing comparisons of the impact the wastewater data has on the forecast, or as a part of a pipeline where one might choose to rely on the admissions only model if there are covergence or known data issues with the wastewater data.

Code
fit_hosp_only <- wwinference::wwinference(
  ww_data = ww_data_to_fit,
  count_data = hosp_data_preprocessed,
  forecast_date = forecast_date,
  calibration_time = calibration_time,
  forecast_horizon = forecast_horizon,
  model_spec = get_model_spec(
    generation_interval = generation_interval,
    inf_to_count_delay = inf_to_hosp,
    infection_feedback_pmf = infection_feedback_pmf,
    include_ww = FALSE,
    params = params
  ),
  fit_opts = get_mcmc_options(seed = 123),
  compiled_model = model
)
## Running MCMC with 4 parallel chains...
## 
## Chain 1 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 2 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 3 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 4 Iteration:    1 / 1250 [  0%]  (Warmup) 
## Chain 2 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 3 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 1 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 4 Iteration:  100 / 1250 [  8%]  (Warmup) 
## Chain 2 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 3 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 1 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 4 Iteration:  200 / 1250 [ 16%]  (Warmup) 
## Chain 2 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 3 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 1 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 2 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 4 Iteration:  300 / 1250 [ 24%]  (Warmup) 
## Chain 3 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 2 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 1 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 4 Iteration:  400 / 1250 [ 32%]  (Warmup) 
## Chain 3 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 2 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 1 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 4 Iteration:  500 / 1250 [ 40%]  (Warmup) 
## Chain 2 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 3 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 1 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 4 Iteration:  600 / 1250 [ 48%]  (Warmup) 
## Chain 2 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 3 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 1 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 4 Iteration:  700 / 1250 [ 56%]  (Warmup) 
## Chain 3 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 1 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 2 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 4 Iteration:  751 / 1250 [ 60%]  (Sampling) 
## Chain 1 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 3 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 2 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 4 Iteration:  850 / 1250 [ 68%]  (Sampling) 
## Chain 3 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 1 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 2 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 4 Iteration:  950 / 1250 [ 76%]  (Sampling) 
## Chain 3 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 1 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 2 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 4 Iteration: 1050 / 1250 [ 84%]  (Sampling) 
## Chain 3 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 1 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 4 Iteration: 1150 / 1250 [ 92%]  (Sampling) 
## Chain 3 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 3 finished in 90.8 seconds.
## Chain 2 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 2 finished in 91.0 seconds.
## Chain 1 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 1 finished in 92.2 seconds.
## Chain 4 Iteration: 1250 / 1250 [100%]  (Sampling) 
## Chain 4 finished in 93.6 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 91.9 seconds.
## Total execution time: 93.7 seconds.
Code
draws_df_hosp_only <- get_draws(fit_hosp_only)
plot_hosp_hosp_only <- get_plot_forecasted_counts(
  draws = draws_df_hosp_only$predicted_counts,
  count_data_eval = hosp_data_eval,
  count_data_eval_col_name = "daily_hosp_admits_for_eval",
  forecast_date = forecast_date
)
plot_hosp_hosp_only
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).