As technology continues to drive biomedical research forward, new challenges arise with the surge of high volume, information-dense and multivariate data that are generated. The extraction of critical information from such data remains an open problem in biomedical research, which can be significantly aided by the incorporation of machine learning techniques. In particular, unsupervised learning methods have the potential to uncover the underlying structure in biomedical data and therefore propel research on biological processes and diseases that have not yet been fully understood.

This paper aims to build unsupervised neural methods that can be applied to understand cell differentiation using gene expression data. Recent technology for performing single-cell RNA sequencing has resulted in high-throughput experiments capable of measuring gene expression levels for individual cells in a population, thus achieving a granularity not previously possible. In-depth analysis of this high-dimensional and complex gene expression data about the cells can lead to important biomedical discoveries about the factors influencing the differentiation process. However, gene expression data is in general high dimensional, as there are thousands of gene expression measurements for each cell, and very complex.

We propose using a disentangled generative probabilistic framework to model single-cell RNA sequencing gene expression data and build a low dimensional representation that can help us discovering latent biological mechanisms. In this framework, we develop novel methodology that can be used to identify the different cell types in such single-cell RNA-seq datasets using the learned latent representations. We show that we can correctly identify the different cell types in two datasets: one dataset consisting of hematopoietic stem and differentiated cells in zebrafish obtained using Smart-Seq21, and another dataset consisting of humans pancreatic cells obtained using CEL-Seq22. We also show and discuss some limitations of our methods on a dataset with human hematopoietic cells3. In addition, we explore performing perturbations to the latent representation to study how the stem or progenitor cells can be changed into differentiated cells. Moreover, we propose a graph representation learning method based on an autoencoder consisting of graph convolutional layers that can be used to analyze links between single cells.

Current methods for analysing single-cell RNA-seq data are based on the combination of dimensionality reduction techniques and clustering algorithms, at either gene or cell level analysis. The identification of cell lineages and trajectories is one of the main fields of study where single-cell scRNA-seq has had a great influence. The most widespread computational tools include Waterfall or Wishbone4,5, which are based on principal component analysis (PCA). Monocle uses independent component analysis (ICA) and SCUBA pseudotime focuses on t-distributed stochastic neighbour embedding (tSNE)6,7. However, some of these methods, particularly the ones based on linear approaches such as PCA, are not able to capture the complex relationships between the input dimensions and can disregard meaningful information within the data. In addition, Yeung and Ruzzo8 showed that using PCA before clustering gene expression data has a negative effect on the quality of the clusters. Despite these findings, a lot of research in gene expression analysis9,10,11 is based on applying PCA before clustering cells to identify their types.

Autoencoders can be used to perform non-linear dimensionality reduction, but also to extract biologically relevant latent features from transcriptomics data. Related work shows the effectiveness of these models in analysing gene expression data. Way and Greene12 trained a variational autoencoder on pan-cancer RNA-seq data from The Cancer Genome Atlas13 to explore the biological relevance of the latent space produced by the autoencoder. Tan et al.14 built a denoising autoencoder capable of modelling the response of cells to low oxygen and finding differences between strains in gene expression from Pseudomonas aeruginosa. Eraslan et al.15 used autoencoders for denoising purposes, developing a method that is linearly scalable with the number of cells and outperforms existing methods for data imputation. Talwar et al.16 proposed an autoencoder-based method to perform gene expression imputation, while Wang and Gu17 use variational autoencoders for dimensionality reduction and visualization of single-cell data. Finally, Rashid et al.18 used a variational autoencoder to identify tumour subpopulations, marker genes, as well as differentiation trajectories for the malignant cells using scRNA-seq genomic data. Compared to our proposed models, these methods do not use a training objective to enforce disentanglement in the latent representation12,18 and focus on different applications such as denoising14,15, visualization17 and missing data imputation16.

This work represents the first application, to the best of our knowledge, of disentanglement, perturbation and graph-based methods for variational autoencoders with the aim of analysing cell differentiation using single-cell RNA-seq data. We emphasise the importance of building interpretable models, by analysing the relationship between the embedding and gene expression spaces. We also explore the robustness and variability of the latent space by introducing perturbations. Graph representation learning represents a new powerful generation of methodologies for graphs. We show how predicting links between cells can provide insights into differentiation trajectories.

Disentangled generative probabilistic framework

We propose using a generative probabilistic framework19 to model the biological processes that lead to the changes in the observed gene expression for cells at different stages in the differentiation process. Let \({\mathscr{D}}={\{{{\bf{x}}}^{(i)}\}}_{i=1}^{N}\) be a high-dimensional single-cell RNAseq dataset consisting of the gene expression of N i.i.d cells. Each gene expression vector x(i) is an observation from a continuous random variable x, having distribution pdata(x). The gene expression data is assumed to be generated by some random process, modelled by an unobserved continuous random variable z with parametrised prior distribution pθ(z). The marginal likelihood pθ(x), also known as the evidence, is computed by integrating over the possible latent representations:

$${p}_{{\boldsymbol{\theta }}}({\bf{x}})={\int }_{{\bf{z}}\in {\mathscr{Z}}}\,{p}_{{\boldsymbol{\theta }}}({\bf{x}},{\bf{z}}){\rm{d}}{\bf{z}}={\int }_{{\bf{z}}\in {\mathscr{Z}}}\,{p}_{\theta }({\bf{x}}|{\bf{z}}){p}_{\theta }({\bf{z}}){\rm{d}}{\bf{z}}.$$

Computing the integral involves spanning the space of values for z which is often intractable. For inference, the posterior \({p}_{{\boldsymbol{\theta }}}({\bf{z}}|{\bf{x}})=({p}_{{\boldsymbol{\theta }}}({\bf{x}}|{\bf{z}}){p}_{\theta }({\bf{z}}))/{p}_{{\boldsymbol{\theta }}}({\bf{x}})\) has to be computed, which is also intractable, as it requires the marginal likelihood.

