Introduction

The wide variety of molecular types and sizes poses numerous challenges in the computational modeling of molecular systems for drug discovery, structural biology, quantum chemistry, and others1. To address these challenges, recent advances in geometric deep learning (GDL) approaches have become increasingly important2, 3. Especially, Graph Neural Networks (GNNs) have demonstrated superior performance among various GDL approaches4,5,6. GNNs treat each molecule as a graph and perform message passing scheme on it7. By representing atoms or groups of atoms like functional groups as nodes, and chemical bonds or any pairwise interactions as edges, molecular graphs can naturally encode the structural information in molecules. In addition to this, GNNs can incorporate symmetry and achieve invariance or equivariance to transformations such as rotations, translations, and reflections8, which further contributes to their effectiveness in molecular science applications. To enhance their ability to capture molecular structures and increase the expressive power of their models, previous GNNs have utilized auxiliary information such as chemical properties9,10,11,12, atomic pairwise distances in Euclidean space7, 13, 14, angular information15,16,17,18, etc.

In spite of the success of GNNs, their application in molecular sciences is still in its early stages. One reason for this is that current GNNs often use targeted inductive bias for modeling a specific type of molecular system, and cannot be directly transferred to other contexts although all molecule structures and their interactions follow the same law of physics. For example, GNNs designed for modeling proteins may include operations that are specific to the structural characteristics of amino acids19, 20, which are not relevant for other types of molecules. Additionally, GNNs that incorporate comprehensive geometric information can be computationally expensive, making them difficult to scale to tasks involving a large number of molecules (e.g., high-throughput compound screening) or macromolecules (e.g., proteins and RNAs). For instance, incorporating angular information can significantly improve the performance of GNNs15,16,17,18, but also increases the complexity of the model, requiring at least \(O(Nk^2)\) messages to be computed where N and k denote the number of nodes and the average degree in a graph.

To tackle the limitations mentioned above, we propose a universal GNN framework, Physics-Aware Multiplex Graph Neural Network (PAMNet), for the accurate and efficient representation learning of 3D molecules ranging from small molecules to macromolecules in any molecular system. PAMNet induces a physics-informed bias inspired by molecular mechanics21, which separately models local and non-local interactions in molecules based on different geometric information. To achieve this, we represent each molecule as a two-layer multiplex graph, where one plex only contains local interactions, and the other plex contains additional non-local interactions. PAMNet takes the multiplex graphs as input and uses different operations to incorporate the geometric information for each type of interaction. This flexibility allows PAMNet to achieve efficiency by avoiding the use of computationally expensive operations on non-local interactions, which consist of the majority of interactions in a molecule. Additionally, a fusion module in PAMNet allows the contribution of each type of interaction to be learned and fused for the final feature or prediction. To preserve symmetry, PAMNet utilizes E(3)-invariant representations and operations when predicting scalar properties, and is extended to predict E(3)-equivariant vectorial properties by considering the geometric vectors in molecular structures that arise from quantum mechanics.

To demonstrate the effectiveness of PAMNet, we conduct a comprehensive set of experiments on a variety of tasks involving different molecular systems, including small molecules, RNAs, and protein-ligand complexes. These tasks include predicting small molecule properties, RNA 3D structures, and protein-ligand binding affinities. We compare PAMNet to state-of-the-art baselines in each task and the results show that PAMNet outperforms the baselines in terms of both accuracy and efficiency across all three tasks. Given the diversity of the tasks and the types of molecules involved, the superior performance of PAMNet shows its versatility to be applied in various real-world scenarios.

Overview of PAMNet

Multiplex graph representation

Given any 3D molecule or molecular system, we define a multiplex graph representation as the input of our PAMNet model based on the original 3D structure (Fig. 1a). The construction of multiplex graphs is inspired by molecular mechanics21, in which the molecular energy E is separately modeled based on local and non-local interactions (Fig. 1c). In detail, the local terms \(E_{\text{ bond }}+E_{\text{ angle }}+E_{\text{ dihedral }}\) model local, covalent interactions including \(E_{\text{ bond }}\) that depends on bond lengths, \(E_{\text{ angle }}\) on bond angles, and \(E_{\text{ dihedral }}\) on dihedral angles. The non-local terms \(E_{\text{ vdW }}+E_{\text{ electro }}\) model non-local, non-covalent interactions including van der Waals and electrostatic interactions which depend on interatomic distances. Motivated by this, we also decouple the modeling of these two types of interactions in PAMNet. For local interactions, we can define them either using chemical bonds or by finding the neighbors of each node within a relatively small cutoff distance, depending on the given task. For global interactions that contain both local and non-local ones, we define them by finding the neighbors of each node within a relatively large cutoff distance. For each type of interaction, we use a layer to represent all atoms as nodes and the interactions as edges. The resulting layers that share the same group of atoms form a two-layer multiplex graph \(G = \{G_{global}, G_{local}\}\) which represents the original 3D molecular structure (Fig. 1a).

Message passing modules

To update the node embeddings in the multiplex graph G, we design two message passing modules that incorporate geometric information: Global Message Passing and Local Message Passing for updating the node embeddings in \(G_{global}\) and \(G_{local}\), respectively (Fig. 1b). These message passing modules are inspired by physical principles from molecular mechanics (Fig. 1c): When modeling the molecular energy E, the terms for local interactions require geometric information including interatomic distances (bond lengths) and angles (bond angles and dihedral angles), while the terms for non-local interactions only require interatomic distances as geometric information. The message passing modules in PAMNet also use geometric information in this way when modeling these interactions (Fig. 1b,e). Specifically, we capture the pairwise distances and angles contained within up to two-hop neighborhoods (Fig. 1d). The Local Message Passing requires the related adjacency matrix \({\textbf {A}}_{local}\), pairwise distances \(d_{local}\) and angles \(\theta _{local}\), while the Global Message Passing only needs the related adjacency matrix \({\textbf {A}}_{global}\) and pairwise distances \(d_{global}\). Each message passing module then learns the node embeddings \({\varvec {z}}_g\) or \({\varvec {z}}_l\) in \(G_{global}\) and \(G_{local}\), respectively.

