Introduction

Machine Learning (ML) based predictive models have made rapid strides in computational chemistry due to their efficiency and performance. Characterized by their computational efficiency and accuracy, these methods are capable of faster high-throughput screening compared to classical physics models1,2. This capability has roots in both novel learning algorithms and improved hardware. Even though ML models can offer faster predictions, the accuracy of these models is highly correlated with the availability of clean labeled data3. In general, it is difficult to develop accurate and robust ML models without sufficiently large labeled data4. Moreover, the acquisition of labeled data is expensive as it involves performing Density Functional Theory (DFT) simulations or experiments to characterize materials5,6. On the other hand, gigantic databases containing structures and compositions of materials without labels (properties) are available. These databases cannot be used in supervised learning tasks due to the lack of labels. Given the availability of large unlabeled datasets, two interesting questions are raised: (1) can we develop more efficient ML models that are capable of learning the underlying structural chemistry from unlabeled data, and (2) can these models be used to make the supervised learning tasks more accurate?

In this work, we aim to address these questions by leveraging Self-Supervised Learning (SSL) for material property prediction. Unlike supervised learning which uses labels for supervision, SSL makes use of the large unlabeled data for supervision to learn robust and generalizable representations that can be used for various tasks. Recently, SSL frameworks such as SimCLR7, Barlow Twins8, BYOL9, SwAV10, MoCo11, SimSiam12, Albert13, and self-supervised dialog learning14 have been successfully applied to computer vision and natural language processing tasks. The success of these SSL methods has inspired many works in molecular ML, leading to the development of highly accurate frameworks such as MolCLR15, dual view molecule pre-training16, 3D Infomax17, and numerous other popular works18,19,20,21,22,23,24,25. It should be noted that SSL-based methods have been developed for molecules, which have finite structures. However, the periodic crystalline materials are different from the molecules, since crystalline materials are composed of infinitely repeating unit cells of atoms, ions, or molecules. Besides, crystalline materials can have non-covalent bonds that are different from covalent bonds in molecules. Based on the differences, specialized deep learning architectures explicitly modeling crystals are required.

Most of the promising works developed for material property prediction tasks are using graph neural networks (GNN). GNNs consider non-Euclidean topology to construct a graph representation that can be learned and modified according to the task26,27,28. In general, the GNNs developed for material property prediction take input the 3D coordinates of the crystal and construct the graph by modeling atoms as the nodes and the interactions between the atoms as edges. GNNs developed for material property prediction include CGCNN29, OGCNN30, SchNet31, MegNet32, and other models33,34,35,36,37,38,39,40,41,42,43,44. Developments have also been made in tasks such as material structure generation and prediction45,46,47,48,49,50 as well as identifying new materials with specific properties51. Despite progress being made in developing self-supervised ML architectures in the molecular ML, there is a noticeable lack of research works implementing such techniques for the periodic crystalline systems property prediction.

In this work, we introduce Crystal Twins (CT): an SSL framework for crystalline material property prediction with GNNs (Fig. 1). In pre-training, the models in the CT framework does not make use of any labeled data to learn crystalline representations, instead, it trains ML models in a self-supervised manner. In the CT framework, we use the CGCNN29 as the encoder to learn expressive representations of crystalline system. We adapt two different SSL pretraining methods based on Barlow Twins8 and SimSiamese7 loss functions. In CTBarlow which uses Barlow twins loss for pre-training, the GNN encoder generates representations of two augmented instances from the same crystal and the objective of pre-training is to make the cross-correlation matrix of the two embeddings as close as possible to the identity matrix (Fig. 1A). In the other model CTSimSiam that uses SimSiamese7 loss function for pre-training, the objective is to maximize the cosine similarity between the embeddings generated from the graph encoder CGCNN for two augmented instances. Additionally, in CTSimSiam, one branch has the stop gradient operation and the other has predictor head after the graph encoder (Fig. 1B). To create augmented instances, we introduce the combination of three different augmentation techniques: random perturbations, atom masking, and edge masking (Fig. 1C). The representations learned by the encoder are later used for downstream material property prediction tasks in the fine-tuning stage (Fig. 1D). In the pre-training stage, graph encoder learns representations from unlabeled data. Using the pre-trained weights to initialize the graph encoder for fine-tuning, both CTBarlow and CTSimSiam demonstrate superior prediction performances on 14 challenging datasets. We also compare the performance of the CT models with other competitive supervised learning baselines. We have successfully demonstrated the use of self-supervised learning for crystalline material property prediction.

