Introduction

Machine learning is a widely used tool to predict material properties, which is not only much faster (multiple orders of magnitude) than AB initio calculations, but also with close prediction accuracy based on large material databases1,2,3,4. Traditional machine learning methods5,6,7,8,9,10,11,12,13 have been developed for various property prediction tasks, such as formation energy, band gap, thermal conductivity, and so on. Compared to traditional machine learning, many deep networks like equivariant 3D convolution networks14,15,16,17 and graph neural network (GNN) have been proposed recently. As a novel machine learning model, GNN has advantages in representing material’s topological structures by graph18,19. GNN has been designed to let networks automatically learn complex features from nearly raw structure data which could just contain atomic numbers and the position of each atom. Up to now, GNN has already achieved remarkably high accuracy for predicting material properties that traditional machine learning has never achieved before19,20,21,22.

The topological structure can be well embedded as the adjacency matrix \(A\) in the graph for material property predictions. For example, the Crystal Graph Neural Network (CGNN)21 achieved a mean absolute error (MAE) of formation energy (\({E}_{f}\)) = 0.0346 eV/atom with a single model that only used topological information in the aggregation process. In material prediction domain, the geometrical structure information like spatial distance and direction is also important because the relative spatial position of atoms in the microstructure is closely related to the charge interaction between them and thus affects macroscopic properties of the material.23 Many works have already introduced geometric structure information into their models. The Crystal Graph Convolutional Neural Network (CGCNN)19 chose the distance between atoms to represent the edges in the crystal graph. The Materials Graph Network (MEGNet)24 introduced manual features that included topological distance and spatial distance. The directional message passing neural network (DimeNet)25 encoded the directional information into GNN models for molecular materials26,27. However, previous GNN based works have provided incomplete spatial geometrical information, such as just the distance19,22,24, or one additional angle25, which makes the models unable to learn the complete local geometrical relationship between atoms.

There are two different ways to embed the geometrical information into a GNN model. One way is to directly embed the geometrical information into node or edge features, and the other one is to encode the geometrical information first and then use the encoded information to act on the message passing process. A proper encoder could transform discrete and unnormalized geometrical information into a set of normalized data which is better to be learned. In addition, specific formation of the encoder could introduce the physical meanings related to the specific task. For example, PhysNet28 introduced gaussian attention mask to filter the message between nodes according to the physical knowledge that bound state wave functions in two-body systems decay exponentially. DimeNet25 introduced the 2D Fourier-Bessel basis as an attention mask under the hypothesis that each atom exists in an infinite deep spherical well. The utilization of attention masks is an excellent way to encode the geometrical structure among atoms. Since it is just beginning of the use in the field of crystal material prediction, there is still room for development. In this work, we propose a GNN model to accurately predict properties for any crystalline materials, which is invariant to global 3D rotations, translations, and node permutations. Our model achieves unprecedented prediction accuracy by introducing complete local spatial geometrical information. On the one hand, we construct a directed multi-graph and define the edge feature as the distance vector between atoms and their neighbors to incorporate the complete geometrical information in the crystal lattice coordinate system. Unlike distance and angles used in previous woks19,25, distance vector registers the complete information about local spatial geometrical structure. In addition, to validate discrete and unnormalized distance vector data better learned by the model, we propose an encoder as an attention mask to transform the discrete distance vector to two sets of orthogonal basis functions inspired by mixed basis29,30 in the solution space of Schrödinger’s equation. As the Bloch theorem utilizes plane waves to describe periodic structure of crystals, we introduce the plane waves in the mask function for the crystal material predictions. Experimental results prove that they can help our model better learn crystal structures. Finally, without compromising accuracy, to ensure the universality, the initial features of nodes and edges in our model only contain atomic numbers and the position of atoms.

Our main contributions can be summarized as below: (1) We propose a message passing neural network (MPNN)31-based GNN architecture with high prediction accuracy for the formation energy and band gap; and (2) we provide an effective way to encode the local geometrical information in the process of aggregation, that is, an attention mask composed by Gaussian radial basis and plane waves.

Results

Crystal graph definition and the introduction of geometric information

In this section, we construct a crystal graph representation suitable for any stoichiometric crystalline material. Such a graph retains the information of the topological and geometric structure of crystals. It also records the periodicity and the key crystal information, such as the crystal lattice vector and the cell volume. Besides, representations in the graph meet translation invariance and node permutation invariance.

In a molecule graph representation, nodes usually represent atoms in the molecule and edges represent the chemical bonds between atoms18,31,32. But in crystals, there are no clearly defined chemical bonds among atoms. Hence, it is necessary to define the adjacency relationship among atoms first. Similar to CGCNN19, we define the neighbors of each atom to be the nearest \(k\) atoms in the cutoff radius \(c\). In this work, \(k=12\) and \(c=8\AA\). Note that this is a multi-graph due to the periodicity, \({u}_{(ij),k}\) means the kth edge between node \(i\) and \(j\). We implement the idea using the open python library pymatgen33. By defining the adjacency relationship in this way, our model embeds the topological structure and the periodicity of crystals. Then we define the node and edge representations on the graph. To construct a general crystal graph appropriate for all kinds of crystals, we should introduce as few manual features as possible. Node features \({v}_{i}^{0}\) are defined as a one-hot encoding that depends on atomic number, as shown in Eq. (1). The matrix \(W\) is to resize the feature’s dimension. To introduce the geometric information, we keep the distance vector between atoms as the edge features, that is, \({u}_{(ij),k}={{{{{{\bf{r}}}}}}}_{(ij),k}\) (We will omit \(k\) in the rest of the paper if there is no ambiguity). Note that the graph in this paper is also a directed graph because the edge \({u}_{(ij)}\) needs to record the distance vector from node \(i\) to node \(j\) and it is obviously different from \({u}_{(ji)}\), or even the \({u}_{(ji)}\) doesn’t exist. In addition, we record the lattice vector \({{{{{\bf{a}}}}}},{{{{{\bf{b}}}}}},{{{{{\bf{c}}}}}}\) and cell volume \(\Omega\) as \({P}_{{{{{{\rm{c}}}}}}}\) which means crystal parameters and they will be used in section C. The full picture of the crystal graph is shown in Fig. 1.

$${v}_{i}^{0}={{{{{\rm{Embedding}}}}}}({z}_{i})=W({{{{{\rm{Onehot}}}}}}({z}_{i}));{u}_{ij}={\overrightarrow{r}}_{ij}$$
(1)
Fig. 1: The crystal graph.
figure 1

a NaCl crystal structure. b crystal graph with num of neighbors=8, each arrow represents a directed edge from one atom node to another. In this graph \(G=({P}_{{{{{{\rm{c}}}}}}},\left\{v\right\},{{{{{\rm{\{}}}}}}u{{{{{\rm{\}}}}}}})\). The global descriptor \({P}_{{{{{{\rm{c}}}}}}}=(\Omega ,{{{{{\bf{a}}}}}},{{{{{\bf{b}}}}}},{{{{{\bf{c}}}}}})\), node set \({{{{{\rm{\{}}}}}}v{{{{{\rm{\}}}}}}}=({v}_{i},{v}_{j})\), edge set \({{{{{\rm{\{}}}}}}u{{{{{\rm{\}}}}}}}=({u}_{\left({ii}\right),l},{u}_{\left({ij}\right),m},{u}_{\left({jj}\right),n},{u}_{\left({ji}\right),k}{{{{{\rm{|}}}}}}l,m,n,k\in [{{{{\mathrm{1,4}}}}}])\).

A MPNN-based GNN for crystal property prediction

As described in the “Methods” part, the forward propagation process of GNN can be explained as two steps: the node updating and the target outputting. Note that in this work the edges will not update with iteration. The note updating process in our model can be written as Eq. (2).

$${v}_{i}^{t}={f}_{{{{{{\rm{update}}}}}}}\left.\left({v}_{i}^{t-1},{f}_{{{{{{\rm{agg}}}}}}}({v}_{j}^{t-1},{{{{{{\bf{r}}}}}}}_{ij},{P}_{{{{{{\rm{c}}}}}}})|\,j\in {N}_{i}\right)\right)$$
(2)

Similar to CGCNN19, we utilize the architecture of Message Passing Neural Network31 (MPNN) and Gated Convolution (Gated Conv) to implement Eq. (2). In MPNN, the message being passed from node \(i\) to node \(j\) is a function of \({v}_{j}\oplus {v}_{i}\), where \(\oplus\) denotes the concatenate operation. In this work, we add additional information into the message (Eq. 3): the difference quotient of node features \(\nabla {v}_{ij}=({v}_{j}-{v}_{i})/|{{{{{{\bf{r}}}}}}}_{ij}|\), which means the change of node features with distance between nodes. Experiments have proven that it improves the performance of the model by 5%. As the distance vector \({{{{{{\bf{r}}}}}}}_{ij}\) plays a role to represent the relationship between atoms and it is hard to be directly learned by the model due to its discreteness and non-normality, we embed \({{{{{{\bf{r}}}}}}}_{ij}\) and \({P}_{{{{{{\rm{c}}}}}}}\) into a learnable attention mask34 \(M\) to filter the message being passed from the neighbors, where \(M\) will be carefully discussed in the next section. The mathematical expression is shown in Eqs. (35), where \(W\) denotes the learnable weight matrix and \(\odot\) denotes the element-wise product operation.\(\,\sigma\) denotes any nonlinear function which is the Elu function in this work and \(g\) denotes a Sigmoid function to filter the message passing from node \(i\) to \(j\). Note that there is a significant difference between the two functions \(g\) and \(M\). Although both are used to filter messages, the former is just based on topological information, but the latter is based on geometric information. The concatenated node feature \({\tilde{v}}_{ij}\) first passes into an aggregation function \({f}_{{{{{{\rm{agg}}}}}}}\) with an attention mask \(M\) to get the aggregated feature \({\omega }_{i}^{t}\) for each node where \(i\) denotes the ith node and \(t\) denotes the tth layer. Then \({\omega }_{i}^{t}\) passes into an update function \({f}_{{{{{{\rm{update}}}}}}}\) to get the updated node feature \({v}_{i}^{t+1}\). One node updating process finishes and the next layer of updating starts.

