Example: Statistical inference for ODE-based infectious disease models
Introduction
What are we going to do in this Vignette
In this vignette, we'll demonstrate how to use EpiAware
in conjunction with SciML ecosystem for Bayesian inference of infectious disease dynamics. The model and data is heavily based on Contemporary statistical inference for infectious disease models using Stan Chatzilena et al. 2019.
We'll cover the following key points:
Defining the deterministic ODE model from Chatzilena et al section 2.2.2 using SciML ODE functionality and an
EpiAware
observation model.Build on this to define the stochastic ODE model from Chatzilena et al section 2.2.3 using an
EpiAware
observation model.Fitting the deterministic ODE model to data from an Influenza outbreak in an English boarding school.
Fitting the stochastic ODE model to data from an Influenza outbreak in an English boarding school.
What might I need to know before starting
This vignette builds on concepts from EpiAware
observation models and a familarity with the SciML
and Turing
ecosystems would be useful but not essential.
Packages used in this vignette
Alongside the EpiAware
package we will use the OrdinaryDiffEq
and SciMLSensitivity
packages for interfacing with SciML
ecosystem; this is a lower dependency usage of DifferentialEquations.jl
that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. Bayesian inference will be done with NUTS
from the Turing
ecosystem. We will also use the CairoMakie
package for plotting and DataFramesMeta
for data manipulation.
using EpiAware
using Turing
using OrdinaryDiffEq, SciMLSensitivity #ODE solvers and adjoint methods
using Distributions, Statistics, LogExpFunctions #Statistics and special func packages
using CSV, DataFramesMeta #Data wrangling
using CairoMakie, PairPlots
using ReverseDiff #Automatic differentiation backend
begin #Date utility and set Random seed
using Dates
using Random
Random.seed!(1234)
end
TaskLocalRNG()
Single population SIR model
As mentioned in Chatzilena et al disease spread is frequently modelled in terms of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals.
$$\begin{aligned} {dS \over dt} &= - \beta \frac{I(t)}{N} S(t) \\ {dI \over dt} &= \beta \frac{I(t)}{N} S(t) - \gamma I(t) \\ {dR \over dt} &= \gamma I(t). \\ \end{aligned}$$
where S(t) represents the number of susceptible, I(t) the number of infected and R(t) the number of recovered individuals at time t. The total population size is denoted by N (with N = S(t) + I(t) + R(t)), β denotes the transmission rate and γ denotes the recovery rate.
We can interface to the SciML
ecosystem by writing a function with the signature:
(du, u, p, t) -> nothing
Where:
du
is the vector field of the ODE problem, e.g. \({dS \over dt}\), \({dI \over dt}\) etc. This is calculated in-place (commonly denoted using ! in function names in Julia).u
is the state of the ODE problem, e.g. \(S\), \(I\), etc.p
is an object that represents the parameters of the ODE problem, e.g. \(\beta\), \(\gamma\).t
is the time of the ODE problem.
We do this for the SIR model described above in a function called sir!
:
function sir!(du, u, p, t)
S, I, R = u
β, γ = p
du[1] = -β * I * S
du[2] = β * I * S - γ * I
du[3] = γ * I
return nothing
end
sir! (generic function with 1 method)
We combine vector field function sir!
with a initial condition u0
and the integration period tspan
to make an ODEProblem
. We do not define the parameters, these will be defined within an inference approach.
sir_prob = ODEProblem(
sir!,
N .* [0.99, 0.01, 0.0],
(0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1)
)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true timespan: (0.0, 14.0) u0: 3-element Vector{Float64}: 755.37 7.63 0.0
Note that this is analogous to the EpiProblem
approach we expose from EpiAware
, as used in the Mishra et al replication. The difference is that here we are going to use ODE solvers from the SciML
ecosystem to generate the dynamics of the underlying infections. In the linked example, we use latent process generation exposed by EpiAware
as the underlying generative process for underlying dynamics.
Data for inference
There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. Of the 763 children at the boarding scholl, 512 became ill.
We downloaded the data of this outbreak using the R package outbreaks
which is maintained as part of the R Epidemics Consortium(RECON).
data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/main/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |>
url -> CSV.read(download(url), DataFrame) |>
df -> @transform(df,
:ts=(:date .- minimum(:date)) .|> d -> d.value + 1.0,)
Column1 | date | in_bed | convalescent | ts | |
---|---|---|---|---|---|
1 | 1 | 1978-01-22 | 3 | 0 | 1.0 |
2 | 2 | 1978-01-23 | 8 | 0 | 2.0 |
3 | 3 | 1978-01-24 | 26 | 0 | 3.0 |
4 | 4 | 1978-01-25 | 76 | 0 | 4.0 |
5 | 5 | 1978-01-26 | 225 | 9 | 5.0 |
6 | 6 | 1978-01-27 | 298 | 17 | 6.0 |
7 | 7 | 1978-01-28 | 258 | 105 | 7.0 |
8 | 8 | 1978-01-29 | 233 | 162 | 8.0 |
9 | 9 | 1978-01-30 | 189 | 176 | 9.0 |
10 | 10 | 1978-01-31 | 128 | 166 | 10.0 |
11 | 11 | 1978-02-01 | 68 | 150 | 11.0 |
12 | 12 | 1978-02-02 | 29 | 85 | 12.0 |
13 | 13 | 1978-02-03 | 14 | 47 | 13.0 |
14 | 14 | 1978-02-04 | 4 | 20 | 14.0 |
N = 763;
Inference for the deterministic SIR model
The boarding school data gives the number of children "in bed" and "convalescent" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow Chatzilena et al and treat the number "in bed" as a proxy for the number of children in the infectious (I) compartment in the ODE model.
The full observation model is:
$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5). \end{aligned}$$
NB: Chatzilena et al give \(\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).\) However, this doesn't match their underlying stan code.
From EpiAware
, we have the PoissonError
struct which defines the probabilistic structure of this observation error model.
obs = PoissonError()
PoissonError()
Now we can write the probabilistic model using the Turing
PPL. Note that instead of using \(I(t)\) directly we do the softplus transform on \(I(t)\) implemented by LogExpFunctions.log1pexp
. The reason is that the solver can return small negative numbers, the soft plus transform smoothly maintains positivity which being very close to \(I(t)\) when \(I(t) > 2\).
@model function deterministic_ode_mdl(y_t, ts, obs, prob, N;
solver = AutoTsit5(Rosenbrock23())
)
##Priors##
β ~ LogNormal(0.0, 1.0)
γ ~ Gamma(0.004, 1 / 0.002)
S₀ ~ Beta(0.5, 0.5)
##remake ODE model##
_prob = remake(prob;
u0 = [S₀, 1 - S₀, 0.0],
p = [β, γ]
)
##Solve remade ODE model##
sol = solve(_prob, solver;
saveat = ts,
verbose = false)
##log-like accumulation using obs##
λt = log1pexp.(N * sol[2, :]) # #expected It
@submodel generated_y_t = generate_observations(obs, y_t, λt)
##Generated quantities##
return (; sol, generated_y_t, R0 = β / γ)
end
deterministic_ode_mdl (generic function with 2 methods)
We instantiate the model in two ways:
deterministic_mdl
: This conditions the generative model on the data observation. We can sample from this model to find the posterior distribution of the parameters.deterministic_uncond_mdl
: This doesn't condition on the data. This is useful for prior and posterior predictive modelling.
Here we construct the Turing
model directly, in the Mishra et al replication we using the EpiProblem
functionality to build a Turing
model under the hood. Because in this note we are using a mix of functionality from SciML
and EpiAware
, we construct the model to sample from directly.
deterministic_mdl = deterministic_ode_mdl(data.in_bed, data.ts, obs, sir_prob, N);
deterministic_uncond_mdl = deterministic_ode_mdl(
fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N);
We add a useful plotting utility.
function plot_predYt(data, gens; title::String, ylabel::String)
fig = Figure()
ga = fig[1, 1:2] = GridLayout()
ax = Axis(ga[1, 1];
title = title,
xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string),
ylabel = ylabel
)
pred_Yt = mapreduce(hcat, gens) do gen
gen.generated_y_t
end |> X -> mapreduce(vcat, eachrow(X)) do row
quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])'
end
lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median")
band!(
ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI")
band!(
ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI")
band!(
ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI")
scatter!(ax, data.in_bed, label = "data")
leg = Legend(ga[1, 2], ax; framevisible = false)
hidespines!(ax)
fig
end
plot_predYt (generic function with 1 method)
Prior predictive sampling
let
prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000)
gens = generated_quantities(deterministic_uncond_mdl, prior_chn)
plot_predYt(data, gens;
title = "Prior predictive: deterministic model",
ylabel = "Number of Infected students"
)
end
The prior predictive checking suggests that a priori our parameter beliefs are very far from the data. Approaching the inference naively can lead to poor fits.
We do three things to mitigate this:
We choose a switching ODE solver which switches between explicit (
Tsit5
) and implicit (Rosenbrock23
) solvers. This helps avoid the ODE solver failing when the sampler tries extreme parameter values. This is the defaultsolver = AutoTsit5(Rosenbrock23())
above.We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for
NUTS
.
nmle_tries = 100
100
mle_fit = map(1:nmle_tries) do _
fit = try
maximum_likelihood(deterministic_mdl)
catch
(lp = -Inf,)
end
end |>
fits -> (findmax(fit -> fit.lp, fits)[2], fits) |>
max_and_fits -> max_and_fits[2][max_and_fits[1]]
ModeResult with maximized lp of -67.36 [1.8991528341217605, 0.4808836287362608, 0.9995360155493858]
mle_fit.optim_result.retcode
ReturnCode.Success = 1
Note that we choose the best out of 100 tries for the MLE estimators.
Now, we sample aiming at 1000 samples for each of 4 chains.
chn = sample(
deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4;
initial_params = fill(mle_fit.values.array, 4)
)
iteration | chain | β | γ | S₀ | lp | n_steps | is_accept | ... | |
---|---|---|---|---|---|---|---|---|---|
1 | 501 | 1 | 1.92453 | 0.498762 | 0.999601 | -80.6922 | 7.0 | 1.0 | |
2 | 502 | 1 | 1.91393 | 0.510477 | 0.999503 | -83.0619 | 7.0 | 1.0 | |
3 | 503 | 1 | 1.82063 | 0.454811 | 0.999413 | -83.098 | 15.0 | 1.0 | |
4 | 504 | 1 | 2.00822 | 0.502552 | 0.999739 | -83.5956 | 31.0 | 1.0 | |
5 | 505 | 1 | 2.02514 | 0.461803 | 0.999763 | -83.666 | 15.0 | 1.0 | |
6 | 506 | 1 | 1.99927 | 0.465928 | 0.999722 | -81.8424 | 7.0 | 1.0 | |
7 | 507 | 1 | 1.79381 | 0.488809 | 0.999205 | -81.2625 | 63.0 | 1.0 | |
8 | 508 | 1 | 1.79029 | 0.490384 | 0.999199 | -81.466 | 3.0 | 1.0 | |
9 | 509 | 1 | 1.79489 | 0.471104 | 0.999239 | -80.8799 | 15.0 | 1.0 | |
10 | 510 | 1 | 1.89717 | 0.474568 | 0.999502 | -79.587 | 15.0 | 1.0 | |
... |
describe(chn)
2-element Vector{ChainDataFrame}: Summary Statistics (3 x 8) Quantiles (3 x 6)
pairplot(chn)
Posterior predictive plotting
let
gens = generated_quantities(deterministic_uncond_mdl, chn)
plot_predYt(data, gens;
title = "Fitted deterministic model",
ylabel = "Number of Infected students"
)
end
Inference for the Stochastic SIR model
In Chatzilena et al, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. The argument is that the stochastic component of the model can absorb the noise generated by a possible mis-specification of the model.
In their approach they consider \(\kappa_t = \log \lambda_t\) where \(\kappa_t\) evolves according to an Ornstein-Uhlenbeck process:
$$d\kappa_t = \phi(\mu_t - \kappa_t) dt + \sigma dB_t.$$
Which has transition density:
$$\kappa_{t+1} | \kappa_t \sim N\Big(\mu_t + \left(\kappa_t - \mu_t\right)e^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$
Where \(\mu_t = \log(I(t))\).
We modify this approach since it implies that the \(\mu_t\) is treated as constant between observation times.
Instead we redefine \(\kappa_t\) as the log-residual:
$$\kappa_t = \log(\lambda_t / I(t)).$$
With the transition density:
$$\kappa_{t+1} | \kappa_t \sim N\Big(\kappa_te^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$
This is an AR(1) process.
The stochastic model is completed:
$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\exp(\kappa_t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5)\\ \phi & \sim \text{HalfNormal}(0, 100) \\ 1 / \sigma^2 & \sim \text{InvGamma}(0.1,0.1). \end{aligned}$$
We will using the AR
struct from EpiAware
to define the auto-regressive process in this model which has a direct parameterisation of the AR
model.
To convert from the formulation above we sample from the priors, and define HalfNormal
priors based on the sampled prior means of \(e^{-\phi}\) and \({\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\). We also add a strong prior that \(\kappa_1 \approx 0\).
ϕs = rand(truncated(Normal(0, 100), lower = 0.0), 1000)
1000-element Vector{Float64}: 84.27394515942191 13.516491690956862 51.07348186961277 37.941468070981934 128.41727813505105 43.06012859066134 62.31804897315879 ⋮ 56.57116875489856 158.33706887743045 42.72304061442974 7.423694327684998 155.60429115685992 22.802727733585563
σ²s = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x
1000-element Vector{Float64}: 0.0016224742151858818 6.79221353591839e-9 6.207746413070522e-7 0.18882277475797452 0.0001662633660039789 0.1923483831345634 0.14764829136880042 ⋮ 0.06624877782984823 0.14836794638364514 0.00021895942825830565 2.209773387224151 0.06613574232694587 0.0026714312973339926
sampled_AR_damps = ϕs .|> ϕ -> exp(-ϕ)
1000-element Vector{Float64}: 2.5135680594819346e-37 1.3485350660539842e-6 6.592781044298219e-23 3.3283560716429985e-17 1.6946683748176592e-56 1.991699264693254e-19 8.622142732783223e-28 ⋮ 2.7005584094809084e-25 1.7182434846473966e-69 2.7900964146464195e-19 0.0005969397758191972 2.641891576222659e-68 1.249974556559806e-10
sampled_AR_stds = map(ϕs, σ²s) do ϕ, σ²
(1 - exp(-2 * ϕ)) * σ² / (2 * ϕ)
end
1000-element Vector{Float64}: 9.626191179946722e-6 2.5125652762581625e-10 6.0772696376159436e-9 0.00248834302358464 6.473559026423481e-7 0.002233485935017376 0.001184635059999989 ⋮ 0.0005855348164793897 0.00046851930326718863 2.562545000417783e-6 0.14883240757631075 0.00021251259150776393 5.857701167477672e-5
We define the AR(1) process by matching means of HalfNormal
prior distributions for the damp parameters and std deviation parameter to the calculated the prior means from the Chatzilena et al definition.
ar = AR(
damp_priors = [HalfNormal(mean(sampled_AR_damps))],
std_prior = HalfNormal(mean(sampled_AR_stds)),
init_priors = [Normal(0, 0.001)]
)
AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1)
We can sample directly from the behaviour specified by the ar
struct to do prior predictive checking on the AR(1)
process.
let
nobs = size(data, 1)
ar_mdl = generate_latent(ar, nobs)
fig = Figure()
ax = Axis(fig[1, 1],
xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string),
ylabel = "exp(kt)",
title = "Prior predictive sampling for relative residual in mean pred."
)
for i in 1:500
lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15))
end
fig
end
We see that the choice of priors implies an a priori belief that the extra observation noise on the mean prediction of the ODE model is fairly small, approximately 10% relative to the mean prediction.
We can now define the probabilistic model. The stochastic model assumes a (random) time-varying ascertainment, which we implement using the Ascertainment
struct from EpiAware
. Note that instead of implementing an ascertainment factor exp.(κₜ)
directly, which can be unstable for large primal values, by default Ascertainment
uses the LogExpFunctions.xexpy
function which implements \(x\exp(y)\) stabily for a wide range of values.
To distinguish random variables sampled by various sub-processes EpiAware
process types create prefixes. The default for Ascertainment
is just the string "Ascertainment"
, but in this case we use the less verbose "va"
for "varying ascertainment".
mdl_prefix = "va"
"va"
Now we can construct our time varying ascertianment model. The main keyword arguments here are model
and latent_model
. model
sets the connection between the expected observation and the actual observation. In this case, we reuse our PoissonError
model from above. latent_model
sets the modification model on the expected values. In this case, we use the AR
process we defined above.
varying_ascertainment = Ascertainment(
model = obs,
latent_model = ar,
latent_prefix = mdl_prefix
)
Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va")
Now we can declare the full model in the Turing
PPL.
@model function stochastic_ode_mdl(y_t, ts, obs, prob, N;
solver = AutoTsit5(Rosenbrock23())
)
##Priors##
β ~ LogNormal(0.0, 1.0)
γ ~ Gamma(0.004, 1 / 0.002)
S₀ ~ Beta(0.5, 0.5)
##Remake ODE model##
_prob = remake(prob;
u0 = [S₀, 1 - S₀, 0.0],
p = [β, γ]
)
##Solve ODE model##
sol = solve(_prob, solver;
saveat = ts,
verbose = false
)
λt = log1pexp.(N * sol[2, :])
##Observation##
@submodel generated_y_t = generate_observations(obs, y_t, λt)
##Generated quantities##
return (; sol, generated_y_t, R0 = β / γ)
end
stochastic_ode_mdl (generic function with 2 methods)
stochastic_mdl = stochastic_ode_mdl(
data.in_bed,
data.ts,
varying_ascertainment,
sir_prob,
N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:y_t, :ts, :obs, :prob, :N), (:solver,), (), Tuple{Vector{Int64}, Vector{Float64}, Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}}, DynamicPPL.DefaultContext}(stochastic_ode_mdl, (y_t = [3, 8, 26, 76, 225, 298, 258, 233, 189, 128, 68, 29, 14, 4], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], obs = Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va"), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)),), DynamicPPL.DefaultContext())
stochastic_uncond_mdl = stochastic_ode_mdl(
fill(missing, length(data.in_bed)),
data.ts,
varying_ascertainment,
sir_prob,
N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:y_t, :ts, :obs, :prob, :N), (:solver,), (), Tuple{Vector{Missing}, Vector{Float64}, Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}}, DynamicPPL.DefaultContext}(stochastic_ode_mdl, (y_t = [missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], obs = Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va"), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)),), DynamicPPL.DefaultContext())
Prior predictive checking
let
prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000)
gens = generated_quantities(stochastic_uncond_mdl, prior_chn)
plot_predYt(data, gens;
title = "Prior predictive: stochastic model",
ylabel = "Number of Infected students"
)
end
The prior predictive checking again shows misaligned prior beliefs; for example a priori without data we would not expect the median prediction of number of ill children as about 600 out of 763 after 1 day.
The latent process for the log-residuals \(\kappa_t\) doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. We do this by first making an initial guess which is a mixture of:
The posterior averages from the deterministic model.
The prior averages of the structure parameters of the AR(1) process.
Zero for the time-varying noise underlying the AR(1) process.
rand(stochastic_mdl)
(β = 1.4733099145592605, γ = 2.903750758256854e-123, S₀ = 0.29861836011258897, var"va.σ_AR" = 0.04830504386163741, var"va.ar_init" = [-0.00024863122657975786], var"va.damp_AR" = [0.0032571734979405884], var"va.ϵ_t" = [0.3072398792156006, -1.183649965567883, 2.771050948892893, -0.6366422192999562, 1.6191332959597484, 0.24589190588482895, 1.4615005554123257, 0.025353011915720307, 0.16407599045634794, 0.2628599221133207, -1.0048884450877293, 1.96700665270484, -0.7501415436101209])
initial_guess = [[mean(chn[:β]),
mean(chn[:γ]),
mean(chn[:S₀]),
mean(ar.std_prior),
mean(ar.init_prior)[1],
mean(ar.damp_prior)[1]]
zeros(13)]
19-element Vector{Float64}: 1.8942148283773665 0.48062141906187955 0.9995061985155343 0.0184303247003225 0.0 0.004725237126863895 0.0 ⋮ 0.0 0.0 0.0 0.0 0.0 0.0
Starting from the initial guess, the MAP point is calculated rapidly in one pass.
map_fit_stoch_mdl = maximum_a_posteriori(stochastic_mdl;
adtype = AutoReverseDiff(),
initial_params = initial_guess
)
ModeResult with maximized lp of -69.56 [1.9168299393361845, 0.48970414611353336, 0.9995563465737589, 0.06675569399815419, 1.3740571956377935e-6, 0.0001575540787629065, 0.14269438141615298, 0.17055297942269643, -0.2985981800136858, 0.6377161610468507, -0.008381874345876414, -0.5911576922746032, 0.7987402092625825, 1.7391572458510847, 1.438270023561078, 0.2451580311597817, -0.6799723083106257, -0.7437116269695498, -0.8064297334547238]
Now we can run NUTS, sampling 1000 posterior draws per chain for 4 chains.
chn2 = sample(
stochastic_mdl,
NUTS(; adtype = AutoReverseDiff(true)),
MCMCThreads(), 1000, 4;
initial_params = fill(map_fit_stoch_mdl.values.array, 4)
)
iteration | chain | β | γ | S₀ | va.σ_AR | va.ar_init[1] | va.damp_AR[1] | ... | |
---|---|---|---|---|---|---|---|---|---|
1 | 501 | 1 | 1.83773 | 0.471008 | 0.999432 | 0.0492726 | 0.000776343 | 0.0106267 | |
2 | 502 | 1 | 1.90809 | 0.493836 | 0.999513 | 0.0416465 | -0.000522202 | 0.00238174 | |
3 | 503 | 1 | 1.83978 | 0.484794 | 0.9994 | 0.0329526 | 0.00010738 | 0.00806697 | |
4 | 504 | 1 | 1.83873 | 0.482094 | 0.999383 | 0.028583 | -0.000843394 | 0.0110785 | |
5 | 505 | 1 | 1.86047 | 0.481308 | 0.999422 | 0.0841 | -0.000560455 | 0.00169988 | |
6 | 506 | 1 | 1.88082 | 0.504821 | 0.999495 | 0.0357006 | -0.000644581 | 0.0140381 | |
7 | 507 | 1 | 1.85906 | 0.492325 | 0.999433 | 0.0699084 | -0.00108052 | 0.000536058 | |
8 | 508 | 1 | 1.9207 | 0.487275 | 0.999568 | 0.0541636 | -0.00114944 | 0.0120675 | |
9 | 509 | 1 | 2.02692 | 0.507493 | 0.999699 | 0.0512195 | -0.000289786 | 0.00205155 | |
10 | 510 | 1 | 1.93871 | 0.50499 | 0.999638 | 0.0592997 | -0.000668149 | 0.0108368 | |
... |
describe(chn2)
2-element Vector{ChainDataFrame}: Summary Statistics (19 x 8) Quantiles (19 x 6)
pairplot(chn2[[:β, :γ, :S₀, Symbol(mdl_prefix * ".σ_AR"),
Symbol(mdl_prefix * ".ar_init[1]"), Symbol(mdl_prefix * ".damp_AR[1]")]])
let
vars = mapreduce(vcat, 1:13) do i
Symbol(mdl_prefix * ".ϵ_t[$i]")
end
pairplot(chn2[vars])
end
let
gens = generated_quantities(stochastic_uncond_mdl, chn2)
plot_predYt(data, gens;
title = "Fitted stochastic model",
ylabel = "Number of Infected students"
)
end