To learn in such a framework we use variational inference and we approximate the posterior using the variational distribution qϕ(z|x). We thus build a variational autoencoder model19 and we use a multivariate Gaussian \({\mathscr{N}}({\bf{z}};\mu ,{\rm{diag}}({{\boldsymbol{\sigma }}}^{2}))\) distribution with mean μ and variance σ2 to approximate qϕ(z|x). An encoder neural network is trained to estimate qϕ(z|x). In addition, an isotropic multivariate Gaussian prior is assigned to the latent representation: \({p}_{{\boldsymbol{\theta }}}({\bf{z}})={\mathscr{N}}({\bf{z}};{\bf{0}},{\bf{I}})\). The decoder neural network is trained to reconstruct (generate) the gene expression data from the latent representation and thus estimate pθ(x|z). See Fig. 1a for a graphical illustration of the model. The training objective of the standard variational autoencoder model19 penalises the mutual information between the input gene expression and the latent representation20 and it also does not encourage disentaglement in the latent representation21. Disentanglement is desirable in our case because, ideally, the latent representation Z should be able to separate the biological factors that have led to the development of various cell types.

Figure 1
figure 1

Pipeline for identifying the cell types in a dataset using DiffVAE. Illustration on the zebrafish dataset. (a) Train DiffVAE to map the gene expression measurements for each cell to a m-dimensional latent representation z. (b) Apply T-SNE on the latent representation z and clustering to find the different cell clusters in the dataset. (c) Identify which latent dimensions in z encode the differentiation of the cells in each cluster. (d) Find the high weights genes for the relevant latent dimensions. (e) Map the clusters to cell types based on the high weight genes for each cluster.

We introduce DiffVAE, a variational autoencoder that can be used to model and study the differentiation of cells using gene expression data. DiffVAE is an MMD-VAE, part of the InfoVAE family of autoencoders21 and is trained to maximize the following objective:

$${{\mathscr{L}}}_{{\rm{DiffVAE}}}({\boldsymbol{\theta }},{\boldsymbol{\phi }};{\bf{x}})={{\mathbb{E}}}_{{q}_{{\boldsymbol{\phi }}({\bf{z}}|{\bf{x}})}}[\log \,{p}_{{\boldsymbol{\theta }}}({\bf{x}}|{\bf{z}})]-{\rm{MMD}}({q}_{{\boldsymbol{\phi }}}({\bf{z}})\Vert {p}_{{\boldsymbol{\theta }}}({\bf{z}})),$$

where \({{\mathbb{E}}}_{{q}_{{\boldsymbol{\phi }}({\bf{z}}|{\bf{x}})}}[\log \,{p}_{{\boldsymbol{\theta }}}({\bf{x}}|{\bf{z}})]\) represents the reconstruction accuracy and the maximum mean discrepancy (MMD)22,23,24 divergence between qϕ(z) and pθ(z) measures how different the moments of two probability distributions are. The intuition behind the MMD divergence is given by the fact that two probability distributions are identical if and only if their moments match. Zhao et al.21 prove that using this training objective will always prefer to maximizes mutual information between the input and the latent representation. Moreover, minimising the divergence MMD(qϕ(z)||pθ(z)), will encourage qϕ to be similar to the prior \({p}_{{\boldsymbol{\theta }}}({\bf{z}})={\mathscr{N}}({\bf{z}};{\bf{0}},{\bf{I}})\) with diagonal covariance matrix, which will lead to disentanglement in the latent dimension. Using the MMD to measure discrepancy between distributions achieves best performance in the InfoVAE model family21. The MMD also achieved good results when used in the training objective of other autoencoder models23,25,26. For further analysis we use as the latent representation z the mean μ of the distribution q(z|x) learnt by DiffVAE. The reason for this choice is the fact that the mean of the distribution q(z|x) represents the maximum likelihood estimate of the latent distribution learnt by DiffVAE for the cells in the dataset.

The DiffVAE model consists of two fully connected layers in the decoder and encoder networks. See the Methods section for more details about the DiffVAE model. The models in this paper were implemented in Python using Keras27.

Identifying cell types using DiffVAE

Data details and pre-processing

The unsupervised models and the methodology developed in this paper are used to analyse single-cell gene expression data from hematopoietic stem and differentiated cells in zebrafish1 and in human3 and also from human pancreatic cells2. Let a scRNA-seq dataset be denoted as \({\mathscr{D}}={\{{{\bf{x}}}^{(i)}\}}_{i=1}^{N}\), where \({{\bf{x}}}^{(i)}={[{x}_{1}^{(i)}{x}_{2}^{(i)}\ldots {x}_{k}^{(i)}]}^{T}\) consists of the transcriptomics data for cell (i). The zebrafish dataset consists of k = 1845 gene expression measurements from N = 1422 cells. We used the same 1845 genes identified by1 to be the most highly variable ones among the 1422 zebrafish single cells. The dataset with human pancreatic cells consists of N = 2285 cells with measurements from the k = 4000 most highly variable genes. The dataset with human hematopoietic cells contains N = 1034 cells with k = 700 measurements from the most variable genes. In all cases, we consider that the cell states are initially unknown and we show how the methodology developed in this paper can be used to identify them. Note that some of the results on these datasets are also presented in the supplementary materials.

The transcriptomics data used is log-normalized. However, to use the transcriptomics data as input to DiffVAE, we performed additional normalization through Min-Max scaling such that the expression values for the genes were scaled to the range [0, 1]. This way we model the gene expression for each cell as a multivariate Bernoulli distribution in our probabilistic framework.

Pipeline for identifying the cell types

In this section, we describe how DiffVAE can be used to find the different cell types in each dataset. Figure 1 shows the methodological pipeline for this process, with the specific details described in further subsections.

Using DiffVAE to obtain cell clusters

