Main

Deep learning is a successful strategy where a highly parameterized model makes human-like predictions across many fields1,2,3,4. Yet challenges in both interpretation and generalization often keep deep learning from use in practice5,6. Deep-learned models and their specific prediction mechanisms are difficult to assess directly due to the large collection of model parameters. Inspection methods such as activation or saliency maps7,8 highlight only the results for a single prediction with their own limitations9. Likewise, influence estimation techniques10 often produce a ranked list of samples. These tend to be most useful to understand issues retrospectively, after they have been identified. In comparison, global data visualizations such as tSNE11 and UMAP12,13 offer the power to inspect the global space of predictions among large collections of data. In principle, these methods offer the ability to prospectively identify those problematic data regions; however, the dimension reduction inherent to these methods may distort properties of the data.

Topological data analysis, on the other hand, excels at distilling representation-invariant information14,15,16,17 because it seeks to simplify the shape of data in its ambient space without reducing its dimension. Topological data analysis (TDA) of complex predictive models such as deep learning remains in its infancy18,19,20,21,22. Existing research focuses on trying to assess the topological properties of the network weights, to assess the topology of the features used by the network, to initialize network weights with topologically consistent operators (GENEOs), or to add topological features to predictions. Our approach seeks to assess the topology of the neural network embeddings, representations of the data, and how they interact with the predictions. By way of an anthropomorphic analogy, we seek to simplify the topological lens with which the neural network sees the data for predictions. Although we say deep learning, our methods are compatible with any mechanisms that outputs a vector of class probability values as discussed in the Supplementary Methods, including more classic techniques such as support vector machines or logistic regressions.

Our GTDA method

We construct a Reeb network to assess the prediction landscape of a neural-network-like prediction method. Reeb networks are discretizations of topological structures called Reeb spaces, which generalize Reeb graphs17,23. An example of the differences among these concepts is illustrated in Extended Data Fig. 1 with further discussion in Supplementary Section 1.6. Reeb networks seek to simplify the data while respecting topology. We design a recursive splitting and merging procedure called graph-based topological data analysis (GTDA) to simplify the data.

Our GTDA method builds on the mapper algorithm15. Mapper, itself, builds a discrete approximation of a Reeb graph or Reeb space (see Supplementary Section 1.6 and Extended Data Fig. 1). It begins with a set of data points (x1, …, xn), along with a single or multivalued function sampled at each data point. The set of all these values {f1, …, fn} samples a map \(f:X\to {{\mathbb{R}}}^{k}\) on a topological space X. The map f is called a filter or lens. The idea is that when f is single-valued, a Reeb graph shows a quotient topology of X with respect to f and mapper discretizes this Reeb graph using the sampled values of f on points x1, …, xn. Algorithmically, mapper consists of the steps:

  1. 1.

    Sort the values fi and split them into overlapping bins B1, …, Br of the same size.

  2. 2.

    For each bin of values Bj, let Sj denote the set of data points with those values and cluster the data points in each Sj independently (that is, we run a clustering algorithm on each Sj as if it were the entire dataset).

  3. 3.

    Create a node in the Reeb graph for each cluster found in the previous step.

  4. 4.

    Connect the nodes of the Reeb graph if the clusters they represent share a common datapoint.

The resulting graph is a discrete approximation of the Reeb graph and represents a compressed view of the shape underlying the original dataset.

The input data for mapper is usually a point cloud in a high-dimensional space where the point coordinates are used only in the clustering step. In our methodology, we are interested in datasets that are even more general. Graph inputs provide this generality. Datasets not in graph format such as images or DNA sequences can be easily transformed into graphs by first extracting intermediate outputs of the model as embeddings and then building a nearest-neighbour graph from the embedding matrix. The resulting graph then facilitates easy clustering: for each subset of points, we extract the subgraph induced by those points and then use a parameter-free connected-components analysis to generate clusters. Our method could also work with point-cloud data and clustering directly through standard relationships between graph-based algorithms and point-cloud-based algorithms. We focus on the graph-based approach both for simplicity and because we found it the most helpful for these prediction applications.

