Introduction

In recent times several machine learning techniques1,2,3,4,5,6,7,8 have been proposed to enable fast and accurate prediction of different properties for crystalline materials, thus facilitating the rapid screening of large material search spaces9,10,11. The existing techniques either use handcrafted feature-based descriptors1,2,3,4,5 or deep graph neural network (GNN)6,7,8,12,13,14,15,16,17 to generate a representation from the 3d conformation of crystal structures. Generating handcrafted features requires specific domain knowledge and human intervention, which makes the methods inherently restricted. Deep learning methods, on the other hand, do not depend on careful feature curation and can automatically learn the structure-property relationships of materials; thus making it an attractive candidate.

Graph neural network-based approaches are getting popular recently for their ability to encode graph information in enriched representation space. Orbital-based GNNs16,17 use symmetry-adapted atomic orbital features to predict different molecular properties. Though orbital-based GNNs predict molecular properties well, they are not an excellent choice for capturing complicated periodic structures such as crystals since they describe the nature of the electron distribution particularly close to atoms. On the other hand, motif-centric GNNs14,15 convert motif sub-structures of a crystal as a node and encode their interconnections for a large set of crystalline compounds using an unsupervised learning algorithm. Though they show improvements on property prediction tasks for metal oxides, their applicability is restricted as they ignore the atomic configuration inside the motif substructure which is also very important.

On a different departure, CGCNN6, MTCGCNN7 build a convolution neural network directly on a 2d crystal graph derived from a 3d crystal structure. GATGNN12 incorporates the idea of graph attention network on crystal graphs to learn the importance of different bonds between the atoms whereas MEGNet8 introduces global state attributes for quantitative structure-state property relationship prediction in materials. As this class of methods aims to capture the information of any crystal graph just from the connectivity and atomic features, we contribute in this promising direction.

Like any large deep neural network-based model, GNN based architectures also introduce a large number of trainable parameters. Hence, to estimate these parameters correctly for better accuracy, a huge amount of tagged training data is required which is not always available for all the crystal properties. Hence developing a deep learning-based model which can be trained on a small amount of tagged data would be extremely useful to infer varied properties of crystal materials. Also as available experimental data for the various properties are small and less diverse18,19,20, these models are trained using data gathered from the DFT calculations21,22,23. As DFT data often differ from experimental ground truth due to its inefficiency in describing the many-body ground state, especially for properties such as band gap24 or treatment of van der Waals interactions25, training with DFT only method may incorporate the inaccuracies of DFT in the prediction. Moreover, in most cases, the existing property predictors are trained to predict a specific property. Hence, the generated descriptor or embeddings of any crystal are specific to a given property. It prevents them from sharing common structural information relevant to multiple properties. Though a multi-task learning setup achieves information sharing across properties7, it works well only for properties that are correlated with each other. Last but not least, the existing neural network-based methods6,7,8,9,10,12,13,14,15,16,17,26,27 hardly provide any explanation for their results. The lack of interpretability and algorithmic transparency allows little use of them in the field of material science. Therefore it is necessary to explore and provide the reasons behind a prediction for any given property.

In this paper, we propose an explainable deep property predictor CrysXPP. It is built upon CrysAE, an auto-encoder-based architecture that is trained with a large amount of easily available crystal data, that is, the property agnostic structural information of the corresponding crystal graphs. This leads to the deep encoding module capturing all the important structural and basic chemical information of the constituent atoms (nodes) of the crystal graph. The learned information is leveraged to build the property predictor, CrysXPP, where the knowledgeable encoder helps to produce high quality representation of a candidate crystal. Consequently, the property predictor provides superior performance (better than all the competing baselines) even when trained with a small amount of property-tagged data, thus largely mitigating the need for having a huge amount of dataset tagged with a specific property. The structural information learned in the encoding module of an auto-encoder is robust and can remove the error bias introduced by DFT by fine tuning the system with a small amount of experimental data, whenever available. Further, we introduce a feature selector that helps to provide an explanation by highlighting the subset of the atomic features responsible for the manifestation of a chemical property of the given crystal.

Through extensive analysis of inorganic crystal data set across seven properties, we show that our method can achieve the lowest error compared to other alternative baselines; the improvement is particularly significant when only a small amount of tagged data is available for training. We have further shown that CrysXPP is effective in removing error bias due to DFT tagged data by incorporating a small amount of experimental data in the training set for both formation energy and bandgap. Finally, with appropriate case studies, we show that the feature selection module can effectively provide explanations of the importance of different features towards prediction, which are in sync with the domain knowledge.

Results and discussions

Model architecture

In this section, we discuss in more detail the key technical contributions towards this goal followed by the training process and implementation details.

Overview

