Enhanced survival prediction using explainable artificial intelligence in heart transplantation

The most limiting factor in heart transplantation is the lack of donor organs. With enhanced prediction of outcome, it may be possible to increase the life-years from the organs that become available. Applications of machine learning to tabular data, typical of clinical decision support, pose the practical question of interpretation, which has technical and potential ethical implications. In particular, there is an issue of principle about the predictability of complex data and whether this is inherent in the data or strongly dependent on the choice of machine learning model, leading to the so-called accuracy-interpretability trade-off. We model 1-year mortality in heart transplantation data with a self-explaining neural network, which is benchmarked against a deep learning model on the same development data, in an external validation study with two data sets: (1) UNOS transplants in 2017–2018 (n = 4750) for which the self-explaining and deep learning models are comparable in their AUROC 0.628 [0.602,0.654] cf. 0.635 [0.609,0.662] and (2) Scandinavian transplants during 1997–2018 (n = 2293), showing good calibration with AUROCs of 0.626 [0.588,0.665] and 0.634 [0.570, 0.698], respectively, with and without missing data (n = 982). This shows that for tabular data, predictive models can be transparent and capture important nonlinearities, retaining full predictive performance.


The Partial Response Network (PRN) and PRN-Lasso model
The PRN is built in a stepwise manner as follows.

Pre-trained MLP
The first step is to train an MLP to estimate the posterior distribution of membership of class C, P(C|x). The predictions of this model are then transformed by inverting the sigmoid function at the output node, to obtain the logit(P(C|x)) which is the logarithm of the odds-ratio of the posterior probability.

ANOVA decomposition
The key step in the proposed method is to express the multivariate function consisting of the model predictions transformed into the logit(P(c|x)), by application of the ANOVA decomposition shown in (1) which is anchored at the overall median of the data 1 The terms in the summation are given by: The general form of the terms in (1) is a recursive function of nested subsets of the covariate indices { ' , … , ( }: Recall that the standardisation of data is such that = 0 ← ( ), which sets the overall median of the data as the reference point for the calculation of the logit function, that is to say, the point at which all of the variables take the value 0 and where, by construction, all of the partial responses also vanish, leaving only the constant term ( ).
Two things to note: expression (1) is not an equation but an identity involving a finite number of terms, 2 P where P is the number of input variables in the MLP; and the component functions While the main effects, which are the univariate terms in (3), are numerically identical to the partial dependency plots often used to visualise the operation of non-linear multivariate models, the bivariate effects in (4) are not simply cuts through the model response function, but have the univariate terms removed, so that & , = 0-= & = 0, -= 0. This denotes the orthogonality property under the Dirac measure, which is a special case of the general formulation of the terms of the ANOVA decomposition for a given function f(x), namely + ( , ) = ∫ ( ) ( , ) − Σ -⊂+ -( -) where S indicates the set of variables in term + and ( , ) is the chosen measure. In the case of the Dirac measure, ( , ) = ( − , ) meaning that the function f(x) is evaluated at the anchor point = , .

Lasso optimisation
The third step in the proposed method is to retain, from the ANOVA decomposition of the logit of the pre-trained classifier, only the terms that involve just one or two variables. This amounts to truncating the decomposition (1) followed by a re-calibration.
The terms defined by (3) and (4) become the input variables in a linear model which is logistic regression regularised by the Least Absolute Shrinkage and Selection Operator (Lasso) 10. This is a powerful feature selection method that scales well for a large number of variables. It uses L1 regularisation to carry out feature selection, which slides the coefficients gradually to zero resulting in a sparse model.
This optimisation is computationally feasible because there are in total * ( + 1) 2 ⁄ ≪ 2 / univariate and bivariate effects, which is quadratic in the input dimensions of the original data, rather than exponential as the number of terms in (1).
The resulting model has a logit function consisting of partial responses added together. This is the equivalent of the score index in logistic regression, which is the sum 0 + ∑ 1 1 1 taken over the inputs : = 1. . . In our case, the input values have been replaced by non-linear functions of the variables, each function involving one or two variables at a time in order to ensure interpretability. As these functions sum together to make the model prediction, in a truncated ANOVA decomposition (1), the variables directly relate to the predicted outcome and the final model can be represented with a nomogram.

The PRN model
The application of the Lasso to generate a sparse model has the effect of removing the least informative input variables, improving the signal to noise ratio of the truncated ANOVA decomposition compared with the multivariate function defined by the logit of the predictions of the original MLP. It is now possible to refine the form of the partial responses, by mapping the Lasso model into the form of a GANN. This is initialised to exactly replicate the output of the Lasso model. This mapping is done in order to enable training to continue by back-propagation to fully optimise the partial responses.
This new network is shown in Fig. 6. It implements univariate and bivariate functions with a modular structure. The weight mappings to reproduce the Lasso outputs are given below, following the notation in Fig. 6 i.e. hidden node weights denoted by wij and output weights by vj: i.
Univariate partial response corresponding to input Xi This is shown for input X1. The hidden layer weights w1j connected to node X1 are the same as in the original MLP but the weights and bias to the output node need to be adjusted as follows: ii.

Bivariate partial response for input pair {Xk ,Xl}
This is illustrated in Fig. 6 for inputs X2 and X3. This time, in order to replicate the partial response multiplied by the Lasso coefficient, it is necessary to add three elements to the structure, namely, a univariate partial response for each of the inputs involved and a coupled network that both inputs feed into together. We will use the generic input indices k and l to avoid confusion with the hidden node index j. The hidden layer weights once again remain unchanged from the original MLP. The output layer weights and bias for the network structure representing the univariate term associated with input k (and similarly for input l) are adjusted by whereas the weights and bias for the coupled network are changed according to: iii.
Finally, the intercept of the logistic Lasso, b0 is also added to the output node bias 0 .
An interesting property of this mapping is that the module for a bivariate term involves a combination of univariate and bivariate functions. This means that following further training, what started as an interaction term may result in two independent univariate effects, even if they were not present in the Lasso.

PRN-Lasso
Following optimisation of the PRN, the orthogonality property of the partial response terms will have been broken. It is possible to retrieve this property by repeating steps 2-3 but applied to the outputs of the PRN rather than those of the original MLP. This results in a set of refined partial responses, suitably re-calibrated by the Lasso, which is the PRN-Lasso model.
In all of the figures, partial responses derived from the MLP are shown as dashed lines and those from the PRN using solid lines.