Main

The pre-training (PT)/fine-tuning (FT) learning paradigm (also known as transfer learning) has had a tremendous impact on natural language processing (NLP) and related domains1,2,3. PT/FT methods have produced models capable of providing free-text answers to natural language questions4, predicting properties of proteins from sequences5 and enabling reaction synthesis prediction from molecular simplified molecular-input line-entry system (SMILES) strings6, among other advancements.

In NLP or NLP-derived PT/FT, for a given pre-training data modality \({\mathcal{X}}\), we are given a dataset \({{{\bf{X}}}}\in {{{{\mathcal{X}}}}}^{{N}_{{{{\rm{PT}}}}}}\) of size \({N_{\mathrm{PT}} \in \mathbb{Z}}\) and pre-train an encoder \({f}_{{{{\mathbf{\uptheta }}}}}:{{{\mathcal{X}}}}\to {{{\mathcal{Z}}}}\) parametrized by \({{\vec{\theta}}}\), which maps \({{{\mathcal{X}}}}\) into a latent space \({{{\mathcal{Z}}}}\). This encoder fθ is then transferred for use in various FT tasks (which are not known during PT). We evaluate PT/FT systems via the performance of fθ on said FT tasks.

In this Article, we are concerned primarily with the efficacy of PT/FT for downstream tasks that operate at a per-sample level. For example, in NLP, evaluating the sentiment of a full restaurant review is a per-sample task, in contrast to identifying a named entity token within a sentence, which is an intra-sample, per-token task. One aspect of PT that drives such eventual FT performance is the induced geometry of the pre-trained, per-sample latent space \({{{\mathcal{Z}}}}\) (formally defined in Methods). For example, it is well documented that the sentence embeddings produced by pre-trained language models in NLP can be non-smooth and anisotropic, which harms downstream task performance7 (note that our use of the term language model refers to methods designed to produce embeddings or enable FT off of pre-trained language models, not to autoregressive language models for generation). In other domains, such as biomedical modalities, where per-sample tasks are even more prevalent than intra-sample tasks compared with NLP, the importance of this geometry only increases. Despite this importance, research into mechanisms to induce explicit, deep structural constraints in \({{{\mathcal{Z}}}}\) is limited. For example, many methods ignore the geometry of \({{{\mathcal{Z}}}}\) by imposing no PT loss over the whole-sample embeddings3,8,9. Other methods impose either only shallow constraints, such as through an auxiliary classification PT objective1,10,11, or deeper structural constraints, but in an implicit manner, such as through data augmentation-based12,13,14,15,16,17 or noising-based18,19 contrastive losses. While such methods can be powerful and have been successful in many areas, we argue that the lack of a clear framework to design PT methods that impose structural constraints on \({{{\mathcal{Z}}}}\) that are simultaneously explicit (similar to supervised classification losses) and deep (similar to noising-based and augmentation-based contrastive losses) is a substantial weakness.

On the basis of this observation, we develop a framework under which the PT objective is subdivided into two components: first, a language model imputation or denoising objective that leverages intra-sample relationships, and second, a loss term driven to regularize the geometry of the per-sample latent space \({{{\mathcal{Z}}}}\) to reflect the connectivity patterns of a user-specified graph GPT. By relying on graphs to capture the structure we wish to induce in \({{{\mathcal{Z}}}}\), this framework allows us to specify PT methods that induce deep structure in an explicit manner, filling exactly the gap identified above. In addition, this paradigm can capture diverse relationships, such as those motivated by external knowledge (for example, ref. 20), self-supervised constraints (for example, refs. 21,22) or distances between samples in an alternative modality (for example, ref. 23). Moreover, this PT framework is simultaneously specific to allow us to make theoretical guarantees about how different PT graphs impact FT performance, general enough to encompass a variety of PT methods and sufficiently expressive to motivate new PT methods that have not been previously studied. In addition to theoretical analysis, we demonstrate empirically that defining new methods according to our framework, using explicit forms of real-world structure, yields significant benefits over competitive PT baselines across three modalities and ten FT tasks.

Our work advances PT/FT research through three contributions. First, through a comprehensive review and detailed commentary, we show that existing PT methods do not induce structural constraints over \({{{\mathcal{Z}}}}\) that are simultaneously deep and explicit. Second, we establish a framework for describing PT methods, which provides a mechanism to design PT methods that explicitly induce deep structural constraints in \({{{\mathcal{Z}}}}\) by a user-specified PT graph GPT. We further support this framework with theoretical results quantifying how the graph’s structure relates to FT task performance. Crucially, this formalization in our new PT paradigm offers insight into when PT does or does not add value over supervised learning alone. Third, we show that structure-inducing PT methods through our framework perform at or above the level of existing PT methods across three data modalities and ten FT tasks.

Results

General PT problem formulation

