Abstract
Causal learning is a key challenge in scientific artificial intelligence as it allows researchers to go beyond purely correlative or predictive analyses towards learning underlying causeandeffect relationships, which are important for scientific understanding as well as for a wide range of downstream tasks. Here, motivated by emerging biomedical questions, we propose a deep neural architecture for learning causal relationships between variables from a combination of highdimensional data and prior causal knowledge. We combine convolutional and graph neural networks within a causal risk framework to provide an approach that is demonstrably effective under the conditions of high dimensionality, noise and data limitations that are characteristic of many applications, including in largescale biology. In experiments, we find that the proposed learners can effectively identify novel causal relationships across thousands of variables. Results include extensive (linear and nonlinear) simulations (where the ground truth is known and can be directly compared against), as well as real biological examples where the models are applied to highdimensional molecular data and their outputs compared against entirely unseen validation experiments. These results support the notion that deep learning approaches can be used to learn causal networks at large scale.
Similar content being viewed by others
Main
Causality remains an important open area in artificial intelligence (AI) research^{1,2}, and the task of identifying causal relationships between variables is key in many scientific domains^{3}. The rich body of work in learning causal structures includes methods such as PC^{4}, LiNGAM^{5}, IDA^{6}, GIES^{7}, RFCI^{8}, ICP^{9} and MRCL^{10}. Scaling causal structure learning to larger problems has been facilitated by reformulation as a continuous optimization problem^{11}, and recent neural approaches, such as SDI^{12}, DCDI^{13}, DCDFG^{14} and ENCO^{15}, have demonstrated stateoftheart performance (Supplementary section 1 provides a detailed discussion). However, learning causal structures from data remains nontrivial and continues to pose challenges, particularly under the conditions (high dimensionality, limited data sizes and hidden variables, for example) seen in many realworld problems.
In biomedicine, causal networks representing the interplay between entities such as genes or proteins play a central conceptual and practical role. Such networks are increasingly understood to be contextdependent, and are thought to underpin aspects of disease heterogeneity and the variation in therapeutic response (for example, refs. ^{16,17,18,19}). A key bottleneck in characterizing such heterogeneity lies in the challenging nature of learning causal structures at scale, because of general methodological issues as well as aspects relevant in the biological domain such as high dimensionality, complex underlying events, the presence of hidden/unmeasured variables, limited data and noise levels.
In this Article, we propose a deep architecture for causal learning that is particularly motivated by highdimensional biomedical problems. The approach we put forward operates within an emerging causal risk paradigm (Methods and Supplementary section 2) that allows us to leverage AI tools and scale to very highdimensional problems involving thousands of variables. The learners proposed allow for the integration of partial knowledge concerning a subset of causal relationships and then seek to generalize beyond what is initially known to learn relationships between all variables. This corresponds to a common scientific usecase in which some prior knowledge is available at the outset—from previous experiments or scientific background knowledge—but it is desired to go beyond what is known to learn a model spanning all available variables.
A large part of the causal structure learning literature involves learning models that allow an explicit description of the relevant datagenerating model (including both observational and interventional distributions) and are in that sense ‘generative’ (see, for example, ref. ^{3} and references therein). Taking a different approach, a line of recent work, including refs. ^{10,20,21,22}, has considered learning indicators of causal relationships between variables (without necessarily learning full details of the underlying datagenerating models), and this can be viewed as being related to notions of causal risk^{23}. Such indicators may encode, for example, whether, for a pair of variables A and B, A has a causal influence on B, B on A, or neither.
The approach we propose, called ‘deep discriminative causal learning’ (D^{2}CL), is in the latter vein. We consider a version of the causal structure learning problem in which the desired output consists of binary indicators of causal relationships between observed variables^{10,23}, that is, a directed graph with nodes identified with the variables. Available multivariate data X are transformed to provide inputs to a neural network (NN), whose outputs are estimates of the causal indicators. D^{2}CL differs from classical causal structure learning approaches both in terms of the underlying framework (based on causal risk rather than generative causal models) and in leveraging NNs. The assumptions underlying the approach are also different in nature from those in classical causal structure learning and concern higherlevel regularities in the datagenerating processes (Methods). A number of recent papers, including refs. ^{12,13,14,15}, also leverage neural approaches for learning causal structures and share a basis in the continuous optimization framework introduced in ref. ^{11} based on a directed acyclic graph (DAG) framework. D^{2}CL, in contrast, uses a riskbased approach that is not based on DAGs. Eigenmann et al. ^{23} studied causal risk for the assessment of existing learners; instead, we leverage the notion of causal risk to propose a new learner. In common with D^{2}CL, the recently proposed CSIvA method^{24} seeks to directly map input data to a graph output. The key difference is that, while CSIvA uses a metalearning scheme based on largescale synthetic data, D^{2}CL is based on supervised learning using data from a specific system of interest (for example, a biological system; see Supplementary section 1 for a more detailed overview and comparison). We show that contextspecific training allows D^{2}CL to successfully learn structures in a range of scenarios, including challenging realworld experimental data (as detailed in the following). Furthermore, D^{2}CL is demonstrably scalable to large numbers of variables (we show examples ranging up to p = 50,000 nodes) and applicable in regimes where very large sample data or strong simulation engines are not available.
Framework overview
We propose an endtoend neural approach to learn causal networks from a combination of empirical data X and prior causal knowledge Π. The general D^{2}CL workflow and its application to biomolecular problems are summarized in Fig. 1. Here we provide a very brief, highlevel summary of the main ideas. A detailed presentation of the methodology and associated discussion (including of causal semantics and assumptions) are provided in the Methods and Supplementary section 2.
Suppose X_{1}, …, X_{p} is a set of variables whose mutual causal relationships are of interest. Let G* denote an (unknown) graph whose directed edges encode these causal relationships. D^{2}CL seeks to learn G* from two inputs: (1) empirical data X containing measurements on each of the variables of interest and (2) prior causal knowledge Π concerning a subset of causal relationships. This corresponds to a common paradigm in realworld scientific settings, where some data are measured on variables of interest, but only limited knowledge about causal relationships is available at the outset (for example, from prior scientific knowledge or specific experiments).
We formalize the task in the following way. For each ordered pair of variables with indices (i, j) whose causal status is not known via Π, our goal is to learn an indicator of whether or not X_{i} has a causal influence on X_{j}. D^{2}CL treats these causal indicators as ‘labels’ in a machine learning sense, using the available inputs to learn a suitable mapping. The goal of the mapping is to minimize discrepancy with respect to the true, unknown causal status; this learning task can be viewed through the lens of causal risk^{23}. In all experiments, the learner never has access to data in which the parent node of an unknown edge was intervened on. This makes learning challenging, as we require generalization to interventional regimes/distributions that are entirely unseen.
Learning is carried out using a flexible, neural model F_{θ} with a set of trainable parameters θ. The model is trained in a specific fashion that leverages the input information Π as a supervision/training signal to allow the model to learn representations suitable for generalization to novel causal relationships (the Methods provides details and a discussion of the assumptions). The network F_{θ} combines a convolutional neural network (CNN) and a graph neural network (GNN) to resolve distributional and graph structural regularities (Fig. 2). In image processing, CNNs make use of certain properties, such as spatial invariance, that exploit the notion of an image as a function on the plane. Here we leverage the CNN toolkit to capture distributional information in data X, represented as images. We create these visual representations for twotuples of nodes. Specifically, for a variable pair (i, j) we use the n × 2 submatrix X_{(⋅, [ij])}, to form a bivariate kernel density estimate f_{ij} = KDE(X_{(⋅, [ij])}) that is treated as an image input. Note that this is in general asymmetric in the sense that f_{ij} ≠ f_{ji}. This is important, as we want to learn ordered/directed relationships (symmetry here would imply an inability to distinguish the causal direction). The GNN is aimed at capturing graph structural regularities and to this end learns a state embedding h_{j} that contains the information of the neighbourhood for each node j. The GNN requires a graph as input; we provide an initial input graph \({\hat{G}}_{0}\) via computationally lightweight routines solely based on the available data, X (Methods).
Finally, following training, the model F—with parameters now fixed as a function of inputs X and Π—can be used to assign causal status to any pair via an inference step. In the experiments described in the following, the global model output is tested systematically at large scale against either the true graph G* (in simulations) or against entirely unseen interventional experiments (for real biological examples).
Our focus is on causal learning for realworld, highdimensional problems with thousands of nodes and limited data, motivated by largescale biomedical problems. Within the causal risk paradigm^{10,23} we use here, acyclicity (of the directed graphs to be learned) is not assumed, nor is availability of any standard factorization of the joint probability distribution. It is not required that data samples in X are drawn from a single distribution; instead, data can be drawn from, for example, a mix of observational and interventional distributions, and the causal characteristics of these regimes (for example, which node(s) or latents were intervened on) need not be known in advance. Nor is it required that we have interventional data or prior information concerning all nodes. On the contrary, in all experiments, the learner never has access to data in which the parent node of an unknown edge was intervened on nor prior information concerning the unknown edge. This is a common realworld setup, in particular for emerging experimental designs in biology (examples are described in the following). We emphasize that the NNs used are not rotationinvariant and hence can break symmetries and allow inference of causal direction.
Results
We use both simulated data and real biological data to assess performance. In all cases, the model output is tested with respect to causal relationships that are entirely unseen in the sense that (1) causal relationships on which the model output is tested are disjoint from those provided as inputs during training and (2) no data used to define causal relationships against which the model output is tested appear in inputs to the models. Additional results, as well as details of the experimental protocols, are provided in Supplementary sections 3 and 4.
Simulation benchmarks
We tested the methods using data generated from a (linear or nonlinear) structural equation model (SEM) with noise, based on a known underlying causal graph G*. The protocol is outlined in Fig. 3a, with further details provided in Supplementary section 3. In brief, data were generated via structural equations of the form \({X}_{i}={f}_{i}{(\text{Pa}_{{G}^{* }}({X}_{i}),\,{U}_{{X}_{i}})}\), for i = 1, …, p, where p is the total number of variables, \({\text{Pa}_{{G}^{* }}({X}_{i})}\) is the set of parents for node i in the true graph G*, and \({U}_{{X}_{i}}\) are noise variables (exogenous and jointly independent). The functions f_{i} are unknown to the learners. Varying the noise magnitude allows us to control the signaltonoise ratio (SNR), and varying p allows us to understand the effect of dimensionality. The output was tested against the true, goldstandard causal structure G* and hence assessed in causal (and not correlational or predictive) terms.
Insystem, outofdistribution evaluation
Here, model training uses (limited) prior knowledge and data from a given system, and assessment is carried out with respect to unknown edges within the same system (test and training edges are always entirely disjoint). This is outofdistribution in the sense that the learner never has access to samples from the test interventional distributions, but insystem, because all data are from the same overall datagenerating system. This corresponds to a common scientific usecase where the goal is to learn a model for a specific system of interest given available data on that system. Figure 3c shows results for a problem of dimension p = 1,500 using a nonlinear transition function (the tangent hyperbolic; other functions and configurations are shown in Supplementary Tables 2 (area under the curve, AUC) and 3 (area under the precisionrecallcurve, AUPRC)) and varying SNR. (For these first results, we restricted the dimension of the problem to facilitate comparison to existing approaches that are less scalable than D^{2}CL; higherdimensional examples appear in the following.) Note that pairwise correlations between the variables (‘Pearson’) are ineffective; this is expected due to the presence of latent variables in all experiments and the fundamental difference between correlational and causal relationships. Overall, D^{2}CL remains effective across a broad range of SNRs, as well as for a range of linear and nonlinear problems and problem sizes (Supplementary Table 1). We also compared D^{2}CL to DCDFG^{14} and ENCO^{15}, two recently proposed, scalable neuralcausal learners. Owing to computational considerations, we restricted this comparison to a subset of the simulations. Illustrative results are provided in Fig. 3b. We find that neither approach is effective in this case, possibly due to the limited data and the presence of latent variables.
In addition, we tested the effectiveness of D^{2}CL for additive and multiplicative Gaussian noise with varying SNRs under settings with hard deterministic and stochastic interventions. We refer the interested reader to Supplementary section 3 for a definition of an intervention and the types used. The test results (AUC and AUPRC values) are summarized in Supplementary Tables 8 and 9 and support the notion that D^{2}CL is robust to different types of noise.
The graph G* in the above examples encodes direct causal relationships as there is an edge from one node to another if the former appears in the equation for the latter. However, in many realworld examples, interest focuses also on indirect effects, which may be mediated by other nodes. For example, if node A has a direct effect on B, and B on C, intervention on A may change C, even though A does not itself appear in the equation for C. To test the ability to learn indirect edges, we proceeded as above, but with the inputs Π being indirect edges and the output tested against the true indirect graph. Results are presented in Fig. 3d. D^{2}CL outperforms existing methods across a range of SNRs and also in other linear/nonlinear problem configurations (Supplementary Tables 4 and 5). IDA performs well in the case of a linear SEM, but not for functions based on nonlinear multilayer perceptrons. D^{2}CL appears to be the most noiserobust of the methods tested. These results show that D^{2}CL can learn indirect causal edges over many variables under conditions of noise and nonlinearity.
Outofsystem, outofdistribution evaluation
D^{2}CL is trainable using (limited) data from a specific system (for example, a specific biological system, such as cells of a particular kind, or a disease state). However, it is interesting to see whether it is possible to generalize to different systems. To this end, we trained D^{2}CL on a dataset from a certain system and crossevaluated the trained model on data from another system (a different simulation regime). The results are provided in Supplementary Tables 10 and 11. Some generalization appears possible, suggesting that D^{2}CL can find signals that are causally informative in a crosssystem sense, although performance is always worse relative to insystem training (this is expected in our framework, and we emphasize that we do not claim any general ability to achieve outofsystem generalization). Nevertheless, these results broadly support the notion of largescale metalearning for causal structures^{24}.
Largescale evaluation
Finally, to test the scalability of D^{2}CL to highdimensional problems, we considered a problem with p = 50,000 variables (that is, p = 50,000 nodes in the groundtruth graph; note that none of the compared methods can practically scale to this setting). We considered learning of direct causal relationships; the results are shown in Supplementary Table 6 and support the notion that D^{2}CL can scale to problems spanning many thousands of variables.
Largescale biological data
To study performance in the context of real biological data, we leveraged a large set of gene deletion experiments in yeast^{25}, which have previously been used for causal learning^{9,10,26}. The experiments involve measuring gene expression in yeast cells under each of a large number of interventions (gene deletions; Supplementary section 3 provides further details).
In biological experiments, causal effects may be indirect, and we sought to learn a directed graph with nodes corresponding to p observed genes and edges representing (possibly indirect) causal influences. Such edges are scientifically interesting and amenable to experimental verification, as noted in refs. ^{22,27}. Cycles can arise in systems biology (see, for example, ref. ^{28}) and we do not enforce acyclicity (see ref. ^{29} and references therein for a discussion of cyclic causality). A fuller discussion of the causal interpretation of laboratory experiments is beyond the scope of this Article, but relevant work includes refs. ^{29,30,31}, and we direct the interested reader to these references for further discussion.
Because causal background knowledge is an input for our approach, it is relevant to consider performance as a function of the amount of such input. To this end, we fixed the problem size to p = 1,000 and varied the number of interventions m whose effects were available to the learner (Supplementary section 3 provides details). As each experiment involves only a subset of the entire yeast genome, latent variables are present by design. The input prior knowledge Π is derived from the causal status, but, as in all experiments, is strictly disjoint with respect to any test edges.
Results are presented in Fig. 4a–c, including the area under the receiver operating characteristic (ROC) curve (AUC; computed with respect to an experimentally determined gold standard; Supplementary section 3). Overall, the proposed methods perform well, achieving good results in this highdimensional, limiteddata regime. Next, to shed light on data efficiency, we varied the sample size n of the data matrix X (Fig. 4d–f).
Finally, we tested the performance in a higherdimensional example spanning all p = 5,535 available genes (Fig. 4g–k) and found that D^{2}CL remains effective at the genome scale. Interestingly, although the CNN tower performs particularly well, the GNN tower degrades more. This may be because larger p leads to a larger number of variable pairs (which is helpful for the CNN), but also to a (rapid) increase in the number of nodes and edges in the GNN subgraphs and hence a harder GNN learning task in practice.
D^{2}CL leverages prior causal knowledge. However, in practice, the available causal inputs Π may be incorrect, for example, due to flawed initial experiments or errors in the known science. To study sensitivity to flawed causal inputs, we introduced errors into Π. This was done by perturbing 10% of the inputs (that is, labelling causal pairs as noncausal and vice versa) at the outset. The results are shown in Fig. 5a and demonstrate a level of robustness to such perturbation. We also see a benefit of the dual network variants; this is investigated further in Fig. 5b. For the latter, in general, the embedding of either tower is modified immediately before the fusion layer. We considered several different modifications: setting the embedding of one tower to zero and hence effectively removing all information from this tower, or applying Gaussian noise with magnitude σ = 1.0, σ = 2.0 and σ = 5.0.
Causal relations are in general directed and asymmetric, so it is interesting to explore model behaviour with respect to causal direction. Given an image representation, the CNN tower is designed to extract feature maps that are unique for ordered node pairs, that is, such that in general features differ depending on edge direction. To empirically study learning of causal direction, we constructed additional test data as follows: for each truly causal edge k → l in the test set, we also included the reverse direction l → k. This means that any learner estimating undirected links would have an AUC score of 0.5 (because the output k → l entails also l → k, one of which is a false positive). Supplementary Table 4 shows that D^{2}CL is indeed capable of accurately identifying causal direction. In addition, Fig. 5c shows a lowdimensional representation of the feature maps of the converged CNN tower. These feature maps differ by causal direction (k → l versus l → k) throughout the forward pass, supporting the foregoing arguments.
Highdimensional CRISPRbased perturbations
Finally, we used recent, singlecell clustered regularly interspaced short palindromic repeats (CRISPR)based interventional experiments^{32} to illustrate the use of the proposed approaches in very highdimensional data from human cells. The experimental protocol (see ref. ^{32} for full details) includes a large number of interventions in a leukaemia cell line (K562) and in retinal pigment epithelial (RPE) cells. The K562 and RPE experiments include geneexpression levels for a total of, respectively, p = 8,552 and p = 8,833 genes (Supplementary section 3 provides details). This is a challenging setting due to the known complexity of regulatory events in human cells and high levels of variability and noise in singlecell protocols. The results are presented in Fig. 6 and demonstrate good performance for RPE, and slightly worse performance, but still nontrivial consistency with the experimental gold standard, for K562. Additional plots in Fig. 6 and Supplementary Fig. 3 show the performance and runtime for a set of baseline algorithms. These results demonstrate two key points. First, the runtime for many available algorithms grows so rapidly with increasing number of variables as to render them unsuitable for problems at this scale. Second, for existing methods that are at all able to scale to larger problems, performance is considerably less effective than D^{2}CL in this setting.
Conclusions
Emerging experimental protocols, involving combinations of perturbations and highdimensional readouts, are allowing for new, scalable ways to query molecular networks in a contextspecific fashion. Combined with scalable causal learning tools, these approaches have the potential to strongly impact disease biology by allowing global networks, spanning thousands or tens of thousands of variables, to be investigated across many different contexts.
Networks learned in this way could, in the future, be leveraged to allow for the prediction of disease phenotypes or drug response under novel perturbations (this is a different task from standard supervised learning, because the test case involves an unseen perturbation to the system). Furthermore, in the area of personalized medicine, such an approach could even allow for rational optimization over potential therapeutic strategies, because the latter are often interventions targeted at molecular nodes.
Our model leverages deep learning tools to learn causal relationships between variables at large scale. However, and in contrast to wellestablished approaches based on causal graphical models, it provides only a structural output rather than a probability model of the underlying system. It is also interesting to contrast D^{2}CL with the recently proposed CSIvA^{24}. Both approaches pursue, in a sense, a ‘direct’ mapping of data inputs to graph outputs, with a key difference being that CSIvA uses metalearning and seeks to generalize across systems, whereas D^{2}CL uses supervised learning to generalize to new interventions on a given system (for example, a biological system of interest). An interesting direction for future work may be to combine both approaches, for example by using CSIvA to provide the initial input to D^{2}CL; this would combine general, simulationbased learning and dataefficient, systemspecific training.
At present, rigorous theory and an understanding of the theoretical properties of the kind of approach studied here remain lacking. A key direction for future theoretical work will be to understand the precise conditions for the underlying system that are needed to ensure that direct mapping approaches can guarantee the recovery of specific causal structures. An interesting observation is that the proposed approach may benefit from a ‘blessing of dimensionality’, because the learning problem will typically enjoy a larger number of examples as dimension p grows. Conversely, and in contrast to established statistical causal models, our approach (at the current stage) cannot be used in the smallp regime, because the number of examples will be too small for deep learning.
Methods
In this section, we provide information on the causal interpretation of our learning scheme, as well as a more detailed presentation of the architecture and associated implementation.
Notation
Observed variables with index set V = {1, …, p} are denoted X_{1}, …, X_{p}. The variables will be identified with vertices in a directed graph G whose vertex and edge sets are denoted V(G) and E(G), respectively. We occasionally overload G to refer also to the corresponding binary adjacency matrix, using G_{ij} to refer to the entry (i, j) of the adjacency matrix, as will be clear from context. We use linear indexing of variable pairs to aid formulation as a machine learning problem. Specifically, an ordered pair (i, j) ∈ V × V has an associated linear index \({k\in {{{\mathcal{K}}}}}=\{1,\,\ldots ,\,K\,\}\), where K is the total number of variable pairs of interest. Where useful, we make the mapping explicit, denoting the linear index corresponding to a pair (i, j) as k(i, j) and the variable pair corresponding to a linear index k as (i(k), j(k)). The linear indices of pairs whose causal relationships are unknown and of interest are \({{{\mathcal{U}}}}\subset {{{\mathcal{K}}}}\), and those pairs known in advance via input knowledge Π are \({{{\mathcal{T}}}}{(\varPi )}\subset {{{\mathcal{K}}}}\). In all experiments, \({{{\mathcal{T}}}}{(\varPi )}\) and \({{{\mathcal{U}}}}\) are disjoint; that is, no prior causal information is available on the pairs \({{{\mathcal{U}}}}\) of interest.
Problem statement
We focus on the setting in which the available inputs are

