Skip to content

Datasets

load_generation_interval

load_generation_interval() -> DataFrame

Load the generation interval dataset

This dataset contains the generation interval distribution for COVID-19.

Returns:

Type Description
DataFrame

The generation interval dataset

Notes

This dataset was downloaded directly from: https://raw.githubusercontent.com/CDCgov/wastewater-informed-covid-forecasting/0962c5d1652787479ac72caebf076ab55fe4e10c/input/saved_pmfs/generation_interval.csv

The dataset contains the following columns: - timepoint - probability_mass

Source code in pyrenew/datasets/generation_interval.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def load_generation_interval() -> pl.DataFrame:
    """
    Load the generation interval dataset

    This dataset contains the generation interval distribution for COVID-19.

    Returns
    -------
    pl.DataFrame
        The generation interval dataset

    Notes
    -----
    This dataset was downloaded directly from:
    https://raw.githubusercontent.com/CDCgov/wastewater-informed-covid-forecasting/0962c5d1652787479ac72caebf076ab55fe4e10c/input/saved_pmfs/generation_interval.csv

    The dataset contains the following columns:
        - `timepoint`
        - `probability_mass`
    """

    # Load the dataset
    return pl.read_csv(
        source=files("pyrenew.datasets") / "generation_interval.tsv",
        separator="\t",
    )

load_hospital_data_for_state

load_hospital_data_for_state(
    state_abbr: str = "CA", filename: str = "2023-11-06.csv"
) -> dict

Load hospital admissions data for a specific state.

Parameters:

Name Type Description Default
state_abbr str

State abbreviation (e.g., "CA"). Default is "CA".

'CA'
filename str

CSV filename. Default is "2023-11-06.csv".

'2023-11-06.csv'

Returns:

Type Description
dict

Dictionary containing:

  • daily_admits: JAX array of daily hospital admissions
  • population: Population size (scalar)
  • dates: List of datetime.date objects
  • n_days: Number of days
Notes

Data source: CDC cfa-forecast-renewal-ww repository. License: Public Domain (CC0 1.0 Universal) - U.S. Government work.

Source code in pyrenew/datasets/hospital_admissions.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def load_hospital_data_for_state(
    state_abbr: str = "CA",
    filename: str = "2023-11-06.csv",
) -> dict:
    """
    Load hospital admissions data for a specific state.

    Parameters
    ----------
    state_abbr : str
        State abbreviation (e.g., "CA"). Default is "CA".
    filename : str
        CSV filename. Default is "2023-11-06.csv".

    Returns
    -------
    dict
        Dictionary containing:

        - daily_admits: JAX array of daily hospital admissions
        - population: Population size (scalar)
        - dates: List of datetime.date objects
        - n_days: Number of days

    Notes
    -----
    Data source: CDC cfa-forecast-renewal-ww repository.
    License: Public Domain (CC0 1.0 Universal) - U.S. Government work.
    """
    data_path = files("pyrenew.datasets.hospital_admissions_data") / filename
    df = pl.read_csv(source=data_path)

    df = (
        df.with_columns(pl.col("date").str.to_date())
        .filter(pl.col("location") == state_abbr)
        .sort("date")
    )

    if len(df) == 0:
        raise ValueError(f"No data found for state {state_abbr} in {filename}")

    daily_admits = jnp.array(df["daily_hosp_admits"].to_numpy())
    population = int(df["pop"][0])
    dates = df["date"].to_list()

    return {
        "daily_admits": daily_admits,
        "population": population,
        "dates": dates,
        "n_days": len(daily_admits),
    }

load_infection_admission_interval

load_infection_admission_interval() -> DataFrame

Load the infection to admission interval

This dataset contains the infection to admission interval distribution for COVID-19.

Returns:

Type Description
DataFrame

The infection to admission interval dataset

Notes

This dataset was downloaded directly from: https://raw.githubusercontent.com/CDCgov/wastewater-informed-covid-forecasting/0962c5d1652787479ac72caebf076ab55fe4e10c/input/saved_pmfs/inf_to_hosp.csv

The dataset contains the following columns: - timepoint - probability_mass

Source code in pyrenew/datasets/infection_admission_interval.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def load_infection_admission_interval() -> pl.DataFrame:
    """
    Load the infection to admission interval

    This dataset contains the infection to admission interval distribution for
    COVID-19.

    Returns
    -------
    pl.DataFrame
        The infection to admission interval dataset

    Notes
    -----
    This dataset was downloaded directly from:
    https://raw.githubusercontent.com/CDCgov/wastewater-informed-covid-forecasting/0962c5d1652787479ac72caebf076ab55fe4e10c/input/saved_pmfs/inf_to_hosp.csv

    The dataset contains the following columns:
        - `timepoint`
        - `probability_mass`
    """

    # Load the dataset
    return pl.read_csv(
        source=files("pyrenew.datasets") / "infection_admission_interval.tsv",
        separator="\t",
    )

load_wastewater

load_wastewater() -> DataFrame

Load the wastewater dataset. This dataset contains simulated entries of COVID-19 wastewater concentration data. The dataset is used to demonstrate the use of the wastewater-informed COVID-19 forecasting model.

Returns:

Type Description
DataFrame

The wastewater dataset.

Notes

This dataset was downloaded directly from: https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/292526383ece582f10823fc939c7e590ca349c6d/cfaforecastrenewalww/data/example_df.rda

The dataset contains the following columns: - lab_wwtp_unique_id - log_conc - date - lod_sewage - below_lod - daily_hosp_admits - daily_hosp_admits_for_eval - pop - forecast_date - hosp_calibration_time - site - ww_pop - inf_per_capita