Given a dataset \({{{{\bf{X}}}}}_{{{{\rm{PT}}}}}\in {{{{\mathcal{X}}}}}^{{N}_{{{{\rm{PT}}}}}}\), a PT method aims to learn an encoder \({f}_{{{{\bf{\uptheta }}}}}:{{{\mathcal{X}}}}\to {{{\mathcal{Z}}}}\) such that fθ can be transferred to FT tasks that are unknown at PT time. While we can leverage additional information at PT time to inform the training of fθ (for example, PT-specific labels YPT), the encoder fθ must take only samples from \({{{\mathcal{X}}}}\) as inputs so that it can be used for FT. PT methods typically solve this problem by training fθ to minimize a PT loss \({{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}\) over XPT. For example, in the model Bidirectional Encoder Representations from Transformers (BERT), \({{{\mathcal{X}}}}\) consists of free-text samples, fθ is a transformer model and \({{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}\) consists of both a masked language modelling per-token loss and the next-sentence-prediction (NSP) per-sample loss1.

Our definition of PT ignores secondary applications of the PT objective; for example, autoregressive language models (for example, Generative Pre-trained Transformer (GPT)-3 (ref. 3)) are often used for their generative use directly and not as commonly used to acquire embeddings or in transfer learning. Therefore, we are primarily interested in PT methods derived from NLP PT methods. This area is of particular interest because methods have been successful within NLP1,3,24, have motivated a large number of derived methods in non-language, biomedical modalities25,26,27,28 and are not yet fully technically understood7,29,30.

Defining explicit and deep structural constraints

Central to our hypothesis is the claim that most NLP-derived PT methods today do not impose explicit, deep constraints on the (per-sample) latent space geometry of \({{{\mathcal{Z}}}}\). To justify this claim, we define explicit and deep structural constraints through the following definitions.

Definition 1 explicit versus implicit structural constraints

A PT objective \({{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}\) imposes a structural constraint that is explicit (versus implicit) to the degree that it (as fθ approaches optimality) permits us to reason directly about the relationship (in particular, the distance) between any two samples zi and zj in the latent space \({{{\mathcal{Z}}}}\), where subscripts i and j are merely used to differentiate between these two samples in \({{\mathcal Z}}\).

Definition 2 deep versus shallow structural constraints

A PT objective \({{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}\) imposes a structural constraint that is deep (versus shallow) based on how much information (for example, how many dimensions) would be required to fully satisfy the constraint.

For example, consider a classification PT loss with labels in the set \({{\mathcal Y}}\), with sample i having label \({y}_{i}\in {{{\mathcal{Y}}}}\), and using a logit layer that maps the induced representation of sample i to a predicted score: \({{{{\mathbf{z}}}}}_{i}\mapsto \tilde{{y}_{i}}\). This method produces an explicit structural constraint because, near optimality, we can infer that the relative (cosine) distance between two samples zi and zj is small if and only if yi = yj. However, this constraint is also shallow because to fully satisfy this constraint, we need only embed each class \(c\in {{{\mathcal{Y}}}}\) with a unique position \({{{{\bf{p}}}}}_{c}\in {{{\mathcal{Z}}}}\), then compress all samples zi near their class prototype \({{{{\bf{p}}}}}_{{y}_{i}}\). Moreover, this distance-based constraint can be accomplished in a very-low-dimensional space \({{{\mathcal{Z}}}}\) (for example, we can distribute each pc uniformly about a two-dimensional unit circle, then compress all zi to appear at a minimal cosine distance from their class prototypes), illustrating that this constraint is very shallow.

In contrast, consider a contrastive method that asserts that zi = fθ(xi) should be close to \({{{{\bf{z}}}}}_{i}^{{\prime} }={f}_{{{{\bf{\uptheta }}}}}(\widetilde{{{{{\bf{x}}}}}_{i}})\), where \({\widetilde{\vec{x}_i}}\) is a perturbed version of \({\vec{x}_i}\) under some noising or augmentation procedure \({{{{\bf{x}}}}}_{i}\mapsto \widetilde{{{{{\bf{x}}}}}_{i}}\), but simultaneously far from other samples zj. While this method constrains the latent space to be smooth with respect to the noising process, it offers only an implicit constraint on \({{{\mathcal{Z}}}}\) as it is generally not possible to infer how the distance between distinct samples zi and zj is constrained. However, it imposes a deeper constraint than the classification objective because the implicit connections between samples induced by the noising procedure reflect relationships that cannot necessarily be captured in a low-dimensional space (dependent on dataset size and density).

Existing PT method constraints

To show that existing methods broadly do not provide means to impose structural constraints that are simultaneously deep and explicit, we survey over 90 existing PT methods based on how their objective functions constrain the \({{{\mathcal{Z}}}}\) (Extended Data Fig. 1 and Supplementary Information). For full details on our review findings, see Methods. Throughout all examined methods, we find that deep, explicit structural constraints are rarely employed. Instead, most methods either (1) impose no per-sample PT objectives at all (for example, text-generation models, which are often not used for embeddings at all but rather for prompting or generative applications3,8,9,31), (2) use explicit, but shallow, supervised PT objectives (for example, BERT’s NSP objective, A Lite BERT’s (ALBERT’s) sentence-order prediction (SOP) objective or various multi-task objectives1,10,11), or (3) use implicit, but deep, unsupervised or self-supervised contrastive PT objectives (for example, contrastive sentence embedding losses12,13,18,19,32 or other noising-based or augmentation-based approaches14,15,16,17).

Across all surveyed methods, we find that only four methods impose simultaneously explicit and deep constraints: Knowledge Embedding and Pre-trained LanguagE Representation (KEPLER)33, Contrastive Knowledge-aware GNN (CK-GNN)23, XLM-K34 and WebFormer35. All four can be described as some form of per-sample graph alignment, in which an external, PT knowledge graph GPT or connectivity algorithm is employed over a subset of PT samples, and the output embeddings of pairs of samples zi = fθ(xi) and zj = fθ(xj) are constrained to reflect their relationships in the PT graph. This form of constraint is explicit, as the graph GPT contains explicit relationships that will be induced in the output latent space, but also deep, as the geometry of the graph GPT can be arbitrarily complex.

However, all these methods have major limitations. In KEPLER and XLM-K, the per-sample embeddings are only constrained to a restricted set of samples corresponding to entity descriptions from a knowledge graph. As such, no constraints are implied on the general domain free-text samples in \({{{\mathcal{X}}}}\) alone33,34. In CK-GNN, the graph connectivity is derived from a cluster-restricted one-nearest-neighbour graph in an alternative modality’s distance space, which may offer a limited higher-order structure. Unlike the NLP approaches, this method has no intra-sample (for example, per-token) PT task23. Finally, in WebFormer, the graph used is inferred from the structure of the HyperText Markup Language (HTML) underlying web pages, and relationships are only constrained at the per-sample level for limited structural relationships within the HTML. Furthermore, WebFormer is a specialized model specifically for processing web content (text and HTML elements), so this approach cannot be directly generalized to other domains35. Moreover, these methods explore only the particular contexts of their models. They offer no general framework for realizing these deep, explicit per-sample constraints in other contexts and do not explore any theory on how these constraints relate to performance for FT tasks23,33,34,35.

Overall, our review of PT methods establishes unequivocally that PT methods capable of providing explicit, deep structural constraints are significantly under-explored. Across all the methods we reviewed, only four methods leverage constraints are explicit and deep, all of which have significant limitations, and there is no consensus on how to constrain the \({{{\mathcal{Z}}}}\) explicitly and deeply. These findings motivate our framework, which offers insight into realizing deep, explicit structural constraints in PT models across diverse contexts and provides theoretical guidance on how structural constraints relate to FT performance. As we show in our results, inducing deep, explicit constraints through our framework will induce significant benefits over existing PT methodologies across three diverse biomedical domains.

Structure-inducing PT

Our PT problem framework includes two small but important differences from the standard formulation (Fig. 1).

Fig. 1: Our PT framework.
figure 1

We re-cast the PT formulation by taking a PT graph GPT as an auxiliary input. GPT is used to define a structure-inducing objective \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\), which pushes a PT encoder fθ to embed samples such that samples are close in the latent space if and only if they are linked in GPT.

First, we assume that we have as an additional input to the PT problem a graph GPT = (V, E) where vertices (V) denote PT samples within XPT (for example, xPTxPTXPTV) and edges (E) represent user-specified relationships. Notably, while we take the graph GPT as input to the PT problem, we cannot use it as a direct input to fθ. Just like in traditional PT, fθ must take as input only samples from \({{{\mathcal{X}}}}\). This is because otherwise, we cannot apply fθ to the same general class of FT tasks over domain \({{{\mathcal{X}}}}\).

Second, we decompose the PT loss \({{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}\) into two components, weighted with hyperparameter 0 ≤ λSI ≤ 1:

$${{{{\mathcal{L}}}}}_{{{{\rm{PT}}}}}=(1-{\lambda }_{{{{\rm{SI}}}}}){{{{\mathcal{L}}}}}_{{{{\rm{M}}}}}+{\lambda }_{{{{\rm{SI}}}}}{{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}.$$

\({{{{\mathcal{L}}}}}_{{{{\rm{M}}}}}\) is a traditional, intra-sample objective (for example, a language model), and \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) is a new, structure-inducing objective designed to regularize the per-sample latent space geometry by the relationships (edges) in GPT. Under our framework, \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) is only allowable for GPT, fθ and \({{{\mathcal{Z}}}}\) if it permits some stable optima at which point a radius nearest-neighbour connectivity algorithm under some distance function in \({{{\mathcal{Z}}}}\) will recover GPT (formal constraint is in Methods). Note that this constraint strikes a connection between our framework and the wealth of existing research focused on graph representation learning36,37,38,39,40,41. These techniques do indeed offer valuable insights into how to sample minibatches over graph-structured data and devise losses for graph embeddings; however, many methods for actually modelling graph-structured data, including deep attributed graph embeddings and graph convolutional neural networks, should not be seen as replacements for our techniques here as they are typically not adaptable to contexts in which the graph is not known at inference time, and so they could not be used in our PT setting where fθ must take in only inputs from \({{{\mathcal{X}}}}\) directly.

As the loss term added \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) is explicitly designed to induce the structure of GPT in \({{{\mathcal{Z}}}}\), we call methods (in particular methods leveraging deep, explicit structural constraints) trained under our framework structure-inducing pre-training (SIPT) methods. Many existing PT approaches can be re-realized as SIPT methods, including classification-based PT objectives such as NSP or SOP, contrastive methods, or existing graph alignment methods (Methods). Although SIPT is designed to make it easier to induce deep, explicit structural constraints, it is also flexible enough to capture implicit or shallow structural constraints.

Theoretical analyses

Under our framework, one can link the structure of the PT graph GPT to eventual FT task performance. In particular, as an SIPT embedder f over graph GPT approaches optimality under the loss \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\), it produces an embedding space such that nearest-neighbour performance for any downstream task is lower bounded by the performance that could be obtained via the nearest-neighbour algorithm over graph GPT (Theorem 1). This fact directly connects the geometry of the graph GPT with the eventual FT performance of an SIPT embedder f. Furthermore, it demonstrates the advantage of employing an explicit constraint rather than an implicit one; by controlling the structure of GPT, users can directly choose to add different inductive biases to the PT process in a manner that has a provable impact on the eventual suitability for downstream FT tasks.

Theorem 1

Let XPT be a PT dataset, let GPT be a PT graph and let \({f}_{{{{{\bf{\uptheta }}}}}^{* }}\) be an encoder pre-trained under a PT objective permissible under our framing that realizes an \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) value no more than *. Then, under embedder f, the nearest-neighbour accuracy for an FT task y converges as dataset size increases to at least the local consistency (Supplementary Definition 3) of y over GPT.

We establish two corollaries of Theorem 1 that illustrate the importance of choosing graphs GPT that impose deep structural constraints.

Corollary 1

Let \({{{{\bf{X}}}}}_{{{{\rm{PT}}}}}\in {{{{\mathcal{X}}}}}^{N}\) be a PT dataset with corresponding labels \({{{\bf{y}}}}\in {{{{\mathcal{Y}}}}}_{{{{\rm{PT}}}}}^{N}\). Define GPT = (XPT, E) such that (xi, xj) E if and only if yi = yj.

Then, the local consistency for a given FT task y(FT) over GPT (and thus by Theorem 1, the nearest-neighbour accuracy for any optimized SIPT embedder) is upper bounded by the probability that a sample xi’s FT label \({y}_{i}^{({{{\rm{FT}}}})}\) agrees with the majority class label for task y(FT) over the clique consisting of all nodes with the same PT label yi as xi.

Corollary 2

Let XPT be a PT dataset that can be realized over a valid manifold \({{{\mathcal{M}}}}\). Assume XPT is sampled with full support over \({{{\mathcal{M}}}}\). Let GPT(XPT, E) be an r-nearest-neighbour graph over \({{{\mathcal{M}}}}\) (for example, (xi, xj) E if and only if the geodesic distance between the two points on \({{{\mathcal{M}}}}\) is less than r: \({{{{\mathcal{D}}}}}_{{{{\mathcal{M}}}}}({{{{\bf{x}}}}}_{i},{{{{\bf{x}}}}}_{j}) < r\)). Let y(FT) be an FT classification task that is almost everywhere smooth on the manifold.

Then, as the PT dataset size (and thus the size of GPT) tends to ∞, and r tends to zero, the local consistency of y(FT) over GPT (and thus by Theorem 1 the nearest-neighbour accuracy of an SIPT embedder) will likewise tend to one.

Informally, these corollaries establish that when a shallow structural constraint is used (for example, a supervised classification objective), then the associated SIPT-equivalent model permits only minimal guarantees for FT performance, driven by the extent to which an FT task label is consistent within the classes under the supervised PT objective. In contrast, if a deep structural constraint is used, realized in Corollary 2 via GPT being a nearest-neighbour graph over an arbitrary manifold \({{{\mathcal{M}}}}\), then an SIPT model permits a theoretical guarantee for FT performance that approaches unity as the PT dataset size grows for any FT task that is smooth over \({{{\mathcal{M}}}}\).

This theoretical analysis shows that we can directly connect the structure induced in \({{{\mathcal{Z}}}}\) to downstream FT performance. As such, new PT methods that leverage graphs GPT with deeper structural constraints can markedly improve performance, as we will demonstrate on real-world datasets in our experiments. Complete proofs for all theoretical results and semi-synthetic experiments validating our theoretical findings in practice are in Methods.

Datasets and tasks

We examine three data modalities for our experiments: ‘Proteins’, containing protein sequences; ‘Abstracts’, containing free-text biomedical abstracts; and ‘Networks’, containing subgraphs of protein–protein interaction (PPI) networks.

In each data modality, we use different PT datasets and leverage different kinds of PT graphs GPT, test on publicly available benchmarks for FT tasks and compare our SIPT methods with compelling baselines spanning both per-sample and per-token methods (Tables 13). Further details on these aspects are in Methods.

