Introduction

Our capacity to prevent or contain outbreaks of infectious diseases is directly linked to our ability to accurately model contagion dynamics. Since the seminal work of Kermack and McKendrick almost a century ago1, a variety of models incorporating ever more sophisticated contagion mechanisms has been proposed2,3,4,5. These mechanistic models have provided invaluable insights about how infectious diseases spread, and have thereby contributed to the design of better public health policies. However, several challenges remain unresolved, which call for contributions from new modeling approaches6,7,8.

For instance, many complex contagion processes involve the nontrivial interaction of several pathogens9,10,11,12, and some social contagion phenomena, like the spread of misinformation, require to go beyond pairwise interactions between individuals13,14,15. Also, while qualitatively informative, the forecasts of most mechanistic models lack quantitative accuracy16. Indeed, most models are constructed from a handful of mechanisms which can hardly reproduce the intricacies of real complex contagion dynamics. One approach to these challenges is to complexify the models by adding more detailed and sophisticated mechanisms. However, mechanistic models become rapidly intractable as new mechanisms are added. Moreover, models with higher complexity require the specification of a large number of parameters whose values can be difficult to infer from limited data.

There has been a recent gain of interest towards using machine learning to address the issue of the often-limiting complexity of mechanistic models12,17,18,19,20,21,22,23. This new kind of approach aims at training predictive models directly from observational time series data. These data-driven models are then used for various tasks such as making accurate predictions19,21, gaining useful intuitions about complex phenomena12 and discovering new patterns from which better mechanisms can be designed17,18. Although these approaches were originally designed for regularly structured data, this new paradigm is now being applied to epidemics spreading on networked systems24,25, and more generally to dynamical systems26,27,28. Meanwhile, the machine learning community has dedicated a considerable amount of attention to deep learning on networks, structure learning and graph neural networks (GNN)29,30,31. Recent works showed great promise for GNN in the context of community detection32, link prediction33, network inference34, as well as for the discovery of new materials and drugs35,36. Yet, others have pointed out the inherent limitations of a majority of GNN architectures in distinguishing certain network structures31, in turn limiting their learning capabilities. Hence, while recent advances and results37,38,39,40 suggest that GNN could be prime candidates for building effective data-driven dynamical models on networks, it remains to be shown if, how and when GNN can be applied to dynamics learning problems.

In this paper, we show how GNN, usually used for structure learning, can also be used to model contagion dynamics on complex networks. Our contribution is threefold. First, we design a training procedure and an appropriate GNN architecture capable of representing a wide range of dynamics with very few assumptions. Second, we demonstrate the validity of our approach using various contagion dynamics of increasing complexity on networks of different natures, as well as real epidemiological data. Finally, we show how our approach can provide predictions for previously unseen network structures, therefore allowing the exploration of the properties of the learned dynamics beyond the training data. Our work generalizes the idea of constructing dynamical models from regularly structured data to arbitrary network structures, and suggests that our approach could be accurately extended to many other classes of dynamical processes.

Results

In our approach, we assume that an unknown dynamical process, denoted \({{{{{\mathcal{M}}}}}}\), takes place on a known network structure—or ensemble of networks—, denoted \(G=({{{{{\mathcal{V}}}}}},{{{{{\mathcal{E}}}}}};{{{{{\boldsymbol{\Phi }}}}}},{{{{{\boldsymbol{\Omega }}}}}})\), where \({{{{{\mathcal{V}}}}}}=\left\{{v}_{1},\cdots \ ,{v}_{N}\right\}\) is the node set and \({{{{{\mathcal{E}}}}}}=\{{e}_{ij}| {v}_{j}\,{{\mbox{is connected to}}}\;{v}_{i}{{\mbox{}}}\wedge ({v}_{i},{v}_{j})\in {{{{{{\mathcal{V}}}}}}}^{2}\}\) is the edge set. We also assume that the network(s) contains metadata, taking the form of node and edge attributes denoted Φi = (ϕ1(vi),  , ϕQ(vi)) for node vi and Ωij = (ω1(eij),  , ωP(eij)) for edge eij, respectively, where \({\phi }_{q}:{{{{{\mathcal{V}}}}}}\to {\mathbb{R}}\) and \({\omega }_{p}:{{{{{\mathcal{E}}}}}}\to {\mathbb{R}}\). These metadata can take various forms like node characteristics or edge weights. We also denote the node and edge attribute matrices \({{{{{\boldsymbol{\Phi }}}}}}=({{{\Phi }}}_{i}| {v}_{i}\in {{{{{\mathcal{V}}}}}})\) and \({{{{{\boldsymbol{\Omega }}}}}}=({{{\Omega }}}_{ij}| {e}_{ij}\in {{{{{\mathcal{E}}}}}})\), respectively.

Next, we assume that the dynamics \({{{{{\mathcal{M}}}}}}\) has generated a time series D on the network G. This time series takes the form of a pair of consecutive snapshots D = (X, Y) with \({{{{{\boldsymbol{X}}}}}}=\left({X}_{1},\cdots \ ,{X}_{T}\right)\) and \({{{{{\boldsymbol{Y}}}}}}=\left({Y}_{1},\cdots \ ,{Y}_{T}\right)\), where \({X}_{t}\in {{{{{{\mathcal{S}}}}}}}^{| {{{{{\mathcal{V}}}}}}| }\) is the state of the nodes at time t, \({Y}_{t}\in {{{{{{\mathcal{R}}}}}}}^{| {{{{{\mathcal{V}}}}}}| }\) is the outcome of \({{{{{\mathcal{M}}}}}}\) defined as

$${Y}_{t}={{{{{\mathcal{M}}}}}}\left({X}_{t},G\right)\ ,$$
(1)

\({{{{{\mathcal{S}}}}}}\) is the set of possible node states, and \({{{{{\mathcal{R}}}}}}\) is the set of possible node outcomes. This way of defining D allows us to formally concatenate multiple realizations of the dynamics in a single dataset. Additionally, the elements \({x}_{i}(t) = {({X}_{t})}_{i}\) and \({y}_{i}(t) = {({Y}_{t})}_{i}\) correspond to the state of node vi at time t and its outcome, respectively. Typically, we consider that the outcome yi(t) is simply the state of node vi after transitioning from state xi(t). In this case, we have \({{{{{\mathcal{S}}}}}}={{{{{\mathcal{R}}}}}}\) and xi(t + Δt) = yi(t) where Δt is the length of the time steps. However, if \({{{{{\mathcal{S}}}}}}\) is a discrete set—i.e. finite and countable—yi(t) is a transition probability vector conditioned on xi(t) from which the following state, xi(t + Δt), will be sampled. The element \({\left({y}_{i}(t)\right)}_{m}\) corresponds to the probability that node vi evolves to state \(m\in {{{{{\mathcal{S}}}}}}\) given that it was previously in state xi(t)—i.e. \({{{{{\mathcal{R}}}}}}={[0,1]}^{| {{{{{\mathcal{S}}}}}}| }\). When \({{{{{\mathcal{M}}}}}}\) is a stochastic dynamics, we do not typically have access to the transition probabilities yi(t) directly, but rather to the observed outcome state—e.g. xi(t + Δt) in the event where X is temporally ordered—, we therefore define the observed outcome \({\tilde{y}}_{i}(t)\) as

$${({\tilde{y}}_{i}(t))}_{m}=\delta \left({x}_{i}(t+{{\Delta }}t),m\right)\ ,\ \ \forall m\in {{{{{\mathcal{S}}}}}}$$
(2)

where δ(x, y) is the Kronecker delta. Finally, we assume that \({{{{{\mathcal{M}}}}}}\) acts on Xt locally and identically at all times, according to the structure of G. In other words, knowing the state xi as well as the states of all the neighbors of vi, the outcome yi is computed using a time independent function f identical for all nodes

$${y}_{i} = f\left({x}_{i},{{{\Phi }}}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}}\right)\ ,$$
(3)

