surv_bart
- class bart_survival.surv_bart.BartSurvModel(model_config: Dict = None, sampler_config: Dict = None)
BART Survival Model
- Returns:
BartSurvModel
- Return type:
_type_
- bart_predict(X_pred: ndarray, coords: ndarray, size: int = None, rng: RandomState = None, **kwargs) DataArray
Derives posterior predictions on updated dataset. Alternative method for re-loaded models.
- Parameters:
X_pred (np.ndarray) – Covariate matrix in long-time format.
coords (np.ndarray) – Coordinates associated with long-time format.
size (int, optional) – Sets sample of posterior draws. Defaults to None.
rng (RandomState, optional) – Random number generator for reproducable results. Defaults to None.
- Returns:
DataArray containing the predicted outputs from the model and covariate matrix.
- Return type:
xr.DataArray
- build_model(X: ndarray, y: ndarray, weights: ndarray, coords: ndarray, predictor_names: list, **kwargs)
Builds the PYMC base model.
- Parameters:
X (np.ndarray) – Covariate matrix in long-time form.
y (np.ndarray) – Event status in long-time form.
weights (np.ndarray) – Array of weights.
coords (np.ndarray) – Array of coordinates associated with long-time form.
predictor_names (list) – List of names for variables.
- Returns:
BartSurvModel
- Return type:
_type_
- fit(y: ndarray, X: ndarray, weights: ndarray = None, coords: ndarray = None, progressbar: bool = True, predictor_names: List[str] = None, random_seed: RandomState = None, **kwargs: Any) InferenceData
Call to build and train the data.
- Parameters:
y (np.ndarray) – Event status in long-time form.
X (np.ndarray) – Covariate matrix in long-time form.
weights (np.ndarray) – Weights associated with each observation.
coords (np.ndarray) – Coordinates associated with long-time form.
progressbar (bool, optional) – Displays training progress. Defaults to True.
predictor_names (List[str], optional) – Names of covariates in X matrix. Defaults to None.
random_seed (RandomState, optional) – Seed. Defaults to None.
- Returns:
_description_
- Return type:
az.InferenceData
- static get_default_sampler_config() Dict
Default Sampler Configuration.
- Returns:
Default Sampler Configuration.
- Return type:
Dict
- property id: str
Generate a unique hash value for the model.
The hash value is created using the last 16 characters of the SHA256 hash encoding, based on the model configuration, version, and model type.
- Returns:
A string of length 16 characters containing a unique hash of the model.
- Return type:
str
- classmethod load(fname: str, treename: str)
Loads a saved model.
- Parameters:
fname (str) – Path to saved model object.
treename (str) – Path to saved tree object.
- Returns:
Returns object of SurvBartModel class.
- Return type:
_type_
- sample_model(**kwargs) InferenceData
Initiates training/sampling of the model.
- Returns:
Posterior data collected from training the model.
- Return type:
az.InferenceData
- sample_posterior_predictive(X_pred: ndarray, coords: ndarray, extend_idata: bool = False, **kwargs) Dataset
Derives posterior predictions on updated datasets.
- Parameters:
X_pred (np.ndarray) – Covariate matrix in long-time format.
coords (np.ndarray) – Coordinates associated with long-time format.
extend_idata (bool, optional) – Adds results to the existing idata object. Defaults to False.
- Returns:
Dataset containing the predicted outputs from the model and covariate matrix.
- Return type:
xr.Dataset
- save(idata_name: str, all_tree_name: str) None
Saves a trained model and tree object.
- Parameters:
idata_name (str) – Path to saving model object.
all_tree_name (str) – Path to saving tree object.
- set_idata_attrs(idata: InferenceData | None = None) InferenceData
Sets the additional information in the idata object.
- Parameters:
idata (Optional[az.InferenceData], optional) – Idata. Defaults to None.
- Returns:
Idata
- Return type:
az.InferenceData
- bart_survival.surv_bart.get_pdp(x: ndarray, var_col: List[int] = [], values: List[int | float] = [], qt: List[float] = [0.25, 0.5, 0.75], sample_n=None) ndarray | Dict
Generates data for Partial Dependency Plots.
- Parameters:
x (np.ndarray) – Covariate matrix.
var_col (List[int], optional) – Covariate column to generate pdp values. Defaults to [].
values (List[Union[int,float]], optional) – Values to test on for each covariate. Defaults to [].
qt (List[float], optional) – Quantiles to generate values on if non are given. Defaults to [0.25,0.5,0.75].
sample_n (_type_, optional) – Sample size of large dataset. Useful for working with large datasets. Defaults to None.
- Returns:
PDP covariate matrix and dictionary for tracking PDP components.
- Return type:
Union[np.ndarray, Dict]
- bart_survival.surv_bart.get_posterior_test(y_time: ndarray, y_status: ndarray, x: ndarray, time_scale=None) Dict
Generates long-time format for posterior distribution testing.
To analyze the posterior distribution, posterior predictive estimates need to be generated. Similar to the training data, the data for posterior predictions must also be in a long-time format.
- Parameters:
y_time (np.ndarray) – Event time
y_status (nd.array) – Event status indicator (1 if event occurs, 0 if censored)
x_test (np.ndarray) – Covariate matrix.
time_scale (np.ndarray) – Time Scaling factor
- Returns:
Covariate matrix in long-time format and associated coordinates for the long-time format.
- Return type:
Dict
- bart_survival.surv_bart.get_surv_pre_train(y_time: ndarray, y_status: ndarray, x: ndarray, weight: ndarray | None = None, time_scale=None) Dict
Generates long-time formatted event status and covariate matrix.
The SurvBartModel operates using a discrete time format. This means each observation is represented by a series of observations for each time point up to the event time.
- Parameters:
y_time (np.ndarray) – Event time
y_status (nd.array) – Event status indicator (1 if event occurs, 0 if censored)
x (np.ndarray) – Covariate matrix.
weight (Optional[np.ndarray], optional) – Weights associated with each observation. If non provided, then each observation will have weights of 1. Defaults to None.
time_scale (Optional[int]) – time scaleing factor
- Returns:
Dictionary containg all of the training data including an event status array, covariate matrix, weights and coordinates associated with the long-time format.
- Return type:
Dict
- bart_survival.surv_bart.get_sv_mean_quant(sv: ndarray, msk: ndarray, draws: bool = True, qntile: List[float] = [0.025, 0.975]) Dict
Generates mean and quantile estimates of Survival Probabilties.
- Parameters:
sv (np.ndarray) – Survival probabilties.
msk (np.ndarray) – Mask for selecting values to average (class of a covariate).
draws (bool, optional) – If true, will average over the draws of the posterior, as well as masked values. Defaults to True.
qntile (List[float], optional) – Quantiles to average over. Defaults to [0.025, 0.975].
- Returns:
Mask True mean, Mask True quantiles, Mask False mean, Mask False quantiles.
- Return type:
Dict
- bart_survival.surv_bart.get_sv_prob(post: DataArray | Dataset) Dict
Generates the Survival Probability estimates for time points.
- Parameters:
post (Union[xr.DataArray, xr.Dataset]) – Posterior output.
- Returns:
Risk Probability (Hazard) and Survival Probability for each draw, observation and time.
- Return type:
Dict
- bart_survival.surv_bart.pdp_diff_metric(pdp_val: Dict, idx: ndarray, qntile: List[float] = [0.025, 0.975]) Dict
Generate estimate of marginal difference from PDP posterior.
- Parameters:
pdp_val (Dict) – Posterior of PDP predictions.
idx (np.ndarray) – Index of PDP sets.
qntile (List[float], optional) – Quantile values for Credible Interval. Defaults to [0.025, 0.975].
- Returns:
Survival Probability Difference Mean, Survival Probability Difference Quantiles.
- Return type:
Dict
- bart_survival.surv_bart.pdp_rr_metric(pdp_val: Dict, idx: ndarray, qntile: list = [0.025, 0.975]) Dict
Generates a Risk Ratio (Hazard Ratio) from pdp values.
- Parameters:
pdp_val (Dict) – Posterior of PDP predictions.
idx (np.ndarray) – Index of PDP sets.
qntile (list, optional) – Quantile valeus for Credible Interval. Defaults to [0.025, 0.975].
- Returns:
Risk Ratio Mean, Risk Ratio Quantiles.
- Return type:
Dict