For the operations in our message passing modules, they can preserve different symmetries: E(3)-invariance and E(3)-equivariance, which contain essential inductive bias incorporated by GNNs when dealing with graphs with geometric information8. E(3)-invariance is preserved when predicting E(3)-invariant scalar quantities like energies, which remain unchanged when the original molecular structure undergoes any E(3) transformation including rotation, translation, and reflection. To preserve E(3)-invariance, the input node embeddings \({{\varvec{h}}}\) and geometric features are all E(3)-invariant. To update these features, PAMNet utilizes operations that can preserve the invariance. In contrast, E(3)-equivariance is preserved when predicting E(3)-equivariant vectorial quantities like dipole moment, which will change according to the same transformation applied to the original molecular structure through E(3) transformation. To preserve E(3)-equivariance, an extra associated geometric vector \(\varvec{v} \in \mathbb {R}^3\) is defined for each node. These geometric vectors are updated by operations inspired by quantum mechanics22, allowing for the learning of E(3)-equivariant vectorial representations. More details about the explanations of E(3)-invariance, E(3)-equivariance, and our operations can be found in Methods.

Fusion module

After updating the node embeddings \({\varvec {z}}_g\) or \({\varvec {z}}_l\) of the two layers in the multiplex graph G, we design a fusion module with a two-step pooling process to combine \({\varvec {z}}_g\) and \({\varvec {z}}_l\) for downstream tasks (Fig. 1b). In the first step, we design an attention pooling module based on attention mechanism23 for each hidden layer t in PAMNet. Since \(G_{global}\) and \(G_{local}\) contains the same set of nodes \(\{\textit{N}\}\), we apply the attention mechanism to each node \(n \in \{\textit{N}\}\) to learn the attention weights (\(\alpha _g^t\) and \(\alpha _l^t\)) between the node embeddings of n in \(G_{global}\) and \(G_{local}\), which are \({\varvec {z}}_g^t\) and \({\varvec {z}}_l^t\). Then the attention weights are treated as the importance of \({\varvec {z}}_g^t\) and \({\varvec {z}}_l^t\) to compute the combined node embedding \({\varvec {z}}^t\) in each hidden layer t based on a weighted summation (Fig. 1e). In the second step, the \({\varvec {z}}^t\) of all hidden layers are summed together to compute the node embeddings of the original input. If a graph embedding is desired, we compute it using an average or a summation of the node embeddings.

Results and discussion

In this section, we will demonstrate the performance of our proposed PAMNet regarding two aspects: accuracy and efficiency. Accuracy denotes how well the model performs measured by the metrics corresponding to a given task. Efficiency denotes the memory consumed and the inference time spent by the model.

Figure 1
figure 1

Overview of PAMNet. (a), Based on the 3D structure of any molecule or molecular system, a two-layer multiplex graph \(G = \{G_{global}, G_{local}\}\) is constructed to separate the modeling of global and local interactions. (b), PAMNet takes G as input and learns node-level or graph-level representation for downstream tasks. PAMNet contains stacked message passing modules that update the node embeddings \({\varvec {z}}\) in G, and a fusion module that learns to combine the updated embeddings. In each message passing module, two message passing schemes are designed to encode the different geometric information in G’s two layers. In the fusion module, a two-step pooling process is proposed. (c), Calculation of molecular energy E in molecular mechanics. (d), An example of the geometric information in G. By considering the one-hop neighbors \(\{j\}\) and two-hop neighbors \(\{k\}\) of atom i, we can define the pairwise distances d and the related angles \(\theta\). (e), Detailed architecture of the message passing module and the attention pooling in PAMNet.

Performance of PAMNet regarding accuracy

Small molecule property prediction

To evaluate the accuracy of PAMNet in learning representations of small 3D molecules, we choose QM9, which is a widely used benchmark for the prediction of 12 molecular properties of around 130k small organic molecules with up to 9 non-hydrogen atoms24. Mean absolute error (MAE) and mean standardized MAE (std. MAE)15 are used for quantitative evaluation of the target properties. Besides evaluating the original PAMNet which captures geometric information within two-hop neighborhoods of each node, we also develop a “simple” PAMNet, called PAMNet-s, that utilizes only the geometric information within one-hop neighborhoods. The PAMNet models are compared with several state-of-the-art models including SchNet13, PhysNet14, MGCN25, PaiNN26, DimeNet++16, and SphereNet27. More details of the experiments can be found in Methods and Supplementary Information.

Table 1 Performance comparison on QM9.

We compare the performance of PAMNet with those of the baseline models mentioned above on QM9, as shown in Table 1. PAMNet achieves 4 best and 6 second-best results among all 12 properties, while PAMNet-s achieves 3 second-best results. When evaluating the overall performance using the std. MAE across all properties, PAMNet and PAMNet-s rank 1 and 2 among all models with 10\(\%\) and 5\(\%\) better std. MAE than the third-best model (SphereNet), respectively. From the results, we can observe that the models incorporating only atomic pairwise distance d as geometric information like SchNet, PhysNet, and MGCN generally perform worse than those models incorporating more geometric information like PaiNN, DimeNet++, SphereNet, and our PAMNet. Besides, PAMNet-s which captures geometric information only within one-hop neighborhoods performs worse than PAMNet which considers two-hop neighborhoods. These show the importance of capturing rich geometric information when representing 3D small molecules. The superior performance of PAMNet models demonstrates the power of our separate modeling of different interactions in molecules and the effectiveness of the message passing modules designed.