(I1) Empirical data: an n × p data matrix X whose columns correspond to variables X_{1}, …, X_{p}.

(I2) Causal background knowledge Π providing information on a subset \({{{\mathcal{T}}}}{(\varPi )}\subset {{{\mathcal{K}}}}\) of causal relationships.
For (I2), we assume that information is available concerning the causal status of a subset of variable pairs. That is, for some variable pairs (X_{i}, X_{j}) the correct binary indicator \({G}_{ij}^{* }\), representing the presence/absence of an edge in the target graphical object, is provided as an input. In terms of linear indexing, these can be viewed as available ‘labels’ of causal status for the pairs \({{{\mathcal{T}}}}{(\varPi )}\subset {{{\mathcal{K}}}}\). No specific assumption is made on the data X, but, in line with our focus on generalizing to unseen causal relationships, it is assumed that it does not contain interventional data corresponding to the pairs in \({{{\mathcal{U}}}}\). Furthermore, in all experiments, not only are the sets \({{{\mathcal{T}}}}\) and \({{{\mathcal{U}}}}\) disjoint, but we enforce the stronger requirement that \({u\in {{{\mathcal{U}}}}\ \Rightarrow \ \nexists {j:k(i(u),\,j)}\in {{{\mathcal{T}}}}}\). This means that all interventions on which models are tested are entirely novel, that is, unrepresented in the inputs to the learner, either as data or prior input. This also means that the learner has no access whatsoever to samples from the test interventional distributions, and all experiments are outofdistribution in this sense.
The learning task can thus be formulated as follows: given inputs (I1) and (I2), the goal is to estimate, for each ordered pair of variables (X_{i}, X_{j}) with unknown causal relationship, whether or not X_{i} has a causal influence on X_{j}.
Summary of the learning scheme
With the notation above, our goal is to learn a graph whose nodes correspond to the variables X_{1}, …, X_{p} and whose edges represent causal relationships. To this end, we train a parameterized network F_{θ}, that is, a nonlinear function F with a set of unknown, trainable parameters θ. This is possible, because we know for each pair \({k\in {{{\mathcal{T}}}}}\) the causal status \({G}_{ij}^{* }\) based on input information Π. The architecture we use as F_{θ} is detailed below, but for now assume this has been specified. Then, given the data X and the training labels \({Y}_{k}={G}_{i(k),\,j(k)}^{* }\) for all pairs \({k\in {{{\mathcal{T}}}}(\varPi )}\), we train the set of parameters \({\hat{\theta }(X,\,\varPi )}\) under a loss that is supervised by the (causal) labels Y_{k}.
At this stage, the trained network \({F}_{\hat{\theta }(X,\,\varPi )}\) allows assignment of causal status to any pair, because it gives an estimate of the entire graph including those pairs whose causal status was unknown. The output is given by
where (i, j) are ordered variable pairs. Note that the overall estimate depends solely on the data X and causal information Π. By default, no change is made for pairs \({{{\mathcal{T}}}}\) whose status was known at the outset. Reference ^{23} studied causal notions of risk based on loss functions of the form that compare a graph estimate \({\hat{G}}\) with ground truth G*. In our setting, we consider a classificationtype loss on the variable pairs k, where the causal status of known pairs \({{{\mathcal{T}}}}{(\varPi )}\) provides the training ‘labels’. We therefore use the corresponding binary crossentropy loss, augmented by additional terms that, for example, prevent exploding weights.
Causal interpretation of the learning scheme
D^{2}CL outputs a directed graph. The discriminative nature of D^{2}CL means that the notion of causal influence encoded by the edges is rooted in the application setting and input information Π, because causal semantics are inherited via the problem setting rather than specified by a generative model (see ref. ^{10} for related discussions). Indeed, in the experiments we showed that D^{2}CL could be used to successfully learn either direct or indirect/ancestral causal relationships.
Here we provide some intuition as to why discriminative learning can be effective in this setting. However, we note that the following arguments are not intended to constitute a rigorous theory at this stage, but rather to help gain an understanding of the conditions under which discriminative causal structure learning may be expected to be effective.
We start with a general causal framework and then introduce assumptions for D^{2}CL (the metagenerator assumption (MGA) and the dominant cause under single intervention (DCSI), described in the following sections). Following refs. ^{1,33}, we assume decomposition of the underlying system into modular and independent mechanisms:
Independent causal mechanisms (ICMs)
The causal generative process of a system’s variables is composed of autonomous modules that do not inform or influence each other.
For variables X_{i} assume a structural causal model with equations \({X}_{i}={{f}_{i}{(\text{Pa}_{{G}^{* }}({X}_{i}),\,{U}_{{X}_{i}})},\,{i=1},\,\ldots ,\;p}\), where \({\text{Pa}_{{G}^{* }}({X}_{i})}\) denotes the set of parents in the groundtruth graph G* for node i, and f_{i} is a nodespecific function. Exogeneous noise terms \({U}_{{X}_{i}}\) are assumed jointly independent and distributed as \({U}_{{X}_{i}} \sim {p}_{i}\), where p_{i} is a nodespecific density.
Our approach treats the f_{i} and p_{i} as unknown, but assumes they are related at a higher level. This can be formalized as a metagenerator assumption as follows.
Metagenerator assumption (MGA)
For a specific system W, the functions f_{i} and noise distributions p_{i} are (independently) generated as \({f}_{i} \sim {{{{\mathcal{F}}}}}_{W}\) and \({p}_{i} \sim {{{{\mathcal{P}}}}}_{W}\), where \({{{{\mathcal{F}}}}}_{W}\) denotes a function generator and \({{{{\mathcal{P}}}}}_{W}\) a stochastic generator, that are specific to the applied problem setting W.
MGA is motivated by the notion that in any particular realworld system, underlying (biological, physical, social and so on) processes tend to share some functional and stochastic aspects, which impart some higherlevel regularity. That is, MGA states that, in a given applied context, functions f_{i} and (independent causal mechanismconsistent) noise terms \({U}_{{X}_{i}}\), while unknown, varied and potentially complex, are nonetheless related at a ‘meta’level. The generators \({{{{\mathcal{F}}}}}_{W},\,{{{{\mathcal{P}}}}}_{W}\) are random processes, representing, respectively, a ‘distribution over functions’ and a ‘distribution over distributions’, whose role here is to capture the notion of relatedness among f_{i} functions (respectively p_{i}) in a given setting W. Note that \({{{{\mathcal{F}}}}}_{W},{{{{\mathcal{P}}}}}_{W}\) are treated as unknown and never directly estimated.
As mentioned in the problem statement, we focus on the causal status of variable pairs (X_{i}, X_{j}) (rather than general tuples), which denotes the simplest possible case under MGA. Furthermore, in both our work and the majority of interventional studies in applications such as biology, single interventions (rather than joint interventions on multiple nodes) are the norm. This requires the additional assumption, DCSI.
Dominant cause under single interventions (DCSI)
A sufficiently large change in one of potentially multiple causes leads to a change with respect to the effect. Therefore, single interventions are sufficient to drive variation in the child distribution.
From MGA and DCSI to discriminative causal structure learning
Consider an applied problem W with underlying causal graph \({G}_{W}^{* }\), treated as fixed but unknown. The associated functions and noise terms are also unknown but assumed to follow MGA. Then, under DCSI, we have that all pairs of the form (X_{i}, X_{j}) have underlying relationships of the form \({X}_{j}={f}_{j}({X}_{i},\,{U}_{{X}_{j}})\) with components following the MGA (that is, drawn from generators \({{{{\mathcal{F}}}}}_{W},\,{{{{\mathcal{P}}}}}_{W}\)). This in turn suggests that within the setting W, identification of causal pairs can be treated as a classification problem, as all pairs share the same generators. In other words, MGA restricts the distribution over relations of variables and noise terms to systemspecific generators.
Note that no particular assumption is made on the individual functions f_{j}, only that they are mutually related on a higher level. Furthermore, the generators themselves need not be known nor directly estimated; rather, it is only important that they are shared across the applied setting W. Note that a model learned for setting W will not in general be able to classify pairs in an entirely different applied setting W′ (because the generators may then differ strongly); that is, we do not seek to learn ‘universal’ patterns that apply to all causal relations in any system whatsoever. The classification task of D^{2}CL aims to tell apart causal relationships (assumed drawn from the systemspecific generators) from noncausal ones. We note that, in real systems, f_{i} functions may be coupled via constraints on global functionality, and are thus nonindependent; however, the good performance seen in the results empirically justifies the approach. Despite the initial theoretical ideas described above, rigorous theory and the theoretical properties of the kind of approach studied here remain to be understood, in particular the precise conditions for the underlying system needed to ensure that the classificationtype approach can guarantee recovery of specific causal structures. We emphasize also that in contrast to classical causal learning schemes, for example, based on causal DAGs, we cannot at this stage make theoretical statements concerning underlying multivariate distributions and their link to edges estimated by our models. Our goal is good performance in an edgewise sense (as detailed above), and the core assumptions (formalized above) concern a limited notion of classifiability. We note also that our models at present learn edges separately and do not impose any particular wider/global constraints (such as acyclicity or path constraints), although this could in principle be done within the causal risk framework.
Architecture details
CNN tower
To capture distributional information from empirical data X, a preprocessing step is required. In principle, this could be done via a variety of multidimensional transformations of X. We consider the simplest possible case, namely for a pair (i, j) to consider only the corresponding columns i and j in the data matrix X. Specifically, we use the n × 2 submatrix X_{(⋅, [ij])} to form a bivariate kernel density estimate f_{ij} = KDE(X_{(⋅, [ij])}). Note that this is, in general, asymmetric in the sense that f_{ij} ≠ f_{ji}, which is important as we want to learn ordered/directed relationships. In other words, this ensures that, in general, the CNN tower can output different probabilities for edges A → B and B → A (for any two nodes A and B). Evaluations of the KDE at equally spaced grid points on the plane (that is, numerical values from the induced density function) are treated as the input to the CNN. The KDE itself is a standard bivariate approach using automated bandwidth selection following refs. ^{34,35}. This provides an ‘image’ of the data and allows us to leverage existing image analysis ideas. Furthermore, we concatenate channelwise the numerical KDE values on the regularly spaced grid with a positional encoding of the grid points.
The concrete network architecture of our CNN tower is inspired by a ResNet54 architecture^{36}. From a highlevel perspective, it consists of a stem, five stages with [3, 4, 6, 3, 3] ResNet blocks and multiple fully connected layers that transform the highlevel feature maps into a latent space that is merged with the output of the GNN tower. The first ResNet block at each stage downsamples the spatial dimensions of the output of the previous stage by a factor of two. To enhance the computational efficiency of the bottleneck layers in each ResBlock, channel down and upsampling exploiting 1 × 1 convolutions is performed before and after each featureextraction CNN layer^{37}. We replaced ReLU activations by the parametric counterpart PReLU^{38}, allowing us to learn the slope of the negative part at negligible additional computational costs, which addresses the problem of dying neurons. Following ref. ^{39}, we chose a full preactivation of the convolutional layers, normalization–activation–convolution.
GNN tower
Our GNN tower builds on the SEAL architecture of ref. ^{40} and the resulting graph convolutional neural network (GCNN) for link prediction. The underlying notion is that a heuristic function predicts scores for the existence of a link. However, instead of employing predefined heuristics (such as the Katz coefficient or PageRank), an adaptive function is learned in an endtoend fashion, which is formulated as a graph classification problem on enclosing subgraphs. Reference ^{40} showed that a γdecaying heuristic can be approximated by an hhop neighbourhood while the approximation error is at least decreasing exponentially. These findings suggest that it is possible to learn highorder graph structure features from local enclosing subgraphs instead of the entire graph, which can be exploited for link prediction. Consider the pair of nodes of interest (i, j); the GNN tower is intended to infer causally interesting node features and state embeddings based on a local 1hop enclosing subgraph extracted from the approximated input graph \({\hat{G}}_{0}\). For node pair (i, j), we first extract a set of nodes \({{{\mathcal{N}}}}\) with all nodes that are connected to either node i or node j based on the adjacency matrix of the approximated input graph \({\hat{G}}_{0}\). The edge structure within the subgraph G_{i, j} is then reconstructed by pulling out all edges from \({\hat{G}}_{0}\) for which the parent and child node are in \({{{\mathcal{N}}}}\). The order of the nodes is shuffled for each subgraph. The node features in every input subgraph consist of structural node labels that are assigned by a doubleradius node labelling (DRNL) heuristic^{40} and the individual data features. In a first step, the distances between node i and all other nodes of the local subgraph except node j are computed. The same is repeated for node j. A hashing function then transforms the two distance labels into a DRNL label that assigns the same label to nodes that are on the same ‘orbit’ around the centre nodes i and j. During the training process, the DRNL label is transformed into a onehot encoded vector and passed to the first graph convolutional layer. In contrast to traditional CNNs, GCNNs do not benefit strongly from very deep architecture design^{41,42}. Therefore, our GNN tower consists only of four sequentially stacked graph convolutional layers. The activation function is defined as the hyperbolic tangent. Because the number of nodes in the enclosing subgraph for each pair of variables (i, j) is different, a SortPooling layer^{43} is applied to select the top k nodes according to their structural role within the graph. Afterwards, onedimensional convolutions extract features from the selected state embeddings.
Embedding fusion
Each tower outputs a highdimensional embedding of the individual features found. These embeddings are concatenated and further processed by multiple fully connected layers. Finally, the last layers output the loglikelihood of a directed edge from node i to node j.
Implementation details
All network architectures are implemented in the opensource framework PyTorch^{44}. The GNN is coded based on the Deep Graph Library^{45}. All modules are initialized from scratch using random weights. During training, we apply an AdamOptimizer^{46} starting at an initial learning rate of ϵ_{0} = 0.0001. The learning rate is reduced by a factor of five once the evaluation metrics stop improving for 15 consecutive epochs. The minimum learning rate is set to ϵ_{min} = 10^{−8}. The training predictions are supervised on the binary crossentropy loss between estimated and groundtruth edge labels. The evaluation metric is the (heldout) area under the ROC curve. Every network architecture is trained for 100 epochs. All computations are run on multiple graphics processing unit (GPU) nodes simultaneously, each equipped with eight Nvidia Tesla V100 GPUs.
Data availability
Data files are publicly available as follows. Yeast gene deletion data are from ref. ^{25}. CRISPR perturbation data are from ref. ^{32}. The pseudocode for data simulation is provided in Supplementary section 5.
Code availability
A Code Ocean compute capsule, which contains a prebuilt compute environment and the source code of D^{2}CL, is available at https://codeocean.com/capsule/4465854/tree/v1 ref. ^{47}.
References
Peters, J., Janzing, D. & Schölkopf, B. Elements of Causal Inference: Foundations and Learning Algorithms (MIT Press, 2017).
Arjovsky, M., Bottou, L., Gulrajani, I. & LopezPaz, D. Invariant risk minimization. Preprint at https://arxiv.org/abs/1907.02893 (2019).
HeinzeDeml, C., Maathuis, M. H. & Meinshausen, N. Causal structure learning. Annu. Rev. Stat. Appl. 5, 371–391 (2018).
Spirtes, P., Glymour, C. & Scheines, R. Causation, Prediction and Search (MIT Press, 2000).
Shimizu, S., Hoyer, P. O., Hyvärinen, A. & Kerminen, A. A linear nonGaussian acyclic model for causal discovery. J. Mach. Learn. Res. 7, 2003–2030 (2006).
Maathuis, M. H., Kalisch, M. & Bühlmann, P. Estimating highdimensional intervention effects from observational data. Ann. Stat. 37, 3133–3164 (2009).
Hauser, A. & Bühlmann, P. Characterization and greedy learning of interventional Markov equivalence classes of directed acyclic graphs. J. Mach. Learn. Res. 13, 2409–2464 (2012).
Colombo, D., Maathuis, M. H., Kalisch, M. & Richardson, T. S. Learning highdimensional directed acyclic graphs with latent and selection variables. Ann. Stat. 40, 294–321 (2012).
Peters, J., Bühlmann, P. & Meinshausen, N. Causal inference using invariant prediction: identification and confidence intervals. J. R. Stat. Soc. 78, 947–1012 (2016).
Hill, S. M., Oates, C. J., Blythe, D. A. & Mukherjee, S. Causal learning via manifold regularization. J. Mach. Learn. Res. 20, 127 (2019).
Zheng, X., Aragam, B., Ravikumar, P. K. & Xing, E. P. DAGs with no tears: continuous optimization for structure learning. In Proc. Advance in Neural Information Processing Systems Vol. 31, 9472–9483, (eds Bengio, S. et al.) (Curran Associates, 2018).
Ke, N. R. et al. Learning neural causal models from unknown interventions. Preprint at https://arxiv.org/abs/1910.01075 (2019).
Brouillard, P., Lachapelle, S., Lacoste, A., LacosteJulien, S. & Drouin, A. Differentiable causal discovery from interventional data. Adv. Neural Inf. Process. Syst. 33, 21865–21877 (2020).
Lopez, R., Hütter, J.C., Pritchard, J. & Regev, A. Largescale differentiable causal discovery of factor graphs. Adv. Neural Inf. Process. Syst. 35, 19290–19303 (2022).
Lippe, P., Cohen, T. & Gavves, E. Efficient neural causal discovery without acyclicity constraints. In International Conference on Learning Representations (2022).
Ideker, T. & Krogan, N. J. Differential network biology. Mol. Syst. Biol. 8, 565 (2012).
Hill, S. M. et al. Inferring causal molecular networks: empirical assessment through a communitybased effort. Nat. Methods 13, 310–318 (2016).
Hill, S. M. et al. Context specificity in causal signaling networks revealed by phosphoprotein profiling. Cell Syst. 4, 73–83 (2017).
Kuenzi, B. M. & Ideker, T. A census of pathway maps in cancer systems biology. Nat. Rev. Cancer 20, 233–246 (2020).
LopezPaz, D., Muandet, K., Schölkopf, B. & Tolstikhin, I. Towards a learning theory of causeeffect inference. In Proc. 32nd International Conference on Machine Learning Vol. 37, 1452–1461 (eds Bach, F. et al.) (PMLR, 2015).
Mooij, J. M., Peters, J., Janzing, D., Zscheischler, J. & Schölkopf, B. Distinguishing cause from effect using observational data: methods and benchmarks. J. Mach. Learn. Res. 17, 1–102 (2016).
Noè, U., Taschler, B., Täger, J., Heutink, P. & Mukherjee, S. Ancestral causal learning in high dimensions with a human genomewide application. Preprint at https://arxiv.org/abs/1905.11506 (2019).
Eigenmann, M., Mukherjee, S. & Maathuis, M. Evaluation of causal structure learning algorithms via risk estimation. In Proc. 36th Conference of Uncertainty in Artificial Intelligence 2020, UAI 2020 Vol. 124, 151–160 (eds Peters, J. et al.) (PMLR, 2020).
Ke, N. R. et al. Learning to induce causal structure. Preprint at https://arxiv.org/abs/2204.04875 (2022).
Kemmeren, P. et al. Largescale genetic perturbations reveal regulatory networks and an abundance of genespecific repressors. Cell 157, 740–752 (2014).
Meinshausen, N. et al. Methods for causal inference from gene perturbation experiments and validation. Proc. Natl Acad. Sci. USA 113, 7361–7368 (2016).
Zhang, J. Causal reasoning with ancestral graphs. J. Mach. Learn. Res. 9, 1437–1474 (2008).
Alon, U. An Introduction to Systems Biology: Design Principles of Biological Circuits (CRC Press, 2019).
Hyttinen, A., Eberhardt, F. & Hoyer, P. O. Learning linear cyclic causal models with latent variables. J. Mach. Learn. Res. 13, 3387–3439 (2012).
Eberhardt, F. & Scheines, R. Interventions and causal inference. Philos. Sci. 74, 981–995 (2007).
Kocaoglu, M., Shanmugam, K. & Bareinboim, E. Experimental design for learning causal graphs with latent variables. In Proc. Advance in Neural Information Processing Systems Vol. 30, 7018–7028, (eds Guyon, I. et al.) (Curran Associates, 2017).
Replogle, J. M. et al. Mapping informationrich genotypephenotype landscapes with genomescale Perturbseq. Cell 185, 2559–2575 (2022).
Schölkopf, B. et al. On causal and anticausal learning. In Proc. 29th International Conference on Machine Learning, ICML 2012 459–466 (eds Langford, J. et al.) (icml.cc/Omnipress, 2012).
Silverman, B. W. Density Estimation for Statistics and Data Analysis (Chapman & Hall, 1986).
Turlach, B. Bandwidth selection in kernel density estimation: a review. Technical Report (1999).
He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for image recognition. In Proc. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 770–778 (IEEE, 2016).
Szegedy, C. et al. Going deeper with convolutions. In Proc. 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 1–9 (IEEE, 2015).
He, K., Zhang, X., Ren, S. & Sun, J. Delving deep into rectifiers: surpassing humanlevel performance on ImageNet classification. In Proc. 2015 IEEE International Conference on Computer Vision (ICCV) 1026–1034 (IEEE, 2015).
Xie, S., Girshick, R., Dollár, P., Tu, Z. & He, K. Aggregated residual transformations for deep neural networks. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 5998–5995 (IEEE, 2017).
Zhang, M. & Chen, Y. Link prediction based on graph neural networks. In Proc. Advances in Neural Information Processing Systems 2018 Vol. 31, 5165–5175 (eds Bengio, S. et al.) (Curran Associates, 2018).
Chen, D. et al. Measuring and relieving the oversmoothing problem for graph neural networks from the topological view. Computing Research Repository (CoRR) https://doi.org/10.1609/aaai.v34i04.5747 (2019).
Li, Q., Han, Z. & Wu, X.M. Deeper insights into graph convolutional networks for semisupervised learning. In Proc. 32nd AAAI Conference on Artificial Intelligence 3538–3545 (eds McIlraith, S. et al.) (AAAI, 2018).
Zhang, M., Cui, Z., Neumann, M. & Chen, Y. An endtoend deep learning architecture for graph classification. In Proc. 32nd AAAI Conference on Artificial Intelligence 4438–4445 (eds McIlraith, S. et al.) (AAAI, 2018).
Paszke, A. et al. PyTorch: an imperative style, highperformance deep learning library. In Proc. Advances in Neural Information Processing Systems Vol. 32, 8026–8037 (eds Wallach, H. et al.) (Curran Associates, 2019).
Wang, M. et al. Deep Graph Library: a graphcentric, highlyperformant package for graph neural networks. Preprint at https://arxiv.org/abs/1909.01315 (2019).
Kingma, D. P. & Ba, J. Adam: a method for stochastic optimization. In 3rd International Conference on Learning Representations (2015).
Lagemann, K., Lagemann, C., Taschler, B. & Mukherjee, S. Deep learning of causal structures in high dimensions under data limitations https://codeocean.com/capsule/4465854/tree/v1CodeOcean (2023).
Acknowledgements
This work was partly supported by the German Federal Ministry of Education and Research (BMBF) project ‘LODE’, the UK Medical Research Council (MCUU00002/17) and the National Institute for Health Research (Cambridge Biomedical Research Centre at the Cambridge University Hospitals NHS Foundation Trust).
Funding
Open access funding provided by Deutsches Zentrum für Neurodegenerative Erkrankungen e.V. (DZNE) in der HelmholtzGemeinschaft.
Author information
Authors and Affiliations
Contributions
Methods were developed by K.L. and S.M. Implementation and experiments were performed by K.L., supported by C.L. B.T. contributed to the design and implementation of experiments using the baseline algorithms. The manuscript was written by K.L. and S.M., with input from C.L. and B.T. The research was supervised by S.M.
Corresponding authors
Ethics declarations
Competing interests
The authors declare no competing interests.
Peer review
Peer review information
Nature Machine Intelligence thanks the anonymous reviewers for their contribution to the peer review of this work. Primary Handling Editor: Liesbeth Venema, in collaboration with the Nature Machine Intelligence team.
Additional information
Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Supplementary information
Supplementary Information
Supplementary discussion, Figs. 1–3 and Tables 1–12.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons license, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons license, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons license and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Lagemann, K., Lagemann, C., Taschler, B. et al. Deep learning of causal structures in high dimensions under data limitations. Nat Mach Intell 5, 1306–1316 (2023). https://doi.org/10.1038/s4225602300744z
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1038/s4225602300744z
This article is cited by

Towards extending the aircraft flight envelope by mitigating transonic airfoil buffet
Nature Communications (2024)

Challenges of deep unsupervised optical flow estimation for particleimage velocimetry data
Experiments in Fluids (2024)

Deep learning of causal structures in high dimensions under data limitations
Nature Machine Intelligence (2023)