The GTDA method therefore begins with a graph representing relationships among data points and a set of values over each node called lenses (Fig. 1a,b); the terminology of lenses arises from a work by Lum and colleagues17. In the applications we consider, the lenses we use are the prediction matrix of a neural network model where Pij is the probability that sample i belongs to class j. Graph-based TDA uses a recursive splitting strategy to build the bins in the multidimensional space (Fig. 1c), instead of tensor product bin constructions as in multidimensional generalizations of mapper. Detailed pseudo code for this procedure can be found in Supplementary Algorithm 1. An animation of the method can be found in Supplementary Video 1. The fundamental idea is that the recursive splitting starts with the set of connected components in the input graph. This is a set of sets: \({\mathbb{S}}\). The key recursive step is when the method takes a set \({{\mathbb{S}}}_{i}\) from \({\mathbb{S}}\), it then splits \({{\mathbb{S}}}_{i}\) into new (possibly) overlapping sets \({{\mathbb{T}}}_{1},\ldots ,{{\mathbb{T}}}_{h}\) on the basis of the lens with the maximum difference in values on \({{\mathbb{S}}}_{i}\), and ensures that each \({{\mathbb{T}}}_{i}\) is a single connected component in the graph. Each \({{\mathbb{T}}}_{i}\) is then either added to \({\mathbb{S}}\) if it is large enough (that is, it has more than K vertices) and where there exists a lens with a maximum difference of larger than d. Otherwise, \({{\mathbb{T}}}_{i}\) goes into the set of finalized sets \({\mathbb{F}}\).

Fig. 1: Overview of the GTDA method.
figure 1

a,b, The GTDA construction of a Reeb network starts with an input graph (a) and a set of lenses that assign values to each node of the graph (b), where the values are indicated by the node colour. c, A Reeb network is a simplification built from overlapping subgroups or clusters in the original data with similar values for the lenses—GTDA builds these using a recursive splitting procedure. At each recursive step, a single lens is chosen and the data are split into parts based on the node values in that lens. The split is done so that there are overlapped nodes around the split boundary. This continues until only small groups remain. d, These subgroups are assembled into a Reeb network by simplifying each subgroup to a single Reeb node and connecting Reeb nodes if they share any nodes from the overlapped splits. e, The GTDA method further combines and connects small and isolated Reeb nodes to produce the GTDA Reeb network from the graph and lenses.

After the graph has been split into overlapping subgroups and we have the final set of sets, \({\mathbb{F}}\), the initial Reeb network simplifies each subgroup into a single Reeb network node and connects these simplified nodes if they share any data points (Fig. 1d). This connection strategy may leave Reeb nodes isolated, which is not helpful to understand predictions. We reduce this isolation by adding edges from a minimum spanning tree (Fig. 1e and Supplementary Algorithms 2 and 3) on the basis of the potential for overlap from alternative splits caused by the lenses. We then take two final merging steps, along with building the Reeb net: the first is to merge sets in \({\mathbb{F}}\) if they are too small (Supplementary Algorithm 2); the second is to add edges to the Reeb net to promote more connectivity (Supplementary Algorithm 3). In total, there are seven user-chosen parameters that control the method and these final merging steps, which are described in Extended Data Table 1.

Computing a Reeb network with GTDA for a complex prediction function or deep-learning method offers a number of opportunities to inspect the predictions (Fig. 2). In this example, GTDA offers more detail at the interface between prediction classes than what is possible with existing methods such as Mapper.

Fig. 2: Exploring prediction class interfaces with GTDA.
figure 2

