Fitting distributions using EpiAware
and Turing PPL
Introduction
What are we going to do in this Vignette
In this vignette, we'll demonstrate how to use the CDF function for censored delay distributions EpiAwareUtils.∫F
, which underlies EpiAwareUtils.censored_pmf
in conjunction with the Turing PPL for Bayesian inference of epidemiological delay distributions. We'll cover the following key points:
Simulating censored delay distribution data
Fitting a naive model using Turing
Evaluating the naive model's performance
Fitting an improved model using censored delay functionality from
EpiAware
.Comparing the censored delay model's performance to the naive model
What might I need to know before starting
This note builds on the concepts introduced in the R/stan package primarycensoreddist
, especially the Fitting distributions using primarycensorseddist and cmdstan vignette and assumes familiarity with using Turing tools as covered in the Turing documentation.
This note is generated using the EpiAware
package locally via Pkg.develop
, in the EpiAware/docs
environment. It is also possible to install EpiAware
using
Pkg.add(url="https://github.com/CDCgov/Rt-without-renewal", subdir="EpiAware")
Packages used in this vignette
As well as EpiAware
and Turing
we will use Makie
ecosystem packages for plotting and DataFramesMeta
for data manipulation.
let
docs_dir = dirname(dirname(dirname(@__DIR__)))
using Pkg: Pkg
Pkg.activate(docs_dir)
Pkg.instantiate()
end
The other dependencies are as follows:
begin
using EpiAware.EpiAwareUtils: censored_pmf, censored_cdf, ∫F
using Random, Distributions, StatsBase #utilities for random events
using DataFramesMeta #Data wrangling
using CairoMakie, PairPlots #plotting
using Turing #PPL
end
Simulating censored and truncated delay distribution data
We'll start by simulating some censored and truncated delay distribution data. We’ll define a rpcens
function for generating data.
Random.seed!(123) # For reproducibility
TaskLocalRNG()
Define the true distribution parameters
n = 2000
2000
meanlog = 1.5
1.5
sdlog = 0.75
0.75
true_dist = LogNormal(meanlog, sdlog)
Distributions.LogNormal{Float64}(μ=1.5, σ=0.75)
Generate varying pwindow, swindow, and obs_time lengths
pwindows = rand(1:2, n)
2000-element Vector{Int64}: 2 2 2 1 2 1 1 ⋮ 1 2 1 2 2 2
swindows = rand(1:2, n)
2000-element Vector{Int64}: 1 2 2 1 2 1 1 ⋮ 2 2 2 1 1 2
obs_times = rand(8:10, n)
2000-element Vector{Int64}: 10 9 9 10 9 8 8 ⋮ 8 9 9 10 8 8
We recreate the primary censored sampling function from primarycensoreddist
, c.f. documentation here.
"""
function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)
Does a truncated censored sample from `dist` with a uniform primary time on `[0, pwindow]`.
"""
function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)
T = zero(eltype(dist))
invalid_sample = true
attempts = 1
while (invalid_sample && attempts <= max_tries)
X = rand(dist)
U = rand() * pwindow
T = X + U
attempts += 1
if X + U < D
invalid_sample = false
end
end
@assert !invalid_sample "censored value not found in $max_tries attempts"
return (T ÷ swindow) * swindow
end
#Sample secondary time relative to beginning of primary censor window respecting the right-truncation
samples = map(pwindows, swindows, obs_times) do pw, sw, ot
rpcens(true_dist; pwindow = pw, swindow = sw, D = ot)
end
2000-element Vector{Float64}: 4.0 2.0 2.0 2.0 4.0 3.0 6.0 ⋮ 4.0 6.0 2.0 6.0 4.0 4.0
Aggregate to unique combinations and count occurrences
delay_counts = mapreduce(vcat, pwindows, swindows, obs_times, samples) do pw, sw, ot, s
DataFrame(
pwindow = pw,
swindow = sw,
obs_time = ot,
observed_delay = s,
observed_delay_upper = s + sw
)
end |>
df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay,
:observed_delay_upper) |>
gd -> @combine(gd, :n=length(:pwindow))
pwindow | swindow | obs_time | observed_delay | observed_delay_upper | n | |
---|---|---|---|---|---|---|
1 | 1 | 1 | 8 | 0.0 | 1.0 | 1 |
2 | 1 | 1 | 8 | 1.0 | 2.0 | 13 |
3 | 1 | 1 | 8 | 2.0 | 3.0 | 32 |
4 | 1 | 1 | 8 | 3.0 | 4.0 | 29 |
5 | 1 | 1 | 8 | 4.0 | 5.0 | 34 |
6 | 1 | 1 | 8 | 5.0 | 6.0 | 26 |
7 | 1 | 1 | 8 | 6.0 | 7.0 | 19 |
8 | 1 | 1 | 8 | 7.0 | 8.0 | 14 |
9 | 1 | 1 | 9 | 0.0 | 1.0 | 2 |
10 | 1 | 1 | 9 | 1.0 | 2.0 | 5 |
... | ||||||
80 | 2 | 2 | 10 | 8.0 | 10.0 | 22 |
Compare the samples with and without secondary censoring to the true distribution and calculate empirical CDF
empirical_cdf = ecdf(samples)
ECDF{Vector{Float64}, Weights{Float64, Float64, Vector{Float64}}}([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0], Float64[])
empirical_cdf_obs = ecdf(delay_counts.observed_delay, weights = delay_counts.n)
ECDF{Vector{Float64}, Weights{Int64, Int64, Vector{Int64}}}([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0], [1, 2, 2, 13, 16, 21, 1, 13, 13, 9 … 9, 10, 9, 13, 17, 15, 12, 22, 8, 7])
x_seq = range(minimum(samples), maximum(samples), 100)
0.0:0.09090909090909091:9.0
theoretical_cdf = x_seq |> x -> cdf(true_dist, x)
100-element Vector{Float64}: 0.0 1.011597608751049e-7 9.643132895117507e-6 9.484054524759167e-5 0.0004058100212574347 0.0011393531997368723 0.0024911102275376566 ⋮ 0.8052522612515658 0.8091156793117527 0.8128920005554523 0.8165833494282897 0.8201917991805499 0.8237193727611859
let
f = Figure()
ax = Axis(f[1, 1],
title = "Comparison of Observed vs Theoretical CDF",
ylabel = "Cumulative Probability",
xlabel = "Delay"
)
lines!(
ax, x_seq, empirical_cdf_obs, label = "Empirical CDF", color = :blue, linewidth = 2)
lines!(ax, x_seq, theoretical_cdf, label = "Theoretical CDF",
color = :black, linewidth = 2)
vlines!(ax, [mean(samples)], color = :blue, linestyle = :dash,
label = "Empirical mean", linewidth = 2)
vlines!(ax, [mean(true_dist)], linestyle = :dash,
label = "Theoretical mean", color = :black, linewidth = 2)
axislegend(position = :rb)
f
end
We've aggregated the data to unique combinations of pwindow
, swindow
, and obs_time
and counted the number of occurrences of each observed_delay
for each combination. This is the data we will use to fit our model.
Fitting a naive model using Turing
We'll start by fitting a naive model using NUTS from Turing
. We define the model in the Turing
PPL.
@model function naive_model(N, y, n)
mu ~ Normal(1.0, 1.0)
sigma ~ truncated(Normal(0.5, 1.0); lower = 0.0)
d = LogNormal(mu, sigma)
for i in eachindex(y)
Turing.@addlogprob! n[i] * logpdf(d, y[i])
end
end
naive_model (generic function with 2 methods)
Now lets instantiate this model with data
naive_mdl = naive_model(
size(delay_counts, 1),
delay_counts.observed_delay .+ 1e-6, # Add a small constant to avoid log(0)
delay_counts.n)
DynamicPPL.Model{typeof(naive_model), (:N, :y, :n), (), (), Tuple{Int64, Vector{Float64}, Vector{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(naive_model, (N = 80, y = [1.0e-6, 1.000001, 2.000001, 3.000001, 4.000001, 5.000001, 6.000001, 7.000001, 1.0e-6, 1.000001 … 1.0e-6, 2.000001, 4.000001, 6.000001, 8.000001, 1.0e-6, 2.000001, 4.000001, 6.000001, 8.000001], n = [1, 13, 32, 29, 34, 26, 19, 14, 2, 5 … 13, 69, 59, 30, 12, 9, 69, 48, 29, 22]), NamedTuple(), DynamicPPL.DefaultContext())
and now let's fit the compiled model.
naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4)
iteration | chain | mu | sigma | lp | n_steps | is_accept | acceptance_rate | ... | |
---|---|---|---|---|---|---|---|---|---|
1 | 251 | 1 | 0.569828 | 3.16687 | -6326.42 | 3.0 | 1.0 | 0.889234 | |
2 | 252 | 1 | 0.51554 | 3.21602 | -6327.17 | 3.0 | 1.0 | 0.748919 | |
3 | 253 | 1 | 0.699124 | 3.15787 | -6327.72 | 7.0 | 1.0 | 0.895135 | |
4 | 254 | 1 | 0.523035 | 3.16482 | -6326.8 | 3.0 | 1.0 | 0.728358 | |
5 | 255 | 1 | 0.497583 | 3.17839 | -6327.16 | 3.0 | 1.0 | 0.840027 | |
6 | 256 | 1 | 0.548956 | 3.15474 | -6326.61 | 3.0 | 1.0 | 0.792753 | |
7 | 257 | 1 | 0.569335 | 3.16453 | -6326.43 | 3.0 | 1.0 | 1.0 | |
8 | 258 | 1 | 0.606234 | 3.2018 | -6326.53 | 3.0 | 1.0 | 0.977859 | |
9 | 259 | 1 | 0.530658 | 3.17006 | -6326.69 | 3.0 | 1.0 | 0.900204 | |
10 | 260 | 1 | 0.627117 | 3.11173 | -6327.41 | 7.0 | 1.0 | 0.761508 | |
... |
summarize(naive_fit)
parameters | mean | std | mcse | ess_bulk | ess_tail | rhat | ess_per_sec | |
---|---|---|---|---|---|---|---|---|
1 | :mu | 0.584129 | 0.0704067 | 0.00151449 | 2153.02 | 1487.54 | 1.00003 | 318.212 |
2 | :sigma | 3.17766 | 0.0495775 | 0.00113542 | 1905.04 | 1306.94 | 1.00127 | 281.561 |
let
f = pairplot(naive_fit)
vlines!(f[1, 1], [meanlog], linewidth = 4)
vlines!(f[2, 2], [sdlog], linewidth = 4)
f
end
We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. mu
is smaller than the target 1.5 and sigma
is larger than the target 0.75.
Fitting an improved model using censoring utilities
We'll now fit an improved model using the ∫F
function from EpiAware.EpiAwareUtils
for calculating the CDF of the total delay from the beginning of the primary window to the secondary event time. This includes both the delay distribution we are making inference on and the time between the start of the primary censor window and the primary event. The ∫F
function underlies censored_pmf
function from the EpiAware.EpiAwareUtils
submodule.
Using the ∫F
function we can write a log-pmf function primary_censored_dist_lpmf
that accounts for:
The primary and secondary censoring windows, which can vary in length.
The effect of right truncation in biasing our observations.
This is the analog function to the function of the same name in primarycensoreddist
: it calculates the log-probability of the secondary event occurring in the secondary censoring window conditional on the primary event occurring in the primary censoring window by calculating the increase in the CDF over the secondary window and rescaling by the probability of the secondary event occuring within the maximum observation time D
.
function primary_censored_dist_lpmf(dist, y, pwindow, y_upper, D)
if y == 0.0
return log(∫F(dist, y_upper, pwindow)) - log(∫F(dist, D, pwindow))
else
return log(∫F(dist, y_upper, pwindow) - ∫F(dist, y, pwindow)) -
log(∫F(dist, D, pwindow))
end
end
primary_censored_dist_lpmf (generic function with 1 method)
We make a new Turing
model that now uses primary_censored_dist_lpmf
rather than the naive uncensored and untruncated logpdf
.
@model function primarycensoreddist_model(y, y_upper, n, pws, Ds)
mu ~ Normal(1.0, 1.0)
sigma ~ truncated(Normal(0.5, 0.5); lower = 0.0)
dist = LogNormal(mu, sigma)
for i in eachindex(y)
Turing.@addlogprob! n[i] * primary_censored_dist_lpmf(
dist, y[i], pws[i], y_upper[i], Ds[i])
end
end
primarycensoreddist_model (generic function with 2 methods)
Lets instantiate this model with data
primarycensoreddist_mdl = primarycensoreddist_model(
delay_counts.observed_delay,
delay_counts.observed_delay_upper,
delay_counts.n,
delay_counts.pwindow,
delay_counts.obs_time
)
DynamicPPL.Model{typeof(primarycensoreddist_model), (:y, :y_upper, :n, :pws, :Ds), (), (), Tuple{Vector{Float64}, Vector{Float64}, Vector{Int64}, Vector{Int64}, Vector{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(primarycensoreddist_model, (y = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0 … 0.0, 2.0, 4.0, 6.0, 8.0, 0.0, 2.0, 4.0, 6.0, 8.0], y_upper = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0 … 2.0, 4.0, 6.0, 8.0, 10.0, 2.0, 4.0, 6.0, 8.0, 10.0], n = [1, 13, 32, 29, 34, 26, 19, 14, 2, 5 … 13, 69, 59, 30, 12, 9, 69, 48, 29, 22], pws = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], Ds = [8, 8, 8, 8, 8, 8, 8, 8, 9, 9 … 9, 9, 9, 9, 9, 10, 10, 10, 10, 10]), NamedTuple(), DynamicPPL.DefaultContext())
Now let’s fit the compiled model.
primarycensoreddist_fit = sample(
primarycensoreddist_mdl, NUTS(), MCMCThreads(), 1000, 4)
iteration | chain | mu | sigma | lp | n_steps | is_accept | acceptance_rate | ... | |
---|---|---|---|---|---|---|---|---|---|
1 | 501 | 1 | 1.46819 | 0.771804 | -3376.41 | 3.0 | 1.0 | 1.0 | |
2 | 502 | 1 | 1.46877 | 0.738944 | -3375.1 | 3.0 | 1.0 | 0.999202 | |
3 | 503 | 1 | 1.49735 | 0.740651 | -3376.39 | 3.0 | 1.0 | 0.741842 | |
4 | 504 | 1 | 1.47618 | 0.762895 | -3375.61 | 3.0 | 1.0 | 0.983406 | |
5 | 505 | 1 | 1.48132 | 0.74067 | -3375.47 | 3.0 | 1.0 | 0.852127 | |
6 | 506 | 1 | 1.40746 | 0.711968 | -3375.63 | 7.0 | 1.0 | 0.914995 | |
7 | 507 | 1 | 1.44329 | 0.747661 | -3375.55 | 7.0 | 1.0 | 0.89498 | |
8 | 508 | 1 | 1.43568 | 0.734698 | -3375.17 | 3.0 | 1.0 | 0.977821 | |
9 | 509 | 1 | 1.42456 | 0.696408 | -3375.79 | 3.0 | 1.0 | 0.941795 | |
10 | 510 | 1 | 1.46966 | 0.758485 | -3375.46 | 5.0 | 1.0 | 1.0 | |
... |
summarize(primarycensoreddist_fit)
parameters | mean | std | mcse | ess_bulk | ess_tail | rhat | ess_per_sec | |
---|---|---|---|---|---|---|---|---|
1 | :mu | 1.45168 | 0.0355514 | 0.00108636 | 1087.88 | 1420.31 | 1.00169 | 61.0142 |
2 | :sigma | 0.733008 | 0.0274925 | 0.00082729 | 1110.27 | 1643.86 | 1.00173 | 62.2695 |
let
f = pairplot(primarycensoreddist_fit)
CairoMakie.vlines!(f[1, 1], [meanlog], linewidth = 3)
CairoMakie.vlines!(f[2, 2], [sdlog], linewidth = 3)
f
end
We see that the model has converged and the diagnostics look good. We also see that the posterior means are very near the true parameters and the 90% credible intervals include the true parameters.