Table 1 A summary of our datasets, tasks and benchmarks
Table 2 Mean (± standard deviation) relative reduction of error (defined to be ([baseline error] − [GPT model error])/[baseline error]) of models trained under our framework versus published per-token or per-sample baselines
Table 3 FT tasks

\({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) and training procedures

As discussed in the definition of our framework, an SIPT method differs from a standard PT method by (1) the choice of graph GPT (Table 1) and (2) the design of the structure-inducing loss \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\). To define \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) in our experiments, we leverage ideas from structure-preserving metric learning42,43,44. Structure-preserving metric learning is a form of metric learning where positive relationships are defined by edges in a graph rather than a shared supervised label. We adapt two losses, a traditional contrastive loss45 and a multi-similarity loss46, from supervised metric learning to the graph-based, structure-preserving context of \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) terms in SIPT.

In addition to these losses, in the Abstracts and Proteins domains, we use a warm-start procedure to initialize PT from existing language models rather than beginning from scratch. This saves significant computational time and allows for a powerful ablation study to isolate performance improvements to introducing our \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) term. Second, we perform extensive hyperparameter tuning studies on these two domains to identify appropriate values for λSI, and adapt those findings to the Networks domain. Further details about the experimental set-up, including formal statements of our contrastive and multi-similarity losses, are in Methods. Note that, as is standard in PT applications, for each PT algorithm and data modality, we pre-train a single model on the PT dataset, then fine-tune that one pre-trained model on each FT task independently; in other words, in no setting do we need to pre-train a separate model per FT task.

SIPT matches or outperforms all baselines

To analyse our experiments, we compute the relative reduction of error of the best-performing SIPT model versus the per-token or per-sample baselines across all FT tasks (Table 2). In 10 out of 15 cases, SIPT improves over existing methods; in no case does it do worse than either baseline. In some cases, the gains in performance are significant, with improvements of approximately 17% (0.05 macro-F1 raw change) on ACL-ARC (AA), 6% on SciERC relation extraction (SRE) (0.01 macro-F1 absolute change) and 4% on remote homology (RH; 2% absolute accuracy change). SIPT models further establish a new state-of-the-art performance on AA and RH and match state-of-the-art performance on fluorescence (FL), stability (ST) and paper field (PF). See Table 3 and Supplementary Information for details on these tasks, and recall that the F1 metric is the harmonic mean of precision and recall.

Figure 2 shows how performance evolves over FT iterations for the Networks dataset to determine whether the improvements observed at the final converged values are present throughout training. We see that SIPT methods converge faster to better performance than both baselines. Raw results across all settings are presented in Extended Data Tables 3 and 4.

Fig. 2: FT performance over Networks.
figure 2

Mean ± standard deviation FT AUROC as a function of FT iteration for the Networks dataset. Differences in variance scale result from different runs triggering early stop at different iterations. The SIPT method converges faster and performs better than intra-sample (masked node modelling) or per-sample (multi-task classification) PT. MT-PT indicates using traditional, supervised, multi-task pre-training alone. Mask-PT represents performing mask-imputation pre-training alone, whereas SIPT indicates the combination of the two approaches through our SIPT framework.

SIPT performance gains are robust

SIPT performance gains persist across all three data modalities and all different GPT types. This shows that explicitly regularizing the per-sample latent space geometry offers value across NLP, non-language sequences and non-sequential domains. Furthermore, leveraging graphs, including those defined by external knowledge, by self-supervised signals in the data directly, and by nearest-neighbour methods over multi-task label spaces, is beneficial. Furthermore, these improvements exist compared with standard language modelling approaches and against existing methods that impose per-sample PT objectives, including single- and multi-task classification objectives.

Gains are attributable to SIPT loss \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\)

As outlined in Methods, our experimental design permits us to determine how much of the observed gains in Table 2 are due to the SIPT loss component, as opposed to, for example, continued training, new PT data or the batch selection procedures used in our method, which also indirectly leverage the knowledge inherent in GPT. Unsurprisingly, some gains are observed due to these other factors, and performance gains shrink when considering these ablation studies. However, even when comparing against the maximal performance baseline or ablation study overall, neither the direction of observed relationships nor the statistical significance of observed comparisons changes. Therefore, we can conclusively state that the performance improvements observed here are uniquely attributable to the structure-inducing components introduced by our framework. Full ablation study results can be found in Extended Data Tables 3 and 4.

Discussion

Despite the breadth of research into PT methods, methods for imposing explicit and deep structural constraints over the per-sample, PT latent space \({{{\mathcal{Z}}}}\) are under-explored (Extended Data Fig. 1). Our theoretical and empirical analyses show that this deficit matters. In particular, we define a PT framework, SIPT, under which the PT loss is subdivided into two components: one that is designed to capture intra-sample (for example, per-token) relationships and one that is intended to constrain the per-sample latent space to capture relationships between samples given by a user-specified PT graph GPT. Under our framework, we show theoretically and via experiments that the structure induced in \({{{\mathcal{Z}}}}\) can be directly connected to eventual FT performance. Empirically, we show that SIPT methods leveraging a variety of PT graphs can consistently outperform existing PT methods across three real-world domains.