We propose Crystal eXplainable Property Predictor (CrysXPP), which realizes a crystalline material as a graph structure (say \({{{\mathcal{G}}}}\)) and predicts the value of a property (eg. formation energy) given the crystal graph structure. As depicted in Fig. 1, CrysXPP comprises two building blocks. (a) A property prediction module and (b) a graph embedding module. In the graph embedding module we have a crystal graph encoder based on graph convolution neural network (GCNN)6, which takes a crystal graph structure along with node and edge feature information as input and returns an embedding corresponding to each node as output. The weights of the node features (check Table 1) are determined by a feature selector layer. We consider nine different atomic properties (Table 1) as node features and the weights of those node features are determined by the feature selector layer. Moreover, the graph embedding module needs to capture the structural and chemical properties of the underlying crystal, hence one can use the huge amount of available crystal information (irrespective of the property) to train the graph convolution network. For this at first, we separately train the GCNN as a part (the encoder) of CrysAE (Fig. 2); and the weights thereby obtained are used as an initialization of the GCNN of CrysXPP. The structural information learned in the encoding module of CrysAE and duly transferred to the GCNN of CrysXPP makes CrysXPP more robust.

Fig. 1: Architecture of Crystal eXplainable Property Predictor (CrysXPP). It comprises two building blocks.
figure 1

Graph embedding module and Property prediction module. Given graph structure and node feature information, graph embedding module produces an embedding corresponding to each graph. Property predictor is a deep regressor module, which takes graph embedding as input and predicts the property value.

Table 1 Description of different atomic properties used as node features and their dimensions.
Fig. 2: Architecture of the Crystal Auto Encoder (CrysAE) module.
figure 2

It comprises a multilayer graph convolution network as the encoder and a corresponding decoder module for reconstructing different local and global features.

Our overall model architecture is essentially composed of the following two modules:

  • Auto encoder (CrysAE):\({q}_{{{{\boldsymbol{\theta }}}}}:({{{\boldsymbol{{{{\mathcal{V}}}}}}}},{{{\boldsymbol{{{{\mathcal{E}}}}}}}},{{{\boldsymbol{{{{\mathcal{X}}}}}}}},{{{\boldsymbol{{{{\mathcal{F}}}}}}}})\to {{{\boldsymbol{{{{\mathcal{Z}}}}}}}}\); \({p}_{{{{\boldsymbol{\phi }}}}}:{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}\to ({{{\boldsymbol{{{{\mathcal{V}}}}}}}},{{{\boldsymbol{{{{\mathcal{E}}}}}}}},{{{\boldsymbol{{{{\mathcal{X}}}}}}}},{{{\boldsymbol{{{{\mathcal{F}}}}}}}})\)

  • Property predictor (CrysXPP):\({p}_{{{{\boldsymbol{\zeta }}}},{{{{\boldsymbol{\theta }}}}}^{\prime},{{{\boldsymbol{\psi }}}}}:{{{\boldsymbol{{{{\mathcal{X}}}}}}}}{\to }_{{{{\boldsymbol{\zeta }}}}}{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime};({{{\boldsymbol{{{{\mathcal{V}}}}}}}},{{{\boldsymbol{{{{\mathcal{E}}}}}}}},{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime},{{{\boldsymbol{{{{\mathcal{F}}}}}}}}){\to }_{{{{{\boldsymbol{\theta }}}}}^{\prime}}{{{\boldsymbol{{{{\mathcal{Z}}}}}}}};{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}{\to }_{{{{\boldsymbol{\psi }}}}}{{{\mathcal{P}}}}\)

In the above characterization, \({{{\boldsymbol{\theta }}}},{{{\boldsymbol{\phi }}}},{{{\boldsymbol{\zeta }}}},{{{{\boldsymbol{\theta }}}}}^{\prime}\), and ψ are the trainable parameters of the respective modules. Here θ and ϕ are the parameters for the encoder and decoder respectively of the CrysAE. ζ is the trainable parameter of feature selector \({{{\mathcal{S}}}}\), \({{{{\boldsymbol{\theta }}}}}^{\prime}\) is the parameter of the encoder, and ψ is the parameter of the multi layer perceptron of CrysXPP model. We initialize \({{{{\boldsymbol{\theta }}}}}^{\prime}:= {{{\boldsymbol{\theta }}}}\) i.e., we first train the autoencoder and then the parameters of the encoder of CrysAE are transferred to the CrysXPP.

Crystal representation

