Introduction

Cell state atlases constructed with single-cell RNA-seq, ATAC-seq and multimodal technologies reveal a multiplicity of stable states and interconnected differentiation trajectories1,2,3,4,5,6. Combined with perturbations, including gene knockouts7,8,9, drug treatments10,11 and mutations12,13, atlases can bring to light deep insights into the roles of transcription factors and signaling pathways in developmental processes, cancer plasticity, and other cell state transitions8,14,15. Perturbations can cause alterations in the structure of cell state atlases, including changes in cell population distributions, depletions, or enrichments of certain cell states, and possibly the emergence of wholly new cell states and differentiation trajectories. Comparing changes in cell population structures with unperturbed biology enables the mechanistic dissection of perturbation effects. However, perturbation experiments often require single-cell assays to be carried out in multiple batches, which can introduce technical distortions to the data16. This is especially problematic when batches contain different cell state population distributions or capture new and different cell types, which leads to technical effect confounding of the underlying batch-specific biological variation. Because current batch correction methods perform poorly when presented with batch-confounded cell states, cells in perturbation atlases are not typically compared directly, but are analyzed via projections onto reference wild-type cell state maps8,15. This practice precludes the possibility of directly observing alternative cell states and differentiation trajectories related to perturbations. Thus, batch correction methods that can remove technical artifacts while preserving biological cell states, especially those which are batch-confounded, would greatly enhance the power of comparative single-cell analysis.

To mitigate confounding technical effects, several batch correction methods have been proposed17,18,19 which fall broadly into the categories of generative and latent space merging models. Most generative models for batch effect removal are derivatives of factor analysis or the variational autoencoder (VAE) framework. This includes VAE-based scVI20, which has been shown to be one of the most effective approaches in single-cell RNA-seq atlas benchmarking tests21. scVI parametrizes the distribution of observed counts using a deep neural network conditioned on the joint distribution of latent variables and cell batch labels but makes no attempt to identify or separate the biological from the technical components of those observations. Other generative model-based methods include the semi-supervised VAE scANVI18 and factor analysis-based ZINB-WaVE22. In addition, methods that use the generative adversarial network and maximum mean discrepancy frameworks23 have been proposed, but these models require that each batch has the same cell population distribution. This is both difficult to assess apriori and inappropriate to assume when integrating perturbation datasets as considered here.

Alternatively, most effective latent space merging models use mutual nearest neighbors between datasets to find shared cell states between batches. The MNNs are then used to calculate a nonlinear projection to reduce the distance between batches in some latent space. This class includes popular methods such as Seurat v324, FastMNN17, and Scanorama25. Finally, the Harmony19 method removes technical effects from data by using cross-dataset fuzzy clustering to iteratively merge clusters of cells predicted to be in similar states. Neither current generative models nor latent space merging methods admit a direct explanation of how technical distortions influence the data. Unregularized estimations of technical effects can lead to over-correction or the misidentification of biological signals as technical effects. This may be why these methods struggle to detect batch-confounded cell states and are insensitive to differences in cell population distributions across batches.

Another important aspect of comparative atlas analysis is the deconvolution of the effects of different perturbations. Typically, analysis of single-cell RNA-seq and ATAC-seq atlases involves the representation of high-dimensional data in terms of low-dimensional latent spaces. Topic models26,27,28 and matrix factorization methods29,30 infer interpretable latent space representations that correspond to modules of co-regulated genes or co-accessible peaks. These modules may be useful for capturing and explaining the influence of different perturbations on gene regulatory programs. When considering batched data, however, differences in technical factors induced by single-cell experimental protocols impose a layer of batch-dependent distortion over the modular biological effects. Naively applied, matrix factorization methods cannot distinguish between biological and technical sources of variation, which results in the discovery of modules that may be contaminated by technical effects.

To mitigate current challenges in perturbation atlas analysis, we propose a variational autoencoder31-based statistical model and novel parameter inference procedure which extends interpretable topic modeling to batched single-cell data. This method, called CODAL (COvariate Disentangling Augmented Loss) explicitly disentangles factors related to technical and biological effects, decomposes biological effects into interpretable modules, detects batch-confounded cell states, and represents cells in a batch-corrected, yet cell-type discriminative, latent space. Our approach can be applied to single-cell RNA-seq, ATAC-seq, and to each modality within true multimodal RNA-seq plus ATAC-seq data. We benchmark the method using rigorously defined standards32 and demonstrate its capacity for batch-confounded cell type discovery when applied to simulated datasets and on embryonic development atlases with gene knockouts8. In the integrated regulatory analysis of true multimodal data21, we show that CODAL batch correction improves the representation of RNA-seq and ATAC-seq modalities and enables the generalization of other count-based generative models to multi-batched data33. Overall, CODAL delivers robust technical effect correction and representation for datasets with varying degrees of confounded cell population differences, dataset size, and technical or biological complexity (Supplementary Table 1).

The CODAL model architecture extends MIRA28, our earlier method for variational topic modeling of single-cell RNA-seq and ATAC-seq data. CODAL includes additional modules to facilitate batch effect correction using our new objective function. Furthermore, we designed new highly scalable and parallelizable automated hyperparameter selection and model training procedures which tailors the CODAL method to the properties of the dataset at hand. For fixed model size, CODAL training time scales linearly with the number of cells modeled, while memory usage is kept constant using an efficient minibatch stochastic gradient descent algorithm which streams data from a fast-loading on-disk cache (Supplementary Fig. 1). CODAL is available as an open-source Python package at https://mira-multiome.readthedocs.io.

Results

Framework for disentangling biological and technical effects

In the analysis of a multi-batch single-cell RNA-seq or ATAC-seq experiment, we explicitly decompose the variation in observed read counts into biological and technical components (Fig. 1a). We use a variational autoencoder-based implementation of Latent Dirichlet Allocation34 to further factorize biological variation into latent variables, or “topics”, Z, and matrix of linear feature associations, β. The topic compositions form a low-dimensional latent space that represents the cell states observable in the given population of cells. The linear association matrix constitutes modules of covarying biological quantities, gene expression or chromatin accessibility, that are evident in the data.

Fig. 1: CODAL disentanglement model overview and comparison to standard variational autoencoder.
figure 1

a Decomposition of observed gene expression counts into biological and technical components. Gene expression follows some low-dimensional manifold, while technical effects are of an unknown functional form. b CODAL model description. \({{{{{\mathcal{D}}}}}}\) is the negative binomial distribution, λ is the biological quantity (gene expression or chromatin accessibility), which is the product of cell latent topics Z and linear gene associations \(\beta\). The technical effect vector t is given by a neural network h with weights \(\phi\), dependent on Z and the cell covariates C (batch of origin, quality control metrics, etc.). Counts are drawn from a distribution parameterized by the sum of biological quantity and technical effects. c (left) CODAL model structure. Dependencies between random variables are indicated by solid arrows. The red dash arrow indicates an association between λ and t that is implied by the dependence of t on Z, confounding the direct estimation of the biological quantities. (right) The CODAL objective function maximizes the marginal log-likelihood of the data minus the mutual information (\(I\)) between biological quantities λ and technical effects t. d Simulated bifurcating differentiation system with batch-confounded cell types. The treatment batch (blue) has elevated expression of gene A relative to the wild-type batch (red) after a bifurcation in state. e Expression rates for gene A in simulated system estimated using model trained with ELBO versus CODAL objective. (top) The standard VAE objective (ELBO, marginal likelihood maximization only) yields poor estimates of gene A’s expression level. Changes in expression are entangled with changes in technical effects. (bottom) Mutual information regularization disentangles the expression rate of gene A from the technical effects, yielding expression estimates which match the data generation procedure. Using the CODAL objective, the mutual information between estimated gene expression and technical effects is minimal. Source data are provided as a Source Data file.

Technical effects are well known to confound the interpretation and comparison of single-cell datasets35. Although various types of technical distortions in RNA-seq and ATAC-seq have been described36,37, each step in long single-cell protocols can contribute to such biases. The overall effect of all technical artifacts has therefore not been systematically characterized. Nevertheless, cells from the same batch, sharing protocol conditions and reagents, are subject to more similar technical effects than cells from different batches. Therefore, we assume that variance in technical effects is driven primarily by hidden batch-specific factors that systematically alter the read counts observed in each batch of cells. Although we do not precisely know the identity or effects of the technical factors, the cells’ batch of origin can be used as a proxy variable to indicate which cells were subjected to a protocol using common reagents and under similar conditions.

We also observe that some technical effects appear to depend on both batch and cell state. In other words, some cell types exhibit different degrees of technical effect, even within the same batch. This can arise from cell state-dependent differences in cell size, cytoplasmic or nuclear chemistry, or cell state abundance within a population. Controlling for these state-dependent confounders would render cell states independent of technical effects. Typically, however, these technical factors are unknown and exert their influence on observed counts as contextual interactions between cell state and batch. Taken together, we estimate the distribution of technical effects in each gene in each cell as a function of the cell state random variable and the batch covariate proxies (Fig. 1b). This function is implemented as a neural network because the functional form of technical effects is unknown.

Mutual information-based disentanglement of technical and biological phenomena

We aim to learn the distribution of expression rates for cells in state Z without confounding technical effects. However, from the dependency diagram of the model (Fig. 1c), due to the association between expression rates (λ) and technical effects (t) implied by their dependence on cell state (Z), we find t and λ are not independent. To estimate gene expression under these conditions, we therefore need to make an additional assumption about the relationship between technical effects and biological quantities. Here we propose to approximate the unconfounded distribution of biological quantities by penalizing the dependence between biological quantities (λ) and technical effects (t), through regularization of the mutual information between their distributions. To implement this, we augment the evidence lower bound (ELBO) objective function31 with a lower bound approximation of the mutual information38,39,40. The result is a novel objective function we call the COvariate Disentangling Augmented Loss, or CODAL, which is a differentiable approximation of the sum of the mutual information and the marginal likelihood of the data. Optimizing with this objective yields a generative distribution that explicitly estimates unconfounded biological quantities in cells across batches and explains the influence of technical effects on observed counts in single-cell genomics experiments.

Penalizing mutual information between biological quantities and technical effects encourages the model to learn a function for technical effects which is largely independent of the cell state random variable Z, but still allows for modeling of those technical effects which do appear state-dependent. The CODAL objective, therefore, produces a generative explanation of the data which is a compromise between an idealized representation of technical effects which are assumed to be completely independent of the biological variation, and current methods where technical effect estimates are unconstrained and vary freely with cell state. When confronted with batch-confounded cell types (Fig. 1d, e), the practical implications of mutual information regularization are apparent: the CODAL objective enforces a distribution for technical effects that is not highly dependent on cell state, effectively disentangling the distributions of biological and technical effects, while marginal likelihood maximization (implemented using the “vanilla” ELBO objective) finds a complex technical effect function which confounds the coherent estimation of biological quantities.

To assess the effect of mutual information regularization on gene expression estimates in a batched scRNA-seq dataset, we compared a model with parameters estimated using the standard ELBO objective maximization to one estimated using the CODAL objective. The batched dataset, created for the 2021 NEURIPS Multimodal Single-Cell Data Integration challenge33, was generated by distributing bone marrow samples from multiple donors to multiple laboratories for single-cell RNA-seq and ATAC-seq analysis. The resulting dataset has a hierarchical batch structure, where multiple donors were assayed in different batches at more than one site. Without mutual information regularization, we find that technical effects and expression changes are frequently correlated or anti-correlated. With increasing observed counts, the model confounds batch effects and biological changes (Fig. 2a). For the Ccl5 gene, a marker for NK and CD8+ T cells33,41, the unregularized model erroneously predicted expression to be highest in erythroblast cells. The monocyte marker Tcf7l233,42 likewise shows a pattern of high expression in some cell types that is not supported by the observed counts. Importantly, the solutions produced from optimizing the ELBO objective yield incoherent and entangled explanations for the observed data.

Fig. 2: Mutual information regularization disentangles the influence of biological and technical effects.
figure 2

a Predicted gene expression rates versus predicted technical effects for naïve B-cell, NK/CD8+ T-cell, and monocyte marker genes estimated using a topic model trained with the marginal likelihood maximization objective (\({{{{{{\mathcal{V}}}}}}}_{{{{{{\rm{ELBO}}}}}}}\)). Colored by observed expression counts, expert-annotated cell types, and batches. The batches are labeled according to single-cell assay site (s1-s4) and sample donor (d1-d10). b Predicted expression rates versus technical effects for the same genes, estimated using a topic model trained with mutual information regularization (\({{{{{{\mathcal{V}}}}}}}_{{{{{{\rm{CODAL}}}}}}}\)). Source data are provided as a Source Data file.

In contrast, mutual information regularization yields uncorrelated and disentangled expression rates and technical effects for each marker gene (Fig. 2b), where the known associated cell types are predicted as having the highest expression. Notably, the marginal distribution of expression rates within each cell type is similar across batches, so the cell types are arrayed vertically according to the relative expression levels of genes within these cell types. Batches, meanwhile, are arrayed horizontally according to the basal levels of counts observed within them. The CODAL objective therefore disentangles biological and technical contributions to observed counts and represents those contributions as independent factors.

CODAL latent space demonstration of strong performance on cell-type discrimination and batch correction benchmarks

Applied to the 2021 NEURIPS dataset, CODAL successfully merges batches, finds shared cell types, and distributes those cell types along known paths of hematopoietic differentiation (Fig. 3a, b). Evaluating the quality of the latent space using silhouette width with expert-annotated cell types, we found that cells with the same label tend to be closer together in the CODAL-derived latent space than in the space derived by scVI20, a more traditional VAE model which uses likelihood maximization with unconstrained technical effect modeling (Fig. 3c). Next, we benchmarked CODAL against popular batch correction methods shown to be effective in an extensive benchmarking study21 (Fig. 3d, Supplementary Figs. 2 and 3). Notably, CODAL demonstrated both batch mixing and cell type colocalization comparable to scANVI, which previously demonstrated best-in-class atlas-level integration21 and was fully supervised on cell type label for this test. Through regularization of the technical effect function, CODAL yields more discriminative biological latent space descriptors than VAE and geometry-based models without loss of capacity for technical effect correction.

Fig. 3: Benchmarking demonstrates CODAL’s strong performance relative to popular batch effect correction algorithms.
figure 3

a UMAP representations of NEURIPS bone marrow gene expression data based on CODAL and PCA latent spaces, colored by batch. The batches are labeled according to single-cell assay site (s1-s4) and sample donor (d1-d10). Principal component analysis (PCA) does not correct for technical effects. b UMAPs colored according to expert-annotated cell types. c Cell type silhouette per cell comparing CODAL to scVI latent spaces. Higher cell silhouette widths indicate that cells are closer to other cells with the same cell type label. d Benchmarks of cell type and batch average silhouette widths (ASW) calculated from latent spaces using multiple methods. Increasing cell type ASWs correspond to greater similarities between cells with the same annotations. Decreasing batch ASWs correspond to greater mixing of cells from different batches. scANVI is a semi-supervised method that makes use of cell-type label inputs. MIRA is a baseline topic modeling algorithm that does not correct for technical effects. e Selected topic composition shown on the CODAL latent space-based UMAP. f Topic 14 associations (from \(\beta\) matrix) versus gene expression log-fold changes for CD16+ monocytes relative to other cell types in the dataset. Genes are colored by inclusion in the HuBMAP CD16+ monocyte ontology (black) versus not (gray). The marginal distribution of \(\beta\) associations is shown as a density plot on the y-axis, with a rug plot marking the genes in the CD16+ ontology. Genes in the CD16+ ontology were significantly enriched in the top 200 topic 14-associated genes (p-value=\(1.5\times {10}^{-19}\), one-sided Fisher’s exact test). g Topic 14 composition across CD16+ monocytes, CD14+ monocytes, myeloid progenitors, and other cell types. Source data are provided as a Source Data file.

