Example: Statistical inference for ODE-based infectious disease models

Introduction

What are we going to do in this Vignette

In this vignette, we'll demonstrate how to use EpiAware in conjunction with SciML ecosystem for Bayesian inference of infectious disease dynamics. The model and data is heavily based on Contemporary statistical inference for infectious disease models using Stan Chatzilena et al. 2019.

We'll cover the following key points:

  1. Defining the deterministic ODE model from Chatzilena et al section 2.2.2 using SciML ODE functionality and an EpiAware observation model.

  2. Build on this to define the stochastic ODE model from Chatzilena et al section 2.2.3 using an EpiAware observation model.

  3. Fitting the deterministic ODE model to data from an Influenza outbreak in an English boarding school.

  4. Fitting the stochastic ODE model to data from an Influenza outbreak in an English boarding school.

What might I need to know before starting

This vignette builds on concepts from EpiAware observation models and a familarity with the SciML and Turing ecosystems would be useful but not essential.

Packages used in this vignette

Alongside the EpiAware package we will use the OrdinaryDiffEq and SciMLSensitivity packages for interfacing with SciML ecosystem; this is a lower dependency usage of DifferentialEquations.jl that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. Bayesian inference will be done with NUTS from the Turing ecosystem. We will also use the CairoMakie package for plotting and DataFramesMeta for data manipulation.

using EpiAware
using Turing
using OrdinaryDiffEq, SciMLSensitivity #ODE solvers and adjoint methods
using Distributions, Statistics, LogExpFunctions #Statistics and special func packages
using CSV, DataFramesMeta #Data wrangling
using CairoMakie, PairPlots
using ReverseDiff #Automatic differentiation backend
begin #Date utility and set Random seed
    using Dates
    using Random
    Random.seed!(1234)
end
TaskLocalRNG()

Single population SIR model

As mentioned in Chatzilena et al disease spread is frequently modelled in terms of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals.

$$\begin{aligned} {dS \over dt} &= - \beta \frac{I(t)}{N} S(t) \\ {dI \over dt} &= \beta \frac{I(t)}{N} S(t) - \gamma I(t) \\ {dR \over dt} &= \gamma I(t). \\ \end{aligned}$$

where S(t) represents the number of susceptible, I(t) the number of infected and R(t) the number of recovered individuals at time t. The total population size is denoted by N (with N = S(t) + I(t) + R(t)), β denotes the transmission rate and γ denotes the recovery rate.

We can interface to the SciML ecosystem by writing a function with the signature:

(du, u, p, t) -> nothing

Where:

  • du is the vector field of the ODE problem, e.g. \({dS \over dt}\), \({dI \over dt}\) etc. This is calculated in-place (commonly denoted using ! in function names in Julia).

  • u is the state of the ODE problem, e.g. \(S\), \(I\), etc.

  • p is an object that represents the parameters of the ODE problem, e.g. \(\beta\), \(\gamma\).

  • t is the time of the ODE problem.

We do this for the SIR model described above in a function called sir!:

function sir!(du, u, p, t)
    S, I, R = u
    β, γ = p
    du[1] = -β * I * S
    du[2] = β * I * S - γ * I
    du[3] = γ * I

    return nothing
end
sir! (generic function with 1 method)

We combine vector field function sir! with a initial condition u0 and the integration period tspan to make an ODEProblem. We do not define the parameters, these will be defined within an inference approach.

sir_prob = ODEProblem(
    sir!,
    N .* [0.99, 0.01, 0.0],
    (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1)
)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 14.0)
u0: 3-element Vector{Float64}:
 755.37
   7.63
   0.0

Note that this is analogous to the EpiProblem approach we expose from EpiAware, as used in the Mishra et al replication. The difference is that here we are going to use ODE solvers from the SciML ecosystem to generate the dynamics of the underlying infections. In the linked example, we use latent process generation exposed by EpiAware as the underlying generative process for underlying dynamics.

Data for inference

There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. Of the 763 children at the boarding scholl, 512 became ill.