a, Consider a prediction scenario with three classes in a Swiss Roll structure and an underlying graph where graph neural network predictions show reasonable accuracy (0.88). The result of the neural network model is a set of three functions over the nodes of the graph that give the probability of prediction for each class, which we call lenses. b, The proposed GTDA method produces a simplified topological map of these lenses along with the graph structure, that is, a Reeb network. Each node in the Reeb network maps to a small cluster of similar values to the lens. Nodes are coloured by the fraction of points in each predicted class. The map is disconnected, and each connected piece maps to a limited piece of the original data, simplifying and specifically focusing inspection. c, This specificity enables exploration of the interface between the orange and purple class, showing regions where training and validation data points might suggest alternative predictions. d, Results from the existing Mapper algorithm for TDA lack this boundary because they contain too many disconnected, isolated pieces.

To apply GTDA to prediction analysis, there must be a large set of data points with unknown labels beyond those used for training and validating the prediction model; this is common when gathering data is easy. There must be known relationships among all data points such as: (1) a given graph of relationships among all points (used in Fig. 2a); (2) a nearest-neighbour computation to create such a graph (used when analysing Enformer); or (3) a domain-relevant means of clustering related points. All of our examples use (1) and (2). We also need a real-valued guide to each prediction or predicted class, such as the output from the last layer of a neural network (Fig. 2a). The prediction from this layer provides the lenses. We found it helpful to first smooth the information from the lenses over the relationship graph to avoid sharp gradients using five or ten steps of an iterative smoothing procedure related to a diffusion. Furthermore, there are two main parameters: the maximum size of a Reeb node or cluster, and the amount of overlap in Reeb nodes. The other parameters are less influential (see Supplementary Table 1 for a full list); useful results arise from a wide range of parameters (see Supplementary Section 7 for further discussion of parameter sensitivity).

Constructing a Reeb net with GTDA is a scalable operation. Analysing the Enformer model of gene expression prediction below takes about 30 s, whereas running the Enformer model itself takes hours to generate the necessary data. Analysing 1.3 million images in ImageNet24 with 2,000 lenses for 1,000 classes in a comparison of ResNet25 and AlexNet26 takes 7.24 h (see Extended Data Table 2).

Understanding malignant gene mutation predictions

The Enformer model1 is a transformer-based model27 designed to estimate gene expression on the basis of DNA. It works by mapping between the DNA sequence to an estimate of the expression level of this piece of DNA in each of 5,313 genomic tracks. Although Enformer has excellent predictive results, it remains a highly parameterized black box. Our GTDA methodology allows us to assess the topological landscape of the Enformer embeddings when they are used to predict harmful mutations of the BRCA1 gene in Homo sapiens (Fig. 3).

Fig. 3: To apply GTDA to study the Enformer model, we adapt the pipeline proposed by Avsec and colleagues1 to use Enformer to study harmful gene variants.
figure 3

a, For each DNA variant of BRCA1 from ClinVar34, we run Enformer to generate the difference in expression levels in each of 5,313 genomic tracks. b,c, These differences are assembled into a 5,313 × 23,376 matrix of data (b) that we split into a 50/50 training and testing set for logistic regression against ClinVar’s evidence of harm (c). d, Four lenses are input into GTDA: two prediction probabilities from logistic regression and the two dominant vectors from Principal Component Analysis (PCA), along with a five-nearest-neighbours graph of a 128-dimensional reduction via PCA. e, The GTDA result shows 105 individual connected components placed on the basis of the mean of all median DNA variant starting positions for each Reeb net node.

As GTDA results in a simplification of the landscape, this enables us to demonstrate biologically relevant features of Enformer’s predictions. In particular, the GTDA map of Enformer shows that many regions of the predictions and embeddings are localized in the DNA sequence (Fig. 4a). Exceptions indicate potential long-distance interactions that Enformer uses to enhance its predictions. By contrast, neither the standard Mapper algorithm for TDA (Fig. 4b) nor the tSNE or UMAP embeddings (Fig. 4c,d) of the same points show nearly the same degree of location sensitivity. In another demonstration of how the GTDA framework highlights the known biology of DNA, we examine where mutations in the exons of the 1JNX region of BRCA1 are present in the final maps. Again, we see strong localization among the exons and the GTDA map (Fig. 4e). The results are again much weaker for Mapper, tSNE and UMAP (Fig. 4f–h).