Fig. 1: Overview of the crystal twins (CT) framework.
figure 1

We propose two methodologies for SSL pre-training based on the Barlow Twins loss and SimSiamese loss function. The CT framework takes the structural file (CIF) as the input and then augments the structure to create two different augmented instances. (A) In CTBarlow, each instance is passed to the CGCNN graph encoder followed by a projector to generate embedding. The pre-training objective aims to maximize the cross-correlation between the two embeddings. (B) The CTSimSiam, each instance is passed through same CGCNN encoder branch to generate embeddings. One branch has an projector MLP head after the encoder and the other branch has stop-gradient operation. The pre-training objective is to maximize similarity between the embeddings. (C) To create augmented instances, three augmentation techniques are used in this work: random perturbations, atom masking, and edge masking. (D) In the pre-training stage we trained using SSL. In the fine-tuning stage, the pre-trained weights are shared with the encoder (CGCNN) which is trained to predict the material property.

Results

To comprehensively evaluate the performance of models using the CT framework, we test its performance on 13 challenging regression benchmark datasets and 1 classification dataset. The capabilities of the models in the CT framework are tested on a wide variety of properties including exfoliation energy, frequency of the highest frequency optical phonon mode peak, band gap, formation energy, refractive index, bulk modulus, shear modulus, Fermi energy, and metallicity. An overview of the datasets used for benchmarking the performance of the models in the CT framework is shown in Table 1. Among the total 14 datasets, we benchmark the performance of the models on 9 datasets (Table 2) from the MatBench suite and the remaining 5 datasets (Table 3) follow the datasets used in previously published works of CGCNN29 and OGCNN30. More detailed descriptions of these datasets are available in the Supplementary Information.

Table 1 Overview of the datasets used for benchmarking the performance of the CT framework.
Table 2 Mean and standard deviation of test MAE of Crystal Twins (CT) in comparison to the supervised baselines on MatBench42 regression benchmarks.
Table 3 Mean and standard deviation of test MAE of Crystal Twins (CT) in comparison to the supervised baselines on 5 regression benchmarks.

Benchmarking the models on the MatBench Suite

The MatBench42 suite consists of multiple material property prediction datasets. In this work, we consider 9 datasets that have crystal structures as input for benchmarking our self-supervised learning models CTBarlow and CTSimSiam. We compare the results of our framework with the previously published supervised learning baselines available on MatBench. The protocols for benchmarking the performance of CTBarlow and CTSimSiam are exactly same as introduced in MatBench. We make use of nested 5 fold cross validation to generate the results in Table 2. The detailed hyperparameters used for finetuning models are listed in the Supplementary Information (Supplementary Table 3). We observe that the models trained using SSL based approach consistently outperform the supervised learning CGCNN baseline. Improved results for models in CT framework over the CGCNN baseline are observed for 7 out of the 9 datasets. For the Is Metal dataset, the performance of the models in CT framework are within the standard deviation of the supervised model. We also compare the performance of our SSL model with AMMExpress42 model in the MatBench suite. We observe that models in the CT framework outperform AMMExpress on 6 out of the 9 datasets. Additionally, we also benchmark our model against the state-of-the-art model for material property prediction ALIGNN43. It was observed that our model performs better than ALIGNN only for the classification task. It must be noted that the ALIGNN achieves this high performance by modeling three-body interactions whereas CGCNN models two-body interactions. The enhancement of explicitly modeling three-body interactions gives ALIGNN more expressive power than CGCNN making it a more accurate baseline. Since we are using CGCNN as the graph encoder model in the CT framework, the CTBarlow and CTSimSiam are essentially modeling two-body interactions and are unable to compete with ALIGNN. The improvements demonstrated in our results over supervised learning baselines CGCNN29, AMMExpress42 show the promise of using SSL for learning representation of crystalline materials.

