Communication-efficient federated learning via knowledge distillation

Federated learning is a privacy-preserving machine learning technique to train intelligent models from decentralized data, which enables exploiting private data by communicating local model updates in each iteration of model learning rather than the raw data. However, model updates can be extremely large if they contain numerous parameters, and many rounds of communication are needed for model training. The huge communication cost in federated learning leads to heavy overheads on clients and high environmental burdens. Here, we present a federated learning method named FedKD that is both communication-efficient and effective, based on adaptive mutual knowledge distillation and dynamic gradient compression techniques. FedKD is validated on three different scenarios that need privacy protection, showing that it maximally can reduce 94.89% of communication cost and achieve competitive results with centralized model learning. FedKD provides a potential to efficiently deploy privacy-preserving intelligent systems in many scenarios, such as intelligent healthcare and personalization.


Introduction
Privacy protection of user data is a very important issue (Shokri and Shmatikov, 2015). Federated learning is a well-known technique to learn intelligent models from decentralized user data (McMahan et al., 2017). It has been widely used in various applications such as intelligent keyboard (Hard et al., 2018), personalized recommendation  and topic modeling (Jiang et al., 2019).
In federated learning, the private data is locally stored on different clients . Each client keeps a local model and computes the model updates from its local data. In each iteration, the model updates from a number of clients are uploaded to a server, which aggregates the local model updates into a global one to update its maintained global model. Then, the server distributes the global update to each client to conduct a local model update. This process is iteratively executed for many rounds until the model converges. In this framework, the server and clients need to intensively communicate the model updates. However, the communication cost is enormous if the model is in large size, which hinders the applications of many powerful but large-scale models like BERT (Devlin et al., 2019) to federated learning.
In this paper, we propose a communication efficient federated learning method based on knowledge distillation (FedKD). Instead of directly communicating the large models between the clients and server, in FedKD a small student model and a large teacher model are distilled from each other, where only the student model is shared by different clients and learned collaboratively, which can effectively reduce the communication cost. More specifically, each client maintains a large local teacher model and a local copy of a small student model that is shared among different clients. We propose an adaptive knowledge distillation method to enable the local teacher and student to learn from both the local data on its client and the knowledge distilled from each other, where their distillation intensities are controlled by the correctness of their predictions. The local teacher model on each client is locally updated, while the local updates of the student models from different clients are uploaded to a central server, which aggregates these local updates into a global one. The server further distributes the global update to different clients to update their local student models. This process is iteratively executed until the student model converges. In addition, to further reduce the communi-cation cost when exchanging the student model updates, we propose a dynamic gradient approximation method based on singular value decomposition (SVD) to compress the communicated gradients with dynamic precision. Extensive experiments on benchmark datasets for different tasks validate that our approach can effectively reduce communication costs in federated learning and meanwhile achieve competitive performance.
The contributions of this paper are as follows: • We propose a communication efficient federated learning approach based on knowledge distillation, which can achieve competitive results with much less communication cost.
• We propose an adaptive mutual knowledge distillation method to encourage teacher and student to learn from each other and be aware of their prediction correctness.
• We propose a dynamic gradient approximation method based on SVD for gradient compression with dynamic precision to further reduce communication cost.
• We conduct extensive experiments on benchmark datasets for different tasks to verify the effectiveness and efficiency of our approach. is a privacy-aware technique to learn intelligent models from decentralized data storage, where the raw user data never leaves where it is stored. It has been widely used in many applications like intelligent keyboard (Hard et al., 2018), personalized recommendation (Lin et al., 2020a;, topic modeling (Jiang et al., 2019) and medical natural language processing (Ge et al., 2020). In federated learning, there are usually a number of user devices that locally keep the privacy-sensitive user data, and a server that coordinates these user devices for collaborative model learning. Each user device contains a local model copy and computes the model update based on the local data. The model updates from a certain number of user devices are uploaded to the server, which aggregates the local updates into a global one for updating its maintained global model. The updated global model is further distributed to user devices to update their local model copies. This process will be repeated until the model is fully trained. Since the model updates usually contain much less private information (McMahan et al., 2017), federated learning can exploit decentralized data for model learning and significantly reduce privacy and security risks. However, since model updates are communicated between the server and user clients for many rounds, the communication cost would be huge if the model is large. To remedy this issue, we propose a communication efficient federated learning method with knowledge distillation, which can reduce the parameters to be communicated and meanwhile keep competitive model performance.

Knowledge Distillation
Knowledge distillation is a technique to transfer knowledge from a large teacher model (e.g., BERT) to a small student model (Hinton et al., 2015), which is widely used for model compression (Sanh et al., 2019;Sun et al., 2019;Jiao et al., 2020;Wang et al., 2020b). For example, Sanh et al. (2019) proposed a DistilBERT approach that distills useful knowledge from the output using the distillation loss and the hidden states of the teacher model via a cosine loss. Sun et al. (2019) proposed a BERT-PKD approach that aligns the hidden states of the student model with the teacher using a mean squared error loss. Jiao et al. (2020) proposed a TinyBERT approach that can additionally transfer useful knowledge from the attention matrix of the teacher model. However, these methods usually require centralized data storage, which may pose privacy issues during data collection.

Communication Efficient FL
Generally, the communication efficiency of federated learning can be improved by gradient compression (Konečnỳ et al., 2016;Caldas et al., 2018;Rothchild et al., 2020) and knowledge distillation (Sui et al., 2020). Both genres of methods are orthogonal in reducing the communication cost of federated learning and are usually compatible with each other. A core technique used by existing knowledge distillation-based federated learning methods is codistillation (Anil et al., 2018). In this method, the models on different clients are learned on the same dataset. The output of each model is regularized to be similar to the ensemble of predictions from all models via a distillation loss. The idea of codistillation is used by several methods to reduce communication cost of federated learning (Sui et al., 2020;Li and Wang, 2019; Seo et al.,  2020; Lin et al., 2020b;Sun and Lyu, 2020). For example, Sui et al. (2020) proposed a federated ensemble distillation approach for medical relation extraction. It first learns student models locally on each client and then uses the student models to generate predictions on a shared dataset and upload them to a server. The server ensembles the predictions from different clients as a virtual teacher and computes the distillation loss between the teacher and students. In this way, the model parameters do not need to be uploaded, and only the predictions on the shared dataset are communicated, which can reduce the communication cost. However, these methods require a dataset that is shared among different clients to conduct ensemble distillation. Unfortunately, in many scenarios such as personalized recommendation, the data (e.g., user behavior logs) is highly privacy-sensitive and cannot be shared or exchange among different clients. Thus, these methods cannot be applied to these scenarios. By contrast, our approach circumvents the need of a shared dataset because the teacher models in our approach are locally stored on different clients. Our approach can also effectively reduce the communication cost by communicating a distilled tiny student model instead of the original large model and using SVD to reduce gradient size.

FedKD
In this section, we introduce our communication efficient federated learning approach based on knowledge distillation (FedKD). We first present a def-inition of the problem studied in this paper, then introduce the details of our approach, and finally present some discussions on the computation and communication complexity of our approach.

Problem Definition
In our approach, we assume that there are N clients that locally store their private data, where the raw data never leaves the client where it is stored. We denote the dataset on the i-th client as D i . In our approach, each client keeps a large local teacher model T i with a parameter set Θ t i and a local copy of a smaller shared student model S with a parameter set Θ s . In addition, a central server coordinates these clients for collaborative model learning. The goal is to learn a strong model in a privacypreserving way with less communication cost.

Federated Knowledge Distillation
Next, we introduce the details of our federated knowledge distillation framework, as shown in Figure 1. In each iteration, each client simultaneously computes the update of the local teacher model and the student model based on the supervision of the labeled local data and the knowledge distilled from each other. The teacher models are locally updated, while the student model is shared among different clients and are learned collaboratively. Since the local teacher models have more sophisticated architectures than the student model, the useful knowledge encoded by the teacher model can help teach the student model. In addition, since the teacher model can only learn from local data while the student model can see the data on all clients, the teacher can also benefit from the knowledge distilled from the student model.
In our approach, we use three loss functions to learn student and teacher models locally, including an adaptive mutual distillation loss to transfer knowledge from output soft labels, an adaptive hidden loss to distill knowledge from the hidden states and self-attention heatmaps, and a task loss to directly provide task-specific supervision for learning the teacher and student models. We denote the soft probabilities of a sample x i predicted by the local teacher and student on the i-th client as y t i and y s i , respectively. Since incorrect predictions from the teacher/student model may mislead the other one in the knowledge transfer, we propose an adaptive method to weight the distillation loss according to the quality of predicted soft labels. We first use the task labels to compute the task losses for the teacher and student models (denoted as L t t,i and L s s,i ). We denote the gold label of x i as y i , and the task losses are formulated as follows: where CE stands for cross-entropy. The adaptive distillation losses for both teacher and student models (denoted as L d t,i and L d s,i )are formulated as follows: where KL means the Kullback-Leibler divergence. In this way, the distillation intensity is weak if the predictions of teacher and student are not reliable. The distillation loss becomes dominant if the student and teacher are well tuned, which has the potential to mitigate the risk of overfitting. In addition, previous works have validated that transferring knowledge between the hidden states (Sun et al., 2019) and hidden attention matrices (Jiao et al., 2020) (if available) is beneficial for student teaching. Thus, taking language model distillation as an example, we also introduce additional adaptive hidden losses to align the hidden states and attention heatmaps of the student and the local teachers. The losses for the teacher and student models (denoted as L h t,i and L h s,i ) are formulated as follows: where MSE stands for the mean squared error, H t i , A t i , H s , and A s respectively denote the hidden states and attention heatmaps in the i-th local teacher and the student, and W h i is a learnable linear transformation matrix. Here we propose to control the intensity of the adaptive hidden loss based on the prediction correctness of the student and teacher. Besides, motivated by the task-specific distillation framework in (Tang et al., 2019), we also learn the student model based on the task-specific labels on each client. Thus, on each client the unified loss functions for computing the local update of teacher and student models (denoted as L t,i and L s,i ) are formulated as follows: The student model gradients g i on the i-th client can be derived from L s,i via g i = ∂L s,i ∂Θ s , where Θ s is the parameter set of student model. The local teacher model on each client is immediately updated by their local gradients derived from the loss function L t,i . Afterwards, the local gradients g i on each client will be uploaded to the central server for global aggregation. Since the raw model gradients may still contain some private information (Zhu and Han, 2020), we encrypt the local gradients before uploading. The server receives the local student model gradients from different clients and uses a gradient aggregator 1 to synthesize the local gradients into a global one (denoted as g). The server further delivers the aggregated global gradients to each client for a local update. The client decrypts the global gradients to update its local copy of the student model. This process will be repeated until both student and teacher models converge. Note that in the test phase, the teacher model is used for label inference.

Dynamic Gradients Approximation
In our FedKD framework, although the size of student model updates is smaller than the teacher models, the communication cost can still be relatively high when the student model is not tiny. Thus, we propose to a dynamic gradients approximation method to compress the gradients exchanged among the server and clients to further reduce computational cost. As shown in Fig. 1, we first factorize the local gradients into smaller matrices before uploading them. The server reconstructs the local gradients by multiplying the factorized matrices before aggregation. The aggregated global gradients are further factorized, which are distributed to the clients for reconstruction and model update. More specifically, we denote the gradient g i ∈ R P ×Q as a matrix with P rows and Q columns (we assume P ≥ Q). 2 It is approximately factorized into the multiplication of three matrix, i.e., are factorized matrices and K is the number of retained singular values. If the value of K satisfies P K + K 2 + KQ < P Q, the size of uploaded and downloaded gradients can be reduced. We denote the singular values of g i as [σ 1 , σ 2 , ..., σ Q ] (ordered by their absolute values). To control the approximation error, we use an energy threshold T to decide how many singular values are kept, which is formulated as follows: To better help the model converge, we propose to use a dynamic value of T . The function between the threshold T and the percentage of training steps t is formulated as follows: where T start and T end are two hyperparameters that control the start and end values of T . In this way, the student model is learned on roughly approximated gradients at the beginning, while learned on more accurately approximated gradients when the model gets to convergence, which can help learn a more accurate student model. To help readers better understand how FedKD works, we summarize the entire workflow of FedKD in the Algorithm 1 in Appendix.

Complexity Analysis
In this section, we will present some analysis on the complexity of our FedKD approach in terms of computation and communication cost. We denote the number of communication rounds as R and the average size of dataset on each client as D. Thus, the computational cost of directly learning a large model (the parameter set is denoted as Θ t ) in a federated way is O(RD|Θ t |), and the communication cost is O(R|Θ t |). 3 In FedKD, the communication cost is O(R|Θ s |/ρ) (ρ is the gradient compression ratio), which is much smaller because |Θ s | |Θ t | and ρ > 1. The computational cost contains three parts, i.e., local teacher model learning, student model learning and gradient compression/reconstruction, which are O(RD|Θ t |), O(RD|Θ s |) and O(RP Q 2 ), respectively. The total computational cost of FedKD is O(RD|Θ t | + RD|Θ s | + RP Q 2 ). In practice, compared with the standard FedAvg (McMahan et al., 2017) method, the extra computational cost of learning the student model in FedKD is much smaller than learning the large teacher model, and SVD can also be very efficiently computed in parallel. Thus, FedKD is efficient in terms of both communication and computation.

Datasets and Experimental Settings
Our experiments are conducted in two tasks that involve user data. The first one is personalized news recommendation, which needs to predict whether a user will click a candidate news based on the user interest inferred from historical news click behaviors. In this task we use the MIND  dataset. 4 It contains the news impression logs of 1 million users on the Microsoft News platform during 6 weeks. The logs in the last week are used for test, and the rest are for training and validation. The second one is adverse drug reaction (ADR) mentioning tweet detection, which is a binary classification task. We use the dataset released by the 3rd shared task of the SMM4H 2018 workshop (Weissenbacher et al., 2018). 5 We denote this dataset as SMM4H. The original SMM4H dataset contains 25,678 tweet IDs. However, since many tweet texts in this dataset are no longer available, we only crawled 16,694 tweets for experiments. Following (Wu et al., 2019d), we use 80% of the dataset for training, 10% for validation and 10% for test. The detailed statistics of these two datasets are summarized in Table 1. To simulate the scenario where private data is decentralized on different clients, we randomly divide the training data into 4 folds and assume that each fold is locally stored on different clients.  In our experiments, on each client we use the UniLM-Base (Bao et al., 2020) model as the local teacher. 6 We use its submodels with the first 4 or 2 Transformer layers as the student models. On the MIND dataset, we incorporate the language model as the news encoder of NAML. On the SMM4H dataset we apply an attentive pooling and a dense layer after the language model for text classification. The energy thresholds T start and T end are 0.95 and 0.98, respectively. The optimizer we use is Adam (Bengio and LeCun, 2015). 7 Following , on the MIND dataset, we use AUC, MRR, nDCG@5 and nDCG@10 as the metrics. On the SMM4H dataset, we use precision, recall and Fscore of the positive class as the metrics (Wu et al., 2019d). We repeat each experiment repeat 5 times to mitigate occasionality.

Performance Evaluation
First, we compare the performance and communication cost 8 of FedKD with several additional baselines, including: (1) UniLM (Local), learning the full UniLM model with the local data on a client; (2) UniLM (Cen), learning the full UniLM model on centralized data; (3) UniLM (Fed), learning the full UniLM model in the standard federated framework; (4) DistilBERT (Sanh et al., 2019), finetuning the DistilBERT model in federated learning; (5) BERT-PKD (Sun et al., 2019), finetuning BERT-PKD in a federated manner; (6) Tiny-BERT (Jiao et al., 2020), finetuning TinyBERT in a federated way; (7) UniLM 4/2 , using the first 4 or 2 layers of UniLM in federated learning. (8) 6 We take pre-trained language model distillation as a representative example in our experiments. 7 The detailed hyperparameter settings of our approach and baselines are in the Appendix. 8 The communication cost on the two datasets are slightly different due to the number of updated token embeddings.   FetchSGD (Rothchild et al., 2020), a count sketch based communication efficient federated learning method. (9) FedDropout (Caldas et al., 2018), a federated dropout method to reduce the number of exchanged parameters. In the methods (4)-(6), we compare the performance of their officially released 6-layer and 4-layer models. In methods (8) and (9), we use the full UniLM model. The results on the MIND and SMM4H datasets are respectively shown in Tables 2 and 3. From the results, we have the following findings. First, compared with UniLM (local), other methods achieve better performance. This is because the local data on a single client may not be sufficient to learn a strong model, while federated learning can exploit data decentralized on multiple clients to facilitate model training. Second, although UniLM achieves the best performance, the communication cost for model learning is huge (e.g., over 2GB for each client on the MIND dataset). Thus, it may be difficult to incorporate it in real-world applications. Third, compared with the off-the-shelf distilled models like DistilBERT, BERT-PKD and TinyBERT, our FedKD approach performs better. This is because the former meth- ods are distilled in a task-agnostic manner, which may be suboptimal in downstream tasks without further task-specific distillation. Fourth, FedKD also outperforms UniLM 4 and UniLM 2 . This is because FedKD can learn useful knowledge from the output and intermediate results of the sophisticated local teacher models while UniLM 4 and UniLM 2 cannot. Fifth, FedKD can achieve better performance and lower communication cost than other communication efficient methods like FedtchSGD and FedDropout. It is because FedKD can transfer rich knowledge between the teacher and student models to improve the model performance, and can reduce the communication cost by exchanging the updates of a small student model and meanwhile compress the gradients with SVD. Sixth, the communication cost of FedKD is much less than the original UniLM model, and the performance of FedKD is comparable with UniLM (Fed) and UniLM (Cen). These results show that FedKD can effectively reduce the communication cost of federated learning while keeping good performance.

Effectiveness of Adaptive Mutual Distillation
We also verify the effectiveness of our proposed adaptive mutual distillation method. We first compare the performance of FedKD models trained with or without mutual distillation (the teacher model is only learned on local data), as shown in Fig. 2. 9 We observe that mutual distillation can effectively improve the performance of both teacher and student models with different sizes, especially the teacher model. This is because useful knowledge transferred between the teachers and student can help student better imitate the complicated teacher models, and can help teachers break the limitation of the amount of local labeled data.
In addition, we observe that local teachers slightly outperform the student. Thus, we choose to use the teacher models for inference in the test stages. We further compare FedKD and its variants with the adaptive mutual distillation loss, the adaptive hidden loss or the adaptive loss weighting method removed, as shown in Fig. 3 (we report the performance of teacher models). We can see that both adaptive mutual distillation and adaptive hidden losses are useful for improving the model performance. In addition, the performance is suboptimal when the adaptive loss weighting method is removed (this variant is similar to the standard mutual distillation ). This is because weighting the distillation and hidden losses can be aware of the correctness of model predictions, which may help distill higher-quality knowledge and meanwhile mitigate the risk of overfitting.

Influence of Client Number
We study the influence of client number on the model performance in this section. We divide the full training data into different numbers of folds to simulate the scenarios with different amounts of labeled data on each client. Figure 4 shows the performance of FedKD and UniLM 4/2 under different numbers of clients. We find the performance of FedKD is similar and can even be slightly improved when more clients are involved. This is because FedKD can learn from multiple teacher models on 9 We only include results on MIND due to space limit. The results on SMM4H are in Appendix.   different clients, which can encode richer knowledge when more teacher models participate. On the contrary, the performance of UniLM 4/2 (Fed) slightly declines with the increase of client number. This may be because the vanilla FedAvg method has some performance sacrifice by learning models for multiple epochs on limited local data.

Impact of Energy Threshold
We then study the influence of the energy threshold T start and T end on the performance and communication cost of our approach. We first vary T start under T end = 1, and the results are shown in Fig. 6(a). We find the communication cost is smaller when T start is smaller, while we observe that the performance starts to drop quickly when T start < 0.95. Thus, we chose T start = 0.95 to balance communication cost and model performance. Under T start = 0.95, we then vary T end to compare the performance and communication cost, as shown in Fig. 6(b). In a similar way, we choose T end = 0.98 to achieve a good tradeoff between model accuracy and communication cost.

Analysis of Dynamic Gradient Approximation
Finally, we present some analysis of our proposed SVD-based gradient compression method. We show the cumulative energy distributions of singu- lar values of different parameter gradient matrices in the UniLM model in Fig. 6, which reveals several interesting findings. First, all kinds of parameter matrices in UniLM are low-rank, especially the parameters in the feed-forward network. Thus, the communication cost can be greatly reduced by compressing the low-rank gradient matrices. In addition, we find the singular value energy is more concentrated at the beginning than the end of training. This may be because when the model is not well-tuned, the gradients may have more low frequency components that aim to push the model to converge more quickly. However, when the model gets to converge, the updates of model parameters are usually subtle, which yields more high frequency components. The evolution of required singular values under T = 0.95 is shown in Fig. 7.
We can see that more singular values need to be retained to achieve the same energy threshold. To ensure the model accuracy of FedKD, we choose to set a higher as the model training continues to learn more accurate models.

Conclusion
In this paper, we propose a communication efficient federated learning method based on knowledge distillation named FedKD. In our approach, we propose an adaptive mutual distillation method to reciprocally learn a teacher model and a student model on each client, where the distillation intensity is controlled by their prediction correctness. The large teacher model is locally updated, while the small student model is shared among different clients and learned collaboratively, which can effectively reduce the communication cost. In addition, we propose a dynamic gradient approximation method to further reduce the communication cost. Extensive experiments on two benchmark datasets for different tasks validate that FedKD can largely reduce the communication cost in federated learning while keeping promising model performance.

A.3 Algorithm Workflow
The workflow of FedKD is summarized in Algorithm 1.

A.4 Experimental Environment
Our experimental environment is built on a Linux server with Ubuntu 16.04 operation system. The version of Python is 3.6 The server has 4 Tesla  V100 GPUs with 32GB memory. The CPU type is Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz. The total memory is 128GB. We use the horovod framework for parallel model training on the 4 GPUs, each of which represents a platform.

A.5 Model Initialization
In our approach, we use the token embedding layer and the first 4 or 2 layers of UniLM to initialize the student model. We do not change the hidden dimension of the model because the UniLMv2 models with other hidden dimensions are not released. Note that our approach does not have limitations on the hidden dimension of the student model.

A.6 Running Time
On the MIND dataset, the total training of time FedKD 4 and FedKD 2 are take around 66 and 57 hours, respectively. On the SMM4H dataset, their total training times are about 12 minutes and 10.5 minutes, respectively.