Introduction

Access to large-scale datasets is a primary driver of advancement in machine learning, with well-known datasets such as ImageNet1 in computer vision, or SQuAD2 in natural language processing leading to remarkable achievements. Other domains such as healthcare and finance face restrictions on sharing data, due to regulations and privacy concerns. It is impossible for institutions in these domains to pool and disseminate their data, which limits the progress of research and model development. The ability to share information between institutions while respecting the data privacy of individuals would lead to more robust and accurate models.

In the healthcare domain, for example, histopathology has seen widespread adoption of digitization, offering unique opportunities to increase objectivity and accuracy of diagnostic interpretations through machine learning3. Digital images of tissue specimens exhibit significant heterogeneity from the preparation, fixation, and staining protocols used at the preparation site, among other factors. Without careful regularization, deep models may excessively focus on imaging artifacts and hence fail to generalize on data collected from new sources4. Additionally, the need to serve a diverse population including minority or rare groups5, and mitigate bias6, requires diverse and multi-centric datasets for model training. Because of specializations at institutions and variability across local populations the integration of medical data across multiple institutions is essential.

However, centralization of medical data faces regulatory obstacles, as well as workflow and technical challenges including managing and distributing the data. The latter is particularly relevant for digital pathology since each histopathology image is generally a gigapixel file, often one or more gigabytes in size. Distributed machine learning on decentralized data could be a solution to overcome these challenges, and promote the adoption of machine learning in healthcare and similar highly regulated domains.

Federated learning (FL) is a distributed learning framework that was designed to train a model on data that could not be centralized7. It trains a model in a distributed manner directly on client devices where data is generated, and gradient updates are communicated back to the centralized server for aggregation. However, the centralized FL setting is not suited to the multi-institutional collaboration problem, as it involves a centralized third party that controls a single model. Considering a collaboration between hospitals, creating one central model may be undesirable. Each hospital may seek autonomy over its own model for regulatory compliance and tailoring to its own specialty. As a result, decentralized FL frameworks8 are preferred under such settings.

While it is often claimed that FL provides improved privacy since raw data never leaves the client’s device, it does not provide the guarantee of security that regulated institutions require. FL involves each client sending unaudited gradient updates to the central server, which is problematic since deep neural networks are capable of memorizing individual training examples, which may completely breach the client’s privacy9.

In contrast, meaningful and quantitative guarantees of privacy are provided by the differential privacy (DP) framework10. In DP, access to a database is only permitted through randomized queries in a way that obscures the presence of individual datapoints. More formally, let \({{{{{{{\mathcal{D}}}}}}}}\) represent a set of datapoints, and M a probabilistic function, or mechanism, acting on databases. We say that the mechanism is (ϵ, δ)-differentially private if for all subsets of possible outputs \({{{{{{{\mathcal{S}}}}}}}}\subset \,{{\mbox{Range}}}\,(M)\), and for all pairs of databases \({{{{{{{\mathcal{D}}}}}}}}\) and \({{{{{{{{\mathcal{D}}}}}}}}}^{{\prime} }\) that differ by one element,

$$\,{{\mbox{Pr}}}\,[M({{{{{{{\mathcal{D}}}}}}}})\in {{{{{{{\mathcal{S}}}}}}}}]\le \exp (\epsilon )\,\,{{\mbox{Pr}}}\,[M({{{{{{{{\mathcal{D}}}}}}}}}^{{\prime} })\in {{{{{{{\mathcal{S}}}}}}}}]+\delta .$$
(1)

The spirit of this definition is that when one individual’s data is added or removed from the database, the outcomes of a private mechanism should be largely unchanged in distribution. This will hold when ϵ and δ are small positive numbers. In this case, an adversary would not be able to learn about the individual’s data by observing the mechanism’s output, hence, privacy is preserved. DP mechanisms satisfy several useful properties, including strong guarantees of privacy under composition, and post-processing11, 12. However, if an individual contributes several datapoints to a dataset, then their privacy guarantees may be weaker than expected due to correlations between their datapoints. Still, group differential privacy11 shows that privacy guarantees degrade in a controlled manner as the number of datapoints contributed increases. These properties make DP a suitable solution for ensuring data privacy in a collaborative FL setting.