When predicting dipole moment \(\mu\) as a scalar value, which is originally an E(3)-equivariant vectorial property \(\varvec {\mu }\), PAMNet preserves the E(3)-equivariance to directly predict \(\varvec {\mu }\) first and then takes the magnitude of \(\varvec {\mu }\) as the final prediction. As a result, PAMNet and PAMNet-s all get lower MAE (10.8 mD and 11.3 mD) than the previous best result (12 mD) achieved by PaiNN, which is a GNN with equivariant operations for predicting vectorial properties. Note that the remaining baselines all directly predict dipole moment as a scalar property by preserving invariance. We also examine that by preserving invariance in PAMNet and directly predicting dipole moment as a scalar property, the MAE (24.0 mD) is much higher than the equivariant version. These results demonstrate that preserving equivariance is more helpful than preserving invariance for predicting dipole moments.

RNA 3D structure prediction

Besides small molecules, we further apply PAMNet to predict RNA 3D structures for evaluating the accuracy of PAMNet in learning representations of 3D macromolecules. Following the previous works28,29,30, we refer the prediction to be the task of identifying accurate structural models of RNA from less accurate ones: Given a group of candidate 3D structural models generated based on an RNA sequence, a desired model that serves as a scoring function needs to distinguish accurate structural models among all candidates. We use the same datasets as those used in30, which include a dataset for training and a benchmark for evaluation. The training dataset contains 18 relatively older and smaller RNA molecules experimentally determined31. Each RNA is used to generate 1000 structural models via the Rosetta FARFAR2 sampling method29. The benchmark for evaluation contains relatively newer and larger RNAs, which are the first 21 RNAs in the RNA-Puzzles structure prediction challenge32. Each RNA is used to generate at least 1500 structural models using FARFAR2, where 1\(\%\) of the models are near-native (i.e., within a 2 Å RMSD of the experimentally determined native structure). In practice, each scoring function predicts the root mean square deviation (RMSD) from the unknown true structure for each structural model. A lower RMSD would suggest a more accurate structural model predicted. We compare PAMNet with four state-of-the-art baselines: ARES30, Rosetta (2020 version)29, RASP33, and 3dRNAscore28. Among the baselines, only ARES is a deep learning-based method, and is a GNN using equivariant operations. More details of the experiments are introduced in Methods and Supplementary Information.

Figure 2
figure 2

Performance comparison on RNA-Puzzles. Given a group of candidate structural models for each RNA, we rank the models using PAMNet and the other four leading scoring functions for comparison. Each cross in the figures corresponds to one RNA. (a) The best-scoring structural model of each RNA predicted by the scoring functions is compared. PAMNet in general identifies more accurate models (with lower RMSDs from the native structure) than those decided by the other scoring functions. (b) Comparison of the 10 best-scoring structural models. The identifications of PAMNet contain accurate models more frequently than those from other scoring functions. (c) The rank of the best-scoring near-native structural model for each RNA is used for comparison. PAMNet usually performs better than the other scoring functions by having a lower rank.

On the RNA-Puzzles benchmark for evaluation, PAMNet significantly outperforms all other four scoring functions as shown in Fig. 2. When comparing the best-scoring structural model of each RNA (Fig. 2a), the probability of the model to be near-native (\(<2\) Å RMSD from the native structure) is 90\(\%\) when using PAMNet, compared with 62, 43, 33, and 5\(\%\) for ARES, Rosetta, RASP, and 3dRNAscore, respectively. As for the 10 best-scoring structural models of each RNA (Fig. 2b), the probability of the models to include at least one near-native model is 90\(\%\) when using PAMNet, compared with 81, 48, 48, and 33\(\%\) for ARES, Rosetta, RASP, and 3dRNAscore, respectively. When comparing the rank of the best-scoring near-native structural model of each RNA (Fig. 2c), the geometric mean of the ranks across all RNAs is 1.7 for PAMNet, compared with 3.6, 73.0, 26.4, and 127.7 for ARES, Rosetta, RASP, and 3dRNAscore, respectively. The lower mean rank of PAMNet indicates that less effort is needed to go down the ranked list of PAMNet to include one near-native structural model. A more detailed analysis of the near-native ranking task can be found in Supplementary Figure S1.

Protein-ligand binding affinity prediction

In this experiment, we evaluate the accuracy of PAMNet in representing the complexes that contain both small molecules and macromolecules. We use PDBbind, which is a well-known public database of experimentally measured binding affinities for protein-ligand complexes34. The goal is to predict the binding affinity of each complex based on its 3D structure. We use the PDBbind v2016 dataset and preprocess each original complex to a structure that contains around 300 nonhydrogen atoms on average with only the ligand and the protein residues within 6 Å around it. To comprehensively evaluate the performance, we use Root Mean Square Error (RMSE), Mean Absolute Error (MAE), Pearson’s correlation coefficient (R) and the standard deviation (SD) in regression following18. PAMNet is compared with various comparative methods including machine learning-based methods (LR, SVR, and RF-Score35), CNN-based methods (Pafnucy36 and OnionNet37), and GNN-based methods (GraphDTA38, SGCN39, GNN-DTI40, D-MPNN12, MAT41, DimeNet15, CMPNN42, and SIGN18). More details of the experiments are provided in Methods and Supplementary Information.

Table 2 Performance comparison on PDBbind.

We list the results of all models and compare their performance in Table 2 and Supplementary Table S1. PAMNet achieves the best performance regarding all 4 evaluation metrics in our experiment. When compared with the second-best model, SIGN, our PAMNet performs significantly better with p-value < 0.05. These results clearly demonstrate the accuracy of our model when learning representations of 3D macromolecule complexes.