We downloaded the data of this outbreak using the R package outbreaks which is maintained as part of the R Epidemics Consortium(RECON).

data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/main/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |>
       url -> CSV.read(download(url), DataFrame) |>
              df -> @transform(df,
    :ts=(:date .- minimum(:date)) .|> d -> d.value + 1.0,)
Column1datein_bedconvalescentts
111978-01-22301.0
221978-01-23802.0
331978-01-242603.0
441978-01-257604.0
551978-01-2622595.0
661978-01-27298176.0
771978-01-282581057.0
881978-01-292331628.0
991978-01-301891769.0
10101978-01-3112816610.0
11111978-02-016815011.0
12121978-02-02298512.0
13131978-02-03144713.0
14141978-02-0442014.0
N = 763;

Inference for the deterministic SIR model

The boarding school data gives the number of children "in bed" and "convalescent" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow Chatzilena et al and treat the number "in bed" as a proxy for the number of children in the infectious (I) compartment in the ODE model.

The full observation model is:

$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5). \end{aligned}$$

NB: Chatzilena et al give \(\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).\) However, this doesn't match their underlying stan code.

From EpiAware, we have the PoissonError struct which defines the probabilistic structure of this observation error model.

obs = PoissonError()
PoissonError()

Now we can write the probabilistic model using the Turing PPL. Note that instead of using \(I(t)\) directly we do the softplus transform on \(I(t)\) implemented by LogExpFunctions.log1pexp. The reason is that the solver can return small negative numbers, the soft plus transform smoothly maintains positivity which being very close to \(I(t)\) when \(I(t) > 2\).

@model function deterministic_ode_mdl(y_t, ts, obs, prob, N;
        solver = AutoTsit5(Rosenbrock23())
)
    ##Priors##
    β ~ LogNormal(0.0, 1.0)
    γ ~ Gamma(0.004, 1 / 0.002)
    S₀ ~ Beta(0.5, 0.5)

    ##remake ODE model##
    _prob = remake(prob;
        u0 = [S₀, 1 - S₀, 0.0],
        p = [β, γ]
    )

    ##Solve remade ODE model##

    sol = solve(_prob, solver;
        saveat = ts,
        verbose = false)

    ##log-like accumulation using obs##
    λt = log1pexp.(N * sol[2, :]) # #expected It
    @submodel generated_y_t = generate_observations(obs, y_t, λt)

    ##Generated quantities##
    return (; sol, generated_y_t, R0 = β / γ)
end
deterministic_ode_mdl (generic function with 2 methods)

We instantiate the model in two ways:

  1. deterministic_mdl: This conditions the generative model on the data observation. We can sample from this model to find the posterior distribution of the parameters.

  2. deterministic_uncond_mdl: This doesn't condition on the data. This is useful for prior and posterior predictive modelling.

Here we construct the Turing model directly, in the Mishra et al replication we using the EpiProblem functionality to build a Turing model under the hood. Because in this note we are using a mix of functionality from SciML and EpiAware, we construct the model to sample from directly.

deterministic_mdl = deterministic_ode_mdl(data.in_bed, data.ts, obs, sir_prob, N);
deterministic_uncond_mdl = deterministic_ode_mdl(
    fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N);

We add a useful plotting utility.