Unlike centralized FL7, 8 where federated clients coordinate to train a centralized model that can be utilized by everyone as a service, decentralized FL is more suitable for multi-institutional collaborations due to regulatory constraints. The main challenge of decentralized FL is to develop a protocol that allows information passing in a peer-to-peer manner. Gossip protocols13 can be used for efficient communication and information sharing14, 15. There are different forms of information being exchanged in the literature, including model weights16, 17, knowledge representations18, or model outputs19, 20. However, unlike our method, none of these protocols provides a quantitative theoretical guarantee of privacy for participants, and therefore are not suitable for highly regulated domains.

Several other methods have been proposed to achieve decentralization with varying secondary objectives. In Cyclical Weight Transfer (CWT)21 each client trains a model on local data, then passes that model to the next client in a cyclical fashion. Like standard FL, CWT avoids the need to centralize data and can achieve good performance when strict privacy (à la DP) is not a concern. Split learning22 enables multiple parties to jointly train a single model with a server such that no party controls the entire model. In our context, the additional reliance on a central party for inference is undesirable. Finally, swarm learning23 applies blockchain technology to promote a decentralized, secure network for collaborative training, with one client voted in to act as a central authority each round. Swarm learning does not change the core learning algorithm of FL7 so it inherits relatively poor model performance when strict privacy measures are applied, and requires homogeneous model architectures.

Each client in our ProxyFL has two models that serve different purposes. They are trained using a DP variant of deep mutual learning (DML)24 which is an approach for mutual knowledge transfer. DML compares favorably to knowledge distillation between a pre-trained teacher and a typically smaller student25 since it allows training both models simultaneously from scratch, and provides beneficial information to both models. Federated Mutual Learning (FML)26 introduces a meme model that resembles our proxy model, which is also trained mutually with each client’s private model, but is aggregated at a central server. However, FML is not well-suited to the multi-institutional collaboration setting as it is centralized and provides no privacy guarantee to clients.

Although raw data never leaves client devices, FL is still susceptible to breaches of privacy27, 28. DP has been combined with FL to train centralized models with a guarantee of privacy for all clients that participate29. By ensuring that gradient updates are not overly reliant on the information in any single training example, gradients can be aggregated centrally with a DP guarantee30. We take inspiration from these ideas for ProxyFL.

The main application domain considered in this work is computational pathology. Various articles have emphasized the need for privacy-preserving FL when facing large-scale computational pathology workloads. Li et al.31 and Ke et al.32 used FL for medical image augmentation and segmentation. Their method used a centralized server to aggregate selective weight updates that were treated in a DP framework, but they did not account for the total privacy budget expended over the training procedure. Li et al.33 and Lu et al.34 built medical image classification models with FL, and added noise to model weights for privacy. However, model weights have unbounded sensitivity, so no meaningful DP guarantee is achieved with these techniques.

In this work, we propose proxy-based federated learning, or ProxyFL (Fig. 1), for decentralized collaboration between institutions which enables training of high-performance and robust models, without sacrificing data privacy or communication efficiency. Our contributions are: (i) a method for decentralized FL in multi-institutional collaborations that is adapted to heterogeneous data sources, and preserves model autonomy for each participant; (ii) incorporation of DP for rigorous privacy guarantees; (iii) analysis and improvement of the communication overhead required to collaborate.

Fig. 1: Overall view of the proposed ProxyFL.
figure 1

ProxyFL is a communication-efficient, decentralized federated learning method where each client (e.g., hospital) maintains a private model, a proxy model, and private data. During distributed training, the client communicates with others only by exchanging their proxy model which enables data and model autonomy. After training, a client’s private model can be used for inference.

Results and discussions

In this section, we demonstrate the effectiveness and efficiency of ProxyFL on popular image classification datasets as well as a cancer diagnostic problem. In our experiments, we use the exponential communication protocol of Assran et al.35, as illustrated in Fig. 2. The communication graph P(t) is a permutation matrix so that each client only receives and sends one proxy per communication round. Each client communicates with its peer that is \({2}^{0},{2}^{1},\cdots \,,{2}^{\lfloor {\log }_{2}(|{{{{{{{\mathcal{K}}}}}}}}|-1)\rfloor }\) steps away periodically. Sending only one proxy per communication round enables ProxyFL to scale to large numbers of clients. The communication and synchronization is handled via the OpenMPI library36. Our code is available at https://github.com/layer6ai-labs/ProxyFL.

Fig. 2: An illustration of the directed exponential graphs used in our experiments.
figure 2

Dark arrows indicate the communication path of the bottom node at each round, while light arrows show the communication path of others. Solid nodes indicate clients who have received information from the bottom client. After only \(\lceil {\log }_{2}(|{{{{{{{\mathcal{K}}}}}}}}|)\rceil\) rounds, all nodes have access to that information.

Benchmark image classification

Datasets and settings

We conducted experiments on MNIST37, Fashion-MNIST (FaMNIST)38, and CIFAR-1039. MNIST and FaMNIST have 60k training images of size 28 × 28, while CIFAR-10 has 50k RGB training images of size 32 × 32. Each dataset has 10k test images, which are used to evaluate the model performance. Experiments were conducted on a server with 8 V100 GPUs, which correspond to 8 clients. In each run, every client had 1k (MNIST and FaMNIST) or 3k (CIFAR-10) non-overlapping private images sampled from the training set, meaning only a subset of the available training data was used overall which increases the difficulty of the classification task. To test robustness on non-IID data, clients were given a skewed private data distribution. For each client, a randomly chosen class was assigned and a fraction pmajor (0.8 for MNIST and FaMNIST; 0.3 for CIFAR-10) of that client’s private data was drawn from that class. The remaining data was randomly drawn from all other classes in an IID manner. Hence, clients must learn from collaborators to generalize well on the IID test set.

Baselines

We compare our method to FedAvg7, AvgPush, CWT21, FML26, Regular, and Joint training. FedAvg and FML are centralized schemes that average models with identical structure. FML is similar to ProxyFL in that every client has two models, except FML does centralized averaging and originally did not incorporate DP training. AvgPush is a decentralized version of FedAvg using PushSum for aggregation. CWT is similar to AvgPush, but uses cyclical model passing instead of aggregation. In line with prior work40, we also compare federated schemes with the Regular and Joint settings. Regular training uses the local private datasets without any collaboration. Joint training mimics a scenario without constraints on data centralization by combining data from all clients and training a single model. Regular, Joint, FedAvg, AvgPush, and CWT use DP-SGD for training their models, while ProxyFL and FML use it for their proxies.

Implementation details

Following Shen et al.26, the private/proxy models are LeNet5/MLP for MNIST and FaMNIST, and CNN2/CNN1 for CIFAR-10. All methods use the Adam optimizer41 with learning rate of 0.001, weight decay of 1e-4, mini-batch size of 250, clipping C = 1.0, and noise level σ = 1.0. Each round of local training takes a number of gradient steps equivalent to one epoch over the private data. For proper DP accounting, mini-batches are sampled from the training set independently with replacement by including each training example with a fixed probability42. The DML parameters α, β are set at 0.5 for FML and ProxyFL. We report means and standard deviations based on 5 random seeds. Additional details and results can be found in Sections A and B of the Supplementary Information.

Results and discussions

Figure 3 shows the performance on the test datasets. There are a few observations: (i) The private models of ProxyFL achieve the best overall performance on all datasets, even better than the centralized counterpart FML. The improvements of ProxyFL-private over other methods are statistically significant for every dataset (p value < 1e-5). Note that the Joint method serves as an upper bound of the problem when private datasets are combined. (ii) As a decentralized scheme, ProxyFL has a much lower communication cost compared to FML, as shown in Fig. 4. The exponential protocol has a constant time complexity per round regardless of the number of clients, which makes ProxyFL much more scalable. (iii) Decentralized schemes seem to be more robust to DP training, as AvgPush outperforms FedAvg and ProxyFL outperforms FML consistently.

Fig. 3: Test performance with differential privacy (DP) training.
figure 3

Datasets are a MNIST, b Fashion-MNIST, and c CIFAR-10. Each figure reports mean and standard deviation over eight clients for each of five independent runs. Source data are provided as a Source Data file.

Fig. 4: Communication time required as number of clients increases.
figure 4

Source data are provided as a Source Data file.

Ablation