DiffVAE was trained to map the gene expression data for the single cells to a latent representation of m dimensions (Fig. 1a). For the datasets with hematopoietic cells (both zebrafish and human), we used m = 50 latent dimensions, while for the dataset with the human pancreatic cells, we used m = 100 latent dimensions. The large number of latent dimensions is needed to capture the complex biological processes influencing cell differentiation. To visualize the data and identify the different cells, we further use t-Distributed Stochastic Neighbour Embedding (t-SNE)28 to obtain a 2-dimensional embedding for each cell. In the zebrafish dataset, K-means clustering is applied to the t-SNE embedding to obtain 5 cell clusters (Fig. 1b). In the datasets with human cells, we used DBSCAN clustering. We further develop the methodology for mapping each cluster to a cell type.

Latent dimensions encoding cell differentiation

DiffVAE was designed to model the data generating process giving rise to the observations in our dataset \({\mathscr{D}}\). Thus, this method should be able to identify the biological mechanisms that result in the observed gene expression value for our cells. Consider the analysis of a latent dimension k for any of the models. Let \({{\bf{z}}}_{k}={[{z}_{k}^{(1)}{z}_{k}^{(2)}\ldots {z}_{k}^{(N)}]}^{T}\) be the predicted value of the encoder for zk across all of the cells in the dataset. Let μk and σk be the mean and standard deviation of zk. We define:

$${{\mathscr{D}}}_{k}=\{{{\bf{x}}}^{(i)}\in {\mathscr{D}}|{z}_{k}^{(i)}\ge {\mu }_{k}+\sigma \vee {z}_{k}^{(i)}\le {\mu }_{k}-\sigma \}$$

as the set of cells at least a standard deviation from the mean in latent dimension k. By computing the percentage distribution of the cells in \({{\mathscr{D}}}_{k}\) across the distinct cell clusters found in the dataset, we can evaluate how well the latent dimension is encoding the differentiation of the cells in a particular cluster (Fig. 1c). Thus, for each cluster C we compute the percentage of cells from cluster C in each of \({{\mathscr{D}}}_{k},k\in \mathrm{\{1,2,}\ldots \mathrm{,50\}}\). The latent dimensions relevant for the differentiation of cells in cluster C will be the ones with the top 10 highest percentage of cells from cluster C in \({{\mathscr{D}}}_{k}\).

Identifying high weight genes

The decoder in DiffVAE learns to reconstruct the original gene expression data, and therefore, the weights in the decoder indicate the contribution of each gene in the biological process. By finding the high weight connections between the latent dimensions relevant for each cell cluster and the reconstructed gene expression, we can identify the marker genes for each cell cluster. This will help us identify the cell types.