In general, we find that the models with explicitly encoded 3D geometric information like DimeNet, SIGN, and our PAMNet outperform the other models without the information directly encoded. An exception is that DimeNet cannot beat CMPNN. This might be because DimeNet is domain-specific and is originally designed for small molecules rather than macromolecule complexes. In contrast, our proposed PAMNet is more flexible to learn representations for various types of molecular systems. The superior performance of PAMNet for predicting binding affinity relies on the separate modeling of local and non-local interactions. For protein-ligand complexes, the local interactions mainly capture the interactions inside the protein and the ligand, while the non-local interactions can capture the interactions between protein and ligand. Thus PAMNet is able to effectively handle diverse interactions and achieve accurate results.

Performance of PAMNet regarding efficiency

Table 3 Results of efficiency evaluation.
Table 4 Comparison of the average attention weights

To evaluate the efficiency of PAMNet, we compare it to the best-performed baselines in each task regarding memory consumption and inference time and summarize the results in Table 3. Theoretically, DimeNet++, SphereNet, and SIGN all require \(O(Nk^2)\) messages in message passing, while our PAMNet requires \(O(N(k_g+{k_l}^2))\) messages instead, where N is the number of nodes, k is the average degree in a graph, \(k_g\) and \(k_l\) denotes the average degree in \(G_g\) and \(G_l\) in the corresponding multiplex graph G. When \(k_g \sim k\) and \(k_l \ll k_g\), PAMNet is much more efficient regarding the number of messages involved. A more detailed analysis of computational complexity is included in Methods. Based on the results in Table 3 empirically, we find PAMNet models all require less memory consumption and inference time than the best-performed baselines in all three tasks, which matches our theoretical analysis. We also compare the memory consumption when using a different largest cutoff distance d of the related models in Fig. 3. From the results, we observe that the memory consumed by DimeNet and SIGN increases much faster than PAMNet when d increases. When fixing \(d=5\) Å as an example, PAMNet requires 80\(\%\) and 71\(\%\) less memory than DimeNet and SIGN, respectively. Thus PAMNet is much more memory-efficient and is able to capture longer-range interactions than these baselines with restricted resources. The efficiency of PAMNet models comes from the separate modeling of local and non-local interactions in 3D molecular structures. By doing so, when modeling the non-local interactions, which make up the majority of all interactions, we utilize a relatively efficient message passing scheme that only encodes pairwise distances d as the geometric information. Thus when compared with the models that require more comprehensive geometric information when modeling all interactions, PAMNet significantly reduces the computationally expensive operations. More information about the details of experimental settings is included in Methods.

All components in PAMNet contribute to the performance

To figure out whether all of the components in PAMNet, including the fusion module and the message passing modules, contribute to the performance of PAMNet, we conduct an ablation study by designing PAMNet variants. Without the attention pooling, we use the averaged results from the message passing modules in each hidden layer to build a variant. We also remove either the Local Message Passing or the Global Message Passing for investigation. The performances of all PAMNet variants are evaluated on the three benchmarks. Specifically, the std. MAE across all properties on QM9, the geometric mean of the ranks across all RNAs on RNA-Puzzles, and the four metrics used in the experiment on PDBbind are computed for comparison. The results in Fig. 4 show that all variants decrease the performance of PAMNet in the evaluations, which clearly validates the contributions of all those components. Detailed results of the properties on QM9 can be found in Supplementary Table S2.

Analysis of the contribution of local and global interactions

A salient property of PAMNet is the incorporation of the attention mechanism in the fusion module, which takes the importance of node embeddings in \(G_{local}\) and \(G_{global}\) of G into consideration in learning combined node embeddings. Recall that for each node n in the set of nodes \(\{N\}\) in G, the attention pooling in the fusion module learns the attention weights \(\alpha _l\) and \(\alpha _g\) between n’s node embedding \({\varvec {z}}_{l}\) in \(G_{local}\) and n’s node embedding \({\varvec {z}}_{g}\) in \(G_{global}\). \(\alpha _l\) and \(\alpha _g\) serve as the importance of \({\varvec {z}}_{l}\) and \({\varvec {z}}_{g}\) when computing the combined node embedding \({\varvec {z}}\). To better understand the contribution of \({\varvec {z}}_{l}\) and \({\varvec {z}}_{g}\), we conduct a detailed analysis of the learned attention weights \(\alpha _l\) and \(\alpha _g\) in the three tasks we experimented with. Since the node embeddings are directly related to the involved interactions, such analysis can also reveal the contribution of local and global interactions on the predictions in different tasks. In each task, we take an average of all \(\alpha _l\) or \(\alpha _g\) to be the overall importance of the corresponding group of interactions. Then we compare the computed average attention weights \(\overline{\alpha _l}\) and \(\overline{\alpha _g}\) and list the results in Table 4. A higher attention weight in each task indicates a stronger contribution of the corresponding interactions on solving the task.

For the targets being predicted in QM9, we find that all of them have \(\overline{\alpha _l} \ge \overline{\alpha _g}\) except the electronic spatial extent \(\left\langle R^{2}\right\rangle\), indicating a stronger contribution of the local interactions, which are defined by chemical bonds in this task. This may be because QM9 contains small molecules with only up to 9 non-hydrogen atoms, local interactions can capture a considerable portion of all atomic interactions. However, when predicting electronic spatial extent \(\left\langle R^{2}\right\rangle\), we notice that \(\overline{\alpha _l} < \overline{\alpha _g}\), which suggests that \(\left\langle R^{2}\right\rangle\) is mainly affected by the global interactions that are the pairwise interactions within \(10\) Å in this case. This is not surprising since \(\left\langle R^{2}\right\rangle\) is the electric field area affected by the ions in the molecule, and is directly related to the diameter or radius of the molecule. Besides, previous study43 has demonstrated that graph properties like diameter and radius cannot be computed by message passing-based GNNs that rely entirely on local information, and additional global information is needed. Thus it is expected that global interactions have a stronger contribution than local interactions on predicting electronic spatial extent.