We ablated ProxyFL to see how different factors affect its performance on the MNIST dataset. Unless specified otherwise, all models, including the private ones in FML and ProxyFL, have the same MLP structure. Additional ablation results can be found in Section B of the Supplementary Information.

IID versus non-IID

Figure 5 (Left) shows the performance of the methods with different levels of non-IID dataset skew. pmajor = 0.1 is the IID setting because there are 10 classes. We can see that as the setting deviates from IID, most methods degrade except for Joint training since it unifies the datasets. The private model of ProxyFL is the most robust to the degree of non-IID dataset skew. Note that the proxy model of ProxyFL achieves similar performance to the private model of FML, which is trained without DP guarantees. This indicates that ProxyFL is a scheme that is robust to distribution shifts among the clients.

Fig. 5: Ablations of our method.
figure 5

a Accuracy as the non-IID skew of the dataset increases. b Accuracy when clients have heterogeneous model architectures, and c accuracy with and without differentially private training. Each figure reports mean and standard deviation over eight clients for each of five independent runs. Source data are provided as a Source Data file.

Different private architectures

One important aspect of ProxyFL is that the private models can be heterogeneous—customized to meet the special needs of the individual clients. On the MNIST task we use all four model architectures, one for every two clients (CNN1 and CNN2 are slightly adapted to fit MNIST images). Results are shown in Fig. 5 (Center). We see that different models can achieve very diverse and sub-optimal performances with individual Regular training, while ProxyFL can improve all architectures’ performance. The improvements for weaker models are more significant than for stronger models.

DP versus non-DP

Next we investigate the effect of DP-SGD on the algorithms. Figure 5 (Right) shows the test accuracies of the different training methods with and without DP-SGD’s gradient clipping and noise addition. Clearly, all methods can outperform Regular training when there is no privacy constraint. However, with DP-SGD, centralized methods like FedAvg and FML-proxy perform poorly, even worse than Regular training. ProxyFL-private shows the smallest decrease in performance when DP-SGD is included and remains closest to the upper bound of Joint training.

Gastrointestinal disease detection

Dataset and setup

We applied our method on the Kvasir dataset43, which is a multi-class image dataset for gastrointestinal disease detection. It consists of 8000 endoscopic images from eight classes such as anatomical landmarks or pathological findings. Each class has 1000 images and each image is resized to 100 × 80 pixels. Following Yang et al.44, the dataset is partitioned into 6000 training and 2000 test images in each run, and the training set is further distributed into 8 clients by sampling from a Dirichlet distribution with concentration 0.545. The proxy and private models are both the VGG model from Yang et al.44 with minor adjustments for the resized images. The parameter settings are the same as in “Benchmark image classification” except that the batch size is now 128.

Results and discussion

The results are shown in Fig. 6. We can see that centralized schemes like FedAvg and FML-proxy do not learn much during the process, as opposed to their decentralized counterparts AvgPush and ProxyFL-proxy. This demonstrates the effectiveness of the PushSum scheme in this setting. Additionally, ProxyFL-private consistently outperforms FML-private during training, showing that the ProxyFL model is able to provide better learning signal to the private model compared with FML.

Fig. 6: Test accuracy on the Kvasir dataset.
figure 6

This figure reports mean and standard deviation over eight clients for each of five independent runs. Source data are provided as a Source Data file.

Histopathology image analysis

In this experiment, we evaluated ProxyFL for classifying the presence of lymph node metastases in a tissue sample. The presence of lymph node metastases is one of the most important factors in breast cancer prognosis. The sentinel lymph node is the most likely lymph node to contain metastasized cancer cells and is excised, histopathologically processed, and examined by the pathologist. This tedious examination process is time-consuming and can lead to small metastases being missed46.

Dataset

