Getting started with wwinference
Kaitlyn Johnson
2024-06-27
Source:vignettes/wwinference.Rmd
wwinference.Rmd
Quick start
In this quick start, we demonstrate using wwinference
to specify and fit a
minimal model using daily COVID-19 hospital admissions from a “global” population and
viral concentrations in wastewater from a few “local” wastewater treatment plants,
which come from subsets of the larger population.
In this context, when we say “global”, we are referring to a larger
population e.g. a state, and when we say “local” we are referring to a smaller
subset of that population, e.g. a municipality within that state.
This is intended to be used as a reference for those
interested in fitting the wwinference
model to their own data.
Packages
In this quick start, we also use dplyr
tidybayes
and ggplot2
packages.
These are installed as dependencies when wwinference
is installed.
Data
The model expects two types of data: daily counts of hospital admissions data from the larger “global” population, and wastewater concentration data from wastewater treatment plants whose catchment areas are contained within the larger “global” population. For this quick start, we will use simulated data, modeled after a hypothetical US state with 4 wastewater treatment plants (also referred to as sites) reporting data on log scale viral concentrations of SARS-COV-2, processed in 3 different labs, covering about 25% of the state’s population. This simulated data contains daily counts of the total hospital admissions in a hypothetical US state from September 1, 2023 to November 29, 2023. It contains wastewater log genome concentration data from September 1, 2023 to December 1, 2023, with varying sampling frequencies. We will be using this data to produce a forecast of COVID-19 hospital admissions as of December 6, 2023. These data are provided as part of the package data.
These data are already in a format that can be used for the wwinference
package.
For the hospital admissions data, it contains:
- a date (column
date
): the date of the observation, in this case, the date the hospital admissions occurred - a count (column
daily_hosp_admits
): the number of hospital admissions observed on that day - a population size (column
state_pop
): the population size covered by the hospital admissions data, in this case, the size of the theoretical state.
Additionally, we provide the hosp_data_eval
dataset which contains the
simulated hospital admissions 28 days ahead of the forecast date, which can be
used to evaluate the model.
For the wastewater data, the expcted format is a table of observations with the
following columns. The wastewater data should not contain NA
values for days with
missing observations, instead these should be excluded:
- a date (column date
): the date the sample was collected
- a site indicator (column site
): the unique identifier for the wastewater treatment plant
that the sample was collected from
- a lab indicator (column lab
): the unique identifier for the lab where the sample was processed
- a concentration (column log_genome_copies_ml
): the measured
log genome copies per mL for the given sample. This column should not
contain NA
values, even if the observation for that sample is below the limit of
detection.
- a limit of detection (column log_lod
): the natural log of the limit
of detection of the assay used to process the sample. Units should be the same
units as the concentration column.
- a site population size (column site_pop
): the population size covered by the
wastewater catchment area of that site
Code
hosp_data <- wwinference::hosp_data
hosp_data_eval <- wwinference::hosp_data_eval
ww_data <- wwinference::ww_data
head(ww_data)
## # A tibble: 6 × 6
## date site lab log_genome_copies_per_ml log_lod site_pop
## <date> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 2023-09-01 1 1 7.82 5.01 400000
## 2 2023-09-11 1 1 8.27 5.01 400000
## 3 2023-09-14 1 1 8.64 5.01 400000
## 4 2023-09-18 1 1 8.87 5.01 400000
## 5 2023-09-21 1 1 8.91 5.01 400000
## 6 2023-09-26 1 1 9.17 5.01 400000
Code
head(hosp_data)
## # A tibble: 6 × 4
## date daily_hosp_admits state_pop location
## <date> <dbl> <dbl> <chr>
## 1 2023-09-01 25 3000000 example state
## 2 2023-09-02 17 3000000 example state
## 3 2023-09-03 25 3000000 example state
## 4 2023-09-04 24 3000000 example state
## 5 2023-09-05 26 3000000 example state
## 6 2023-09-06 25 3000000 example state
Pre-processing
The user will need to provide data that is in a similar format to the package data, as described above. This represents the bare minimum required data for a single location and a single forecast date. We will need to do some pre-processing to add some additional variables that the model will need to be able apply features such as outlier exclusion and censoring of values below the limit of detection.
Parameters
Get the example parameters from the package, which we will use here. Note that some of these are COVID specific, others are more general to the model, as indicated in the .toml file.
Code
params <- get_params(
system.file("extdata", "example_params.toml",
package = "wwinference"
)
)
Wastewater data pre-processing
The preprocess_ww_data()
function adds the following variables to the original
dataset. First, it assigns a unique identifier
the unique combinations of labs and sites, since this is the unit we will
use for estimating the observation error in the reported measurements.
Second it adds a column below_lod
which is an indicator of whether the
reported concentration is above or below the limit of detection (LOD). If the
observation is below the LOD, the model will treat this observation as censored.
Third, it adds a column flag_as_ww_outlier
that indicates whether the
measurement is identified as an outlier by our algorithm and the default
thresholds. While the default choice will be to exclude the measurements flagged
as outliers, the user can still choose to include these if they’d like later on.
The user must specify the name of the column containing the
concentration measurements (presumed to be in genome copies per mL) and the
name of the column containing the limit of detection for each measurement. The
function assumes that the original data contains the columns date
, site
,
and lab
, and will return a dataframe with the column names needed to
pass to the downstream model fitting functions.
Code
ww_data_preprocessed <- preprocess_ww_data(
ww_data,
conc_col_name = "log_genome_copies_per_ml",
lod_col_name = "log_lod"
)
Note that this function assumes that there are no missing values in the concentration column. The package expects observations below the LOD will be replaced with a numeric value below the LOD. If there are NAs in your dataset when observations are below the LOD, we suggest replacing them with a value below the LOD in upstream pre-processing.
Hospital admissions data pre-processing
The preprocess_count_data()
function standardizes the column names of the
resulting datafame. The user must specify the name of the column containing
the daily hospital admissions counts and the population size that the hospital
admissions are coming from (from in this case, a hypothetical US state). The
function assumes that the original data contains the column date
, and will
return a dataframe with the column names needed to pass to the downstream model
fitting functions.
Code
hosp_data_preprocessed <- preprocess_count_data(
hosp_data,
count_col_name = "daily_hosp_admits",
pop_size_col_name = "state_pop"
)
We’ll make some plots of the data just to make sure it looks like what we’d expect:
Code
ggplot(ww_data_preprocessed) +
geom_point(
aes(
x = date, y = log_genome_copies_per_ml,
color = as.factor(lab_site_name)
),
show.legend = FALSE,
size = 0.5
) +
geom_point(
data = ww_data_preprocessed |> filter(
log_genome_copies_per_ml <= log_lod
),
aes(x = date, y = log_genome_copies_per_ml, color = "red"),
show.legend = FALSE, size = 0.5
) +
scale_x_date(
date_breaks = "2 weeks",
labels = scales::date_format("%Y-%m-%d")
) +
geom_hline(aes(yintercept = log_lod), linetype = "dashed") +
facet_wrap(~lab_site_name, scales = "free") +
xlab("") +
ylab("Genome copies/mL") +
ggtitle("Lab-site level wastewater concentration") +
theme_bw() +
theme(
axis.text.x = element_text(
size = 5, vjust = 1,
hjust = 1, angle = 45
),
axis.title.x = element_text(size = 12),
axis.text.y = element_text(size = 5),
strip.text = element_text(size = 5),
axis.title.y = element_text(size = 12),
plot.title = element_text(
size = 10,
vjust = 0.5, hjust = 0.5
)
)
Code
ggplot(hosp_data_preprocessed) +
# Plot the hospital admissions data that we will evaluate against in white
geom_point(
data = hosp_data_eval, aes(
x = date,
y = daily_hosp_admits_for_eval
),
shape = 21, color = "black", fill = "white"
) +
# Plot the data we will calibrate to
geom_point(aes(x = date, y = count)) +
scale_x_date(
date_breaks = "2 weeks",
labels = scales::date_format("%Y-%m-%d")
) +
xlab("") +
ylab("Daily hospital admissions") +
ggtitle("State level hospital admissions") +
theme_bw() +
theme(
axis.text.x = element_text(
size = 8, vjust = 1,
hjust = 1, angle = 45
),
axis.title.x = element_text(size = 12),
axis.title.y = element_text(size = 12),
plot.title = element_text(
size = 10,
vjust = 0.5, hjust = 0.5
)
)
The closed circles indicate the data the model will be calibrated to, while the open circles indicate data we later observe after the forecast date.
Data exclusion
As an optional additional pre-processing step, the user can decide to exclude
certain data points in the model fit procedure. For example,
we recommend excluding the flagged wastewater concentration outliers. To do so
we will use the indicate_ww_exclusions()
function, which will add the
flagged outliers to the exclude column where indicated.
Code
ww_data_to_fit <- indicate_ww_exclusions(
ww_data_preprocessed,
outlier_col_name = "flag_as_ww_outlier",
remove_outliers = TRUE
)
Model specification:
We will need to set some metadata to facilitate model specification. This includes: - forecast date (the date we are making a forecast) - number of days to calibrate the model for - number of days to forecast beyond the forecast date - specification of the generation interval, in this case for COVID-19 - specification of the delay from infection to the count data, in this case from infection to COVID-19 hospital admission
Calibration time and forecast time
The calibration time represents the number of days to calibrate the count data
to. This must be less than or equal to the number of rows in hosp_data
. The
forecast horizon represents the number of days from the forecast date to
generate forecasted hospital admissions for. Typically, the hospital admissions
data will not be complete up until the forecast date, and we will refer to the
time between the last hospital admissions data point and the forecast date as
the nowcast time. The model will “forecast” this period, in addition to the
specified forecast horizon.
Code
forecast_date <- "2023-12-06"
calibration_time <- 90
forecast_horizon <- 28
Delay distributions
We will pass in probability mass functions (PMFs) that are specific to COVID, and to the delay from infections to hospital admissions, the count data we are using to fit the model. If using a different pathogen or a different count dataset, these PMFs need to be replaced. We provide them as package data here. The model expects that these are discrete daily PMFs.
Additionally, the model requires specifying a delay distribution for the infection feedback term, which essentially describes the delay at which high incident infections results in negative feedback on future infections (due to susceptibility, behavior changes, policies to reduce transmission, etc.). We by default set this as the generation interval, but this can be modified with any discrete daily PMF.
Code
generation_interval <- wwinference::default_covid_gi
inf_to_hosp <- wwinference::default_covid_inf_to_hosp
# Assign infection feedback equal to the generation interval
infection_feedback_pmf <- generation_interval
We will pass these to the get_model_spec()
function of the wwinference()
model,
along with the other specified parameters above.
Precompiling the model
As wwinference
uses cmdstan
to fit its models, it is necessary to first
compile the model. This can be done using the compile_model()
function.
Code
model <- wwinference::compile_model()
## Using model source file:
## /home/runner/work/_temp/Library/wwinference/stan/wwinference.stan
## Using include paths: /home/runner/work/_temp/Library/wwinference/stan
## Model compiled or loaded successfully; model executable binary located at:
## /tmp/RtmpQe00iW/wwinference
Fitting the model
We’re now ready to fit the model using the “No-U-Turn Sampler Markov chain
Monte Carlo” method. This is a type of Hamiltonian Monte Carlo (HMC) algorithm
and is the core fitting method used by cmdstan
. The user can adjust the MCMC
settings (see the documentation for get_mcmc_options()
),
however this vignette will use
the default parameter settings which includes running 4 parallel chains with
750 warm up iterations, 500 sampling iterations for each chain, a target average
acceptance probability of 0.95 and a maximum tree depth of 12. The default is
not to set a the seed for the random number generator for the MCMC model runs
(which would produce stochastic results each time the model is run), but for
reproducibility we will set the seed of the Stan PRNG to 123
in this vignette.
When applying the model to real data, experimenting with these MCMC settings may make it possible to achieve improved model convergence and/or faster model fitting times. See the Stan User’s Guide for an introduction to No-U-Turn sampler convergence diagnostics and configuration parameters.
We also pass our preprocessed datasets (ww_data_to_fit
and
hosp_data_preprocessed
), specify our model using get_model_spec()
,
set the MCMC settings by passing a list of arguments to fit_opts
that will be passed to the cmdstanr::sample()
function, and pass in our
pre-compiled model(model
) to wwinference()
where they are combined and
used to fit the model.
Code
ww_fit <- wwinference(
ww_data = ww_data_to_fit,
count_data = hosp_data_preprocessed,
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
model_spec = get_model_spec(
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf,
params = params
),
fit_opts = list(seed = 123),
compiled_model = model
)
## Running MCMC with 4 parallel chains...
##
## Chain 1 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 3 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 4 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 1 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 3 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 2 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 4 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 1 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 3 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 2 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 4 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 1 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 3 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 2 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 4 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 1 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 3 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 2 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 4 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 1 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 3 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 2 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 4 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 1 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 3 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 2 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 4 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 3 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 1 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 4 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 2 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 1 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 3 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 4 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 2 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 1 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 3 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 4 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 2 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 1 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 3 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 4 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 2 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 1 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 3 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 4 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 2 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 1 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 3 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 3 finished in 323.0 seconds.
## Chain 2 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 2 finished in 327.4 seconds.
## Chain 4 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 4 finished in 328.8 seconds.
## Chain 1 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 1 finished in 330.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 327.3 seconds.
## Total execution time: 330.3 seconds.
The wwinference_fit
object
The wwinference()
function returns a wwinference_fit
object which includes
the underlying and the underlying
CmdStanModel
object
(fit
), a list of the two sources of input
data (raw_input_data
), the list of the arguments passed to stan
(stan_data_list
), and the list of the MCMC options (fit_opts
) passed to
stan. We show how to generate downstream elements from a wwinference_fit
object.
wwinference_fit
objects currently have the following methods available:
Code
methods(class = "wwinference_fit")
## [1] get_draws get_model_diagnostic_flags
## [3] print summary
## see '?methods' for accessing help and source code
The print
and summary
methods can provide some information about the model. In particular, the summary
method is a wrapper for cmdstanr::summary()
:
Code
print(ww_fit)
## wwinference_fit object
## N of WW sites : 4
## N of unique lab-site pairs : 5
## Total population : 3000000
## N of weeks : 18
## --------------------
## For more details, you can access the following:
## - `$fit` for the CmdStan object
## - `$raw_input_data` for the input data
## - `$stan_data_list` for the stan data arguments
## - `$fit_opts` for the fitting options
Code
summary(ww_fit)
## # A tibble: 4,282 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -430. -430. 10.0 9.98 -447. -415. 1.00 622.
## 2 w[1] 0.171 0.180 0.865 0.844 -1.26 1.57 1.00 2052.
## 3 w[2] 0.452 0.465 0.790 0.787 -0.835 1.76 1.00 1895.
## 4 w[3] 0.354 0.358 0.801 0.802 -0.933 1.70 1.00 1964.
## 5 w[4] 0.0112 0.0284 0.797 0.791 -1.28 1.32 0.999 1723.
## 6 w[5] -1.22 -1.22 0.737 0.765 -2.43 -0.0315 1.00 1644.
## 7 w[6] -1.62 -1.62 0.778 0.777 -2.91 -0.365 1.00 1509.
## 8 w[7] -1.53 -1.51 0.796 0.797 -2.85 -0.256 1.00 2017.
## 9 w[8] -1.34 -1.33 0.843 0.886 -2.68 0.0203 1.00 1831.
## 10 w[9] -0.172 -0.156 0.822 0.806 -1.54 1.19 1.00 1782.
## # ℹ 4,272 more rows
## # ℹ 1 more variable: ess_tail <dbl>
Extracting the posterior predictions
Working with the posterior predictions alongside the input data can be useful to check that your model is fitting the data well and that the nowcasted/forecast quantities look reasonable.
We can use the get_draws()
function to generate dataframes that contain
the posterior draws of the estimated, nowcasted, and forecasted quantities,
joined to the relevant data.
We can generate this directly on the output of wwinference()
using:
## Draws from the model featuring 2000 draws across 125 days in the following datasets:
## - `$predicted_counts` with 250000 rows
## - `$predicted_ww` with 1250000 rows across 5 sites.
## - `$global_rt` with 250000 rows
## - `$subpop_rt` with 1250000 rows across 5 subpopulations
## You can use $ to access the datasets.
Note that by default the get_draws()
function will return a list of class wwinference_fit_draws
which contains separate dataframes of the posterior draws for predicted counts ("predicted_counts"
),
wastewater concentrations ("predicted_ww"
), global \(\mathcal{R}(t)\) ("global_rt"
) estimates, and
subpopulation-level \(\mathcal{R}(t)\) estimates (“subpop_rt"
).
To examine a particular variable (e.g. "predicted_counts"
for posterior
predicted hospital admissions in this case), access the corresponding tibble using the $
operator.
You can also specify which outputs to return using the what
argument.
Code
## # A tibble: 6 × 5
## date draw observed_value pred_value total_pop
## <date> <int> <dbl> <dbl> <dbl>
## 1 2023-09-01 1 25 37 3000000
## 2 2023-09-01 2 25 21 3000000
## 3 2023-09-01 3 25 23 3000000
## 4 2023-09-01 4 25 25 3000000
## 5 2023-09-01 5 25 22 3000000
## 6 2023-09-01 6 25 20 3000000
Using explicit passed arguments rather than S3 methods
Rather than using S3 methods supplied for wwinference()
, the elements in the
wwinference_fit
object can also be used directly to create this dataframe.
This is demonstrated below:
Code
draws_explicit <- get_draws(
x = ww_fit$raw_input_data$input_ww_data,
count_data = ww_fit$raw_input_data$input_count_data,
date_time_spine = ww_fit$raw_input_data$date_time_spine,
site_subpop_spine = ww_fit$raw_input_data$site_subpop_spine,
lab_site_subpop_spine = ww_fit$raw_input_data$lab_site_subpop_spine,
stan_data_list = ww_fit$stan_data_list,
fit_obj = ww_fit$fit
)
Plotting the outputs
We can create plots of the outputs using corresponding dataframes in the draws
object and the fitting wrapper functions. Note that by default, these plots
will not include outliers that were flagged for exclusion. Data points
that are below the LOD will be plotted in blue.
Code
plot_hosp_with_eval <- get_plot_forecasted_counts(
draws = draws$predicted_counts,
forecast_date = forecast_date,
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval"
)
plot_hosp_with_eval
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).
Code
plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date)
plot_ww
## Warning: Removed 53100 rows containing missing values or values outside the scale range
## (`geom_point()`).
Code
plot_state_rt <- get_plot_global_rt(draws$global_rt, forecast_date)
plot_state_rt
Code
plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date)
plot_subpop_rt
To plot the forecasts without the retrospectively observed hospital admissions, simply don’t pass them to the plotting function.
Code
plot_hosp <- get_plot_forecasted_counts(
draws = draws$predicted_counts,
forecast_date = forecast_date
)
plot_hosp
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).
The previous three are equivalent to calling the plot
method of wwinference_fit_draws
using the what
argument:
Code
plot(
x = draws,
what = "predicted_counts",
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval",
forecast_date = forecast_date
)
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).
Code
plot(draws, what = "predicted_ww", forecast_date = forecast_date)
## Warning: Removed 53100 rows containing missing values or values outside the scale range
## (`geom_point()`).
Code
plot(draws, what = "global_rt", forecast_date = forecast_date)
Code
plot(draws, what = "subpop_rt", forecast_date = forecast_date)
Diagnostics
We strongly recommend running diagnostics as a post-processing step on the model outputs.
This can be done by passing the output of
wwinference()
into the get_model_diagnostic_flags()
, summary_diagnostics()
and parameter_diagnostics()
functions.
get_model_diagnostic_flags()
will print out a table of any flags, if any of
these are TRUE, it will print out a warning.
We have set default thresholds on the model diagnostics for production-level
runs, we recommend adjusting as needed (see below)
To further troubleshoot, you can look at
the summary diagnostics using the summary_diagnostics()
function
and the diagnostics of the individual parameters using
the parameter_diagnostics()
function.
For further information on troubleshooting the model diagnostics, we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html].
You can access the CmdStan object directly using ww_fit$fit$result
Code
convergence_flag_df <- get_model_diagnostic_flags(ww_fit)
print(convergence_flag_df)
## flag_high_max_treedepth flag_too_many_divergences flag_high_rhat
## 1 FALSE FALSE FALSE
## flag_low_embfi
## 1 FALSE
Code
summary_diagnostics(ww_fit)
## $num_divergent
## [1] 0 0 0 0
##
## $num_max_treedepth
## [1] 0 0 0 0
##
## $ebfmi
## [1] 1.0554791 0.9719766 0.8871317 0.8494908
Code
param_diagnostics <- parameter_diagnostics(ww_fit)
head(param_diagnostics)
## # A tibble: 6 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -430. -430. 10.0 9.98 -447. -415. 1.00 622.
## 2 w[1] 0.171 0.180 0.865 0.844 -1.26 1.57 1.00 2052.
## 3 w[2] 0.452 0.465 0.790 0.787 -0.835 1.76 1.00 1895.
## 4 w[3] 0.354 0.358 0.801 0.802 -0.933 1.70 1.00 1964.
## 5 w[4] 0.0112 0.0284 0.797 0.791 -1.28 1.32 0.999 1723.
## 6 w[5] -1.22 -1.22 0.737 0.765 -2.43 -0.0315 1.00 1644.
## # ℹ 1 more variable: ess_tail <dbl>
This can also be done explicitly by parsing the elements of the
wwinference_fit
object into the custom functions we built / directly calling
CmdStan
’s built in functions.
Start by passing the stan fit object(ww_fit$fit$result
) into the
get_model_diagnostic_flags()
and adjusting the thresholds if desired.
Then, we recommend looking at the diagnostics summary provided by CmdStan
,
which we had wrapped into the parameter_diagnostics()
call above. Lastly,
we recommend looking at the individual model parameters provided by CmdStan
to identify which components of the model might be driving the convergence
issues.
For further information on troubleshooting the model diagnostics, we recommend the bayesplot tutorial.
Code
convergence_flag_df <- get_model_diagnostic_flags(
x = ww_fit$fit$result,
ebmfi_tolerance = 0.2,
divergences_tolerance = 0.01,
frac_high_rhat_tolerance = 0.05,
rhat_tolerance = 1.05,
max_tree_depth_tol = 0.01
)
# Get the tables using the CmdStan functions via wrappers
summary(ww_fit)
## # A tibble: 4,282 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -430. -430. 10.0 9.98 -447. -415. 1.00 622.
## 2 w[1] 0.171 0.180 0.865 0.844 -1.26 1.57 1.00 2052.
## 3 w[2] 0.452 0.465 0.790 0.787 -0.835 1.76 1.00 1895.
## 4 w[3] 0.354 0.358 0.801 0.802 -0.933 1.70 1.00 1964.
## 5 w[4] 0.0112 0.0284 0.797 0.791 -1.28 1.32 0.999 1723.
## 6 w[5] -1.22 -1.22 0.737 0.765 -2.43 -0.0315 1.00 1644.
## 7 w[6] -1.62 -1.62 0.778 0.777 -2.91 -0.365 1.00 1509.
## 8 w[7] -1.53 -1.51 0.796 0.797 -2.85 -0.256 1.00 2017.
## 9 w[8] -1.34 -1.33 0.843 0.886 -2.68 0.0203 1.00 1831.
## 10 w[9] -0.172 -0.156 0.822 0.806 -1.54 1.19 1.00 1782.
## # ℹ 4,272 more rows
## # ℹ 1 more variable: ess_tail <dbl>
Code
parameter_diagnostics(ww_fit, quiet = TRUE)
## # A tibble: 4,282 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -430. -430. 10.0 9.98 -447. -415. 1.00 622.
## 2 w[1] 0.171 0.180 0.865 0.844 -1.26 1.57 1.00 2052.
## 3 w[2] 0.452 0.465 0.790 0.787 -0.835 1.76 1.00 1895.
## 4 w[3] 0.354 0.358 0.801 0.802 -0.933 1.70 1.00 1964.
## 5 w[4] 0.0112 0.0284 0.797 0.791 -1.28 1.32 0.999 1723.
## 6 w[5] -1.22 -1.22 0.737 0.765 -2.43 -0.0315 1.00 1644.
## 7 w[6] -1.62 -1.62 0.778 0.777 -2.91 -0.365 1.00 1509.
## 8 w[7] -1.53 -1.51 0.796 0.797 -2.85 -0.256 1.00 2017.
## 9 w[8] -1.34 -1.33 0.843 0.886 -2.68 0.0203 1.00 1831.
## 10 w[9] -0.172 -0.156 0.822 0.806 -1.54 1.19 1.00 1782.
## # ℹ 4,272 more rows
## # ℹ 1 more variable: ess_tail <dbl>
Code
head(convergence_flag_df)
## flag_high_max_treedepth flag_too_many_divergences flag_high_rhat
## 1 FALSE FALSE FALSE
## flag_low_embfi
## 1 FALSE
Fit to only hospital admissions data
The package also has functionality to fit the model without wastewater data. This can be useful when doing comparisons of the impact the wastewater data has on the forecast, or as a part of a pipeline where one might choose to rely on the admissions only model if there are covergence or known data issues with the wastewater data.
Code
fit_hosp_only <- wwinference(
ww_data = ww_data_to_fit,
count_data = hosp_data_preprocessed,
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
model_spec = get_model_spec(
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf,
include_ww = FALSE,
params = params
),
fit_opts = list(seed = 123),
compiled_model = model
)
## Running MCMC with 4 parallel chains...
##
## Chain 1 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1250 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 3 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 1 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 2 Iteration: 100 / 1250 [ 8%] (Warmup)
## Chain 4 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 3 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 3 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 4 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 2 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 1 Iteration: 200 / 1250 [ 16%] (Warmup)
## Chain 3 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 4 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 2 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 3 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 1 Iteration: 300 / 1250 [ 24%] (Warmup)
## Chain 4 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 3 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 4 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 1 Iteration: 400 / 1250 [ 32%] (Warmup)
## Chain 2 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 3 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 1 Iteration: 500 / 1250 [ 40%] (Warmup)
## Chain 4 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 2 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 3 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 4 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 1 Iteration: 600 / 1250 [ 48%] (Warmup)
## Chain 2 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 3 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 2 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 1 Iteration: 700 / 1250 [ 56%] (Warmup)
## Chain 3 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 4 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 2 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 1 Iteration: 751 / 1250 [ 60%] (Sampling)
## Chain 3 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 4 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 2 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 3 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 1 Iteration: 850 / 1250 [ 68%] (Sampling)
## Chain 2 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 3 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 3 finished in 69.6 seconds.
## Chain 1 Iteration: 950 / 1250 [ 76%] (Sampling)
## Chain 4 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 2 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 2 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 2 finished in 72.4 seconds.
## Chain 1 Iteration: 1050 / 1250 [ 84%] (Sampling)
## Chain 4 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 1 Iteration: 1150 / 1250 [ 92%] (Sampling)
## Chain 1 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 1 finished in 75.6 seconds.
## Chain 4 Iteration: 1250 / 1250 [100%] (Sampling)
## Chain 4 finished in 75.7 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 73.3 seconds.
## Total execution time: 75.9 seconds.
Code
draws_hosp_only <- get_draws(fit_hosp_only)
## Warning in get_draws.data.frame(x = x$raw_input_data$input_ww_data, count_data = x$raw_input_data$input_count_data, : Model wasn't fit to wastewater data. Predicted wastewater concentrations and subpopulation R(t)s
## estimates will not be returned in the `wwinference_fit_draws` object
Code
plot(draws_hosp_only,
what = "predicted_counts",
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval",
forecast_date = forecast_date
)
## Warning: Removed 3500 rows containing missing values or values outside the scale range
## (`geom_point()`).
Code
plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date)