Skip to content

API reference

vcf

vcf.CoverageModel

Bases: ABC

Source code in vcf/__init__.py
class CoverageModel(abc.ABC):
    @abc.abstractmethod
    def __init__(
        self,
        data: pl.DataFrame,
        forecast_date: datetime.date,
        params: dict[str, Any],
        season: dict[str, Any],
        quantiles: list[float],
    ):
        """Initialize a coverage forecasting model.

        Args:
            data: Input observations used for fitting and/or prediction.
            forecast_date: Training cutoff date.
            params: Model-specific parameter configuration.
            season: Seasonal boundary configuration (start/end month/day).
            quantiles: Quantiles to return during prediction.
        """
        pass

    @abc.abstractmethod
    def fit(self) -> Self:
        """Fit model parameters using available training data.

        Returns:
            Fitted model instance.
        """
        pass

    @abc.abstractmethod
    def predict(self) -> pl.DataFrame:
        """Generate predictions for configured quantiles.

        Returns:
            Data frame containing forecasts and identifying metadata.
        """
        pass

vcf.CoverageModel.__init__(data, forecast_date, params, season, quantiles) abstractmethod

Initialize a coverage forecasting model.

Parameters:

Name Type Description Default
data DataFrame

Input observations used for fitting and/or prediction.

required
forecast_date date

Training cutoff date.

required
params dict[str, Any]

Model-specific parameter configuration.

required
season dict[str, Any]

Seasonal boundary configuration (start/end month/day).

required
quantiles list[float]

Quantiles to return during prediction.

required
Source code in vcf/__init__.py
@abc.abstractmethod
def __init__(
    self,
    data: pl.DataFrame,
    forecast_date: datetime.date,
    params: dict[str, Any],
    season: dict[str, Any],
    quantiles: list[float],
):
    """Initialize a coverage forecasting model.

    Args:
        data: Input observations used for fitting and/or prediction.
        forecast_date: Training cutoff date.
        params: Model-specific parameter configuration.
        season: Seasonal boundary configuration (start/end month/day).
        quantiles: Quantiles to return during prediction.
    """
    pass

vcf.CoverageModel.fit() abstractmethod

Fit model parameters using available training data.

Returns:

Type Description
Self

Fitted model instance.

Source code in vcf/__init__.py
@abc.abstractmethod
def fit(self) -> Self:
    """Fit model parameters using available training data.

    Returns:
        Fitted model instance.
    """
    pass

vcf.CoverageModel.predict() abstractmethod

Generate predictions for configured quantiles.

Returns:

Type Description
DataFrame

Data frame containing forecasts and identifying metadata.

Source code in vcf/__init__.py
@abc.abstractmethod
def predict(self) -> pl.DataFrame:
    """Generate predictions for configured quantiles.

    Returns:
        Data frame containing forecasts and identifying metadata.
    """
    pass

vcf.LPLModel

Bases: CoverageModel

Subclass of CoverageModel for a mixed Logistic Plus Linear model. For details, see the online docs.