where \({x}_{{{{{{{\mathcal{N}}}}}}}_{i}}=({x}_{j}| {v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i})\) denotes the states of the neighbors, \({{{{{{\mathcal{N}}}}}}}_{i} = \{{v}_{j}| {e}_{ij}\in {{{{{\mathcal{E}}}}}}\}\) is the set of the neighbors, \({{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}} = \{{{{\Phi }}}_{j}| {v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}\}\) and \({{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}} = \{{{{\Omega }}}_{ij}| {v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}\}\). As a result, we impose a notion of locality where the underlying dynamics is time invariant and invariant under the permutation of the node labels in G, under the assumption that the node and edge attributes are left invariant.

Our objective is to build a model \(\hat{{{{{{\mathcal{M}}}}}}}\), parametrized by a GNN with a set of tunable parameters Θ and trained on the observed dataset D to mimic \({{{{{\mathcal{M}}}}}}\) given G, such that

$$\hat{{{{{{\mathcal{M}}}}}}}({X}_{t}^{\prime},G^{\prime} ;{{{{{\boldsymbol{\Theta }}}}}})\approx {{{{{\mathcal{M}}}}}}({X}_{t}^{\prime},G^{\prime} )\ ,$$
(4)

for all states \({X}_{t}^{\prime}\) and all networks \(G^{\prime}\). The architecture of \(\hat{{{{{{\mathcal{M}}}}}}}\), detailed in Section “Graph neural network and training details”, is designed to act locally similarly to \({{{{{\mathcal{M}}}}}}\). In this case, the locality is imposed by a modified attention mechanism inspired by ref. 41. The advantage of imposing locality allows our architecture to be inductive: If the GNN is trained on a wide range of local structures—i.e. nodes with different neighborhood sizes (or degrees) and states—it can then be used on any other networks within that range. This suggests that the topology of G will have a strong impact on the quality of the trained models, an intuition that is confirmed below. Simiarly to Eq. (3), we can write each individual node outcome computed by the GNN using a function \(\hat{f}\) such that

$${\hat{y}}_{i} = \hat{f}\left({x}_{i},{{{\Phi }}}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}};{{{\boldsymbol{\Theta }}}}\right)$$
(5)

where \({\hat{y}}_{i}\) is the outcome of node vi predicted by \(\hat{{{{{{\mathcal{M}}}}}}}\).

The objective described by Eq. (4) must be encoded into a global loss function, denoted \({{{{{\mathcal{L}}}}}}({{{{{\boldsymbol{\Theta }}}}}})\). Like the outcome functions, \({{{{{\mathcal{L}}}}}}({{{{{\boldsymbol{\Theta }}}}}})\) can be decomposed locally, where the local losses of each node \(L({y}_{i},{\hat{y}}_{i})\) are arithmetically averaged over all possible node inputs \(({x}_{i},{{{\Phi }}}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}})\), where yi and \({\hat{y}}_{i}\) are given by Eqs. (3) and (5), respectively. By using an arithmetic mean to the evaluation of \({{{{{\mathcal{L}}}}}}({{{{{\boldsymbol{\Theta }}}}}})\), we assume that the node inputs are equally important and uniformly distributed. Consequently, the model should be trained equally well on all of them. This consideration is critical because in practice we only have access to a finite number of inputs in D and G, for which the node input distribution, denoted \(\rho(k_i, x_i, \Phi_i, x_{{{{\mathcal{N}}}}_i}, \Phi_{{{{\mathcal{N}}}}_i}, \Omega_{i, {{{\mathcal{N}}}}_i})\), is typically far from being uniform. Hence, in order to train effective models, we recalibrate the inputs using the following global loss

$${{{{{\mathcal{L}}}}}}({{{{{\boldsymbol{\Theta }}}}}})=\mathop{\sum}\limits_{t\in {{{{{\mathcal{T}}}}}}^{\prime} }\ \mathop{\sum}\limits_{{v}_{i}\in {{{{{\mathcal{V}}}}}}^{\prime} (t)}\frac{{w}_{i}(t)}{Z^{\prime} }L\left({y}_{i}(t),{\hat{y}}_{i}(t)\right)$$
(6)

where wi(t) is a weight assigned to node vi at time t, and \(Z^{\prime} ={\sum }_{t\in {{{{{\mathcal{T}}}}}}^{\prime} }{\sum }_{{v}_{i}\in {{{{{\mathcal{V}}}}}}^{\prime} (t)}{w}_{i}(t)\) is a normalization factor. Here, the training node set \({{{{{\mathcal{V}}}}}}^{\prime} (t)\subseteq {{{{{\mathcal{V}}}}}}\) and the training time set \({{{{{\mathcal{T}}}}}}^{\prime} \subseteq [1,T]\) allow us to partition the training dataset for validation and testing when required.

The choice of weights needs to reflect the importance of each node at each time. Because we wish to lower the influence of overrepresented inputs and increase that of rare inputs, a sound choice of weights is

$${w}_{i}(t)\propto \rho {\left({k}_{i},{x}_{i},{{{\Phi }}}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}}\right)}^{-\lambda }$$
(7)

where ki is the degree of node vi in G, and 0≤λ≤1 is an hyperparameter. Equation (7) is an ideal choice, because it corresponds to a principled importance sampling approximation of Eq. (6)42, which is relaxed via the exponent λ. We obtain a pure importance sampling scheme when λ = 1. Note that the weights can rarely be exactly computed using Eq. (7), because the distribution ρ is typically computationally intensive to obtain from data, especially for continuous \({{{{{\mathcal{S}}}}}}\) with metadata. We illustrate various ways to evaluate the weights in Sec. “Importance weights” and Sec. I in the Supplementary Information.

We now illustrate the accuracy of our approach by applying it to four types of synthetic dynamics of various natures (see Sec. “Dynamics” for details on the dynamics). We first consider a simple contagion dynamics: The discrete-time susceptible-infected-susceptible (SIS) dynamics. In this dynamics, nodes are either susceptible (S) or infected (I) by some disease, i.e. \({{{{{\mathcal{S}}}}}}=\left\{S,I\right\}=\left\{0,1\right\}\), and transition between each state stochastically according to an infection probability function α(), where is the number of infected neighbors of a node, and a constant recovery probability β. A notable feature of simple contagion dynamics is that susceptible nodes get infected by the disease through their infected neighbors independently. This reflects the assumption that disease transmission behaves identically whether a person has a large number of infected neighbors or not.

Second, we relax this assumption by considering a complex contagion dynamics with a nonmonotonic infection function α() where the aforementioned transmission events are no longer independent14. This contagion dynamics has an interesting interpretation in the context of the propagation of a social behavior, where the local popularity of a behavior (large ) hinders its adoption. The independent transmission assumption can also be lifted when multiple diseases are interacting10. Thus, we also consider an asymmetric interacting contagion dynamics with two diseases. In this case, \({{{{{\mathcal{S}}}}}}=\left\{{S}_{1}{S}_{2},{I}_{1}{S}_{2},{S}_{1}{I}_{2},{I}_{1}{I}_{2}\right\}=\left\{0,1,2,3\right\}\) where U1V2 corresponds to a state where a node is in state U with respect to the first disease and in state V with respect to the second disease. The interaction between the diseases happens via a coupling that is active only when a node is infected by at least one disease, otherwise it behaves identically to the simple contagion dynamics. This coupling may increase or decrease the virulence of the other disease.

Whereas the previously presented dynamics capture various features of contagion phenomena, real datasets containing this level of detail about the interactions among individuals are rare43,44,45. A class of dynamics for which dataset are easier to find is that of mass-action metapopulation dynamics46,47,48,49, where the status of the individuals are gathered by geographical regions. These dynamics typically evolve on the weighted networks of the individuals’ mobility between regions and the state of a region consists in the number of people that are in each individual health state. As a fourth case study, we consider a type of deterministic metapopulation dynamics where the population size is constant and where people can either be susceptible (S), infected (I) or recovered from the disease (R). As a result, we define the state of the node as three-dimensional vectors specifying the fraction of people in each state—i.e. \({{{{{\mathcal{S}}}}}}={{{{{\mathcal{R}}}}}}={[0,1]}^{3}\).

Figure 1 shows the GNN predictions for the infection and recovery probabilities of the simple and complex contagion dynamics as a function of the number of infected neighbors . We then compare these probabilities with their ground truths, i.e. Eq. (20) using Eqs. (18)–(22) for the infection functions. We also show the maximum likelihood estimators (MLE) of the transition probabilities computed from the fraction of nodes in state x and with infected neighbors that transitioned to state y in the complete dataset D. The MLE, which are typically used in this kind of inference problem50, stands as a reference to benchmark the performance of our approach.

Fig. 1: Predictions of GNN trained on a Barabási-Albert random network 72 (BA).
figure 1

a Transition probabilities of the simple contagion dynamics. b Transition probabilities of the complex contagion dynamics. The solid and dashed lines correspond to the transition probabilities of the dynamics used to generate the training data (labeled GT for "ground truth''), and predicted by the GNN, respectively. Symbols correspond to the maximum likelihood estimation (MLE) of the transition probabilities computed from the dataset D. The colors indicate the type of transition: infection (S → I) in blue and recovery (I → S) in red. The standard deviations, as a result of averaging the outcomes given , are shown using a colored area around the lines (typically narrower than the width of the lines) and using vertical bars for the symbols.

We find that the GNN learns remarkably well the transition probabilities of the simple and complex contagion dynamics. In fact, the predictions of the GNN seem to be systematically smoother than the ones provided by the MLE. This is because the MLE is computed for each individual pair (x, ) from disjoint subsets of the training dataset. This implies that a large number of samples of each pair (x, ) is needed for the MLE to be accurate; a condition rarely met in realistic settings, especially for high degree nodes. This also means that the MLE cannot be used directly to interpolate beyond the pairs (x, ) present in the training dataset, in sharp contrast with the GNN which, by definition, can interpolate within the dataset D. Furthermore, all of its parameters are hierarchically involved during training, meaning that the GNN benefits from any sample to improve all of its predictions, which are then smoother and more consistent. Further still, we found that not all GNN architectures can reproduce the level of accuracy obtained in Fig. 1 (see Sec. III F of the Supplementary Information). In fact, we showed that many standard GNN aggregation mechanisms are ineffective at learning the simple and complex contagion dynamics, most likely because they were specifically designed with structure learning in mind rather than dynamics learning.

It is worth mentioning that the GNN is not specifically designed nor trained to compute the transition probabilities as a function of a single variable, namely the number of infected . In reality, the GNN computes its outcome from the complete multivariate state of the neighbors of a node. The interacting contagion and the metapopulation dynamics, unlike the simple and complex contagions, are examples of such multivariate cases. Their outcome is thus harder to visualize in a representation similar to Fig. 1. Figure 2a–h address this issue by comparing each of the GNN predictions \({\hat{y}}_{i}(t)\) with its corresponding target yi(t) in the dataset D. We quantify the global performance of the models in different scenarios, for the different dynamics and underlying network structures, using the Pearson correlation coefficient r between the predictions and targets (see Sec. “Graph neural network and training details”). We also compute the error, defined from the Pearson coefficient as 1 − r for each degree class k (i.e. between the predictions and targets of only the nodes of degree k). This allows us to quantify the GNN performance for every local structure.

Fig. 2: GNN learning performance on different dynamics and structures.
figure 2

a, b, c, d GNN trained on Erdős–Rényi networks (ER). e, f, g, h GNN trained on Barabási-Albert networks (BA). i, j, k, l Error as a function of the number of neighbors. Each point shown on the panels ah corresponds to a different pair \(({y}_{i}(t),{\hat{y}}_{i}(t))\) in the complete dataset D. We also indicate the Pearson coefficient r on each panel to measure the correlation between the predictions and the targets and use it as a global performance measure. The panels il show the errors (1 − r) as a function of the number of neighbors for GNN trained on ER and BA networks, and those of the corresponding MLE. These errors are obtained from the Pearson coefficients computed from subsets of the prediction-target pairs where all nodes have degree k.

Figures 2i–k confirm that the GNN provides more accurate predictions than the MLE in general and across all degrees. This is especially true in the case of the interacting contagion, where the accuracy of the MLE seems to deteriorate rapidly for large degree nodes. This is a consequence of how scarce the inputs are for this dynamics compared to both the simple and complex contagion dynamics for training datasets of the same size, and of how fast the size of the set of possible inputs scales, thereby quickly rendering MLE completely ineffective for small training datasets. The GNN, on the other hand, is less affected by the scarcity of the data, since any sample improves its global performance, as discussed above.

Figure 2 also exposes the crucial role of the network G on which the dynamics evolves in the global performance of the GNN. Namely, the heterogeneous degree distributions of Barabási-Albert networks (BA)—or any heterogeneous degree distribution—offer a wider range of degrees than those of homogeneous Erdős-Rényi networks (ER). We can take advantage of this heterogeneity to train GNN models that generalize well across a larger range of local structures, as seen in Fig. 2i–l (see also Sec. III of the Supplementary Information). However, the predictions on BA networks are not systematically always better for low degrees than those on ER networks, as seen in the interacting and metapopulation cases. This nonetheless suggests a wide applicability of our approach for real complex systems, whose underlying network structures recurrently exhibit a heterogeneous degree distribution51.

We now test the trained GNN on unseen network structures by recovering the bifurcation diagrams of the four dynamics. In the infinite-size limit \(| {{{{{\mathcal{V}}}}}}| \to \infty\), these dynamics have two possible long-term outcomes: the absorbing state where the diseases quickly die out, and the endemic/epidemic state in which a macroscopic fraction of nodes remains (endemic) or has been infected over time (epidemic)4,10,52. These possible long-term outcomes exchange stability during a phase transition which is continuous for the simple contagion and metapopulation dynamics, and discontinuous for the complex and interacting contagion dynamics. The position of the phase transition depends on the parameters of the dynamics as well as on the topology of the network. Note that for the interacting contagion dynamics, the stability of absorbing and endemic states do not change at the same point, giving rise to a bistable regime where both states are stable.

Figure 3 shows the different bifurcation diagrams obtained by performing numerical simulations with the trained GNN models [using Eq. (5)] while varying the average degree of networks, on which the GNN has not been trained. Quantitatively, the predictions are again strikingly accurate—essentially perfect for the simple and complex contagion dynamics—which is remarkable given that the bifurcation diagrams were obtained on networks the GNN had never seen before. These results illustrate how insights can be gained about the underlying process concerning the existence of phase transitions and their order, among other things. They also suggest how the GNN can be used for diverse applications, such as predicting the dynamics under various network structures (e.g. designing intervention strategies that affect the way individuals interact and are therefore connected).

Fig. 3: Bifurcation diagrams of the trained GNN.
figure 3

a Simple contagion dynamics. b Complex contagion dynamics. c Interacting contagion dynamics. d Metapopulation dynamics. In these experiments, we used Poisson networks composed of \(|{{{\mathcal{V}}}}| = 2000\) nodes with different average degrees \(\langle k \rangle\). The prevalence is defined as the average fraction of nodes that are asymptotically infected by at least one disease and the outbreak size corresponds to the average fraction of nodes that have recovered. These quantities are obtained from numerical simulations using the "ground truth" (GT) dynamics (blue circles) and the GNN trained on Barabási-Albert networks (orange triangles). The error bars correspond to the standard deviations of these numerical simulations. The trained GNN used are the same ones as those used for Fig. 2. As a reference, we also indicate with dashed lines the value(s) of average degree 〈k〉 corresponding to the network(s) on which the GNN were trained. On d, more than one value of 〈k〉 appear as multiple networks with different average degrees were used to train the GNN (see Sec. "Graph neural networks and training details").

Finally, we illustrate the applicability of our approach by training our GNN model using the evolution of COVID-19 in Spain between January 1st 2020 and March 27th 2021 (see Fig. 4). This dataset consists of the daily number of new cases (i.e. incidence) for each of the 50 provinces of Spain as well as Ceuta and Melilla53. We also use a network of the mobility flow recorded in 201854 as a proxy to model the interaction network between these 52 regions. This network is multiplex—each layer corresponding to a different mode of transportation—, directed and weighted (average daily mobility flow).

Fig. 4: Spain COVID-19 dataset.
figure 4

a Spain mobility multiplex network54. The thickness of the edges is proportional to the average number of people transitioning between all connected node pairs. The size of the nodes is proportional to the population Ni living in the province. b Time series of the incidence for the 52 provinces of Spain between January 2020 and March 202153. Each province is identified by its corresponding ISO code. Each incidence time series has been rescaled by its maximum value for the purpose of visualization. The shaded area indicates the training and validation datasets (in-sample) from January 1st 2020 to December 1st 2021. The remaining dataset is used for testing.

We compare the performance of our approach with that of different baselines: Four data-driven techniques—three competing neural network architectures37,55 and a linear vector autoregressive model (VAR)56,57—, and an equivalent mechanistic metapopulation model (Metapop.) driven by a simple contagion mechanism58. Among the three neural network architectures, we used the model of ref. 37 (KP-GNN) that has been used to predict the evolution of COVID-19 in the US. In a way of an ablation study, the other two GNN architectures embody the assumptions that the nodes of the networks are mutually independent (IND), or that the nodes are arbitrarily codependent (FC) in a way that is learned by the neural network. Note that there exists a wide variety of GNN architectures designed to handle dynamics of networks38—networks whose topology evolves over time—but that these architectures are not typically adapted for learning dynamics on networks (see Sec. III F of the Supplementary Information). Finally, we used the parameters of ref. 59 for the metapopulation model. Section “COVID-19 outbreak in Spain” provides more details on the baselines.

Figure 5 shows that all data-driven models can generate highly accurate in-sample predictions, with the exception of the KP-GNN model which appears to have a hard time learning the dynamics, possibly because of its aggregation mechanism (see Sec. IIIF of the Supplementary Information). This further substantiates the idea that many GNN architectures designed for structure learning, like the graph convolutional network60 at the core the KP-GNN model, are suboptimal for dynamics learning problems. However, the other architectures do not appear to have the same capability to generalize the dynamics out-of-sample: The FC and the VAR models, especially, tend to overfit more than the GNN and the IND models. While this was expected for the linear VAR model, the FC model overfits because it is granted too much freedom in the way it learns how the nodes interact with one another. Interestingly, the IND and the GNN models seem to perform similarly, which hints at the possibility that the specifics of the mobility network might not have contributed significantly to the dynamics. This is perhaps not surprising since social distancing and confinement measures were in place during the period covered by the dataset. Indeed, our results indicate that the global effective dynamics was mostly driven by internal processes within each individual province, rather than driven by the interaction between them. This last observation suggests that our GNN model is robust to spurious connections in the interaction network.

Fig. 5: Learning the Spain COVID-19 dataset.
figure 5

a, b Comparison between the targets and the predictions in the in-sample and the out-of-sample datasets for our GNN model (blue) and for other models (KP-GNN in orange, IND in pink, FC in purple and VAR in green; see main text). The accuracy of the predictions is quantified by the Pearson correlation coefficient provided in the legend. c Forecasts by our GNN model for individual time series of the provincial daily incidence compared with the ground truth. Underestimation and overestimation are respectively shown in blue and red. Each time series has been rescaled as in Fig. 4b and are ordered according to mean square error of the GNN’s predictions. d Forecasts for the global incidence (sum of the daily incidence in every province). The solid gray line indicates the ground truth (GT); the dashed blue line, the dashed orange line and dotted green line show the forecast of our GNN model, of KP-GNN and of VAR, respectively. We also show the forecast of an equivalent metapopulation model (red dash-dotted line) which has its own scale (red axis on the right) to improve the visualization; the other lines share the same axis on the left. Similarly to Fig. 4, we differentiate the in-sample from the out-of-sample forecasts using a shaded background.

Finally, Fig. 5d shows that the metapopulation model is systematically overestimating the incidence by over an order of magnitude. Again, this is likely due to the confinement measures in place during that period which were not reflected in the original parameters of the dynamics59. Additional mechanisms accounting for this interplay between social restrictions and the prevalence of the disease—e.g. complex contagion mechanisms12 or time-dependent parameters61—would therefore be in order to extend the validity of the metapopulation model to the full length of the dataset. Interestingly, a signature of this interplay is encoded in the daily incidence data and our GNN model appears to be able to capture it to some extent.

Discussion

We introduced a data-driven approach that learns effective mechanisms governing the propagation of diverse dynamics on complex networks. We proposed a reliable training protocol, and we validated the projections of our GNN architecture on simple, complex, interacting contagion and metapopulation dynamics using synthetic networks. Interestingly, we found that many standard GNN architectures do not handle correctly the problem of learning contagion dynamics from time series. Also, we found that our approach performs better when trained on data whose underlying network structure is heterogeneous, which could prove useful in real-world applications of our method given the ubiquitousness of scale-free networks62.

By recovering the bifurcation diagram of various dynamics, we illustrated how our approach can leverage time series from an unknown dynamical process to gain insights about its properties—e.g. the existence of a phase transition and its order. We have also shown how to use this framework on real datasets, which in turn could then be used to help build better effective models. In a way, we see this approach as the equivalent of a numerical Petri dish—offering a new way to experiment and gain insights about an unknown dynamics—that is complementary to traditional mechanistic modeling to design better intervention procedures, containment countermeasures and to perform model selection.

Although we focused the presentation of our method on contagion dynamics, its potential applicability reaches many other realms of complex systems modeling where intricate mechanisms are at play. We believe this work establishes solid foundations for the use of deep learning in the design of realistic effective models of complex systems.

Gathering detailed epidemiological datasets is a complex and labor-intensive process, meaning that datasets suitable for our approach are currently the exception rather than the norm. The current COVID-19 pandemic has, however, shown how an adequate international reaction to an emerging infectious pathogen critically depends on the free flow of information. New initiatives like Golbal health63 are good examples of how the international epidemiological community is coming together to share data more openly and to make available comprehensive datasets to all researchers. Thanks to such initiatives, it is likely that future pandemics will see larger amount of data available to the scientific community in real time. It is therefore crucial for the community to start developing tools, such as the one presented here, to leverage these datasets so that we are ready for the next pandemic.

Methods

Graph neural network and training details

In this section, we briefly present our GNN architecture, the training settings, the synthetic data generation procedure and the hyperparmeters used in our experiments.

Architecture

We use the GNN architecture shown in Fig. 6 and detailed in Table 1. First, we transform the state xi of every node with a shared multilayer perception (MLP), denoted \({\hat{f}}_{{{{{{\rm{in}}}}}}}:{{{{{\mathcal{S}}}}}}\to {{\mathbb{R}}}^{d}\) where d is the resulting number of node features, such that

$${\xi }_{i}={\hat{f}}_{{{{{{\rm{in}}}}}}}({x}_{i}).$$
(8)

We concatenate the node attributes Φi to xi, when these attributes are available, in which case \({\hat{f}}_{{{{{{\rm{in}}}}}}}:{{{{{\mathcal{S}}}}}}\times {{\mathbb{R}}}^{Q}\to {{\mathbb{R}}}^{d}\). At this point, ξi is a vector of features representing the state (and attributes) of node vi. Then, we aggregate the features of the first neighbors using a modified attention mechanism \({\hat{f}}_{{{{{{\rm{att}}}}}}}\), inspired by ref. 41 (see Section “Attention mechanism”),

$${\nu }_{i}={\hat{f}}_{{{{{{\rm{att}}}}}}}({\xi }_{i},{\xi }_{{{{{{{\mathcal{N}}}}}}}_{i}}),$$
(9)

where we recall that \({{{{{{\mathcal{N}}}}}}}_{i}=\{{v}_{j}| {e}_{ij}\in {{{{{\mathcal{E}}}}}}\}\) is the set of nodes connected to node vi. We also include the edge attributes Ωij into the attention mechanism, when they are available. To do so, we transform the edge attributes Ωij into abstract edge features, such that \({\psi }_{ij}={\hat{f}}_{{{{{{\rm{edge}}}}}}}({{{\Omega }}}_{ij})\) where \({\hat{f}}_{{{{{{\rm{edge}}}}}}}:{{\mathbb{R}}}^{P}\to {{\mathbb{R}}}^{{d}_{{{{{{\rm{edge}}}}}}}}\) is also a MLP, before they are used in the aggregation. Finally, we compute the outcome \({\hat{y}}_{i}\) of each node vi with another MLP \({\hat{f}}_{{{{{{\rm{out}}}}}}}:{{\mathbb{R}}}^{d}\to {{{{{\mathcal{R}}}}}}\) such that

$${\hat{y}}_{i}={\hat{f}}_{{{{{{\rm{out}}}}}}}({\nu }_{i})\ .$$
(10)
Fig. 6: Visualization of the GNN architecture.
figure 6

The blocks of different colors represent mathematical operations. The red blocks correspond to trainable affine transformation parametrized by weights and biases. The purple blocks represent activation functions between each layer. The core of the model is the attention module41, which is represented in blue. The orange block at the end is an activation function that transforms the output into the proper outcomes.

Table 1 Layer by layer description of the GNN models for each dynamics.

Attention mechanism

We use an attention mechanism inspired by the graph attention network architecture (GAT)41. The attention mechanism consists of three trainable functions \({{{{{\mathcal{A}}}}}}:{{\mathbb{R}}}^{d}\to {\mathbb{R}}\), \({{{{{\mathcal{B}}}}}}:{{\mathbb{R}}}^{d}\to {\mathbb{R}}\) and \({{{{{\mathcal{C}}}}}}:{{\mathbb{R}}}^{{d}_{{{{{{\rm{edge}}}}}}}}\to {\mathbb{R}}\), that combine the feature vectors ξi, ξj and ψij of a connected pair of nodes vi and vj, where we recall that d and dedge are the number of node and edge features, respectively. Then, the attention coefficient aij is computed as follows

$${a}_{ij}=\sigma \left[{{{{{\mathcal{A}}}}}}\left({\xi }_{i}\right)+{{{{{\mathcal{B}}}}}}\left({\xi }_{j}\right)+{{{{{\mathcal{C}}}}}}\left({\psi }_{ij}\right)\right]$$
(11)

where \(\sigma (x)={[1+{e}^{-x}]}^{-1}\) is the logistic function. Notice that, by using this logistic function, the value of the attention coefficients is constrained to the open interval (0, 1), where aij = 0 implies that the feature ξj does not change the value of νi, and aij = 1 implies that it maximally changes the value of νi. In principle, aij quantifies the influence of the state of node vj over the outcome of node vi. In reality, the representation learned by the GNN can be non-sparse, meaning that the neighbor features \({\xi }_{{{{{{{\mathcal{N}}}}}}}_{i}}\) can be combined in such a way that their noncontributing parts are canceled out without having aij being necessarily zero. This can result in the failure of this interpretation of this attention coefficients (see the Supplementary Information for further details). Nevertheless, the attention coefficients can be used to assess how connected nodes interact together.

We compute the aggregated feature vectors \(\nu_i\) of node \(v_i\) as

$${\nu }_{i}={\hat{f}}_{{{{{{\rm{att}}}}}}}({\xi }_{i},{\xi }_{{{{{{{\mathcal{N}}}}}}}_{i}})={\xi }_{i}+\mathop{\sum}\limits_{{v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}}{a}_{ij}{\xi }_{j}.$$
(12)

It is important to stress that, at this point, \(\nu_i\) contains some information about \(v_i\) and all of its neighbors in a pairwise manner. In all our experiments, we fix \({{{{{\mathcal{A}}}}}}\), \({{{{{\mathcal{B}}}}}}\), and \({{{{{\mathcal{C}}}}}}\) to be affine transformations with trainable weight matrix and bias vector. Also, we use multiple attention modules in parallel to increase the expressive power of the GNN architecture, as suggested by ref. 41.

The attention mechanism described by Eq. (11) is slightly different from the vanilla version of ref. 41. Similarly to other well-known GNN architectures33,60,64, the aggregation scheme of the vanilla GAT is designed as an average of the feature vectors of the neighbors—where, by definition, \({\sum }_{{v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}}{a}_{ij}=1\) for all \(v_i\)—rather than as a general weighted sum like for Eq. (12). This is often reasonable in the context of structure learning, where the node features represent some coordinates in a metric space where connected nodes are likely to be close33. Yet, in the general case, this type of constraint was shown to lessen dramatically the expressive power of the GNN architecture31. We also reached the same conclusion while using average-like GNN architectures (see Sec. III F of the Supplementary Information). By contrast, the aggregation scheme described by Eq. (12) allows our architecture to represent various dynamic processes on networks accurately.

Training settings

In all experiments on synthetic data, we use the cross entropy loss as the local loss function,

$$L\left({y}_{i},{\hat{y}}_{i}\right)=-\mathop{\sum}\limits_{m}{y}_{i,m}{{{{{\mathrm{log}}}}}}\,{\hat{y}}_{i,m},$$
(13)

where yi,m corresponds to the m-th element of the outcome vector of node vi, which either is a transition probability for the stochastic contagion dynamics or a fraction of people for the metapopulation dynamics. For the simple, complex and interacting contagion dynamics, we used the observed outcomes, i.e. using \({y}_i \to {\tilde{y}}_i\) in Eq. (13), corresponding to the stochastic state of node vi at the next time step, as the target in the loss function. While we noticed a diminished performance when using the observed outcomes as opposed to the true transition probabilities (see Sec. III E of the Supplementary Information), this setting is more realistic and shows what happens when the targets are noisy. The effect of noise can be tempered by increasing the size of the dataset (see the Supplementary Information). For the metapopulation dynamics, since this model is deterministic, we used the true targets without adding noise.

Performance measures

We use the Pearson correlation coefficient r as a global performance measure defined on a set of targets Y and predictions \(\hat{Y}\) as

$$r=\frac{{\mathbb{E}}\left[\right.(Y-{\mathbb{E}}[Y])(\hat{Y}-{\mathbb{E}}[\hat{Y}])\left]\right.}{\sqrt{{\mathbb{E}}\left[\right.{(Y-{\mathbb{E}}[Y])}^{2}\left]\right.\ {\mathbb{E}}\left[\right.{(\hat{Y}-{\mathbb{E}}[\hat{Y}])}^{2}\left]\right.}}$$
(14)

where \({\mathbb{E}}[W]\) denotes the expectation of W. Also, because the maximum correlation occurs at r = 1, we also define 1 − r as the global error on the set of target-prediction pairs.

Synthetic data generation

We generate data from each dynamics using the following algorithm:

  1. 1.

    Sample a network G from a given generative model (e.g. the Erdős-Rényi G(N, M) or the Barabási-Albert network models).

  2. 2.

    Initialize the state of the system \(X(0)={\left({x}_{i}(0)\right)}_{i = 1..N}\). For the simple, complex and interacting contagion dynamics, sample uniformly the number of nodes in each state. For the metapopulation dynamics, sample the population size for each node from a Poisson distribution of average 104 and then sample the number of infected people within each node from a binomial distribution of parameter 10−5. For instance, a network of \(| {{{{{\mathcal{V}}}}}}| =1{0}^{3}\) nodes will be initialized with a total of 100 infected people, on average, distributed among the nodes.

  3. 3.

    At time t, compute the observed outcome—Yt for the metapopulation dynamics, and \({\tilde{Y}}_{t}\) for the three stochastic dynamics. Then, record the states Xt and Yt (or \({\tilde{Y}}_{t}\)).

  4. 4.

    Repeat step 3 until \((t\,{{{{{\mathrm{mod}}}}}}\,\,{t}_{s})=0\), where ts is a resampling time. At this moment, apply step 2 to reinitialize the states Xt and repeat step 3.

  5. 5.

    Stop when t = T, where T is the targeted number of samples.

The resampling step parametrized by ts indirectly controls the diversity of the training dataset. We allow ts to be small for the contagion dynamics (ts = 2) and larger for the metapopulation dynamics (ts = 100) to emphasize on the performance of the GNN rather than the quality of the training dataset, while acknowledging that different values of ts could lead to poor training (see Sec. III D of the Supplementary Information).

We trained the simple, complex and interacting contagion GNN models on networks of size \(| {{{{{\mathcal{V}}}}}}| =1{0}^{3}\) nodes and on time series of length T = 104. To generate the networks, we either used Erdős-Rényi (ER) random networks G(N, M) or Barabási-Albert (BA) random networks. In both cases, the parameters of the generative network models are chosen such that the average degree is fixed to 〈k〉 = 4.

To train our models on the metapopulation dynamics, we generated 10 networks of \(| {{{{{\mathcal{V}}}}}}| =100\) nodes and generated for each of them time series of ts = 100 time steps. This number of time steps roughly corresponds to the moment where the epidemic dies out. Simiarly to the previous experiments, we used the ER and the BA models to generate the networks, where the parameters were chosen such that 〈k〉 = 4. However, because this dynamics is not stochastic, we varied the average degree of the networks to increase the variability in the time series. This was done by randomly removing a fraction \(p=1-{{{{{\mathrm{ln}}}}}}\,\left(1-\mu +e\mu \right)\) of their edges, where μ was sampled for each network uniformly between 0 and 1. In this scenario, the networks were directed and weighted, with each edge weight eij being uniformly distributed between 0 and 1.

Hyperparameters

The optimization of the parameters was performed using the rectified Adam algorithm65, which is hyperparameterized by b1 = 0.9 and b2 = 0.999, as suggested in ref. 65.

To build a validation dataset, we selected a fraction of the node states randomly for each time step. More specifically, we randomly chose node vi at time t proportionally to its importance weight wi(t). For all experiments on synthetic dynamics, we randomly selected 10 nodes per time step to be part of the validation set, on average. For all experiments, the learning rate ϵ was reduced by a factor 2 every 10 epochs with initial value ϵ0 = 0.001. A weight decay of 10−4 was used as well to help regularize the training. We trained all models for 30 epochs, and selected the GNN model with the lowest loss on validation datasets. We fixed the importance sampling bias exponents for the training to λ = 0.5 in the simple, complex and interacting contagion cases, and fixed it to λ = 1 in the metapopulation case.

Importance weights

In this section, we show how to implement the importance weights in the different cases. Other versions of the importance weights are also available in Sec. I of the Supplementary Information.

Discrete state stochastic dynamics

When \({{{{{\mathcal{S}}}}}}\) is a finite countable set, the importance weights can be computed exactly using Eq. (7),

$${w}_{i}(t)\propto {\left[\rho \left({k}_{i},{x}_{i}(t),{x}_{{{{{{{\mathcal{N}}}}}}}_{i}}(t)\right)\right]}^{-\lambda }$$
(15)

where \(\rho \left(k,x,{x}_{{{{{{\mathcal{N}}}}}}}\right)\) is the probability to observe a node of degree k in state x with a neighborhood in state \({x}_{{{{{{\mathcal{N}}}}}}}\) in the complete dataset D. The inputs can be simplified from \((k,x,{x}_{{{{{{\mathcal{N}}}}}}})\) to (k, x, ) without loss of generality, where is a vector whose entries are the number of neighbors in each state. The distribution is then estimated from the complete dataset D by computing the fraction of inputs that are in every configuration

$$\rho (k,x,{{{{{\boldsymbol{\ell }}}}}})= \, \frac{1}{| {{{{{\mathcal{V}}}}}}| T}\mathop{\sum }\limits_{i=1}^{| {{{{{\mathcal{V}}}}}}| }I\left({k}_{i}=k\right)\\ \, \times \mathop{\sum }\limits_{t=1}^{T}I\left({x}_{i}(t)=x\right)\ \ I\left({{{{{{\boldsymbol{\ell }}}}}}}_{i}(t)={{\ell }}\!\right)$$
(16)

where I() is the indicator function.

Continuous state deterministic dynamics

The case of continuous states—e.g. for metapopulation dynamics—is more challenging than its discrete counterpart, especially if the node and edge attributes, Φi and Ωij, need to be accounted for. One of the challenges is that we cannot count the inputs like in the discrete case. As a result, the estimated distribution ρ cannot be estimated directly using Eq. (16), and we use instead

$${w}_{i}(t)=\left[\right.P({k}_{i})\,{{\Sigma }}({{{\Phi }}}_{i},{{{\Omega }}}_{i}| {k}_{i})\,{{\Pi }}\left(\bar{x}(t)\right){\left]\right.}^{-\lambda }$$
(17)

where P(ki) is the fraction of nodes with degree ki, Σ(Φi, Ωiki) is the joint probability density function (pdf) conditioned on the degree ki for the node attributes Φi and the sum of the edge attributes \({{{\Omega }}}_{i}\equiv {\sum }_{{v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}}{{{\Omega }}}_{ij}\), and where \({{\Pi }}\left(\bar{x}(t)\right)\) is the pdf for the average of node states at time t\(\bar{x}(t)=\frac{1}{| {{{{{\mathcal{V}}}}}}| }{\sum }_{{v}_{i}\in {{{{{\mathcal{V}}}}}}}{x}_{i}(t)\). The pdf are obtained using nonparametric Gaussian kernel density estimators (KDE)66. Provided that the density values of the KDE are unbounded above, we normalize the pdf such that the density of each sample used to construct the KDE sum to one. Further details on how we designed the importance weights are provided in Sec. I of the Supplementary Information.

Dynamics

In what follows, we describe in detail the contagion dynamics used for our experiments. We specify the node outcome function f introduced in Eq. (3) and the parameters of the dynamics.

Simple contagion

We consider the simple contagion dynamics called the susceptible-infected-susceptible (SIS) dynamics for which \({{{{{\mathcal{S}}}}}}=\left\{S,I\right\}=\left\{0,1\right\}\)—we use these two representations of \({{{{{\mathcal{S}}}}}}\) interchangeably. Because this dynamics is stochastic, we let \({{{{{\mathcal{R}}}}}}={[0,1]}^{2}\). We define the infection function α() as the probability that a susceptible node becomes infected given its number of infected neighbors

$$\Pr \left(S\to I| \ell \right)=\alpha \left(\ell \right)=1-{\left(1-\gamma \right)}^{\ell }\,,$$
(18)

where γ [0, 1] is the disease transmission probability. In other words, a node can be infected by any of its infected neighbors independently with probability γ. We also define the constant recovery probability as

$$\Pr \left(I\to S\right)=\beta \ .$$
(19)

The node outcome function for the SIS dynamics is therefore

$$f({x}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}})=\left\{\begin{array}{ll}\left(1-\alpha ({\ell }_{i}),\ \alpha ({\ell }_{i})\right)\quad &\,{{\mbox{if}}}\,{x}_{i}=0\,{{\mbox{,}}}\,\\ (\beta ,\ 1-\beta )\quad\hfill &\,{{\mbox{if}}}\,{x}_{i}=1\,{{\mbox{,}}}\,\end{array}\right.$$
(20)

where

$${\ell }_{i}=\mathop{\sum}\limits_{{v}_{j}\in {{{{{{\mathcal{N}}}}}}}_{i}}\delta ({x}_{j},1)$$
(21)

is the number of infected neighbors of vi and δ(x, y) is the Kronecker delta. Note that for each case in Eq. (20), the outcome is a two-dimensional probability vector, where the first entry is the probability that node vi becomes/remains susceptible at the following time step, and the second entry is the probability that it becomes/remains infected. We used (γ, β) = (0.04, 0.08) in all experiments involving this simple contagion dynamics.

Complex contagion

To lift the independent transmission assumption of the SIS dynamics, we consider a complex contagion dynamics for which the node outcome function has a similar form as Eq. (20), but where the infection function α() has the nonmonotonic form

$$\alpha (\ell )=\frac{1}{z(\eta )}\frac{{\ell }^{3}}{{e}^{\ell /\eta }-1}$$
(22)

where z(η) normalizes the infection function such that α(*) = 1 at its global maximum * and η > 0 is a parameter controlling the position of *. This function is inspired by the Planck distribution for the black-body radiation, although it was chosen for its general shape rather than for any physical meaning whatsoever. We used (η, β) = (8, 0.06) in all experiments involving this complex contagion dynamics.

Interacting contagion

We define the interacting contagion as two SIS dynamics that are interacting and denote it as the SIS-SIS dynamics. In this case, we have \({{{{{\mathcal{S}}}}}}=\left\{{S}_{1}{S}_{2},{I}_{1}{S}_{2},{S}_{1}{I}_{2},{I}_{1}{I}_{2}\right\}=\left\{0,1,2,3\right\}\). Simiarly to the SIS dynamics, we have \({{{{{\mathcal{R}}}}}}={[0,1]}^{4}\) and we define the infection probability functions

$${\alpha }_{g}({\ell }_{g})=1-{(1-{\gamma }_{g})}^{{\ell }_{g}}\ \,{{\mbox{if}}}\,\ x=0$$
(23a)
$${\alpha }_{g}^{* }({\ell }_{g})=1-{(1-\zeta {\gamma }_{g})}^{{\ell }_{g}}\ \,{{\mbox{if}}}\,\ x=1,2\,,$$
(23b)

where ζ ≥ 0 is a coupling constant and g is the number of neighbors infected by disease g, and also define the recovery probabilities βg for each disease (g = 1, 2). The case where ζ > 1 corresponds to the situation in which the diseases are synergistic (i.e. being infected by one increases the probability of getting infected by the other), whereas competition is introduced if ζ < 1 (being already infected by one decreases the probability of getting infected by the other). The case ζ = 1 falls back on two independent SIS dynamics that evolve simultaneously on the network. The outcome function is composed of 16 entries that are expressed as follows

$$\, f({x}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}})\\ \, =\left\{\begin{array}{ll}\left(\left[\right.1-{\alpha }_{1}({\ell }_{i,1})\left]\ \right.\left[\right.1-{\alpha }_{2}({\ell }_{i,2})\left]\right.,\ {\alpha }_{1}({\ell }_{i,1})\ \left[\right.1-{\alpha }_{2}({\ell }_{i,2})\left]\right.,\ \left[\right.1-{\alpha }_{1}({\ell }_{i,1})\left]\right.\ {\alpha }_{2}({\ell }_{i,2}),\ {\alpha }_{1}({\ell }_{i,1}){\alpha }_{2}({\ell }_{i,2})\right)&\,{{{\mbox{if}}}}\,\ {x}_{i}=0,\\ \left({\beta }_{1}\left[\right.1-{\alpha }_{2}^{* }({\ell }_{i,2})\left]\right.,\ \left[\right.1-{\beta }_{1}\left]\ \right.\left[\right.1-{\alpha }_{2}^{* }({\ell }_{i,2})\left]\right.,\ {\beta }_{1}{\alpha }_{2}^{* }({\ell }_{i,2}),\ \left[\right.1-{\beta }_{1}\left]\right.{\alpha }_{2}^{* }({\ell }_{i,2})\right)\hfill&\,{{{\mbox{if}}}}\,\ {x}_{i}=1,\\ \left(\left[\right.1-{\alpha }_{1}^{* }({\ell }_{i,1})\left]\right.\ {\beta }_{2},\ {\alpha }_{1}^{* }({\ell }_{i,1}){\beta }_{2},\ \left[\right.1-{\alpha }_{1}^{* }({\ell }_{i,1})\left]\right.\ \left[\right.1-{\beta }_{2}\left]\right.,\ {\alpha }_{1}^{* }({\ell }_{i,1})\ \left[\right.1-{\beta }_{2}\left]\right.\right)\hfill&\,{{{\mbox{if}}}}\,\ {x}_{i}=2,\\ \left({\beta }_{1}{\beta }_{2},\ [1-{\beta }_{1}]{\beta }_{2},\ {\beta }_{1}[1-{\beta }_{2}],\ [1-{\beta }_{1}][1-{\beta }_{2}]\right)\hfill&\,{{{\mbox{if}}}}\,\ {x}_{i}=3.\end{array}\right.$$
(24)

where we define i,g as the number of neighbors of vi that are infected by disease g. We used (γ1, γ2, β1, β2, ζ) = (0.01, 0.012, 0.19, 0.22, 50) in all experiments involving this interacting contagion dynamics.

Metapopulation

The metapopulation dynamics considered is a deterministic version of the susceptible-infection-recovered (SIR) metapopulation model46,47,48,49. We consider that the nodes are populated by a fixed number of people Ni, which can be in three states—susceptible (S), infected (I) or recovered (R). We therefore track the number of people in every state at each time. Furthermore, we let the network G be weighted, with the weights describing the mobility flow of people between regions. In this case, \({{{\Omega }}}_{ij}\in {\mathbb{R}}\) is the average number of people that are traveling from node vj to node vi. Finally, because we assume that the population size is on average steady, we let Φi = Ni be a node attribute and work with the fraction of people in every epidemiological state. More precisely, we define the state of node vj by xj = (sj, ij, rj), where sj, ij and rj are the fractions of susceptible, infected and recovered people, respectively. From these definitions, we define the node outcome function of this dynamics as

$$f({x}_{j},{x}_{{{{{{{\mathcal{N}}}}}}}_{j}},G)=\left(\begin{array}{l}{s}_{j}-{s}_{j}{\tilde{\alpha }}_{j}\\ {i}_{j}-\frac{{i}_{j}}{{\tau }_{r}}+{s}_{j}{\tilde{\alpha }}_{j}\\ {r}_{j}+\frac{{i}_{j}}{{\tau }_{r}}\end{array}\right)$$
(25)

where

$${\tilde{\alpha }}_{j}=\alpha ({i}_{j},{N}_{j})+\mathop{\sum}\limits_{{v}_{l}\in {{{{{{\mathcal{N}}}}}}}_{j}}\frac{{k}_{j}{{{\Omega }}}_{jl}\alpha ({i}_{l},{N}_{l})}{{\sum }_{{v}_{n}\in {{{{{{\mathcal{N}}}}}}}_{j}}{{{\Omega }}}_{jn}}\ ,$$
(26)

and kj is the degree of node vj. The function α(i, N) corresponds to the infection rate, per day, at which an individual is infected by someone visiting from a neighboring region with iN infected people in it, and is equal to

$$\alpha (i,N)=1-{\left(1-\frac{{R}_{0}}{{\tau }_{r}N}\right)}^{iN}\approx 1-{e}^{-\frac{{R}_{0}}{{\tau }_{r}}i}\ .$$
(27)

where R0 corresponds to the reproduction number and, τr is the average recovery time in days. In all experiments with this metapopulation dynamics, we used (R0, τr) = (8.31, 7.5).

COVID-19 outbreak in Spain

Dataset

The dataset is composed of the daily incidence of the 52 Spanish provinces (including Ceuta and Melilla) monitored for 450 days between January 1st 2020 and March 27th 202153. The dataset is augmented with the origin-destination (OD) network of individual mobility54. This mobility network is multiplex, directed and weighted, where the weight of each edge \({e}_{ij}^{\nu }\) represents the mobility flow from province vj and to province vj using transportation \(\nu\). The metadata associated to each node is the population of province vi67, noted Φi = Ni. The metadata associated to each edge, \({{{\Omega }}}_{ij}^{\nu }\), corresponds to the average number of people that moved from vj to vi using \(\nu\) as the main means of transportation.

Models

The GNN model used in Fig. 5 is very similar to the metapopulation GNN model—with node and edge attributes—with the exception that different attention modules are used to model the different OD edge types (plane, car, coach, train, and boat, see Table 1). To combine the features of each layer of the multiplex network, we average pooled the output features of the attention modules. We also generalize our model to take in input a sequence of L states of the system, that is

$${\hat{Y}}_{t}=\hat{{{{{{\mathcal{M}}}}}}}({X}_{t:t-L+1},G;{{{{{\boldsymbol{\Theta }}}}}})$$
(28)

where Xt:tL+1 = (Xt, Xt−1,  , XtL+1) and L is a lag. At the local level, it reads

$${\hat{y}}_{i}(t)= \, \hat{f}\left({x}_{i}(t:t-L+1)\right.,\\ \, \left.{{{\Phi }}}_{i},{x}_{{{{{{{\mathcal{N}}}}}}}_{i}}(t:t-L+1),{{{\Phi }}}_{{{{{{{\mathcal{N}}}}}}}_{i}},{{{\Omega }}}_{i{{{{{{\mathcal{N}}}}}}}_{i}};{{{{{\boldsymbol{\Theta }}}}}}\right)$$
(29)

where xi(t: t − L + 1) corresponds to the L previous state of node i from time t to time t − L + 1. As we now feed sequences of node states to the GNN, we use Elman recurrent neural networks68 to transform these sequences of states before aggregating them instead of linear layers, as shown in Fig. 6 and Table 1. Additionally, because the outputs of the models are not probability vectors, like for the dynamics of Sec. “Dynamics”, but real numbers, we use the mean square error (MSE) loss to train the model:

$$L({y}_{i},{\hat{y}}_{i})={({y}_{i}-{\hat{y}}_{i})}^{2}\ .$$
(30)

We use five different baseline models to compare with the performance of our GNN: Three additional neural network architectures, a vector autoregressive model (VAR)56 and an equivalent metapopulation model driven by a simple contagion mechanism. The first neural network architecture, denoted the KP-GNN model, was used in ref. 37 to forecast the evolution COVID-19 in the US using a similar strategy as ours with respect to the mobility network. As described in ref. 37, we used a single-layered MLP with 64 hidden units to transform the input, and then we used two graph convolutional networks (GCN) in series, each with 32 hidden units, to perform the feature aggregation. Finally, we computed the output of the model using another single-layered MLP with 32 hidden units. The layers of this model are separated by ReLU activation functions and are sampled with a dropout rate of 0.5. Because this model is not directly adapted to multiplex networks, we merged all layers together into a single network and summed the weights of the edges. Then, as prescribed in ref. 37, we thresholded the merged network by keeping at most 32 neighbors with the highest edge weight for each node. We did not use our importance sampling procedure to train the KP-GNN model—letting λ = 0—to remain as close as possible to the original model.

The other two neural network architectures are very similar to the GNN model we presented in Table 1: The only different component is their aggregation mechanism. The IND model, where the nodes are assumed to be mutually independent, does not aggregate the features of the neighbors. It therefore acts like a univariate model, where the time series of each node are processed like different elements of a minibatch. In the FC model, the nodes interact via a single-layered MLP connecting all nodes together. The parameters of this MLP are learnable, which effectively allows the model to express any interaction patterns. Because the number of parameters of this MLP scales with \(d| {{{{{\mathcal{V}}}}}}{| }^{2}\), where d is the number of node features after the input layers, we introduce another layer of 8 hidden units to compress the input features before aggregating them.

The VAR model is a linear generative model adapted for multivariate time series forecasting

$${\hat{Y}}_{t}=\hat{{{{{{\mathcal{M}}}}}}}({X}_{t},{X}_{t-1},\cdots \ ,{X}_{t-L+1})=\mathop{\sum }\limits_{l=0}^{L-1}{{{{{{\boldsymbol{A}}}}}}}_{l}{X}_{t-l}+{{{{{\boldsymbol{b}}}}}}+{{{{{{\boldsymbol{\epsilon }}}}}}}_{t}$$
(31)

where \({{{{{{\boldsymbol{A}}}}}}}_{l}\in {{\mathbb{R}}}^{| {{{{{\mathcal{V}}}}}}| \times | {{{{{\mathcal{V}}}}}}| }\) are weight matrices, \({{{{{\boldsymbol{b}}}}}}\in {{\mathbb{R}}}^{| {{{{{\mathcal{V}}}}}}| }\) is a trend vector and ϵt is an error term with \({\mathbb{E}}[{{{{{{\boldsymbol{\epsilon }}}}}}}_{t}]={{{{{\boldsymbol{0}}}}}}\) and \({\mathbb{E}}[{{{{{{\boldsymbol{\epsilon }}}}}}}_{t}{{{{{{\boldsymbol{\epsilon }}}}}}}_{s}]={\delta }_{t,s}{{{{{\boldsymbol{\Sigma }}}}}}\), with Σ being a positive-semidefinite covariance matrix. While autoregressive models are often used to predict stock markets69, they have also been used recently to forecast diverse COVID-19 outbreaks57. This model is fitted to the COVID-19 time series dataset also by minimizing the MSE.

The metapopulation model is essentially identical to the model presented in Sec. “Dynamics”. However, because we track the incidence, i.e. the number of newly infectious cases χi(t) in each province i, instead of the complete state (Si, Ii, Ri) representing the number of individuals in each state, we allow the model to internally track the complete state based on the ground truth. At first, the whole population is susceptible, i.e. Si(1) = Ni. Then, at each time step, we subtract the number of newly infectious cases in each node from Si(t), and add it to Ii. Finally, the model allows a fraction \(\frac{1}{{\tau }_{r}}\) of its infected people to recover. The evolution equations of this model are as follows

$${S}_{i}(t+1)={S}_{i}(t)-{\chi }_{i}(t)\ ,$$
(32a)
$${I}_{i}(t+1)={I}_{i}(t)+{\chi }_{i}(t)-\frac{1}{{\tau }_{r}}{I}_{i}(t)\ ,$$
(32b)
$${R}_{i}(t+1)={R}_{i}(t)+\frac{1}{{\tau }_{r}}{I}_{i}(t)\ .$$
(32c)

Finally, we computed the incidence \({\hat{\chi }}_{i}(t)\) predicted by the metapopulation model using the current internal state as follows:

$${\hat{\chi }}_{i}(t)={S}_{i}{\tilde{\alpha }}_{i}\ ,$$
(33)

where \({\tilde{\alpha }}_{i}\) is given by Eq. (26), using the mobility network G, Eq. (27) for α(i, N) and \({i}_{j}=\frac{{I}_{j}}{{N}_{j}}\). Since the mobility weights \({{{\Omega }}}_{ij}^{\nu }\) represent the average number of people traveling from province vj to province vi, we assumed all layers to be equivalent and aggregated each layer into a single typeless network where \({{{\Omega }}}_{ij}={\sum }_{\nu }{{{\Omega }}}_{ij}^{\nu }\). We fixed the parameters of the model to R0 = 2.5 and τr = 7.5, as these values were used in other contexts for modeling the propagation of COVID-19 in Spain59.

Training

We trained the GNN and other neural networks for 200 epochs, while decreasing the learning rate by a factor of 2 every 20 epochs with an initial value of 10−3. For our GNN, the IND and the FC models, we fixed the importance sampling bias exponent to λ = 0.5 and, like the models trained on synthetic data, we used a weight decay of 10−4 (see Sec. “Training settings”). We fixed the lag of these models, including the VAR model, to L = 5. The KP-GNN model was trained using a weight decay of 10−5 following ref. 37, and we chose a lag of L = 7. For all models, we constructed the validation dataset by randomly selecting a partition of the nodes at each time step proportionally to their importance weights wi(t): 20% of the nodes are used for validation in this case. The test dataset was constructed by selecting the last 100 time steps of the time series of all nodes, which rough corresponds to the third wave of the outbreak in Spain.