This class contains the base which is used to train data upon.

General Model

class GeneralModel[source]

GeneralModel(base:Module, loss_fn:Loss, lr:float=0.001) :: LightningModule

Helper class that provides a standard way to create an ABC using inheritance.

train_model[source]

train_model(model, train_dl, valid_dl, epochs)

Data for Demo

print(df.shape)
df.head()
(432, 9)
t e fin age race wexp mar paro prio
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3

Hazard Model

class ModelHazard[source]

ModelHazard(model:str, percentiles=[20, 40, 60, 80], h:tuple=(), bs:int=128, epochs:int=20, lr:float=1.0, beta:float=0)

Modelling instantaneous hazard (λ). parameters:

  • model(str): ['ph'|'cox'] which maps to Piecewise Hazard, Cox Proportional Hazard.
  • percentiles: list of time percentiles at which time should be broken
  • h: list of hidden units (disregarding input units)
  • bs: batch size
  • epochs: epochs
  • lr: learning rate
  • beta: l2 penalty on weights

Cox Model Demo

model = ModelHazard('cox')
model.fit(df)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type               | Params
--------------------------------------------
0 | base | ProportionalHazard | 12    
--------------------------------------------
12        Trainable params
0         Non-trainable params
12        Total params
Epoch 0:  75%|███████▌  | 3/4 [00:00<00:00, 21.84it/s, loss=nan, v_num=47]
Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 4/4 [00:00<00:00, 13.67it/s, loss=nan, v_num=47]
Epoch 1:  75%|███████▌  | 3/4 [00:00<00:00, 20.65it/s, loss=nan, v_num=47]
Validating: 0it [00:00, ?it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 13.37it/s, loss=nan, v_num=47]
Epoch 2:  75%|███████▌  | 3/4 [00:00<00:00, 21.48it/s, loss=nan, v_num=47]
Validating: 0it [00:00, ?it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 13.82it/s, loss=nan, v_num=47]
Epoch 3:  75%|███████▌  | 3/4 [00:00<00:00, 20.85it/s, loss=nan, v_num=47]
Validating: 0it [00:00, ?it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 13.30it/s, loss=nan, v_num=47]
                                                         Saving latest checkpoint...
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 12.95it/s, loss=nan, v_num=47]
# %tensorboard --logdir ./lightning_logs/
λ, Λ = model.predict(df)
df.shape, λ.shape, Λ.shape
((432, 9), torch.Size([432, 1]), torch.Size([432, 1]))

Modelling Distribution with AFT models

class ModelAFT[source]

ModelAFT(dist:str, h:tuple=(), bs:int=128, epochs:int=20, lr:float=0.1, beta:float=0)

Modelling error distribution given inputs x. parameters:

  • dist(str): Univariate distribution of error
  • h: list of hidden units (disregarding input units)
  • bs: batch size
  • epochs: epochs
  • lr: learning rate
  • beta: l2 penalty on weights
model = ModelAFT('Gumbel')
model.fit(df)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type     | Params
----------------------------------
0 | base | AFTModel | 9     
----------------------------------
9         Trainable params
0         Non-trainable params
9         Total params
Epoch 0:  75%|███████▌  | 3/4 [00:00<00:00, 18.74it/s, loss=nan, v_num=48]
Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 4/4 [00:00<00:00, 12.84it/s, loss=nan, v_num=48]
Epoch 1:  75%|███████▌  | 3/4 [00:00<00:00, 19.59it/s, loss=nan, v_num=48]
Validating: 0it [00:00, ?it/s]
Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 13.26it/s, loss=nan, v_num=48]
Epoch 2:  75%|███████▌  | 3/4 [00:00<00:00, 20.76it/s, loss=nan, v_num=48]
Validating: 0it [00:00, ?it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 13.38it/s, loss=nan, v_num=48]
Epoch 3:  75%|███████▌  | 3/4 [00:00<00:00, 20.85it/s, loss=nan, v_num=48]
Validating: 0it [00:00, ?it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 13.83it/s, loss=nan, v_num=48]
                                                         Saving latest checkpoint...
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 13.50it/s, loss=nan, v_num=48]
surv_prob = model.predict(df)
mode_time = model.predict_time(df)
df["surv_prob"] = surv_prob
df["mode_time"] = mode_time
df
t e fin age race wexp mar paro prio surv_prob mode_time
0 20 1 0 27 1 0 0 1 3 0.954641 30.349995
1 17 1 0 18 1 0 0 1 8 0.857569 21.751127
2 25 1 0 19 0 1 0 1 13 0.049757 8.331078
3 52 0 1 23 1 1 1 1 1 0.013023 10.488750
4 52 0 0 19 0 1 0 1 3 0.021093 12.552552
... ... ... ... ... ... ... ... ... ... ... ...
427 52 0 1 31 0 1 0 1 3 0.029531 14.236526
428 52 0 0 20 1 0 0 1 1 0.198965 29.817591
429 52 0 1 20 1 1 1 1 1 0.011608 10.050090
430 52 0 0 29 1 1 0 1 3 0.076160 20.383711
431 52 0 1 24 1 1 0 1 1 0.069662 19.698050

432 rows × 11 columns

plt.hist(df[df["e"] == 1]["surv_prob"].values, bins=30, alpha=0.5, density=True, label="death")
plt.hist(df[df["e"] == 0]["surv_prob"].values, bins=30, alpha=0.5, density=True, label="censored")
plt.legend()
plt.show()
# %tensorboard --logdir ./lightning_logs/