Source code in vcf/__init__.py
class LPLModel(CoverageModel):
    """
    Subclass of CoverageModel for a mixed Logistic Plus Linear model.
    For details, see the online docs.
    """

    def __init__(
        self,
        data: pl.DataFrame,
        forecast_date: datetime.date,
        params: dict[str, Any],
        season: dict[str, Any],
        quantiles: list[float],
        date_column: str = "time_end",
    ):
        """Initialize Logistic Plus Linear model.

        Args:
            data: Cumulative coverage data for fitting and prediction.
            forecast_date: Date used to split training and forecast periods.
            params: Model and sampler settings. Includes prior hyperparameters
                consumed by _logistic_plus_linear and MCMC controls such as
                num_warmup, num_samples, num_chains, progress_bar, and seed.
            season: Seasonal settings with start_month, start_day, end_month,
                and end_day.
            quantiles: Posterior quantiles to report.
            date_column: Name of the date column in the data. Defaults to "time_end".
        """
        self.raw_data = data
        self.date_column = date_column
        self.quantiles = quantiles
        self.forecast_date = forecast_date
        self.season = season

        # use parameters, separating MCMC and model fitting parameters
        self.params = params

        mcmc_keys = {"num_warmup", "num_samples", "num_chains", "progress_bar"}
        self.mcmc_params = {k: v for k, v in params.items() if k in mcmc_keys}
        self.model_params = {
            k: v
            for k, v in params.items()
            if k in inspect.signature(self._logistic_plus_linear).parameters
        }
        self.fit_key, self.pred_key = random.split(random.key(self.params["seed"]), 2)

        # input data validation
        assert {self.date_column, "estimate", "season", "geography"}.issubset(
            self.raw_data.columns
        )

        # preprocess data
        self.data = (
            self.raw_data
            # prepare observation data
            .with_columns(N_vax=(pl.col("N_tot") * pl.col("estimate")).round(0))
            # add interaction term
            .with_columns(
                season_geo=pl.concat_str(["season", "geography"], separator="_")
            )
            .with_columns(
                t=self._days_in_season(
                    pl.col(date_column),
                    season_start_month=self.season["start_month"],
                    season_start_day=self.season["start_day"],
                )
                / 365
            )
        )

        # set up encoder
        self.features = ("season", "geography", "season_geo")
        self.enc = OrdinalEncoder(dtype=np.int64).fit(
            self.data.select(self.features).to_numpy()
        )
        self.n_feature_levels = [len(x) for x in self.enc.categories_]  # type: ignore

        # initialize MCMC. `None` is a placeholder indicating fitting has not occurred
        self.mcmc = None

    @staticmethod
    def _days_in_season(
        date: pl.Expr, season_start_month: int, season_start_day: int
    ) -> pl.Expr:
        """Extract a time elapsed column from a date column, as polars expressions.

        Args:
            date: Dates
            season_start_month: First month of the overwinter disease season.
            season_start_day: First day of the first month of the overwinter disease season.

        Returns:
            number of days elapsed since the first date
        """
        # for every date, figure out the season breakpoint in that year
        season_start = pl.date(date.dt.year(), season_start_month, season_start_day)

        # for dates before the season breakpoint in year, subtract a year
        year = date.dt.year()
        season_start_year = pl.when(date < season_start).then(year - 1).otherwise(year)

        # rewrite the season breakpoints to that immediately before each date
        season_start = pl.date(season_start_year, season_start_month, season_start_day)

        # return the number of days from season start to each date
        return (date - season_start).dt.total_days()

    def model(self, data: pl.DataFrame):
        """Build the NumPyro model call for a given data slice.

        Args:
            data: Preprocessed data with columns t, N_tot, feature columns, and
                optionally N_vax for observed likelihood evaluation.
        """
        if "N_vax" in data.columns:
            N_vax = jnp.array(data["N_vax"])
        else:
            N_vax = None

        return self._logistic_plus_linear(
            N_vax=N_vax,
            t=jnp.array(data["t"]),
            # jax runs into a problem if you don't specify this type
            N_tot=jnp.array(data["N_tot"], dtype=jnp.int32),
            feature_levels=jnp.array(
                self.enc.transform(data.select(self.features).to_numpy())
            ),
            **self.model_params,
        )

    def _logistic_plus_linear(
        self,
        N_vax: jnp.ndarray | None,
        t: jnp.ndarray,
        N_tot: jnp.ndarray,
        feature_levels: jnp.ndarray,
        muA_shape1: float,
        muA_shape2: float,
        sigmaA_rate: float,
        tau_shape1: float,
        tau_shape2: float,
        K_shape: float,
        K_rate: float,
        muM_shape: float,
        muM_rate: float,
        sigmaM_rate: float,
        D_shape: float,
        D_rate: float,
    ):
        """
        Logistic Plus Linear model

        Args:
            t: Fraction of a year elapsed since the start of season at each data point.
            N_vax: Number of people vaccinated at each data point, or `None`.
            N_tot: Total number of people in the population at each data point.
            feature_levels: Numeric codes for feature levels: row = data point, col = feature.
            muA_shape1: Beta distribution shape1 parameter for muA prior.
            muA_shape2: Beta distribution shape2 parameter for muA prior.
            sigmaA_rate: Exponential distribution rate parameter for sigmaA prior.
            tau_shape1: Beta distribution shape1 parameter for tau prior.
            tau_shape2: Beta distribution shape2 parameter for tau prior.
            K_shape: Gamma distribution shape parameter for K prior.
            K_rate: Gamma distribution rate parameter for K prior.
            muM_shape: Gamma distribution shape parameter for muM prior.
            muM_rate: Gamma distribution rate parameter for muM prior.
            sigmaM_rate: Exponential distribution rate parameter for sigmaM prior.
            D_shape: Gamma distribution shape parameter for D prior.
            D_rate: Gamma distribution rate parameter for D prior.
        """
        # Sample the overall average value for each parameter
        muA = numpyro.sample("muA", dist.Beta(muA_shape1, muA_shape2))
        muM = numpyro.sample("muM", dist.Gamma(muM_shape, muM_rate))
        tau = numpyro.sample("tau", dist.Beta(tau_shape1, tau_shape2))
        K = numpyro.sample("K", dist.Gamma(K_shape, K_rate))
        D = numpyro.sample("D", dist.Gamma(D_shape, D_rate))

        sigmaA = numpyro.sample(
            "sigmaA", dist.Exponential(sigmaA_rate), sample_shape=(len(self.features),)
        )
        sigmaM = numpyro.sample(
            "sigmaM", dist.Exponential(sigmaM_rate), sample_shape=(len(self.features),)
        )
        zA = numpyro.sample(
            "zA", dist.Normal(0, 1), sample_shape=(sum(self.n_feature_levels),)
        )
        zM = numpyro.sample(
            "zM", dist.Normal(0, 1), sample_shape=(sum(self.n_feature_levels),)
        )

        v = self._vgt(
            t=t,
            feature_levels=feature_levels,
            muA=muA,
            sigmaA=sigmaA,
            zA=zA,
            muM=muM,
            sigmaM=sigmaM,
            zM=zM,
            K=K,
            tau=tau,
        )

        numpyro.sample("obs", dist.BetaBinomial(v * D, (1 - v) * D, N_tot), obs=N_vax)  # type: ignore

    def _vgt(self, t, feature_levels, muA, sigmaA, zA, muM, sigmaM, zM, K, tau):
        """Compute latent coverage trajectory v_g(t) for each row.

        Args:
            t: Time since season start in years.
            feature_levels: Encoded feature-level indices for each row.
            muA: Global intercept mean.
            sigmaA: Feature-level intercept scales.
            zA: Standard-normal offsets for intercept effects.
            muM: Global linear slope mean.
            sigmaM: Feature-level slope scales.
            zM: Standard-normal offsets for slope effects.
            K: Logistic growth rate.
            tau: Logistic midpoint.

        Returns:
            Vector of latent coverage means for each row.
        """
        deltaA = zA * np.repeat(sigmaA, np.array(self.n_feature_levels))
        deltaM = zM * np.repeat(sigmaM, np.array(self.n_feature_levels))

        A = muA + np.sum(deltaA[feature_levels], axis=1)
        M = muM + np.sum(deltaM[feature_levels], axis=1)

        return A / (1 + jnp.exp(-K * (t - tau))) + (M * t)  # type: ignore

    def fit(self) -> Self:
        """Fit a mixed Logistic Plus Linear model on training data.

        A hierarchical model is built with feature-level effects for logistic
        maximum and linear slope, which induce group-specific trajectories.
        Other parameters are non-hierarchical.

        Uses the data, features, model_params, and mcmc_params specified during
        initialization.

        Returns:
            Self with the fitted model stored in the mcmc attribute.
        """
        self.kernel = NUTS(self.model, init_strategy=init_to_sample)
        self.mcmc = MCMC(self.kernel, **self.mcmc_params)
        self.mcmc.run(
            self.fit_key,
            self.data.filter(pl.col(self.date_column) <= self.forecast_date),
        )

        if "progress_bar" in self.mcmc_params and self.mcmc_params["progress_bar"]:
            self.mcmc.print_summary()

        return self

    def predict(self) -> pl.DataFrame:
        """
        Make projections from a fit Logistic Plus Linear model.

        Returns:
            Sample forecast data frame with predictions for dates after forecast_date.
        """

        assert self.mcmc is not None, "Need to fit() first"

        quantile_preds = [
            pl.DataFrame({"quantile": q, "estimate": self._predict_quantile(q)})
            for q in self.quantiles
        ]

        out_data = self.data.select("geography", "season", "time_end").with_columns(
            forecast_date=self.forecast_date
        )

        return pl.concat(
            [pl.concat([out_data, qp], how="horizontal") for qp in quantile_preds],
            how="vertical",
        )

    def _predict_quantile(self, q: float) -> np.ndarray:
        assert self.mcmc is not None
        # pull posterior samples for all parameters (except D) and get the desired quantile
        samples = self.mcmc.get_samples()
        n_samples = len(samples["K"])

        preds = np.stack(
            [
                self._vgt(
                    t=self.data["t"].to_numpy(),
                    feature_levels=self.enc.transform(
                        self.data.select(self.features).to_numpy()
                    ),
                    **{k: v[i,] for k, v in samples.items() if k != "D"},
                )
                for i in range(n_samples)
            ]
        )

        return np.quantile(preds, q=q, axis=0).astype(np.float64)

vcf.LPLModel.__init__(data, forecast_date, params, season, quantiles, date_column='time_end')

Initialize Logistic Plus Linear model.

Parameters:

Name Type Description Default
data DataFrame

Cumulative coverage data for fitting and prediction.

required
forecast_date date

Date used to split training and forecast periods.

required
params dict[str, Any]

Model and sampler settings. Includes prior hyperparameters consumed by _logistic_plus_linear and MCMC controls such as num_warmup, num_samples, num_chains, progress_bar, and seed.

required
season dict[str, Any]

Seasonal settings with start_month, start_day, end_month, and end_day.

required
quantiles list[float]

Posterior quantiles to report.

required
date_column str

Name of the date column in the data. Defaults to "time_end".

'time_end'
Source code in vcf/__init__.py
def __init__(
    self,
    data: pl.DataFrame,
    forecast_date: datetime.date,
    params: dict[str, Any],
    season: dict[str, Any],
    quantiles: list[float],
    date_column: str = "time_end",
):
    """Initialize Logistic Plus Linear model.

    Args:
        data: Cumulative coverage data for fitting and prediction.
        forecast_date: Date used to split training and forecast periods.
        params: Model and sampler settings. Includes prior hyperparameters
            consumed by _logistic_plus_linear and MCMC controls such as
            num_warmup, num_samples, num_chains, progress_bar, and seed.
        season: Seasonal settings with start_month, start_day, end_month,
            and end_day.
        quantiles: Posterior quantiles to report.
        date_column: Name of the date column in the data. Defaults to "time_end".
    """
    self.raw_data = data
    self.date_column = date_column
    self.quantiles = quantiles
    self.forecast_date = forecast_date
    self.season = season

    # use parameters, separating MCMC and model fitting parameters
    self.params = params

    mcmc_keys = {"num_warmup", "num_samples", "num_chains", "progress_bar"}
    self.mcmc_params = {k: v for k, v in params.items() if k in mcmc_keys}
    self.model_params = {
        k: v
        for k, v in params.items()
        if k in inspect.signature(self._logistic_plus_linear).parameters
    }
    self.fit_key, self.pred_key = random.split(random.key(self.params["seed"]), 2)

    # input data validation
    assert {self.date_column, "estimate", "season", "geography"}.issubset(
        self.raw_data.columns
    )

    # preprocess data
    self.data = (
        self.raw_data
        # prepare observation data
        .with_columns(N_vax=(pl.col("N_tot") * pl.col("estimate")).round(0))
        # add interaction term
        .with_columns(
            season_geo=pl.concat_str(["season", "geography"], separator="_")
        )
        .with_columns(
            t=self._days_in_season(
                pl.col(date_column),
                season_start_month=self.season["start_month"],
                season_start_day=self.season["start_day"],
            )
            / 365
        )
    )

    # set up encoder
    self.features = ("season", "geography", "season_geo")
    self.enc = OrdinalEncoder(dtype=np.int64).fit(
        self.data.select(self.features).to_numpy()
    )
    self.n_feature_levels = [len(x) for x in self.enc.categories_]  # type: ignore

    # initialize MCMC. `None` is a placeholder indicating fitting has not occurred
    self.mcmc = None

vcf.LPLModel.fit()

Fit a mixed Logistic Plus Linear model on training data.

A hierarchical model is built with feature-level effects for logistic maximum and linear slope, which induce group-specific trajectories. Other parameters are non-hierarchical.

Uses the data, features, model_params, and mcmc_params specified during initialization.

Returns:

Type Description
Self

Self with the fitted model stored in the mcmc attribute.

Source code in vcf/__init__.py
def fit(self) -> Self:
    """Fit a mixed Logistic Plus Linear model on training data.

    A hierarchical model is built with feature-level effects for logistic
    maximum and linear slope, which induce group-specific trajectories.
    Other parameters are non-hierarchical.

    Uses the data, features, model_params, and mcmc_params specified during
    initialization.

    Returns:
        Self with the fitted model stored in the mcmc attribute.
    """
    self.kernel = NUTS(self.model, init_strategy=init_to_sample)
    self.mcmc = MCMC(self.kernel, **self.mcmc_params)
    self.mcmc.run(
        self.fit_key,
        self.data.filter(pl.col(self.date_column) <= self.forecast_date),
    )

    if "progress_bar" in self.mcmc_params and self.mcmc_params["progress_bar"]:
        self.mcmc.print_summary()

    return self

vcf.LPLModel.model(data)

Build the NumPyro model call for a given data slice.

Parameters:

Name Type Description Default
data DataFrame

Preprocessed data with columns t, N_tot, feature columns, and optionally N_vax for observed likelihood evaluation.

required
Source code in vcf/__init__.py
def model(self, data: pl.DataFrame):
    """Build the NumPyro model call for a given data slice.

    Args:
        data: Preprocessed data with columns t, N_tot, feature columns, and
            optionally N_vax for observed likelihood evaluation.
    """
    if "N_vax" in data.columns:
        N_vax = jnp.array(data["N_vax"])
    else:
        N_vax = None

    return self._logistic_plus_linear(
        N_vax=N_vax,
        t=jnp.array(data["t"]),
        # jax runs into a problem if you don't specify this type
        N_tot=jnp.array(data["N_tot"], dtype=jnp.int32),
        feature_levels=jnp.array(
            self.enc.transform(data.select(self.features).to_numpy())
        ),
        **self.model_params,
    )

vcf.LPLModel.predict()

Make projections from a fit Logistic Plus Linear model.

Returns:

Type Description
DataFrame

Sample forecast data frame with predictions for dates after forecast_date.

Source code in vcf/__init__.py
def predict(self) -> pl.DataFrame:
    """
    Make projections from a fit Logistic Plus Linear model.

    Returns:
        Sample forecast data frame with predictions for dates after forecast_date.
    """

    assert self.mcmc is not None, "Need to fit() first"

    quantile_preds = [
        pl.DataFrame({"quantile": q, "estimate": self._predict_quantile(q)})
        for q in self.quantiles
    ]

    out_data = self.data.select("geography", "season", "time_end").with_columns(
        forecast_date=self.forecast_date
    )

    return pl.concat(
        [pl.concat([out_data, qp], how="horizontal") for qp in quantile_preds],
        how="vertical",
    )

vcf.RFModel

Bases: CoverageModel

