surv_bart

class bart_survival.surv_bart.BartSurvModel(model_config: Dict = None, sampler_config: Dict = None)

BART Survival Model

Returns:

BartSurvModel

Return type:

_type_

bart_predict(X_pred: ndarray, coords: ndarray, size: int = None, rng: RandomState = None, **kwargs) DataArray

Derives posterior predictions on updated dataset. Alternative method for re-loaded models.

Parameters:
  • X_pred (np.ndarray) – Covariate matrix in long-time format.

  • coords (np.ndarray) – Coordinates associated with long-time format.

  • size (int, optional) – Sets sample of posterior draws. Defaults to None.

  • rng (RandomState, optional) – Random number generator for reproducable results. Defaults to None.

Returns:

DataArray containing the predicted outputs from the model and covariate matrix.

Return type:

xr.DataArray

build_model(X: ndarray, y: ndarray, weights: ndarray, coords: ndarray, predictor_names: list, **kwargs)

Builds the PYMC base model.

Parameters:
  • X (np.ndarray) – Covariate matrix in long-time form.

  • y (np.ndarray) – Event status in long-time form.

  • weights (np.ndarray) – Array of weights.

  • coords (np.ndarray) – Array of coordinates associated with long-time form.

  • predictor_names (list) – List of names for variables.

Returns:

BartSurvModel

Return type:

_type_

fit(y: ndarray, X: ndarray, weights: ndarray = None, coords: ndarray = None, progressbar: bool = True, predictor_names: List[str] = None, random_seed: RandomState = None, **kwargs: Any) InferenceData

Call to build and train the data.

Parameters:
  • y (np.ndarray) – Event status in long-time form.

  • X (np.ndarray) – Covariate matrix in long-time form.

  • weights (np.ndarray) – Weights associated with each observation.

  • coords (np.ndarray) – Coordinates associated with long-time form.

  • progressbar (bool, optional) – Displays training progress. Defaults to True.

  • predictor_names (List[str], optional) – Names of covariates in X matrix. Defaults to None.

  • random_seed (RandomState, optional) – Seed. Defaults to None.

Returns:

_description_

Return type:

az.InferenceData

static get_default_sampler_config() Dict

Default Sampler Configuration.

Returns:

Default Sampler Configuration.

Return type:

Dict

property id: str

Generate a unique hash value for the model.

The hash value is created using the last 16 characters of the SHA256 hash encoding, based on the model configuration, version, and model type.

Returns:

A string of length 16 characters containing a unique hash of the model.

Return type:

str

classmethod load(fname: str, treename: str)

Loads a saved model.

Parameters:
  • fname (str) – Path to saved model object.

  • treename (str) – Path to saved tree object.

Returns:

Returns object of SurvBartModel class.

Return type:

_type_

sample_model(**kwargs) InferenceData

Initiates training/sampling of the model.

Returns:

Posterior data collected from training the model.

Return type:

az.InferenceData

sample_posterior_predictive(X_pred: ndarray, coords: ndarray, extend_idata: bool = False, **kwargs) Dataset

Derives posterior predictions on updated datasets.

Parameters:
  • X_pred (np.ndarray) – Covariate matrix in long-time format.

  • coords (np.ndarray) – Coordinates associated with long-time format.

  • extend_idata (bool, optional) – Adds results to the existing idata object. Defaults to False.

Returns:

Dataset containing the predicted outputs from the model and covariate matrix.

Return type:

xr.Dataset

save(idata_name: str, all_tree_name: str) None

Saves a trained model and tree object.

Parameters:
  • idata_name (str) – Path to saving model object.

  • all_tree_name (str) – Path to saving tree object.

set_idata_attrs(idata: InferenceData | None = None) InferenceData

Sets the additional information in the idata object.

Parameters:

idata (Optional[az.InferenceData], optional) – Idata. Defaults to None.

Returns:

Idata

Return type:

az.InferenceData

bart_survival.surv_bart.get_pdp(x: ndarray, var_col: List[int] = [], values: List[int | float] = [], qt: List[float] = [0.25, 0.5, 0.75], sample_n=None) ndarray | Dict

Generates data for Partial Dependency Plots.