Benchmarking the models on additional datasets

Apart from benchmarking the performance on datasets from the MatBench suite. We also benchmarked the performance of our models on additional datasets similar to previously published works OGCNN30 and CGCNN29. The datasets include properties like formation energy, band gap and Fermi energy. As we pre-trained the model with CGCNN encoder, the comparison with the CGCNN model is the most direct and fair, and it offers insights into how self-supervised learning methods can help in predicting the crystalline material properties with a high degree of accuracy. We also compare the performance of the models in the CT framework with other popular supervised GNN models, i.e., GIN20 and OGCNN30 for the datasets in Table 3. We would like to note that all the models used for comparison in Table 3 are trained with the same hyperparameters as suggested in their publicly available codes. The train/validation/test split for all the datasets is the same and set to 0.6/0.2/0.2 following previous standard benchmarking protocols. The data splitting is performed randomly following the protocols in the previously published works. The test Mean Absolute Errors (MAEs) for the supervised training baselines and the models in CT framework are shown in Table 3. The detailed hyperparameters used for supervised models are listed in the supplementary Table 4.

It is observed that the CT models outperform all supervised learning baselines on all the 5 regression tasks. We would like to note that the performance improvements (Supplementary Table 1) achieved by the CT models over the baseline CGCNN model are non-trivial. We observed an average improvement of 17.09% for CTBarlow and 21.83% for CTSimSiam when compared to CGCNN. The results in Table 3 clearly demonstrate the merit of using self-supervised learning frameworks for periodic crystal property prediction. In order to test the generic nature of our SSL framework, we also implement GIN20 pre-trained via the Barlow Twins loss. We observed impressive gains in performance for the GINBarlow over the supervised GIN model. The average improvement of the GINBarlow model when compared to the supervised GIN model is 36.97%. The improvement in case of GINBarlow indicates that CT framework can be applied with other graph encoder architectures and performance gains may be expected for those GNN models when compared to their supervised counterpart.

Ablation study

To compare the effectiveness of the different augmentations techniques, we pre-train three CTBarlow models, (1) using only random perturbation augmentations (RP), (2) using only atom masking and edge masking augmentations (AM+EM), (3) using all three random perturbation, atom masking, and edge masking augmentations (RP+AM+EM). We report the MAE of the model on different fine-tuning datasets to determine the effectiveness of the augmentation techniques (Fig. 2).

Fig. 2: Ablation study of three augmentation techniques, random perturbation (RP), atom masking (AM), and edge masking (EM), for CTBarlow model.
figure 2

(A) Evaluating the effect of different augmentation techniques in Band Gap and HOIP dataset where the label is band gap. (B) Evaluating the effect of augmentation techniques on the FE, Lanthanides, and Perovskites datasets for which the label is formation energy. (C) Evaluating the effect of different augmentation strategies on the Fermi energy and \({\log }_{10}\) VRH - shear modulus of the structures prediction. The error bars indicate variation in MAE over 3 different runs.

The performance of AM+EM augmentation is better than RP for perovskites, BG and GVRH datasets, whereas RP augmentation has better performance than AM+EM for Fermi energy, lanthanides, and HOIP datasets. For FE dataset the performance of both RP and AM+EM augmentation techniques is the same. It must be noted that the performance of models trained with different augmentation techniques is almost identical, making it difficult to conclusively ascertain which augmentation technique is better. Moreover, we also observe that the effectiveness of the augmentation techniques is dataset dependent. We would also like to note that the standard deviation of MAE is always lower when using the pre-trained model with all augmentation techniques. Therefore, using a combination of all three augmentation techniques is most effective.

Understanding the CT representations