$${\tilde{v}}_{ij}={v}_{j}\oplus {v}_{i}\oplus \nabla {v}_{ji}$$
(3)
$${\omega }_{i}^{t}={f}_{{{{{{\rm{agg}}}}}}}={v}_{i}^{t}+\mathop{\sum} _{j\in N\{i\}}\left.\sigma ({W}_{{{\sigma }}}^{t}{\tilde{v}}_{ij})\odot g({W}_{{{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij})\odot M({{{{{{\bf{r}}}}}}}_{ij},{P}_{{{{{{\rm{c}}}}}}}\,)\right)$$
(4)
$${v}_{i}^{t+1}={f}_{{{{{{\rm{update}}}}}}}={{{{{\rm{MLP}}}}}}({\omega }_{i}^{t})$$
(5)

To enable the model to learn different scales of information, we use a basic strategy to set the gated pooling (Eq. 6) layer after each gated Conv layer to get a layer vector \({\gamma }^{t}\) at each layer \(t\) where \(g\mbox{'}\) is a Tanh function. Then get a graph vector by a simple summation \(\mathop{\sum} _{t}{\gamma }^{t}\). Finally, input the graph vector \(\mathop{\sum} _{t}{\gamma }^{t}\) into a Multilayer Perceptron to get a real number \(P\) (Eq. 7).

$${\gamma }^{t}=\frac{1}{N}\mathop{\sum} _{i}g^{\prime} ({W}_{{{{{{\rm{g}}}}}}{{\mbox{'}}}}^{t}{\omega }_{i}^{t})\odot {W}_{{{omega }}}^{t}{\omega }_{i}^{t}$$
(6)
$$P={{{{{\rm{MLP}}}}}}(\mathop{\sum} _{t}{\gamma }^{t})$$
(7)

The training process of the model can be seen as an optimization problem. As in Eq. (8), the target is to optimize all the following weight matrices \(W\) and to minimize the loss function which is Mean Square Error (MSE) between model predictions and DFT calculations in this work.

$${min} \,{{{{{\rm{MSE}}}}}}(P(W),y)$$
(8)

where \(y\) is the DFT calculated data and \(P\) is the model predictions. \(W\) denotes all the learnable matrices in the model. This optimization problem can be solved by back propagation and gradient descent.

Utilize an attention mask to encode the crystal geometric structure

In 2017, DTNN35 introduced quantum chemical insights into GNN and encoded the distance between atoms into the node updating function of GNN. Recently, other works like PhysNet28 and DimeNet25 refined this idea and proposed attention masks respectively based on Gaussian radial function and Fourier-Bessel basis function to encode spatial and chemical information. In the latest work DimeNet25, it proved that even utilizing the simplest wave function (the solution of Schrödinger equation under infinite sphere potential) as an attention mask can significantly improve the performance of the GNN model. In this part, we introduce an attention mask \(M\) with a more precise physical meaning. It effectively encodes the geometric structure information, which is given by

$$\begin{array}{c}M={W}_{{{{{{\rm{R}}}}}}}\{{a}_{{{{{{\rm{RBF}}}}}}}(|{{{{{{\bf{r}}}}}}}_{ij}|)\}+{W}_{{{{{{\rm{P}}}}}}}\{{a}_{{{{{{\rm{PW}}}}}}}({{{{{{\bf{r}}}}}}}_{ij},\Omega ,{{{{{\bf{a}}}}}},{{{{{\bf{b}}}}}},{{{{{\bf{c}}}}}})(\oplus )G\}\end{array}$$
(9)

Since plane waves are eigenfunctions of the Schrödinger equation with constant potential, they are the natural basis in the nearly-free-electron approximation. On the other hand, local orbitals can be used as a basis to carry out a full self-consistent solution of independent particle equations. And analytic forms, especially gaussians, are extensively used in chemistry23,30. Based on these theories, the mixed basis29,30 has been developed and utilized to expand the traditional approaches, which provides a convenient way to describe some complex electronic systems like transition-metals29. Equation (9) takes the form of mixed basis: the Gaussian radial basis \(\{{a}_{{{{{{\rm{RBF}}}}}}}\}\) and the plane wave \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\) with gate function \(G\). \(W\) denotes learnable weight matrixes. Note that in Eq. (9), two tensor product operations are set here to make the dimension of two basis set match so that the element-wise plus operation can be implemented. We neither utilize any nonlinear function except for the gate function \(G,\) nor add any bias item after linear combination of the basis sets\(.\)

We choose the Gaussian basis set proposed by PhysNet28 as \(\{{a}_{{{{{{\rm{RBF}}}}}}}\,\}\) in Eq. (9), which is given by

$$\begin{array}{c}\{{a}_{{{{{{\rm{RBF}}}}}}}\}=\phi (|{{{{{{\bf{r}}}}}}}_{ij}|)\cdot exp(-{\beta }_{n}{(\exp (-|{{{{{{\bf{r}}}}}}}_{ij}|)-{\mu }_{n})}^{2})\end{array}$$
(10)

where \(|{{{{{{\bf{r}}}}}}}_{ij}|\) denotes the distance between atoms; \(\phi (|{{{{{{\bf{r}}}}}}}_{ij}|)\) denotes a smooth cutoff function36 to ensure continuous behavior when an atom enters or leaves the cutoff sphere; \({\beta }_{n}\),\(\,{\mu }_{n}\) denote the constant parameters of \({n}^{th}\) order.

\(\{{a}_{PW}\}\) in Eq. (9) is the plane wave basis set where \(\varOmega\) denotes the volume of the crystal cell. \({{{{{\bf{k}}}}}}\) denotes the points in reciprocal space (k space). A k-point mesh with \(q\times q\times q\) Monkhorst Pack special points37 is employed, which is a widely used sampling method in the first Brillouin zone. The mathematical definition of \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\) and the k-point sampling method are given by

$$\begin{array}{c}\{{a}_{{{{{{\rm{PW}}}}}}}\}={{{{{\rm{real}}}}}}(\frac{1}{\sqrt{\varOmega }}{{{{{{\rm{e}}}}}}}^{{{{{{\rm{i}}}}}}{{{{{\bf{k}}}}}}\cdot {{{{{{\bf{r}}}}}}}_{ij}})\end{array}$$
(11)
$$\begin{array}{c}u=\frac{2r-q-1}{2q}(r=1,2,3,\ldots ,q)\end{array}$$
(12)
$$\begin{array}{c}{{{{{{\bf{k}}}}}}}_{prs}={u}_{p}{{{{{{\bf{b}}}}}}}_{1}+{u}_{r}{{{{{{\bf{b}}}}}}}_{2}+{u}_{s}{{{{{{\bf{b}}}}}}}_{3};{u}_{p},{u}_{r},{u}_{s}\in u\end{array}$$
(13)

where \(q\) denotes an integer to determine the number of sampling grids and \({u}_{r}\) denotes a set of real numbers for the weight at each basis vector. \({{{{{{\bf{b}}}}}}}_{i}\) denotes the basis vector in reciprocal space corresponding to the crystal lattice vector \({{{{{{\bf{a}}}}}}}_{i}\). To be specific, \({{{{{{\bf{b}}}}}}}_{1}=2{{pi }}\frac{({{{{{{\bf{a}}}}}}}_{2}\times {{{{{{\bf{a}}}}}}}_{3})}{\varOmega };\,{{{{{{\bf{b}}}}}}}_{2}=2{{pi }}\frac{({{{{{{\bf{a}}}}}}}_{3}\times {{{{{{\bf{a}}}}}}}_{1})}{\varOmega };{{{{{{\bf{b}}}}}}}_{3}=2{{pi }}\frac{({{{{{{\bf{a}}}}}}}_{1}\times {{{{{{\bf{a}}}}}}}_{2})}{\varOmega }\). Note that in some previous works37,38, the special points in the k-space were chosen according to the specific characteristics of the given crystal system.

In this work, instead of manually selecting special k-points for various crystal systems, a learnable gate \(G\) is utilized to filter the \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\) automatically. \(G\) has the same dimension \({q}^{3}\) with \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\) and the value of each dimension ranges from 0 to 1. We tried two different formats of \(G\), \(g(W\{SG\})\) and \(g(W\{{a}_{{{{{{\rm{PW}}}}}}}\})\), where \(g\) is the Sigmoid function, \(\{SG\}\) denotes the one-hot vector of current space group and \(W\) is a learnable matrix. Experiments have proved that the latter is better, which means that the self-adapting gate is better than the space-group based gate. The reason is probably that the value of plane wave \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\,\)at each k-point already reflects the space information of a given crystal structure. All the results in this paper are based on the format of \(G=g(W\{{a}_{{{{{{\rm{PW}}}}}}}\})\).

Because of the rotational invariance of mask function \(M\), our model is invariant w.r.t. rotations of the input crystal. The rotational invariance of the \(\{{a}_{{{{{{\rm{RBF}}}}}}}\}\) in \(M\) is obvious. Figure 2(c) presents a simple example to show why \(\{{a}_{{{{{{\rm{PW}}}}}}}\}\) is also rotation invariant. The sampled k-points \({{{{{{\bf{k}}}}}}}_{i}\) and the distance vector between atoms \({{{{{{\bf{r}}}}}}}_{ij}\,\)change with the rotation operation, but their dot product will keep the same because the reciprocal space also rotates with the coordinate space.

Fig. 2: Several features of the attention mask function M.
figure 2

a \(\left\{{a}_{{{{{{{\rm{PW}}}}}}}}\right\}\) versus \({{{{{{\bf{r}}}}}}}_{{ij}}(\AA )\). 15 plane waves \(\left\{{a}_{{{{{{\rm{PW}}}}}}}\right\}\) in a plane of \(z=0\) determined by different k-points. \(q=4\). (crystal structure used here is CaTiO3). b\(\,\left\{{a}_{{{{{{\rm{RBF}}}}}}}\right\}\) versus \(\left|{{{{{{\bf{r}}}}}}}_{{ij}}\right|(\AA )\). The smooth cutoff function \(\phi\) (red line) with a set of radial bases \(\left\{{a}_{{{{{{\rm{RBF}}}}}}}\right\}\) (black line, \(n\in [{{{{\mathrm{1,16}}}}}]\), \({{{{{\rm{cutoff}}}}}}=8\mathring{\rm A}\)).c A simple example to show the rotational invariance of \(\left\{{a}_{{{{{{\rm{PW}}}}}}}\right\}.\) \({a}_{1},{a}_{2},{a}_{3}\,\)define a crystal lattice and \({b}_{1},{b}_{2},{b}_{3}\) define the corresponding reciprocal space. The distance vector \({{{{{{\bf{r}}}}}}}_{{ij}}\,\)and the sampled k-point \({{{{{{\bf{k}}}}}}}_{i}\) change with the rotation of the whole crystal but their dot product keeps the same.

As shown in Fig. 3, our model contains three main blocks: embedding, gated convolution, and output blocks. The embedding block generates an initial one-hot vector of each atom according to the atomic number and passes the initialized vector to the first gated convolution block. The gated convolution block updates the node embedding via gated graph convolution among the neighbors of each node and passes each layer’s node embedding to the output block. The output block aggregates on each crystal graph’s node features to generate each layer vector, and then passes them into the multilayer perceptron to get the final target.

Fig. 3: Model framework.
figure 3

a Model architecture, \(z\) denotes the atomic number of each atom; Embedding block generates one-hot encoding and implements a linear transformation to adjust the embedding’s dimension; \(\left\{{a}_{{{{{{{\rm{RBF}}}}}}}}\right\},\{{a}_{{{{{{{\rm{PW}}}}}}}}\}\) denote respectively the Gaussian radial basis (Equation 10) and the plane wave basis (Eq. 11), \(t\) is the layer index. b The detailed Gated Graph Convolution block includes the aggregation and updating functions (Eqs. 4 and 5). c the detailed Output block (Eqs. 6 and 7). MLP means the multi-layers perceptron where dimension = 1 in the last layer.

Model evaluation

Different datasets cause different performances of machine learning. However, many of previous works ignore this point when evaluating their models. To evaluate our model sufficiently, we compared our model against three other models with the same datasets (training set, validation set, and testing set) which comprised by different amounts of data in Material Project (MP)2 and Open Quantum Material Database (OQMD)1. MP and OQMD contain over 130K and 560K computational crystal structures, respectively. The targets we choose are formation energy per atom \(({E}_{{{{{{\rm{f}}}}}}})\), band gap \(({E}_{{{{{{\rm{g}}}}}}})\) and two elastic metrics: bulk modulus \({K}_{{{{{{\rm{VRH}}}}}}}\) and shear modulus \({G}_{{{{{{\rm{VRH}}}}}}}\). Note that there are a few data mismatches with CGCNN because of database updating of MP. In addition, we trained our model on the largest datasets we can find (over 560K of OQMD) to get the best performance we can get now.

Figure 4 shows that our model performs well over the entire range of \({E}_{{{{{{\rm{f}}}}}}}\) and \({E}_{{{{{{\rm{g}}}}}}}\). The MAE of \({E}_{{{{{{\rm{f}}}}}}}\) and \({E}_{{{{{{\rm{g}}}}}}}\) of DFT calculations against experimental measurements are 0.81–0.136 eV/atom and 0.6 eV, respectively1,39. Also, Table 1 summarizes the MAE of four models (CGCNN19, MEGNet24, iCGCNN22 and this work) on different datasets. In comparison, our model can achieve much better performance. For example, for \({E}_{{{{{{\rm{f}}}}}}}\), our model outperforms CGCNN, MEGNet, and iCGCNN by 25.6%, 14.3% and 35.7% respectively. For \({E}_{{{{{{\rm{g}}}}}}}\), our model outperforms CGCNN and MEGNet by 27.6% and 12.4% respectively. For elastic metrics KVRH and \({G}_{{{{{{\rm{VRH}}}}}}}\), our model reaches the same-level error of the MEGNet model. The good performance on diverse datasets shows that our model has both high precision and good generalization.

Fig. 4: Comparison between DFT calculations and predictions.
figure 4

a 56K predicted \({E}_{{{{{{\rm{f}}}}}}}\) versus the corresponding DFT calculated values. The model used is trained by 449K DFT data in the OQMD database. b 3K predicted \({E}_{{{{{{\rm{g}}}}}}}\) versus the corresponding DFT calculated values. The model used is trained by 24K DFT data in the OQMD database.

Table 1 Comparison with other approaches recently reported in the literature.

We note that the volume of dataset is closely related to the performance. To better understand this relationship, we implemented several experiments on different volumes of training sets. The testing sets were the same 9274 DFT data from Material Project. Figure 5(a) shows that the precision improves significantly with the increase of the number of training data. Our model reaches a comparable precision with that of experiment vs. DFT calculation when the number of training data reaches ~\({10}^{3}\).

Fig. 5: Experiments on 9274 DFT data in Material Project.
figure 5

a Different volume of training data versus the MAE of \({E}_{{{{{{\rm{f}}}}}}}\). b The MAE of \({E}_{{{{{{\rm{f}}}}}}}\) on 6 different aggregation functions.

To validate the contribution of our model, we implemented the ablation experiments on random 27,824 training data in the MP database. To test the validity of the attention mask \(M\), we run the model with a set of different aggregate functions (Eq. 4). First, we validate the model without \(M\) and the aggregate function changes to \({f}_{{{{{{\rm{agg}}}}}}1}\) in Eq. (14). Compared with our full model, the MAE is increased by 53%. The results demonstrate that the attention mask \(M\) is the main contributor to the high scores of our model. To further validate the two parts of attention mask \(M\), we implemented additional experiments on the same database as above. We compared 2 different aggregate functions \({f}_{{{{{{\rm{agg}}}}}}2}\) in Eq. (15) and \({f}_{{{{{{\rm{agg}}}}}}3}\) in Eq. (16) to prove that both the Gaussian radial basis and plane waves are valuable. Figure 5(b) shows that two single basis function (Gaussian radial basis or Plane Waves) offers limited improvements to the model, but when combined they can significantly improve the model’s performance. We also evaluate alternatives to implement \({\tilde{v}}_{ij}\) and \(G\). Based on our full model, we respectively change \({\tilde{v}}_{ij}\) to \({\tilde{v}}_{ij}{\prime}\) where \({\tilde{v}}_{ij}^{{\prime} }={v}_{j}\oplus {v}_{i}\), and \(G\) to \(G^{\prime}\) where \(G^{\prime} =g(W\{SG\})\) as in Eq. (17) and Eq. (18). \(\{SG\}\) denotes the one-hot vector of space group.

$${f}_{{{{{{\rm{agg}}}}}}1}=\mathop{\sum} _{i\in N\{j\}}\sigma ({W}_{{\!\!{\sigma }}}^{t}{\tilde{v}}_{ij})\odot g({W}_{{\!\!{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij})$$
(14)
$${f}_{{{{{{\rm{agg}}}}}}2}=\mathop{\sum} _{i\in N\{j\}}\sigma ({W}_{{\!\!{\sigma }}}^{t}{\tilde{v}}_{ij})\odot g({W}_{\!\!{{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij})\odot {W}_{{{{{{\rm{R}}}}}}}\{{a}_{{{{{{\rm{RBF}}}}}}}\}$$
(15)
$$\hskip30pt{f}_{{{{{{\rm{agg}}}}}}3}=\mathop{\sum} _{i\in N\{j\}}\sigma ({W}_{{\!\!{\sigma }}}^{t}{\tilde{v}}_{ij})\odot g({W}_{\!\!{{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij})\odot {W}_{{{{{{\rm{R}}}}}}}\{{a}_{{{{{{\rm{PW}}}}}}}\otimes G\}$$
(16)
$$\hskip26pt{f}_{{{{{{\rm{agg}}}}}}4}=\mathop{\sum} _{i\in N\{j\}}\sigma ({W}_{{\!\!{\sigma }}}^{t}{\tilde{v}}_{ij}^{{\prime} })\odot g({W}_{\!\!{{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij}^{{\prime} })\odot M,\,{\tilde{v}}_{ij}^{{\prime} }={v}_{j}\oplus {v}_{i}$$
(17)
$$\hskip38pt{f}_{{{{{{\rm{agg}}}}}}5} =\mathop{\sum} _{i\in N\{j\}}\sigma ({W}_{{\!\!{\sigma }}}^{t}{\tilde{v}}_{ij})\odot g({W}_{\!\!{{{{{\rm{g}}}}}}}^{t}{\tilde{v}}_{ij})\odot M(G^{\prime} ),G^{{\prime} } \\ =g(W\{SG\})$$
(18)

Self-learned supercell invariance

In theory, static properties like \({E}_{{{{{{\rm{g}}}}}}}\) and \({E}_{{{{{{\rm{f}}}}}}}\) would not change with supercell transformation (present multi primitive cells in a bigger supercell), so a machine learning model should satisfy this characteristic. The output of our model is mainly based on the correlative relationship of atoms and their local structures of input crystals. As given in Eq. (11), the plane wave mask function imports the information of the crystal structure which alters with supercell transformation, which means the absence of inherent a priori supercell invariance in our model.

But this does not mean that our model cannot capture the supercell invariance by learning multi scales of crystal structure data. There are many supercell data in MP and OQMD database and the good test results on these databases already demonstrate the ability of our model to learn the supercell invariance. From the point of view of Machine Learning, our model can learn similar representations of the Mask function (showed in Eq. 9) of similar crystal structures. The definition of “similarity” is automatically determined by the model according to the prediction target. For the properties like \({E}_{{{{{{\rm{g}}}}}}}\) or \({E}_{{{{{{\rm{f}}}}}}}\), the model can learn this kind of “similarity” as well. In other words, the supercell transformation invariance is one kind of similarities that our model should learn from massive different scales of training data. In detail, \({W}_{{{{{{\rm{P}}}}}}}\) and \(G\) of Eq. (9) enable our model to learn the similarities among crystals. For different supercells of one primitive cell, \({a}_{PW}\) is different on each dimension but the value of M function should be very close.

To demonstrate the supercell transformation invariance clearer, we selected 3 crystals: perovskite CaTiO3, rutile TiO2, and perovskite SrTiO3, each of them is with 8 different supercell configurations. Figure 6 illustrate three supercell configurations of \(Ti{O}_{2}\). Note that \(mp-XX{X}_{ijk}\,\)means the lengths of the first, second, and third basis are magnified by \(i\), \(j\), and \(k\) times, respectively; and \(mp-XXX\) is the material ID in the Material Project database. Take formation energy as an example: Table 2 shows the DFT calculations (Target value) and the predictions of our model for three crystal structures with eight different supercell configurations. The results demonstrate the supercell transformation invariance where the errors of prediction of each supercell are comparable. The model used here is trained on 69K DFT calculation data (formation energy) which not includes the three selected crystals.

Fig. 6: Supercell transformation invariance.
figure 6

a Crystal cell of TiO2(mp-2657111). b Crystal supercell of TiO2(mp-2657112). c Crystal supercell of TiO2(mp-2657222).

Table 2 Predictions (\({E}_{{{{{{\rm{f}}}}}}}\)) of eight different supercells for three crystal structures.

Discussion

Deep neural networks including GNN nearly always outperform traditional machine learning, but the interpretability is the main drawback of deep neural networks due to their multiple nonlinear transformation layers and a large number of neurons. Like all other prediction models, our model is evaluated by the prediction error. Although the prediction accuracy is high, the features learned by the model have nothing to do with physical properties but merely fit the results based on a limited dataset. For some crystalline materials, even just changing one atom can make big differences in their physical properties. It is understandable for experts, but it is difficult for a machine learning model to recognize it because the input is just a set of one-hot vectors with similar distances between any two atoms. If our model can still accurately predict properties even if the input data is only fine-tuned, then we can be more confident that the model does learn features with physical meanings. In this section, we present a specific case of band gap predictions of perovskite materials to show the plausibility of the model.

The data we use in this section is a group of halide perovskites with the formular \({{{{{{\rm{AMX}}}}}}}_{3}\) (A = \({{{{{{\rm{CH}}}}}}}_{3}{{{{{{\rm{NH}}}}}}}_{3}\), Cs; M = Pb, Sn, X = I, Br, Cl) in cubic or orthorhombic system. As the demonstration in Fig. 7, the structures of \({{{{{{\rm{AMX}}}}}}}_{3}\) with the same crystal system are similar, especially relative positions among atoms, which could be a big challenge for the machine learning model to make a satisfying prediction. In detail, we selected five different perovskites, including 3 cubic crystals and 2 orthorhombic crystals, to evaluate the sensitivity of our model to slight differences in atomic types and crystal structures. The prediction target is band gap, an important parameter for solar cell materials.

Fig. 7: Crystal structures of halide perovskites in cubic system.
figure 7

a crystal structure of \({{{{{\rm{CsPb}}}}}}{{{{{{\rm{I}}}}}}}_{3}\). b crystal structure of \({{{{{\rm{CsSn}}}}}}{{{{{{\rm{I}}}}}}}_{3}\).

The model used in this section is trained on 69,000 DFT calculated band gaps \({E}_{{{{{{\rm{g}}}}}}}\) in the open database Material Project. As shown in Table 3, the errors between model predictions and experimental results are comparable with those between DFT and experimental results. In details, we can find that the band gap predictions of \({{{{{\rm{C{H}}}}}}}_{3}{{{{{\rm{N{H}}}}}}}_{3}{{{{{\rm{PbC{l}}}}}}}_{3}\), \({{{{{\rm{C{H}}}}}}}_{3}{{{{{\rm{N{H}}}}}}}_{3}{{{{{\rm{PbB{r}}}}}}}_{3}\) and \({{{{{\rm{C{H}}}}}}}_{3}{{{{{\rm{N{H}}}}}}}_{3}{{{{{\rm{Pb{I}}}}}}}_{3}\) descend in turn due to different intrinsic properties of \({{{{{\rm{Cl}}}}}}\), \({{{{{\rm{Br}}}}}}\) and\(\,{{{{{\rm{I}}}}}}\). Our predictions are consistent with both DFT calculations and experimental results. As from \({{{{{{\rm{CsPbI}}}}}}}_{3}\) to \({{{{{{\rm{CsSnI}}}}}}}_{3}\), the band gap prediction sharply decreases from 2.180 eV to 1.381 eV, which also matches with the DFT calculations and experimental results. Note that the DFT calculations in Table 3 are conducted using the latest approach of approximate quasiparticle DFT-1/2 method40. We believe that our model can get better performance along with the iterating of more available open databases.

Table 3 Band gap predictions vs. experimental results vs. DFT calculations (eV). \(\alpha ,\gamma\) denotes cubic system and orthorhombic system, respectively.

Conclusion

In summary, GeoCGNN presents a machine learning approach for crystal property prediction. With an attention mask containing Gaussian basis and plane wave basis, we demonstrate that encoding the spatial geometrical information in a proper way can significantly improve the prediction accuracy. The predicted targets of our model can be very close to the DFT calculated results with a training set of ~\({10}^{4}\). It indicates that our model can be a potential substitute for the DFT calculation, especially in a high-throughput screening scenario. Finally, introducing distance vector between atoms can help our model to learn the complete geometrical information with the rotational invariance. Hence our model can be used and transformed in the databases with unified defined crystal coordinate systems.

Methods

Graph Neural Networks

This section demonstrates how GNN learns structural information during each step of forward propagation. Figure 8 uses different colors to represent different nodes’ information, and we can note that each node gathers the information of its neighbors after one step of node updating (from layer t to later t+1). Equation (19) and the right part of Fig. 8 give a simple demo of details in each step. Equation (20) shows the simplest output function where \(P\) is the target. The output function usually contains several pooling layers which compress all nodes and edges into a single vector and a final full connected layer to transform the vector to a real number \(P\). The graph \(G\) is defined by Eq. (21), where \(v,\,u\,{{{{{\rm{and}}}}}}\,A\) denote the node, the edge, and the adjacency matrix, respectively. It is clear that the nodes with different positions in this topological graph can gather different information from each other. Thus, each node in the tth layer encodes the t-order topological information around itself. The nodes can also encode the geometric information in a similar way just by adding the geometric information into the edge representation \(u\).

$${v}_{i}^{t}\,={f}_{{{{{{\rm{update}}}}}}}({v}_{i}^{t-1},{f}_{{{{{{\rm{aggregate}}}}}}}({v}_{j}^{t-1},{u}_{ij}^{t-1}|\,j\in {N}_{i}))$$
(19)
$$P={f}_{{{{{{\rm{output}}}}}}}(\{{v}_{i}^{t}\},\{{u}_{ij}^{t}\}|\,i,j\in G)$$
(20)
$$G=(v,u,A)$$
(21)
Fig. 8: A demonstration of the node aggregation and updating in GNN.
figure 8

a The full picture of one step of forward propagation in GNN. b The detailed process in two nodes.\(\,{v}_{i}^{t}\) means the ith node in the tth layer,\(\,{u}_{{ij}}^{t}\) means the edge between node \(i\) and \({j}\) in the tth layer (in this work, \({u}_{{ij}}^{t}\) doesn’t vary with \(t\)), W is a learnable parameter matrix, \(f\) and \(g\) are nonlinear functions.

Hyperparameters

In this part, we will discuss several main hyperparameters in our model. The results in this part are based on the \({E}_{{{{{{\rm{f}}}}}}}\) data in MP dataset. Data splitting: train/val/test = 60K/4.6K/4.6K. For all of the models presented in this work, 300 training epochs and the Adam optimizer49 were used. Figure 9 shows that 300 epochs are enough for the convergence of our model.

Fig. 9: The plot of learning history on full OQMD database.
figure 9

a Red line shows the MAE of training set in the learning process of 449,506 \({E}_{{{{{{\rm{f}}}}}}}\) training data. Blue line shows the MAE of validation set in the training process of 56,188 \({E}_{{{{{{\rm{f}}}}}}}\) validation data. b The learning history of 24,296 \({E}_{{{{{{\rm{g}}}}}}}\) training data and 3,036 \({E}_{{{{{{\rm{g}}}}}}}\) validation data.

Learning rate

Learning rate is one of the most important hyperparameters. We tried 1e−4, 5e−4, 1e−3, 5e−3, and we found the model is sensitive to the learning rate. The performance of the model fluctuates within 23%. The best performance is got at 1e−3. In addition, we implemented a basic approach to reduce the learning rate to 1e−4 during the later 50 epochs for tighter convergence.

Batch size

Experiments show that our model is not sensitive to batch size. We tried 128, 256, 300, 396, 512, and the performance fluctuates within 5%. Finally, we got the best point at batch size 300.

The number of convolution layers \({{{{{{\boldsymbol{N}}}}}}}_{{{{{{\bf{l}}}}}}{{{{{\bf{a}}}}}}{{{{{\bf{y}}}}}}{{{{{\bf{e}}}}}}{{{{{\bf{r}}}}}}}\)

Because of the issue of \(oversmoothing\), we just tried 2~7 \({N}_{{{{{{\rm{layer}}}}}}}\). We can find that the model’s performance is stable if \({N}_{{{{{{\rm{layer}}}}}}}\ge 3\). And the best performance was obtained when \({N}_{{{{{{\rm{layer}}}}}}}=5\).

Dimension of node Dimnode and graph Dimgraph

We tried 64, 128, 192, and 256 for Dimnode and Dimgraph, respectively, and found that Dimnode = 192 and Dimgraph = 192 are the best. The performance has no obvious change when both dimensions are higher than 128.

The number of sampled k-points \({{{{{{\boldsymbol{N}}}}}}}_{{{{{{\bf{k}}}}}}}\) and the number of Gaussian radial basis \({{{{{{\boldsymbol{N}}}}}}}_{{{{{{\bf{G}}}}}}{{{{{\bf{a}}}}}}{{{{{\bf{u}}}}}}{{{{{\bf{s}}}}}}{{{{{\bf{s}}}}}}{{{{{\bf{i}}}}}}{{{{{\bf{a}}}}}}{{{{{\bf{n}}}}}}}\)

We tried 27, 64, 125 for \({N}_{{{{{{\rm{k}}}}}}}\) and 32, 64, 128 for \({N}_{{{{{{\rm{Gaussian}}}}}}}\), respectively and found that \({N}_{{{{{{\rm{k}}}}}}}=64\) and \({N}_{{{{{{\rm{Gaussian}}}}}}}=64\) are the best.