Our work highlights several important directions for future research. For example, are there losses better suited than metric learning losses for PT graphs—for example, can we leverage the graph distance alongside the intra-batch distance to improve negative sampling strategies? In addition, can we produce theoretical results on the convergence of pre-trained models? For example, can we advance the understanding of when and how pre-trained models converge to solutions that recover GPT? In a different direction, can pre-trained models reflect forms of structure beyond nearest-neighbour relationships—for example, by leveraging higher-order topological considerations or by matching a distance function rather than a discrete graph? In addition, further exploring the structure-inducing objective’s impact on the underlying models’ internal mechanisms, as explored via explainable artificial intelligence techniques, would be an exciting avenue for future work. We anticipate that further analyses of these and other questions will lead to new PT methods and enable PT to be successful across diverse domains.

Methods

Structure-inducing losses

We use a multi-similarity loss46, parameterized by positive pair weight, w+, negative pair weight, w, and fixed hyperparameter, t, given below:

$$\begin{array}{l}{{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}=\frac{1}{N{w}_{+}}\log \left(1+\mathop{\sum}\limits_{(i,\,j)\in E}{{\mathrm{e}}}^{-{w}_{+}(\langle\, {f}_{{{{\bf{\uptheta}} }}}({{{{\bf{x}}}}}_{i}),\,{f}_{{{{\bf{\uptheta }}}}}({{{{\bf{x}}}}}_{j})\rangle -t)}\right)\\+\frac{1}{N{w}_{-}}\log \left(1+\mathop{\sum}\limits_{(i,\,j)\notin E}{{\mathrm{e}}}^{{w}_{-}(\langle\, {f}_{{{{\bf{\uptheta }}}}}({{{{\bf{x}}}}}_{i}),\,{f}_{{{{bf{\uptheta }}}}}({{{{\bf{x}}}}}_{j})\rangle -t)}\right).\end{array}$$

We also leverage a contrastive loss modelled after the version in ref. 45. For this loss, we assume we are given the following mappings: ‘pos’, which maps x into a positive node (that is, linked to x in GPT), and ‘neg’, which maps x into a negative node (that is, not linked to x in GPT). The union of a seed minibatch B of points XB and its images under ‘pos’ and ‘neg’ mappings form a full minibatch. This loss is specified by the positive and negative margin parameters μ+ and μ as:

$$\begin{array}{l}{{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}^{({{{\rm{CL}}}})}=\frac{1}{N}\mathop{\sum}\limits_{{{{{\bf{x}}}}}_{i}\in {{{\bf{X}}}}}\max \left(\right.{{{\mathcal{D}}}}({{{{\boldsymbol{x}}}}}_{i},{{{\rm{pos}}}}({{{{\bf{x}}}}}_{i}))\\-{\mu }_{+},0\left)\right.+\frac{1}{N}\mathop{\sum}\limits_{{{{{\bf{x}}}}}_{i}\in {{{\bf{X}}}}}\max ({\mu }_{-}-{{{\mathcal{D}}}}({{{{\bf{x}}}}}_{i},{{{\rm{neg}}}}({{{{\bf{x}}}}}_{i})),0).\end{array}$$

The Proteins dataset and FT tasks

We use a dataset of ~1.5 million protein sequences from the Stanford tree-of-life dataset20 (https://snap.stanford.edu/tree-of-life/data.html). The associated GitHub repository for this resource lists a Massachusetts Institute of Technology (MIT) license.

Two proteins are linked in GPT for this dataset if and only if they are documented in the scientific literature to interact, according to the tree-of-life interaction dataset. This is an external knowledge graph.

For FT, we use the Tasks Assessing Protein Embeddings (TAPE) FT benchmark tasks5, including remote homology (RH), a per-sequence classification task to predict protein fold category (metric: accuracy); secondary structure (SS), a per-token classification task to predict amino acid structural properties (metric: accuracy); stability (ST) and fluorescence (FL), per-sequence, regression tasks to predict a protein’s stability and fluorescence, respectively (metric: Spearman’s ρ); and contact prediction (CP), an intra-sequence classification task to predict which pairs of amino acids are in contact in the protein’s three-dimensional conformation (metric: precision at L/5 where L is protein length). All of these tasks are from publicly available datasets that can be obtained directly on TAPE’s GitHub (https://github.com/songlab-cal/tape#data), which lists no licences for these datasets though the overall GitHub is released under a BSD 3-Clause ‘New’ or ‘Revised’ License. RH tasks a model to predict a protein fold category at a per-sequence level. This task’s dataset contains 12,312/736/718 train/validation/test proteins and is originally sourced from ref. 47. SS is a per-token, multi-class classification problem, evaluated using accuracy, which tasks a model to predict the structural properties of each amino acid in the final, folded protein. This task’s dataset contains 8,678/2,170/513 train/validation/test proteins and is sourced from ref. 48. ST tasks a model to predict the protein’s stability in response to environmental conditions. This task’s dataset contains 53,679/2,447/12,839 train/validation/test proteins, originally sourced from ref. 49. FL requires a model to predict how brightly a protein will fluoresce. This task’s dataset contains 21,446/5,362/27,217 train/validation/test proteins and is originally sourced from ref. 50. Finally, CP requires a model to predict whether any given pair of amino acids from a protein are less than 8 Å apart or not. This task’s dataset is sourced from ProteinNet51.

In these experiments, we compare against the published TAPE model5, which uses a language modeling task alone as our per-token comparison point, and the Protein sequence representations Learned Using Structural information (PLUS)52 model, which optimizes for LM and supervised classification jointly, for our per-sample comparison point.

The Abstracts dataset and FT tasks

We use a dataset of ~650,000 free-text scientific article abstracts from the Microsoft Academic Graph (MAG) dataset21,22. The Abstracts PT data (the MAG dataset) is licensed with an Open Data Commons Attribution License (ODC-By) v1.0 license.

Two abstracts are linked in GPT for this dataset if and only if their corresponding papers cite one another. This is a self-supervised graph.

Here, we use a subset of the FT tasks used in the SciBERT paper53, including paper field (PF), SciCite (SC), ACL-ARC (AA) and SciERC relation extraction (SRE), all of which are per-sentence classification problems (metric: macro-F1). PF tasks models to predict a paper’s area of study from its title, SC and AA tasks both predict an ‘intent’ label for citations, and SRE is a relation extraction task. All FT datasets can be obtained from the SciBERT GitHub (https://github.com/allenai/scibert), which lists no dataset-specific licences but is released with an Apache-2.0 license. The PF task asks models to predict a paper’s area of study given its title. This task’s dataset contains 84,000/5,599/22,399 train/validation/test sentences. Although the original dataset is derived from the MAG21, it was formulated into this task format by SciBERT directly53. The SC task challenges models to predict an ‘intent’ label for sentences that cite other scientific works within academic articles. This task’s dataset contains 7,320/916/1,861 train/validation/test sentences and is originally sourced from ref. 54. The AA task requires models to predict an ‘intent’ label for sentences that cite other scientific works within academic articles. This task’s dataset contains 1,688/114/139 train/validation/test sentences and is originally sourced from ref. 55.

We compare against the published SciBERT model53 as our per-token comparison and the BioLinkBERT model56 as our per-sample comparison. BioLinkBERT augments language modelling with a classification task to predict whether the input text consists of two sentences from the same document, linked documents (where linkage is determined via a citation graph) or unlinked documents. In this way, it uses similar information as used to build our PT graph but via a single-task classification loss rather than the more general structure-inducing losses we use here. Recently, more successful base language models have been proposed beyond the SciBERT model (such as PubMedBERT57) and switching to using those to initialize our SIPT models in the warm-start procedures would probably further improve performance across all models. However, given the computational expense of model PT, we retain the use of SciBERT for our initialization model (and accordingly for our corresponding per-token baseline) and leave the investigation of PubMedBERT for future work.

The Networks dataset and FT tasks

We use a dataset of ~70,000 PPI ego networks here, sourced from ref. 26. Each sample here describes a single protein, realized as a biological network (that is, an attributed graph) corresponding to the ego network about that protein (that is, a small subgraph containing all nodes within the target protein) in a broader PPI graph. Unlike our other domains, this domain does not contain sequences. The Networks PT dataset releases its code and dataset files under an MIT license.

This dataset is labelled with the presence or absence of any of 4,000 protein Gene Ontology terms associated with the central protein in each PPI ego network. Leveraging these labels, two PPI ego networks are linked in GPT if and only if the Hamming distance between their observed label vectors is no more than nine. This is an alternative-representation nearest-neighbour graph.

We study only one FT task in this setting, which is the multi-label binary classification of the 40 Gene Ontology term annotations (metric: macro area under the receiver operating characteristic curve (AUROC)) used in ref. 26. We use the PT set for FT training and evaluate the model on a held-out random 10% split.

We compare against both attribute-masking26 and multi-task supervised PT.

Experimental set-up

To minimize computational burden, we do not pre-train a structure-inducing model from scratch for Proteins and Abstracts datasets. Instead, we initialize a model from the per-token baseline directly, then perform additional PT for only a small number of epochs under the SIPT loss subdivision. We assess both multi-similarity and contrastive \({{{{\mathcal{L}}}}}_{{{{\rm{SI}}}}}\) variants in these domains. On the Networks dataset, we pre-train all models (including baselines) from scratch, and based on early experimental results, we only assess the contrastive loss variant.

Ablation analyses

Note that the warm-start procedure described above on the Proteins and Abstracts domains allows a powerful ablation study: by additionally training a PT model from the per-token baseline with λSI = 0, we can uniquely assess the impact of the new loss term, rather than simply additional training or the different PT dataset. We perform this ablation study for all relevant datasets. For the Networks dataset, no other ablation studies are needed to assess the impact of the loss term, given all models are trained from scratch with the same early-stop procedures.

Selection of λ SI model parameter

For the Proteins and Abstracts datasets, to choose the optimal value of λSI for use at PT time, we pre-trained several models and evaluated their efficacy in a link-retrieval task on GPT = (V, E). In particular, we score a node embedder f by embedding all nodes nV as f(n), then rank all other nodes n′ by the Euclidean distance between f(n) and f(n′), and assess this ranked list via label ranking average precision, normalized discounted cumulative gain, average precision and mean reciprocal rank, where a node n′ is deemed to be a ‘successful’ retrieval for n if \((n,{n}^{{\prime} })\in E\). In this way, note that we choose λSI in a manner that is independent of the FT task and can be determined solely based on the PT data. The final results for these experiments are shown in Extended Data Table 5 for the proteins dataset and Extended Data Table 6 for scientific articles. Ultimately, this process suggests that λSI of 0.1 is a robust setting, and as such, 0.1 was used directly for the Networks task without further optimization.

Model architecture and other model parameters

The architectures of our encoders for the Proteins and Abstracts domains are entirely determined from our source models in TAPE5 and SciBERT53. In particular, for proteins and scientific articles, we use a 12-layer transformer with a hidden size of 768, an intermediate size of 3,072 and 12 attention heads. Provided TAPE and SciBERT tokenizers are also used. A single linear layer to the output dimensionality of each task is used as the prediction head, taking as input the output of the final layer’s [CLS] token as a whole-sequence embedding. We also tested either PT for a single or four additional epochs based on validation set performance. We ultimately used a single epoch for proteins and four for scientific articles.

For the Networks domain, we match the architecture used in the original source26 for the mask model runs, save that for computational efficiency, scale the batch size up as high as possible, then proportionally scale up the learning rate to account for the larger batch size. This corresponds to a batch size of 1,024, a learning rate of 0.01, a graph convolutional neural network (GCNN) with a Graph Isomorphism Network (GIN) encoder, embedding dimensions of 300, 5 layers, 10% dropout, mean pooling and a node feature combination strategy (JK) of ‘last’.

FT hyperparameters (learning rate, batch size and the number of epochs) were determined based on a combination of existing results, hyperparameter tuning and machine limitations. On Proteins, most hyperparameters were set to follow those reported for a LM PT model in ref. 58, although additional limited hyperparameter searches were performed to validate that these choices were adequate. As the original source for these hyperparameters was an LM PT model, any bias here should be against SIPT, meaning this is a conservative choice. Early stopping (based on the number of epochs without observing improvement in the validation set performance) was employed, and batch size was set as large as possible considering the underlying machine. For the PLUS reproduction, we compared hyperparameters analogous to the reported PLUS hyperparameters for other tasks and analogous to our hyperparameters for other tasks and used those that performed best on the validation set. For scientific articles, we performed a grid search to optimize downstream task performance on the validation set, with the learning rate varying between 5 × 10−6 and 5 × 10−5 and the number of epochs between 2 and 5. The same grid search was used in the original SciBERT method. We additionally match the SciBERT benchmark by applying a dropout of 0.1, using the Adam optimizer with linear warm-up and decay, a batch size of 32, and no early stopping. For the Networks, FT hyperparameters were again chosen to match the original source model26 to save the increase in batch size and learning rate. No additional hyperparameter search was performed.

Final hyperparameters for each downstream task are shown in Extended Data Table 1 for proteins and Extended Data Table 2 for scientific articles.

Implementation and compute environments

We leverage PyTorch for our codebase. FT Experiments and Networks PT were run over various Ubuntu machines (versions ranged from 16.04 to 20.04) with various NVIDIA graphics processing units. Proteins and Abstracts PT runs were performed on a Power 9 system, each run using 4 NVIDIA 32 GB V100 graphics processing units with InfiniBand at half precision.

Systematic review of PT methods

Papers were selected via a manual search of the NLP and NLP-derived PT methods (that is, methods focused primarily on other domains or multi-modal domains were excluded) via Google Scholar and by crawling through references of papers already included. Citation counts for each work were obtained via Google Scholar on 2 August 2022. Publication date (used to calculate citations per month since publication date) was computed as the earlier of either (1) the paper’s venue-specific date of publication or (2) the first submission date to the arXiv or bioRxiv platforms, as referenced via an exact title match. A manual review was done to classify how PT methods constrain latent space geometry and assign subjective, numerical ‘shallow–deep’ and ‘explicit–implicit’ axes scores. In total, over 90 methods were examined, of which 74 were suitable for inclusion in numerical review results (Extended Data Fig. 1). Supplementary Information summarizes and categorizes all methods considered (and reasons for exclusions are given). Note that our framework focuses on NLP-derived PT methods, but we do not examine generative PT methodology focused on high-dimensional continuous distributions, such as diffusion models59. However, these methods have succeeded in other domains, such as computer vision.