function plot_predYt(data, gens; title::String, ylabel::String)
    fig = Figure()
    ga = fig[1, 1:2] = GridLayout()

    ax = Axis(ga[1, 1];
        title = title,
        xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string),
        ylabel = ylabel
    )
    pred_Yt = mapreduce(hcat, gens) do gen
        gen.generated_y_t
    end |> X -> mapreduce(vcat, eachrow(X)) do row
        quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])'
    end

    lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median")
    band!(
        ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI")
    band!(
        ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI")
    band!(
        ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI")
    scatter!(ax, data.in_bed, label = "data")
    leg = Legend(ga[1, 2], ax; framevisible = false)
    hidespines!(ax)

    fig
end
plot_predYt (generic function with 1 method)

Prior predictive sampling

let
    prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000)
    gens = generated_quantities(deterministic_uncond_mdl, prior_chn)
    plot_predYt(data, gens;
        title = "Prior predictive: deterministic model",
        ylabel = "Number of Infected students"
    )
end

The prior predictive checking suggests that a priori our parameter beliefs are very far from the data. Approaching the inference naively can lead to poor fits.

We do three things to mitigate this:

  1. We choose a switching ODE solver which switches between explicit (Tsit5) and implicit (Rosenbrock23) solvers. This helps avoid the ODE solver failing when the sampler tries extreme parameter values. This is the default solver = AutoTsit5(Rosenbrock23()) above.

  2. We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for NUTS.

nmle_tries = 100
100
mle_fit = map(1:nmle_tries) do _
    fit = try
        maximum_likelihood(deterministic_mdl)
    catch
        (lp = -Inf,)
    end
end |>
          fits -> (findmax(fit -> fit.lp, fits)[2], fits) |>
                  max_and_fits -> max_and_fits[2][max_and_fits[1]]
ModeResult with maximized lp of -67.36
[1.8991528341217605, 0.4808836287362608, 0.9995360155493858]
mle_fit.optim_result.retcode
ReturnCode.Success = 1

Note that we choose the best out of 100 tries for the MLE estimators.

Now, we sample aiming at 1000 samples for each of 4 chains.

chn = sample(
    deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4;
    initial_params = fill(mle_fit.values.array, 4)
)
iterationchainβγS₀lpn_stepsis_accept...
150111.924530.4987620.999601-80.69227.01.0
250211.913930.5104770.999503-83.06197.01.0
350311.820630.4548110.999413-83.09815.01.0
450412.008220.5025520.999739-83.595631.01.0
550512.025140.4618030.999763-83.66615.01.0
650611.999270.4659280.999722-81.84247.01.0
750711.793810.4888090.999205-81.262563.01.0
850811.790290.4903840.999199-81.4663.01.0
950911.794890.4711040.999239-80.879915.01.0
1051011.897170.4745680.999502-79.58715.01.0
...
describe(chn)
2-element Vector{ChainDataFrame}:
 Summary Statistics (3 x 8)
 Quantiles (3 x 6)
pairplot(chn)

Posterior predictive plotting

let
    gens = generated_quantities(deterministic_uncond_mdl, chn)
    plot_predYt(data, gens;
        title = "Fitted deterministic model",
        ylabel = "Number of Infected students"
    )
end

Inference for the Stochastic SIR model

In Chatzilena et al, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. The argument is that the stochastic component of the model can absorb the noise generated by a possible mis-specification of the model.

In their approach they consider \(\kappa_t = \log \lambda_t\) where \(\kappa_t\) evolves according to an Ornstein-Uhlenbeck process:

$$d\kappa_t = \phi(\mu_t - \kappa_t) dt + \sigma dB_t.$$

Which has transition density:

$$\kappa_{t+1} | \kappa_t \sim N\Big(\mu_t + \left(\kappa_t - \mu_t\right)e^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$

Where \(\mu_t = \log(I(t))\).

We modify this approach since it implies that the \(\mu_t\) is treated as constant between observation times.

Instead we redefine \(\kappa_t\) as the log-residual:

$$\kappa_t = \log(\lambda_t / I(t)).$$

With the transition density:

$$\kappa_{t+1} | \kappa_t \sim N\Big(\kappa_te^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big).$$

This is an AR(1) process.

The stochastic model is completed:

$$\begin{aligned} Y_t &\sim \text{Poisson}(\lambda_t)\\ \lambda_t &= I(t)\exp(\kappa_t)\\ \beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ \gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ S(0) /N &\sim \text{Beta}(0.5, 0.5)\\ \phi & \sim \text{HalfNormal}(0, 100) \\ 1 / \sigma^2 & \sim \text{InvGamma}(0.1,0.1). \end{aligned}$$

We will using the AR struct from EpiAware to define the auto-regressive process in this model which has a direct parameterisation of the AR model.

To convert from the formulation above we sample from the priors, and define HalfNormal priors based on the sampled prior means of \(e^{-\phi}\) and \({\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\). We also add a strong prior that \(\kappa_1 \approx 0\).

ϕs = rand(truncated(Normal(0, 100), lower = 0.0), 1000)
1000-element Vector{Float64}:
  84.27394515942191
  13.516491690956862
  51.07348186961277
  37.941468070981934
 128.41727813505105
  43.06012859066134
  62.31804897315879
   ⋮
  56.57116875489856
 158.33706887743045
  42.72304061442974
   7.423694327684998
 155.60429115685992
  22.802727733585563
σ²s = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x
1000-element Vector{Float64}:
 0.0016224742151858818
 6.79221353591839e-9
 6.207746413070522e-7
 0.18882277475797452
 0.0001662633660039789
 0.1923483831345634
 0.14764829136880042
 ⋮
 0.06624877782984823
 0.14836794638364514
 0.00021895942825830565
 2.209773387224151
 0.06613574232694587
 0.0026714312973339926
sampled_AR_damps = ϕs .|> ϕ -> exp(-ϕ)
1000-element Vector{Float64}:
 2.5135680594819346e-37
 1.3485350660539842e-6
 6.592781044298219e-23
 3.3283560716429985e-17
 1.6946683748176592e-56
 1.991699264693254e-19
 8.622142732783223e-28
 ⋮
 2.7005584094809084e-25
 1.7182434846473966e-69
 2.7900964146464195e-19
 0.0005969397758191972
 2.641891576222659e-68
 1.249974556559806e-10
sampled_AR_stds = map(ϕs, σ²s) do ϕ, σ²
    (1 - exp(-2 * ϕ)) * σ² / (2 * ϕ)
end
1000-element Vector{Float64}:
 9.626191179946722e-6
 2.5125652762581625e-10
 6.0772696376159436e-9
 0.00248834302358464
 6.473559026423481e-7
 0.002233485935017376
 0.001184635059999989
 ⋮
 0.0005855348164793897
 0.00046851930326718863
 2.562545000417783e-6
 0.14883240757631075
 0.00021251259150776393
 5.857701167477672e-5

We define the AR(1) process by matching means of HalfNormal prior distributions for the damp parameters and std deviation parameter to the calculated the prior means from the Chatzilena et al definition.

ar = AR(
    damp_priors = [HalfNormal(mean(sampled_AR_damps))],
    std_prior = HalfNormal(mean(sampled_AR_stds)),
    init_priors = [Normal(0, 0.001)]
)
AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1)

We can sample directly from the behaviour specified by the ar struct to do prior predictive checking on the AR(1) process.

let
    nobs = size(data, 1)
    ar_mdl = generate_latent(ar, nobs)
    fig = Figure()
    ax = Axis(fig[1, 1],
        xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string),
        ylabel = "exp(kt)",
        title = "Prior predictive sampling for relative residual in mean pred."
    )
    for i in 1:500
        lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15))
    end
    fig