Parameters:
  • x (np.ndarray) – Covariate matrix.

  • var_col (List[int], optional) – Covariate column to generate pdp values. Defaults to [].

  • values (List[Union[int,float]], optional) – Values to test on for each covariate. Defaults to [].

  • qt (List[float], optional) – Quantiles to generate values on if non are given. Defaults to [0.25,0.5,0.75].

  • sample_n (_type_, optional) – Sample size of large dataset. Useful for working with large datasets. Defaults to None.

Returns:

PDP covariate matrix and dictionary for tracking PDP components.

Return type:

Union[np.ndarray, Dict]

bart_survival.surv_bart.get_posterior_test(y_time: ndarray, y_status: ndarray, x: ndarray, time_scale=None) Dict

Generates long-time format for posterior distribution testing.

To analyze the posterior distribution, posterior predictive estimates need to be generated. Similar to the training data, the data for posterior predictions must also be in a long-time format.

Parameters:
  • y_time (np.ndarray) – Event time

  • y_status (nd.array) – Event status indicator (1 if event occurs, 0 if censored)

  • x_test (np.ndarray) – Covariate matrix.

  • time_scale (np.ndarray) – Time Scaling factor

Returns:

Covariate matrix in long-time format and associated coordinates for the long-time format.

Return type:

Dict

bart_survival.surv_bart.get_surv_pre_train(y_time: ndarray, y_status: ndarray, x: ndarray, weight: ndarray | None = None, time_scale=None) Dict

Generates long-time formatted event status and covariate matrix.

The SurvBartModel operates using a discrete time format. This means each observation is represented by a series of observations for each time point up to the event time.

Parameters:
  • y_time (np.ndarray) – Event time

  • y_status (nd.array) – Event status indicator (1 if event occurs, 0 if censored)

  • x (np.ndarray) – Covariate matrix.

  • weight (Optional[np.ndarray], optional) – Weights associated with each observation. If non provided, then each observation will have weights of 1. Defaults to None.

  • time_scale (Optional[int]) – time scaleing factor

Returns:

Dictionary containg all of the training data including an event status array, covariate matrix, weights and coordinates associated with the long-time format.

Return type:

Dict

bart_survival.surv_bart.get_sv_mean_quant(sv: ndarray, msk: ndarray, draws: bool = True, qntile: List[float] = [0.025, 0.975]) Dict

Generates mean and quantile estimates of Survival Probabilties.

Parameters:
  • sv (np.ndarray) – Survival probabilties.

  • msk (np.ndarray) – Mask for selecting values to average (class of a covariate).

  • draws (bool, optional) – If true, will average over the draws of the posterior, as well as masked values. Defaults to True.

  • qntile (List[float], optional) – Quantiles to average over. Defaults to [0.025, 0.975].

Returns:

Mask True mean, Mask True quantiles, Mask False mean, Mask False quantiles.

Return type:

Dict

bart_survival.surv_bart.get_sv_prob(post: DataArray | Dataset) Dict

Generates the Survival Probability estimates for time points.

Parameters:

post (Union[xr.DataArray, xr.Dataset]) – Posterior output.

Returns:

Risk Probability (Hazard) and Survival Probability for each draw, observation and time.

Return type:

Dict

bart_survival.surv_bart.pdp_diff_metric(pdp_val: Dict, idx: ndarray, qntile: List[float] = [0.025, 0.975]) Dict

Generate estimate of marginal difference from PDP posterior.

Parameters:
  • pdp_val (Dict) – Posterior of PDP predictions.

  • idx (np.ndarray) – Index of PDP sets.

  • qntile (List[float], optional) – Quantile values for Credible Interval. Defaults to [0.025, 0.975].

Returns:

Survival Probability Difference Mean, Survival Probability Difference Quantiles.

Return type:

Dict

bart_survival.surv_bart.pdp_rr_metric(pdp_val: Dict, idx: ndarray, qntile: list = [0.025, 0.975]) Dict

Generates a Risk Ratio (Hazard Ratio) from pdp values.

Parameters:
  • pdp_val (Dict) – Posterior of PDP predictions.

  • idx (np.ndarray) – Index of PDP sets.

  • qntile (list, optional) – Quantile valeus for Credible Interval. Defaults to [0.025, 0.975].

Returns:

Risk Ratio Mean, Risk Ratio Quantiles.

Return type:

Dict