Fig. 4: Demonstrating biologically relevant features of Enformer’s predictions.
figure 4

a, The topological simplification identified by GTDA is highly correlated with DNA variant starting location. bd, Alternative global visualizations, such as the simplification from Mapper (b)—or dimensionality reduction techniques UMAP (c) and tSNE (d)—show significantly less sensitivity to the locations of the variants (P < 0.001 in a Kolmogorov–Smirnov test; see Supplementary Table 6). e, Likewise, the GTDA results strongly localize the exons of the 1JNX structure within the BRCA1 gene. fh, The results are significantly weaker for Mapper (f), UMAP (g) and tSNE (h) (P < 0.001; see Supplementary Table 6). These results demonstrate that the Enformer model is sensitive to these aspects of gene expression and that GTDA makes inspection possible.

If we study the 1JNX repeat region within BRCA1 and its associated 3D structure, then key pieces of the secondary structure of 1JNX are likewise localized in the Reeb components identified by GTDA (Fig. 5a). This greatly aids interpretation of the results. For one of the helix structures, this analysis reveals regions where insertions and deletions are harmful (pathogenic) and where single nucleotide variants lack evidence of harm (Fig. 5b).

Fig. 5: We use Reeb networks to visualize harmful (probably pathogenic) and potentially non-harmful (no evidence of pathogenicity) predictions of gene variants in BRCA1.
figure 5

a, The topology indicates several secondary structures on part of the protein (1JNX). b, We further check the model predictions on variants targeting one secondary structure. Our error estimate shows a number of what are probably erroneous predictions, and we flip these expected errors (a final analysis showed that these errors were correctly identified). We continue to see variants with distinct predictions in a small region of a few amino acids. Close examination shows a strong association between mutation types and model predictions where deletion or insertion is more likely to be harmful than a single nucleotide variant. c, Further insights when using the full label set show that some estimated errors are completely wrong. These prediction mistakes involve gene mutation experiments with insignificant or conflicting results and indicate underlying uncertainty. These results show how GTDA enables detailed domain-specific inspection of Enformers results (a,b) and highlights problems with the training and testing data (c).

Another tool that our framework provides is automatic error estimation. A similar tool is uncertainty in neural network predictions, which highlights places with less confident predictions. Automatic error estimation in GTDA applies a simple network diffusion analysis to the original data graph, but restricted to edges that are identified as important to the topological simplification. Full details of the procedure are given in Supplementary Section 1.5. This error estimation greatly outperforms model uncertainty for this study (area under the curve (AUC) of 0.9 from GTDA compared with an AUC of 0.66 for uncertainty; Extended Data Fig. 2). In binary classification problems, we can automatically correct mistakes if the probability of error from our estimate is higher than model confidence in the solution.

In comparing our error estimate to the underlying annotations on harmful DNA variants, we discovered a Reeb component with many harmful predictions (Fig. 5c). This component had many mutations where the framework incorrectly predicts errors after comparing with known labels. Detailed analysis showed that these errors are strongly associated with insignificant results in the underlying data that should not have been used as training or testing data (Extended Data Fig. 3 and Supplementary Section 5.5).

Additional findings

When the framework is applied to a pretrained ResNet50 model25 on the Imagenette dataset28, it produces a visual taxonomy of images suggesting what ResNet50 is using to categorize the images (Supplementary Section 3). This example also highlights a region where the ground-truth labels of the data points are incorrect and cars are erroneously labelled ‘cassette player’ (Fig. 6). To make these visualize taxonomies easier to explore, we design a diagram that places images directly on the layout of the Reeb network (Extended Data Fig. 4). We were unable to determine how to use traditional TDA results to identify this set of erroneous examples (Extended Data Fig. 5), although we reliably do so with GTDA (Supplementary Section 3.3).

Fig. 6: Analyzing ResNet50 predictions on Imagenette.
figure 6