To understand the CT representations, we visualize the representations from the pre-trained and fine-tuned CTBarlow framework in comparison to the CGCNN model in 2D using t-SNE52. The t-SNE representation maps the embedding based on the similarity in the 2D space. The comparison between the representations of the CGCNN model and the CTBarlow model for the perovskites dataset is shown in Fig. 3. Each point is colored by the formation energy of perovskites which is the label that the model is trained on in the fine-tuning stage.

Fig. 3: Visualizing the embeddings space for the perovskites dataset using t-SNE.
figure 3

Every point on the t-SNE plot is colored corresponding to the formation energy of the crystalline system. (A) The t-SNE plot for the embedding was generated from the CGCNN model. (B) The t-SNE plot for the embedding was generated from the graph encoder of CT model after fine-tuning.

We observe that the t-SNE projection from the CTBarlow model has a better clustering, namely, the crystalline materials with higher formation energy are clustered at the top left of the t-SNE projection plot (Fig. 3B) when compared to the CGCNN (Fig. 3A). Similarly the materials with lower formation energy are clustered at the bottom of the t-SNE plot (Fig. 3B) for the CTBarlow model. For example, perovskites InOsO3 and LaReO3 with relatively lower formation energies of −0.58 and −0.64 eV/atom, respectively, are clustered closely together in t-SNE projection from CTBarlow compared to CGCNN. This demonstrates the generalizability of the representations learned by the CTBarlow model when compared to supervised learning. Such representation learnt from the CT framework can also be used to characterize and understand the large chemical space of materials.

Discussion

In this work, we develop Crystal Twins (CT), a generic SSL framework for crystalline material property prediction. In this framework, we propose two SSL strategies using the twin graph neural networks to learn representations by leveraging the Barlow Twins loss and SimSiamese loss during pre-training. The models in CT framework (CTBarlow and CTSimSiam) achieve superior performance compared to other competitive supervised learning baselines. The models in CT framework demonstrate high generalizability and robustness by learning representations that can be used to predict a variety of properties like formation energy, band gap, Fermi energy, shear modulus, bulk modulus, and refractive index of different crystalline materials. The pre-training of models in the CT framework has been performed on significantly less amount of data compared to SSL models in other domains like molecular machine learning, computer vision, and natural language processing. In general, SSL models are known to demonstrate better performance with larger unlabeled data as it allows them to learn more generalizable representations. We expect the models in the CT framework to demonstrate a superior performance with larger training data when compared to our current results. The representations learned by the models in the CT framework are of great promise and can open up avenues for exciting research in understanding the chemical space and designing materials with desired properties.

Methods

In this section, we describe the components of the CT framework (Fig. 1). In general, SSL frameworks employ correlations in the input itself to learn robust and generalizable representations from unlabeled data.53. As a part of CT framework we propose two different SSL pretraining models namely CTBarlow and CTSimSiam. For the CTBarlow, the goal during pre-training is to force the empirical cross-correlation matrix created from the encoder embeddings of two different augmentations generated by the same crystal towards the identity matrix. All the elements in the cross-correlation matrix lie between −1 and 1, with 1 representing maximum correlation. Intuitively, since the embeddings are generated from augmentations of the same crystalline system, the cross-correlation matrix must be close to the identity matrix. For the CTSimSiam, the pre-training objective is to maximize the cosine similarities between encoder embeddings of augmented instances generated from the same crystalline system. To avoid model collapse, CTSimSiam implements an extra projection head on one side of the twin networks and applies the stop-gradient technique on the other side in training. Using such objectives during pre-training allows the graph encoder to learn robust representations. To create the augmented instances, we use augmentation techniques, including atom masking, edge masking, and random perturbation (refer Supporting Information). The embeddings for the augmented instances of the crystalline system are generated via the CGCNN graph encoder. We pre-train the CGCNN model with two SSL strategies using Barlow Twins and SimSiamese loss function. The weights of the pre-trained self-supervised model are used to initialize the graph encoder model during the fine-tuning stage for material property prediction.

