| Title: | Gradient-Based Feature Attribution for Survival Neural Networks |
|---|---|
| Description: | This package implements model-specific, gradient-based feature attribution methods for deep survival neural networks, including DeepHit, CoxTime, and DeepSurv. It accompanies the ICML 2025 paper "Gradient-based Explanations for Deep Learning Survival Models". |
| Authors: | Niklas Koenen [aut, cre] (ORCID: <https://orcid.org/0000-0002-4623-8271>), Sophie Hanna Langbein [aut] (ORCID: <https://orcid.org/0000-0001-5629-2055>) |
| Maintainer: | Niklas Koenen <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.1.0 |
| Built: | 2026-05-10 08:19:05 UTC |
| Source: | https://github.com/bips-hb/survinng |
This function converts the survival attribution results into a data frame format. It can handle both stacked and non-stacked formats.
## S3 method for class 'surv_result' as.data.frame(x, ..., stacked = FALSE) as.data.table.surv_result(x, ..., stacked = FALSE)## S3 method for class 'surv_result' as.data.frame(x, ..., stacked = FALSE) as.data.table.surv_result(x, ..., stacked = FALSE)
x |
An object of class |
... |
Unused arguments. |
stacked |
Logical indicating whether to convert to a stacked data frame,
i.e., the attributions are stacked on top of each other. Default is |
This function is a generic method that dispatches to the appropriate explain method based on the class of the model.
explain( model, data = NULL, model_type = NULL, baseline_hazard = NULL, labtrans = NULL, time_bins = NULL, preprocess_fun = NULL, postprocess_fun = NULL, predict_fun = NULL ) ## S3 method for class 'nn_module' explain( model, data, model_type, baseline_hazard = NULL, labtrans = NULL, time_bins = NULL, preprocess_fun = NULL, postprocess_fun = NULL, predict_fun = NULL ) ## S3 method for class 'extracted_survivalmodels_coxtime' explain(model, data, ...) ## S3 method for class 'extracted_survivalmodels_deephit' explain(model, data, ...) ## S3 method for class 'extracted_survivalmodels_deepsurv' explain(model, data, ...)explain( model, data = NULL, model_type = NULL, baseline_hazard = NULL, labtrans = NULL, time_bins = NULL, preprocess_fun = NULL, postprocess_fun = NULL, predict_fun = NULL ) ## S3 method for class 'nn_module' explain( model, data, model_type, baseline_hazard = NULL, labtrans = NULL, time_bins = NULL, preprocess_fun = NULL, postprocess_fun = NULL, predict_fun = NULL ) ## S3 method for class 'extracted_survivalmodels_coxtime' explain(model, data, ...) ## S3 method for class 'extracted_survivalmodels_deephit' explain(model, data, ...) ## S3 method for class 'extracted_survivalmodels_deepsurv' explain(model, data, ...)
model |
A model object. |
data |
A data frame or matrix of data to explain the model. |
model_type |
A string specifying the type of the survival model. Possible values are "coxtime", "deephit", or "deepsurv". |
baseline_hazard |
A data frame containing the baseline hazard. It should have two columns: "time" and "hazard". This is only used for "coxtime" and "deepsurv" models. |
labtrans |
A list containing the transformation functions for the time variable. It should have two elements: "transform" and "transform_inv". This is highly experimental and not yet fully supported. |
time_bins |
A numeric vector specifying the time bins for the
"deephit" model, e.g., |
preprocess_fun |
A function to preprocess the data before
making predictions, e.g., adding a time variable for a |
postprocess_fun |
A function to postprocess the predictions after making them. This argument is highly experimental and the default values should work. |
predict_fun |
A function that can be used to make predictions
from the model. If |
... |
Unused arguments. |
An object of class explainer that contains the model, the
data, and the prediction function.
survivalmodels objectThis function extracts model information from a neural network trained with
the survivalmodels package.
extract_model(x, path = NULL, num_basehazard = 200L) ## S3 method for class 'coxtime' extract_model(x, path = NULL, num_basehazard = 200L) ## S3 method for class 'deephit' extract_model(x, path = NULL, ...) ## S3 method for class 'deepsurv' extract_model(x, path = NULL, num_basehazard = 200L)extract_model(x, path = NULL, num_basehazard = 200L) ## S3 method for class 'coxtime' extract_model(x, path = NULL, num_basehazard = 200L) ## S3 method for class 'deephit' extract_model(x, path = NULL, ...) ## S3 method for class 'deepsurv' extract_model(x, path = NULL, num_basehazard = 200L)
x |
A |
path |
A string specifying the path to save the extracted model. If |
num_basehazard |
An integer specifying the number of points of the
baseline hazard to compute. Default is |
... |
Unused arguments. |
Visualize survival predictions, feature attributions, and contribution percentages and force plots for survival results. The latter two are specifically for GradSHAP(t) and IntGrad(t) methods.
## S3 method for class 'surv_result' plot(x, ..., type = "attr") plot_force(x, num_bars = 10) plot_pred(x) plot_attr(x, normalize = "none", add_comp = NULL) plot_contr(x, aggregate = FALSE)## S3 method for class 'surv_result' plot(x, ..., type = "attr") plot_force(x, num_bars = 10) plot_pred(x) plot_attr(x, normalize = "none", add_comp = NULL) plot_contr(x, aggregate = FALSE)
x |
An object of class |
... |
(unsed arguments) |
type |
Type of plot to generate when using the generic
|
num_bars |
Number of bars to show in the force plot. Default is 10. |
normalize |
Normalization method for
|
add_comp |
Optional vector of comparison curves to add to the
attribution plot (
|
aggregate |
Logical; if |
These functions provide a convenient way to visualize the results of survival attribution methods:
plot() is a generic wrapper that dispatches to the appropriate plot type
based on the type argument.
plot_pred() visualizes survival predictions across time for the selected
instances.
plot_attr() displays time-resolved attributions over time per instance.
The following methods are only available for GradSHAP(t) and IntGrad(t):
plot_contr() visualizes the relative contribution of features over time,
optionally aggregated across instances for global insights.
plot_force() generates force plots showing the features' effect to the
prediction over time.
A ggplot2 object.
This function prints a summary of the explainer object.
## S3 method for class 'explainer_coxtime' print(x, ...) ## S3 method for class 'explainer_deepsurv' print(x, ...) ## S3 method for class 'explainer_deephit' print(x, ...)## S3 method for class 'explainer_coxtime' print(x, ...) ## S3 method for class 'explainer_deepsurv' print(x, ...) ## S3 method for class 'explainer_deephit' print(x, ...)
x |
An object of class 'explainer_coxtime', 'explainer_deepsurv', or 'explainer_deephit'. |
... |
Additional arguments (not used). |
Print method for extracted pycox survival model
## S3 method for class 'extracted_survivalmodels_coxtime' print(x, ...) ## S3 method for class 'extracted_survivalmodels_deepsurv' print(x, ...) ## S3 method for class 'extracted_survivalmodels_deephit' print(x, ...)## S3 method for class 'extracted_survivalmodels_coxtime' print(x, ...) ## S3 method for class 'extracted_survivalmodels_deepsurv' print(x, ...) ## S3 method for class 'extracted_survivalmodels_deephit' print(x, ...)
x |
An object of class |
... |
Additional arguments (not used). |
Print function for surv_result objects
## S3 method for class 'surv_result' print(x, ...)## S3 method for class 'surv_result' print(x, ...)
x |
An object of class "surv_result" |
... |
Additional arguments (not used) |
This function calculates the gradient of the survival function with respect to the input features and time points for a given instance. In the paper, this is referred to as the "Grad(t)" method. It shows the sensitivity of the survival function to changes in the input features at a specific time point.
surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 50, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", ... )surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 50, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_grad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, dtype = "float", ... )
exp |
An object of class |
target |
A character string indicating the target output. For |
instance |
An integer specifying the instance for which the gradient is calculated. It should be between 1 and the number of instances in the dataset. |
times_input |
A logical value indicating whether the gradient should be multiplied with input. In the paper, this variant is referred to as "GxI(t)". |
batch_size |
An integer specifying the batch size for processing. The
default is 1000. This value describes the number of instances within one batch
and not the final number of rows in the batch. For example, |
dtype |
A character string indicating the data type for the tensors.
It can be either |
include_time |
A logical value indicating whether to include the time
points in the output. This is only relevant for |
... |
Unused arguments. |
Other Attribution Methods:
surv_gradSHAP(),
surv_intgrad(),
surv_smoothgrad()
This function calculates the GradSHAP values of the survival function with respect to the input features and time points for a given instance. In the paper, this is referred to as the "GradSHAP(t)" method. It is a fast and model-specific method for calculating the Shapley values for a deep survival model.
surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, ... ) ## S3 method for class 'explainer_coxtime' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, ... )surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, ... ) ## S3 method for class 'explainer_coxtime' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_gradSHAP( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 50, num_samples = 10, data_ref = NULL, dtype = "float", replace = TRUE, ... )
exp |
An object of class |
target |
A character string indicating the target output. For
|
instance |
An integer specifying the instance for which the GradSHAP values are calculated. It should be between 1 and the number of instances in the dataset. |
times_input |
A logical value indicating whether the GradSHAP values should be multiplied with input. |
batch_size |
An integer specifying the batch size for processing. The
default is 1000. This value describes the number of instances within one
batch and not the final number of rows in the batch. For example,
|
n |
An integer specifying the number of samples to be used for approximating the integral. The default is 50. |
num_samples |
An integer specifying the number of samples to be used for the baseline distribution. The default is 10. |
data_ref |
A reference dataset for sampling. If |
dtype |
A character string indicating the data type for the tensors.
It can be either |
replace |
A logical value indicating whether to sample from the baseline
distribution with replacement. The default is |
include_time |
A logical value indicating whether to calculate GradSHAP
also for each time point. This is only relevant for |
... |
Unused arguments. #' @return Returns an object of class |
Other Attribution Methods:
surv_grad(),
surv_intgrad(),
surv_smoothgrad()
This function calculates the integrated gradients of the survival function with respect to the input features and time points for a given instance. In the paper, this is referred to as the "IntGrad(t)" method. It shows the attributions of the input features to the target function with respect to a reference input.
surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 50, n = 10, x_ref = NULL, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", ... )surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 50, n = 10, x_ref = NULL, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_intgrad( exp, target = "survival", instance = 1, times_input = TRUE, batch_size = 1000, n = 10, x_ref = NULL, dtype = "float", ... )
exp |
An object of class |
target |
A character string indicating the target output. For |
instance |
An integer specifying the instance for which the integrated gradients are calculated. It should be between 1 and the number of instances in the dataset. |
times_input |
A logical value indicating whether the integrated
gradients should be multiplied with input. Default is |
batch_size |
An integer specifying the batch size for processing. The
default is 1000. This value describes the number of instances within one batch
and not the final number of rows in the batch. For example, |
n |
An integer specifying the number of approximation points for the integral calculation. Default is 10. |
x_ref |
A reference input for the integrated gradients. If |
dtype |
A character string indicating the data type for the tensors.
It can be either |
include_time |
A logical value indicating whether to include attributions
for the time points. This is only relevant for |
... |
Unused arguments. |
Other Attribution Methods:
surv_gradSHAP(),
surv_grad(),
surv_smoothgrad()
This function calculates the SmoothGrad values of the survival function with respect to the input features and time points for a given instance. In the paper, this is referred to as the "SG(t)" method. It shows the smoothed sensitivity of the survival function to changes in the input features at a specific time point.
surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 50, n = 10, noise_level = 0.1, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", ... )surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 50, n = 10, noise_level = 0.1, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deepsurv' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", ... ) ## S3 method for class 'explainer_coxtime' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", include_time = FALSE ) ## S3 method for class 'explainer_deephit' surv_smoothgrad( exp, target = "survival", instance = 1, times_input = FALSE, batch_size = 1000, n = 10, noise_level = 0.1, dtype = "float", ... )
exp |
An object of class |
target |
A character string indicating the target output. For |
instance |
An integer specifying the instance for which the SmoothGrad is calculated. It should be between 1 and the number of instances in the dataset. |
times_input |
A logical value indicating whether the SmoothGrad should
be multiplied with input. In the paper, this variant is referred to as
|
batch_size |
An integer specifying the batch size for processing. The
default is 1000. This value describes the number of instances within one batch
and not the final number of rows in the batch. For example, |
n |
An integer specifying the number of noise samples to be added to the input features. The default is 10. |
noise_level |
A numeric value specifying the level of Gaussian noise to be added to the input features. The default is 0.1. |
dtype |
A character string indicating the data type for the tensors.
It can be either |
include_time |
A logical value indicating whether to also calculate
the gradients with respect to the time. This is only relevant for
|
... |
Unused arguments. |
Other Attribution Methods:
surv_gradSHAP(),
surv_grad(),
surv_intgrad()