API reference¶
Download the Nextstrain metadata file, preprocess it, and export it.
Two datasets are exported: one for model fitting and one for evaluation. The model dataset contains sequences collected and reported by a specified forecast date, while the evaluation dataset extends the horizon into the future.
To change default behaviors, create a yaml configuration file with the key ["data"],
and pass it in the call to this script. For a list of configurable sub-keys, see the
DEFAULT_CONFIG dictionary.
The output is given in Apache Parquet format, with columns date, fd_offset,
division, lineage, count. Rows are uniquely identified by
(date, division, lineage). date and fd_offset can be computed from each other,
given the forecast date; the fd_offset column is the number of days between the
forecast date and the date column, such that, for example, 0 is the forecast date,
-1 the day before, and 1 the day after.
Note that observations without a recorded date are removed, and only observations from human hosts are included.
DEFAULT_CONFIG = {'data': {'nextstrain_source': 'https://data.nextstrain.org/files/ncov/open/metadata.tsv.zst', 'use_usher': True, 'usher_lag': 168, 'usher_root': 'https://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/UShER_SARS-CoV-2/', 'use_cladecombiner_as_of': True, 'custom_relabelings': None, 'cache_dir': '.cache/', 'save_file': {'model': 'data/metadata-model.parquet', 'eval': 'data/metadata-eval.parquet'}, 'redownload': False, 'nextstrain_lineage_column_name': 'clade_nextstrain', 'usher_lineage_column_name': 'Nextstrain_clade', 'forecast_date': {'year': datetime.now().year, 'month': datetime.now().month, 'day': datetime.now().day}, 'horizon': {'lower': -90, 'upper': 14}, 'included_divisions': ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Puerto Rico', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'Washington DC', 'West Virginia', 'Wisconsin', 'Wyoming'], 'lineages': []}}
module-attribute
¶
Default configuration for data download, preprocessing, and export.
The configuration dictionary expects all of the following entries in a
data key.
hhs_regions = {'Connecticut': 'HHS-1', 'Maine': 'HHS-1', 'Massachusetts': 'HHS-1', 'New Hampshire': 'HHS-1', 'Rhode Island': 'HHS-1', 'Vermont': 'HHS-1', 'New Jersey': 'HHS-2', 'New York': 'HHS-2', 'Puerto Rico': 'HHS-2', 'Virgin Islands': 'HHS-2', 'Delaware': 'HHS-3', 'Washington DC': 'HHS-3', 'Maryland': 'HHS-3', 'Pennsylvania': 'HHS-3', 'Virginia': 'HHS-3', 'West Virginia': 'HHS-3', 'Alabama': 'HHS-4', 'Florida': 'HHS-4', 'Georgia': 'HHS-4', 'Kentucky': 'HHS-4', 'Mississippi': 'HHS-4', 'North Carolina': 'HHS-4', 'South Carolina': 'HHS-4', 'Tennessee': 'HHS-4', 'Illinois': 'HHS-5', 'Indiana': 'HHS-5', 'Michigan': 'HHS-5', 'Minnesota': 'HHS-5', 'Ohio': 'HHS-5', 'Wisconsin': 'HHS-5', 'Arkansas': 'HHS-6', 'Louisiana': 'HHS-6', 'New Mexico': 'HHS-6', 'Oklahoma': 'HHS-6', 'Texas': 'HHS-6', 'Iowa': 'HHS-7', 'Kansas': 'HHS-7', 'Missouri': 'HHS-7', 'Nebraska': 'HHS-7', 'Colorado': 'HHS-8', 'Montana': 'HHS-8', 'North Dakota': 'HHS-8', 'South Dakota': 'HHS-8', 'Utah': 'HHS-8', 'Wyoming': 'HHS-8', 'Arizona': 'HHS-9', 'California': 'HHS-9', 'Hawaii': 'HHS-9', 'Nevada': 'HHS-9', 'American Samoa': 'HHS-9', 'Commonwealth of the Northern Mariana Islands': 'HHS-9', 'Federated States of Micronesia': 'HHS-9', 'Guam': 'HHS-9', 'Marshall Islands': 'HHS-9', 'Republic of Palau': 'HHS-9', 'Alaska': 'HHS-10', 'Idaho': 'HHS-10', 'Oregon': 'HHS-10', 'Washington': 'HHS-10'}
module-attribute
¶
Dictionary form of https://www.hhs.gov/about/agencies/iea/regional-offices/index.html, except that DC is Washington DC not District of Columbia
CountsFrame
¶
Bases: DataFrame
A polars.DataFrame which enforces a format for observed counts of lineages.
See REQUIRED_COLUMNS for the expected columns.
combine_clades(df, as_of, lineage_col='lineage')
¶
Uses CDCGov/cladecombiner to recode the stated "lineage" such that any Nextstrain clade which was not recognized (had not yet been named) by the as-of date is put in its ancestor which was recognized.
Useful for when taxonomy shifts (a new clade is named) within a few
months of the desired forecast_date for better-approximating
retrospectively running the pipeline.
process_nextstrain(fp, forecast_date, config)
¶
Reads in Nextstrain data from (uncompressed) Nextstrain metadata file, performs basic filtering and date wrangling.
recode_clades_using_usher(ns, usher_path, usher_lineage_from, lineage_to='lineage')
¶
Replaces the "lineage" column in the input ns (Nextstrain) dataframe using clades called by UShER, as read from an UShER metadata file.
Performs matching based on Genbank accessions. Unmatched entries are dropped.
Useful for better-approximating retrospectively running an analysis as UShER metadata, including Nextstrain clade calls, are archived as far back as mid 2021.
BaselineModel
¶
Bases: MultinomialModel
Multinomial model assuming independence between divisions and specifying a constant proportion over time for each division-lineage.
Observations are counts of lineages for each division-day. See https://doi.org/10.1101/2023.01.02.23284123 No parameters are constrained here, so specific coefficients are not identifiable.
data: A DataFrame with the standard model input format. N: A Series of total counts (across lineages) for each observation. Only required if a generative model is desired; lineage counts will then be ignored. num_lineages: The number of lineages. Only required if a generative model is desired; lineage counts will then be ignored.
dense_mass()
¶
For use by numpyro.infer.NUTS, specification of structure of mass matrix.
Defaults to diagonal mass matrix.
ignore_nan_in()
¶
For use by get_convergence, list of model "sites" where NaN convergence results are expected, such as a constant value.
Defaults to assuming NaNs are unexpected.
CorrelatedDeviationsModel
¶
Bases: HierarchicalDivisionsModel
An extension of the simple hierarchical model that adds correlations among the lineages in how they deviate across sites from the global mean slopes, as in https://doi.org/10.1101/2023.01.02.23284123
ignore_nan_in()
¶
For use by get_convergence, list of model "sites" where NaN convergence results are expected, such as a constant value.
Defaults to assuming NaNs are unexpected.
ForecastFrame
¶
Bases: DataFrame
A polars.DataFrame which enforces a format for probabilistic forecast samples of
population-level lineage proportions.
See REQUIRED_COLUMNS for the expected columns.
HierarchicalDivisionsModel
¶
Bases: MultinomialModel
Multinomial regression model with information sharing over divisions.
Observations are counts of lineages for each division-day. No parameters are constrained here, so specific coefficients are not identifiable.
data: A DataFrame with the standard model input format.
N: A Series of total counts (across lineages) for each observation.
Only required if a generative model is desired; lineage counts will
then be ignored.
num_lineages: The number of lineages. Only required if a generative model is
desired; lineage counts will then be ignored.
pool_intercepts A float in [0,1] which determines how strongly the intercepts are pooled.
Determines the proportion of prior variance on per-division intercepts which
comes from the prior variance on the shared hierarchical mean.
pool_slopes Equivalent of pool_intercepts for slopes.
ignore_nan_in()
¶
For use by get_convergence, list of model "sites" where NaN convergence results are expected, such as a constant value.
Defaults to assuming NaNs are unexpected.
IndependentDivisionsModel
¶
Bases: MultinomialModel
Multinomial regression model assuming independence between divisions and specifying an intercept and slope on time for each division-lineage.
Observations are counts of lineages for each division-day. See https://doi.org/10.1101/2023.01.02.23284123 No parameters are constrained here, so specific coefficients are not identifiable.
data: A DataFrame with the standard model input format. N: A Series of total counts (across lineages) for each observation. Only required if a generative model is desired; lineage counts will then be ignored. num_lineages: The number of lineages. Only required if a generative model is desired; lineage counts will then be ignored.
ignore_nan_in()
¶
For use by get_convergence, list of model "sites" where NaN convergence results are expected, such as a constant value.
Defaults to assuming NaNs are unexpected.
MultinomialModel
¶
Bases: ABC
create_forecasts(mcmc, fd_offsets)
abstractmethod
¶
Generate a data frame of forecasted population-level proportions.
mcmc: The MCMC. fd_offsets: The (relative) days on which to generate forecasted proportions.
dense_mass()
¶
For use by numpyro.infer.NUTS, specification of structure of mass matrix.
Defaults to diagonal mass matrix.
ignore_nan_in()
¶
For use by get_convergence, list of model "sites" where NaN convergence results are expected, such as a constant value.
Defaults to assuming NaNs are unexpected.
numpyro_model()
abstractmethod
¶
A NumPyro model suitable for use as model argument to numpyro.infer.NUTS.
expand_grid(**columns)
¶
Create a DataFrame from all combinations of given columns.
Operates like the R function tidyr::expand_grid.
multinomial_likelihood(beta_0, beta_1, divisions, time, N=None)
¶
Distribution of observations for multinomial regression model for a single human population. Observations are counts of lineages for each time.
beta_0 (np.ndarray): Intercept, shape (num_divisions, num_lineages) beta_1 (np.ndarray): Slope on time, shape (num_divisions, num_lineages) divisions (np.ndarray): Division index for each observation, length (num_observations) time (np.ndarray): Times, length (num_observations) N (np.ndarray): Total counts across lineages, length (num_observations)
pl_softmax(pl_expr)
¶
Computes the softmax of the column pl_expr.
CountsEvaluator
¶
__init__(samples, data, count_sampler='multinomial', seed=None)
¶
Evaluates count forecasts \(\hat{Y}\) sampled from a specified observation model given model proportion forecasts.
count_sampler should be one of the keys in CountsEvaluator._count_samplers.
seed is an optional random seed for the count sampler.
energy_score(filters=None, p=2)
¶
The energy score of count forecasts, summed over all divisions and days.
mean_norm(filters=None, p=1)
¶
The expected norm of count forecast error, summed over all divisions and days.
\(\sum_{t, g} E[ || \hat{Y}_{tg} - Y_{tg} ||_p ]\)
uncovered_proportion(filters=None, alpha=0.11)
¶
Proportion of all lineage observation counts on all division-days not covered by the (central) 1 - alpha prediction interval.
ProportionsEvaluator
¶
mean_norm(filters=None, p=1)
¶
The expected norm of proportion forecast error, summed over all divisions and days.
\(\sum_{t, g} E[ || f_{tg} - \phi_{tg} ||_p ]\)
multinomial_count_sampler(n, p, rng)
¶
Samples from multinomial for multiple rows in one call.
n: 1-D array-like of total counts for each draw. p: 2-D array-like where each row is a probability vector for that draw. rng: a numpy.random.Generator used for reproducible sampling.
Returns an (n_rows, n_lineages) ndarray of integer counts.
optional_filter(df, filters)
¶
Filters the data based on a dict of column name to allowable values.
None is interpreted as all values allowed