Example: Early COVID-19 case data in South Korea
In this example we use EpiAware
functionality to largely recreate an epidemiological model presented in On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective, Mishra et al (2020). Mishra et al consider test-confirmed cases of COVID-19 in South Korea between January to July 2020. The components of the epidemilogical model they consider are:
The time varying reproductive number modelled as an AR(2) process on the log-scale \(\log R_t \sim \text{AR(2)}\).
The latent infection (\(I_t\)) generating process is a renewal model (note that we leave out external infections in this note):
$$I_t = R_t \sum_{s\geq 1} I_{t-s} g_s.$$
The discrete generation interval \(g_t\) is a daily discretisation of the probability mass function of an estimated serial interval distribution for SARS-CoV-2:
$$G \sim \text{Gamma}(6.5,0.62).$$
Observed cases \(C_t\) are distributed around latent infections with negative binomial errors:
$$C_t \sim \text{NegBin}(\text{mean} = I_t,~ \text{overdispersion} = \phi).$$
In the examples below we are going to largely recreate the Mishra et al model, whilst emphasing that each component of the overall epidemiological model is, itself, a stand alone model that can be sampled from.
Dependencies for this notebook
Now we want to import these dependencies into scope. If evaluating these code lines/blocks in REPL, then the REPL will offer to install any missing dependencies. Alternatively, you can add them to your active environment using Pkg.add
.
using EpiAware
using Turing, DynamicPPL #Underlying Turing ecosystem packages to interact with models
using Distributions, Statistics #Statistics packages
using CSV, DataFramesMeta #Data wrangling
using CairoMakie, PairPlots, TimeSeries #Plotting backend
using ReverseDiff #Automatic differentiation backend
begin #Date utility and set Random seed
using Dates
using Random
Random.seed!(1)
end
TaskLocalRNG()
Load early SARS-2 case data for South Korea
First, we make sure that we have the data we want to analysis in scope by downloading it for where we have saved a copy in the EpiAware
repository.
NB: The case data is curated by the covidregionaldata
package. We accessed the South Korean case data using a short R script. It is possible to interface directly from a Julia session using the RCall.jl
package, but we do not do this in this notebook to reduce the number of underlying dependencies required to run this notebook.
url = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/main/EpiAware/docs/src/showcase/replications/mishra-2020/south_korea_data.csv2"
"https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/main/EpiAware/docs/src/showcase/replications/mishra-2020/south_korea_data.csv2"
data = CSV.read(download(url), DataFrame)
Column1 | date | cases_new | deaths_new | |
---|---|---|---|---|
1 | 1 | 2019-12-31 | 0 | 0 |
2 | 2 | 2020-01-01 | 0 | 0 |
3 | 3 | 2020-01-02 | 0 | 0 |
4 | 4 | 2020-01-03 | 0 | 0 |
5 | 5 | 2020-01-04 | 0 | 0 |
6 | 6 | 2020-01-05 | 0 | 0 |
7 | 7 | 2020-01-06 | 0 | 0 |
8 | 8 | 2020-01-07 | 0 | 0 |
9 | 9 | 2020-01-08 | 0 | 0 |
10 | 10 | 2020-01-09 | 0 | 0 |
... | ||||
214 | 214 | 2020-07-31 | 36 | 1 |
Time-varying reproduction number as an AbstractLatentModel
type
EpiAware
exposes a AbstractLatentModel
abstract type; the purpose of which is to group stochastic processes which can be interpreted as generating time-varying parameters/quantities of interest which we call latent process models.
In the Mishra et al model the log-time varying reproductive number \(Z_t\) is assumed to evolve as an auto-regressive process, AR(2):
$$\begin{align} R_t &= \exp Z_t, \\ Z_t &= \rho_1 Z_{t-1} + \rho_2 Z_{t-2} + \epsilon_t, \\ \epsilon_t &\sim \text{Normal}(0, \sigma^*). \end{align}$$
Where \(\rho_1,\rho_2\), which are the parameters of AR process, and \(\epsilon_t\) is a white noise process with standard deviation \(\sigma^*\).
In EpiAware
we determine the behaviour of a latent process by choosing a concrete subtype (i.e. a struct) of AbstractLatentModel
which has fields that set the priors of the various parameters required for the latent process.
The AR process has the struct AR <: AbstractLatentModel
. The user can supply the priors for \(\rho_1,\rho_2\) in the field damp_priors
, for \(\sigma^*\) in the field std_prior
, and the initial values \(Z_1, Z_2\) in the field init_priors
.
We choose priors based on Mishra et al using the Distributions.jl
interface to probability distributions. Note that we condition the AR parameters onto \([0,1]\), as in Mishra et al, using the truncated
function.
In Mishra et al the standard deviation of the stationary distribution of \(Z_t\) which has a standard normal distribution conditioned to be positive \(\sigma \sim \mathcal{N}^+(0,1)\). The value \(σ^*\) was determined from a nonlinear function of sampled \(\sigma, ~\rho_1, ~\rho_2\) values. Since, Mishra et al give sharply informative priors for \(\rho_1,~\rho_2\) (see below) we simplify by calculating \(\sigma^*\) at the prior mode of \(\rho_1,~\rho_2\). This results in a \(\sigma^* \sim \mathcal{N}^+(0, 0.5)\) prior.
ar = AR(
damp_priors = reverse([truncated(Normal(0.8, 0.05), 0, 1),
truncated(Normal(0.1, 0.05), 0, 1)]),
std_prior = HalfNormal(0.5),
init_priors = [Normal(-1.0, 0.1), Normal(-1.0, 0.5)]
)
AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}(Distributions.Product{Distributions.Continuous, Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}, Vector{Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}}}(v=Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}[Truncated(Distributions.Normal{Float64}(μ=0.1, σ=0.05); lower=0.0, upper=1.0), Truncated(Distributions.Normal{Float64}(μ=0.8, σ=0.05); lower=0.0, upper=1.0)]), HalfNormal{Float64}(μ=0.5), Distributions.Product{Distributions.Continuous, Distributions.Normal{Float64}, Vector{Distributions.Normal{Float64}}}(v=Normal{Float64}[Distributions.Normal{Float64}(μ=-1.0, σ=0.1), Distributions.Normal{Float64}(μ=-1.0, σ=0.5)]), 2)
Turing
model interface to the AR process
As mentioned above, we can use this instance of the AR
latent model to construct a Turing
model object which implements the probabilistic behaviour determined by ar
. We do this with the constructor function exposed by EpiAware
: generate_latent
which combines an AbstractLatentModel
substype struct with the number of time steps for which we want to generate the latent process.
As a refresher, we remind that the Turing.Model
object has the following properties:
The model object parameters are sampleable using
rand
; that is we can generate parameters from the specified priors e.g.θ = rand(mdl)
.The model object is generative as a callable; that is we can sample instances of \(Z_t\) e.g.
Z_t = mdl()
.The model object can construct new model objects by conditioning parameters using the
DynamicPPL.jl
syntax, e.g.conditional_mdl = mdl | (σ_AR = 1.0, )
.
As a concrete example we create a model object for the AR(2) process we specified above for 50 time steps:
ar_mdl = generate_latent(ar, 50)
Model{typeof(generate_latent), (:latent_model, :n), (), (), Tuple{AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}, Int64}, Tuple{}, DefaultContext}(generate_latent, (latent_model = AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}(Distributions.Product{Distributions.Continuous, Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}, Vector{Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}}}(v=Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}[Truncated(Distributions.Normal{Float64}(μ=0.1, σ=0.05); lower=0.0, upper=1.0), Truncated(Distributions.Normal{Float64}(μ=0.8, σ=0.05); lower=0.0, upper=1.0)]), HalfNormal{Float64}(μ=0.5), Distributions.Product{Distributions.Continuous, Distributions.Normal{Float64}, Vector{Distributions.Normal{Float64}}}(v=Normal{Float64}[Distributions.Normal{Float64}(μ=-1.0, σ=0.1), Distributions.Normal{Float64}(μ=-1.0, σ=0.5)]), 2), n = 50), NamedTuple(), DefaultContext())
Ultimately, this will only be one component of the full epidemiological model. However, it is useful to visualise its probabilistic behaviour for model diagnostic and prior predictive checking.
We can spaghetti plot generative samples from the AR(2) process with the priors specified above.
plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
ar_mdl() .|> exp #Sample Z_t trajectories for the model
end
fig = Figure()
ax = Axis(fig[1, 1];
yscale = log10,
ylabel = "Time varying Rₜ",
title = "$(n_samples) draws from the prior Rₜ model"
)
for col in eachcol(ar_mdl_samples)
lines!(ax, col, color = (:grey, 0.1))
end
fig
end
This suggests that a priori we believe that there is a few percent chance of achieving very high \(R_t\) values, i.e. \(R_t \sim 10-1000\) is not excluded by our priors.
The Renewal model as an AbstractEpiModel
type
The abstract type for models that generate infections exposed by EpiAware
is called AbstractEpiModel
. As with latent models different concrete subtypes of AbstractEpiModel
define different classes of infection generating process. In this case we want to implement a renewal model.
The Renewal <: AbstractEpiModel
type of struct needs two fields:
Data about the generation interval of the infectious disease so it can construct \(g_t\).
A prior for the initial numbers of infected.
In Mishra et al they use an estimate of the serial interval of SARS-CoV-2 as an estimate of the generation interval.
truth_GI = Gamma(6.5, 0.62)
Distributions.Gamma{Float64}(α=6.5, θ=0.62)
This is a representation of the generation interval distribution as continuous whereas the infection process will be formulated in discrete daily time steps. By default, EpiAware
performs double interval censoring to convert our continuous estimate of the generation interval into a discretized version \(g_t\), whilst also applying left truncation such that \(g_0 = 0\) and normalising \(\sum_t g_t = 1.\)
The constructor for converting a continuous estimate of the generation interval distribution into a usable discrete time estimate is EpiData
.
model_data = EpiData(gen_distribution = truth_GI)
EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp)
We can compare the discretized generation interval with the continuous estimate, which in this example is the serial interval estimate.
let
fig = Figure()
ax = Axis(fig[1, 1];
xticks = 0:14,
xlabel = "Days",
title = "Continuous and discrete generation intervals"
)
barplot!(ax, model_data.gen_int;
label = "Discretized next gen pmf"
)
lines!(truth_GI;
label = "Continuous serial interval",
color = :green
)
axislegend(ax)
fig
end
The user also needs to specify a prior for the log incidence at time zero, \(\log I_0\). The initial history of latent infections \(I_{-1}, I_{-2},\dots\) is constructed as
$$I_t = e^{rt} I_0,\qquad t = 0, -1, -2,...$$
Where the exponential growth rate \(r\) is determined by the initial reproductive number \(R_1\) via the solution to the implicit equation,
$$R_1 = 1 \Big{/} \sum_{t\geq 1} e^{-rt} g_t$$
log_I0_prior = Normal(log(1.0), 1.0)
Distributions.Normal{Float64}(μ=0.0, σ=1.0)
epi = Renewal(model_data; initialisation_prior = log_I0_prior)
Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}(EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp), Distributions.Normal{Float64}(μ=0.0, σ=1.0), EpiAware.EpiInfModels.ConstantRenewalStep{Float64}([0.019313826994143808, 0.04573437575216367, 0.09635404000022223, 0.1731751163417783, 0.24789569560506844, 0.2502660305615846, 0.14059778064943784, 0.026663134095601056]))
NB: We don't implement a background infection rate in this model.
Turing
model interface to Renewal
process
As mentioned above, we can use this instance of the Renewal
latent infection model to construct a Turing
Model
which implements the probabilistic behaviour determined by epi
using the constructor function generate_latent_infs
which combines epi
with a provided \(\log R_t\) time series.
Here we choose an example where \(R_t\) decreases from \(R_t = 3\) to \(R_t = 0.5\) over the course of 50 days.
R_t_fixed = [0.5 + 2.5 / (1 + exp(t - 15)) for t in 1:50]
50-element Vector{Float64}: 2.9999979211799306 2.9999943491892553 2.9999846395634946 2.99995824644538 2.9998865053282437 2.9996915135600344 2.999161624673834 ⋮ 0.5000000000002339 0.500000000000086 0.5000000000000316 0.5000000000000117 0.5000000000000043 0.5000000000000016
latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed))
Model{typeof(generate_latent_infs), (:epi_model, :_Rt), (), (), Tuple{Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}, Vector{Float64}}, Tuple{}, DefaultContext}(generate_latent_infs, (epi_model = Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}(EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp), Distributions.Normal{Float64}(μ=0.0, σ=1.0), EpiAware.EpiInfModels.ConstantRenewalStep{Float64}([0.019313826994143808, 0.04573437575216367, 0.09635404000022223, 0.1731751163417783, 0.24789569560506844, 0.2502660305615846, 0.14059778064943784, 0.026663134095601056])), _Rt = [1.0986115957278464, 1.098610405062754, 1.0986071685094998, 1.0985983707197156, 1.0985744563952262, 1.098509454567543, 1.0983327911702674, 1.097852790994088, 1.09654964358037, 1.0930193012626002 … -0.6931471805343999, -0.6931471805505477, -0.693147180556488, -0.6931471805586734, -0.6931471805594774, -0.6931471805597732, -0.693147180559882, -0.693147180559922, -0.6931471805599366, -0.6931471805599422]), NamedTuple(), DefaultContext())
plt_epi = let
n_samples = 100
#Sample unconditionally the underlying parameters of the model
epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _
latent_inf_mdl()
end
fig = Figure()
ax1 = Axis(fig[1, 1];
title = "$(n_samples) draws from renewal model with chosen Rt",
ylabel = "Latent infections"
)
ax2 = Axis(fig[2, 1];
ylabel = "Rt"
)
for col in eachcol(epi_mdl_samples)
lines!(ax1, col;
color = (:grey, 0.1)
)
end
lines!(ax2, R_t_fixed;
linewidth = 2
)
fig
end
Negative Binomial Observations as an ObservationModel
type
In Mishra et al latent infections were assumed to occur on their observation day with negative binomial errors, this motivates using the serial interval (the time between onset of symptoms of a primary and secondary case) rather than generation interval distribution (the time between infection time of a primary and secondary case).
Observation models are set in EpiAware
as concrete subtypes of an ObservationModel
. The Negative binomial error model without observation delays is set with a NegativeBinomialError
struct. In Mishra et al the overdispersion parameter \(\phi\) sets the relationship between the mean and variance of the negative binomial errors,
$$\text{var} = \text{mean} + {\text{mean}^2 \over \phi}.$$
In EpiAware
, we default to a prior on \(\sqrt{1/\phi}\) because this quantity is approximately the coefficient of variation of the observation noise and, therefore, is easier to reason on a priori beliefs. We call this quantity the cluster factor.
A prior for \(\phi\) was not specified in Mishra et al, we select one below but we will condition a value in analysis below.
obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1))
NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.1))
Turing
model interface to the NegativeBinomialError
model
We can construct a NegativeBinomialError
model implementation as a Turing
Model
using the EpiAware
generate_observations
functions.
Turing
uses missing
arguments to indicate variables that are to be sampled. We use this to observe a forward model that samples observations, conditional on an underlying expected observation time series.
First, we set an artificial expected cases curve.
expected_cases = [1000 * exp(-(t - 15)^2 / (2 * 4)) for t in 1:30]
30-element Vector{Float64}: 2.289734845645553e-8 6.691586091292782e-7 1.5229979744712628e-5 0.0002699578503363014 0.003726653172078671 0.04006529739295107 0.33546262790251186 ⋮ 0.003726653172078671 0.0002699578503363014 1.5229979744712628e-5 6.691586091292782e-7 2.289734845645553e-8 6.101936677605324e-10
obs_mdl = generate_observations(obs, missing, expected_cases)
Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (:y_t,), Tuple{NegativeBinomialError{HalfNormal{Float64}}, Missing, Vector{Float64}}, Tuple{}, DefaultContext}(generate_observations, (obs_model = NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.1)), y_t = missing, Y_t = [2.289734845645553e-8, 6.691586091292782e-7, 1.5229979744712628e-5, 0.0002699578503363014, 0.003726653172078671, 0.04006529739295107, 0.33546262790251186, 2.187491118182885, 11.108996538242305, 43.93693362340742 … 11.108996538242305, 2.187491118182885, 0.33546262790251186, 0.04006529739295107, 0.003726653172078671, 0.0002699578503363014, 1.5229979744712628e-5, 6.691586091292782e-7, 2.289734845645553e-8, 6.101936677605324e-10]), NamedTuple(), DefaultContext())
plt_obs = let
n_samples = 100
obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = obs_mdl() #Sample unconditionally the underlying parameters of the model
end
fig = Figure()
ax = Axis(fig[1, 1];
title = "$(n_samples) draws from neg. bin. obs model",
ylabel = "Observed cases"
)
for col in eachcol(obs_mdl_samples)
scatter!(ax, col;
color = (:grey, 0.2)
)
end
lines!(ax, expected_cases;
color = :red,
linewidth = 3,
label = "Expected cases"
)
axislegend(ax)
fig
end
Composing models into an EpiProblem
Mishra et al follows a common pattern of having an infection generation process driven by a latent process with an observation model that links the infection process to a discrete valued time series of incidence data.
In EpiAware
we provide an EpiProblem
constructor for this common epidemiological model pattern.
The constructor for an EpiProblem
requires:
An
epi_model
.A
latent_model
.An
observation_model
.A
tspan
.
The tspan
set the range of the time index for the models.
epi_prob = EpiProblem(epi_model = epi,
latent_model = ar,
observation_model = obs,
tspan = (45, 80))
EpiProblem{Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}, AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}, NegativeBinomialError{HalfNormal{Float64}}}(Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}(EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp), Distributions.Normal{Float64}(μ=0.0, σ=1.0), EpiAware.EpiInfModels.ConstantRenewalStep{Float64}([0.019313826994143808, 0.04573437575216367, 0.09635404000022223, 0.1731751163417783, 0.24789569560506844, 0.2502660305615846, 0.14059778064943784, 0.026663134095601056])), AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}(Distributions.Product{Distributions.Continuous, Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}, Vector{Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}}}(v=Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}[Truncated(Distributions.Normal{Float64}(μ=0.1, σ=0.05); lower=0.0, upper=1.0), Truncated(Distributions.Normal{Float64}(μ=0.8, σ=0.05); lower=0.0, upper=1.0)]), HalfNormal{Float64}(μ=0.5), Distributions.Product{Distributions.Continuous, Distributions.Normal{Float64}, Vector{Distributions.Normal{Float64}}}(v=Normal{Float64}[Distributions.Normal{Float64}(μ=-1.0, σ=0.1), Distributions.Normal{Float64}(μ=-1.0, σ=0.5)]), 2), NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.1)), (45, 80))
Inference Methods
We make inferences on the unobserved quantities, such as \(R_t\) by sampling from the model conditioned on the observed data. We generate the posterior samples using the No U-Turns (NUTS) sampler.
To make NUTS more robust we provide manypathfinder
, which is built on pathfinder variational inference from Pathfinder.jl. manypathfinder
runs nruns
pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO.
The composition of doing variational inference as a pre-sampler step which gets passed to NUTS initialisation is defined using the EpiMethod
struct, where a sequence of pre-sampler steps can be be defined.
EpiMethod
also allows the specification of NUTS parameters, such as type of automatic differentiation, type of parallelism and number of parallel chains to sample.
num_threads = min(10, Threads.nthreads())
1
inference_method = EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)],
sampler = NUTSampler(
adtype = AutoReverseDiff(compile = true),
ndraws = 2000,
nchains = num_threads,
mcmc_parallel = MCMCThreads())
)
EpiMethod{ManyPathfinder, NUTSampler{AutoReverseDiff{true}, MCMCThreads, UnionAll}}(ManyPathfinder[ManyPathfinder(10, 4, 100, 100)], NUTSampler{AutoReverseDiff{true}, MCMCThreads, UnionAll}(0.8, AutoReverseDiff(compile=true), MCMCThreads(), 1, 10, 1000.0, 0.0, 2000, AdvancedHMC.DiagEuclideanMetric, -1))
Inference and analysis
We supply the data as a NamedTuple
with the y_t
field containing the observed data, shortened to fit the chosen tspan
of epi_prob
.
south_korea_data = (y_t = data.cases_new[epi_prob.tspan[1]:epi_prob.tspan[2]],
dates = data.date[epi_prob.tspan[1]:epi_prob.tspan[2]])
(y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], dates = [Date("2020-02-13"), Date("2020-02-14"), Date("2020-02-15"), Date("2020-02-16"), Date("2020-02-17"), Date("2020-02-18"), Date("2020-02-19"), Date("2020-02-20"), Date("2020-02-21"), Date("2020-02-22") … Date("2020-03-10"), Date("2020-03-11"), Date("2020-03-12"), Date("2020-03-13"), Date("2020-03-14"), Date("2020-03-15"), Date("2020-03-16"), Date("2020-03-17"), Date("2020-03-18"), Date("2020-03-19")])
In the epidemiological model it is hard to identify between the AR parameters such as the standard deviation of the AR process and the cluster factor of the negative binomial observation model. The reason for this identifiability problem is that the model assumes no delay between infection and observation. Therefore, on any day the data could be explained by \(R_t\) changing or observation noise and its not easy to disentangle greater volatility in \(R_t\) from higher noise in the observations.
In models with latent delays, changes in \(R_t\) impact the observed cases over several days which means that it easier to disentangle trend effects from observation-to-observation fluctuations.
To counter act this problem we condition the model on a fixed cluster factor value.
fixed_cluster_factor = 0.25
0.25
EpiAware
has the generate_epiaware
function which joins an EpiProblem
object with the data to produce as Turing
model. This Turing
model composes the three unit Turing
models defined above: the Renewal infection generating process, the AR latent process for \(\log R_t\), and the negative binomial observation model. Therefore, we can condition on variables as with any other Turing
model.
mdl = generate_epiaware(epi_prob, south_korea_data) |
(var"obs.cluster_factor" = fixed_cluster_factor,)
Model{typeof(generate_epiaware), (:y_t, :time_steps, :epi_model), (:latent_model, :observation_model), (), Tuple{Vector{Int64}, Int64, Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}}, Tuple{AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}, NegativeBinomialError{HalfNormal{Float64}}}, ConditionContext{@NamedTuple{obs.cluster_factor::Float64}, DefaultContext}}(generate_epiaware, (y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], time_steps = 36, epi_model = Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}(EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp), Distributions.Normal{Float64}(μ=0.0, σ=1.0), EpiAware.EpiInfModels.ConstantRenewalStep{Float64}([0.019313826994143808, 0.04573437575216367, 0.09635404000022223, 0.1731751163417783, 0.24789569560506844, 0.2502660305615846, 0.14059778064943784, 0.026663134095601056]))), (latent_model = AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}(Distributions.Product{Distributions.Continuous, Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}, Vector{Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}}}(v=Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}[Truncated(Distributions.Normal{Float64}(μ=0.1, σ=0.05); lower=0.0, upper=1.0), Truncated(Distributions.Normal{Float64}(μ=0.8, σ=0.05); lower=0.0, upper=1.0)]), HalfNormal{Float64}(μ=0.5), Distributions.Product{Distributions.Continuous, Distributions.Normal{Float64}, Vector{Distributions.Normal{Float64}}}(v=Normal{Float64}[Distributions.Normal{Float64}(μ=-1.0, σ=0.1), Distributions.Normal{Float64}(μ=-1.0, σ=0.5)]), 2), observation_model = NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.1))), ConditionContext((var"obs.cluster_factor" = 0.25,), DynamicPPL.DefaultContext()))
Sampling with apply_method
The apply_method
function combines the elements above:
An
EpiProblem
object orTuring
model.An
EpiMethod
object.Data to condition the model upon.
And returns a collection of results:
The epidemiological model as a
Turing
Model
.Samples from MCMC.
Generated quantities of the model.
inference_results = apply_method(mdl,
inference_method,
south_korea_data
)
EpiAwareObservables(Model{typeof(generate_epiaware), (:y_t, :time_steps, :epi_model), (:latent_model, :observation_model), (), Tuple{Vector{Int64}, Int64, Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}}, Tuple{AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}, NegativeBinomialError{HalfNormal{Float64}}}, ConditionContext{@NamedTuple{obs.cluster_factor::Float64}, DefaultContext}}(generate_epiaware, (y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], time_steps = 36, epi_model = Renewal{EpiData{Float64, typeof(exp)}, Normal{Float64}, EpiAware.EpiInfModels.ConstantRenewalStep{Float64}}(EpiData{Float64, typeof(exp)}([0.026663134095601056, 0.14059778064943784, 0.2502660305615846, 0.24789569560506844, 0.1731751163417783, 0.09635404000022223, 0.04573437575216367, 0.019313826994143808], 8, exp), Distributions.Normal{Float64}(μ=0.0, σ=1.0), EpiAware.EpiInfModels.ConstantRenewalStep{Float64}([0.019313826994143808, 0.04573437575216367, 0.09635404000022223, 0.1731751163417783, 0.24789569560506844, 0.2502660305615846, 0.14059778064943784, 0.026663134095601056]))), (latent_model = AR{Product{Continuous, Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}, Vector{Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}}}, HalfNormal{Float64}, Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}, Int64}(Distributions.Product{Distributions.Continuous, Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}, Vector{Distributions.Truncated{Distributions.Normal{Float64}, Distributions.Continuous, Float64, Float64, Float64}}}(v=Truncated{Normal{Float64}, Continuous, Float64, Float64, Float64}[Truncated(Distributions.Normal{Float64}(μ=0.1, σ=0.05); lower=0.0, upper=1.0), Truncated(Distributions.Normal{Float64}(μ=0.8, σ=0.05); lower=0.0, upper=1.0)]), HalfNormal{Float64}(μ=0.5), Distributions.Product{Distributions.Continuous, Distributions.Normal{Float64}, Vector{Distributions.Normal{Float64}}}(v=Normal{Float64}[Distributions.Normal{Float64}(μ=-1.0, σ=0.1), Distributions.Normal{Float64}(μ=-1.0, σ=0.5)]), 2), observation_model = NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.1))), ConditionContext((var"obs.cluster_factor" = 0.25,), DynamicPPL.DefaultContext())), (y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], dates = [Date("2020-02-13"), Date("2020-02-14"), Date("2020-02-15"), Date("2020-02-16"), Date("2020-02-17"), Date("2020-02-18"), Date("2020-02-19"), Date("2020-02-20"), Date("2020-02-21"), Date("2020-02-22") … Date("2020-03-10"), Date("2020-03-11"), Date("2020-03-12"), Date("2020-03-13"), Date("2020-03-14"), Date("2020-03-15"), Date("2020-03-16"), Date("2020-03-17"), Date("2020-03-18"), Date("2020-03-19")]), MCMC chain (2000×52×1 Array{Float64, 3}), @NamedTuple{generated_y_t::Vector{Int64}, I_t::Vector{Float64}, Z_t::Vector{Float64}}[(generated_y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], I_t = [0.31320720807054464, 0.29683750985248575, 0.4089938601616076, 1.3022343628622814, 1.0225626150832448, 2.8527817296411477, 10.369697939206759, 30.004843225718137, 69.27683270045847, 197.8262544135264 … 165.82971938826418, 261.08028416983564, 128.7797349208945, 98.10236805176936, 150.39291882409293, 144.96465663477696, 85.45247839555344, 76.49701858640442, 93.82203638614712, 150.84255039241933], Z_t = [-0.9622516557378923, -0.779767809194889, -0.22894416095706988, 1.1292735506394296, 0.9415632945314946, 1.7200973931688441, 2.6194952823620414, 3.0811689176981933, 2.9825028812997374, 3.015766126931429 … -1.1445943026275176, -0.527355948876789, -1.0005686058194918, -1.0426429725004263, -0.4220655088301653, -0.2592548294034239, -0.6147274678250605, -0.6236497063004567, -0.342583682700895, 0.25399548116615]); (generated_y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], I_t = [0.45814991051360804, 0.5157565397914593, 0.7485784521593867, 1.9258424522354982, 1.881579386894302, 3.0862955995411405, 23.768073533134828, 32.469112917123354, 79.08124126301547, 133.2426021603086 … 218.50970158839397, 189.3209377495298, 162.14089882072605, 74.82939396174662, 84.68530229352427, 115.38772530731804, 74.09774056573012, 78.26487650435531, 118.86475219175041, 125.40968754401513], Z_t = [-1.0088100650997336, -0.639544659082267, -0.026428595731269766, 1.106648226738017, 1.0986425367420274, 1.3272100844568537, 2.9971592948443506, 2.679686874488113, 2.5981760009741737, 2.257218180409069 … -0.771390714439445, -0.8621512462984372, -0.8677931302838009, -1.4402257812121877, -1.093298024516366, -0.5310548172927552, -0.6966495722710407, -0.4237917262795986, 0.13517581524187788, 0.28543487337970164]); … ; (generated_y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], I_t = [0.5158893098958413, 0.4643438820242316, 0.573058040409358, 1.0591979428548484, 1.0216766089139648, 4.977924446799509, 10.379829593366294, 38.06015718537239, 52.785133791434795, 110.17269294072234 … 133.32317529108064, 246.07992816063788, 178.29360470154046, 124.97217683109014, 124.06804457042375, 71.61177951656336, 71.22461510270242, 101.9991809371847, 93.03147773366017, 120.5822766958333], Z_t = [-1.0210041815742075, -0.8713671807503602, -0.41324905463364925, 0.421555199727424, 0.5159404613230617, 2.051190024137349, 2.4937798436595635, 3.123698200942809, 2.5209349500859215, 2.3128754298302074 … -1.4075716776839973, -0.7130153359849505, -0.8603864344110592, -0.9819602779371421, -0.7664180669401581, -1.1026213666869615, -0.9002959563439888, -0.3274085775203858, -0.20380639049023946, 0.1989039251891774]); (generated_y_t = [0, 0, 0, 1, 1, 1, 15, 34, 75, 190 … 131, 242, 114, 110, 107, 76, 74, 84, 93, 152], I_t = [0.504769766867223, 0.310306090042842, 0.5131669904549973, 1.5030702265370084, 2.802351779067001, 1.4698799404359122, 12.981457766592483, 20.113164564630615, 74.3333268644793, 216.22515597743123 … 253.24676521364242, 221.20271873483486, 102.18122690472096, 103.19829495194227, 116.67237754308549, 83.73332916634705, 77.75012482367323, 66.34548811778136, 91.87960122252447, 163.27415680770434], Z_t = [-0.9524412048170314, -1.2057488759480226, -0.4696039243518976, 0.839141397204691, 1.6035897556442378, 0.7535006022615907, 2.484868832014502, 2.4245438158415697, 2.9847584797143485, 3.1699589180089607 … -0.6586714025253477, -0.6448978991097477, -1.227673886693326, -1.03657558668536, -0.7138395469441117, -0.8091533570285688, -0.6492157835649145, -0.6210108342818847, -0.13636614713928563, 0.5836160289527288]);;])
Results and Predictive plotting
To assess the quality of the inference visually we can plot predictive quantiles for generated case data from the version of the model which hasn't conditioned on case data using posterior parameters inferred from the version conditioned on observed data. For this purpose, we add a generated_quantiles
utility function. This kind of visualisation is known as posterior predictive checking, and is a useful diagnostic tool for Bayesian inference (see here).
We also plot the inferred \(R_t\) estimates from the model. We find that the EpiAware
model recovers the main finding in Mishra et al; that the \(R_t\) in South Korea peaked at a very high value (\(R_t \sim 10\) at peak) before rapidly dropping below 1 in early March 2020.
Note that, in reality, the peak \(R_t\) found here and in Mishra et al is unrealistically high, this might be due to a combination of:
A mis-estimated generation interval/serial interval distribution.
An ascertainment rate that was, in reality, changing over time.
In a future note, we'll demonstrate having a time-varying ascertainment rate.
function generated_quantiles(gens, quantity, qs; transformation = x -> x)
mapreduce(hcat, gens) do gen #loop over sampled generated quantities
getfield(gen, quantity) |> transformation
end |> mat -> mapreduce(hcat, qs) do q #Loop over matrix row to condense into qs
map(eachrow(mat)) do row
if any(ismissing, row)
return missing
else
quantile(row, q)
end
end
end
end
generated_quantiles (generic function with 1 method)
let
C = south_korea_data.y_t
D = south_korea_data.dates
#Case unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob,
(y_t = fill(missing, length(C)),)
) | (var"obs.cluster_factor" = fixed_cluster_factor,)
posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples)
#plotting quantiles
qs = [0.025, 0.25, 0.5, 0.75, 0.975]
#Prediction quantiles
predicted_y_t = generated_quantiles(posterior_gens, :generated_y_t, qs)
predicted_R_t = generated_quantiles(
posterior_gens, :Z_t, qs; transformation = x -> exp.(x))
ts = D .|> d -> d - minimum(D) .|> d -> d.value + 1
t_ticks = string.(D)
fig = Figure()
ax1 = Axis(fig[1, 1];
ylabel = "Daily cases",
xticks = (ts[1:14:end], t_ticks[1:14:end]),
title = "Posterior predictive: Cases"
)
ax2 = Axis(fig[2, 1];
yscale = log10,
title = "Prediction: Reproduction number",
xticks = (ts[1:14:end], t_ticks[1:14:end])
)
linkxaxes!(ax1, ax2)
lines!(ax1, ts, predicted_y_t[:, 3];
color = :purple,
linewidth = 2,
label = "Post. median"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 2], predicted_y_t[:, 4];
color = (:purple, 0.4),
label = "50%"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 1], predicted_y_t[:, 5];
color = (:purple, 0.2),
label = "95%"
)
scatter!(ax1, C;
color = :black,
label = "Actual cases")
axislegend(ax1)
lines!(ax2, ts, predicted_R_t[:, 3];
color = :green,
linewidth = 2,
label = "Post. median"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 2], predicted_R_t[:, 4];
color = (:green, 0.4),
label = "50%"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 1], predicted_R_t[:, 5];
color = (:green, 0.2),
label = "95%"
)
axislegend(ax2)
fig
end
Parameter inference
We can interrogate the sampled chains directly from the samples
field of the inference_results
object.
let
sub_chn = inference_results.samples[inference_results.samples.name_map.parameters[[1:5;
end]]]
fig = pairplot(sub_chn)
lines!(fig[1, 1], ar.std_prior, label = "Prior")
lines!(fig[2, 2], ar.init_prior.v[1], label = "Prior")
lines!(fig[3, 3], ar.init_prior.v[2], label = "Prior")
lines!(fig[4, 4], ar.damp_prior.v[1], label = "Prior")
lines!(fig[5, 5], ar.damp_prior.v[2], label = "Prior")
lines!(fig[6, 6], epi.initialisation_prior, label = "Prior")
fig
end