Title: | Artificial neural networks for survival analysis |
---|---|
Description: | Artificial neural networks for survival analysis |
Authors: | Marvin N. Wright |
Maintainer: | Marvin N. Wright <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.0.5 |
Built: | 2024-12-31 04:17:19 UTC |
Source: | https://github.com/bips-hb/survnet |
Creates matrix with at-risk and event information. Format: (S_1, ..., S_K, E_1, ..., E_K). Dimensions: obs X 2*causes*time.
convert_surv_cens(time, status, breaks, num_causes)
convert_surv_cens(time, status, breaks, num_causes)
time |
Survival time |
status |
Censoring indicator: 0 for censored observations, positive values for events. |
breaks |
Right interval limits for discrete survival time. |
num_causes |
Number of competing risks. |
Binary response matrix.
Likelihood of parametric inference for the cumulative incidence functions as defined by Jeong & Fine 2006. Also used by Lee et al. 2018.
loss_cif_loglik(num_intervals, num_causes = 1)
loss_cif_loglik(num_intervals, num_causes = 1)
num_intervals |
Number of time intervals |
num_causes |
Number of causes for competing risks |
Data structure:
y_true
True survival: Matrix with at-risk and event information. Format: (S_1, ..., S_K, E_1, ..., E_K). Dimensions: obs X 2*causes*time.
y_pred
Network output: One probability for each time and cause. Format: (y_11, ..., y_T1, ..., y_TK). Dimensions: obs X causes*time.
Negative log-likelihood
Jeong, J. & Fine, J. (2006). Direct parametric inference for the cumulative incidence function. J R Stat Soc Ser C Appl Stat 55:187-200. https://doi.org/10.1111/j.1467-9876.2006.00532.x.
Lee, C., Zame, W.R., Yoon, J. & van der Shaar, M. (2018). DeepHit: A deep learning approach to survival analysis with competing risks. AAAI 2018. http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit.
survnet prediction
## S3 method for class 'survnet' predict(object, newdata, cause = NULL, ...)
## S3 method for class 'survnet' predict(object, newdata, cause = NULL, ...)
object |
|
newdata |
New data predictors: |
cause |
Select cause for competing risks, |
... |
Further arguments passed to or from other methods. |
Cumulative incidence function of selected or all causes.
Artificial neural networks for survival analysis
survnet(y, x, breaks, units = c(3, 5), units_rnn = c(4, 6), units_causes = c(3, 2), epochs = 100, batch_size = 16, validation_split = 0.2, loss = loss_cif_loglik, activation = "tanh", rnn_type = "LSTM", skip = TRUE, dropout = rep(0, length(units)), dropout_rnn = rep(0, length(units_rnn)), dropout_causes = rep(0, length(units_causes)), l2 = rep(0, length(units)), l2_rnn = rep(0, length(units_rnn)), l2_causes = rep(0, length(units_causes)), optimizer = optimizer_rmsprop(lr = 0.001), verbose = 2)
survnet(y, x, breaks, units = c(3, 5), units_rnn = c(4, 6), units_causes = c(3, 2), epochs = 100, batch_size = 16, validation_split = 0.2, loss = loss_cif_loglik, activation = "tanh", rnn_type = "LSTM", skip = TRUE, dropout = rep(0, length(units)), dropout_rnn = rep(0, length(units_rnn)), dropout_causes = rep(0, length(units_causes)), l2 = rep(0, length(units)), l2_rnn = rep(0, length(units_rnn)), l2_causes = rep(0, length(units_causes)), optimizer = optimizer_rmsprop(lr = 0.001), verbose = 2)
y |
Survival outcome: |
x |
Predictors: |
breaks |
Right interval limits for discrete survival time. |
units |
Vector of units, each specifying the number of units in one hidden layer. |
units_rnn |
Vector of units for recurrent layers. |
units_causes |
Vector of units for cause-specific layers (competing risks only). Either a vector (will be repeated for each cause) or a list of vectors with layers for each cause. |
epochs |
Number of epochs to train the model. |
batch_size |
Number of samples per gradient update. |
validation_split |
Fraction in [0,1] of the training data to be used as validation data. |
loss |
Loss function. |
activation |
Activation function. |
rnn_type |
Type of RNN layers. Either |
skip |
Add skip connection from input and RNN layers to cause-specific layers. |
dropout |
Vector of dropout rates after each hidden layer. Use 0 for no dropout (default). |
dropout_rnn |
Vector of dropout rates after each recurrent layer. Use 0 for no dropout (default). |
dropout_causes |
Vector of dropout rates after each cause-specific layer. Use 0 for no dropout (default). |
l2 |
Vector of L2 regularization factors for each hidden layer. Use 0 for no regularization (default). |
l2_rnn |
Vector of L2 regularization factors for each recurrent layer. Use 0 for no regularization (default). |
l2_causes |
Vector of L2 regularization factors for each cause-specific layer. Use 0 for no regularization (default). |
optimizer |
Name of optimizer or optimizer instance. |
verbose |
Verbosity mode (0 = silent, 1 = progress bar, 2 = one line per epoch). |
Fitted model.
library(survival) library(survnet) # Survival data y <- veteran[, c(3, 4)] x <- veteran[, c(-2, -3, -4)] x <- data.frame(lapply(x, scale)) breaks <- c(1, 50, 100, 200, 500, 1000) # Fit simple model fit <- survnet(y = y, x = x, breaks = breaks) plot(fit$history)
library(survival) library(survnet) # Survival data y <- veteran[, c(3, 4)] x <- veteran[, c(-2, -3, -4)] x <- data.frame(lapply(x, scale)) breaks <- c(1, 50, 100, 200, 500, 1000) # Fit simple model fit <- survnet(y = y, x = x, breaks = breaks) plot(fit$history)