We considered a large public archive of whole-slide images (WSIs), namely the Camelyon-17 challenge dataset46, which has previously been used for federated learning in47. The dataset is derived from 1399 annotated whole-slide images of lymph nodes, both with and without metastases. Slides were collected from five different medical centers to cover a broad range of image appearance and staining variations. A total of 209 WSIs contain detailed hand-drawn contours for all metastases. The client data for this study was created from Camelyon-17 by choosing WSIs out of these 209 annotated WSIs from four of the institutions (see sample patches in Fig. 7). We used the provided annotations to extract both normal and tumor-containing patches from WSIs. We extracted 512 × 512 pixel patches from WSIs and assigned each patch a binary label (healthy/tumor-containing) based on the provided annotation. A similar number of patches were sampled from the WSIs originating at each of the four institutions, with balanced labels. We then split patches into training and test sets in a rough 80/20 ratio, ensuring that patches in each set originated from non-overlapping patients. The distribution of patches in each split is described in Table 1.

Fig. 7: Example histopathology images for model training.
figure 7

Patches extracted from Camelyon-17 whole-slide images (WSIs): a, b healthy tissue, c, d from WSIs containing tumors.

Table 1 Distribution of WSI patches across four different institutions

Generalization of models to data collected from diverse sources has become a well-known challenge for applying deep learning to medical applications48. The standard method for testing generalization is to evaluate models on external test data which should originate from entirely different institutions than those used for training49,50,51. FL can potentially mitigate generalization problems by increasing the diversity of data that models are exposed to during training. Hence, we merged all four client test sets into a single multi-centric test set used in all model evaluations. From the perspective of a given client, a majority of the test set is external data which requires the trained model to generalize beyond the internal training data in order to show strong performance.

Models

For all methods, including both private and proxy models, we used the standard ResNet-18 neural network architecture52, as implemented in the torchvision package53, with randomly initialized weights. ResNet architectures use BatchNorm layers54, which are problematic for the analysis of differential privacy guarantees in DP-SGD training because batch normalization causes each sample’s gradient to depend on all datapoints in the batch. As a substitute, we replaced BatchNorm layers with GroupNorm layers55, and did so across all models for consistency. Finally, we modified the dimensions of the output layer of ResNet-18 to fit our patch size and binary classification requirements.

Experimental setup

The experiments were conducted using four V100 GPUs, each representing a client. In each scenario, training was conducted for 30 epochs with a mini-batch size of 32. All methods used the DP settings of σ = 1.4, C = 0.7, and δ = 10−5. Values for ϵ were then computed per-client based on the training set sizes in Table 1, and are provided in Table 2 (since several datapoints can be associated to the same patient, the formalism of group differential privacy11 could be applied to account for the correlations between such datapoints. For purposes of demonstration, and because the dataset we use is publicly available, we report ϵ as if each datapoint came from a distinct patient). FedAvg, AvgPush, CWT, Regular, Joint, and the proxy models of FML and ProxyFL used a DP-Adam optimizer with learning rate 0.001, whereas the private models of FML and ProxyFL used Adam with the same learning rate. The DML parameters α and β were set at 0.3 for both FML and ProxyFL. Finally, for FML and ProxyFL the private models are used to evaluate performance, the central model is used in the case of FedAvg, and for other methods the local models are used.

Table 2 (Left) Final performance metrics on the histopathology dataset

Results

The binary classification results for accuracy and macro-averaged accuracy are reported in Fig. 8 and Table 2. ProxyFL and FML achieve overall higher accuracy throughout training compared to other approaches, due to their private model’s ability to focus on local data while extracting useful information about other institutions through proxy models. Notably, FML’s performance peaks early and begins to degrade, while ProxyFL continues to improve marginally to the end of training.

Fig. 8: Performance of methods on the histopathology dataset involving four institutions.
figure 8

For accuracy (a) and macro-averaged accuracy (b) the mean and standard deviation of performance on the test set is recorded at the end of each epoch. Each figure reports mean and standard deviation over 4 clients for each of 15 independent runs. Source data are provided as a Source Data file.

Methods

Our research complies with all relevant ethical regulations; as we have used publicly available data from the Camelyon-17 challenge dataset, no institutional approval was necessary.

ProxyFL, or proxy-based federated learning is our proposed approach for decentralized federated learning. It is designed for multi-institutional collaborations in highly regulated domains, and as such incorporates quantitative privacy guarantees with efficient communication.

Problem formulation and overview