Graph neural network encoder

Most recent successful deep learning approaches for crystalline material property prediction are based on GNNs because of their ability to capture structural geometry and chemistry. In a crystal graph (G), we consider the atoms as the nodes (V), and interactions between them are modeled via edges (E). In general, GNNs aggregate information from the neighborhood of the node to construct embeddings that are updated iteratively. The update for the GNN can be described as in Eq. (1).

$$\begin{array}{r}\begin{array}{ll}{{{{\boldsymbol{h}}}}}_{v}^{(k)}&={{{{\rm{COMBINE}}}}}^{(k)}\left({{{{\boldsymbol{h}}}}}_{v}^{(k-1)},{{{{\rm{AGGREGATE}}}}}^{(k)}\left(\{{{{{\boldsymbol{h}}}}}_{u}^{(k-1)}| u\in {{{\mathcal{N}}}}(v)\}\right)\right),\end{array}\end{array}$$
(1)

where \({{{{\boldsymbol{h}}}}}_{v}^{(k)}\) is the feature of the node v at the k-th layer and \({{{{\boldsymbol{h}}}}}_{v}^{(0)}\) is initialized by node feature xv. \({{{\mathcal{N}}}}(v)\) denotes the set of all the neighbors of node v. \({{{{\boldsymbol{a}}}}}_{v}^{(k)}\) is the output from the aggregation operation at the kth layer. The aggregation operation collects the features of neighboring nodes and the combination operation combines the original node feature with the aggregated features. To extract the feature of the entire crystal system, hG, readout operation integrates all the node features among the graph G as given in Eq. (2).

$${{{{\boldsymbol{h}}}}}_{G}={{{\rm{READOUT}}}}\left(\{{{{{\boldsymbol{h}}}}}_{v}^{(k)}| v\in G\}\right).$$
(2)

The readout operations such as summation, averaging, and max pooling are most commonly used54.

In this work, we implement the CGCNN29 architecture as the GNN encoder. We choose CGCNN because of its competitive performance and computational efficiency when compared to other GNN baselines. Moreover, CGCNN is one of the most widely benchmarked baseline models for material property prediction allowing us to compare the performance of our SSL framework with CGCNN and other baselines. To encode crystal features and obtain an embedding, we use mean pooling to generate a latent representation with the dimension of 64. After the GNN encoder, the projection head with 2 MLP layers is attached to generate the final embedding on which the SSL loss functions are applied for pre-training. Additionally, we also implement a general purpose graph neural network GIN20 to test the fidelity of SSL methods on another architecture apart from CGCNN.

To generate self-supervised learning representations, we need to construct different augmentations of the crystalline system. Inspired by AugLiChem,55 we devise three different augmentation techniques (Fig. 1C), namely, random perturbation, atom masking and edge masking. The random perturbation augmentation perturbs each atom by a distance drawn from the uniform distribution between 0 Å and 0.05 Å. For atom masking, we randomly mask 10% of the atoms in the crystal, similarly for edge masking we randomly mask 10% of the edge features between two neighboring atoms. More details on atom masking and edge masking are provided in the Supplementary Information (Supplementary Fig. 1). These augmentations are applied to the crystalline systems and two augmented instances are generated randomly on the fly at each epoch during pre-training. These augmented instances are fed into the GNN encoder to generate embeddings on which SSL loss functions are applied.

Barlow Twins loss

In the pre-training stage for CTBarlow, we use the Barlow Twins loss function to learn graph representations from crystals. This loss is based on the redundancy reduction principle proposed by neuroscientist H. Barlow56,57 and was introduced to SSL by Zbontar et al.8. We use the Barlow Twins loss function in CT because of its high performance and ease of implementation. Moreover, the Barlow Twins loss, unlike other contrastive loss functions, does not explicitly require positive and negative pairs for pre-training. The Barlow Twins loss function is applied to the cross-correlation matrix created from encoder-generated embeddings of the two different augmentations generated from the same crystalline system. The Barlow Twins loss function is represented by Eq. (3),