In addition to yielding a well-separated latent space, CODAL’s latent dimensions are designed to be interpretable, unlike those of deep latent variable models. CODAL’s use of a sparsity-inducing hierarchical Dirichlet prior results in latent dimensions that coherently convey changes in cell-type identity and potentially deconvolve contributions of gene regulatory programs (Fig. 3e, Supplementary Fig. 4). Each latent variable, or topic, is linearly associated with changes in gene expression through the β matrix. These sets of associations capture some covarying element, or module, of gene expression. Topic 14, for instance, precisely described the CD16+ monocyte identity program (Fig. 3f, g), and the captured associations were well correlated with the log-fold change of those genes’ expression in CD16+ monocytes relative to the rest43. While differential gene expression and log-fold change are usually defined by investigator-driven clustering after latent space construction, CODAL topics and their consequent relationships with gene expression changes are learned jointly and directly from the data.

Since the bone marrow dataset is multimodal, assaying gene expression and chromatin accessibility in the same single cells, we also carried out the benchmarking analysis on the ATAC-seq data (Supplementary Figs. 5 and  6a, b), finding CODAL to perform well by several metrics. In addition to categorical indicators of batch, CODAL allows for the inclusion of continuous proxies for technical effects (Supplementary Figs. 6c, d and 7). One possible proxy is the FRiP score (fraction of reads in peaks), a commonly used ATAC-seq quality control metric44. scATAC-seq data is typically analyzed by first calling “peaks”, or frequently accessible loci, from the aggregate profile of reads sequenced across all cells. Individual cells are encoded as vectors of binary variables, indicating the presence or absence of fragments in peaks. The FRiP score is used to identify and remove cells for which observed reads are primarily from noisy background genomic regions and do not contribute to the aggregate read peaks. A peak set derived from the bulk signature of a batched sample will tend to be biased toward the most common cell types and the largest batches, which can influence the distribution of peaks observed in each cell in a jointly batch and cell-specific manner. The FRiP score is therefore an example of a cell state-dependent technical factor in single-cell ATAC-seq analysis. Including the FRiP score in the CODAL analysis, in addition to batch indicators, resulted in a slight improvement in performance (Supplementary Fig. 6a, b) and reduced separation of cell type subpopulations which were affected by large differences in FRiP (Supplementary Fig. 6c). This suggests that further improvements to technical effect disentanglement might be attainable through careful characterization of sources of technical bias in single-cell analysis protocols.

CODAL-enhanced refinement of regulatory models connecting chromatin accessibility and gene expression

While CODAL’s batch-corrected latent space estimation is useful for constructing and merging cell state atlases, its estimates of gene expression, chromatin accessibility and technical effects can be used to improve other types of single-cell analysis. Since mutual information regularization renders CODAL technical effect estimates effectively independent of the biological variation underlying changes in gene expression, we hypothesized these technical effect estimates are “transferable”. This means CODAL-estimated technical effects can be fixed in a second generative model that incorporates explanatory features that were not included in the primary CODAL analysis. The resultant generative distribution will adjust for the confounding technical effects, and its parameters can be estimated using standard likelihood maximization. Therefore, CODAL technical effect estimates could be used to extend other generative models to batched data.

For example, the MIRA software for the integrative analysis of multimodal single-cell RNA-seq and ATAC-seq includes a method for relating gene expression changes with cis-regulatory chromatin accessibility28. This method models the cis-regulatory environment of a gene as a regulatory potential, in which the influence a chromatin accessible genomic interval has on a gene’s expression decays exponentially with the genomic distance between the transcription start site of the gene and the region itself. For each gene, MIRA estimates the decay rates upstream and downstream of the gene along with the relative activities of the upstream, downstream and promoter regions (Fig. 4a). In the previous implementation of this method, the model parameters could be learned from only one batch of single-cell multimodal data as technical effect confounders from both the RNA and ATAC-seq batches would obscure the regulatory relationship or preclude its estimation. We augmented the regulatory potential model by adding CODAL technical effect vectors as a fixed aspect of its generative distribution, then applied the regulatory potential analysis to batches of 10x Genomics Multiome data comprising the bone marrow hematopoiesis dataset33. In principle, the regulatory ranges and activations estimated by the augmented model should be more accurate than the parameters predicted on non-batch-corrected data.

Fig. 4: Disentanglement improves the performance of other count-based single-cell models.
figure 4

a Regulatory potential (RP) model relating Slc25a37 proximal chromatin accessibility with gene expression. Model was trained while adjusting for both chromatin accessibility and gene expression technical effects using pre-trained technical effect vectors from the CODAL model. Fragment plot shows chromatin accessibility dynamics during differentiation from hematopoietic stem cells (HSCs) to normoblasts. Regulatory potential—the expression prediction given by the RP model—correlates with increasing accessibility of nearby loci and with observed expression counts. b Analysis of log-likelihood (ℓ) of RP models trained with and without augmented technical effect correction. Each batch of the bone marrow dataset was split into training and testing datasets. For all highly variable genes, RP models were trained on the training sets of each batch individually, on all batches combined (“All”), or on all batches with technical effect correction (“All + correction”). The likelihood of each subsequent set of RP models was evaluated across each batch’s test set. The cell colors in the heatmap are column-normalized. (right) The “Total ℓ” gives the likelihood of that model across all held-out cells. (bottom) For each batch, the increase in likelihood afforded by technical effect correction was compared against the likelihood of the uncorrected model trained only on that batch. Technical effects explain the difference in likelihood between models trained and tested within a batch versus across all batches. c Comparison of the likelihood of RP models trained with and without technical effect correction for each gene. d For the top 200 genes associated with the proerythroblast-specific topic 4, enrichment of transcription factor motifs within open chromatin regions predicted to be influential on expression of each of those genes by regulatory potential models. Enrichment was calculated using the probabilistic in silico-deletion test (one-sided). The enrichment p-value is shown for each transcription factor when tested against regulatory potential models trained on all batches with technical effect correction (“All + correction”) versus without (“All”). Significantly enriched motifs (\(\alpha=0.05,\) Bonferroni adjusted) are labeled. e Gene expression of each factor whose motif was significantly enriched in (d), for each cell type in the erythroid lineage. Source data are provided as a Source Data file.

We trained uncorrected regulatory potential models on data from each batch alone and across all batches combined. To compare these with CODAL-corrected models trained across all batches, we computed the likelihoods of the models on held-out cells from each batch (Fig. 4b). Differences in technical effects confounded the application of models trained on one batch when tested on another. When trained across all batches without correction, the models suffered from high bias because variation in technical effects between samples was left unexplained. Finally, training across all batches with fixed CODAL technical effect correction vectors gave the most likely model across all batches and broadly across all genes (Fig. 4c), despite having identical trained parameters.

To determine if the improvements in likelihood translate to a better understanding of the biology of the system, we used a procedure called probabilistic in silico deletion to find JASPAR motifs45 enriched within the influential chromatin regions prescribed by regulatory potential functions for 200 genes upregulated in proerythroblast cells (Fig. 4d, e). Regulatory potential models which better explain the true cis-regulation of these genes should contain more relevant motif hits. Without technical effect correction, the only significant motif enrichment (p-value < 0.05 (Bonferroni-corrected), n = 1600) was for TAL1-GATA1 complex motifs, which represent a key transcription factor complex promoting early erythropoiesis46,47. Regulatory potential models trained with CODAL technical effect correction resulted in 3 orders of magnitude increase in significance of TAL1-GATA1 enrichment in addition to significant enrichment for relevant GATA family and TAL1-TCF3 complex motifs46.

These results suggest that the technical effects learned as part of CODAL modeling remove independent and confounding sources of technical variation in the data. Their incorporation into the generative distributions of other models enhances biological signal recovery in multi-batch modeling, improving the performance and interpretability of these models even over single-batch applications.

CODAL-enabled identification of batch-confounded cell types on Frankencell benchmarks

One of the most valuable uses of CODAL is its application to single-cell analyses in which cell types are confounded with batch. In studies of cell state plasticity and differentiation, for example, it is of interest to compare perturbations of biological systems, such as gene knockouts, with each other and with wild-type cells. Such projects are likely to be carried out in a series of batches, with discoveries made in early batches informing later experiments. We constructed gold-standard benchmarking datasets using Frankencell, a program that synthesizes cell differentiation trajectories by sampling reads and cells from real multimodal datasets (Fig. 5a). In these benchmarks we simulated datasets in which wild-type and perturbed cells were collected in different batches. The batch data was derived from real 10x Genomics Multiome batches of data, therefore the data sampled from these batches are affected by real technical effects. We modulated the difficulty of the tests by varying parameters defining the synthetic datasets, including the degree to which each of two terminally differentiated cell types are enriched or depleted from each batch, and mixing proportions determining the states of the synthesized cells. In this way, we varied the extent to which known biological variation was confounded by technical batch effects.

Fig. 5: Mutual information-based disentanglement enables the discovery of batch-confounded cell types.
figure 5

a “Frankencell” synthetic dataset generation and evaluation algorithm. Starting from either of two batches in the NEURIPS bone marrow dataset which exhibit different technical effects, reads from annotated cell clusters were mixed according to a construction plan to create a simulated differentiation trajectory interpolating between cell types. Datasets were composed of a trajectory constructed from both batches to represent known sources of biological and technical variation. By controlling terminal cell states present in each trajectory, we introduced batch-confounded cell types. We also varied the base cell similarity to measure method robustness. The trajectories were then integrated and evaluated against the construction plan using established trajectory comparison metrics. b Results from F1 branch score metric across all tests, colored by method. F1 branch score measures the similarity of predicted cell branch assignments to the construction plan. All branch assignments were calculated using MIRA pseudotime trajectory inference on the integrated trajectory. c Example UMAPs of CODAL and Harmony latent spaces, colored by true cell type and batch of origin. Cells from different batches were slightly offset for readability. Each example was taken from the “Medium” base cell similarity test. Source data are provided as a Source Data file.

We used CODAL, PCA corrected with Harmony, and scVI, to construct latent space representations of the inferred biological states, then compared the resultant space to ground truth trajectories using established trajectory metrics implemented in the dynverse benchmarking package48 (Fig. 5b, Supplementary Fig. 8a). The most difficult tests are “completely confounded”, as the perturbation batch solely produces a new terminal cell type that is not present in the simulated wild-type batch. As expected given the limitations of likelihood-based models, scVI shows deteriorating performance in accordance with the level of batch-confounded biology in the test, while Harmony, which performs particularly well on smaller problems21, has difficulties in the completely confounded samples (Fig. 5b, c). CODAL, meanwhile, is robust to all types of confounding while still merging shared cell types. On these benchmarks, CODAL never failed to solve the correct trajectory.

Next, we used a “completely confounded” Frankencell dataset to investigate how the mutual information regularization term of the CODAL objective function affects representation quality and repeatability in batch-confounded cell systems (Supplementary Fig. 8b). We varied the mutual information regularization strength by introducing a weight to that term of the objective function. For each regularizer strength, we modeled the same dataset ten times with different initial seeds, then assessed the quality of the reconstructed trajectory and the variance of the resulting technical effect estimates. We found modeling the dataset using CODAL’s default regularization strength resulted in a 15-fold reduction in the variance of technical effect estimation versus unconstrained marginal likelihood maximization, while producing the best-reconstructed trajectories (Supplementary Fig. 8c). Of the total variance describing biological and technical effects across ten CODAL models, only 0.14% was attributable to repeated technical effect estimation.

Integrated analysis of wild-type and Tal1 knockout embryonic developmental atlases

Pijuan-Sala et al.8 constructed an extensive single-cell RNA-seq atlas of mouse embryonic development spanning stages E6.5 to E8.5. In addition to the wild-type mouse, this study included an analysis of the impact of knocking out the transcription factor Tal1, which is necessary for embryonic erythropoiesis. The complete knockout of Tal1 is lethal in the mouse embryo, so they created a chimeric mouse embryo combining wild-type and Tal1 knockout cells (Tal1-/-). Wild-type and chimeric embryo data were generated in separate batches, which exhibit severe batch effects on uncorrected data, precluding meaningful direct comparisons between chimeric and wild-type manifolds (Supplementary Fig. 9a). They therefore present an analysis based on the projection of the chimeric cells onto the wild-type manifold. In the resulting projection, no erythroid cells were identified with a Tal1-/- genotype, as expected49, and the mapping resulted in an overabundance of cells mapped as hemato-endothelial progenitors, suggesting the Tal1-/- cells were arrested in a state preceding Tal1 activation. Their analysis of the projected hemato-endothelial progenitor cells revealed transcriptional differences from the wild-type cells with which they were co-embedded. Pijuan-Sala et al. had originally identified three Tal1-/--specific hemato-endothelial progenitor subtypes which also expressed marker genes typically indicative of major mesodermal populations. In particular, they found a mesenchyme-like population expressing marker Tdo2; an allantois-like population expressing Pcolce; and a population expressing cardiac-related genes Nkx2-5, Mef2c, and Tnnt2. The core assumption in the projection-based analysis is that the reference population comprises all cell types present in the projecting population. However, since the Tal1-/- and wild-type hemato-endothelial progenitor cells showed transcriptional differences, there is an indication that cell states exist in the Tal1-/- cells that are rare or non-existent in the wild-type.

Reanalyzing this dataset with CODAL, we constructed a new latent space where the wild-type and chimeric Tal1-/- batches were modeled together (Fig. 6a). Whereas severe batch effects are present in the PCA-based analysis of the combined batches (Fig. 6b), the batches are well integrated in the CODAL analysis. Consistent with Pijuan-Sala et al., we found an abundance of Tal1-/- cells arrested in hemato-endothelial progenitor-like states (Fig. 6c). The CODAL latent space, however, revealed that while these cells were most similar to wild-type hemato-endothelial progenitor cells, they crucially occupied their own latent subspace and showed transcriptional signatures unseen in wild-type cells. After using the Leiden algorithm to cluster Tal1-/- hemato-endothelial progenitors in the CODAL latent space, we identified five distinct subtypes (Fig. 6d) whose numbers were induced or promoted in the chimeric perturbation. We matched expression in these subtypes to wild-type mesodermal and endothelial subtypes through shared marker gene expression and topic composition (Fig. 6e, f, Supplementary Fig. 9b). Populations B, C and E expressed the aforementioned mesenchymal, cardiomyocyte, and allantois markers, respectively, corroborating the subtypes discovered in Pijuan-Sala et al. Furthermore, we found previously undescribed subtype A, with high expression of the endothelial marker Emcn, and ExE mesoderm-like subtype D, with high expression of homeobox transcription factors Hoxb9, Hoxc8, and Cdx4. Crucially, although each subcluster expressed distinct gene sets associated with different mesodermal cell populations, all subclusters expressed high levels of hemato-endothelial progenitor-specific markers Kdr, Sox17, and Esam8. Thus, the perturbation-induced subclusters exhibit combinations of cell identity programs that were not sampled in wild-type development.

Fig. 6: Re-analysis of mouse embryo differentiation using CODAL reveals new cell types in Tal1 knockout chimera.
figure 6

a UMAP of CODAL latent space calculated from the mouse embryo differentiation dataset. From left to right, UMAP colored by cell batch of origin, day of differentiation at which cells were collected, and cell type annotations provided by the authors of the original study. Batches WT1-3 contained only WT cells. Chimera-Control embryos were injected with Tomato-expressing but otherwise unaltered embryonic stem cells (ESCs) early in differentiation. Chimera-Tal1-/- embryos were injected with Tomato-expressing Tal1-/- ESCs. b UMAP of PCA latent space. c CODAL UMAP colored by Tomato presence in Chimera-Tal1-/- batch. Tomato-expressing (Tom+) cells have Tal-/- genotype; Tom- cells have WT genotype. All other batches are shown in gray. d Zoom-in on hemato-endothelial progenitor (HE-prog) cell type population. New cell types present only in Tom+ Tal1-/- perturbed population were subclustered using the Leiden algorithm and shown in shades of red. Previously uncharacterized subclusters are labeled in bold. From the original projection-based analysis, the closest WT cell to each Tal1-/- cell in clusters A-E are shown in blue. e Marker gene expression, column-normalized, comparing gene expression in WT populations with Tal-/- HE-prog subclusters. f Selected topic compositions shown on UMAP. Shared topics indicate covarying gene expression between Tal-/- subclusters and specific mesodermal cell types. g (left) UMAP of CODAL latent space calculated from embryo dataset with contrived batch structure maximizing confounded biology and technical variation. (right) Zoom-in on HE-prog cell type population, colored by Tomato expression in Chimera-Tal1-/- batch. All other batches are shown in gray. Tom- cells were removed to increase confounding effects. h Marker gene expression shown on HE-prog cell type subset of UMAP from (g). Source data are provided as a Source Data file.