Our model realizes crystalline materials as crystal graph structures \({{{\boldsymbol{{{{\mathcal{D}}}}}}}}=\{{{{{\mathcal{G}}}}}_{i}=({{{{\boldsymbol{{{{\mathcal{V}}}}}}}}}_{i},{{{{\boldsymbol{{{{\mathcal{E}}}}}}}}}_{i},{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}_{i},{{{{\boldsymbol{{{{\mathcal{F}}}}}}}}}_{i})\}\) as proposed in6. Crystals have a repeating structure as depicted in Fig. 2 where a unit cell gets repeated across all the three dimensions. Hence, unlike simple graphs, the \({{{{\mathcal{G}}}}}_{i}\) is an undirected weighted multi-graph where \({{{{\boldsymbol{{{{\mathcal{V}}}}}}}}}_{i}\) denotes a set of nodes (atoms) present in a unit cell of the crystal structure and \({{{{\boldsymbol{{{{\mathcal{E}}}}}}}}}_{i}=\{(u,v,{k}_{uv})\}\) denotes a multi-set of node pairs and the number of edges between them. kuv edges between a pair of nodes (u, v) indicate that v is present in kuv repeating cells within r radius from u (r is a hyper-parameter). \({{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}_{i}\) represents node features i.e., features that uniquely identify the chemical properties such as atomic volume, electron affinity, etc. of an atom as described in Table 1. Lastly, \({{{{\boldsymbol{{{{\mathcal{F}}}}}}}}}_{i}\) corresponds to a muti-set of edge weights. We denote \({{{{\boldsymbol{{{{\mathcal{F}}}}}}}}}_{i}=\{{\{{s}^{k}\}}_{(u,v)}| (u,v)\in {{{{\boldsymbol{{{{\mathcal{E}}}}}}}}}_{i}\}\) where sk denotes the kth bond length between the node pair (u, v). Between any pair of nodes, a maximum of K edges are possible where K is empirically determined. The bond length helps to specify the relative distance of an atom from its neighboring atoms. We use this graphical abstraction of a crystal as this can effectively embed the periodicity (indicated by the number of bonds) along with the relative positioning for each atom in a simpler way, which otherwise was difficult to capture. For easy reference, we drop the index of the notations. Next, we formally define the auto encoder (CrysAE) and property predictor (CrysXPP).

Auto encoder (CrysAE)

We build Crystal Auto Encoder (CrysAE) which composes of a simple encoder followed by an appropriate decoder to facilitate the overall training in order to learn necessary information in the encoding mechanism.

Encoder

We extend the crystal graph encoder proposed by Xie et al.6 to encode the chemical and structural information of a crystal graph \({{{\mathcal{G}}}}\). Specifically, we encode L-hop neighboring information of each node as:

$$\begin{array}{ll}&{{{{\boldsymbol{h}}}}}_{{(u,v)}_{k}}^{l}={{{{\boldsymbol{z}}}}}_{u}^{l}\oplus {{{{\boldsymbol{z}}}}}_{v}^{l}\oplus {s}_{(u,v)}^{k}\\ &{{{{\boldsymbol{z}}}}}_{u}^{l+1}={{{{\boldsymbol{z}}}}}_{u}^{l}+\mathop{\sum}\limits_{v,k}\sigma ({{{{{\boldsymbol{h}}}}}_{(u,v)}^{l}}_{k}{{{{\boldsymbol{W}}}}}_{c}^{(l)}+{{{{\boldsymbol{b}}}}}_{c}^{(l)})\odot g({{{{{\boldsymbol{h}}}}}_{(u,v)}^{l}}_{k}{{{{\boldsymbol{W}}}}}_{s}^{(l)}+{{{{\boldsymbol{b}}}}}_{s}^{(l)})\end{array}$$
(1)

where \({{{{\boldsymbol{z}}}}}_{u}^{l}\) denotes the embedding of node u after l hop neighbor information aggregation. The embedding of a node u is initialized to a transformed node feature vector, i.e., it is a function of the atom u’s chemical features as \({{{{\boldsymbol{z}}}}}_{u}^{0}:= {{{{\boldsymbol{x}}}}}_{u}{{{{\boldsymbol{W}}}}}_{x}\) where Wx is the trainable parameter of the transformation network and xu is the input node feature vector. \({s}_{(u,v)}^{k}\in {{{{\boldsymbol{{{{\mathcal{F}}}}}}}}}_{u}\) represents the length of the kth edge between nodes u and v. The  operator denotes concatenation and  denotes element-wise multiplication. \({{{{\boldsymbol{W}}}}}_{c}^{(l)},{{{{\boldsymbol{W}}}}}_{s}^{(l)},{{{{\boldsymbol{b}}}}}_{c}^{(l)},{{{{\boldsymbol{b}}}}}_{s}^{(l)}\) are the convolution weight matrix, self weight matrix, convolution bias, self bias of lth hop convolution, respectively. σ is a non-linear transformation function and it is used to generate a squeezed real value in [0, 1] indicating the edge importance and g is a feed forward network. After neighborhood aggregation we accumulate local information at each node which can be represented as \({{{{\boldsymbol{z}}}}}_{u}:= {{{{\boldsymbol{z}}}}}_{u}^{L}\). Subsequently we generate a graph level global information \({{{\boldsymbol{{{{\mathcal{Z}}}}}}}}=\{{{{{\boldsymbol{z}}}}}_{1},...,{{{{\boldsymbol{z}}}}}_{| {{{\boldsymbol{{{{\mathcal{V}}}}}}}}| }\}\). We do not aggregate the node embeddings further to prevent information loss in the autoencoder. We denote the set of trainable parameters for this encoder as θ for future reference.

Decoder

We design an effective decoder that helps the encoder to transform the desired information in the representation vector space of \({{{\boldsymbol{{{{\mathcal{Z}}}}}}}}\). The decoder plays an inevitable role in learning the local and global structure as well as chemical features which are extremely useful. As mentioned earlier the global chemical features i.e., the crystal properties are a function of the local chemical environment and the overall conformation of the repeating crystal cell structure; hence, we carefully design the decoder which can reconstruct two important features that induce the local chemical environment. They are (a) the node features i.e chemical properties of individual atoms and (b) local connectivity i.e the relative position of the nodes with respect to their neighbors. Precisely we reconstruct this information as below:

$${z}_{uv}={{{{\boldsymbol{z}}}}}_{u}^{T}{{{{\boldsymbol{W}}}}}_{f}{{{{\boldsymbol{z}}}}}_{v}+{b}_{f}$$
(2)
$$\hat{{s}_{(u,v)}^{k}}=\left\{\begin{array}{ll}{\gamma }_{s}({z}_{uv}\odot k)&\,{{\mathrm{if}}}\,{\gamma }_{s}(.)\, > \,0\\ 0&\,{{\mathrm{otherwise}}}\,\end{array}\right.$$
(3)
$${\hat{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}}_{u}={{{{\boldsymbol{W}}}}}_{x}^{T}{{{{\boldsymbol{z}}}}}_{u}+{{{{\boldsymbol{b}}}}}_{x}$$
(4)

Equations 3 and 4 correspond to reconstructing the node property or atom’s chemical property and a node’s position relative to it’s neighbors as we intend to achieve in (a) and (b) respectively. zuv is a combined transformed embedding of nodes u and v and γs is a feed forward network that generates a real number corresponding to the length of the bonds.

Further, we reconstruct the global structure i.e. (c) the connectivity and periodicity of the crystal structures as below

$$(u,v) \sim p(e=(u,v))=\sigma ({{{{\boldsymbol{z}}}}}_{u}^{T}{{{{\boldsymbol{W}}}}}_{e}{{{{\boldsymbol{z}}}}}_{v}+{b}_{e})$$
(5)
$${k}_{(u,v)}=\arg \mathop{\max }\limits_{k}\frac{{\rm{e}}^{{\gamma }_{k}({z}_{uv},k)}}{{\sum }_{k}{\rm{e}}^{{\gamma }_{k}({{{{\boldsymbol{z}}}}}_{uv},k)}}$$
(6)

Here, We, be are trainable weight and bias associated with the bilinear edge reconstruction module, respectively. σ is a squashing factor that provides a value between [0, 1] denoting the edge probability. Similarly Wf and bf are the trainable weight and bias parameters associated with the intermediate bi-linear transformation module, respectively. γk represents a feed-forward neural network that generates a K length logit vector. We use a softmax to determine the exact number of edges present. Please note that though Eqs. 6, 3 correspond to global and local information respectively, they are heavily dependent upon each other, i.e the number of bonds and bond length both depend on the two end nodes information. Hence, we design a coupled embedding zuv that is shared by both the modules. We denote the set of parameters in the decoder as ϕ.

Training of auto-encoder

We learn the trainable parameters of both encoder and decoder by minimizing the reconstruction loss of different global and local structural and chemical features defined in Eqs. 36. We minimize the cross-entropy loss of the predicted global features and node features along with mean squared loss of the edge weight or bond length in the following objective:

$$\begin{array}{ll}&{{\mathbb{E}}}_{{{{\mathcal{G}}}} \sim {{{\boldsymbol{{{{\mathcal{D}}}}}}}}}-\mathop{\sum}\limits_{(u,v)\in {{{\mathcal{E}}}}}\left[\log p(e=(u,v))+\log \,p({k}_{(u,v)})\right]\\ &-\mathop{\sum}\limits_{u\in {{{\mathcal{V}}}}}\log p({\hat{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}}_{u})+\mathop{\sum}\limits_{(u,v)\in {{{\mathcal{E}}}}}\mathop{\sum}\limits_{k\in [1,\ldots ,K]}{({s}_{(u,v)}^{k}-\hat{{s}_{(u,v)}^{k}})}^{2}\end{array}$$
(7)

where p(.) denotes the probability of any event. Thus by minimizing the reconstruction loss we not only fine tune parameters of decoder but efficiently train the encoder to generate a rich \({{{\boldsymbol{{{{\mathcal{Z}}}}}}}}\) which facilitates decoder operations.

Property predictor (CrysXPP)

Next, we design a property predictor specific to a property that can take advantage of the structural information learned by the encoder as described above. We generate a graph level representation using the same graph encoder module as described in Eq. 1, thus in a way transferring the rich encoded knowledge to the property predictor. Next, we use a symmetric aggregation function to generate a single vector as graph representation \({{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}}_{g}\). Thus the obtained representation of the graph is invariant to the node orderings. Then the obtained representation is fed to a multilayer perceptron which predicts the value of the properties. More formally the property predictor can be characterized as:

$${{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}}_{g}={{\Lambda }}({{{{\boldsymbol{z}}}}}_{1}\ldots ,{{{{\boldsymbol{z}}}}}_{| {{{\mathcal{V}}}}| })$$
(8)
$${{{\mathcal{P}}}}={{{{\mathcal{M}}}}}_{{{{\boldsymbol{\psi }}}}}({{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}}_{g})$$
(9)

Here, Λ is the aggregation function which is symmetric.

\({{{\mathcal{M}}}}\) denotes a multilayer perceptron that has a trainable parameter set ψ.

Feature selection

The node features are first passed through a feature selector which is a trainable weight vector that selects a weighted subset of important node level features \({{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime}\) for a given property of interest \({{{\mathcal{P}}}}\). \({{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime}\) forms input to the encoder.

$${{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime}={{{{\mathcal{S}}}}}_{{{{\boldsymbol{\zeta }}}}}({{{\boldsymbol{{{{\mathcal{X}}}}}}}});{{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}}_{g}={{\Lambda }}({{{\mbox{Encoder}}}}_{{{{{\boldsymbol{\theta }}}}}^{\prime}}({{{\boldsymbol{{{{\mathcal{V}}}}}}}},{{{\boldsymbol{{{{\mathcal{E}}}}}}}},{{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime},{{{\boldsymbol{{{{\mathcal{F}}}}}}}}))$$

In the above set of characteristic equations, \({{{\mathcal{S}}}}\) is the feature selector and ζ is its trainable weight. We will show how the weights chosen by the feature selection layer help us to explain the role of a node feature in the manifestation of a particular property (viz. formation energy) of a crystal.

Training of CrysXPP

We train the property predictor after the autoencoder. We initialize the trainable parameter \({{{{\boldsymbol{\theta }}}}}^{\prime}:= {{{\boldsymbol{\theta }}}}\) where θ is trained in the autoencoding module. Thus we first transfer the trained information such that the property predictor benefits from the inductive bias already learned by training the autoencoder.

We use a LASSO28 regression to impose sparsity on the feature selector layer. Intuitively, if some atomic features (\({{{{\boldsymbol{{{{\mathcal{X}}}}}}}}}^{\prime}\)) are crucial to predict a chemical property of the crystal, the corresponding feature selector value will be high and conversely, if some feature is not so important, the corresponding feature selector value will be negligible. Hence, along with property prediction loss we also consider the LASSO regression loss as formally represented below:

$$\mathop{\min }\limits_{{{{\boldsymbol{\zeta }}}},{{{{\boldsymbol{\theta }}}}}^{\prime},{{{\boldsymbol{\psi }}}}}{(\hat{{{{\mathcal{P}}}}}-{{{\mathcal{P}}}})}^{2}+{\lambda }_{1}* | {{{\boldsymbol{\zeta }}}}{| }_{{L}_{1}}$$
(10)

where ζ denotes the trainable parameters of feature selector \({{{\mathcal{S}}}}\) and λ1 is a hyper parameter that controls the degree of the regularization imposed.

Before reporting the results, we briefly discuss about the dataset and the baselines used for comparison.

Dataset

We have used the Materials Project database for our experiments which consists of ~36,835 crystalline materials and is diverse in the structure having materials with 87 different types of atoms, seven different lattice systems, and 216 space groups. The unit cell of any crystal can have a maximum of 200 atoms. We consider nine properties for each atom which were used to construct the feature vector of each node6. The details of the properties are given in Table 1. We convert them to categorical values if they are already not in that form. The dataset also provides DFT calculated target property values for the crystal structures. Experiments were done on a smaller training set than the original baseline papers.

Comparison with similar baseline algorithms

We compare the performance of CrysXPP with four state-of-the-art algorithms for crystal property prediction. These selected competing methods are varied in terms of input data processing and working paradigms as described below:

  1. (a)

    CGCNN6: This method generates crystal graphs from inorganic crystal materials and builds a graph convolution-based supervised model for predicting various properties of the crystals.

  2. (b)

    MT-CGCNN7: This model uses the graph convolution-based encoding as proposed in the previous model. Moreover, it incorporates multitask learning to jointly predict multiple properties of a single material.

  3. (c)

    MEGNET8: Here authors improved the CGCNN model further by introducing global state attributes including temperature, pressure, entropy, etc. for quantitative structure-state-property relationship prediction in materials. Doing so they found that the crystal embeddings in MEGNet model encode periodic chemical trends. Further to address the issue of data limitation the embeddings from a MEGNet model trained on formation energies are transferred and used to improve the accuracy of ML models for the bandgap and elastic moduli.

  4. (d)

    ELEMNET29: This work does not specifically consider any structural properties of the crystal graph, rather it considers only the compositional atoms. It uses deep feed-forward networks to implicitly capture the effect of atoms on each other. It uses transfer learning to mitigate the error bias of DFT tagged data.

  5. (e)

    GATGNN12: In this work, authors have incorporated a graph neural network with multiple graph-attention layers (GAT) and a global attention layer, which can learn efficiently the importance of different complex bonds shared among the atoms within each atom’s local neighborhood.

For all the baselines we have used the hyper parameters as mentioned in the original papers.

Evaluation criteria

We predict seven different properties of crystals in our experiments. Out of these, four are crystal state properties, namely, (a) Formation Energy, (b) Band Gap, (c) Fermi Energy, (d) Magnetic Moment, and three are elastic properties, namely, (e) Bulk Moduli, (f) Shear Moduli, and (g) Poisson Ratio. All of these properties significantly depend on the details of the crystal structure except Magnetic Moment which is more dependent on the atomic/node specifications as the magnetic moment arises from the unpaired d or f electrons in an atom. Also, the size of the moment depends on the local environments30,31. Moreover, we have very little DFT tagged data for Magnetic Moment and Band Gap.

We focus on three different evaluation criteria as described below:

  1. 1.

    How effective is the property predictor? Here we inspect the performance of the property predictor especially when it functions with a small amount of DFT tagged data.

  2. 2.

    How robust is the structural encoding? Here we investigate whether the structural encoding helps us to mitigate the noise introduced by DFT calculated properties.

  3. 3.

    How effective is the explanation? We cross-validate the obtained explanation with domain knowledge.

Effectiveness of property predictor

We first train the autoencoder with all untagged crystal graph present in the dataset, which captures all the structural information of the crystal graphs. Next for a given property of interest, we train the property predictor with 20% of the available DFT tagged data and test on the rest. We report the 10 fold cross-validation results.

Metric

We report mean absolute error (MAE) to compare the performance of the participating methods. MAE is defined as \(\frac{1}{| {{{\boldsymbol{{{{\mathcal{D}}}}}}}}| }{\sum }_{{{{\mathcal{G}}}}\in {{{\boldsymbol{{{{\mathcal{D}}}}}}}}}\left|\right.{{{{\mathcal{P}}}}}_{{{{\mathcal{G}}}}}-{\hat{{{{\mathcal{P}}}}}}_{{{{\mathcal{G}}}}}\left|\right.\), where \({{{{\mathcal{P}}}}}_{{{{\mathcal{G}}}}}\) is the property value calculated by DFT and \({\hat{{{{\mathcal{P}}}}}}_{{{{\mathcal{G}}}}}\) is the predicted value of a graph \({{{\mathcal{G}}}}\).

Results

In Table 2 we report the MAE for CrysXPP as well as other alternatives on seven property values. We observe that CrysXPP outperforms every baseline across all the properties except magnetic moment. For MTCGCNN we report two values: the MAE obtained while jointly predicting the most correlated property, and the average MAE across all possible combinations (in bracket). It is interesting to note that its performance significantly degrades if the other property is not correlated with the current property of interest. A careful inspection reveals that for elastic properties, graph neural network-based methods perform better than that of ElemNet. ElemNet only considers the composition of the crystal and ignores the global structural information, whereas these properties heavily depend on the crystal structure. In contrast, for Magnetic Moment the local information is important and hence, ElemNet performs the best and CrysXPP is the second-best method. For the rest of the crystal state-based properties, there is no consistent second best method. However, CrysXPP is a clear winner with a considerable margin which is due to the fact that the property predictor benefits from the structural knowledge transferred from the autoencoder.

Table 2 Summary of the prediction performance (MAE) of different properties trained on 20% data and evaluated on 80% of the data. The best performance is highlighted in bold and second-best with *. We report MAE jointly training most correlated property (average on all property pairs) for MTCGCNN.

Behavior with increase in tagged data

Further, we check the robustness of CrysXPP, by increasing the percentage of tagged training data for property prediction. We report the behavior of CrysXPP as well as other baselines in Fig. 3 for all the properties. We observe a monotonic decrease of MAE between predicted and DFT calculated values for most of the models where CrysXPP yields consistently smaller MAE and maintains the leadership position for all the properties except Magnetic Moment. This shows the robustness of our model to be able to perform consistently across a diverse set of properties with varied training instances. The MAE margin between CrysXPP and closest competitor (which is variable across properties), however, reduces as training size increases. For Magnetic Moment, the local chemical information is more vital, hence ElemNet, which concentrates more on local chemical information, shows the best performance.

Fig. 3: Variation of MAE with the increase in training instances from 20 to 80%.
figure 3

CrysXPP outperforms all the baselines consistently. a Formation energy. b Band Gap. c Fermi Energy. d Magnetic Moment. e Shear Moduli. f Bulk Moduli. g Poisson Ratio.

Removal of DFT error bias

An important aspect of the prediction is that since we rely only on DFT data for training, we would be limited by the inaccuracies of DFT. In this section, we investigate with a system where we further fine tune the model with a small amount of available experimental data and check whether the system can remove the error propagated due to DFT.

Calculation setup

We consider a property predictor (as explained before) which has been trained with crystals whose particular property (e.g., Band Gap) values have been theoretically derived using DFT. We then fine tune the parameters with limited amount of experimental data; we perform it for two different properties, namely, Band Gap and Formation Energy. For Formation Energy, we use 1500 instances available at22 and use different percentages of the data to fine-tune the model parameters. For Band Gap, we collect 20 experimental instances out of which we randomly pick 10 instances to fine-tune the parameters and report the prediction value for the rest (http://matprop.ru).

Results (Formation Energy)

We report the MAE of Formation Energy in Table 3 achieved by different methods. The DFT prediction of the Formation Energy on the 1500 crystals has an MAE of 0.21 with respect to experimental data and by training our model with DFT data we are performing close to the performance achieved by DFT. The results have a consistent trend for all the methods, whereby we observe that increasing the amount of training data, even if that is error-prone DFT data, helps in minimizing MAE. CrysXPP performs consistently better by a large margin than CGCNN and MTCGCNN, which takes the graph structure as an explicit input. However, it is interesting to observe that ElemNet performs very close to CrysXPP as Formation Energy depends more on the composition than that on the explicit connection of atoms. Further, we conduct an experiment where we replace the experimental data with the same amount of DFT data to train our model. We then evaluate the performance of the model using experimental data as test data and find an inferior performance. We report the results in Table 3 (last column (in bracket)).

Table 3 MAE of predicting experimental values after fine tuning different methods with different percentages of experimental data for Formation Energy. MAE of the experiment where we replace the experimental data with the same amount of DFT data to train CrysXPP, is provided in the bracket. The closest prediction is marked in bold and second-best with *.

Results (Band Gap)

In Table 4 we report the experimental value of Band Gap for 10 test instances along with the predicted values by DFT and other machine learning methods. The error margin of DFT with the actual experimental values is quite high. It is interesting to see that other than a few, DFT prediction is far from experimental data and in most cases, it is underestimating the experimental values. After fine-tuning DFT-trained machine learning models with experimental data, the prediction becomes closer to the experimental value. However, CrysXPP performs closest to the experimental result in almost all the cases in comparison to other alternatives. ElemNet, although second on the average when trained only on DFT (row 2 of Table 2), cannot consistently maintain that position, whereas CGCNN performs better. We have also provided results (in bracket) when we do not do any fine-tuning. It can be seen even in such a scenario in many of the cases the performance is better than DFT. Further, the power of CrysXPP in quickly mitigating the bias of DFT when fine-tuned on minuscule data shows the usefulness of modeling explicit structural information.

Table 4 Experiment (Exp) and predicted value for Band Gap for 10 crystals calculated by DFT and other machine learning models after fine-tuned by experimental data. The predicted value without fine tuning by experimental data is provided in the bracket. The closest prediction is marked in bold and the second-best with *. CrysXPP predicts closest to the ground truth after fine tuning with experimental data.

Ablation studies

We demonstrate the effectiveness of architectural choices and training strategies for CrysXPP, by designing the following set of ablation studies:

  1. 1.

    The importance of explicitly capturing global and local features and understanding their effect on property prediction

  2. 2.

    The impact of sparse feature selection on property prediction, and

  3. 3.

    The choice of GNN models in the autoencoder CrysAE

In the following subsections, we will thoroughly discuss these.

Importance of local and global feature understanding

Here we investigate the importance of different reconstruction loss components on CrysAE training and eventually its effect on property prediction. (a) Without Global + Local effect: In this scenario, we do not train CrysAE and only train CrysXPP. (b) Global effect: We train CrysAE by minimizing the reconstruction loss of only global features and ignoring local feature losses. (c) Local effect: Here we focus only on minimizing local feature losses. We report the performance of the model (MAE) in Table 5 on different train test splits across different properties. We observe that the performance of the model in the setting Without Global + Local effect, is the worst. We also notice that for all the properties, Local effect individually leads to better performance than Global effect, except for the Poisson ratio where the effect is similar for both cases. However, it is found that the impact of local and global effect are somewhat complementary, hence simultaneous reconstruction of global and local features (CrysXPP) results in the best performance. The only exception is formation energy where the addition of global feature leads to performance deterioration.

Table 5 Summary of experiments of ablation study on the importance of different reconstruction loss components on CrysAE training and eventually its effect on CrysXPP (MAE for property prediction).

Impact of sparse feature selection

We perform an ablation study to analyze the impact of sparse feature selection on property prediction. This is done by removing the L1 regularizer term from CrysXPP loss function in Eq. 10. We evaluate the performance of the model and report the results (MAE) in Table 6. We observe discernible improvement due to the introduction of sparse feature selection using L1 regularizer.

Table 6 Summary of experiments of ablation study on sparse feature selection using L1 regularizer, performed on different train test splits across different properties (MAE).

Effect of other GNN variants as graph encoder

To explore the effectiveness of other GNN variants as graph encoders, we conduct an experiment where we replace the CGCNN encoder with one of the popular GNN variants: GCN32 encoder, and evaluate the performance of the model. GCN only considers the graph structural information and atom features to learn the graph representation and unlike CGCNN, it does not consider the individual edges’ weights in the multi-graph representing a crystal. We report the results of the model performance (MAE) in Table 7. We observe that the model performance degrades when trained with GCN. The edge weight calculation, which is a major contribution of CGCNN, is extremely helpful to capture the local structure of the crystal.

Table 7 Summary of experiments (MAE) of ablation study on the effect of GCN as graph encoder in CrysAE and CrysXPP.

Explanation through feature selection

We have introduced a feature selector that is trained along with the property prediction parameters with available tagged data. The feature selector helps to select the subset of the atomic features contributing to the chemical properties of the crystal which makes the model explainable by design. To demonstrate the effectiveness of the feature selector, we have selected a few case studies and provided the feature explanation for formation energy, band gap, and magnetic moment.

Formation energy

We here report case studies corresponding to two crystals BaEr2F8 and AuC, illustrating the important role of feature selector in providing an explanation. We report the feature selector values corresponding to categorical atomic properties after being trained on Formation Energy tagged data in Fig. 4. The bars represent the weights assigned by the feature selector on the categorical values of atomic features and different colors indicate different atoms. A higher category denotes a higher value of the feature. Figure 4 depicts the importance of the atomic features in two extreme cases. One is BaEr2F8, whose Formation Energy is predicted as −4.41 eV/atom indicating its stability while the other is AuC with predicted Formation Energy 2.2 eV/atom denoting the material is quite unstable. In both cases, we see that Period Number is the most important atomic feature as it has maximum weight. Period and Group Numbers provide the information to distinguish each element. As the Group Numbers and the number of Valence Electrons are closely related, we see that the feature selector only selected the former thus avoiding duplicity. Electronegativity and Covalent Radius both are another two important features (with non-zero weight) are evident from the figure. Non-zero difference in Electronegativity of atoms indicates stability in structure. Both Au and C have the same Electronegativity (category value 5), and the feature selector gives the same weight to it, as a result, the difference of Electronegativity is zero in the case of unstable AuC. While for the case of stable BaEr2F8, the feature selector provides different non-zero weights to smaller Electronegative elements and zero weight to the largest Electronegative atom F. The Covalent Radius determines the extent of overlap of electron densities of constituents, therefore, it appeared as another important feature. Higher the radius means weaker the bond. It is interesting to note here the trend of weights is the reverse of that of radius itself (Ba has the largest radius 215 pm and has the smallest weight) for stable BaEr2F8. The scenario is reverse for unstable AuC. Ionization Energy plays a similar role as Electronegativity and we observe the same behavior of the feature selector. As can be seen from the example, the feature selector provides elaborate cues for domain experts to reason out the results.

Fig. 4: Feature selector values corresponding to atom features after trained on Formation Energy tagged data.
figure 4

The top bar chart represents the feature weights of BaEr2F8 and the one below represents the feature weights of AuC.

Band gap

In Fig. 5, we show the important features that appear in the case of band gap. It is interesting to see that the Electron Affinity came out to be the most important atomic feature for the band gap as it determines the location of conduction band minimum with respect to the vacuum. Again, as the number of Valence Electrons and the group number are collinear properties, only one (valence electrons) is found to be having a significant weight. The Conduction Band is composed of Ga-states while the valence band is composed of P states with a small admixture of Ga-states, which gives Ga-valence electrons more weight. The situation is reversed for the Ionization Energy, which determines the location of the valence band with respect to a vacuum, and as the valence band is mostly formed by the P atom, we see P has more weight.

Fig. 5: Feature selector values corresponding to atom features after trained on Band Gap tagged data.
figure 5

The bar chart represents the feature weights of GaP (Band Gap 2.26 eV).

Magnetic moment

In order to understand the feature importance in the case of magnetic moment, we compare the results obtained for two Co-based alloys, namely, CoPt and CoNi (Fig. 6). In both cases, the Atomic Volume, Period Number, and Electronegativity appear to be the three most important features. While in the case of CoNi, Electron Affinity of Ni also appeared to be as additional important feature. It can be seen that in the case of CoPt, the Atomic Volume of Co has higher weight, while for CoNi, the atomic volume for both the species have the same weight. This is quite intuitive, as for CoPt, the magnetic moment is mostly carried by Co atom, while in the case of CoNi, both the atoms have a significant contribution. Electronegativity plays an important role in the context of the magnetic moment. For example, the magnetic moment of Co in CoPt is slightly higher than its corresponding value in pure Co. The electronegativity difference between Co and Pt causes the electron transfer from Co minority spin band to Pt which in turn enhances its magnetic moment31,33. In the case of Period Number again we see that for CoPt, it is only the period number of the magnetic atom, i.e, the Co atom that is given visible weight while in the case of NiCo, the period number of the two atoms appears to be important.

Fig. 6: Feature selector values corresponding to atom features after trained on Magnetic Moment tagged data.
figure 6

The top bar chart represents the feature weights of CoPt and the bottom chart represents the feature weights of CoNi.

It is evident from the above analysis that CrysXPP is effectively constructing models where the important node features are physically intuitive.

In conclusion, we propose an explainable property predictor for crystalline materials, CrysXPP to predict different crystal states and elastic properties with accurate precision using a small amount of property-tagged data. We address the issue of limited crystal data where the value of a particular property is known, using transfer learning from an encoding module CrysAE; which we train in a property agnostic way with a large amount of untagged crystal data to capture all the important structural and chemical information useful to a specific property predictor. We further find the encoder knowledge is extremely useful in de-biasing DFT error using meagre instances of experimental results. CrysXPP outperforms all the baselines across seven diverse sets of properties. With appropriate case studies, we show that the explanations provided by the feature selection module are in sync with the domain knowledge. We release the large pretrained model CrysAE so that it could be fine-tuned using a small amount of tagged data by the research community on various applications with a restricted data source.

Methods

Hyperparameters

We have trained our model with varying convolution layers of the encoder module and obtained the best results with three convolution layers in the encoder module. We kept the embedding dimension for each node as 64, batch size of data as 512, and used average pooling to obtain \({{{{\boldsymbol{{{{\mathcal{Z}}}}}}}}}_{g}\). We selected λ = 0.01 for property selection. We varied the learning rate in logarithmic scale and selected 0.03 which yields faster convergence. We trained the auto-encoder for 200 epochs and property predictor for 200 epochs.