Fitting distributions using EpiAware and Turing PPL

Introduction

What are we going to do in this Vignette

In this vignette, we'll demonstrate how to use the CDF function for censored delay distributions EpiAwareUtils.∫F, which underlies EpiAwareUtils.censored_pmf in conjunction with the Turing PPL for Bayesian inference of epidemiological delay distributions. We'll cover the following key points:

  1. Simulating censored delay distribution data

  2. Fitting a naive model using Turing

  3. Evaluating the naive model's performance

  4. Fitting an improved model using censored delay functionality from EpiAware.

  5. Comparing the censored delay model's performance to the naive model

What might I need to know before starting

This note builds on the concepts introduced in the R/stan package primarycensoreddist, especially the Fitting distributions using primarycensorseddist and cmdstan vignette and assumes familiarity with using Turing tools as covered in the Turing documentation.

This note is generated using the EpiAware package locally via Pkg.develop, in the EpiAware/docs environment. It is also possible to install EpiAware using

Pkg.add(url="https://github.com/CDCgov/Rt-without-renewal", subdir="EpiAware")

Packages used in this vignette

As well as EpiAware and Turing we will use Makie ecosystem packages for plotting and DataFramesMeta for data manipulation.

let
    docs_dir = dirname(dirname(dirname(@__DIR__)))
    using Pkg: Pkg
    Pkg.activate(docs_dir)
    Pkg.instantiate()
end

The other dependencies are as follows:

begin
    using EpiAware.EpiAwareUtils: censored_pmf, censored_cdf, ∫F
    using Random, Distributions, StatsBase #utilities for random events
    using DataFramesMeta #Data wrangling
    using CairoMakie, PairPlots #plotting
    using Turing #PPL
end

Simulating censored and truncated delay distribution data

We'll start by simulating some censored and truncated delay distribution data. We’ll define a rpcens function for generating data.

Random.seed!(123) # For reproducibility
TaskLocalRNG()

Define the true distribution parameters

n = 2000
2000
meanlog = 1.5
1.5
sdlog = 0.75
0.75
true_dist = LogNormal(meanlog, sdlog)
Distributions.LogNormal{Float64}(μ=1.5, σ=0.75)

Generate varying pwindow, swindow, and obs_time lengths

pwindows = rand(1:2, n)
2000-element Vector{Int64}:
 2
 2
 2
 1
 2
 1
 1
 ⋮
 1
 2
 1
 2
 2
 2
swindows = rand(1:2, n)
2000-element Vector{Int64}:
 1
 2
 2
 1
 2
 1
 1
 ⋮
 2
 2
 2
 1
 1
 2
obs_times = rand(8:10, n)
2000-element Vector{Int64}:
 10
  9
  9
 10
  9
  8
  8
  ⋮
  8
  9
  9
 10
  8
  8

We recreate the primary censored sampling function from primarycensoreddist, c.f. documentation here.

"""
    function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)

Does a truncated censored sample from `dist` with a uniform primary time on `[0, pwindow]`.
"""
function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)
    T = zero(eltype(dist))
    invalid_sample = true
    attempts = 1
    while (invalid_sample && attempts <= max_tries)
        X = rand(dist)
        U = rand() * pwindow
        T = X + U
        attempts += 1
        if X + U < D
            invalid_sample = false
        end
    end

    @assert !invalid_sample "censored value not found in $max_tries attempts"

    return (T ÷ swindow) * swindow
end
#Sample secondary time relative to beginning of primary censor window respecting the right-truncation
samples = map(pwindows, swindows, obs_times) do pw, sw, ot
    rpcens(true_dist; pwindow = pw, swindow = sw, D = ot)
end
2000-element Vector{Float64}:
 4.0
 2.0
 2.0
 2.0
 4.0
 3.0
 6.0
 ⋮
 4.0
 6.0
 2.0
 6.0
 4.0
 4.0

Aggregate to unique combinations and count occurrences

delay_counts = mapreduce(vcat, pwindows, swindows, obs_times, samples) do pw, sw, ot, s
    DataFrame(
        pwindow = pw,
        swindow = sw,
        obs_time = ot,
        observed_delay = s,
        observed_delay_upper = s + sw
    )
end |>
               df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay,
    :observed_delay_upper) |>
                     gd -> @combine(gd, :n=length(:pwindow))
pwindowswindowobs_timeobserved_delayobserved_delay_uppern
11180.01.01
21181.02.013
31182.03.032
41183.04.029
51184.05.034
61185.06.026
71186.07.019
81187.08.014
91190.01.02
101191.02.05
...
8022108.010.022

Compare the samples with and without secondary censoring to the true distribution and calculate empirical CDF