The chimeric embryo system is a well-controlled system for studying perturbations because wild-type and perturbed cells can be compared within the same animal and batch. However, this approach cannot be applied to most types of perturbation experiments. As a further test of CODAL’s ability to detect biologically confounded batches, we reanalyzed this dataset after excluding the control Tal1+/+ wild-type cells from the chimeric mouse batch and removing samples that connected batches across time (Supplementary Fig. 10). Even in a dataset with contrived batch effect confounding, we find similar results as before (Fig. 6g, h).

Notably, we can reproduce the original ad hoc projection and differential expression analysis using only the CODAL model’s inferred latent space and gene expression modules. In addition to CODAL latent spaces themselves being powerful representations of biological states, they are also useful for refining downstream analyses such as differential abundance testing and pseudotime trajectory analysis. Single-cell atlas construction using CODAL’s batch correction method is therefore a powerful new strategy for making biological discoveries through the comparison of wild-type and perturbed systems.

Discussion

We describe CODAL, a method for correcting technical batch effects in single-cell RNA-seq, ATAC-seq, and multimodal data while allowing for differences in biological states between batches. CODAL’s variational autoencoder-based framework uses topic modeling to infer interpretable gene co-regulation and transcription factor binding modules in respective single-cell RNA-seq and ATAC-seq datasets. We introduce a novel mutual information-based regularization scheme that disentangles the contribution of biological modules and technical distortions to observed read counts and enables the estimation of gene expression and chromatin accessibility levels independent of a cell’s batch of origin. Empirically, we demonstrate that CODAL produces latent representations of cell state which are both cell-type discriminative and batch-corrected, on par with current state-of-the-art supervised models.

CODAL is particularly well suited for identifying batch-confounded cell states that escape detection by current methodologies. In the comparison of cell state atlases in perturbed model systems, such as gene knockouts and drug treatments, this facilitates the discovery of different biological states and trajectories even when confounded by batch effects. In practice, CODAL enables the integrated analysis of experiments involving a series of batches, where observations from initial batches inform subsequent gene perturbations. This approach is a powerful and practical strategy for analyzing complex cell systems.

CODAL can robustly disentangle technical and biological effects in perturbation experiments when the effect of that perturbation varies according to cell state. To design experiments that satisfy this condition, a simple guideline is to include a control group of cells that do not experience or are known not to respond to the perturbation in both treatment and control batches. In differentiating systems, this may take the form of genetic perturbations which only affect cells in more differentiated states, as demonstrated in the Pijuan-Sala et al. and Frankencell analyses. Drug treatments may be integrated by multiplexing treatment samples with controls or again by measuring a group of cells that are known to be unaffected.

While this enables increased flexibility in experimental design, we caution that the existence of new batch-confounded or perturbation-induced cell states should be validated using orthogonal biological evidence to rule out the under-correction of technical effects. For example, bona fide cell states may have expression patterns that reflect known gene ontologies or demonstrate state-specific accessible chromatin which is enriched for transcription factor motifs. Although CODAL can overcome some of the challenges imposed by confounding batch effects, experimentalists nevertheless must practice good experimental design principles and include appropriate controls50,51 as outlined above.

CODAL models variation in technical effects in a cell state-dependent manner. In addition to using batch identifiers as covariates for batch correction, CODAL can use other indicators of technical effects including single-cell RNA-seq and ATAC-seq quality control metrics, which could enable more accurate inference of chromatin accessibility or gene expression levels. In scATAC-seq data we discovered one such jointly cell-state and batch-dependent technical covariate, the FRiP score. When introduced into the CODAL model, the FRiP score improved technical effect removal in chromatin accessibility data. Further improvements in the single-cell analysis might be achieved through better characterization of the technical sources of bias in conjunction with the CODAL model.

Finally, CODAL explicitly estimates the contribution of technical effects in the generative distribution of read counts. Because the inferred technical effects are effectively independent of expression levels, one may use them to augment other generative models of read counts with batch-corrective power. CODAL’s correction of technical artifacts through multi-batch integration can therefore be used as a preprocessing step to improve the performance of other single-cell analysis methods. This could expand the application of generative models that are typically limited to single-batch analysis to include multi-batch datasets. Taken together, CODAL is a useful and much-needed tool for the study of comparative single-cell atlases and for the mechanistic dissection of dynamic cell systems.

Methods

We describe the COvariate Disentangling Augmented Loss (CODAL) method for the integration of batched single-cell RNA-seq or single-cell ATAC-seq experiments. CODAL specifies a generative model of observed counts which explicitly represents the effects of biological and technical factors in the data. The contributions of these components are disentangled using a novel objective function that regularizes mutual information between the distributions of biological and technical effects. CODAL thereby uncovers an explanation for the influence of technical effects which is maximally independent of biological variation, enabling decomposition of biological effects into interpretable modules, detection of batch-confounded cell states, and representation of cells in a batch-corrected, yet cell type discriminative, latent space.

Because CODAL can detect and represent batch-confounded cell states, this method is useful for the analysis of single-cell perturbation atlases. We infer the parameters of the generative model using fast and memory-efficient streamed gradient ascent over a differentiable approximation of the objective function. The hyperparameters of the model are tuned to a given dataset using highly parallelizable Bayesian optimization. Thus, the method is scalable to large datasets and can leverage more compute resources for faster inference.

In the section “Generative model”, we describe the CODAL generative model of observed counts and outline the assumptions made to estimate the rates of expression of genes or chromatin accessibility of genomic loci without technical effect confounding. In “Parameter inference”, we define the objective function used to infer the generative model parameters. We then describe differentiable approximations of the terms of the objective function which enables fast inference using stochastic gradient ascent and summarize the complete algorithm used to train a CODAL model. In “CODAL Bayesian hyperparameter optimization”, we describe the Bayesian hyperparameter optimization scheme which tunes the model hyperparameters (chiefly the number of topics) to best represent a particular dataset. In “CODAL technical effect augmentation of regulatory potential model”, we show how the technical effect estimates of a CODAL model may be used to augment other generative models of observed counts with technical effect correction. “Analysis of NEURIPS bone marrow dataset” details the NEURIPs bone marrow multimodal batch effect benchmark dataset and analyses. In “Frankencell batch-confounded cell type tests”, we describe the Frankencell benchmark test for batch-confounded cell types. In “Analysis of mouse embryo differentiation and perturbation”, we provide details on the analysis of the mouse embryo differentiation and perturbation dataset. Finally, “Computational resources benchmarking” describes the benchmarking of CODAL’s computational resource requirements.

The CODAL generative model is based on the earlier MIRA28 method for variational autoencoder-based inference of topics from scRNA-seq and scATAC-seq data. CODAL uses the same nonlinear encoder and linear decoder model architectures, latent space reparameterization, and priors as MIRA (sections “Generative model with interpretable biological states and technical effects” and “Lower bound on marginal likelihood”). Unique to CODAL are model components that facilitate technical effect disentanglement, a new parameter inference algorithm and objective function, and a redesigned hyperparameter tuning scheme which is faster and more parallelizable.

Generative model

The CODAL topic model, an extension of the MIRA topic model, accounts for technical effect confounding biological signal in the analysis of multi-batch scRNA-seq, scATAC-seq, or multimodal data. The model specifies a generative explanation of the cell’s observed features (RNA-seq or ATAC-seq reads) as the sum of biological and technical effects. We assume biological effects are purely a reflection of cell state and follow a low-dimensional manifold prescribed by concerted changes in gene regulatory programs in individual cells. Therefore, we model the biological state as a mixture of hidden latent variables which are linearly related to changes in the cell’s underlying gene expression or chromatin accessibility state. Each latent variable, or “topic”, is associated with a set of linear weights describing the gene expression or genomic loci accessibility changes that are linked to the topic. These topics capture axes of association and covariation evident in the data and suggest the influence of some shared underlying facet of cell state.

Generative model with interpretable biological states and technical effects

The technical factors acting on a given batch of cells and their subsequent effects on observed counts are typically unknown, but cell state and technical covariates like batch of origin may serve as proxy variables indicating cells that are influenced by similar factors. We approximate the true distribution of technical effects as a function of these proxies, implemented as a neural network to reflect the unknown functional form of technical effects.

To sample from the generative distribution of transcript counts, \({X}^{{{{{\rm{RNA}}}}}}\in {{\mathbb{Z}}}_{[0,\infty )}^{{N}_{{{{{\rm{cells}}}}}}\times {N}_{{{{{\rm{genes}}}}}}}\) (each observation is an element of the set of integers between 0 and infinity) in a multi-batch scRNA-seq dataset with associated cell technical covariates \(C\in {{\mathbb{R}}}^{{N}_{{{{{\rm{cells}}}}}}\times {N}_{{{{{\rm{covariates}}}}}}}\), we first draw from the distribution of the cell state latent random variables \(Z\in {{\mathbb{R}}}_{[{{{{\mathrm{0,1}}}}}]}^{{{N}_{{{{{\rm{cells}}}}}}\times N}_{{{{{\rm{topics}}}}}}}\), where Zi is a mixture over the latent variables for cell i:

$$\mathop {\sum }\limits_{d=1}^{{N}_{{{{{\rm{topics}}}}}}}{Z}_{{id}}=1,\forall i\in \{1,\, \ldots,\,{N}_{{{{{\rm{cells}}}}}}\}$$
(1)

Like Latent Dirichlet Allocation (LDA)26, we suppose that the cell state mixture of latent variables is Dirichlet-distributed and sparse, such that only a few hidden factors are active in defining the state of each cell. We specify a hierarchical prior that controls the pseudocounts, \(\alpha \in {{\mathbb{R}}}_{(0,\infty )}^{{N}_{{{{{\rm{topics}}}}}}}\), allotted to each latent variable, where \({{{{{\mathcal{I}}}}}}{{{{{\mathscr{\in }}}}}}{{\mathbb{Z}}}_{(0,\infty )}\) is the total pseudocounts allotted to the Dirichlet distribution:

$${Z}_{i\cdot } \sim {{{{{\rm{Dirichlet}}}}}}({\alpha }_{1},\,\ldots,\,{\alpha }_{{N}_{{{{{{\rm{topics}}}}}}}}),\,\forall i\in \{1,\ldots,\, {N}_{{{{{{\rm{cells}}}}}}}\}\\ {\alpha }_{d} \sim {{{{{\rm{Gamma}}}}}}\left(2,\,\frac{2{N}_{{{{{{\rm{topics}}}}}}}}{ {\mathcal I} }\right),\,\forall d\in \{1,\,\ldots,\,{N}_{{{{{{\rm{topics}}}}}}}\}$$
(2)

This allows for data-driven tuning of the prior’s sparsity to fit diverse trends and modalities. We fix the total pseudocounts hyperparameter at \({{{{{\mathcal{I}}}}}}=50\) for all Ntopics so that the sparsity of the Dirichlet hyperprior is not influenced by the dimensionality of the latent space. Next, we estimate biological effects, or expression rates \(\lambda \in {{\mathbb{R}}}^{{{N}_{{{{{\rm{cells}}}}}}\times N}_{{{{{\rm{genes}}}}}}}\), as a function of cell state Z and topic matrix \(\beta \in {{\mathbb{R}}}^{{N}_{{{{{\rm{topics}}}}}}\times {N}_{{{{{\rm{genes}}}}}}}\), which linearly links the influence of topics to changes in expression space:

$$\begin{array}{c}{\lambda }_{{ij}}={{{{{\rm{batchnorm}}}}}}\big({{{{{{\rm{dropout}}}}}}}(Z_{i\cdot })\times {\beta }_{\cdot j}\big),\\ \forall i\in \{1,\ldots,\,{N}_{{{\mbox{cells}}}}\},\forall j\in \left\{1,\ldots,\,{N}_{{{{{{\rm{genes}}}}}}}\right\}.\end{array}$$
(3)

The dropout function52 regularizes the expression model and increases stability during parameter inference. We use the PyTorch53 implementation of dropout, which sets units in the input matrix to zero at rate p, then rescales the matrix by \(\frac{1}{1-p}\). The dropout rate parameter is set within a range of [0.05, 0.1] during hyperparameter tuning. The batchnorm function54 performs the following affine transformation:

$${{\mbox{batchnorm}}}\big({Z}_{i\cdot }{\beta }_{\cdot j}\big)={\gamma }_{{bj}}\frac{{Z}_{i\cdot }{\beta }_{\cdot j}-{\mu }_{Z{\beta }_{\cdot j}}}{{\sigma }_{Z{\beta }_{\cdot j}}}+{b}_{{bj}},$$
(4)

which standardizes the product \({Z}_{i\cdot }{\beta }_{\cdot j}\) by the mean and standard deviation of the product across all cells for that gene (\({\mu }_{Z{\beta }_{\cdot j}}\) and \({\sigma }_{Z{\beta }_{\cdot j}}\)), then rescales by gene-specific biological effect variance and bias parameters \({\gamma }_{b}\in {{\mathbb{R}}}^{{N}_{{{{{\rm{genes}}}}}}}\) and \({b}_{b}\in {{\mathbb{R}}}^{{N}_{{{{{\rm{genes}}}}}}}\).

Next, technical effects \(t\in {{\mathbb{R}}}^{{N}_{{{{{\rm{cells}}}}}}\times {N}_{{{{{\rm{genes}}}}}}}\) are estimated as a function of cell state Z and the cell-level covariate proxies C. C can include one-hot encoded batch indicator variables and continuous quality control metrics, which could be useful for technical effect disentanglement. The technical effect function is implemented using the neural network hϕ with weights ϕ, where the output is zero-centered then rescaled by technical effect variance parameter \({\gamma }_{t}\in {{\mathbb{R}}}^{{N}_{{{{{\rm{genes}}}}}}}\) and the technical effect predictions are Bernoulli corrupted with probability 0.05. Bernoulli corruption of technical effects stabilizes cyclical training of the topic model and the mutual information regularizer (see “Parameter inference”). The hϕ neural network has one input layer with 32 nodes and an output layer. For inputs Zi· and Ci· for cell i; weights \({W}_{h}^{0}\in {{\mathbb{R}}}^{\left({N}_{{{{{{\rm{topics}}}}}}}+{N}_{{{{{{\rm{covariates}}}}}}}\right)\times 32}\) and \({W}_{h}^{1}\in {{\mathbb{R}}}^{32\times {N}_{{{{{{\rm{genes}}}}}}}}\); biases \({b}_{h}^{0}\in {{\mathbb{R}}}^{32}\) and \({b}_{h}^{1}\in {{\mathbb{R}}}^{{N}_{{{{{{\rm{genes}}}}}}}}\); and layer intermediate \({\nu }_{h}^{0}\), the technical effect for gene j is given by:

