论文标题
具有确切随机梯度下降的个性化联合学习
Personalized Federated Learning with Exact Stochastic Gradient Descent
论文作者
论文摘要
在联合学习(FL)中,客户跨客户的数据集往往是异质的或个性化的,这给不解释个性化的标准FL计划的融合带来了挑战。为了解决这个问题,我们提出了一种新的个性化FL方法,该方法可实现确切的随机梯度下降(SGD)最小化。我们从Fedper(Arivazhagan等人,2019年)神经网络(NN)架构开始,用于个性化,因此NN具有两种类型的层:第一个层是跨客户的普通层,而最终的层是特定于客户端的层,并且是个性化的。我们提出了一种新颖的SGD型方案,在每个优化回合中,随机选择的客户在其特定于客户端特定的权重上执行渐变的更新,以优化自己的数据集上的损失功能,而无需更新常见的权重。在最终更新中,每个客户端都可以在客户特定和共同的权重上计算关节梯度,并将通用参数的梯度返回到服务器。这允许以分布式的方式在整个参数上执行精确且无偏的SGD步骤,即,个性化参数的更新由客户端和服务器的常见参数执行。我们的方法在多类分类基准(例如Omniglot,Cifar-10,Mnist,Mnist,Fashion-Mnist和Emnist)中优于FedAvg和Fedper基准,并且每回合的计算复杂性要低得多。
In Federated Learning (FL), datasets across clients tend to be heterogeneous or personalized, and this poses challenges to the convergence of standard FL schemes that do not account for personalization. To address this, we present a new approach for personalized FL that achieves exact stochastic gradient descent (SGD) minimization. We start from the FedPer (Arivazhagan et al., 2019) neural network (NN) architecture for personalization, whereby the NN has two types of layers: the first ones are the common layers across clients, while the few final ones are client-specific and are needed for personalization. We propose a novel SGD-type scheme where, at each optimization round, randomly selected clients perform gradient-descent updates over their client-specific weights towards optimizing the loss function on their own datasets, without updating the common weights. At the final update, each client computes the joint gradient over both client-specific and common weights and returns the gradient of common parameters to the server. This allows to perform an exact and unbiased SGD step over the full set of parameters in a distributed manner, i.e. the updates of the personalized parameters are performed by the clients and those of the common ones by the server. Our method is superior to FedAvg and FedPer baselines in multi-class classification benchmarks such as Omniglot, CIFAR-10, MNIST, Fashion-MNIST, and EMNIST and has much lower computational complexity per round.