a, We take a pretrained ResNet50 model and retrain the last layer to predict ten classes in Imagenette. b,c, We zoom into the Reeb network group of ‘gas pump’ predictions (b) and display the images of different local regions (c), showing gas pump images with distinct visual features (the sparsity in coverage is due to the public domain images we show here). Examining these subgroups can provide a general idea on how the model will behave when predicting future images with similar features, as well as help us quickly identify potential labelling issues in the dataset. d, For instance, we find a group of images whose predicted labels are ‘cassette player’ even though they are actually images of ‘cars’ (or even a car at a gas pump); this arises due to errors in the original training data, where images of cars are labelled ‘cassette player’. Credit: the four gas pumps at lower centre are from the Library of Congress, Prints and Photographs Division, photograph by John Margolies.

When we apply the GTDA framework to a graph neural network that predicts the type of product on Amazon on the basis of reviews, the framework identifies an ambiguity in product categories that limits prediction accuracy (Extended Data Fig. 6 and Supplementary Section 2).

We compare the embeddings from tSNE, UMAP and GTDA for the four datasets (the simple Swiss Roll from Fig. 2, the product type Amazon data, ResNet50 on Imagenette and the malignant gene mutation data) in Extended Data Fig. 7.

Two further studies include an investigation of chest X-ray diagnostics and a comparison of deep-learning frameworks. In the investigation of chest X-rays, we show how the Reeb networks find incorrect diagnostic annotations in chest X-ray datasets used for deep learning29 (Extended Data Fig. 8). The GTDA methods give an AUC of 0.75 (Supplementary Section 6). The comparison of deep learning frameworks is designed to test how well GTDA scales to larger problems with over a million points and 2,000 lenses. In this case, we use GTDA to analyse pairwise differences among ResNet25, AlexNet26 and VOLO-5030. Each lens consists of one set of predictions from each method. These results show that GTDA scales to large problems and does not take much more time than inference (Extended Data Table 2 and Supplementary Section 4).

Discussion

Our Reeb network construction extends recent analytical methods from topology15,17 to facilitate applications to the topology of complex prediction. In comparison with other proposed applications of topology to neural networks, GTDA focuses on simplifying the topology of the combined prediction and embedding space to aid in inspection of the prediction methods. The ideas underlying GTDA are loosely related to how Naitzat et al.18 study topological changes as data are passed through the layers of a neural network, whereas we focus only on the final embedding space. Graph-based topological data analysis (GTDA) differs from methods such as TopoAct19, which studies the shape of activation space at a given layer of a neural network to provide insights on the learned representations, as well as those of Gabrielsson et al.20, who correlate topology in the weights with generalization performance. It also differs from methods that directly try to embed topology in training21,31. It is similar in spirit to methods that combine group-invariance and topological understanding of data with neural networks22, albeit with key differences regarding how the topological information is used. Further combinations of these ideas offer considerable potential for use of topology in complex prediction methods.

Our work relates to interpretability32 and seeks to produce a comprehensible map of the prediction structure to aid navigation of a large space of prediction to those most interesting areas. In this taxonomy of interpretability, our methods are most useful for global, dataset-level post-hoc interpretability. They are relevant because they provide insights into the model’s behaviour in terms of the underlying domains (gene expression in the main text, and images and graph problems in the supplementary case studies). In terms of relevance to real-world problems, our methods highlight both problematic data points in the training and validation sets. These results also highlight weaknesses of dimension-reduction methods for similar uses. Beyond identifying that there is a problem, the insights from the topology suggest relationships to nearby data and thereby suggest mechanisms that could be addressed through future improvements.

Considering the ability of these topological inspection techniques to translate prediction models into actionable human-level insights, we expect them to be applicable to new models and predictions, broadly, as they are created and to be a critical early diagnostic of prediction models. The interaction of topology and prediction may provide a fertile ground for future improvements in prediction methods.

Methods

See the Supplementary Information for full details on the methods.