$$\begin{array}{c}{t}_{{ij}}={\left({h}_{\phi }\left({Z}_{i\cdot },{C}_{i\cdot }\right)\right)}_{j}=\left\{\begin{array}{c}{\gamma }_{{tj}}\frac{{t}_{{ij}}^{{\prime} }-{\mu }_{{t}_{\cdot j}^{{\prime} }}}{{\sigma }_{{t}_{\cdot j}^{{\prime} }}}\,{{{{{\rm{if}}}}}}\,c=0\\ 0\,{{{{{\rm{otherwise}}}}}}\end{array}\right.,\\ c \sim {{{{{\rm{Bernoulli}}}}}}\left(\frac{1}{20}\right),\\ \begin{array}{c}{{\forall }}i\in \left\{1,\,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\right\},\,{{\forall }}j\in \left\{1,\,\ldots,\,{N}_{{{{{{\rm{genes}}}}}}}\right\},\\ {t}_{i\cdot }^{{\prime} }={\upsilon }_{h}^{0}{W}_{h}^{1}+{b}_{h}^{1},\\ \begin{array}{c}{v}_{h}^{0}={{{{{\rm{dropout}}}}}}\,\circ \,{{{{{\rm{Relu}}}}}}\,\circ \,{{{{{\rm{batchnorm}}}}}}\,\circ \,\left[({{{{{\rm{dropout}}}}}}({Z}_{i\cdot })\bigoplus {C}_{i\cdot })W_{h}^{0}+{b}_{h}^{0}\right],\\ \begin{array}{c}{{\forall }}i\in \left\{1,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\right\},\\ \phi=\{{\gamma }_{t},\,{W}_{h}^{1},\,{b}_{h}^{1},\,{W}_{h}^{0},\,{b}_{h}^{0}\}.\end{array}\end{array}\end{array}\end{array}$$
(5)

Above, indicates concatenation and the dropout rate is set to 1/20.

Finally, counts XRNA from a scRNA-seq experiment are drawn from a Negative Binomial noise distribution parameterized by the sum of the biological and technical effects and a gene-level dispersion parameter \(\vartheta \in {{\mathbb{R}}}^{{N}_{{{{{\rm{genes}}}}}}}\). Sums of biological and technical effects across all genes in each cell are first transformed into a composition \(\rho \in {{\mathbb{R}}}^{{N}_{{{{{\rm{cells}}}}}}\times {N}_{{{{{\rm{genes}}}}}}}\), where \({\sum }_{j=1}^{{N}_{{{{{\rm{genes}}}}}}}{\rho }_{{ij}}=1,\, \forall i\in \{1,\ldots,\,{N}_{{{{{\rm{cells}}}}}}\}\), describing the underlying categorical distribution over transcripts counts in each cell. That composition is scaled by the learned effective read depth, or “size factor”, random variable, ni for cell i to give the rate parameter of the Negative Binomial distribution:

$$\begin{array}{c}{X}_{{ij}}^{{{{{{\rm{RNA}}}}}}} \sim {{{{{\rm{NegativeBinomial}}}}}}\left({n}_{i}{\rho }_{{ij}},\,{\vartheta }_{j}\right),{{\forall }}i\in \{1,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\},{{\forall }}j\in \\ \{1,\ldots,\,{N}_{{{{{{\rm{genes}}}}}}}\},\\ {n}_{i} \sim {{{{{\rm{LogNormal}}}}}}\left({{\log }}\mathop{\sum }\limits_{j=1}^{{N}_{{{{{{\rm{genes}}}}}}}}{X}_{{ij}}^{{{{{{\rm{RNA}}}}}}},\,1\right),\,{{\forall }}i\in \{1,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\},\\ {\rho }_{{ij}}=\frac{{{\exp }}({\lambda }_{{ij}}+{t}_{{ij}})}{\mathop{\sum }\nolimits_{l=0}^{{N}_{{{{{\rm{genes}}}}}}}{{\exp }}({\lambda }_{{il}}+{t}_{{il}})},\,\forall i\in \{1,\ldots,\,{N}_{{{{{\rm{cells}}}}}}\},\, \forall \,j\in \{1,\ldots,\,{N}_{{{{{\rm{genes}}}}}}\}.\end{array}$$
(6)

The size factor random variable is given a weak LogNormal prior centered at the observed number of counts measured in that cell. For scATAC-seq counts \({X}^{{{{{\rm{ATAC}}}}}}\in {{\mathbb{Z}}}_{\left[{0,1}\right]}^{{N}_{{{{{\rm{cells}}}}}} \times {N}_{{{{{\rm{peaks}}}}}}}\), we model observations of accessibility across all regions in each cell using the multinomial distribution:

$${X}_{i\cdot }^{{{{{\rm{ATAC}}}}}} \sim {{{{{\rm{Multinomial}}}}}}\left({\rho }_{i\cdot },\, {\hat{n}}_{i}^{{{{{\rm{ATAC}}}}}}\right),\,\forall i\in \{1,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\}\\ {\hat{n}}_{i}^{{{{{{\rm{ATAC}}}}}}} =\mathop{\sum }\limits_{k=1}^{{N}_{{{{{{\rm{peaks}}}}}}}}{X}_{{ik}}^{{{{{{\rm{ATAC}}}}}}},\,{{\forall }}i\in \{1,\ldots,\,{N}_{{{{{{\rm{cells}}}}}}}\}.\\ {\rho }_{{ik}} =\frac{{{\exp }}({\lambda }_{{ij}}+{t}_{{ij}})}{\mathop{\sum }\nolimits_{l=0}^{{N}_{{{{{\rm{peaks}}}}}}}{{\exp }}({\lambda }_{{il}}+{t}_{{il}})},\, \forall i \in \{1,\ldots,\,{N}_{{{{{\rm{cells}}}}}}\},\, \forall \, k \in \{1,\ldots,\,{N}_{{{{{{\rm{peaks}}}}}}}\}$$
(7)

Besides the noise distribution, all other aspects of the generative distribution are the same for models of both modalities. In sum, the CODAL topic model is a generative probabilistic latent variable model which describes variation in single-cell count data as the sum of biological and technical effects. The latent topics are modeled as sampled from a sparse Dirichlet-distributed prior and are linearly related to changes in gene expression or peak accessibility. The noise distribution is adapted to fit the specific properties of whichever modality is modeled.

Finally, the unknown distribution of technical effects is approximated using a neural network conditioned on cell state and cell technical covariate proxies. Technical covariates can be provided as either categorical features, which are represented as one-hot vectors, or continuous features, which are standardized. Any number of batches or number of features per batch may be provided.

CODAL disentangled expression and accessibility predictions

In using the generative model to estimate gene expression in some cell i without confounding technical effects, we make three assumptions. The first is that technical effects t and expression rates λ are mostly independent. We infer a generative model of the data where this is effectively satisfied using our mutual information regularization scheme (described in the section “Lower bound on mutual information”). The second is the expectation over observed technical effects is an unbiased estimator for the true mean effect, and that mean effect has no influence on the gene. Consequently, for many batches measuring cells in the same state, we assume observed counts will be distributed about the true biological measurement of counts.

First, we calculate the expected value of the latent variable composition in some cell, \({\bar{Z}}_{i\cdot }\), using the posterior distribution:

$${\bar{Z}}_{i\cdot }={{{{{{\rm{E}}}}}}}_{z \sim p\left(\cdot |,{X}_{i},{C}_{i}\right)} \left[Z\right],\, \forall i \in \left\{1,\ldots,\, {N}_{{{{{\rm{cells}}}}}}\right\},$$
(8)

Then, the unconfounded compositional expression of gene j in that cell, \({\bar{\rho }}_{{ij}}\), is given by:

$$\begin{array}{c}{\bar{\rho }}_{{ij}}=\frac{{{\exp }}({\bar{\lambda }}_{{ij}})}{\mathop{\sum }\nolimits_{l=0}^{{N}_{{{{{\rm{genes}}}}}}}{{\exp }}({\bar{\lambda }}_{{il}})},\\ {\bar{\lambda }}_{{ij}}={\gamma }_{b,j}\frac{{\bar{Z}}_{i\cdot }{\beta }_{\cdot j}-{\mu }_{Z{\beta }_{\cdot j}}}{{\sigma }_{Z{\beta }_{\cdot j}}}+{b}_{b,j},\\ \forall i\in \{1,\ldots,\,{N}_{{{{{\rm{cells}}}}}}\},\forall \, j \in \{1,\ldots,\,{N}_{{{{{\rm{genes}}}}}}\}.\end{array}$$
(9)

Recall that we constrain the distribution of sample technical effects to be zero-centered. Thus, \({\bar{\lambda }}_{{ij}}\) above is equivalent to the expected expression rate plus the expected value of technical effects:

$${\bar{\lambda }}_{{ij}}={\bar{\lambda }}_{{ij}}+E\left[t\right],\,E[t]=0.$$
(10)

CODAL latent space and nearest neighbor graph

The expected value of the posterior distribution of latent topics for cell i, \({\bar{Z}}_{i\cdot }\), falls on the (\({N}_{{{{{{\rm{topics}}}}}}}-1\)) simplex. To analyze distances between cells’ state representations on the simplex, we transform the topic compositions to Euclidean vector space using MIRA’s28 isometric log-ratio transformation of \({\bar{Z}}_{i\cdot }\) We then calculate distances between transformed cell latent variables using the Manhattan distance to create a k-nearest neighbors (k-NN) graph representing cells in similar states. This k-NN graph is used in downstream clustering and UMAP analysis.

Parameter inference

We aim to learn the distribution of expression rates for cells in state Z without technical effect distortion. From the dependency diagram of the generative model (Fig. 1c), due to the association between λ and t implied by their dependence on Z, we find t and λ are not independent. Consequently, to use our generative model to learn an unconfounded distribution for expression rate, we created a novel objective function, \({{{{{\mathcal{V}}}}}}^{{{{{\rm{CODAL}}}}}}\), that penalizes dependence between predicted expression and technical effects, λ and t, using mutual information55.

Given a dataset of independent count observations from cells \({X}^{{{{{{\rm{RNA}}}}}}}\) or \({X}^{{{{{{\rm{ATAC}}}}}}}\) and technical covariates C, CODAL learns the parameters of the generative distribution (associations β, technical effect function parameters ϕ, dispersions ϑ, and batch normalization parameters γb, γt, and bb) as well as the posterior distribution of cell latent variables Z which maximize:

$$\begin{array}{c}{\varphi }_{{{\max }}}={{{\mbox{argmax}}}}_{\varphi }\mathop{\sum }\limits_{i=1}^{{N}_{{{{{{\rm{cells}}}}}}}}{{{{{{\mathcal{V}}}}}}}_{\varphi }^{{{{{{\rm{CODAL}}}}}}}\left({X}_{i},\,{C}_{i}\right)\\ {{{{{{\mathcal{V}}}}}}}_{\varphi }^{{{{{{\rm{CODAL}}}}}}}\left({X}_{i},\,{C}_{i}\right)={{\log }}\int {p}_{\varphi }\left({X}_{i}{{|}}Z,\,{C}_{i}\right){p}_{\varphi }\left(Z\right){dZ}-I\left(\lambda,\,t \right),\\ \varphi=(\beta,\,\phi \,,\, \vartheta \,,\,{\gamma }_{b},\, {b}_{b}).\end{array}$$
(11)

This objective finds parameters that both maximize the marginal likelihood of the observations and minimize the mutual information between estimated biological and technical effects over all the cells included in the analysis. However, the optimal parameter values are intractable to compute analytically for both terms, so we approximate these quantities with differentiable lower bounds.

Lower bound on mutual information

Mutual information neural estimator

The CODAL objective function requires that we regularize the mutual information between the distributions of expression and technical effect predictions. Here, we describe a lower bound on mutual information which can be estimated using gradient ascent and neural networks. The mutual information between the distributions, \({\mathbb{P}}\), of two continuous random variables, in our case λ and t, is defined by:

$$I\left(\lambda,\,t\right):\!\!={{{{{{\rm{D}}}}}}}_{{{{{{\rm{KL}}}}}}}({{\mathbb{P}}}_{\lambda,t}{{||}}{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}),$$
(12)

Where \({{\mathbb{P}}}_{\lambda,t}\) is the joint distribution and \({{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}\) is the product of the marginal distributions of the random variables. Mutual information therefore diminishes when the two variables are independent. The Donsker-Varadhan56 dual form of mutual information is:

$${{{{{{\rm{D}}}}}}}_{{{{{{\rm{KL}}}}}}}({{\mathbb{P}}}_{\lambda,t}{{||}}{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t})={\sup }_{T{{{{{\mathscr{\in }}}}}}{{{{{\mathcal{F}}}}}}}\,{E}_{{{\mathbb{P}}}_{\lambda,t}}\left[T\right]-{{\log }}{E}_{{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}}[{e}^{T}],$$
(13)

where the supremum is taken over all functions T in the set \({{{{{\mathcal{F}}}}}}\) of all functions for which both expectations are finite. Belghazi et al.39, approximates the set of functions \({{{{{\mathcal{F}}}}}}\) with the set of possible weights, Θ, for a neural network. Thus, they give a lower bound on mutual information called Mutual Information Neural Estimator (MINE), where Tθ is a neural network with weights θ Θ:

$${{{{{{\rm{D}}}}}}}_{{{{{{\rm{KL}}}}}}}({{\mathbb{P}}}_{\lambda,t}{{||}}{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t})\ge {\sup }_{\theta \in \Theta }\,{E}_{{{\mathbb{P}}}_{\lambda,t}}\left[{T}_{\theta }\right]-{{\log }}\,\,{E}_{{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}}\left[{e}^{{T}_{\theta }}\right].$$
(14)

In practice, for a minibatch of m pairs of λ and t, corresponding to a sample of m cells, one calculates the lower bound estimate of mutual information, by:

$$\begin{array}{c}I\left(\lambda,\,t\right)\ge {\sup }_{\theta \in \Theta }{{{\mbox{MINE}}}}\left(\lambda,\,t{{;}}\,\theta,\,{m}\right),\\ {{{\mbox{MINE}}}}\left(\lambda,\,t{{;}}\,\theta,\,{m}\right)=\frac{1}{m}\mathop{\sum }\limits_{i=1}^{m}{T}_{\theta }({\lambda }_{i\cdot },\,{t}_{i\cdot })-{{\log }}\frac{1}{{m}^{2}}\mathop{\sum }\limits_{i=1}^{m}\mathop{\sum }\limits_{j=1}^{m}{e}^{{T}_{\theta }\left({\lambda }_{i\cdot },\,{t}_{j\cdot }\right)},\\ \left({\lambda }_{1\cdot },\,{t}_{1\cdot }\right),\ldots,\,\left({\lambda }_{m\cdot },\,{t}_{m\cdot }\right) \sim {{\mathbb{P}}}_{\lambda,t},\end{array}$$
(15)

To search the space of Θ, weights θ at step s are updated by gradient ascent with respect to the mutual information estimate to give a tighter lower bound:

$${\theta }^{s+1}\leftarrow {\theta }^{s}+{\nabla }_{{\theta }^{s}}{{{\mbox{MINE}}}}\left(\lambda,\,t{{;}}\,{\theta }^{s},\,m\right).$$
(16)

In summary, the MINE framework trains a neural network to maximize the score difference between paired and unpaired samples of λ and t, and takes gradient ascent steps with respect to the score difference to tighten the resulting lower bound on mutual information.

We found that using MINE mutual information estimation in the CODAL objective function yielded unstable estimates, likely due to properties of the KL divergence explored by Arjovsky et al.57 In particular, the KL divergence is not informative for distributions with non-identical support. Therefore, we investigated the Earth Mover, or Wasserstein distance58 as an alternative measure of the difference between the joint distribution and product of the marginals between two random variables.

Wasserstein dependency measure

The Wasserstein dependency measure38, IW, substitutes the KL divergence in the formulation of mutual information with the Wasserstein distance, W, yielding:

$${I}_{W}(\lambda,\,t):\!\!=W({{\mathbb{P}}}_{\lambda,t},\,{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}).$$
(17)

The Wasserstein distance is defined by the minimum cost of “transporting” the probability density of one distribution onto a second distribution. Typically, this metric is interpreted as the cost of transforming a pile of earth at one location into another at a different location, where the cost incurred is the distance the dirt was moved times the amount. In this scenario, the two piles of dirt are taken to be two distributions, \({{\mathbb{P}}}_{p}\) and \({{\mathbb{P}}}_{q}\), and the Wasserstein distance is the total cost of the most efficient coupling γ, or transport plan, to convert between them:

$$W\left({{\mathbb{P}}}_{p},{{\mathbb{P}}}_{q}\right):=\,{\inf }_{\gamma \in \Gamma \left({{\mathbb{P}}}_{p},{{\mathbb{P}}}_{q}\right)}{\int }_{x\times y}{||x}-{y||d}\gamma \left(x,\,y\right).$$
(18)

Here, x and y are elements of some metric space, the coupling γ is a joint distribution of \({{\mathbb{P}}}_{p}\) and \({{\mathbb{P}}}_{q}\), and \(\Gamma ({{\mathbb{P}}}_{p},\,{{\mathbb{P}}}_{q})\) is the set of all couplings between the distributions.

Taking \({{\mathbb{P}}}_{p}\) and \({{\mathbb{P}}}_{q}\) to represent our distributions of interest, the Kantorovich-Rubenstein dual form58 of the Wasserstein distance admits:

$$W\left({{\mathbb{P}}}_{\lambda,t},\,{{\mathbb{P}}}_{\lambda }\otimes {{\mathbb{P}}}_{t}\right)={\sup }_{T{{{{{\mathscr{\in }}}}}}{{{{{\mathcal{L}}}}}}}\,{E}_{{{\mathbb{P}}}_{ \lambda,t}}\left[T\right]-{E}_{{{\mathbb{P}}}_{ \lambda }\otimes {{\mathbb{P}}}_{t}}\left[T\right].$$
(19)

where the function T belongs to the set of 1-Lipschitz continuous functions where the expectation is finite: \({{{{{\mathcal{L}}}}}}\). We approximate the set of 1-Lipschitz continuous functions using a neural network with weights constrained such that the function defined by the neural network is 1-Lipschitz continuous: \(\theta \in {\Theta }_{{{{{{\mathcal{L}}}}}}},{\Theta }_{{{{{{\mathcal{L}}}}}}}=\{\theta|\theta \in \Theta,{T}_{\theta }\,{{{{{\mathscr{\in }}}}}}\,{{{{{\mathcal{L}}}}}}\}\). In this way, Ozair et al.38 proposed a lower bound on the Wasserstein dependency measure, w, for a minibatch of m pairs of λ and t:

$$\begin{array}{c}{I}_{W}\left( \lambda,\,t\right)\ge {\sup }_{\theta \in {\Theta }_{{{{{{\mathcal{L}}}}}}}}w\left(\lambda,\,t{{;}}\,\theta,\,m\right),\\ w\left(\lambda,\,t{{;}}\,\theta,\,m\right)=\frac{1}{m}\mathop{\sum }\limits_{i=1}^{m}{T}_{\theta }({\lambda }_{i\cdot },\,{t}_{i\cdot })-\frac{1}{m}\mathop{\sum }\limits_{i=1}^{m}{{\log }}\mathop{\sum }\limits_{i=1}^{m}{{\exp }}{T}_{\theta }({\lambda }_{i\cdot },\,{t}_{j\cdot }),\\ \left({\lambda }_{1\cdot },\,{t}_{1\cdot }\right),\, \ldots,\,\left({\lambda }_{m\cdot },\,{t}_{m\cdot }\right) \sim {{\mathbb{P}}}_{\lambda,t}.\end{array}$$
(20)

These authors found the inclusion of the log∑exp term stabilized training of the estimator. This expression for the lower bound on the Wasserstein dependency metric is a lower bound on MINE plus a constant40:

$$w\left(\lambda,\,t{{;}}\, \theta,\,m\right)\le {{{\mbox{MINE}}}}\left(\lambda,\,t{{;}}\,\theta,\,m\right)+{{\log }}m.$$
(21)

CODAL mutual information regularization

Given the above relationship between mutual information, the MINE estimator, and the Wasserstein dependency measure, we regularize the mutual information between λ and t via a lower bound approximation based on the 1-Lipschitz continuous MINE estimator:

$$I\left(\lambda,t\right)\ge {{{\sup }}}_{\theta \in {\Theta }_{{{{{{\mathcal{L}}}}}}}}{{{{{\rm{MINE}}}}}}\left(\lambda,\,t{{;}}\,\theta,\,m\right).$$
(22)

We use a neural network \({T}_{\theta }:{{\mathbb{R}}}^{2{N}_{{{{{{\rm{gene}}}}}}}}{\mathbb{\to }}{\mathbb{R}}\), to approximate the function T which satisfies the supremum condition in the dual form of mutation information. We use gradient ascent optimization to determine the network parameters θ that maximize the function MINE(λ, t; θ, m). Without 1-Lipschitz constraints on Tθ, the network can maximize this objective by exaggerating small differences between unpaired samples, decoupling the objective score from the apparent strength of the dependence relationship between the variables. This property is remedied by constraining the network to be 1-Lipschitz continuous, as required by Wasserstein distance-based metrics. The 1-Lipschitz constraint also dramatically improves the stability of the gradient with respect to the network parameters, eliminating catastrophic gradient overflows during training.

We employ spectral normalization59 to enforce 1-Lipschitz continuity of the neural network estimator Tθ, adjusting neural network weights θ after each gradient step. The Tθ network, with an input layer, a hidden layer with 64 nodes, and an output layer that outputs a mutual information estimate, is defined as follows:

$${T}_{\theta }\left(\lambda,\,t\right) ={\upsilon }_{T}^{1}{W}_{T}^{2}+{b}_{T}^{2},\\ {\upsilon }_{T}^{1} ={ReLU}\left({\upsilon }_{T}^{0}{W}_{T}^{1}+{b}_{T}^{1}\right),\\ {\upsilon }_{T}^{0} ={{{{{\rm{ReLU}}}}}}\left({{{{{\rm{dropout}}}}}}\left({{{{\uplambda }}}} \oplus {{{{{\rm{t}}}}}}\right){W}_{T}^{0}+{b}_{T}^{0}\right),\\ \theta =\left\{{b}_{T}^{2},\,{b}_{T}^{1},\,{b}_{T}^{0},\,{W}_{T}^{2},\,{W}_{T}^{1},\,{W}_{T}^{0}\right\},$$
(23)

with weights \({W}_{T}^{0}\in {{\mathbb{R}}}^{2{N}_{{{{{{{{\rm{gene}}}}}}}}}\times 64}\), \({W}_{T}^{1}\in {{\mathbb{R}}}^{64\times 64}\), and \({W}_{T}^{2}\in {{\mathbb{R}}}^{64\times 1}\); and biases \({b}_{T}^{0}\in {{\mathbb{R}}}^{64}\), \({b}_{T}^{1}\in {{\mathbb{R}}}^{64}\), and \({b}_{T}^{2}\in {{\mathbb{R}}}^{1}\).

Lower bound on marginal likelihood

To compute a differentiable lower bound on the marginal likelihood of the data, we use the variational autoencoder31 approach, which approximates the posterior distribution of Z with the variational distribution \({q}_{\psi }({Z|X},\,C)\). The variational distribution, then, is parameterized by a neural network called the “encoder” with weights ψ, which stochastically maps data samples \((X,C)\) to the latent space \(Z\):

$$Z \sim {q}_{\psi }(\cdot|X,C)={{{\mbox{Encoder}}}}_{\psi }(X,C).$$
(24)

This approximation admits the lower bound on marginal likelihood called the “evidence lower bound”, or ELBO31. For a sample \(({X}_{i},\,{C}_{i})\) from the dataset:

$$ {{\log }}\,{p}_{\varphi }\left({X}_{i}\right)={{\log }}\int {p}_{\varphi }\left({X}_{i}{{ | }}Z,\,{C}_{i}\right){p}_{\varphi }\left(Z\right){dZ}\\ \ge {E}_{Z \sim {q}_{\psi }(\cdot | {X}_{i},{C}_{i})}\left[{{{\log }}\,p}_{\varphi }\left({X}_{i} |,Z,{C}_{i}\right)\right]-{D}_{{KL}}({q}_{\psi }\left(Z,| {X}_{i},{C}_{i}\right){{ | | }}{p}_{\varphi }(Z)).$$
(25)

We parameterize the variational distribution \({q}_{\psi }(Z|X,\,C)\) such that samples are Dirichlet-distributed using the same method as MIRA. Briefly, for each latent variable in each cell, the encoder neural network outputs mean and variance parameters. These specify a Multivariate Normal distribution with a diagonal covariance matrix. Samples from this Multivariate Normal distribution are transformed into a composition by a Laplace approximation to the Dirichlet distribution34.

We adapt the architecture of the encoder neural network depending on the modality being modeled. The expression encoder neural network has an input layer and one hidden layer each comprised of a fully connected feed-forward layer, batch normalization, ReLU activation60, then dropout. The input and hidden layer have 512 nodes, and the probability of dropout is set to 0.05. Raw expression counts are transformed using deviance residuals61 before being fed through the encoder model. The output layer is fully connected feed forward followed by batch normalization with mean and variance heads for each latent dimension, plus additional mean and variance heads that parameterize the posterior LogNormal distribution from which the cell-specific size factor parameter, \({n}_{i}\), is drawn.

The chromatin accessibility encoder neural network is implemented using the Deep Averaging Network architecture62. Each cell is represented as a “bag of peaks” and encoded as the average of embedding vectors for all peaks in that cell. The average embedding vectors are fed through a hidden layer and an output layer. In addition, a skip connection links the embedding layer to the output layer. Thus, for cell \(i\), given embedding matrix \({W}^{0}\in {{\mathbb{R}}}^{{N}_{{{{peaks}}}}\times 256}\); feed-forward weights \({W}^{1}\in {{\mathbb{R}}}^{256+{N}_{{{{{{{{\rm{covariates}}}}}}}}}\times 256}\) and \({W}^{2}\in {{\mathbb{R}}}^{256\times 2{N}_{{{{{{{{\rm{topics}}}}}}}}}}\); and biases \({b}^{1}\in {{\mathbb{R}}}^{256}\) and \({b}^{2}\in {{\mathbb{R}}}^{2{N}_{{{{{{{{\rm{topics}}}}}}}}}}\), the accessibility encoder network gives:

$$\begin{array}{c}{{{\mbox{Encoder}}}}_{\psi }(X,\,C)={{{\mbox{batchnorm}}}}\left(({v}_{i\cdot }^{1}+{v}_{i\cdot }^{0}){W}^{2}+{b}^{2}\right)\\ {v}_{i\cdot }^{1}={{{{{{{\rm{dropout}}}}}}}}\,\circ \,{{{{{{{\rm{ReLU}}}}}}}}\, \circ \,{{{{{{{\rm{batchnorm}}}}}}}}\,\circ \,\left[\left({v}_{i\cdot }^{0}\oplus {C}_{i\cdot }\right){W}^{1}+{b}^{1}\right]\\ \begin{array}{c}{v}_{i\cdot }^{0}=\frac{1}{\left|{\Omega }_{i}\right|}\mathop{\sum}\limits_{k\in {\Omega }_{i}}{W}_{k\cdot }^{0}\\ {\Omega }_{i}=\left\{k\right|k\in \left\{1,\, \ldots,\,{N}_{{{{peaks}}}}\right\},\,{X}_{{ik}}^{{{{{{{{\rm{ATAC}}}}}}}}}\, > \,0,\,{{\mbox{Bernoulli}}}\left(\frac{1}{20}\right)=0\left.\right\},\end{array}\end{array}$$
(26)

where Ωi is the set of accessible peaks in the cell corrupted by leaving out peaks at a rate given by Bernoulli trials with parameter \(p=\frac{1}{20}\). The embedding and hidden layer each have 256 nodes.

Training procedure

In sections “Lower bound on mutual information” and “Lower bound on marginal likelihood”, we specify differentiable lower bounds of the terms of the CODAL objective function. Using those lower bounds, CODAL employs a cyclical stochastic minibatch gradient ascent optimization strategy to learn parameters governing the latent space and generative distribution. First, for the CODAL objective function:

$${{{{{{{{\mathcal{V}}}}}}}}}_{\varphi }^{{{{{{{{\rm{CODAL}}}}}}}}}\left({X}_{i},\,{C}_{i}\right)={{\log }}\int {p}_{\varphi }\left({X}_{i}{{|}}Z,\,{C}_{i}\right){p}_{\varphi }\left(Z\right)\,{dZ}-I\left(\lambda,\,t \right),$$
(27)

we use the “ELBO” lower bound on the marginal likelihood:

$$ {{\log }}\,{p}_{\varphi }\left({X}_{i}\right)={{\log }}\int {p}_{\varphi }\left({X}_{i}{{|}}Z,\,{C}_{i}\right){p}_{\varphi }\left(Z\right){dZ}\\ \ge {E}_{Z \sim {q}_{\psi }(\cdot|{X}_{i},{C}_{i})}\left[{{{\log }}\,p}_{\varphi }\left({X}_{i}|Z,\,{C}_{i}\right)\right]-{D}_{{KL}}({q}_{\psi }\left(Z|{X}_{i},\,{C}_{i}\right){{||}}{p}_{\varphi }(Z)),$$
(28)

and the 1-Lipschitz regularized “MINE” lower bound on mutual information between expression rates and technical effects across the whole dataset:

$$I\left(\lambda,\,t\right)\ge {\sup }_{\theta \in {\Theta }_{{{{{{{{\mathcal{L}}}}}}}}}}{{{\mbox{MINE}}}}\left(\lambda,\,t{{;}}\,\theta,\,m\right).$$
(29)

To estimate mutual information, we draw \(m\) samples from the joint distribution of \(\lambda\) and \(t\) using Monte Carlo sampling from the dataset and the variational distribution of the latent variable Z:

$$\begin{array}{c}{\lambda }_{{m}^{{\prime} }\cdot }={Z}_{{m}^{{\prime} }}\beta,\,{t}_{{m}^{{\prime} }\cdot }={h}_{\phi }\left({Z}_{{m}^{{\prime} }},\,{C}_{i}\right),\\ {Z}_{{m}^{{\prime} }} \sim {q}_{\psi }(\cdot|{X}_{i},\,{C}_{i}),\\ \begin{array}{c}{X}_{{m}^{{\prime} }},\,{C}_{{m}^{{\prime} }} \sim {{\mathbb{P}}}_{X,C},\\ {{\forall }}m^{\prime} \in \{1,\ldots,\,m\}.\end{array}\end{array}$$
(30)

Putting together the two lower bounds and introducing term weights \({\varepsilon }_{1}\) and \({\varepsilon }_{2}\), which are varied over the course of parameter optimization, we define the differentiable approximation of the CODAL objective function \(\hat{{{{{{{{\mathcal{V}}}}}}}}}\):

$$\hat{{{{{{{{\mathcal{V}}}}}}}}}\left({X}_{i},\,{C}_{i}{{;}}\,{\varepsilon }_{1},\,{\varepsilon }_{2},\,m\right)={E}_{Z \sim {q}_{\psi }(\cdot|{X}_{i},{C}_{i})}\left[{{{\log }}\,p}_{\varphi }\left({X}_{i}|Z,\,{C}_{i}\right)\right]\\ -{\varepsilon }_{1}{D}_{{KL}}({q}_{\psi }\left(Z|{X}_{i},\,{C}_{i}\right){{||}}{p}_{\varphi }(Z))-{\varepsilon }_{2}{{{{{{{\rm{MINE}}}}}}}}\left(\lambda,\,t{{;}}\,\theta,\,m\right).$$
(31)

Intuitively, the \({D}_{{KL}}\) and MINE terms serve to regularize how the model reduces the reconstruction loss term of the ELBO: \({E}_{Z \sim {q}_{\psi }(\cdot|{X}_{i},{C}_{i})}[{{{\log }}\,p}_{\varphi }({X}_{i}|Z,\,{C}_{i})]\). The \(\varepsilon\) coefficients scale the influence of the loss components relative to one another. We anneal \({\varepsilon }_{1}\) and \({\varepsilon }_{2}\) at every step during training according to cyclically increasing and decreasing schedules, which helps to prevent mode collapse and over-regularization. At step \(s\) of training out of a total of training steps \({s}_{{{{{{{{\rm{total}}}}}}}}}\), \({\varepsilon }_{2}\) is set following the cyclic KL annealing schedule63, \(r\), with three cycles:

$$r(s)={{\min }}\left(1,\frac{2\times {{{{{{{\rm{mod}}}}}}}}(s,{s}_{{{{{{{{\rm{total}}}}}}}}}/3)}{{s}_{{{{{{{{\rm{total}}}}}}}}}/3}\right).$$
(32)

For \({\varepsilon }_{1},\) we implement an annealing schedule we call “step-up cyclic” annealing, \({r}_{{{{{{{{\rm{stepup}}}}}}}}}\), which linearly increases the maximum value of \({\varepsilon }_{1}\) at each cycle:

$${r}_{{{{{{{{\rm{stepup}}}}}}}}}(s)=\frac{1}{3}{{{{{{{\rm{ceil}}}}}}}}\left(\frac{s}{{s}_{{{{{{{{\rm{total}}}}}}}}}/3}\right)\times r(s)$$
(33)

Finally, the CODAL objective across the whole dataset may be estimated using subsampled minibatches. Therefore, we use the AdamW64 optimizer to update the topic model weights by minibatch stochastic gradient ascent. We anneal the learning rate, \({\omega }_{1}\), and momentum, \({\omega }_{2}\), parameters of the optimizer using the one-cycle learning rate policy during training. The minimum and maximum bounds of the learning rate parameter are determined by the learning rate range test before training. The momentum parameter is annealed between a minimum and maximum of 0.85 and 0.95 during training.

The gradient with respect to the topic model parameters depends on the \({{{{{{{\rm{MINE}}}}}}}}\) mutual information estimator. To ensure this model gives reliable estimates of mutual information, we also optimize its parameters by taking gradient steps with respect to \(\theta\) at each step of training using the Adam65 optimizer (no weight decay) with a learning rate of \({10}^{-4}\). Parameter updates are followed by spectral normalization to ensure 1-Lipschitz continuity of the estimator.

In sum, we iterate cyclic gradient ascent steps using the following procedure:

\({{{{{{{{\bf{init}}}}}}}}\, \varphi,\,\psi,\,\theta,\,{\omega }_{1},\,{\omega }_{2}}\)

\({{{{{{{{\bf{for}}}}}}}}\,{{{{{{{\rm{step}}}}}}}}\,s=1,\ldots,\,{s}_{{{{{{{{\rm{total}}}}}}}}}\,{{{{{{{\bf{do}}}}}}}}\!\!:}\)

\({{{\omega }_{1},\,{\omega }_{2}\leftarrow {{{{{{{\rm{anneal}}}}}}}}\; {{{{{{{\rm{learning}}}}}}}}\; {{{{{{{\rm{rate}}}}}}}}\; {{{{{{{\rm{and}}}}}}}}\; {{{{{{{\rm{momentum}}}}}}}}\; {{{{{{{\rm{of}}}}}}}}\; {{{{{{{\rm{AdamW}}}}}}}}\; {{{{{{{\rm{optimizer}}}}}}}}}}\)

\({{\varepsilon }_{1},\,{\varepsilon }_{2}\leftarrow {r}_{{{{{{{{\rm{stepup}}}}}}}}}\left(s\right),\,r\left(s\right)\,({{{{{{{\rm{anneal}}}}}}}}\,{{{{{{{\rm{objective}}}}}}}}\,{{{{{{{\rm{term}}}}}}}}\,{{{{{{{\rm{weights}}}}}}}})}\)

\({\left({X}^{m},\,{C}^{m}\right)\leftarrow {{{{{{{\rm{randomly}}}}}}}}\; {{{{{{{\rm{sample}}}}}}}}\; {{{{{{{\rm{minibatch}}}}}}}}\; {{{{{{{\rm{of}}}}}}}}\; {{{{{{{\rm{size}}}}}}}}\,{m}\,{{{{{{{\rm{from}}}}}}}}\; {{{{{{{\rm{dataset}}}}}}}}}\)

\({{{{{{{{{\boldsymbol{g}}}}}}}}}_{\hat{{{{{{{{\mathcal{V}}}}}}}}}}\leftarrow {\nabla }_{\varphi,\psi }\mathop{\sum }\nolimits_{i=1}^{m}\hat{{{{{{{{\mathcal{V}}}}}}}}}({X}_{i}^{m},{C}_{i}^{m};{\varepsilon }_{1},{\varepsilon }_{2},{m})\,({{{{{{{\rm{gradients}}}}}}}}\; {{{{{{{\rm{of}}}}}}}}\; {{{{{{{\rm{topic}}}}}}}}\; {{{{{{{\rm{model}}}}}}}}\; {{{{{{{\rm{weights}}}}}}}})}\)

\({\varphi,\psi \leftarrow {{{{{{{\rm{update}}}}}}}}\; {{{{{{{\rm{parameters}}}}}}}}\; {{{{{{{\rm{using}}}}}}}} \; {{{{{{{{\boldsymbol{g}}}}}}}}}_{\hat{{{{{{{{\mathcal{V}}}}}}}}}}\,{{{{{{{\rm{and}}}}}}}}\; {{{{{{{\rm{AdamW}}}}}}}}\; {{{{{{{\rm{optimizer}}}}}}}}}\)

\({{{{{{{{{\boldsymbol{g}}}}}}}}}_{{{{{{{{\rm{MINE}}}}}}}}}{{{{{{{\boldsymbol{\leftarrow }}}}}}}}{\nabla }_{{{{{{{{\rm{\theta }}}}}}}}}{{{{{{{\rm{MINE}}}}}}}}\left(\lambda,\,t;\theta,\,m\right)({{{{{{{\rm{gradients}}}}}}}}\; {{{{{{{\rm{of}}}}}}}}\; {{{{{{{\rm{MINE}}}}}}}}\; {{{{{{{\rm{estimator}}}}}}}})}\)

\({\theta \leftarrow {{{{{{{\rm{update}}}}}}}}\; {{{{{{{\rm{parameters}}}}}}}}\; {{{{{{{\rm{using}}}}}}}}\,{{{{{{{{\boldsymbol{g}}}}}}}}}_{{{{{{{{\rm{MINE}}}}}}}}}\,{{{{{{{\rm{and}}}}}}}}\; {{{{{{{\rm{Adam}}}}}}}}\; {{{{{{{\rm{optimizer}}}}}}}}}\)

\({\theta \leftarrow \theta /{{{{{{{\rm{spectralnorm}}}}}}}}(\theta )\,({{{{{{{\rm{constrain}}}}}}}}\; {{{{{{{\rm{to}}}}}}}}\,1{\mbox{-}} {{{{{{{\rm{Lipschitz}}}}}}}})}\)

\({{{{{{{{\bf{end}}}}}}}} \, {{{{{{{\bf{for}}}}}}}}}\)

By default, we use a minibatch size of 128 and train for a total of 24 iterations over the dataset, or \(24\times {{{{{{{\rm{ceil}}}}}}}}\left(\frac{{N}_{{{{{{{{\rm{cells}}}}}}}}}}{128}\right)\) total steps. Gradient backpropagation is handled by the PyTorch53 python package, and we used Pyro’s66 implementation of the ELBO objective. In practice, we estimate mutual information using \(m\) samples once per step and use that estimate for gradient calculations for each sample in the minibatch. We use the same mutual information estimate to calculate the gradient with respect to the parameters \(\theta\).

CODAL Bayesian hyperparameter optimization

Single-cell sequencing experiments vary in the number of cells assayed and the biological complexity of the samples. Consequently, a CODAL topic model’s fit on a given dataset is dependent on the value for hyperparameter \({N}_{{{{{{{{\rm{topics}}}}}}}}}\), which sets the dimensionality of the latent space descriptor of biological variation and determines the representational capacity of the model. An appropriate value for \({N}_{{{{{{{{\rm{topics}}}}}}}}}\) will adequately capture relevant covariance structures in the data and represent all detectable cell types without overfitting. With respect to the number of topics, the CODAL objective score appears to correlate with the analytical quality of the model, meaning hyperparameters that give greater CODAL objective scores also give a better latent representation of the dataset67. The \({p}_{{{{dropout}}}}\) hyperparameter sets the appropriate amount of regularization for producing a good fit for a given dataset. The amount of regularization required correlates with the number of topics and tends to scale inversely with dataset size.

Bayesian hyperparameter optimization scheme

Given a dataset \(D\), the CODAL objective function \({{{{{{{{\mathcal{V}}}}}}}}}^{{{{{{{{\rm{CODAL}}}}}}}}}\), and parameters \(\varphi_H\) resulting from training this model with hyperparameters \(H\), we aim to find the set of hyperparameters \({H}_{*}\) in search space \({\mathbb{H}}\) which maximizes:

$${H}_{*}={{{\mbox{argmax}}}}_{H{\mathbb{\in }}{\mathbb{H}}}{{{{{{{{\mathcal{V}}}}}}}}}_{{\varphi }_{H}}^{{{{{{{{\rm{CODAL}}}}}}}}}\left(D\right).$$
(34)

A set of hyperparameters \(H\) is defined as a tuple \(({N}_{{{{topics}}}}\in {{\mathbb{Z}}}_{[1,\infty )},\,{p}_{{{{dropout}}}}\in {{\mathbb{R}}}_{\left[{{{{{{\mathrm{0.05,0.1}}}}}}}\right]})\) which determines the dimensionality of the latent variable and the dropout noise applied to the linear biological decoder model, respectively. The range of possible values for \({N}_{{{{topics}}}}\) is set by the user. For example, the optimal number of topics for a typical PBMC dataset ranges from 10 to 20. The NEURIPS bone marrow gene expression dataset contained 22 topics, and the embryo differentiation dataset contained 65 topics.

To find the hyperparameters that give the best fit on a dataset, we use a Bayesian hyperparameter optimization scheme. Sequentially, a set of hyperparameters is sampled, a model is trained with those parameters, and the model is evaluated on a held-out set of cells. Using the joint distribution of hyperparameters and objective scores from past trials, one predicts scores for untested sets of hyperparameters using a faster-to-compute function. Therefore, one can find and evaluate sets of hyperparameters from subspaces of \({\mathbb{H}}\) which we expect to yield better models, potentially finding the best set of hyperparameters \({H}_{*}\) in fewer steps than by random search.

The CODAL hyperparameter optimization algorithm starts by partitioning the cells in a dataset into training and test portions in a 4:1 ratio, by default, stratified by any covariates supplied by the user. Those partitions are written to disk in fast-loading chunks. For the remainder of the tuning process, data is streamed batch by batch during training to reduce memory overhead. Initially, the tuning algorithm runs 15 startup trails using randomly sampled sets of hyperparameters. In a trial, a CODAL model instantiated with some set of hyperparameters is trained on the training portion of the dataset, and the objective score is evaluated on the testing portion after every epoch. We place checkpoints, or “rungs”, after the 8th and 16th epoch, at which the algorithm compares the current trial’s objective scores to all previously completed trials that reached that rung. Those in-progress trials which fall in the bottom 50th percentile of scores collected at that rung are “pruned”, or their training discontinued. This ensures that computational resources are not wasted on models which are unlikely to offer improvements over trials that have already been computed.

Next, the tuning algorithm switches to Bayesian hyperparameter selection using the algorithm described in the following section (“Bayesian hyperparameter optimization scheme”). We use a Gaussian Process model with a Matern kernel68 to predict the distribution of scores given hyperparameters. The \(\upsilon\) parameter of the Matern kernel is set to 5/2. Tuning iterates for a minimum of 48 trials, after which tuning stops if no improvement in the objective score is recorded in the 12 most recent trials, or if a maximum ceiling of 128 trials is reached. The set of hyperparameters yielding the model with the maximum objective score is taken to best represent the dataset, and that model is returned to the user.

In addition to trial pruning and Bayesian hyperparameter selection, the tuning algorithm employs parallelization of trials to further speed up the tuning process. By default, five concurrent trials may be executed using the SQLite database backend as a message broker, but the optional use of a REDIS database backend enables parallelization to as many cores as desired. We employ the constant liar mean strategy for Bayesian hyperparameter selection during the parallel evaluation of trials. When selecting the next set of hyperparameters to evaluate, there may be many currently running trials for which the objective score is not yet known. The constant liar mean strategy assumes that those trials will result in the mean score of previously completed trials and uses that hypothetical result in the selection of the next set of hyperparameters to evaluate.

This tuning scheme was implemented using the Optuna69 python package.

Bayesian optimization acquisition function with trial pruning

Bayesian optimization is a powerful method for quickly finding sets of hyperparameters that improve model performance on some objective. Briefly, by using a faster-to-compute Bayesian approximation of the objective function, one can sequentially evaluate hyperparameter sets that are expected to perform well based on past trials. At each step, the algorithm chooses the set of hyperparameters which maximizes an “acquisition function” for evaluation in the next trial.

Another popular approach for faster hyperparameter tuning is trial “pruning”70. For each set of hyperparameters, a model trains for some number of epochs before the objective score is evaluated. For many problems, the relative performance of models is comparable even before training is complete. Thus, at checkpoints during training called “rungs”, models are scored using the objective function and the worst \({q}^{{th}}\) percentage of models are pruned, or their training discontinued. Pruning of poorly performing trials ensures computational resources are not wasted on training models which are unlikely to be competitive.

A tuning algorithm that combines Bayesian optimization with trial pruning could potentially speed up hyperparameter optimization more than either algorithm alone. However, most Bayesian optimization algorithms do not account for the possibility that some trails may not be completed due to pruning. Therefore, we implemented a novel Gaussian Process-based acquisition function that extends the popular “expected improvement” function71 for use in conjunction with trial pruning.

Given a history of hyperparameter samples \([{H}_{1},\ldots,\,{H}_{{N}_{{{{{{{\rm{trials}}}}}}}}}]\in {{\mathbb{R}}}^{{N}_{{{{{{{\rm{trials}}}}}}}}\times {N}_{{{{{{{\rm{hyperparameters}}}}}}}}}\) and associated objective scores, \(f\in {\left\{{\mathbb{R}}{{,}}{{\varnothing }}\right\}}^{{N}_{{{{{{{\rm{trials}}}}}}}}\times {N}_{{{{{{{\rm{rungs}}}}}}}}}\) where each trial has a score for each rung that it reached (and a null score, \(\varnothing\), for each that it did not), we select the set of hyperparameters which maximize the \({{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\), the expected improvement (with pruning), to evaluate in the next trial:

$${H}_{{N}_{{{{{{{\rm{trials}}}}}}}}+1}={{{\mbox{argmax}}}}_{H\in {\mathbb{H}}}{{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\left(H\right).$$
(35)

Here, “improvement” refers to the improvement of some score \(f\) recorded at rung \(r\), over the previous best score recorded, \({f}^{*}\):

$$\begin{array}{c}{{{{{{{\rm{Im}}}}}}}}\left(f,\,r\right)=\left\{\begin{array}{c}f-{f}_{*}\,{{{\mbox{if}}}}\,{f}\, > \,{f}_{*}\,{{{{{{{\rm{and}}}}}}}}\,r={N}_{{{{{{{\rm{rungs}}}}}}}}\\ 0\;{{{{{{{\rm{otherwise}}}}}}}}\end{array}\right.,\\ {f}_{*}=\max \{{\, f}_{\tau,{N}_{{{{{{{\rm{rungs}}}}}}}}}{{|}}\tau \in \left\{1,\ldots,\,{N}_{{{{{{{\rm{trials}}}}}}}}\right\},\,{f}_{i,{N}_{{{{{{{\rm{rungs}}}}}}}}}\, \ne \,{{\varnothing }}\}.\end{array}$$
(36)

The improvement of a score is zero if it falls below the previous best score or if that score was recorded before the model completed training (reaching the final rung). Expected improvement, then, is the expected value of the improvement function over the joint distribution of scores and rungs for some set of hyperparameters \(H\):

$${{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\left(H\right){\mathbb{=}}{\mathbb{E}}\left[{{{{{{{\rm{Im}}}}}}}}\left(f,\,r\right)\right]=\mathop{\sum }\limits_{r=1}^{{N}_{{{{{{{\rm{trials}}}}}}}}}{\int }_{-{{\infty }}}^{{{\infty }}}{{{{{{{\rm{Im}}}}}}}}\left(f,\,r\right){p}_{f}\left(f|r,\,H\right){{dfp}}_{r}\left(r|H\right).$$
(37)

The distribution over scores is estimated by the Gaussian Process approximation of the objective. Using the Sklearn Python package’s maximum likelihood regressor, we fit a Gaussian Process (GP) distribution on all tuples (score, hyperparameters, rung) from the training history:

$$\left\{\left({f}_{\tau,r},\,{H}_{\tau },\,r\right){{|}}\tau \in \left\{1,\ldots,\,{N}_{{{{{{{{\rm{trials}}}}}}}}}\right\},\,r \in \big\{1,\, \ldots,\,{N}_{{{{{{{{\rm{rungs}}}}}}}}}\big\},\,{f}_{\tau,r}\,\ne \,{{\varnothing }}\right\},$$
(38)

such that the distribution of scores conditioned on the hyperparameters and the rung is:

$$f{{|}}H,\,r \sim {{{\mbox{GP}}}}\left(H,\,r\right).$$
(39)

Because the improvement function, \({{{{{{{\rm{Im}}}}}}}}\), is zero unless the score \(f\) is greater than the previous best objective score and was recorded on the last rung, in other words \(r={N}_{{{{{{{\rm{rungs}}}}}}}}\), we can simplify the \({{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\) function:

$${{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\left(H\right)={p}_{r}\big(r={N}_{{{{{{{\rm{rungs}}}}}}}}|H\big){\int }_{{f}_{*}}^{{{\infty }}}{{{{{{{\rm{Im}}}}}}}}\big(f,\,{N}_{{{{{{{\rm{rungs}}}}}}}}\big){p}_{f}\big(f|r={N}_{{{{{{{\rm{rungs}}}}}}}},\, H\big){df}.$$
(40)

Next, we reparametrize the distribution for \(f\) in terms of the standard normal random variable \(z\), where the Gaussian Process function, \({{\mbox{gp}}}\), gives the mean and variance of the score distribution:

$$\begin{array}{c}f{{|}}X,\,r \sim {{{\mbox{GP}}}\left(H,\,r\right)=N\big({\mu }_{H,r},\, {\sigma }_{H,r}^{2}\big)=\sigma }_{H,r}z+{\mu }_{H,r},\\ {\mu }_{H,r},\,{\sigma }_{H,r}^{2}={{{\mbox{gp}}}}\left(H,\,r\right),\\ z \sim N\left(0,\,1\right).\end{array}$$
(41)

We can then transform from \(f\)-distributed space to \(z\)-distributed space by standardization:

$${s}_{H,r}\left(f\right)=\frac{f-{\mu }_{H,r}}{{\sigma }_{H,r}}.$$
(42)

For \({{{{{{{{\rm{cdf}}}}}}}}}_{z}\) and \({{{{{{{{\rm{pdf}}}}}}}}}_{z}\) as the cumulative distribution function and probability density function of the standard normal distribution, respectively, the solution to the integral term in the \({{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\left(H\right)\) formula is given by71:

$$\begin{array}{c}{\int }_{{f}_{*}}^{{{\infty }}}{{{\mbox{Im}}}}\left(f,\,{N}_{{{{{{{\rm{rungs}}}}}}}}\right){p}_{f}\left(f|r={N}_{{{{{{{\rm{rungs}}}}}}}},\,H\right){df}=\left({\mu }_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}-{f}_{*}\right){{{{{{{{\rm{cdf}}}}}}}}}_{z}\left({z}_{*}\right)+{\sigma }_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}{{{{{{{{\rm{pdf}}}}}}}}}_{z}\left({z}_{*}\right).\\ {z}_{*}={s}_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}({f}_{*})\end{array}$$
(43)

We contribute the solution to the probability of a trial terminating at a rung given the hyperparameter set \(H\). Typically, pruning algorithms work by setting a threshold score that a trial must exceed at each rung, \({f}_{*r}\). If, at any rung, the threshold is not met, the trial is terminated. The threshold can be the \({q}^{{th}}\) percentile of previous scores recorded at rung \(r\):

$${f}_{*r}={{{\mbox{percentile}}}}_{q}\left\{{f}_{i,r}{{|}}i\in \left\{1,\ldots,\,{N}_{{{{{{{\rm{trials}}}}}}}}\right\},\,{f}_{i,r}\,\ne \,{{\varnothing }}\right\},\, \forall {{\mbox{r}}}\in \left\{1,\, \ldots,\,{N}_{{{{{{{\rm{rungs}}}}}}}}-1\right\}.$$
(44)

The probability that a trial terminates at the last rung (or does not terminate at any previous rungs) is given by:

$${p}_{r}\big({r=N}_{{{{rungs}}}}|H\big) =\mathop{\prod }\limits_{r=1}^{{N}_{{{{{{{\rm{rungs}}}}}}}}-1}{p}_{f}\left(f\, \ge \,{f}_{*r}|r,\,H\right)\\ =\mathop{\prod }\limits_{r=1}^{{N}_{{{{{{{\rm{rungs}}}}}}}}-1}{1-p}_{z}\big(z\, \le \,{s}_{H,r}\left(\, {f}_{*r}\right)\big) \\ =\mathop{\prod }\limits_{r=1}^{{N}_{{{{{{{\rm{rungs}}}}}}}}-1}{{{{{{{{\rm{cdf}}}}}}}}}_{z}\big(-{s}_{H,r}\left({f}_{*r}\right)\big).$$
(45)

In summary, given a history of trials and associated scores, we execute the next trial using the set of hyperparameters that maximize the expected improvement acquisition function:

$$\begin{array}{c}{{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\left(H\right)=\mathop{\prod }\limits_{r=1}^{{N}_{{{{{{{\rm{rungs}}}}}}}}-1}{{{{{{{{\rm{cdf}}}}}}}}}_{z}\left(-{s}_{H,r}\left({f}_{*r}\right)\right)\left[\left({\mu }_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}-{f}_{*}-\xi {f}_{*}\right){{{{{{{{\rm{cdf}}}}}}}}}_{z}\left({z}_{*}\right)+{\sigma }_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}{{{{{{{{\rm{pdf}}}}}}}}}_{z}\left({z}_{*}\right)\right],\\ {z}_{*}={s}_{{H,N}_{{{{{{{\rm{rungs}}}}}}}}}({f}_{*})\end{array}$$
(46)

Above, we add the \(\xi\) parameter, which we set at \(0.1\) by default. This has the effect of overestimating the previous best score when calculating expected improvements, leading to more exploratory hyperparameter choices. For each trial, we test \({{{{{{{\rm{E}}}}}}}}{{{{{{{{\rm{I}}}}}}}}}_{{{{{{{{\rm{p}}}}}}}}}\) for 300 randomly sampled sets of hyperparameters.

CODAL technical effect augmentation of gene regulation models

Because the technical effects estimated from a CODAL model are effectively independent of the estimates of biological quantities, the technical effect predictions are transferrable: they may be used to augment other scRNA-seq or scATAC-seq read count generative models which do not adjust for technical effects. For example, by fixing the technical effect vectors learned from CODAL as an additive and independent component of the generative distribution for counts, then the biological quantities \(\lambda\) may be re-estimated using some other set of features to give \({\lambda }^{{{{{{{{\rm{new}}}}}}}}}\). The independence assumptions of the generative distribution hold if \(\lambda \perp t\) and \(\lambda \, \approx \,{\lambda }^{{{{{{{{\rm{new}}}}}}}}}\), as we may reasonably expect \({\lambda }^{{{{{{{{\rm{new}}}}}}}}}\perp t\).

The subsequent general generative distribution for counts \({X}_{\cdot j}\) for gene \(j\), using noise distribution \({{{{{{{\mathcal{D}}}}}}}}\), and re-estimated expression rates \({\lambda }^{{{{{{{{\rm{new}}}}}}}}}\) accounts for technical effect differences between cells. The parameters of this generative distribution may be inferred without mutual information regularization:

$$\begin{array}{c}{X}_{{ij}}{{{{{{{\mathscr{ \sim }}}}}}}}{{{{{{{\mathcal{D}}}}}}}}\left({\rho }_{{ij}}^{{{{{{{{\rm{new}}}}}}}}}\right),\\ {\rho }_{{ij}}^{{{{{{{{\rm{new}}}}}}}}}=\frac{{{\exp }}\left({\lambda }_{{ij}}^{{{{{{{{\rm{new}}}}}}}}}+{{{{{{t}}}}}}_{{{{{{ij}}}}}}\right)}{{\kappa }_{i}},\\ \begin{array}{c}{\lambda }_{{ij}}^{{{{{{{{\rm{new}}}}}}}}}=f\left(\ldots \right),\\ \forall {i}\in \{1,\ldots,\,{N}_{{{{{{{{\rm{cells}}}}}}}}}\}.\end{array}\end{array}$$
(47)

Above, \({{{{{{t}}}}}}_{{{{{\cdot }}}}{{{{{j}}}}}}\) is given as a non-trainable vector of technical effects in each cell and the \({\kappa }_{i}\) term is a scalar which represents the denominator of the softmax transformation, \(\mathop{\sum }\nolimits_{l=1}^{{N}_{{{{{{{{\rm{genes}}}}}}}}}}{{{{\exp }}}}\left({\lambda }_{{il}}^{{{{{{{{\rm{new}}}}}}}}}+{t}_{{il}}\right)\). The CODAL generative model estimates this denominator quantity, which can be used as an approximation for \({\kappa }_{i}\):

$${\kappa }_{i}\approx \mathop{\sum }\limits_{l=1}^{{N}_{{{{{{{\rm{genes}}}}}}}}}{{\exp }}({\lambda }_{{il}}+{t}_{{il}}).$$
(48)

In summary, fixing CODAL technical effects as a component of the generative distribution outlined above enables re-estimation of expression rates \({\lambda }_{\cdot j}^{{{{{{{{\rm{new}}}}}}}}}=f(\ldots )\) while still accounting for differences in technical effects between cells. The parameters of the \(f\) function may be inferred without using further mutual information regularization since technical effect confounders have already been disentangled.

CODAL technical effect augmentation of regulatory potential model

We applied technical effect augmentation to the MIRA regulatory potential (RP) model28, which relates changes in local chromatin accessibility to gene expression by estimating upstream and downstream exponential decay rates of apparent regulatory influence. The generative model of observed gene expression counts is the same as that used in the CODAL topic model, except expression rates \({\lambda }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}}\) are estimated from the local chromatin accessibility states in cells instead of as a linear function of latent variables. Pure multimodal scRNA-seq and scATAC-seq measurements from the same cells are needed to learn this cis-regulatory relationship, so this model is subject to batch effects from both modalities.

Below, \({n}_{i}\) is the estimated read depth estimated by the CODAL topic model; \({\theta }_{j}^{{{{{{{{\rm{RP}}}}}}}}}\in {{\mathbb{R}}}_{\left(0,\infty \right)}\) is the dispersion parameter of the Negative Binomial noise distribution; \({{\mathfrak{D}}}_{j\eta }\,{{\mbox{for}}}\,\eta \in \{{{\mbox{U}}},\, {{\mbox{D}}},\, {{\mbox{P}}}\}\) are the genomic interval sets containing the peaks upstream (\({{\mbox{U}}}\)), downstream (\({{{{{{{\rm{D}}}}}}}}\)), and proximal to the TSS (the promoter, \({{{{{{{\rm{P}}}}}}}}\)); \({\delta }_{i}\in {{\mathbb{Z}}}_{\left[0,\infty \right)}^{{N}_{{{{{{{{\rm{genes}}}}}}}}}\times {N}_{{{{{{{{\rm{peaks}}}}}}}}}}\) is the genomic distance in kilobases between every gene and every peak; \({a}_{j}\in {{\mathbb{R}}}_{\left(0,\infty \right)}^{3}\) are the upstream, downstream, and promoter effect coefficients; \({\Delta }_{j}\) \(\in {{\mathbb{R}}}_{\left(1,\infty \right)}^{3}\) are the upstream, downstream, and promoter decay distances of influence (the promoter has \({\Delta }_{{jP}}=\infty\)); and \(A\in {{\mathbb{R}}}_{\left[{{{{{{\mathrm{0,1}}}}}}}\right)}^{{N}_{{{{{{{{\rm{cells}}}}}}}}}\times {N}_{{{{{{{{\rm{peaks}}}}}}}}}}\) is the compositional accessibility rate of peaks in cell. To sample from the RP model generative distribution of observed scRNA-seq counts, \({X}_{{ij}}^{{{{{{{{\rm{RNA}}}}}}}}}\), for gene \(j\) in cell \(i\):

$$\begin{array}{c}{X}_{{ij}}^{{{{{{{{\rm{RNA}}}}}}}}} \sim {{{{{{{\rm{NegativeBinomial}}}}}}}}\left({n}_{i}{\rho }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}},\,{\theta }_{j}^{{{{{{{{\rm{RP}}}}}}}}}\right),\\ {\rho }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}}=\frac{{{\exp }}\left({\lambda }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}}+{{{{{{t}}}}}}_{{{{{{ij}}}}}}\right)}{{\kappa }_{i}},\\ \begin{array}{c}{\lambda }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}}={\gamma }_{j}^{{{{{{{{\rm{RP}}}}}}}}}\left(\frac{{c}_{{ij}}-{\mu }_{{c}_{\cdot j}}}{{\sigma }_{{c}_{\cdot j}}}\right)+{b}_{j}^{{{{{{{{\rm{RP}}}}}}}}},\\ \begin{array}{c}{c}_{{ij}}={{{{{{{\rm{RP}}}}}}}}\left({{\mathfrak{D}}}_{j\cdot },\,{A}_{i \cdot },\,{a}_{j\cdot },\,{\delta }_{j \cdot },\,{\Delta }_{j\cdot }\right)=\mathop{\sum}\limits_{\eta \in \left\{{{{{{{{\rm{U}}}}}}}},{{{{{{{\rm{D}}}}}}}},{{{{{{{\rm{P}}}}}}}}\right\}}{a}_{j\eta }\mathop{\sum}\limits_{k \in {{\mathfrak{D}}}_{j\eta }}{{A}_{{ik}}2}^{-{\delta }_{{jk}}/{\Delta }_{j\eta }},\\ \begin{array}{c}{A}_{{ik}}={\bar{\rho }}_{{ik}}^{{{{{{{\rm{ATAC}}}}}}}},\\ \forall {i}\in \{1,\ldots,\,{N}_{{{{{{{{\rm{cells}}}}}}}}}\}.\end{array}\end{array}\end{array}\end{array}$$
(49)

In extending this model to account for scRNA-seq technical effects, we only add the technical effects \({{{{{{t}}}}}}_{{{{{\cdot }}}}{{{{{j}}}}}}\) for gene \(j\) when calculating \({\rho }_{{ij}}^{{{{{{{{\rm{RP}}}}}}}}}\). To adjust for scATAC-seq technical effects, we take \({A}_{{ik}}\) to equal \({\bar{\rho }}_{{ik}}^{{{{{{{\rm{ATAC}}}}}}}}\) from the CODAL model, which is the compositional distribution of accessibility without technical effect confounding. The parameters of the technical effect augmented model are estimated using the same MAP inference procedure as MIRA.

CODAL probabilistic in silico deletion analysis

We used probabilistic in silico deletion as described in MIRA28, except we replaced the MIRA regulatory potential generative model with the generative model outline above for tests with technical effect augmentation. The strength of association between a transcription factor motif and a gene is then the difference in likelihood of the generative parameters given all proximal accessible chromatin versus when accessible sites which contain that motif are masked, or in silico deleted. We use a Wilcoxon test to assess the enrichment of association between that motif and set of genes versus background levels of association across all other genes.

Analysis of NEURIPS bone marrow dataset

Latent space benchmarking

We preprocessed gene expression data for the NEURIPS bone marrow33 using the scanpy72 python package standard workflow. Batches were concatenated together, cells with fewer than 400 counts and genes with fewer than 30 counts filtered out, cell read depths normalized using the normalize_total function with target_sum set to 10000, counts log+1 transformed, and highly variable genes found with a minimum dispersion threshold of 0.3, yielding 3500 highly variable genes. Latent spaces for each method were calculated based on the expression of these 3500 genes across all cells. We used sequencing site and donor attributes of cells as technical covariates. We benchmarked CODAL against the MIRA topic model28, scanpy principal component analysis (PCA), scVI20, scANVI18, scanpy PCA+Harmony19, and scanorama25. For scVI and scANVI, we used the same hyperparameters as the scib21 python package. Harmony and scanorama were tested with default parameters.

For the NEURIPS bone marrow ATAC-seq data, cells with less than 400 peaks and peaks found in less than 30 cells were filtered out. Latent spaces for each method were calculated based on all remaining peaks. We used sequencing site and donor attributes of cells as technical covariates, in addition to the FRiP score for the CODAL model. We benchmarked CODAL against the MIRA topic model, PEAKVI73, the Scikit-learn74 implementation of Latent semantic indexing (LSI)75, and LSI+Harmony. For both GEX and ATAC CODAL topic models, we performed hyperparameter tuning for 32 iterations with the range of possible \({N}_{{{{{{{{\rm{topics}}}}}}}}}\) set to 15–40.

UMAPs76 for each latent space were calculated using the UMAP-learn python package, with the min_dist parameter set to 0.1 and negative_sample_rate set to 3. Silhouette widths77 for each cell were calculated using the Scikit-learn74 python package. For cell type label silhouette, we calculated the silhouette width for each cell with respect to the expert-annotated cell type labels provided with the dataset. A higher silhouette width means a cell was more closely grouped with cells of the same label. For batch silhouette width, we calculated silhouette width with respect to the joint label of batch and cell type. In this case, a lower score means a cell is more intermixed with cells from different batches. Average silhouette width was calculated as the mean of silhouette widths across all cells. Both cell type and integration Local Inverse Simpson’s Index (cLISI and iLISI)19 were calculated using the scib package.

Entangled gene representations

To assess the influence of disentanglement on expression versus technical effect predictions, we trained a topic model using the same parameters as the CODAL model, but with the weight of the mutual information regularization term of the objective function set to zero.

Regulatory potential analysis

We trained regulatory potential models on the multimodal NEURIPS bone marrow dataset for 2841 genes which were both highly variable and had UCSC TSS annotations78. We used the same MAP parameter inference procedure as MIRA. To assess the effect of CODAL technical effect augmentation on the likelihood of RP models, we split each of the 13 batches in the dataset into training and test sets in a 4:1 ratio. Then, for each batch, we trained RP models without technical effect augmentation on only the training data of that batch. We also trained RP models with and without augmentation on the combined training set across all batches. Finally, we evaluated the likelihood of the RP models trained under each condition on the test set of each batch.

Probabilistic in silico deletion analysis

We selected the top 200 genes from topic 4 to represent a set of genes that were upregulated in Proerythroblast cells in the NEURIPs bone marrow dataset. We evaluated those genes for motif enrichment versus the rest (2641 highly variable genes) using the JASPAR79 2020 vertebrate position weight matrix (PWM) collection and probabilistic in silico deletion28. For each of 1641 PWMs tested, we compared p-values of enrichment assessed with regulatory potential models trained across all batches versus models trained across all batches and augmented with CODAL technical effects. We used Bonferroni-corrected p-values to assess the significance of enrichment.

Frankencell batch-confounded cell type tests

We performed benchmarking to compare CODAL’s ability to integrate differentiation trajectories with batch-confounded cell types against popular methodologies. Inspired by the scenario where a wild-type cell atlas suggests future knockout or perturbation experiments, we created a synthetic dataset generation system based on the “Frankencell” python program, where we varied cell abundances along pre-defined trajectory states to introduce varying levels of batch-confounded biology.

Frankencell benchmark generation

Frankencell generates synthetic differentiation trajectories by mixing reads from individual cells sampled from distinct, well-defined cell populations from real single-cell RNA-seq or ATAC-seq data. By defining a construction plan in which cell state trajectories interpolate between cell types, Frankencell creates continuous cell state transitions which maintain the statistical properties of counts from real data, but for which the ground truth state of every cell is known. Frankencell can also simulate batch effects by mixing reads from one batch of cells in a multi-batch dataset. In this way, the synthetically mixed cells will also have counts biased by the same technical effects as the reference sample.

For this test, we constructed synthetic datasets using reads sampled from the NEURIPS bone marrow gene expression dataset. We used the expert-annotated cell types as the pure cell type clusters from which reads were mixed. To simulate batch effects, we constructed two trajectories per dataset: one composed of reads from the “site 3, donor 9” batch and one composed of reads from the “site 4, donor 1” batch.

We defined the construction plan as a tree-structured graph, with the root node composed of reads from the HSC cell type and populations that branch to form “lymphoid” and “monocyte” trajectories. The node tree structure and the mixing weights of each node are shown below:

Nodes

Mixing weights

NodeID

Cell type

Parent Node

\({\pi }_{{{{{{{{\rm{HSC}}}}}}}}}\)

\({\pi }_{{{{{{{{\rm{CD}}}}}}}}16+}\)Mono

\({\pi }_{{{{{{{{\rm{B}}}}}}}}1{{{{{{{\rm{B}}}}}}}}}\)

\({\pi }_{{{{{{{{\rm{NK}}}}}}}}}\)

\({\pi }_{{{{{{{{\rm{CD}}}}}}}}8+{{{{{{{\rm{T}}}}}}}}}\)

\({\pi }_{{{{{{{{\rm{pDC}}}}}}}}}\)

\({\pi }_{{{{{{{{\rm{cDC}}}}}}}}2}\)

root

“HSC”

 

1

      

1

 

root

\(0.5\)

0.125

0.125

0.125

 

0.125

 

2

 

1

0.2+k

0.4-k/2

   

0.2-k/4

0.2-k/4

4

“Mono”

2

k

1-2k

   

k/2

k/2

5

“Dendritic”

2

k

k

   

0.5-k

0.5-k

3

 

1

0.2+k

 

0.4-k/2

0.4-k/2

   

6

“B-cell”

3

k

 

1-2k

k/2

k/2

  

7

“T-cell”

3

k

 

k

0.5-k

0.5-k

  

Where mixing weights for a node sum to one (zero weights are left blank), and the \({{{{{{{\rm{k}}}}}}}}\) parameter governs the base cell similarity across all cell types. The edges between nodes encode the continuous paths that cells follow through the trajectory structure. To create a smooth transition between states, for a cell that has progressed a fraction \(p\) along an edge, its mixing weights were calculated as a sigmoidal transformation, \(\sigma\), of the start and end node, \({\pi }_{{{{{{{\rm{start}}}}}}}}\) and \({\pi }_{{{{{{{\rm{end}}}}}}}}\), of that edge:

$${\pi }_{{{{{{{\rm{cell}}}}}}}} ={\pi }_{{{{{{{\rm{start}}}}}}}}\left(1-\delta \right)+{\pi }_{{{{{{{\rm{end}}}}}}}}\delta,\\ \delta =\sigma \left(2p-1\right).$$
(50)

Next, we defined the Markov transition matrix for a cell progressing through the trajectory starting from the root node:

Node 1

Node 2

Transition probability

Root: “HSC”

1

1.

1

2

0.4

1

3

0.6

2

4: “Mono”

0.5

2

5: “Dendritic”

0.5

3

6: “B-cell”

\({1-P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\)

3

7: “T-cell”

\({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\)

Where the parameter \({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\) controls the proportion of cells that transition to the “T-cell” versus “B-cell” terminal states. To generate a synthetic cell, we sampled a path through the trajectory graph starting from the root node and following the Markov transition matrix above until reaching a terminal state. Then, we sampled a progress value according to a beta (0.5,1) distribution to place the cell on an edge of the cell state tree along that path and calculated the cell’s read mixing weights based on the sigmoidal interpolation of node-defined mixing proportions. Finally, to sample reads to represent each synthetic cell, we randomly selected one real single cell from each population in the reference dataset and hypergeometrically sampled reads from those cells to fulfill their respective contributions according to the mixing weights. Read depths for each cell were sampled from a LogNormal distribution.

In summary, each Frankencell dataset was composed of two trajectories of 2000 cells, where each trajectory was sampled exclusively from reads from a single batch and generated according to a construction plan with parameters \({{{{{{{\rm{k}}}}}}}}\) and \({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\). In this way, the ground truth cell state and batch identities were known for each cell, and we controlled the difficulty of the test by increasing the base cell-cell similarity and introducing batch-confounded cell types. We generated batch-confounded cell types by first incrementally increasing \({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\) for batch 1, then fixing \({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\) at 1 and incrementally decreasing \({P}_{{{{{{{\rm{T}}}}}}}-{{{{{{\rm{cell}}}}}}}}\) for batch 2. We repeated this depletion process for \({{{{{{{\rm{k}}}}}}}}\in \{0,\,0.05,\,0.1\}\).

Performance evaluation on the Frankencell benchmark

To evaluate the performance of different batch correction methods on these datasets, we calculated latent spaces using counts of highly variable genes and the batch of origin for each cell. We used the MIRA pseudotime trajectory inference algorithm to solve the structure of the resulting latent space and compared it to the ground truth result using established metrics implemented in the dynverse package. Importantly, the trajectory inference algorithm was batch-unaware, so unintegrated batches would score poorly, even if the trajectories within each batch were coherent.

We benchmarked CODAL against PCA+Harmony and scVI. For CODAL, we used standard parameters for tuning, with a topic range of 3-10. We applied the Harmony algorithm to the first 10 principal components calculated using scanpy PCA on log+1 transformed highly variable gene expression counts. For scVI, we used the default parameters used in the scib benchmarking package but found the number of latent dimensions allocated to the model was highly influential on the quality of the latent space. On this dataset, greater latent dimension size resulted in improved marginal likelihoods, even while giving worse solutions to the trajectory. Therefore, we trained models with 3 to 6 latent dimensions and evaluated each for trajectory quality. For each test, we compared the other methods to whichever scVI model scored the highest.

Using the dynverse48 package, we evaluated integrated trajectories on edge flip, branch F1 score, and pseudotime correlation. Edge flip measures the minimal number of edge additions or subtractions needed to convert the test model’s inferred trajectory graph into the ground truth graph, divided by the total number of edges in both graphs (and normalized so that 1 is a perfect score). Pseudotime correlation measures the correlation of temporal geodesic distances between cells in the test versus ground truth trajectories. Branch F1 score quantifies the closeness of the predicted cell state assignment compared to the ground truth. Finally, we calculated iLISI on the Frankencell datasets using the scib package, excluding the “T-cell” and “B-cell” states which are sometimes batch-confounded. The overall score was calculated as the geometric mean of all metrics for each test.

Evaluation of variance of technical effect estimation

To analyze the effect of the mutual information regularizer on the repeatability and quality of CODAL representations, we modulated its weight in the CODAL objective by introducing a multiplier weight \(\rho\):

$${{{{{{{{\mathcal{V}}}}}}}}}_{\varphi,\rho }^{{{{{{{{\rm{CODAL}}}}}}}}}\left({X}_{i},\,{C}_{i}\right)={{\log }}\int {p}_{\varphi }\left({X}_{i}\,{{|}}\,Z,\,{C}_{i}\right){p}_{\varphi }\left(Z\right){dZ}-\rho I\left(\lambda,\,t\right).$$
(51)

When this weight is 0, the CODAL objective reduces to marginal likelihood maximization. For each weight \(\rho \in \left\{0,\,\frac{1}{2},\,{{{1,\,2,\,4,\,8}}}\right\},\) we trained ten CODAL models with different initial seeds on the same “completely confounded” Frankencell dataset with difficultly \({{{{{{{\rm{k}}}}}}}}=0.05\). For each model, we collected estimates of biological and technical effects and evaluated the trajectory reconstruction quality using the branch F1 score implemented by dynverse.

Thus, for each weight, we collected data \(({{{{{{{{\boldsymbol{\lambda }}}}}}}}}^{\rho },\,{{{{{{{{\boldsymbol{t}}}}}}}}}^{\rho })\), where \({{{{{{{{\boldsymbol{\lambda }}}}}}}}}^{\rho }\) and \({{{{{{{{\boldsymbol{t}}}}}}}}}^{\rho }\) are tensors of size \(({N}_{{{{{{{\rm{cells}}}}}}}}\times {N}_{{{{{{{\rm{genes}}}}}}}}\times {N}_{{{{{{{\rm{models}}}}}}}}).\) To assess the repeatability of CODAL representations, we used variance decomposition across the ten estimates collected for a certain mutual information regularization weight \(\rho\):

$${{{{{{{\rm{var}}}}}}}}\left({{{{{{{{\boldsymbol{\lambda }}}}}}}}}_{{{\cdot }}j\cdot }^{\rho }\right)={\mathbb{E}}\left[{\left({{{{{{{{\boldsymbol{\lambda }}}}}}}}}_{{{{{{{{\rm{ij}}}}}}}}\cdot }^{{{{{{{{\rm{\rho }}}}}}}}}-{\mathbb{E}}\left[{{{{{{{{\boldsymbol{\lambda }}}}}}}}}_{{{{{{{{\rm{ij}}}}}}}}\cdot }^{{{{{{{{\rm{\rho }}}}}}}}}\right]\right)}^{2}|{{{{{{{\rm{cell}}}}}}}}={{{{{{{\rm{i}}}}}}}}\right]+{{{{{{{\rm{var}}}}}}}}\left({\mathbb{E}}\left[{{{{{{{{\boldsymbol{\lambda }}}}}}}}}_{{{{{{{{\rm{ij}}}}}}}}\cdot }^{{{{{{{{\rm{\rho }}}}}}}}}|{{{{{{{\rm{cell}}}}}}}}={{{{{{{\rm{i}}}}}}}}\right]\right),$$
(52)

where the first term is the expected squared difference across cells between each estimate and the mean estimate in that gene and that cell across all models, and the second term is the variance across cells of the mean estimate across models. Intuitively, the first term measures how much each model’s estimate varies about the average estimate—the repeatability. We repeated the analysis above for each gene and each \(\rho\).

Analysis of mouse embryo differentiation and perturbation

The mouse embryonic differentiation dataset8 was preprocessed using the standard scanpy workflow. All batches were concatenated together, cells with fewer than 400 counts and genes with fewer than 30 counts were filtered out, cell read depths were normalized using the normalize_total function with target_sum set to 10000, counts were log+1 transformed and highly variable genes found with a minimum dispersion threshold of 0.5. CODAL and scanpy PCA latent spaces were calculated from expression of highly variable genes. The CODAL model covariates were defined as the following four categorical variables: the sequencing batch, whether or not the batch was chimeric, whether or not the batch was created from mixing embryos at different points in gastrulation, the day of development at which the batch was collected. We performed 60 iterations of hyperparameter tuning over a number of topics (\({N}_{{{{{{{{\rm{topics}}}}}}}}}\)) ranging from 35 to 80. The final model had 63 topics. For the “confounded batch” experiment, we used the same highly variable genes and range of possible \({N}_{{{{{{{{\rm{topics}}}}}}}}}\). UMAPs for each latent space were calculated using the UMAP-learn python package, with the min_dist parameter set to 0.1, and negative_sample_rate set to 3.

To subcluster chimeric hemato-endothelial progenitor (HE-prog) cells, we performed rough clustering across the whole dataset using the Leiden algorithm and selected a cluster that primarily contained HE-prog cells. Then, we subset again to include only Tom+ Tal1−/− and ran the Leiden algorithm to obtain high-resolution subclusters. We matched subclusters to mesodermal cell type populations using expression of marker genes provided in Pijuan-Sala et al.8 and shared topic latent variable composition.

Computational resources benchmarking

We benchmarked the computational resources required to train CODAL models using publicly-available scRNA-seq and scATAC-seq datasets published by 10x Genomics. For both modes, we simulated datasets of different sizes by downsampling either the number of cells, number of features, or both. For the gene expression model, we trained models on a CPU with each of 4, 8, 16, 32, 64, or 128 topics on datasets with each combination of 1000, 2000, or 4000 features and 2000, 4000, 8000, or 16000 cells. For the chromatin accessibility model, we trained models on an RTX2070 Super GPU with each of 4, 8, 16, 32, 64, and 128 features on datasets with each combination of 50, 100, or 150 thousand features and 1000, 2000, 4000, or 8000 cells. We tracked the total time elapsed and maximum memory used during training.

Statistics and reproducibility

No data were excluded from the analyses, and the investigators were not blinded to allocation during experiments and outcome assessment.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.