We consider the decentralized FL setting involving a set of clients \({{{{{{{\mathcal{K}}}}}}}}\), each with a local data distribution \({{{{{{{{\mathcal{D}}}}}}}}}_{k},\forall k\in {{{{{{{\mathcal{K}}}}}}}}\). Every client maintains a private model \({f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}:{{{{{{{\mathcal{X}}}}}}}}\to {{{{{{{\mathcal{Y}}}}}}}}\) with parameters ϕk, where \({{{{{{{\mathcal{X}}}}}}}},{{{{{{{\mathcal{Y}}}}}}}}\) are the input/output spaces respectively. In this work, we assume that all private models have the same input/output specifications, but may have different structures (this can be further relaxed by including client-specific input/output adaptation layers). The goal is to train the private models collectively so that each generalizes well on the joint data distribution.

There are three major challenges in this setting: (i) The clients may not want to reveal their private model’s structure and parameters to others. Revealing model structure can expose proprietary information, increase the risk of adversarial attacks56, and can leak private information about the local datasets57. (ii) In addition to model heterogeneity, the clients may not want to rely on a third party to manage a shared model, which precludes centralized model averaging schemes. (iii) Information sharing must be efficient, robust, and peer-to-peer. To address the above challenges, we introduce an additional proxy model \({h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}:{{{{{{{\mathcal{X}}}}}}}}\to {{{{{{{\mathcal{Y}}}}}}}}\) for each client with parameters θk. It serves as an interface between the client and the outside world. As part of the communication protocol, all clients agree on a common proxy model architecture for compatibility. Proxy models enable architectural independence among federated clients that may enhance security against white-box adversarial attacks in which the attacker has knowledge of the model architecture58. The proxy model is generally smaller than the private model thus reducing communication overhead among clients. Furthermore, diverting all shared information through the proxy allows the private model to be trained without DP. Since DP incurs a loss of utility, the private model, which is ultimately used for inference by the client locally, tends to have much higher accuracy than otherwise.

In every round of ProxyFL, each client trains its private and proxy models jointly so that they can benefit from one another. With differentially private training, the proxy can extract useful information from private data, ready to be shared with other clients without violating privacy constraints. Then, each client sends its proxy to its out-neighbors and receives new proxies from its in-neighbors according to a communication graph, specified by an adjacency matrix P and de-biasing weights w. Finally, each client aggregates the proxies they received, and replaces their current proxy. The overall procedure is shown in Fig. 1 and Algorithm 1. We discuss each step in detail in the subsequent subsections.

Algorithm 1

ProxyFL

Training objectives

For concreteness, we consider classification tasks. To train the private and proxy models at the start of each round of training, we apply a variant of DML24. Specifically, when training the private model for client k, in addition to the cross-entropy loss (CE)

$${{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{CE}}}}}}}}}({f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}):={{\mathbb{E}}}_{({{{{{{{\bf{x}}}}}}}},y) \sim {{{{{{{{\mathcal{D}}}}}}}}}_{k}}{{{{{{{\rm{CE}}}}}}}}[{f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}({{{{{{{\bf{x}}}}}}}})\parallel y],$$
(2)

DML adds a KL divergence loss (KL)

$${{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{KL}}}}}}}}}({f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}};{h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}):={{\mathbb{E}}}_{({{{{{{{\bf{x}}}}}}}},y) \sim {{{{{{{{\mathcal{D}}}}}}}}}_{k}}{{{{{{{\rm{KL}}}}}}}}[{f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}({{{{{{{\bf{x}}}}}}}})\parallel {h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}({{{{{{{\bf{x}}}}}}}})],$$
(3)

so that the private model can also learn from the current proxy model. The objective for learning the private model is given by

$${{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}:=(1-\alpha )\cdot {{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{CE}}}}}}}}}({f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}})+\alpha \cdot {{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{KL}}}}}}}}}({f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}};{h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}),$$
(4)

where α (0, 1) balances between the two losses. The objective for the proxy model is similarly defined as

$${{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}:=(1-\beta )\cdot {{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{CE}}}}}}}}}({h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}})+\beta \cdot {{{{{{{{\mathcal{L}}}}}}}}}_{{{{{{{{\rm{KL}}}}}}}}}({h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}};{f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}).$$
(5)

where β (0, 1). As in DML, we alternate stochastic gradient steps between the private and proxy models.