Source code in pyrenew/datasets/wastewater.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def load_wastewater() -> pl.DataFrame:  # numpydoc ignore=SS06,SA01,EX01
    """
    Load the wastewater dataset. This dataset
    contains simulated entries of
    COVID-19 wastewater concentration data.
    The dataset is used to demonstrate the use of
    the wastewater-informed COVID-19 forecasting model.

    Returns
    -------
    pl.DataFrame
        The wastewater dataset.

    Notes
    -----
    This dataset was downloaded directly from:
    https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/292526383ece582f10823fc939c7e590ca349c6d/cfaforecastrenewalww/data/example_df.rda

    The dataset contains the following columns:
        - `lab_wwtp_unique_id`
        - `log_conc`
        - `date`
        - `lod_sewage`
        - `below_lod`
        - `daily_hosp_admits`
        - `daily_hosp_admits_for_eval`
        - `pop`
        - `forecast_date`
        - `hosp_calibration_time`
        - `site`
        - `ww_pop`
        - `inf_per_capita`
    """

    # Load the dataset
    return pl.read_csv(
        source=files("pyrenew.datasets") / "wastewater.tsv",
        separator="\t",
        try_parse_dates=True,
    )

load_wastewater_data_for_state

load_wastewater_data_for_state(
    state_abbr: str = "CA", filename: str = "fake_nwss.csv"
) -> dict

Load wastewater data for a specific state.

Parameters:

Name Type Description Default
state_abbr str

State abbreviation (e.g., "CA"). Default is "CA".

'CA'
filename str

CSV filename. Default is "fake_nwss.csv".

'fake_nwss.csv'

Returns:

Type Description
dict

Dictionary containing:

  • observed_conc: JAX array of log concentrations (log copies/mL)
  • observed_conc_linear: JAX array of linear concentrations (copies/mL)
  • site_ids: JAX array of site indices
  • time_indices: JAX array of time indices (days from start)
  • wwtp_names: List of unique WWTP names
  • dates: List of unique dates
  • n_sites: Number of unique sites
  • n_obs: Number of observations
  • raw_df: Polars DataFrame (for debugging)
Notes

Data source: CDC cfa-forecast-renewal-ww repository. License: Public Domain (CC0 1.0 Universal) - U.S. Government work.

The data is synthetic and contains deliberately added noise for public release. Concentrations are in copies/L and are converted to copies/mL (divided by 1000).

Source code in pyrenew/datasets/wastewater_nwss.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def load_wastewater_data_for_state(
    state_abbr: str = "CA",
    filename: str = "fake_nwss.csv",
) -> dict:
    """
    Load wastewater data for a specific state.

    Parameters
    ----------
    state_abbr : str
        State abbreviation (e.g., "CA"). Default is "CA".
    filename : str
        CSV filename. Default is "fake_nwss.csv".

    Returns
    -------
    dict
        Dictionary containing:

        - observed_conc: JAX array of log concentrations (log copies/mL)
        - observed_conc_linear: JAX array of linear concentrations (copies/mL)
        - site_ids: JAX array of site indices
        - time_indices: JAX array of time indices (days from start)
        - wwtp_names: List of unique WWTP names
        - dates: List of unique dates
        - n_sites: Number of unique sites
        - n_obs: Number of observations
        - raw_df: Polars DataFrame (for debugging)

    Notes
    -----
    Data source: CDC cfa-forecast-renewal-ww repository.
    License: Public Domain (CC0 1.0 Universal) - U.S. Government work.

    The data is synthetic and contains deliberately added noise for
    public release. Concentrations are in copies/L and are converted
    to copies/mL (divided by 1000).
    """
    data_path = files("pyrenew.datasets.wastewater_nwss_data") / filename
    df = pl.read_csv(
        source=data_path,
        schema_overrides={"county_names": pl.String},
    )
    df = df.with_columns(pl.col("sample_collect_date").str.to_date())

    # Filter to requested state
    df = df.filter(pl.col("wwtp_jurisdiction") == state_abbr)
    if len(df) == 0:
        raise ValueError(f"No wastewater data found for state {state_abbr}")

    # Convert copies/L to copies/mL
    df = df.with_columns(
        (pl.col("pcr_target_avg_conc") / 1000).alias("conc_linear"),
    )

    df = df.sort("sample_collect_date")
    unique_sites = sorted(df["wwtp_name"].unique().to_list())
    site_to_idx = {site: idx for idx, site in enumerate(unique_sites)}
    min_date = df["sample_collect_date"].min()
    df = df.with_columns(
        ((pl.col("sample_collect_date") - min_date).dt.total_days()).alias("time_idx")
    )
    df = df.with_columns(
        pl.col("wwtp_name").replace_strict(site_to_idx, default=None).alias("site_idx")
    )

    observed_conc_linear = jnp.array(df["conc_linear"].to_numpy())
    observed_log_conc = jnp.log(observed_conc_linear + 1e-8)
    site_ids = jnp.array(df["site_idx"].to_numpy(), dtype=jnp.int32)
    time_indices = jnp.array(df["time_idx"].to_numpy(), dtype=jnp.int32)

    return {
        "observed_conc": observed_log_conc,
        "observed_conc_linear": observed_conc_linear,
        "site_ids": site_ids,
        "time_indices": time_indices,
        "wwtp_names": unique_sites,
        "dates": sorted(df["sample_collect_date"].unique().to_list()),
        "n_sites": len(unique_sites),
        "n_obs": len(df),
        "raw_df": df,
    }