empirical_cdf = ecdf(samples)
ECDF{Vector{Float64}, Weights{Float64, Float64, Vector{Float64}}}([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0], Float64[])
empirical_cdf_obs = ecdf(delay_counts.observed_delay, weights = delay_counts.n)
ECDF{Vector{Float64}, Weights{Int64, Int64, Vector{Int64}}}([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0], [1, 2, 2, 13, 16, 21, 1, 13, 13, 9  …  9, 10, 9, 13, 17, 15, 12, 22, 8, 7])
x_seq = range(minimum(samples), maximum(samples), 100)
0.0:0.09090909090909091:9.0
theoretical_cdf = x_seq |> x -> cdf(true_dist, x)
100-element Vector{Float64}:
 0.0
 1.011597608751049e-7
 9.643132895117507e-6
 9.484054524759167e-5
 0.0004058100212574347
 0.0011393531997368723
 0.0024911102275376566
 ⋮
 0.8052522612515658
 0.8091156793117527
 0.8128920005554523
 0.8165833494282897
 0.8201917991805499
 0.8237193727611859
let
    f = Figure()
    ax = Axis(f[1, 1],
        title = "Comparison of Observed vs Theoretical CDF",
        ylabel = "Cumulative Probability",
        xlabel = "Delay"
    )
    lines!(
        ax, x_seq, empirical_cdf_obs, label = "Empirical CDF", color = :blue, linewidth = 2)
    lines!(ax, x_seq, theoretical_cdf, label = "Theoretical CDF",
        color = :black, linewidth = 2)
    vlines!(ax, [mean(samples)], color = :blue, linestyle = :dash,
        label = "Empirical mean", linewidth = 2)
    vlines!(ax, [mean(true_dist)], linestyle = :dash,
        label = "Theoretical mean", color = :black, linewidth = 2)
    axislegend(position = :rb)

    f
end

We've aggregated the data to unique combinations of pwindow, swindow, and obs_time and counted the number of occurrences of each observed_delay for each combination. This is the data we will use to fit our model.

Fitting a naive model using Turing

We'll start by fitting a naive model using NUTS from Turing. We define the model in the Turing PPL.

@model function naive_model(N, y, n)
    mu ~ Normal(1.0, 1.0)
    sigma ~ truncated(Normal(0.5, 1.0); lower = 0.0)
    d = LogNormal(mu, sigma)

    for i in eachindex(y)
        Turing.@addlogprob! n[i] * logpdf(d, y[i])
    end
end
naive_model (generic function with 2 methods)

Now lets instantiate this model with data

naive_mdl = naive_model(
    size(delay_counts, 1),
    delay_counts.observed_delay .+ 1e-6, # Add a small constant to avoid log(0)
    delay_counts.n)