For the RNA 3D structure prediction on RNA-Puzzles and the protein-ligand binding affinity prediction on PDBbind, we find \(\overline{\alpha _l} < \overline{\alpha _g}\) in both cases, which indicates that global interactions play a more important role than local interactions. It is because the goals of these two tasks highly rely on global interactions, which are necessary for representing the global structure of RNA when predicting RNA 3D structure, and are crucial for capturing the relationships between protein and ligand when predicting binding affinity.

Conclusion

In this work, we tackle the limitations of previous GNNs regarding their limited applicability and inefficiency for representation learning of molecular systems with 3D structures and propose a universal framework, PAMNet, to accurately and efficiently learn the representations of 3D molecules in any molecular system. PAMNet explicitly models local and non-local interaction as well as their combined effects inspired by molecular mechanics. The resulting framework incorporates rich geometric information like distances and angles when modeling local interactions, and avoids using expensive operations on modeling non-local interactions. Besides, PAMNet learns the contribution of different interactions to combine the updated node embeddings for the final output. When designing the aforementioned operations in PAMNet, we preserve E(3)-invariance for scalar output and preserve E(3)-equivariance for vectorial output to enable more applicable cases. In our experiments, we evaluate the performance of PAMNet with state-of-the-art baselines on various tasks involving different molecular systems, including small molecules, RNAs, and protein-ligand complexes. In each task, PAMNet outperforms the corresponding baselines in terms of both accuracy and efficiency. These results clearly demonstrate the generalization power of PAMNet even though non-local interactions in molecules are modeled with only pairwise distances as geometric information.

An under-investigated aspect of our proposed PAMNet is that PAMNet preserves E(3)-invariance in operations when predicting scalar properties while requiring additional representations and operations to preserve E(3)-equivariance for vectorial properties. Considering that various equivariant GNNs have been proposed for predicting either scalar or vectorial properties solely by preserving equivariance, it would be worth extending the idea in PAMNet to equivariant GNNs with a potential to further improve both accuracy and efficiency. Another interesting direction is that although we only experiment PAMNet on single-task learning, PAMNet is promising to be used in multi-task learning across diverse tasks that involve molecules of varying sizes and types to gain better generalization. Besides using PAMNet for predicting physiochemical properties of molecules, PAMNet can be used as a universal building block for the representation learning of molecular systems in various molecular science problems. Another promising application of PAMNet is self-supervised learning for molecular systems with few labeled data (e.g., RNA structures). For example, we can use the features in one graph layer to learn properties in another graph layer by utilizing the multiplex nature of PAMNet.

Methods

Details of PAMNet

In this section, we will describe PAMNet in detail, including the involved features, embeddings, and operations.

Figure 3
figure 3

Memory consumption vs. the largest cutoff distance d on PDBbind. We compare PAMNet with the GNN baselines that also explicitly incorporate the 3D molecular geometric information like pairwise distances and angles.

Figure 4
figure 4

Ablation study of PAMNet. We compare the variants with the original PAMNet and report the differences.

Input features

The input features of PAMNet include atomic features and geometric information as shown in Fig. 1b. For atomic features, we use only atomic numbers Z for the tasks on QM9 and RNA-Puzzles following13,14,15,16, 30, and use 18 chemical features like atomic numbers, hybridization, aromaticity, partial charge, etc., for the task on PDBbind following18, 36. The atomic numbers Z are represented by randomly initialized, trainable embeddings according to13,14,15,16. For geometric information, we capture the needed pairwise distances and angles in the multiplex molecular graph G as shown in Fig. 1d. The features (d, \(\theta\)) for the distances and angles are computed with the basis functions in15 to reduce correlations. For the prediction of vectorial properties, we use the atomic position \(\varvec{r}\) to be the initial associated geometric vector \(\varvec{v}\) of each atom.

Message embeddings

In the message passing scheme7, the update of node embeddings \(\varvec{h}\) relies on the passing of the related messages \(\varvec{m}\) between nodes. In PAMNet, we define the input message embeddings \(\varvec{m}\) of message passing schemes with the following way:

$$\begin{aligned} \varvec{m}_{ji}&= \textrm{MLP}_m([\varvec{h}_{j} | \varvec{h}_{i} | \varvec{e}_{ji}]), \end{aligned}$$
(1)

where \(i, j \in G_{global}\) or \(G_{local}\) are connected nodes that can define a message embedding, \(\textrm{MLP}\) denotes the multi-layer perceptron, | denotes the concatenation operation. The edge embedding \(\varvec{e}_{ji}\) encodes the corresponding pairwise distance d between node ij.

Global message passing

As depicted in Fig. 1e, the Global Message Passing in each hidden layer of PAMNet, which consists of a message block and an update block, updates the node embeddings \(\varvec{h}\) in \(G_{global}\) by using the related adjacency matrix \({\textbf {A}}_{global}\) and pairwise distances \(d_{global}\). The message block is defined as below to perform the message passing operation:

$$\begin{aligned} \varvec{h}_{i}^{t}&= \varvec{h}_{i}^{t-1} + \sum \nolimits _{j \in \mathcal {N}(i)} \varvec{m}_{ji}^{t-1}\odot \phi _{d}(\varvec{e}_{j i}), \end{aligned}$$
(2)