In our context, mini-batches are sampled from the client’s private dataset. Releasing the proxy model to other clients risks revealing that private information. Therefore, each client uses differentially private stochastic gradient descent (DP-SGD)30 when training the proxy (but not the private model). Let \({{{{{{{{\mathcal{B}}}}}}}}}_{k}={\{({{{{{{{{\bf{x}}}}}}}}}_{i},{y}_{i})\}}_{i=1}^{B}\) denote a mini-batch sampled from \({{{{{{{{\mathcal{D}}}}}}}}}_{k}\). The stochastic gradient is \(\nabla {\widehat{{{{{{{{\mathcal{L}}}}}}}}}}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}({{{{{{{{\mathcal{B}}}}}}}}}_{k}):=\frac{1}{B}\mathop{\sum }\nolimits_{i=1}^{B}{{{{{{{{\bf{g}}}}}}}}}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}^{(i)}\) where

$${{{{{{{{\bf{g}}}}}}}}}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}^{(i)}:=(1-\alpha ){\nabla }_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}{{{{{{{\rm{CE}}}}}}}}[{f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}({{{{{{{{\bf{x}}}}}}}}}_{i})\!\parallel \!{y}_{i}]+\alpha {\nabla }_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}{{{{{{{\rm{KL}}}}}}}}[{f}_{{{{{{{{{\boldsymbol{\phi }}}}}}}}}_{k}}({{{{{{{{\bf{x}}}}}}}}}_{i})\!\parallel \!{h}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}({{{{{{{{\bf{x}}}}}}}}}_{i})].$$
(6)

\(\nabla {\widehat{{{{{{{{\mathcal{L}}}}}}}}}}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}({{{{{{{{\mathcal{B}}}}}}}}}_{k})\) and \({{{{{{{{\bf{g}}}}}}}}}_{{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}}^{(i)}\) are similarly defined for the proxy. To perform DP training for the proxy, the per-example gradient is clipped, then aggregated over the mini-batch, and finally Gaussian noise is added30:

$${\overline{{{{{{\bf{g}}}}}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}_{k}}^{(i)}: = \,{{{{{{\bf{g}}}}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}_{k}}^{(i)}/\max \left(1,\parallel {\overline{{{{{{\bf{g}}}}}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}_{k}}^{(i)}{\parallel }_{2}/C\right),\\ \widetilde{\nabla }{\widehat{{{{{{\mathcal{L}}}}}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}_{k}}({{{{{{\mathcal{B}}}}}}}_{k}): = \frac{1}{B}\left(\mathop{\sum }\limits_{i=1}^{B}{\overline{{{{{{\bf{g}}}}}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}_{k}}^{(i)}+{{{{{\mathcal{N}}}}}}(0,{\sigma }^{2}{C}^{2}I)\right),$$
(7)

where C > 0 is the clipping threshold and σ > 0 is the noise level (see Lines 2–5 in Algorithm 1).

Privacy guarantee

The proxy model is the only entity that a client reveals, so each client must ensure this sharing does not compromise the privacy of their data. Since arbitrary post-processing on a DP-mechanism does not weaken its (ϵ, δ) guarantee11, it is safe to release the proxy as long as it was trained via a DP-mechanism. DP-SGD as defined in Eq. (7) is based on the Gaussian mechanism59 which meets the requirement of Eq. (1) by adding Gaussian noise to the outputs of a function f with bounded sensitivity C in L2 norm,

$$M(x)=f(x)+{{{{{{{\mathcal{N}}}}}}}}(0,{\sigma }^{2}{C}^{2}I).$$
(8)

DP-SGD simply takes f(x) to be the stochastic gradient update, with clipping to ensure bounded sensitivity.

Every application of the DP-SGD step incurs a privacy cost related to the clipping threshold C, and noise level σ. A strong bound on the total privacy cost over many applications of DP-SGD is obtained by using the framework of Rényi differential privacy60, 61 to track privacy under compositions of DP-SGD, then convert the result to the language of (ϵ, δ)-DP as in Eq. (1)62.

Finally, privacy guarantees are tracked on a per-client basis. In a multi-institutional collaboration, every client has an obligation to protect the privacy of the data it has collected. Hence, each client individually tracks the parameters (ϵ, δ) for its own proxy model training, and can drop out of the protocol when its prespecified privacy budget is reached. Throughout the paper, we specify δ based on the dataset size, and compute ϵ.

Communication efficiency and robustness

The proxies serve as interfaces for information transfer and must be locally aggregated in a way that facilitates efficient learning among clients. One may use a central parameter server to compute the average of the proxies, similar to Shen et al.26. However, this will incur a communication cost that grows linearly in the number of clients, and is not decentralized. We propose to apply the PushSum scheme13, 15 to exchange proxies among clients that significantly reduces the communication overhead.

Let \({\Theta }^{(t)}\in {{\mathbb{R}}}^{|{{{{{{{\mathcal{K}}}}}}}}|\times {d}_{\theta }}\) represent the stacked proxies at round t, where the rows are the proxy parameters \({{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}^{(t)},\forall k\in {{{{{{{\mathcal{K}}}}}}}}\). We use \({P}^{(t)}\in {{\mathbb{R}}}^{|{{{{{{{\mathcal{K}}}}}}}}|\times|{{{{{{{\mathcal{K}}}}}}}}|}\) to denote the weighted adjacency matrix representing the graph topology at round t, where \({P}_{k,{k}^{{\prime} }}^{(t)}\, \ne \, 0\) indicates that client k receives the proxy from client \({k}^{{\prime} }\). Note that P(t) needs to be column-stochastic, but need not be symmetric (bidirectional communication) nor time-invariant (across rounds). Such a P(t) will ensure efficient communication when it is sparse. The communication can also handle asymmetrical connections such as different upload/download speeds, and can adapt to clients joining or dropping out since it is time-varying.

With these notations, every round of communication can be concisely written as Θ(t+1) = P(t)Θ(t). Under certain mixing conditions63, it can be shown that \({\lim }_{T\to \infty }\mathop{\prod }\nolimits_{t=0}^{T}{P}^{(t)}={{{{{{{\boldsymbol{\pi }}}}}}}}{{{{{{{{\bf{1}}}}}}}}}^{\top }\), where π is the limiting distribution of the Markov chain and 1 is a vector of all ones. Suppose for now that there is no training for the proxies between rounds, i.e., updates to the proxies are due to communication and replacement only. In the limit, θk will converge to \({{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}^{(\infty )}={\pi }_{k}{\sum }_{{k}^{{\prime} }\in {{{{{{{\mathcal{K}}}}}}}}}{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{{k}^{{\prime} }}^{(0)}\). To mimic model averaging (i.e., computing \(\frac{1}{|{{{{{{{\mathcal{K}}}}}}}}|}{\sum }_{{k}^{{\prime} }\in {{{{{{{\mathcal{K}}}}}}}}}{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{{k}^{{\prime} }}^{(0)}\)), the bias introduced by πk must be corrected. This can be achieved by having the clients maintain another set of weights \({{{{{{{\bf{w}}}}}}}}\in {{\mathbb{R}}}^{|{{{{{{{\mathcal{K}}}}}}}}|}\) with initial values w(0) = 1. By communicating w(t+1) = P(t)w(t), we can see that \({{{{{{{{\bf{w}}}}}}}}}^{(\infty )}={{{{{{{\boldsymbol{\pi }}}}}}}}{{{{{{{{\bf{1}}}}}}}}}^{\top }{{{{{{{{\bf{w}}}}}}}}}^{(0)}=|{{{{{{{\mathcal{K}}}}}}}}|{{{{{{{\boldsymbol{\pi }}}}}}}}\). As a result, the de-biased average is given by \({{{{{{{{\boldsymbol{\theta }}}}}}}}}_{k}^{(\infty )}/{w}_{k}^{(\infty )}=\frac{1}{|{{{{{{{\mathcal{K}}}}}}}}|}{\sum }_{{k}^{{\prime} }\in {{{{{{{\mathcal{K}}}}}}}}}{{{{{{{{\boldsymbol{\theta }}}}}}}}}_{{k}^{{\prime} }}^{(0)}\).

Finally, recall that the proxies are trained locally in each round. Instead of running the communication to convergence for proxy averaging, we alternate between training (Lines 2–5 in Algorithm 1) and communicating (Lines 7–10) proxies, similar to Assran et al.35.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.