DynamicPPL.Model{typeof(naive_model), (:N, :y, :n), (), (), Tuple{Int64, Vector{Float64}, Vector{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(naive_model, (N = 80, y = [1.0e-6, 1.000001, 2.000001, 3.000001, 4.000001, 5.000001, 6.000001, 7.000001, 1.0e-6, 1.000001  …  1.0e-6, 2.000001, 4.000001, 6.000001, 8.000001, 1.0e-6, 2.000001, 4.000001, 6.000001, 8.000001], n = [1, 13, 32, 29, 34, 26, 19, 14, 2, 5  …  13, 69, 59, 30, 12, 9, 69, 48, 29, 22]), NamedTuple(), DynamicPPL.DefaultContext())

and now let's fit the compiled model.

naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4)
iterationchainmusigmalpn_stepsis_acceptacceptance_rate...
125110.5698283.16687-6326.423.01.00.889234
225210.515543.21602-6327.173.01.00.748919
325310.6991243.15787-6327.727.01.00.895135
425410.5230353.16482-6326.83.01.00.728358
525510.4975833.17839-6327.163.01.00.840027
625610.5489563.15474-6326.613.01.00.792753
725710.5693353.16453-6326.433.01.01.0
825810.6062343.2018-6326.533.01.00.977859
925910.5306583.17006-6326.693.01.00.900204
1026010.6271173.11173-6327.417.01.00.761508
...
summarize(naive_fit)
parametersmeanstdmcseess_bulkess_tailrhatess_per_sec
1:mu0.5841290.07040670.001514492153.021487.541.00003295.42
2:sigma3.177660.04957750.001135421905.041306.941.00127261.394
let
    f = pairplot(naive_fit)
    vlines!(f[1, 1], [meanlog], linewidth = 4)
    vlines!(f[2, 2], [sdlog], linewidth = 4)
    f
end

We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. mu is smaller than the target 1.5 and sigma is larger than the target 0.75.

Fitting an improved model using censoring utilities

We'll now fit an improved model using the ∫F function from EpiAware.EpiAwareUtils for calculating the CDF of the total delay from the beginning of the primary window to the secondary event time. This includes both the delay distribution we are making inference on and the time between the start of the primary censor window and the primary event. The ∫F function underlies censored_pmf function from the EpiAware.EpiAwareUtils submodule.

Using the ∫F function we can write a log-pmf function primary_censored_dist_lpmf that accounts for:

  • The primary and secondary censoring windows, which can vary in length.

  • The effect of right truncation in biasing our observations.

This is the analog function to the function of the same name in primarycensoreddist: it calculates the log-probability of the secondary event occurring in the secondary censoring window conditional on the primary event occurring in the primary censoring window by calculating the increase in the CDF over the secondary window and rescaling by the probability of the secondary event occuring within the maximum observation time D.

function primary_censored_dist_lpmf(dist, y, pwindow, y_upper, D)
    if y == 0.0
        return log(∫F(dist, y_upper, pwindow)) - log(∫F(dist, D, pwindow))
    else
        return log(∫F(dist, y_upper, pwindow) - ∫F(dist, y, pwindow)) -
               log(∫F(dist, D, pwindow))
    end
end
primary_censored_dist_lpmf (generic function with 1 method)

We make a new Turing model that now uses primary_censored_dist_lpmf rather than the naive uncensored and untruncated logpdf.

@model function primarycensoreddist_model(y, y_upper, n, pws, Ds)
    mu ~ Normal(1.0, 1.0)
    sigma ~ truncated(Normal(0.5, 0.5); lower = 0.0)
    dist = LogNormal(mu, sigma)

    for i in eachindex(y)
        Turing.@addlogprob! n[i] * primary_censored_dist_lpmf(
            dist, y[i], pws[i], y_upper[i], Ds[i])
    end
end
primarycensoreddist_model (generic function with 2 methods)

Lets instantiate this model with data

primarycensoreddist_mdl = primarycensoreddist_model(
    delay_counts.observed_delay,
    delay_counts.observed_delay_upper,
    delay_counts.n,
    delay_counts.pwindow,
    delay_counts.obs_time
)
DynamicPPL.Model{typeof(primarycensoreddist_model), (:y, :y_upper, :n, :pws, :Ds), (), (), Tuple{Vector{Float64}, Vector{Float64}, Vector{Int64}, Vector{Int64}, Vector{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(primarycensoreddist_model, (y = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0  …  0.0, 2.0, 4.0, 6.0, 8.0, 0.0, 2.0, 4.0, 6.0, 8.0], y_upper = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0  …  2.0, 4.0, 6.0, 8.0, 10.0, 2.0, 4.0, 6.0, 8.0, 10.0], n = [1, 13, 32, 29, 34, 26, 19, 14, 2, 5  …  13, 69, 59, 30, 12, 9, 69, 48, 29, 22], pws = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1  …  2, 2, 2, 2, 2, 2, 2, 2, 2, 2], Ds = [8, 8, 8, 8, 8, 8, 8, 8, 9, 9  …  9, 9, 9, 9, 9, 10, 10, 10, 10, 10]), NamedTuple(), DynamicPPL.DefaultContext())

Now let’s fit the compiled model.

primarycensoreddist_fit = sample(
    primarycensoreddist_mdl, NUTS(), MCMCThreads(), 1000, 4)
iterationchainmusigmalpn_stepsis_acceptacceptance_rate...
150111.468190.771804-3376.413.01.01.0
250211.468770.738944-3375.13.01.00.999202
350311.497350.740651-3376.393.01.00.741842
450411.476180.762895-3375.613.01.00.983406
550511.481320.74067-3375.473.01.00.852127
650611.407460.711968-3375.637.01.00.914995
750711.443290.747661-3375.557.01.00.89498
850811.435680.734698-3375.173.01.00.977821
950911.424560.696408-3375.793.01.00.941795
1051011.469660.758485-3375.465.01.01.0
...
summarize(primarycensoreddist_fit)
parametersmeanstdmcseess_bulkess_tailrhatess_per_sec
1:mu1.451680.03555140.001086361087.881420.311.0016960.5793
2:sigma0.7330080.02749250.000827291110.271643.861.0017361.8257
let
    f = pairplot(primarycensoreddist_fit)
    CairoMakie.vlines!(f[1, 1], [meanlog], linewidth = 3)
    CairoMakie.vlines!(f[2, 2], [sdlog], linewidth = 3)
    f
end

We see that the model has converged and the diagnostics look good. We also see that the posterior means are very near the true parameters and the 90% credible intervals include the true parameters.