where \(i, j \in G_{global}\), \(\phi _{d}\) is a learnable function, \(\varvec{e}_{ji}\) is the embedding of pairwise distance d between node ij, and \(\odot\) denotes the element-wise production. After the message block, an update block is used to compute the node embeddings \(\varvec{h}\) for the next layer as well as the output \({\varvec {z}}\) for this layer. We define the update block using a stack of three residual blocks, where each residual block consists of a two-layer MLP and a skip connection across the MLP. There is also a skip connection between the input of the message block and the output of the first residual block. After the residual blocks, the updated node embeddings \(\varvec{h}\) are passed to the next layer. For the output \({\varvec {z}}\) of this layer to be combined in the fusion module, we further use a three-layer MLP to get \({\varvec {z}}\) with desired dimension size.

Local message passing

For the updates of node embeddings \(\varvec{h}\) in \(G_{local}\), we incorporate both pairwise distances \(d_{local}\) and angles \(\theta _{local}\) as shown in Fig. 1e. To capture \(\theta _{local}\), we consider up to the two-hop neighbors of each node. In Fig. 1d, we show an example of the angles we considered: Some angles are between one-hop edges and two-hop edges (e.g. \(\angle i j_1 k_1\)), while the other angles are between one-hop edges (e.g. \(\angle j_1 i j_2\)). Compared to previous GNNs15,16,17 that incorporate only part of these angles, our PAMNet is able to encode the geometric information more comprehensively. In the Local Message Passing, we also use a message block and an update block following the design of the Global Message Passing as shown in Fig. 1e. However, the message block is defined differently as the one in the Global Message Passing to encode additional angular information:

$$\begin{aligned} \varvec{m}_{ji}^{'t-1}&= \varvec{m}_{ji}^{t-1} + \sum _{j' \in \mathcal {N}(i)\setminus \{j\}} \varvec{m}_{j'i}^{t-1} \odot \phi _{d}(\varvec{e}_{j'i}) \odot \phi _{\theta }(\varvec{a}_{j'i, j i}) + \sum _{k \in \mathcal {N}(j)\setminus \{i\}} \varvec{m}_{kj}^{t-1} \odot \phi _{d}(\varvec{e}_{kj}) \odot \phi _{\theta }(\varvec{a}_{k j, j i}), \end{aligned}$$
(3)
$$\begin{aligned} \varvec{h}_{i}^{t}&= \varvec{h}_{i}^{t-1} + \sum _{j \in \mathcal {N}(i)} \varvec{m}_{ji}^{'t-1}\odot \phi _{d}(\varvec{e}_{ji}), \end{aligned}$$
(4)

where \(i, j, k \in G_{local}\), \(\varvec{e}_{ji}\) is the embedding of pairwise distance d between node ij, \(\varvec{a}_{k j, j i}\) is the embedding of angle \(\theta _{k j, j i}=\angle kji\) defined by node ijk, and \(\phi _{d}, \phi _{\theta }\) are learnable functions. In Eq. (3), we use two summation terms to separately encode the angles in different hops with the associated pairwise distances to update \(\varvec{m}_{ji}\). Then in Eq. (4), the updated message embeddings \(\varvec{m}_{ji}'\) are used to perform message passing. After the message block, we use the same update block as the one used in the Global Message Passing for updating the learned node embeddings.

Fusion module

The fusion module consists of two steps of pooling as shown in Fig. 1b. In the first step, attention pooing is utilized to learn the combined embedding \({\varvec {z}}^{t}\) based on the output node embeddings \({\varvec {z}}_{g}^{t}\) and \({\varvec {z}}_{l}^{t}\) in each hidden layer t. The detailed architecture of attention pooling is illustrated in Fig. 1e. We first compute the attention weight \(\alpha _{{p},i}\) on node i that measures the contribution of the results from plex or graph layer \({p} \in \{g, l\}\) in multiplex graph G:

$$\begin{aligned} \alpha _{{p},i}^{t} = \frac{\exp \left( {\text {LeakyReLU}} \left( \textbf{W}_{{p}}^t {\varvec {z}}_{{p},i}^t\right) \right) }{\sum _{{p}}\exp \left( {\text {LeakyReLU}}\left( \textbf{W}_{{p}}^t {\varvec {z}}_{{p},i}^t\right) \right) }, \end{aligned}$$
(5)

where \(\textbf{W}_{{p}}^t \in \mathbb {R}^{1\times F}\) is a learnable weight matrix different for each hidden layer t and graph layer p, and F is the dimension size of \({\varvec {z}}_{{p},i}^t\). With \(\alpha _{{p},i}^{t}\), we can compute the combined node embedding \({\varvec {z}}_i^t\) of node i using a weighted summation:

$$\begin{aligned} {\varvec {z}}_i^t = \sum \nolimits _{{p}} \alpha _{{p},i}^{t} \left( \textbf{W}_{{p}}^{'t}{} {\varvec {z}}_{{p},i}^t\right) , \end{aligned}$$
(6)

where \(\textbf{W}_{{p}}^{'t} \in \mathbb {R}^{D\times F}\) is a learnable weight matrix different for each hidden layer t and graph layer p, D is the dimension size of \({\varvec {z}}_i^t\), and F is the dimension size of \({\varvec {z}}_{{p},i}^t\).

In the second step of the fusion module, we sum the combined node embedding \({\varvec {z}}\) of all hidden layers to compute the final node embeddings \(\varvec{y}\). If a graph-level embedding \(\varvec{y}\) is desired, we compute as follows:

$$\begin{aligned} \varvec{y} = \sum \nolimits _{i=1}^{N}\sum \nolimits _{t=1}^{T} {\varvec {z}}_i^t. \end{aligned}$$
(7)

Preservation of E(3)-invariance & E(3)-equivariance

For the operations described above, they preserve the E(3)-invariance of the input atomic features and geometric information and can predict E(3)-invariant scalar properties. To predict E(3)-equivariant vectorial property \(\varvec{u}\), we introduce an associated geometric vector \(\varvec{v}_i\) for each node i and extend PAMNet to preserve the E(3)-equivariance for learning \(\varvec{u}\). In detail, the associated geometric vector \(\varvec{v}_i^{t}\) of node i in hidden layer t is defined as:

$${\varvec{v}} _{i}^{t} = f_{v} (\{ \varvec{h}^{t} \} ,\{ {\varvec{r}} \} ),$$
(8)

where \(\{\varvec{h}^{t}\}\) denotes the set of learned node embeddings of all nodes in hidden layer t, \(\{\varvec{r}\}\) denotes the set of position vectors of all nodes in 3d coordinate space, and \(f_v\) is a function that preserves the E(3)-equivariance of \(\varvec{v}_i^{t}\) with respect to \(\{\varvec{r}\}\). Equation (8) is computed after each message passing module in PAMNet.

To predict a final vectorial property \(\varvec{u}\), we modify Eqs. (6) and (7) in the fusion module as the following operations:

$$\begin{aligned} \varvec{u}_{i}^{t}&= \sum \nolimits _{{p}} \alpha _{{p},i}^{t} \left( \textbf{W}_{{p}}^{'t}{} {\varvec {z}}_{{p},i}^t\right) \varvec{v}_{{p},i}^t, \end{aligned}$$
(9)
$$\begin{aligned} \varvec{u}&= \sum \nolimits _{i=1}^{N}\sum \nolimits _{t=1}^{T} \varvec{u}_{i}^{t}, \end{aligned}$$
(10)

where \(\varvec{v}_{{p},i}^t\) is the associated geometric vector of node i on graph layer p in hidden layer t, \(\varvec{u}_{i}^{t}\) is the learned vector of node i in hidden layer t, and \(\textbf{W}_{{p}}^{'t} \in \mathbb {R}^{1\times F}\) is a learnable weight matrix different for each hidden layer t and graph layer p. In Eq. (9), we multiply \(\varvec{v}_{{p},i}^t\) with the learned scalar node contributions. In Eq. (10), we sum all node-level vectors in all hidden layers to compute the final prediction \(\varvec{u}\).

For predicting dipole moment \(\varvec {\mu }\) , which is an E(3)-equivariant vectorial property that describes the net molecular polarity in electric field, we design \(f_v\) in Eq. (8) as motivated by quantum mechanics44. The conventional method to compute molecular dipole moment involves approximating electronic charge densities as concentrated at each atomic position, resulting in \(\varvec {\mu }=\sum \nolimits _{i}\varvec{r}_{c,i}q_i\), where \(q_i\) is the partial charge of node i, and \(\varvec{r}_{c,i}=\varvec{r}_i - \left( \sum \nolimits _{i}\varvec{r}_i\right) /N\) is the relative atomic position of node i. However, this approximation is not accurate enough. Instead, we use a more accurate approximation by adding dipoles onto atomic positions in the distributed multipole analysis (DMA) approach22. This results in the dipole moment equation: \(\varvec {\mu }=\sum \nolimits _{i}(\varvec{r}_{c,i}q_i+\varvec {\mu }_i)\), where \(\varvec {\mu }_i\) is the associated partial dipole of node i. The equation can be rewritten as \(\boldsymbol {\mu }=\sum \nolimits _{i}f_v(\varvec{r}_{i})q_i\), where \(q_i\) is the scalar atomic contribution that can be modeled by an invariant fashion. By treating \(f_v(\varvec{r}_{i})\) as \(\varvec{v}_{i}^t\), the equation has a similar format as a combination of Eqs. (9) and (10). We update \(\varvec{v}_{i}^t\) in the following way:

$$\begin{aligned} \varvec{v}_{i}^{t} =\sum \nolimits _{j \in \mathcal {N}(i)}( \varvec{r}_{i} - \varvec{r}_{j})\Vert \varvec{m}_{i j}^{t}\Vert , \end{aligned}$$
(11)

where \(\Vert \cdot \Vert\) denotes the L2 norm. Since \(\varvec{v}_{i}^{t}\) as well as \(\varvec {\mu }\) are computed by a linear combination of \(\{\varvec{r}\}\), our PAMNet can preserve E(3)-equivariance with respect to \(\{\varvec{r}\}\) when performing the prediction.

Computational complexity

We analyze the computational complexity of PAMNet by addressing the number of messages. We denote the cutoff distance when creating the edges as \(d_g\) and \(d_l\) in \(G_g\) and \(G_l\). The average degree is \(k_g\) in \(G_g\) and is \(k_l\) in \(G_l\). In each hidden layer of PAMNet, Global Message Passing needs \(O(Nk_g)\) messages because it requires one message for each pairwise distance between the central node and its one-hop neighbor. While Local Message Passing requires one message for each one-hop or two-hop angle around the central node. The number of angles can be estimated as follows: For k edges connected to a node, they can define \((k(k-1))/2\) angles which result in a complexity of \(O\left( Nk^2\right)\). The number of one-hop angles and two-hop angles all has such complexity. So that Local Message Passing needs \(O\left( 2N{k_l}^2\right)\) messages. In total, PAMNet requires the computation of \(O\left( Nk_g+2N{k_l}^2\right)\) messages in each hidden layer, while previous approaches15,16,17,18, 27 require \(O\left( N{k_g}^2\right)\) messages. For 3D molecules, we have \(k_g \propto {d_g}^3\) and \(k_l \propto {d_l}^3\). With proper choices of \(d_l\) and \(d_g\), we have \(k_l \ll k_g\). In such cases, our model is more efficient than the related GNNs. We here list the comparison of the number of messages needed in our experiments as an example: On QM9 with \(d_g=5\) Å, our model needs 0.5k messages/molecule on average, while DimeNet++ needs 4.3k messages. On PDBBind with \(d_l=2\) Å and \(d_g=6\) Å, our model needs only 12k messages/molecule on average, while DimeNet++ needs 264k messages.

Data collection and processing

QM9

For QM9, we use the source provided by24. Following the previous works15,16,17, we process QM9 by removing about 3k molecules that fail a geometric consistency check or are difficult to converge45. For properties \(U_0\), U, H, and G, only the atomization energies are used by subtracting the atomic reference energies as in15,16,17. For property \(\Delta \epsilon\), we follow the same way as the DFT calculation and predict it by calculating \(\epsilon _{\textrm{LUMO}}-\epsilon _{\textrm{HOMO}}\). For property \(\mu\), the final result is the magnitude of the predicted vectorial \(\varvec{\mu }\) when using our geometric vector-based approaches with PAMNet. The 3D molecular structures are processed using the RDKit library46. Following15, we randomly use 110000 molecules for training, 10000 for validation and 10831 for testing. In our multiplex molecular graphs, we use chemical bonds as the edges in the local layer, and a cutoff distance (5 or 10 Å) to create the edges in the global layer.

RNA-puzzles

RNA-Puzzles consists of the first 21 RNAs in the RNA-Puzzles structure prediction challenge32. Each RNA is used to generate at least 1500 structural models using FARFAR2, where 1\(\%\) of the models are near native (i.e., within a 2 Å RMSD of the experimentally determined native structure). Following30, we only use the carbon, nitrogen, and oxygen atoms in RNA structures. When building multiplex graphs for RNA structures, we use cutoff distance \(d_l=2.6\) Å for the local interactions in \(G_{local}\) and \(d_g=20\) Å for the global interactions in \(G_{global}\).

PDBBind

For PDBBind, we use PDBbind v2016 following18, 36. Besides, we use the same data splitting method according to18 for a fair comparison. In detail, we use the core subset which contains 290 complexes in PDBbind v2016 for testing. The difference between the refined and core subsets, which includes 3767 complexes, is split with a ratio of 9:1 for training and validation. We use log\(K_i\) as the target property being predicted, which is proportional to the binding free energy. In each complex, we exclude the protein residues that are more than 6 Å from the ligand and remove all hydrogen atoms. The resulting complexes contain around 300 atoms on average. In our multiplex molecular graphs, we use cutoff distance \(d_l=2\) Å in the local layer and \(d_g=6\) Å in the global layer.

Experimental settings

In our message passing operations, we define \(\phi _{d}(\varvec{e})=\textbf{W}_{e}\varvec{e}\) and \(\phi _{\theta }(\varvec{\alpha })=\textrm{MLP}_{\alpha }(\varvec{\alpha })\), where \(\textbf{W}_{e}\) is a weight matrix, \(\textrm{MLP}_{\alpha }\) is a multi-layer perceptron (MLP). All MLPs used in our model have two layers by taking advantage of the approximation capability of MLP47. For all activation functions, we use the self-gated Swish activation function48. For the basis functions, we use the same parameters as in15. To initialize all learnable parameters, we use the default settings used in PyTorch without assigning specific initializations except the initialization for the input node embeddings on QM9: \(\varvec{h}\) are initialized with random values uniformly distributed between \(-\sqrt{3}\) and \(\sqrt{3}\). In all experiments, we use the Adam optimizer49 to minimize the loss. In Supplementary Table S3, we list the typical hyperparameters used in our experiments. All of the experiments are done on an NVIDIA Tesla V100 GPU (32 GB).

Small molecule property prediction

In our experiment on QM9, we use the single-target training following15 by using a separate model for each target instead of training a single shared model for all targets. The models are optimized by minimizing the mean absolute error (MAE) loss. We use a linear learning rate warm-up over 1 epoch and an exponential decay with a ratio 0.1 every 600 epochs. The model parameter values for validation and testing are kept using an exponential moving average with a decay rate of 0.999. To prevent overfitting, we use early stopping on the validation loss. For properties ZPVE, \(U_0\), U, H, and G, we use the cutoff distance in the global layer \(d_g=5\) Å. For the other properties, we use \(d_g=10\) Å. We repeat our runs 3 times for each PAMNet variant following50.

RNA 3D structure prediction

PAMNet is optimized by minimizing the smooth L1 loss51 between the predicted value and the ground truth. An early-stopping strategy is adopted to decide the best epoch based on the validation loss.

Protein-ligand binding affinity prediction

We create three weight-sharing, replica networks, one each for predicting the target G of complex, protein pocket, and ligand following52. The final target is computed by \(\Delta G_{\text {complex}} = G_{\text {complex}} - G_{\text {pocket}} - G_{\text {ligand}}\). The full model is trained by minimizing the mean absolute error (MAE) loss between \(\Delta G_{\text {complex}}\) and the true values. The learning rate is dropped by a factor of 0.2 every 50 epochs. Moreover, we perform 5 independent runs according to18.

Efficiency comparison

In the experiment on investigating the efficiency of PAMNet, we use NVIDIA Tesla V100 GPU (32 GB) for a fair comparison. For small molecule property prediction, we use the related models for predicting property \(U_0\) of QM9 as an example. We use batch size=128 for all models and use the configurations reported in the corresponding papers. For RNA 3D structure prediction, we use PAMNet and ARES to predict the structural models of RNA in puzzle 5 of RNA-Puzzles challenge. The RNA being predicted has 6034 non-hydrogen atoms. The model settings of PAMNet and ARES are the same as those used for reproducing the best results. We use batch size=8 when performing the predictions. For protein-ligand binding affinity prediction, we use the configurations that can reproduce the best results for the related models.