end

We see that the choice of priors implies an a priori belief that the extra observation noise on the mean prediction of the ODE model is fairly small, approximately 10% relative to the mean prediction.

We can now define the probabilistic model. The stochastic model assumes a (random) time-varying ascertainment, which we implement using the Ascertainment struct from EpiAware. Note that instead of implementing an ascertainment factor exp.(κₜ) directly, which can be unstable for large primal values, by default Ascertainment uses the LogExpFunctions.xexpy function which implements \(x\exp(y)\) stabily for a wide range of values.

To distinguish random variables sampled by various sub-processes EpiAware process types create prefixes. The default for Ascertainment is just the string "Ascertainment", but in this case we use the less verbose "va" for "varying ascertainment".

mdl_prefix = "va"
"va"

Now we can construct our time varying ascertianment model. The main keyword arguments here are model and latent_model. model sets the connection between the expected observation and the actual observation. In this case, we reuse our PoissonError model from above. latent_model sets the modification model on the expected values. In this case, we use the AR process we defined above.

varying_ascertainment = Ascertainment(
    model = obs,
    latent_model = ar,
    latent_prefix = mdl_prefix
)
Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va")

Now we can declare the full model in the Turing PPL.

@model function stochastic_ode_mdl(y_t, ts, obs, prob, N;
        solver = AutoTsit5(Rosenbrock23())
)

    ##Priors##
    β ~ LogNormal(0.0, 1.0)
    γ ~ Gamma(0.004, 1 / 0.002)
    S₀ ~ Beta(0.5, 0.5)

    ##Remake ODE model##
    _prob = remake(prob;
        u0 = [S₀, 1 - S₀, 0.0],
        p = [β, γ]
    )

    ##Solve ODE model##
    sol = solve(_prob, solver;
        saveat = ts,
        verbose = false
    )
    λt = log1pexp.(N * sol[2, :])

    ##Observation##
    @submodel generated_y_t = generate_observations(obs, y_t, λt)

    ##Generated quantities##
    return (; sol, generated_y_t, R0 = β / γ)
