掌桥专利:专业的专利平台
掌桥专利
首页

基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质

文献发布时间:2024-04-18 19:58:26


基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质

技术领域

本发明属于数据安全领域,尤其涉及一种基于自适应学习率的异步联邦学习参数更新方法、设备及系统。

背景技术

近年来,随着移动和边缘设备已经广泛采用,并为各种应用生成了大量有价值的数据。这些设备也增加了机器学习的需求,以实现个性化和低延迟的AI应用。然而,由于隐私和带宽限制,集中式数据收集和模型训练是不可行的。因此,针对这些问题,谷歌提出了联邦学习用于解决机器学习模型训练的数据需求与用户数据隐私保护之间的矛盾。联邦学习已经成为一种新的范式,使得在大量边缘设备(客户端)之间进行协作机器学习成为可能,而无需共享他们的数据。联邦学习也可以用于用户数据必须保密或不能离开其原始环境的场景,例如在医疗和金融领域。

经典的联邦学习方法大多数是运行在同步的系统中,在每一次迭代中,中心服务器会随机抽取一些工作节点基于本地数据完成本地训练,工作节点将训练好的模型上传至中心服务器,随后中心服务器将收集到的模型参数进行聚合,再向每个工作节点发送更新后的模型。但是在设备异质性和网络异质性的情景中,经典的联邦学习方法面临着拖延者效应,会导致每一轮迭代的运行时间变长,所以联邦学习每一轮迭代的运行时间由最慢的学习者决定。

部分学者已经提出异步联合学习来解决这个问题,每个客户端独立地更新全局模型,这显示了更大的灵活性和可扩展性。在每一轮迭代中,完成本地训练的工作节点上传其模型参数,当中心服务器收到K个更新之后,中心服务器开始进行参数聚合。没有参加本轮聚合的工作节点继续完成其本地训练,等待参与下一轮的更新。异步联邦学习可以降低下一轮迭代中本地训练消耗的时间,从而缓解拖延者效应。

尽管K异步联邦学习方法具有以上优点,但是在实践中经常面临以下两个问题:1)延迟的模型梯度更新是基于陈旧模型进行的,因此延迟梯度相较于当前最新梯度具有一定的方向误差;2)由于多个工作节点上的数据类别分布通常不能服从独立同分布,这会造成不同工作节点的本地梯度更新方向均与中心服务器不一致,从而降低了模型的效应性,甚至会导致不收敛的问题。为了解决上述问题,现有的工作提出了基于两阶段训练策略的异步联邦学习方法,以加速训练并降低数据异质性的影响。但是该工作没有考虑到两阶段训练带来的巨大计算量和通信成本,同时现有工作中衡量梯度陈旧度的方法的策略是通过迭代滞后轮次或本地训练时间。显然,只有少数低延时的梯度会被聚合,大部分高延时的梯度将被过滤掉。

因此现有的技术需要一种能够既能有效缓解数据不平衡问题,又能解决延时梯度的异步联邦学习方法。

发明内容

发明目的:该发明旨在解决异步联邦学习中由于数据不平衡和延迟梯度导致的模型效用降低的问题。为此,本发明提出了一种基于自适应学习率的异步联邦学习参数更新方法,以解决异步联邦学习中的不平衡问题和陈旧问题。

在第一方面上,根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,用于中心服务器,包括

S110.中心服务器接收更新,更新包括由工作节点发出的梯度;

S120.中心服务器根据同步梯度估计全局无偏梯度;

S130.中心服务器根据全局无偏梯度计算延迟梯度的陈旧度;

S140.中心服务器根据陈旧度为延迟梯度调整学习率;

S150.中心服务器根据学习率更新全局神经网络模型;

S160.中心服务器将更新的全局神经网络模型的参数发出,更新的全局神经网络模型的参数由工作节点接收。

其中,同步梯度是工作节点依据最新的全局神经网络模型计算的梯度,延迟梯度是工作节点依据非最新的全局神经网络模型计算的梯度。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,步骤S10中,中心服务器接收的更新还包括迭代次数,中心服务器根据迭代次数达到预先定义的次数停止更新。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,步骤S10中,中心服务器将接收的更新加入队列,当队列的长度达到设定阈值,中心服务器执行步骤S20。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,所述步骤S110中还包括中心服务器将当前的迭代次数和当前的全局神经网络模型参数广播,广播由本地节点接收。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,估计全局无偏梯度由如下公式表示:

式中,g(w

延迟梯度的陈旧度由如下公式表示:

式中,cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,Gt表示全局无偏梯度g(w

为延迟梯度调整学习率由如下公式表示:

式中,η

更新全局神经网络模型由如下公式表示:

w

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,工作节点的梯度更新的损失函数由如下公式表示:

式中,g

在第二方面上,根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,用于工作节点,包括

S210.工作节点接收全局神经网络模型的参数,全局神经网络模型的参数由中心服务器发出;

S220.工作节点依据全局神经网络模型的参数训练工作节点的本地模型;

S230.工作节点的本地模型进行梯度下降得到更新的参数;

S240.工作节点发出更新,更新包括梯度,工作节点依据最新的全局神经网络模型计算的梯度是同步梯度,工作节点依据非最新的全局神经网络模型计算的梯度是延迟梯度,其中,工作节点发出的同步梯度是由中心服务器接收的同步梯度,用于中心服务器估计全局无偏梯度,根据全局无偏梯度计算延迟梯度的陈旧度,根据陈旧度为延迟梯度调整学习率,根据学习率更新全局神经网络模型。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,所述步骤S230还包括工作节点根据更新次数达到预先定义的次数执行步骤S240,否则执行步骤S220。

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,估计全局无偏梯度由如下公式表示:

式中,g(w

延迟梯度的陈旧度由如下公式表示:

式中,cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,Gt表示全局无偏梯度g(w

为延迟梯度调整学习率由如下公式表示:

式中,η

更新全局神经网络模型由如下公式表示:

w

根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,工作节点的梯度由如下公式表示:

式中,g

本发明与现有技术相比,具有如下有益效果:

在第一方面上,本发明通过搭建基于不平衡数据分布的异步联邦学习算法,一方面可以协助多方共同学习一个准确且通用的神经网络模型,而无需公开和共享他们的本地用户数据集;另一方面本发明采取了一种双管齐下的方法,旨在分别解决客户端和服务器端的数据集不平衡和梯度陈旧性问题。本发明整合了一种新的评估方法,采用余弦相似度来衡量延迟梯度的陈旧性,进一步优化了服务器上的聚合算法,以提高异步联邦学习的性能。此外,还加入了一个类平衡损失函数来克服数据集不平衡问题,可以处理数据异质性的问题。这使得工作节点能够以一致的目标训练一个通用的分类器,而不考虑具体的类分布。从而提高了异步联邦学习训练速度的稳定度。

在第二方面上,本发明从梯度下降方向性的角度将延迟梯度的陈旧度进行了新的定义,现有计算延迟梯度陈旧度的方法认为延迟梯度的陈旧度和版本延迟呈正相关,本发明通过实验验证延迟梯度和同步梯度具有方向误差,但方向误差和版本延迟并不呈绝对的正相关,为此,本发明从梯度下降方向性的角度对陈旧度进行了新的定义,考虑了方向误差和版本延迟并不呈绝对的正相关,因此,本发明能够更好的利用延迟梯度促进模型收敛。

在第三方面,本发明解决了异步联邦学习面临着双重挑战:陈旧性问题和数据集不平衡问题,中心服务器接收完K个梯度后,首先进行无偏梯度估计,并实施一种基于余弦相似度的新型评估方法,以衡量延迟梯度的陈旧度;同时进一步调整学习速率,更新并广播模型参数和迭代次数。对于数据集不平衡问题,工作节点引入了一个类平衡损失函数,可以处理异质性数据对于模型训练的影响,本发明根据延时程度自适应调整学习速率,提高了模型的预测精度。

附图说明

图1为本发明实施例提供的一种基于自适应学习率的异步联邦学习参数更新方法流程图。

图2为本发明实施例提供的中心服务器端流程图。

图3为本发明实施例提供的工作节点端流程图。

图4为本发明实施例提供的基于加权聚合联邦学习的网络流量分类架构图。

图5为本发明实施例提供的不同联邦学习策略的实验对比图。

具体实施方式

下面将结合附图和技术方案,对本发明的实施过程进行详细描述。

实施例1:本发明涉及一种基于自适应学习率的异步联邦学习参数更新方法,还提出实现所述方法的相应的电子设备及可读存储介质。

本实施例的基于自适应学习率的异步联邦学习参数更新方法,用于中心服务器,所述方法包括以下步骤:

中心服务器初始化全局神经网络模型w

各种参数初始化完成后,中心服务器向工作节点分发神经网络模型,等待最快的K个工作节点发来梯度更新;

在第j轮全局迭代中,中心服务器接收到K个梯度更新。具体地说,接收到来自第i个节点的id和来自第i个节点的梯度g(w

中心服务器根据同步梯度的本地样本量计算全局无偏梯度估计,具体为:

其中,g(w

对于延迟梯度中心服务器计算其陈旧度,并且根据各个梯度的陈旧度为其赋予不同的学习率;

其中,中心服务器依据如下公式计算当前延迟梯度的陈旧度:

式中,Gt指的是上一步中计算得到的全局无偏梯度估计,Gt-τ是指陈旧梯度。cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度估计的余弦相似性,也即梯度下降中的方向相似性;∈表示超参数,可以根据不同的数据集或者训练任务进行调节,s(T)表示当前延迟梯度的陈旧性。

中心服务器根据K个梯度的陈旧度自适应调整学习率;

其中,中心服务器依据如下公式调整其学习率:

式中,η

在全局无偏梯度估计、计算陈旧度和调整学习率完成后,更新当前的全局模型w

其中,依据如下公式进行全局模型的更新:

w

本实施例的基于自适应学习率的异步联邦学习参数更新方法,用于工作节点,所述方法包括以下步骤:

工作节点接收来自中心服务器的发送的初始模型参数,模型版本version;

在本地使用类平衡损失函数进行训练,以克服本地数据集不平衡带来的负面影响;

其中,工作节点使用如下类平衡损失函数进行训练:

其中,式中,g

在本地训练t轮过后,工作节点将其训练得到的梯度g(w

利用更新后的权重进行下一轮训练。

本实施例的基于自适应学习率的异步联邦学习参数更新系统,包括中心服务器以及与中心服务器通信相连的多个工作节点,中心服务器与工作节点基于异步联邦学习机制进行参数聚合更新,所述中心服务器按照上述方法进行参数聚合更新,所述工作节点按照上述方法完成参数更新。

基于上述方法,一种用于在参数服务器端进行基于异步联邦学习的参数聚合更新的设备,所述设备包括:

存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,致使所述一个或多个处理器执行如本发明第一方面所述的参数聚合更新方法。

基于上述方法,一种用于在工作节点端进行基于异步联邦学习的参数聚合更新的设备,所述设备包括:

存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,致使所述一个或多个处理器执行如本发明第二方面所述的参数聚合更新方法。

本发明的有益效果:本发明通过搭建基于不平衡数据分布的异步联邦学习算法,一方面可以协助多方共同学习一个准确且通用的神经网络模型,而无需公开和共享他们的本地用户数据集;另一方面本发明采取了一种双管齐下的方法,旨在分别解决客户端和服务器端的数据集不平衡和梯度陈旧性问题。本发明整合了一种新的评估方法,采用余弦相似度来衡量延迟梯度的陈旧性,进一步优化了服务器上的聚合算法,以提高异步联邦学习的性能。此外,还加入了一个类平衡损失函数来克服数据集不平衡问题,可以处理数据异质性的问题。这使得工作节点能够以一致的目标训练一个通用的分类器,而不考虑具体的类分布。从而提高了异步联邦学习训练速度的稳定度。

实施例2:为了解决异步联邦学习中因为数据不平衡和延迟梯度导致的模型效用降低的问题,本发明提出一种基于自适应学习率的异步联邦学习参数更新方法,解决异步联邦学习中的不平衡问题和陈旧问题。该方法的中心服务器端包括如下步骤:

S1、中心服务器初始化模型,模型参数w

S2、中心服务器向连接的工作节点广播当前的全局通信轮次、当前的模型版本和最新的模型参数w

S3、中心服务器与工作节点保持网络连接,接收来自工作节点的更新梯度g(w

S4、判断当前队列中是否接收到K个梯度更新,若当前接收到的更新数少于K个,继续接收来自工作节点的更新,等待更新最快的K个节点发送工作更新。如果队列中的更新数等于K,进行下一步;

S5、中心服务器在接收到的K个梯度中选择基于最新模型更新的梯度为同步梯度,则其余梯度则为陈旧梯度,中心服务器依据以下规则进行全局无偏梯度估计:

式中,g(w

S6、中心服务器计根据余弦相似性算延迟梯度的陈旧性,并且根据各个梯度的陈旧度为其赋予不同的学习率;

到目前为止,衡量局部梯度过时程度的现有策略是通过迭代滞后τ的数量或通过局部训练时间。这些策略在解决实验中的陈旧模型问题方面已经证明了一定的有效性。然而,它们在实际场景中有明显的局限性。例如,一些具有低延迟的梯度可以与当前最新梯度有较高的方向一致性,而一些具有高延迟的梯度可能不会与当前最优梯度方向偏离太多。如果通过迭代滞后或局部训练时间测量的这些梯度的陈旧性,延迟梯度的陈旧度一旦超过某个阈值,则可能会错误地丢弃这些梯度。这会对训练模型的收敛产生不利影响,并减慢训练过程。在实践中,这种方法不能准确地测量陈旧的梯度是否有助于全局模型的收敛。因此本发明设计了基于余弦相似性的延迟梯度陈旧性的衡量方法;

S6.1、中心服务器依据如下公式计算当前延迟梯度的陈旧度:

式中,Gt指的是上一步中计算得到的全局无偏梯度估计,Gt-τ是指陈旧梯度。cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度估计的余弦相似性,也即梯度下降中的方向相似性;∈表示超参数,可以根据不同的数据集或者训练任务进行调节,s(τ)表示当前延迟梯度的陈旧性。

从梯度下降方向性的角度将延迟梯度的陈旧度进行了新的定义,现有计算延迟梯度陈旧度的方法认为延迟梯度的陈旧度和版本延迟呈正相关,本发明通过实验验证延迟梯度和同步梯度具有方向误差,但方向误差和版本延迟并不呈绝对的正相关,为此,本发明从梯度下降方向性的角度对陈旧度进行了新的定义,考虑了方向误差和版本延迟并不呈绝对的正相关,因此,本发明能够更好的利用延迟梯度促进模型收敛。

S7、学习率衰减。在异步联邦学习中,不同客户端的梯度关系可能存在陈旧性,即与最新的全局梯度相比有一定的延迟。这种陈旧性会影响全局模型的更新和收敛性能。为了减少陈旧性的影响,一种常用的方法是对陈旧客户端的权重进行学习率衰减,即降低其在全局权重更新中的贡献。学习率衰减的原则是,陈旧性越大,学习率越小。一种常见的学习率衰减策略是根据客户端的陈旧度τ来调整其权重的更新系数η

式中,η

S8、中心服务器更新模型。在全局无偏梯度估计、计算陈旧度和调整学习率完成后,更新当前的全局模型w

其中,依据如下公式进行全局模型的更新:

w

S9、一轮更新结束之后,中心服务器判断当前轮次是否等于预先定义的总沟通轮次,若无,继续执行当前循环,若完成了T轮训练,则代表全局模型已经训练完成,因此程序训练结束。

在一种实施实例中,本发明的若干工作节点包含如下步骤:

S10、工作节点初始化模型和本地轮次t;

S11、工作节点从中心服务器接收最新的全局模型权重w

S12、工作节点用自己的数据集D

S13、本地模型进行梯度下降算法,得到更新的参数。工作节点根据梯度的方向和一个预设的学习率,更新参数向量,使目标函数沿着梯度下降的方向移动一小步,工作节点的梯度更新的损失函数由如下公式表示:

式中,g

S14、工作节点判断本地更新次数是否等于预先定义的t轮,若小于t轮,继续循环S12-S13,直到达到预设的训练轮数,若达到预定轮次,则本地训练结束;

S15、将梯度g(w

根据本发明的另一实施例,提供一种用于在工作节点端进行基于异步联邦学习的中心聚合更新的设备,设备包括:存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,所述一个或多个处理器执行上述方法实施例中的步骤。

本发明实例提供了基于异步联邦学习的聚合更新方法的实施步骤,需要说明的是,虽然在流程图中给出了逻辑流程顺序,但是在某些情况下,可以以不同的执行顺序所示或描述的步骤。

本发明还提供一种基于异步联邦学习的参数聚合更新系统,包括中心服务器以及与中心服务器通信相连的多个工作节点,参数服务器与工作节点基于异步联邦学习机制进行参数聚合更新,中心服务器根据步骤S1-S9所述的方法进行参数聚合更新;工作节点根据步骤S10-S15所述的方法完成参数更新。

本发明公开了一种基于自适应学习率的异步联邦学习参数更新方法、设备及系统。为了解决异步联邦学习面临着双重挑战:陈旧性问题和数据集不平衡问题,本方法分别在中心服务器和工作节点解决如上问题。中心服务器接收完K个梯度后,首先进行无偏梯度估计,并实施一种基于余弦相似度的新型评估方法,以衡量延迟梯度的陈旧度;同时进一步调整学习速率,更新并广播模型参数和迭代次数。对于数据集不平衡问题,工作节点引入了一个类平衡损失函数,可以处理异质性数据对于模型训练的影响。本发明根据延时程度自适应调整学习速率,提高了模型的预测精度。

本领域内的技术人员应明白,本发明的实施例可提供为方法、设备、装置、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等,上实施的计算机程序产品的形式。

本发明是参照根据本发明实施例的方法、设备(系统,、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。

这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。

这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。最后应当说明的是:以上实施例仅用以说明本发明的技术方案而非对其限制,尽管参照上述实施例对本发明进行了详细的说明,所属领域的普通技术人员应当理解:依然可以对本发明的具体实施方式进行修改或者等同替换,而未脱离本发明精神和范围的任何修改或者等同替换,其均应涵盖在本发明的权利要求保护范围之内。

相关技术
  • 基于联邦学习的模型参数获取方法、系统及可读存储介质
  • 基于深度学习的遥感影像建筑物提取方法及系统、存储介质、电子设备
  • 一种基于图像识别的语言学习方法、电子设备及存储介质
  • 基于半监督学习的联邦建模方法、设备及可读存储介质
  • 基于梯度选择和自适应学习率的加权K异步联邦学习方法、系统及装置
  • 一种自适应客户端参数更新的联邦学习方法、系统及存储介质
技术分类

06120116489324