$${L}_{{{{\rm{BT}}}}}\,{\triangleq}\mathop{\sum}\limits_{i}{(1-{C}_{ii})}^{2}+\lambda \mathop{\sum}\limits_{i}\mathop{\sum}\limits_{j\ne i}{C}_{ij}^{2},$$
(3)

where C is the cross-correlation matrix of embeddings from two augmented instances, the cross correlation matrix is given by Eq. (4). The λ used in this work is 0.0051 same as the original paper.

$${C}_{ij}\,{\triangleq} \frac{{\sum }_{b}{{{{\boldsymbol{Z}}}}}_{b,i}^{A}{{{{\boldsymbol{Z}}}}}_{b,j}^{B}}{\sqrt{{({{{{\boldsymbol{Z}}}}}_{b,i}^{A})}^{2}}\sqrt{{({{{{\boldsymbol{Z}}}}}_{b,j}^{B})}^{2}}}$$
(4)

where b is the index of the data in batch and i, j index the vector dimensions of the projector output (ZA and ZB), for both the augmented instances A and B from the same crystalline material.

SimSiam loss

We developed another variant of CT that uses SimSiam12 denoted as CTSimSiam. In this case, an extra MLP head f(  ) is added to the GNN backbone to map the latent vector Z to P, namely P = f(Z). The distance between two vectors is defined as Eq. (5).

$${{{\mathcal{D}}}}({{{{\boldsymbol{P}}}}}_{b}^{A},{{{{\boldsymbol{Z}}}}}_{b}^{B})=-\frac{{{{{\boldsymbol{P}}}}}_{b}^{A}}{\parallel {{{{\boldsymbol{P}}}}}_{b}^{A}{\parallel }_{2}}\cdot \frac{{{{{\boldsymbol{Z}}}}}_{b}^{B}}{\parallel {{{{\boldsymbol{Z}}}}}_{b}^{B}{\parallel }_{2}},$$
(5)

where b denotes the index of the data in a batch and A, B denote two augmented instances from the same crystalline material. The objective to minimize given a batch is further shown in Eq. (6).

$${L}_{{{{\rm{SimSiam}}}}}\,{\triangleq}\,\frac{1}{2}\mathop{\sum}\limits_{b}\left({{{\mathcal{D}}}}({{{{\boldsymbol{P}}}}}_{b}^{A},{\mathtt{stopgrad}}({{{{\boldsymbol{Z}}}}}_{b}^{B}))+{{{\mathcal{D}}}}({{{{\boldsymbol{P}}}}}_{b}^{B},{\mathtt{stopgrad}}({{{{\boldsymbol{Z}}}}}_{b}^{A}))\right),$$
(6)

where \({\mathtt{stopgrad}}(\cdot )\) means the gradient is not back propagated on this branch of the CTSimSiam model. Such an asymmetric architecture and the stop-gradient operation avoid the collapse of learned representations.

Training details

In the pre-training stage, the embedding dimension of the CGCNN encoder is set to 128 for the CTBarlow and 256 for the CTSimSiam model. We use the Adam optimizer58 with a learning rate of 0.00001 and a batch size of 64 and pre-train the model for 15 epochs. The other hyperparameters for the CGCNN (graph encoder) model are kept the same as in the original paper. The train/validation ratio for pre-training data is 95%/5%. For pre-training, we combine the datasets from the Matminer database59 and the hypothetical Metal-Organic Framework dataset60, aggregating a total of 428,275 samples. The labels in the datasets are not used during CT pre-training. Additional details about the hyperparameters during the pretraining stage are available in Supplementary Table 2. In the fine-tuning stage, we add a randomly initialized MLP head with two fully connected layers to generate the final property prediction. The CTBarlow and CTSimSiam are tested on a variety of datasets including some datasets from previously published work OGCNN30 and matbench suite42. Additional details about the hyperparameters used during the finetuning stage are mentioned in Supplementary Table 3.