end
stochastic_ode_mdl (generic function with 2 methods)
stochastic_mdl = stochastic_ode_mdl(
    data.in_bed,
    data.ts,
    varying_ascertainment,
    sir_prob,
    N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:y_t, :ts, :obs, :prob, :N), (:solver,), (), Tuple{Vector{Int64}, Vector{Float64}, Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}}, DynamicPPL.DefaultContext}(stochastic_ode_mdl, (y_t = [3, 8, 26, 76, 225, 298, 258, 233, 189, 128, 68, 29, 14, 4], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], obs = Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va"), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)),), DynamicPPL.DefaultContext())
stochastic_uncond_mdl = stochastic_ode_mdl(
    fill(missing, length(data.in_bed)),
    data.ts,
    varying_ascertainment,
    sir_prob,
    N
)
DynamicPPL.Model{typeof(stochastic_ode_mdl), (:y_t, :ts, :obs, :prob, :N), (:solver,), (), Tuple{Vector{Missing}, Vector{Float64}, Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Int64}, Tuple{CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}}, DynamicPPL.DefaultContext}(stochastic_ode_mdl, (y_t = [missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing, missing], ts = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], obs = Ascertainment{PoissonError, AbstractTuringLatentModel, EpiAware.EpiObsModels.var"#10#16", String}(PoissonError(), PrefixLatentModel{AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}, String}(AR{Product{Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}, HalfNormal{Float64}, DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}, Int64}(Distributions.Product{Distributions.Continuous, HalfNormal{Float64}, FillArrays.Fill{HalfNormal{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(HalfNormal{Float64}(μ=0.004725237126863895), 1)), HalfNormal{Float64}(μ=0.0184303247003225), DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}(m=[0.0], σ=0.001), 1), "va"), EpiAware.EpiObsModels.var"#10#16"(), "va"), prob = ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, SciMLBase.NullParameters, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(ODEFunction{true, SciMLBase.AutoSpecialize, typeof(sir!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(sir!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [755.37, 7.63, 0.0], (0.0, 14.0), SciMLBase.NullParameters(), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), N = 763), (solver = CompositeAlgorithm{0, Tuple{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}}((Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!)), AutoSwitch{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, Rational{Int64}, Int64}(Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), Rosenbrock23{0, true, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}, true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!), 10, 3, 9//10, 9//10, 2, false, 5)),), DynamicPPL.DefaultContext())

Prior predictive checking

let
    prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000)
    gens = generated_quantities(stochastic_uncond_mdl, prior_chn)
    plot_predYt(data, gens;
        title = "Prior predictive: stochastic model",
        ylabel = "Number of Infected students"
    )
end

The prior predictive checking again shows misaligned prior beliefs; for example a priori without data we would not expect the median prediction of number of ill children as about 600 out of 763 after 1 day.

The latent process for the log-residuals \(\kappa_t\) doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. We do this by first making an initial guess which is a mixture of:

  1. The posterior averages from the deterministic model.

  2. The prior averages of the structure parameters of the AR(1) process.

  3. Zero for the time-varying noise underlying the AR(1) process.

rand(stochastic_mdl)
(β = 1.4733099145592605, γ = 2.903750758256854e-123, S₀ = 0.29861836011258897, var"va.σ_AR" = 0.04830504386163741, var"va.ar_init" = [-0.00024863122657975786], var"va.damp_AR" = [0.0032571734979405884], var"va.ϵ_t" = [0.3072398792156006, -1.183649965567883, 2.771050948892893, -0.6366422192999562, 1.6191332959597484, 0.24589190588482895, 1.4615005554123257, 0.025353011915720307, 0.16407599045634794, 0.2628599221133207, -1.0048884450877293, 1.96700665270484, -0.7501415436101209])
initial_guess = [[mean(chn[:β]),
                     mean(chn[:γ]),
                     mean(chn[:S₀]),
                     mean(ar.std_prior),
                     mean(ar.init_prior)[1],
                     mean(ar.damp_prior)[1]]
                 zeros(13)]
19-element Vector{Float64}:
 1.8942148283773665
 0.48062141906187955
 0.9995061985155343
 0.0184303247003225
 0.0
 0.004725237126863895
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

Starting from the initial guess, the MAP point is calculated rapidly in one pass.

map_fit_stoch_mdl = maximum_a_posteriori(stochastic_mdl;
    adtype = AutoReverseDiff(),
    initial_params = initial_guess
)
ModeResult with maximized lp of -69.56
[1.9168299393361845, 0.48970414611353336, 0.9995563465737589, 0.06675569399815419, 1.3740571956377935e-6, 0.0001575540787629065, 0.14269438141615298, 0.17055297942269643, -0.2985981800136858, 0.6377161610468507, -0.008381874345876414, -0.5911576922746032, 0.7987402092625825, 1.7391572458510847, 1.438270023561078, 0.2451580311597817, -0.6799723083106257, -0.7437116269695498, -0.8064297334547238]

Now we can run NUTS, sampling 1000 posterior draws per chain for 4 chains.

chn2 = sample(
    stochastic_mdl,
    NUTS(; adtype = AutoReverseDiff(true)),
    MCMCThreads(), 1000, 4;
    initial_params = fill(map_fit_stoch_mdl.values.array, 4)
)
iterationchainβγS₀va.σ_ARva.ar_init[1]va.damp_AR[1]...
150111.837730.4710080.9994320.04927260.0007763430.0106267
250211.908090.4938360.9995130.0416465-0.0005222020.00238174
350311.839780.4847940.99940.03295260.000107380.00806697
450411.838730.4820940.9993830.028583-0.0008433940.0110785
550511.860470.4813080.9994220.0841-0.0005604550.00169988
650611.880820.5048210.9994950.0357006-0.0006445810.0140381
750711.859060.4923250.9994330.0699084-0.001080520.000536058
850811.92070.4872750.9995680.0541636-0.001149440.0120675
950912.026920.5074930.9996990.0512195-0.0002897860.00205155
1051011.938710.504990.9996380.0592997-0.0006681490.0108368
...
describe(chn2)
2-element Vector{ChainDataFrame}:
 Summary Statistics (19 x 8)
 Quantiles (19 x 6)
pairplot(chn2[[:β, :γ, :S₀, Symbol(mdl_prefix * ".σ_AR"),
    Symbol(mdl_prefix * ".ar_init[1]"), Symbol(mdl_prefix * ".damp_AR[1]")]])
let
    vars = mapreduce(vcat, 1:13) do i
        Symbol(mdl_prefix * ".ϵ_t[$i]")
    end
    pairplot(chn2[vars])
end
let
    gens = generated_quantities(stochastic_uncond_mdl, chn2)
    plot_predYt(data, gens;
        title = "Fitted stochastic model",
        ylabel = "Number of Infected students"
    )
end