Source code in vcf/__init__.py
class RFModel(CoverageModel):
    def __init__(
        self,
        data: pl.DataFrame,
        params: dict[str, Any],
        season: dict[str, Any],
        forecast_date: datetime.date,
        quantiles: list[float],
        date_column: str = "time_end",
    ):
        """Initialize random-forest forecasting model and feature matrices.

        Args:
            data: Input coverage data with season, geography, date, and estimate.
            params: Random forest settings and auxiliary configuration.
            season: Seasonal settings used to compute month index.
            forecast_date: Cutoff date for train/predict split.
            quantiles: Quantiles to compute from tree-level predictions.
            date_column: Date column name. Defaults to "time_end".
        """
        self.raw_data = data
        self.date_column = date_column
        self.forecast_date = forecast_date
        self.quantiles = quantiles
        self.season = season
        self.params = params

        # other params include max_depth, min_samples_split, min_samples_leaf
        rf_keys = {"n_estimators"}
        self.rf_params = {k: v for k, v in params.items() if k in rf_keys}

        data_t = self.raw_data.with_columns(
            t=pl.col(self.date_column).map_elements(self._month_in_season)
        ).sort(["season", "geography", "t"])

        # preprocessing
        self.date_crosswalk = data_t.select("season", date_column, "t").unique()

        self.data = (
            data_t.select(["season", "geography", "t", "estimate"])
            .pivot(on="t", values="estimate", sort_columns=True)
            .sort(["season", "geography"])
            .pipe(self._impute)
        )

        self.forecast_season = pl.select(
            to_season(
                pl.lit(self.forecast_date),
                season_start_month=self.season["start_month"],
                season_end_month=self.season["end_month"],
                season_end_day=self.season["end_day"],
                season_start_day=self.season["start_day"],
            )
        ).item()
        self.forecast_month = self._month_in_season(self.forecast_date)

    @staticmethod
    def _impute(
        df: pl.DataFrame, index_cols: tuple[str, ...] = ("season", "geography")
    ):
        """Impute missing estimates using nearest-neighbor imputation.

        Args:
            df: Wide data frame to impute.
            index_cols: Non-numeric identifying columns to preserve.

        Returns:
            Data frame with missing numeric values imputed.
        """
        to_impute_df = df.drop(index_cols)
        imputed_np = KNNImputer(n_neighbors=2).fit_transform(to_impute_df.to_numpy())
        imputed_df = pl.concat(
            [
                df.select(index_cols),
                pl.DataFrame(imputed_np, schema=to_impute_df.columns),
            ],
            how="horizontal",
        )
        assert imputed_df.null_count().sum_horizontal().item() == 0, (
            "Null remaining in data"
        )
        return imputed_df

    def _month_in_season(self, date: datetime.date) -> int:
        """Convert a date into zero-based month index within the season.

        Args:
            date: First day of a month.

        Returns:
            Month offset from season start.
        """
        assert date.day == 1
        year = date.year
        # start of a season that's in this year
        ssiy = datetime.date(year, self.season["start_month"], self.season["start_day"])

        # season start year
        if date < ssiy:
            ssy = year - 1
        else:
            ssy = year

        return (year - ssy) * 12 + (date.month - self.season["start_month"])

    def fit(self) -> Self:
        """Fit random forest on seasons before the forecast season.

        Returns:
            Self with fitted encoder and random forest model.
        """
        self.enc = _RFEncoder().fit(self.data)

        self.X_features = ["season", "geography"] + [
            str(t)
            for t in range(0, self.forecast_month + 1)
            if str(t) in self.data.columns
        ]
        self.y_features = [
            str(t)
            for t in range(self.forecast_month + 1, 12)
            if str(t) in self.data.columns
        ]

        # fit the model
        data_fit = self.data.filter(pl.col("season") < self.forecast_season)
        X_fit = self.enc.encode(data_fit.select(self.X_features))
        y_fit = data_fit.select(self.y_features).to_numpy()

        # sklearn complains if you pass a column vector rather than a 1d array
        if y_fit.shape[1] == 1:
            y_fit = y_fit.ravel()

        self.model = RandomForestRegressor(**self.rf_params).fit(X_fit, y_fit)

        return self

    def predict(self) -> pl.DataFrame:
        """Generate quantile forecasts using fitted random forest trees.

        Returns:
            Long-format forecast data frame with quantile-specific estimates.
        """
        # make the forecast
        data_pred = self.data.filter(pl.col("season") >= self.forecast_season)

        X_data = data_pred.select(self.X_features)
        assert X_data.shape[0] > 0, f"RF prediction for {self.forecast_date} failed."
        X_pred = self.enc.encode(X_data)

        # make predictions using each tree
        y_tree = np.stack([tree.predict(X_pred) for tree in self.model.estimators_])

        return pl.concat(
            [
                self._postprocess(
                    data_pred=data_pred,
                    y_pred=np.quantile(y_tree, q=q, axis=0),
                    quantile=q,
                )
                for q in self.quantiles
            ]
        )

    def _postprocess(
        self, data_pred: pl.DataFrame, y_pred: np.ndarray, quantile: float
    ) -> pl.DataFrame:
        """Reshape RF output into standardized forecast schema.

        Args:
            data_pred: Input rows being forecast.
            y_pred: Predicted month values in wide numeric array form.
            quantile: Quantile label attached to the prediction.

        Returns:
            Long-format forecast data frame with season/date metadata.
        """
        if len(y_pred.shape) == 1:
            y_pred = y_pred.reshape(-1, 1)

        return (
            data_pred.select(["season", "geography"])
            .hstack(pl.DataFrame(y_pred, schema=self.y_features))
            .unpivot(
                on=self.y_features,
                index=["season", "geography"],
                variable_name="t",
                value_name="estimate",
            )
            .with_columns(pl.col("t").cast(pl.Int64))
            .join(self.date_crosswalk, on=["season", "t"], how="left")
            .drop("t")
            .with_columns(forecast_date=self.forecast_date, quantile=quantile)
        )