The high weight connections can be obtained using the weight matrices in the decoder. The decoder consists of a two fully connected layers. Let \({\bf{z}}\in {{\mathbb{R}}}^{m},{{\bf{h}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{{n}_{1}},{{\bf{h}}}^{\mathrm{(2)}}\in {{\mathbb{R}}}^{{n}_{2}},{\bf{x}}{\prime} \in {R}^{n}\), be the sequence of layer activations in the decoder, where the latent dimension z represents the input, h(1), h(2) are the hidden layers and x’ is the output. The weight matrices for the connections between the layers in the decoder can be described by \({{\bf{W}}}^{\mathrm{(0)}}\in {{\mathbb{R}}}^{m\times {n}_{1}},{{\bf{W}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{{n}_{1}\times {n}_{2}},{{\bf{W}}}^{\mathrm{(2)}}\in {{\mathbb{R}}}^{{n}_{2}\times n}\). Let \({\boldsymbol{\omega }}\in {{\mathbb{R}}}^{m\times n}\) be the weight matrix for the connections between the latent dimension and the output. ω can be computed by multiplying the weight matrices between the individual fully connected layers, as follows: \(\omega ={{\bf{W}}}^{\mathrm{(0)}}\cdot {{\bf{W}}}^{\mathrm{(1)}}\cdot {{\bf{W}}}^{\mathrm{(2)}}\), where the matrix element ωij indicates the weight of the connection between latent dimension i and gene j. For each latent dimension, the genes are sorted by the absolute value of their weight. The genes having the highest of such weights are referred to as the high weight genes (Fig. 1d).

For each cluster, we selected the latent dimensions that distinguished the best the cells in the clusters and then computed the high weight genes. The high weight genes found for the clusters in the zebrafish dataset are given in Table 1. Using knowledge from biomedical literature about marker genes for blood cells, we mapped each cluster to a cell type. Thus, Cluster 1 corresponds to HSPCs, Cluster 2 to Neutrophils, Cluster 3 to Monocytes, Cluster 4 to Erythrocytes and Cluster 5 to Thrombocytes. The same process was used to map the clusters to cell types in the dataset with human pancreatic cells; see Supplementary Table 1 for the high weight genes found for the clusters in the human pancreatic dataset and their mapping to cell types.

Table 1 Zebrafish.

Our results for identifying the different cell types in the zebrafish dataset are validated by1 who computationally reconstructed the differentiation trajectories using the Monocle2 algorithm29 and found the same cellular states. In particular, there is 89.9% overlap between the cell types identified using DiffVAE and the cell types obtained by1. Conversely, for the dataset with the human pancreatic cells, we found there is a 96.2% overlap between the cells types obtained using DiffVAE and the ones reported by Murano et al.2. In addition, DiffVAE identified all the different cell types in the dataset except for the epsilon cells. However, note that are only 4 epsilon cells in the dataset and Murano et al.2 also did not identify them computationally, but rather based on the expression of the GHRL gene. See Supplementary Fig. 1 for the clusters found using DiffVAE on the dataset with human pancreatic cells.

The representations built by DiffVAE on the human hematopoietic cells do not display separable clusters, which makes it difficult to identify all of the cell types. Velten et al.3 also indicate that the hematopoietic stem cells, multipotent progenitors and multilymphoid progenitors cells form a unique continuous group when applying clustering methods to the dataset. Refer to Supplementary Fig. 2 and the corresponding section for a discussion of the limitations of DiffVAE in this case and directions for future work.

Characterization of cell states

For the zebrafish dataset, we also explored the possibility of changing the state of cells through perturbations on the latent dimension. This could help us learn more about the type of biological changes in gene expression that cause a less specialised cell such as an HSPC to differentiate in a more specialised cell such as a Monocyte. For this, we trained a neural network classifier capable of labelling Monocytes, Neutrophils, Erythrocytes and Thrombocytes using the full gene expression data with 99.5% accuracy. See the Methods section for more details.

Assume that we have identified, that latent dimension j encodes the differentiation of a type of mature blood cells, such as Monocytes. Let μj and σj be the mean and standard deviation of \({{\bf{z}}}_{j}={[{z}_{j}^{\mathrm{(1)}}{z}_{j}^{\mathrm{(2)}}\ldots {z}_{j}^{(N)}]}^{T}\) the predicted value of the encoder for zj across all of the cells in the dataset. We can say that if latent dimension j identifies Monocytes, it means that the ratio of the number of Monocytes in \({{\mathscr{D}}}_{j}=\{{{\bf{x}}}^{(i)}\in {\mathscr{D}}|{z}_{j}^{(i)}\ge {\mu }_{j}+{\sigma }_{j}\vee {z}_{j}^{(i)}\le {\mu }_{j}-{\sigma }_{j}\}\) is larger than for the other cells. This strongly suggests that shifting \({z}_{j}^{(i)}\) by the standard deviation σj of latent dimension j could potentially change the cell x(i) label into a Monocyte.

The method proposed for changing a less specialised cell (an HSPC) into Monocytes involves shifting several of the latent dimensions encoding the differentiation of Monocytes proportionally with their standard deviations. The proportionality factor is the parameter λ. The method is illustrated in Fig. 2 and it can be generalised to any of the mature cell types that the embedding can separate.

Figure 2
figure 2

Methodology proposed for changing the cellular states: HSPCs can be converted into Monocytes by shifting the latent dimensions differentiating Monocytes by a factor λ multiplied with their standard deviation. Increasing the shifting parameter λ will result in more of the HSPCs to be subsequently classified as Monocytes.

Figure 3 shows the results after performing this kind of perturbations to change HSPCs into all of the mature blood cells in our dataset. For each cell type, we shifted the top 5 latent representation encoding their differentiation. We illustrate the results for both λ = 0.5 and λ = 1.0. We notice that there is a difference in how easy is to change the HSPCs into the different mature blood cell types which indicates that there is some heterogeneity among the HSPCs. In this context, we can also learn the minimum number of genes that need to be modified to change the cell type, and particularly, to determine which genes get upregulated and which genes get downregulated in this processes.

Figure 3
figure 3

Results obtained after performing cell perturbations. We show in colour the cells of interested for each subfigure and in grey the rest of the cells. Each subfigure indicates how many of the HSPCs were converted into each type of mature blood cell after performing perturbations to the latent representations of DiffVAE. Notice that increasing the shifting parameter λ in the perturbations will result in more cells to be changed.

Let x(i) be the input gene expression measurements for cell (i). After performing the perturbations on the latent representation z(i) of cell (i) and putting the results through the decoder, we obtain the reconstructed gene expression measurement y(i). Assume y(i) is then classified by the neural network as a mature blood cell. By looking at the difference y(i) − x(i) we can learn which genes have changed the most in the process of performing perturbations on the latent dimension. Then, by only changing the expression of these genes in x and leaving the other ones the same, we can compute the maximum number of genes that need to be changed to reprogram HSPCs into mature blood cells.

For the shifting parameter λ = 1, we analyze the HSPCs that were classified as mature blood cells after the operations on their latent representation. Our analysis shows that for the 175 HSPCs that were converted into Erythrocytes, we needed to change a relatively small number of genes (up to 25) for each HSPC. Conversely, the 70 HSPCs changed into Monocytes, gene perturbations needed to be performed on ~70% of the genes to change the cell type. For the 70 HSPCs that were classified as Thrombocytes after the perturbations, 50% of the cells were changed with modifications to only 20 genes; to change 100% of the 70 HSPCs into Thrombocytes almost all of the genes needed to be perturbed. Finally, only 3 HSPCs were changed into Neutrohpils so we did not perform any further analysis in this case. We would like to emphasize that these results are entirely computational and show how performing perturbations on the latent representation obtained from DiffVAE allows us to explore changing cell states. Nevertheless, biological experiments are required to validate such hypothesis generated by DiffVAE about cell reprogramming.

Comparison of DiffVAE with other dimensionality reduction methods

After performing dimensionality reduction, standard single-cell RNA-seq workflows for identifying cell types involve clustering of the lower representation obtained for the gene expression data30. Using the zebrafish cell types found by Athanasiadis et al.1 and the human pancreatic cell types found by Muraro et al.2 as true labels, we compare DiffVAE with a standard variational autoencoder (VAE), a simple autoencoder (AE) and Principle Component Analysis (PCA) in terms of clustering performance. Their performance is compared using two clustering algorithms that use different approaches in defining clusters: k-means and DBSCAN. We will cluster both the raw data obtained through dimensionality reduction for m {20, 50, 100} latent dimensions, as well as the 2-dimensional embedding produced using t-SNE.

For each setting of m (size of latent dimension), the clustering algorithms (including the computation of the t-SNE embedding) were performed 50 times and each time the ARI between the true labels and the cluster labels was computed. The results reported in Table 2 represent mean ARI obtained on the zebrafish dataset. See Supplementary Table 2 for the results on the dataset with human pancreatic cells. For both datasets, the representation built by DiffVAE gives the best overall clustering performance. In addition, computing the t-SNE embedding on top of the latent representation improves the clustering results.

Table 2 Zebrafish.

Exploring links between cells

In this section, we shift the focus from just modelling the stochastic behaviour of gene expression across cell types and we also explore modelling the relations between different cell types. For this purpose, we propose Graph-DiffVAE, a graph variational autoencoder where the encoder and the decoder networks are graph convolutional networks. Graph-DiffVAE is based on the graph variational autoencoder proposed by Kipf and Welling31 and on the Graphite model developed by Grover et al.32.

In this context, we will consider the different cells in the zebrafish dataset as nodes in a graph, represented by the adjacency matrix A. The gene expression measurements for each cell will form the node features X. The encoder part of Graph-DiffVAE takes as input an initial graph structure for the cells and the input node features and computes a latent representation for cell \({q}_{{\boldsymbol{\phi }}}({\bf{Z}}|{\bf{A}},{\bf{X}})\), which in this case will be denoted as latent node features. The decoder uses these latent node features and the initial adjacency matrix to predict additional links between the cells, which will be similar to the ones in the input graph.

The input graph can depend on specific applications. One option is to incorporate biological knowledge in the graph, where for instance, edges can represent potential differentiation trajectories for the cells in the dataset. The proposed architecture and training objective for Graph-DiffVAE results in additional edges between cells to be predicted in the output adjacency matrix. The predicted relationships between cells are similar to the ones in the initial graph given as input to the model.

In this paper, we aim to show a proof of concept for using Graph-DiffVAE with single-cell gene expression data. Thus, we propose building an initial graph for the cells where there is an edge between each cell and the cell most similar to it. For this purpose, we will use the Pearson correlation coefficient to measure the similarity between cells. This initial graph is undirected and is represented by a binary adjacency matrix where 1 indicates that there is an edge between two nodes (cells). For each cell in the dataset, we computed the Pearson correlation coefficient between its gene expression vector and the feature vectors of the rest of the cells in the dataset and we added an edge to connect it to the highest positively correlated cell.

Figure 4a illustrates the pipeline for using Graph-DiffVAE and Fig. 4b. shows the 2-dimensional t-SNE embedding of the node features predicted by Graph-DiffVAE, as well as specific links (both initial and predicted) between the HSPCs and differentiated cells. Figure 4c shows the adjacency matrix for the zebrafish dataset used as input to Graph-DiffVAE and Fig. 4d illustrates the predicted adjacency matrix. In Fig. 4b, it is noticeable that the latent representation built in the encoder exhibits a clustering structure between the different types of cells. This is expected and validates the behaviour of the model, as the cells that are highly correlated to each other are more likely to be part of the same cluster. We can also notice that having an initial edge between these types of cells encourages the prediction of similar types of edges. Moreover, in Fig. 4d this clustering behaviour represented in the encoder is emphasised in the output of the decoder, which predicts relatively well-defined clusters for the Monocytes, Neutrophils, Erythrocytes and Thrombocytes.

Figure 4
figure 4

Methodology proposed for analyzing links between cells. (a) Graph-DiffVAE uses an initial adjacency matrix and individual node features to predict more links between cells. (b) Projection of cells onto 2-dimensional t-SNE embedding of the latent node features learnt by Graph-DiffVAE and illustration of initial and predicted links between HSPCs and differentiated cells. (c) Adjacency matrix with input links between cells (the colour white indicates an input edge); each cell is connected to the highest positively correlated cell. (d) Adjacency matrix with predicted links between all cells by Graph-DiffVAE (the colour white indicates a predicted edge). (e) Co-expression matrix between all cells; each entry represents the absolute value of the Pearson correlation coefficient.

An interesting aspect of the predicted graph in Fig. 4c is that the HSPCs do not cluster together well. In particular, there are clear links between several HSPCs and all of the other cells in a cluster of mature blood cells. This means that among the HSPCs there are cells that have already started the process of differentiation towards one of the specific mature cells. Additionally, we can notice that Graph-DiffVAE predicted more edges between HSPCs and Erythrocytes compared to the other differentiated cells.

For comparison in Fig. 4e we also illustrate a co-expression matrix built by computing the absolute value of the Pearson correlation between all cells. While the co-expression matrix also shows clustering behaviour, the clusters are less well-defined. Using different thresholds for the correlation to select edges in the co-expression matrix will also result in different connections between cells. Moreover, note that such co-expression matrix only accounts for linear relationships between cells, while Graph-DiffVAE allows us to model non-linearities.

Another important difference is that the predictions of Graph-DiffVAE are highly dependent on the input graph. If prior biological knowledge is available about existing links between cells, this can be incorporated into the input graph. Based on this, Graph-DiffVAE will be able generate hypothesis about other links between cells that share the same biological meaning as the input ones.


In this paper, we explored unsupervised generative and graph representation learning methods for modelling single-cell gene expression data and understanding cell differentiation by developing the DiffVAE and Graph-DiffVAE models. The two different models succeed in characterising different states of cell differentiation based on single-cell RNA-Seq data. We illustrated how to identify cell types using DiffVAE through a pipeline that involves clustering the latent representation, detecting important genes for each cluster and mapping from clusters to cell type. Many of the high-weight genes found by DiffVAE are well-known in the literature as key haematopoietic genes. In addition to these “usual suspects”, our method identified a number of novel genes that can be further explored for their role during cell differentiation. We have also shown that the pipeline is applicable to datasets of different nature, providing powerful insight into the noisy information concealed by single-cell genetic data.

The embeddings obtained from DiffVAE can be used to generate artificial samples, allowing further exploration and expansion of the current datasets. Moreover, we explored perturbations over the generative latent space to then analyse the effect on the gene expression and changes in cellular states. The computational results on performing perturbations can help us understand better how easy/difficult it is for the hematopoietic stem and progenitor cells to change into differentiated cells. Additional information can be gained in terms of the number of genes that need to be up-regulated or down-regulated to change cellular state. That can lead to future studies on the stability of cellular states, and robustness over genetic stochasticity.

Through Graph-DiffVAE we explored a way of understanding the connections between cells, and in particular between HSPCs and differentiated cells. Investigating the predicted links of Graph-DiffVAE between HSPCs and the other differentiated cells could inform us about the HSPCs which have already chosen a lineage and have started differentiated. For instance, if in the adjacency matrix predicted by Graph-DiffVAE, an HSPC cell is strongly connected to differentiated cells of a single type, such as Erythrocytes, we can hypothesize that this HSPC cell may also differentiate into an Erythrocyte. Similarly, if the HSPC cell is connected to two types of differentiated cells that might indicate that the HSPC cell has the potential to become either of these two differentiated cells. By analyzing the patterns in the gene expression of these types of cells may allow us to distinguish between the truly stem cells and the cells that have already started the differentiation process. Further analysis in this direction can also allow us to better understand cell differentiation trajectories. From a methodological perspective, future work could involve combining DiffVAE and Graph-DiffVAE into a single multitask learning framework33 and using Graph Attentional Layers34 as part of Graph-DiffVAE to better quantifying the importance of the links between cells.



DiffVAE receives as input expression levels from k genes. The autoencoder model was constructed such that both the encoder and decoder consist of two fully connected hidden consisting of h1 and h2 neurons respectively. The incorporation of multiple hidden layers helps to build a hierarchical representation of features, thus obtaining a more complex model. The size of the hidden layers is symmetric between the encoder and decoder. The latent representation z has m dimensions. The ReLU activation was applied in the hidden layers of both the encoder and decoder in order to introduce non-linearity in the network. The specific operations performed by DiffVAE are as follows:

Encoder (Inference model): \({q}_{\phi }({\bf{z}}|{\bf{x}})={\mathscr{N}}({\rm{z}};{\boldsymbol{\mu }},{\rm{d}}{\rm{i}}{\rm{a}}{\rm{g}}({{\boldsymbol{\sigma }}}^{2}))\)

The encoder consists of fully connected layers and has a Gaussian output. For numerical stability, the encoder network learns log(σ2) instead of σ2. The input to the encoder is \({\bf{x}}\in {{\mathbb{R}}}^{1\times 1845}\), which, in our case, represents the gene expression data. The operations performed by the encoder network are summarised by:

$${\boldsymbol{\mu }}={\rm{ReLU}}({{\bf{W}}}_{\mu }{{\bf{x}}}_{{\rm{enc}}}^{\mathrm{(2)}}+{{\bf{b}}}_{\mu }),$$
$$\log \,{{\boldsymbol{\sigma }}}^{2}={\rm{ReLU}}({{\bf{W}}}_{\sigma }{{\bf{x}}}_{{\rm{enc}}}^{\mathrm{(2)}}+{{\bf{b}}}_{\sigma }),$$

where \({{\bf{W}}}_{{\rm{enc}}}^{\mathrm{(0)}}\in {{\mathbb{R}}}^{k\times {h}_{2}},{{\bf{b}}}_{{\rm{enc}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{1\times {h}_{2}},{{\bf{W}}}_{{\rm{enc}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{{h}_{2}\times {h}_{1}},{{\bf{b}}}_{{\rm{enc}}}^{\mathrm{(2)}}\in {{\mathbb{R}}}^{1\times {h}_{1}},{{\bf{W}}}_{\mu }\in {{\mathbb{R}}}^{{h}_{1}\times m},{{\bf{b}}}_{\mu }\in {{\mathbb{R}}}^{1\times m},{{\bf{W}}}_{\sigma }\in {{\mathbb{R}}}^{{h}_{1}\times m},\) \({{\bf{b}}}_{\sigma }\in {{\mathbb{R}}}^{1\times m}\)are the trainable parameters in the encoder. The encoder also uses batch normalization35 to overcome the problem of internal covariate shift.

Directly sampling the latent representation z can cause problems to the standard gradient-based algorithm, as it is not possible to compute gradients through the random sampling of z. To overcome these issues, Kingma and Welling19 proposed the reparameterisation trick that involves parameterising the latent code as follows:

$${\boldsymbol{\varepsilon }} \sim {\mathscr{N}}({\bf{0}},{\bf{I}}),\,{\bf{z}}={\boldsymbol{\mu }}+{\boldsymbol{\varepsilon }}\odot {\boldsymbol{\sigma }}\mathrm{}.$$

Decoder (generative model): \({p}_{\theta }({\bf{x}}|{\bf{z}})\)

The output of the decoder has to reward the likelihood of the data we want to generate with this model. In our case, for each data point, the gene expression values can be modelled as samples from a multivariate Bernoulli distribution. Intuitively, each input gene is modelled as a Bernoulli random variable, and a sample from this distribution indicates whether the gene is expressed or not. To build a decoder with Bernoulli output, we need to apply the logistic activation function to compute the output of the decoder because it takes values in the range [0, 1].

The input to the decoder is the latent representation z. The decoder performs the following operations in order to obtain the reconstructed input x’:

$${\bf{x}}{\prime} =\sigma ({{\bf{W}}}_{{\rm{out}}}{{\bf{x}}}^{\mathrm{(2)}}+{{\bf{b}}}_{{\rm{out}}}),$$

where σ is the logistic activation function and \({{\bf{W}}}_{{\rm{dec}}}^{\mathrm{(0)}}\in {{\mathbb{R}}}^{m\times {h}_{1}},{{\bf{b}}}_{{\rm{dec}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{1\times {h}_{1}},{{\bf{W}}}_{{\rm{dec}}}^{\mathrm{(1)}}\in {{\mathbb{R}}}^{{h}_{1}\times {h}_{2}},{{\bf{b}}}_{{\rm{dec}}}^{\mathrm{(2)}}\in {{\mathbb{R}}}^{1\times {h}_{2}},\) \({{\bf{W}}}_{{\rm{out}}}\in {{\mathbb{R}}}^{{h}_{2}\times k},{{\bf{b}}}_{{\rm{out}}}\in {{\mathbb{R}}}^{1\times k}\)are the trainable parameters in the decoder. As x’ is not sampled, we provide a maximum likelihood estimate for the reconstruction.

DiffVAE was trained using minibatch stochastic gradient descent to minimize \(-{{\mathscr{L}}}_{{\rm{DiffVAE}}}\):

$${{\mathscr{L}}}_{{\rm{DiffVAE}}}({\boldsymbol{\theta }},{\boldsymbol{\phi }};{\bf{x}})={{\mathbb{E}}}_{{{\boldsymbol{q}}}_{{\boldsymbol{\phi }}({\bf{z}}|{\bf{x}})}}[\log \,{p}_{{\boldsymbol{\theta }}}({\bf{x}}|{\bf{z}})]-{\rm{MMD}}({{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({\bf{z}})\Vert {{\boldsymbol{p}}}_{{\boldsymbol{\theta }}}({\bf{z}})),$$

We used the Adam Optimizer36 and we trained DiffVAE for 100 epochs. The learning rate and batch size were selected as part of the hyperparameter optimization process.

Hyperparameter selection for DiffVAE

DiffVAE consists of the following hyperparameters that need to be optimized before using it a new dataset: number of neurons in the hidden layers (h1 and h2), latent representation size (m),learning rate (α) and batch size (B). The original dataset was split such that 80% of the data points were used for training and 20% for validation. Each hyperparameter was optimised in a range of possible values, with the final one chosen based on the validation loss. A similar approach for hyperparameter optimization has been used by other methods modelling single-cell gene expression data using deep generative architectures37.

The possible values used for each hyperparameter are as follows: number of neurons h1 {256, 128}, latent representation size m {50, 100}, learning rate α {0.01, 0.001, 0.0001} and batch size B {64, 128, 256}. We set h2 = 2 · h1. The final hyperparameter values used for the zebrafish dataset are h1 = 256, h2 = 512, m = 50, α = 0.001 and B = 128. For the dataset with human pancreatic cells we used h1 = 256 and h2 = 512, m = 100, α = 0.001 and B = 256. Finally, for the dataset with human hematopoietic cells, we used h1 = 256 and h2 = 512, m = 50, α = 0.001 and B = 128.

To apply DiffVAE to a new dataset, a similar approach as described in this section can be used for selecting the most appropriate hyperparameters.

Additional implementation details for identifying cell types

In the process of identifying cells using DiffVAE additional design choices came into play in terms of the number of latent dimensions and the number of high weight genes to select for each cluster. In the dataset with zebrafish cells, we have used the top 10 latent dimensions that were encoding the differentiation of cells in each cluster and for each latent dimension, we have investigated the top 3 high weight genes. On the other hand, for the dataset with human pancreatic cells, due to the large number of clusters, we have only used the top 5 latent dimensions with the top 5 high weight genes. In Table 1 and Supplementary Table 1, we have reported the high weight genes that were common among the selected latent dimensions.

The main purpose of this process is to be able to identify the cell types in the dataset. In practice, the choice for the number of latent dimensions to analyze should depend on their ability to differentiate between the cells in the different clusters, as well as on the total number of clusters. Similarly, the selected number of high-weight genes should provide enough information to map clusters to cell types, but also allow for uncovering biological knowledge about new marker genes important for the differentiation process.

Additional models used as benchmarks

To assess the performance of DiffVAE on clustering, we compare it against the following benchmarks: standard VAE, standard autoencoder and PCA. For the standard VAE and the standard autoencoder, we used the same number of layers and neurons as we used for DiffVAE. For both models, we also used the Adam Optimizer with a learning rate of 0.001, a batch size of 128 and we trained them for 100 epochs.

The objective function maximised by the variational autoencoder is:

$${{\mathscr{L}}}_{{\rm{VAE}}}({\boldsymbol{\theta }},{\boldsymbol{\phi }};{\bf{x}})={{\mathbb{E}}}_{{q}_{{\boldsymbol{\phi }}({\bf{z}}|{\bf{x}})}}[\,\log \,{p}_{\theta }({\bf{x}}|{\bf{z}})]-{D}_{KL}({{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({\bf{z}}|{\bf{x}})\Vert {{\boldsymbol{p}}}_{\theta }({\bf{z}}))\mathrm{}.$$

The standard autoencoder model is trained to minimise the reconstruction error, which can be measured by the mean squared error loss function \(l({\bf{x}},\omega )={\Vert {\bf{x}}-dec(enc({\bf{x}}))\Vert }^{2}\), where ω consists of all of the parameters in the encoder and decoder networks.

Neural network for characterizing cell states

The neural network trained for classifying the mature cell types takes as input the gene expression data for the cell and outputs the probability that the particular cell is a Monocyte, Neutrophil, Thrombocyte or Erythrocyte. The model consists of three hidden layers of sizes 256, 512, 256 neurons with ReLU activation and an output layer with 4 neurons and softmax activation. The model was also trained using the Adam Optimizer for 300 epochs with a learning rate of 0.001 and a batch size of 128. The model architecture was chosen through hyperparameter optimization similarly to DiffVAE.


Let \({\mathscr{G}}=({\mathscr{V}}, {\mathcal E} )\), with \(|{\mathscr{V}}|=N\) be an undirected and unweighted initial graph built from the cells, defined by the binary adjacency matrix \({\bf{A}}\in {\{0,1\}}^{N\times N}\). Such an initial graph for the cells can be already available or it can be artificially built using the Pearson correlation between cells.

Let X be an N × F matrix consisting of node features, where F is the number of features for each node. In our case, the nodes are the different cells in the dataset, and the features are represented by the gene expression for each cell. Assume that each node is connected to itself, so that A has diagonal entries Aii = 1. Let D be the diagonal degree matrix of A: \({D}_{ii}={\sum }_{j}\,{A}_{ij}\).

Graph convolutional networks (GCN) were proposed by Kipf and Welling38 with the following layer wise propagation rule:

$${{\bf{X}}}^{(l+\mathrm{1)}}=\tau (\tilde{{\bf{A}}}{{\bf{X}}}^{(l)}{{\bf{W}}}^{(l)})={{\rm{GCN}}}_{\tau ,l}({\bf{A}},{{\bf{H}}}^{(l)}),$$

where \(\mathop{{\bf{A}}}\limits^{ \sim }={{\bf{D}}}^{-1/2}{\bf{A}}{{\bf{D}}}^{-1/2}\), τ is the activation function applied, X(l) and X(l+1) are the activations of the layers (l) and (l + 1) respectively, and W(l) are the weights. X(0) represents the input feature matrix X. The size of layer nl, represents by the number of node features computed at layer (l).

Through the layer-wise propagation rule, the graph convolutional network performs spectral graph convolutions. The model can be regarded as the differentiable and generalised version of the algorithm proposed by Weisfeiler-Lehman on graphs39. In particular, the layer-wise propagation rule can be viewed as a message passing computation over the graph structure. Through one hidden layer, nodes in the graph pass information about their local structure to neighbours that are 1-hop away. Based on the information received from the neighbours, the nodes update their node features.

Graph-DiffVAE is a graph variational autoencoder where the encoder and the decoder networks are graph convolutional networks applying the layer-wise propagation rule. The architecture of Graph-DiffVAE is based on the ones in the graph variational autoencoder proposed by Kipf and Welling31 and in the Graphite model developed by Grover et al.32.

Inference model (encoder): q ϕ(Z|A, X)

The encoder in Graph-DiffVAE is represented by a graph convolutional network with multiple layers and with Gaussian output. The input to the encoder consists of the matrix with node features X and of the graph adjacency matrix A. The layers in the encoder network perform the following operations:

$${{\bf{X}}}_{{\rm{enc}}}^{\mathrm{(1)}}=GC{N}_{{\tau }_{1}\mathrm{,1}}({\bf{A}},{\bf{X}}),$$
$${\boldsymbol{\mu }}=GC{N}_{{\tau }_{2},\mu }({\bf{A}},{{\bf{X}}}_{{\rm{enc}}}^{\mathrm{(1)}}),$$
$$\log \,{{\boldsymbol{\sigma }}}^{2}=GC{N}_{{\tau }_{2},\sigma }({\bf{A}},{{\bf{X}}}_{{\rm{enc}}}^{\mathrm{(1)}}).$$

where τ1 is the ReLU activation function and τ2 is the linear activation function. The number of node features computed in \({{\bf{X}}}_{{\rm{enc}}}^{\mathrm{(1)}}\) is 512 and the number of node features in μ and σ2 is 50, thus forming a latent representation Z where each node has M = 50 features.

The encoder represents a factorised multivariate Gaussian distribution, such that:

$${{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({\bf{Z}}|{\bf{A}},{\bf{X}})=\mathop{\prod }\limits_{i=1}^{N}\,{{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({{\bf{z}}}_{i}|{\bf{Z}},{\bf{A}}),\,{{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({{\bf{z}}}_{i}|{\bf{X}},{\bf{A}})={\mathscr{N}}({{\bf{z}}}_{i}|{\mu }_{i},{\rm{diag}}({\sigma }_{i}^{2})).$$

The reparametrisation trick is used again to sample each zi:

$${\boldsymbol{\varepsilon }} \sim {\mathscr{N}}({\bf{0}},{\bf{I}}),\,{{\bf{z}}}_{i}={{\boldsymbol{\mu }}}_{i}+{\boldsymbol{\varepsilon }}\odot {{\boldsymbol{\sigma }}}_{i}.$$

It is important to notice that the latent representation Z build by the Graph-DiffVAE encoder contains information from both the graph structure and the node features. Z encompass the latent representations for each node in the graph.

Generative model (decoder): pϕ(A|Z, X)

The output of the decoder is an adjacency matrix \(\hat{{\bf{A}}}\) representing an undirected and unweighted graph with predicted edges between nodes. Such an adjacency matrix can be represented by a factorised Bernoulli distribution.

The decoder network uses as input the initial adjacency matrix A and a concatenation of the input node features X and the latent node features computed by the encoder Z, described by [Z|X]. The layers in the decoder network perform the following operations:

$${{\bf{X}}}_{{\rm{dec}}}^{\mathrm{(1)}}=GC{N}_{{\tau }_{1}\mathrm{,1}}({\bf{A}},[{\bf{Z}}|{\bf{X}}]),\,{{\bf{Z}}}^{\ast }=\frac{1}{2}({\bf{Z}}{\prime} +{\bf{Z}}),$$
$${\bf{Z}}{\prime} =GC{N}_{{\tau }_{1}\mathrm{,2}}({\bf{A}},{{\bf{X}}}_{{\rm{dec}}}^{\mathrm{(1)}}),\,\hat{{\bf{A}}}=\sigma ({{\bf{Z}}}_{i}^{\ast }{{\bf{Z}}}_{j}^{\ast }),$$

where τ1 is the ReLU activation function. The number of node features computed in \({{\bf{X}}}_{{\rm{dec}}}^{\mathrm{(1)}}\) is 512. The decoder builds its own latent representation Z′ consisting of 50 node features which it then adds to the representation constructed through the encoder to obtain Z*.

Similarly with the standard framework of the variational autoencoder, we optimize the following objective for Graph-DiffVAE:

$${{\mathscr{L}}}_{{\rm{Graph}}-{\rm{DiffVAE}}}({\boldsymbol{\theta }},{\boldsymbol{\phi }};{\bf{X}},{\bf{A}})={{\mathbb{E}}}_{{{\boldsymbol{q}}}_{{\boldsymbol{\phi }}({\bf{Z}}|{\bf{A}},{\bf{X}})}}[\,\log \,{{\boldsymbol{p}}}_{{\boldsymbol{\theta }}}({\bf{A}}|{\bf{Z}})]-{\rm{KL}}({{\boldsymbol{q}}}_{{\boldsymbol{\phi }}}({\bf{Z}}|{\bf{X}},{\bf{A}})\Vert {{\boldsymbol{p}}}_{{\boldsymbol{\theta }}}({\bf{z}})),$$

The model was trained for 200 epochs with a learning rate of 0.0001. The proposed architecture and training objective for Graph-DiffVAE results in additional edges between cells to be predicted in the output adjacency matrix \(\hat{{\bf{A}}}\). The predicted relationships between cells are similar to the ones in the initial graph given as input to the model.


The code for DiffVAE and Graph-DiffVAE is publicly available in the GitHub repository: