Demonstration
The following is a brief demonstration on how to use the BART-Survival
library. In the demonstration the rossi survival dataset from the lifelines library is used. The dataset contains one year of follow-up observation on 432 convicts who were released from Maryland state prisons in the 1970s. The primary event measured in this dataset is observation of a “new arrest” within that year and the associated “time to arrest” is given in weeks [@rossi1980].
from lifelines.datasets import load_rossi
from bart_survival import surv_bart as sb
import numpy as np
######################################
# Load rossi dataset from lifelines
rossi = load_rossi()
names = rossi.columns.to_numpy()
rossi = rossi.to_numpy()
After loading the libraries and data, the first step is to generate the TAD and PAD datasets. In this step, the time (originally in weeks) is downscaled by a factor of 4, setting time to be measured in months.
######################################
# Transform data into 'augmented' dataset
# Requires creation of the training dataset and a predictive dataset for inference
# TAD
trn = sb.get_surv_pre_train(
y_time=rossi[:,0],
y_status=rossi[:,1],
x = rossi[:,2:],
time_scale=4
)
# PAD
post_test = sb.get_posterior_test(
y_time=rossi[:,0],
y_status=rossi[:,1],
x = rossi[:,2:],
time_scale=4
)
Below the trn
object is displayed. The trn
object generated in this step is a dictionary of arrays containing the TAD components. The y
and x
components are the corresponding outcome and covariates of the TAD. The w
object is an array of weight values generated by the get_surv_pre_train
function. By default all weights are set to \(1\) and do not contribute to model training. For general use, the weighting functionality can be ignored. In more complex study designs, observation level weights can be provided which allows weighted contribution to the likelihood function during model training. Weighting the likelihood function is currently experimental, but we plan to evaluate this utility further in future work. Finally, the coord
object contains the observation identifier for each row of the TAD, making it easy to identify the rows associated with a single or set of observations from the TAD.
# {'y': array([[0.],
# [0.],
# [0.],
# ...,
# [0.],
# [0.],
# [0.]]),
# 'x': array([[ 1., 0., 27., ..., 0., 1., 3.],
# [ 2., 0., 27., ..., 0., 1., 3.],
# [ 3., 0., 27., ..., 0., 1., 3.],
# ...,
# [11., 1., 24., ..., 0., 1., 1.],
# [12., 1., 24., ..., 0., 1., 1.],
# [13., 1., 24., ..., 0., 1., 1.]]),
# 'w': array([[1.],
# [1.],
# [1.],
# ...,
# [1.],
# [1.],
# [1.]]),
# 'coord': array([ 0, 0, 0, ..., 431, 431, 431])}
The next step is to initialize the model, which involves setting several parameter values. The key considerations when initializing the models is number of trees and the split rules. The number of trees controls how many regression trees will be used. Typically 50 trees is a good default, but it can be adjusted to assist in model performance. Split rules is a specific PyMC-BART
parameter and is used to designate the how the regression trees are constructed. The one requirement of the split rules is that the time covariate has to be set as a pmb.ContinuousSplitRule()
. Otherwise, continuous variables can be assigned pmb.ContinuousSplitRule()
and categorical variables assigned pmb.OneHotSplitRule()
. It is is recommended to review the PyMC-BART
literature for more information regarding parameterization of the models.
######################################
# Instantiate the BART models
# model_dict is defines specific model parameters
model_dict = {"trees": 50,
"split_rules": [
"pmb.ContinuousSplitRule()", # time
"pmb.OneHotSplitRule()", # fin
"pmb.ContinuousSplitRule()", # age
"pmb.OneHotSplitRule()", # race
"pmb.OneHotSplitRule()", # wexp
"pmb.OneHotSplitRule()", # mar
"pmb.OneHotSplitRule()", # paro
"pmb.ContinuousSplitRule()", # prio
]
}
# sampler_dict defines specific sampling parameters
sampler_dict = {
"draws": 200,
"tune": 200,
"cores": 8,
"chains": 8,
"compute_convergence_checks": False
}
BSM = sb.BartSurvModel(
model_config=model_dict,
sampler_config=sampler_dict
)
The model can then be trained with the TAD input and predicted \(p_{ij}\) values yielded with the PAD input.
#####################################
# Fit Model with TAD
BSM.fit(
y = trn["y"],
X = trn["x"],
weights=trn["w"],
coords = trn["coord"],
random_seed=5
)
# Get posterior predictive for evaluation using the PAD
post1 = BSM.sample_posterior_predictive(
X_pred=post_test["post_x"],
coords=post_test["coords"]
)
Finally, the survival probability can derived from the \(p_{ij}\) estimates.
# Convert to SV probability.
sv_prob = sb.get_sv_prob(post1)
print(sv_prob["sv"].shape)
# (1600, 432, 13)
The sv_prob object above is a dictionary containing numpy arrays of both the \(p_{ij}\) and \(s_{ij}\) estimates, labeled “prob” and “sv” respectively. The \(p\),\(s\) arrays are three dimensional with the dimensions of the arrays being:
axis 0 = draws of the posterior predictive distribution: 1600
axis 1 = observations \(i\): 432
axis 2 = times \(j\): 13
These arrays can be easily reduced down to point estimates and credible intervals using basic numpy methods. For example to get the estimate of the mean over all observations, first get the mean over the observations (axis 1) followed by the mean over the posterior draws (axis 0). The results being the estimated mean survival over the 13 time intervals.
Similarly the 0.05-0.95 credible interval for the estimated mean survival can be returned as the quantile evaluations of the same mean-over-axis-1 array. This yields a (2,13) array with the lower and upper bounds (rows) of the credible interval defined for each time point (columns).
# get the mean value across observations for each time within each draw of the posterior predictive distribution
ave_obs = sv_prob["sv"].mean(axis=1)
print(ave_obs.shape)
# (1600, 13)
# get the average across the posterior draws
ave_obs_draws = ave_obs.mean(0)
print(ave_obs_draws)
#[0.98282867 0.96457594 0.94578149
# 0.92613075 0.90503023 0.88429523
# 0.86382304 0.84376736 0.82281631
# 0.80169965 0.78080675 0.75948555
# 0.73816541]
# get the .05 and .95 percentiles of the mean across posterior draws
ci_obs_draws = np.quantile(ave_obs, [0.05, 0.95], axis=0)
print(ci_obs_draws)
# lower bound
# [0.97813492 0.95597036 0.93449701
# 0.91273325 0.88996273 0.86803963
# 0.84529644 0.82373944 0.8013158
# 0.77941446 0.75844137 0.73615453
# 0.71387 ]
# upper bound
# [0.9879464 0.97269846 0.95684153
# 0.93988969 0.91999613 0.90054844
# 0.8815908 0.86294746 0.84339772
# 0.82276706 0.80264632 0.78084766
# 0.76145384]
Examples of generation of marginal effect estimates can be found in the example notebooks provided in the repository documentation.