vcf.RFModel.__init__(data, params, season, forecast_date, quantiles, date_column='time_end')

Initialize random-forest forecasting model and feature matrices.

Parameters:

Name Type Description Default
data DataFrame

Input coverage data with season, geography, date, and estimate.

required
params dict[str, Any]

Random forest settings and auxiliary configuration.

required
season dict[str, Any]

Seasonal settings used to compute month index.

required
forecast_date date

Cutoff date for train/predict split.

required
quantiles list[float]

Quantiles to compute from tree-level predictions.

required
date_column str

Date column name. Defaults to "time_end".

'time_end'
Source code in vcf/__init__.py
def __init__(
    self,
    data: pl.DataFrame,
    params: dict[str, Any],
    season: dict[str, Any],
    forecast_date: datetime.date,
    quantiles: list[float],
    date_column: str = "time_end",
):
    """Initialize random-forest forecasting model and feature matrices.

    Args:
        data: Input coverage data with season, geography, date, and estimate.
        params: Random forest settings and auxiliary configuration.
        season: Seasonal settings used to compute month index.
        forecast_date: Cutoff date for train/predict split.
        quantiles: Quantiles to compute from tree-level predictions.
        date_column: Date column name. Defaults to "time_end".
    """
    self.raw_data = data
    self.date_column = date_column
    self.forecast_date = forecast_date
    self.quantiles = quantiles
    self.season = season
    self.params = params

    # other params include max_depth, min_samples_split, min_samples_leaf
    rf_keys = {"n_estimators"}
    self.rf_params = {k: v for k, v in params.items() if k in rf_keys}

    data_t = self.raw_data.with_columns(
        t=pl.col(self.date_column).map_elements(self._month_in_season)
    ).sort(["season", "geography", "t"])

    # preprocessing
    self.date_crosswalk = data_t.select("season", date_column, "t").unique()

    self.data = (
        data_t.select(["season", "geography", "t", "estimate"])
        .pivot(on="t", values="estimate", sort_columns=True)
        .sort(["season", "geography"])
        .pipe(self._impute)
    )

    self.forecast_season = pl.select(
        to_season(
            pl.lit(self.forecast_date),
            season_start_month=self.season["start_month"],
            season_end_month=self.season["end_month"],
            season_end_day=self.season["end_day"],
            season_start_day=self.season["start_day"],
        )
    ).item()
    self.forecast_month = self._month_in_season(self.forecast_date)

vcf.RFModel.fit()

Fit random forest on seasons before the forecast season.

Returns:

Type Description
Self

Self with fitted encoder and random forest model.

Source code in vcf/__init__.py
def fit(self) -> Self:
    """Fit random forest on seasons before the forecast season.

    Returns:
        Self with fitted encoder and random forest model.
    """
    self.enc = _RFEncoder().fit(self.data)

    self.X_features = ["season", "geography"] + [
        str(t)
        for t in range(0, self.forecast_month + 1)
        if str(t) in self.data.columns
    ]
    self.y_features = [
        str(t)
        for t in range(self.forecast_month + 1, 12)
        if str(t) in self.data.columns
    ]

    # fit the model
    data_fit = self.data.filter(pl.col("season") < self.forecast_season)
    X_fit = self.enc.encode(data_fit.select(self.X_features))
    y_fit = data_fit.select(self.y_features).to_numpy()

    # sklearn complains if you pass a column vector rather than a 1d array
    if y_fit.shape[1] == 1:
        y_fit = y_fit.ravel()

    self.model = RandomForestRegressor(**self.rf_params).fit(X_fit, y_fit)

    return self

vcf.RFModel.predict()

Generate quantile forecasts using fitted random forest trees.

Returns:

Type Description
DataFrame

Long-format forecast data frame with quantile-specific estimates.

Source code in vcf/__init__.py
def predict(self) -> pl.DataFrame:
    """Generate quantile forecasts using fitted random forest trees.

    Returns:
        Long-format forecast data frame with quantile-specific estimates.
    """
    # make the forecast
    data_pred = self.data.filter(pl.col("season") >= self.forecast_season)

    X_data = data_pred.select(self.X_features)
    assert X_data.shape[0] > 0, f"RF prediction for {self.forecast_date} failed."
    X_pred = self.enc.encode(X_data)

    # make predictions using each tree
    y_tree = np.stack([tree.predict(X_pred) for tree in self.model.estimators_])

    return pl.concat(
        [
            self._postprocess(
                data_pred=data_pred,
                y_pred=np.quantile(y_tree, q=q, axis=0),
                quantile=q,
            )
            for q in self.quantiles
        ]
    )

vcf.eos_abs_diff(obs, pred, features)

Calculate the absolute difference between observed data and prediction for the last date in a season.

Parameters:

Name Type Description Default
obs DataFrame

Observed data.

required
pred DataFrame

Predicted data.

required
features list[str]

Feature column names (must include 'season').

required

Returns:

Type Description
DataFrame

Data frame with absolute difference scores for end-of-season dates.

Source code in vcf/__init__.py
def eos_abs_diff(
    obs: pl.DataFrame, pred: pl.DataFrame, features: list[str]
) -> pl.DataFrame:
    """Calculate the absolute difference between observed data and prediction for the last date in a season.

    Args:
        obs: Observed data.
        pred: Predicted data.
        features: Feature column names (must include 'season').

    Returns:
        Data frame with absolute difference scores for end-of-season dates.
    """
    assert "season" in features

    return (
        pred.filter(pl.col("quantile") == 0.5)
        .join(
            obs.filter((pl.col("time_end") == pl.col("time_end").max()).over(features)),
            on=["time_end"] + features,
            how="right",
        )
        .rename({"estimate_right": "obs", "estimate": "pred"})
        .with_columns(
            score_value=(pl.col("pred") - pl.col("obs")).abs(),
            score_fun=pl.lit("eos_abs_diff"),
        )
        .select(features + ["model", "forecast_date", "score_fun", "score_value"])
        .drop_nulls()
    )

vcf.to_season(date, season_start_month, season_end_month, season_start_day=1, season_end_day=1)

Identify the overwinter season from a date.

Every year, there is a season end (e.g., May 1) and a season start (e.g., Sep 1). Dates before the season end are associated with the prior season (e.g., Feb 1, 2020 belongs to 2019/2020 season). Dates after the season start are associated with the next season (e.g., Oct 1, 2020 belongs to 2020/2021). Dates between the season end and season start are not in any season (e.g., June 1).

Parameters:

Name Type Description Default
date Expr

dates

required
season_start_month int

first month

required
season_end_month int

last month

required
season_start_day int

first day

1
season_end_day int

last day

1

Returns:

Type Description
Expr

season like "2020/2021"

Source code in vcf/__init__.py
def to_season(
    date: pl.Expr,
    season_start_month: int,
    season_end_month: int,
    season_start_day: int = 1,
    season_end_day: int = 1,
) -> pl.Expr:
    """
    Identify the overwinter season from a date.

    Every year, there is a season end (e.g., May 1) and a season start (e.g., Sep 1).
    Dates before the season end are associated with the prior season (e.g., Feb 1, 2020
    belongs to 2019/2020 season). Dates after the season start are associated with the
    next season (e.g., Oct 1, 2020 belongs to 2020/2021). Dates between the season end
    and season start are not in any season (e.g., June 1).

    Args:
        date: dates
        season_start_month: first month
        season_end_month: last month
        season_start_day: first day
        season_end_day: last day

    Returns:
        season like "2020/2021"
    """
    assert (season_start_month, season_start_day) > (
        season_end_month,
        season_end_day,
    ), "Only overwinter seasons are supported"

    # year of this date
    y = date.dt.year()
    # start and end dates of seasons in this year
    end = pl.date(y, season_end_month, season_end_day)
    start = pl.date(y, season_start_month, season_start_day)

    # first year of the two-year season
    sy1 = pl.when(date <= end).then(y - 1).when(date >= start).then(y).otherwise(None)

    return pl.when(sy1.is_null()).then(None).otherwise(pl.format("{}/{}", sy1, sy1 + 1))