掌桥专利:专业的专利平台
掌桥专利
首页

一种基于批标准化层参数修正联邦学习的图像分类方法

文献发布时间:2023-06-19 19:27:02


一种基于批标准化层参数修正联邦学习的图像分类方法

技术领域

本发明涉及图像分类,特别是涉及一种基于批标准化层参数修正联邦学习的图像分类方法。

背景技术

随着物联网的兴起和用户数据隐私保护意识的增强,联邦学习(federatedlearning,FL)框架被提出,其通过联合多个边缘客户端训练深度神经网络(deepneuralnetwork,DNN)模型,并且无需访问客户端的原始数据。与此同时,许多DNN模型采用批标准化(batch normalization,BN)来提升模型的训练速度与泛化能力。

在实际训练中,不同客户端的本地数据集往往是异构的,这会导致联邦学习在训练含有BN的DNN模型时,性能出现大幅下降。因此,需要针对含有BN层的DNN模型,设计一种联邦学习算法,使得DNN模型在不同数据分布下都可以取得好的训练性能。

当不同客户端具有不同的本地数据分布时,当前的联邦学习算法无法达到好的训练效果,无法保证模型参数能够收敛到一个好的解。

例如,利用联邦学习框架训练ResNet-20模型,对CIFAR-10数据库中的数据进行分类。当不同客户端的本地数据集具有相同的分布时,训练所得的模型在测试数据集的分类精度可以达到90%左右;而当不同客户端具有不同分布的本地数据集时,训练所得的模型的分类精度会出现大幅度的下降,可能只能达到40%左右。

发明内容

本发明的目的在于克服现有技术的不足,提供一种基于批标准化层参数修正联邦学习的图像分类方法,通过修正本地模型训练时的批均值与批方差、以及批均值与批方差的梯度,降低本地模型的训练偏差,提高图像分类的准确性。

本发明的目的是通过以下技术方案来实现的:一种基于批标准化层参数修正联邦学习的图像分类方法,包括以下步骤:

S1.构建联邦学习场景,包括1个云服务器和N个分布于不同位置的客户端,所述云服务器通过网络分别与每一个客户端连接;

S2.各个客户端持续采集图像样本,将采集到的图像样本表示为RGB矩阵,并利用图像类别作为标签对RGB矩阵进行标记,在每个客户端形成本地数据库;

S3.云服务器构建一个用于图像分类的包含BN层的DNN模型,初始化DNN模型的参数,并设定联邦学习参数;

S4.在任一轮迭代过程中,首先初始化本地模型参数,然后每个客户端利用本地数据库进行本地模型的更新,然后将更新结果上传到服务器,由服务器进行全局模型的更新;

S5.重复执行步骤S4,对全局模型进行R轮迭代训练,得到最终的DNN模型并分发给每一个客户端,客户端根据得到的DNN模型对待识别的图像进行分类。

本发明的有益效果是:本发明通过修正本地模型训练时的批均值与批方差、以及批均值与批方差的梯度,降低本地模型的训练偏差,从而使得联邦学习算法达到好的训练效果,提高图像分类的准确性。

附图说明

图1为本发明的方法流程图;

图2为不同客户端具有相同的本地数据库分布时迭代次数与测试精度的关系示意图;

图3为不同客户端具有不同的本地数据库分布时迭代次数与测试精度的关系示意图。

具体实施方式

下面结合附图进一步详细描述本发明的技术方案,但本发明的保护范围不局限于以下所述。

如图1所示,一种基于批标准化层参数修正联邦学习的图像分类方法,其特征在于:包括以下步骤:

S1.构建联邦学习场景,包括1个云服务器和N个分布于不同位置的客户端,所述云服务器通过网络分别与每一个客户端连接;

S2.各个客户端持续采集图像样本,将采集到的图像样本表示为RGB矩阵,并利用图像类别作为标签对RGB矩阵进行标记,在每个客户端形成本地数据库;

S201.对于任一客户端i,首先在本地持续采集图像样本,然后将将每张新的图像样本表示为一个维度为M×M×3的RGB矩阵ξ

其中,X

例如,对于CIFAR-10数据集,每张图片是一个维度为32×32×3的RGB彩色图片。CIFAR-10数据集一共包含10个类别:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。在此场景下,图像所属的类别为这些物体的名称,即飞机、汽车等。图像标签可以被定义为0至9,数字0对应飞机,数字9对应卡车。

S202.客户端i将新采集的图像样本数据ξ

S203.在i=1,…,N时,对于各个客户端重复执行步骤S201~S202,在每一个客户端中均得到一个本地数据库。

S3.云服务器构建一个用于图像分类的包含BN层的DNN模型,初始化DNN模型的参数,并设定联邦学习参数;在实际执行过程中,所选DNN模型可以为一些成熟的图像分类DNN模型,如ResNet等。

所述步骤S3中,设定的联邦学习参数包括:总体迭代轮数R,客户端在每一轮的本地模型更新次数E,用户i的权重p

初始化DNN模型参数为

其中,BN(batch normalization)是指批标准化:即将DNN中间层输入数据做标准化处理,使得输出服从正态分布,从而避免变量分布偏移的问题。具体来说,记BN层l的输入数据为Y

其中,γ

其中,

S4.在任一轮迭代过程中,首先初始化本地模型参数,然后每个客户端利用本地数据库进行本地模型的更新,然后将更新结果上传到服务器,由服务器进行全局模型的更新;

S401.设在第1轮迭代中,云服务器将初始化的DNN模型参数

S402.每个客户端i(i=1,…,N)将本地DNN模型的参数

S403.每个客户端进行E次本地模型更新。

所述步骤S403包括:

A1、在第1次本地更新时,每个客户端i,逐层对BN层的批均值与批方差进行修正,得到临时统计参数

所述步骤A1包括:

A101:每个客户端i从最新的本地数据库

A102:每个客户端i将

当网络层为非BN层时,直接计算该网络层的输出

否则,进行步骤A103至A112;

A103:对于客户端i,当网络层为BN层时,记

A104:每个客户端i计算

A105:每个客户端i将本地批均值

A106:云服务器计算全局批均值为

A107:每个客户端i将本地批均值

A108:每个客户端i计算

A109:每个客户端i将本地批方差

A110:服务器计算全局批均值为

A111:每个客户端i将本地批方差

A112:每个客户端i将BN层l的输入

其中,γ

A113:对于DNN模型中包含的每一个BN层,重复步骤A103至A112,修正所有BN层的批均值和批方差,完成前向传播过程,得到统计参数

A2、每个客户端i(i=1,…,N)逐层地对BN层的批均值梯度与批方差梯度进行修正,计算梯度参数的梯度

所述步骤A2包括以下子步骤:

A201:客户端i根据DNN模型输出

A202:客户端i从输出层到输入层,逐层计算每个网络层参数的梯度,该过程为反向传播过程;当网络层为非BN层时,计算该网络层梯度参数

否则,进行步骤A203至步骤A209;

A203:对于客户端i,当网络层为BN层时,记

A204:每个客户端i计算本地批方差梯度为

A205:每个客户端i计算本地批均值梯度为

A206:每个客户端i将本地批均值梯度

A207:云服务器计算全局批均值梯度为

A208:每个客户端i将本地批均值梯度

A209:每个客户端i计算尺度参数的梯度为

A210:对于DNN模型中包含的每一个BN层,重复步骤A203至A209,修正所有BN层的批均值梯度和批方差梯度,完成反向传播过程,得到梯度参数的梯度

A3、每个客户端i(i=1,…,N)更新本地统计参数为

A4、在完成第1次本地更新后,每个客户端i(i=1,…,N)接着利用梯度下降方法进行E-1次本地DNN模型更新;

所述步骤A4包括以下子步骤:

A401:在第t步本地更新时,每个客户端i从最新的本地数据库

A402:客户端i将

A403:对于客户端i,当网络层为BN层时,记

A404:每个客户端i计算

A405:每个客户端i将

其中,γ

A406:对于DNN模型中包含的每一个BN层,重复步骤A403至A405,修正所有BN层的批均值和批方差,完成前向传播过程,得到统计参数

A407:客户端i根据DNN模型输出

A408:客户端i从输出层到输入层,逐层计算每个网络层参数的梯度,该过程为反向传播过程;当网络层为非BN层时,计算该网络层梯度参数

A409:对于客户端i,当网络层为BN层时,记

A410:每个客户端i计算本地批方差梯度为

A411:每个客户端i计算尺度参数的梯度为

A412:对于DNN模型中包含的每一个BN层,重复步骤A409至A411,修正所有BN层的批均值梯度和批方差梯度,完成反向传播过程,得到梯度参数的梯度

A413:每个客户端i更新统计参数为

A414:在t=2,…,E时,重复步骤A401至A403,完成对本地DNN模型的E-1步更新。

A5、进行全局模型的更新:

每个客户端i(i=1,…,N)将更新后的本地DNN模型

S5.重复执行步骤S4,对全局模型进行R轮迭代训练,得到最终的DNN模型并分发给每一个客户端,客户端根据得到的DNN模型对待识别的图像进行分类。

所述步骤S5包括以下子步骤:

S501.对全局DNN模型按照步骤S4进行R轮迭代训练,得到最终DNN模型

S502.云服务器将联邦训练所得的DNN模型

S503.对于客户端i(i=1,…,N)中新采集的一张图像,按照步骤S2获得图像样本数据ξ

在本申请的实施例中,通过修正本地DNN模型在训练过程中的批均值与批方差及其梯度,可以有效地提升联邦学习在训练含有批标准化层的DNN模型时,在不同数据分布下都可以取得好的训练性能,设置如下实验参数,进行仿真实验:

训练数据集:CIFAR-10,其中,训练数据集与测试数据都包含了10种类别的数据样本;

训练DNN模型:ResNet-20;

客户端数量N=5;

本地DNN模型更新次数E=5;

学习速率γ:第1次到第6000次迭代,γ=0.5;第6001次到第10000次迭代,γ=0.05;

衰减系数ρ=0.1;

当不同客户端具有相同的本地数据库分布时,每个客户端包含10种类别的数据样本,且每种类别的数据样本所占的比例相同;当不同客户端具有不同的本地数据库分布时,每个客户端只包含2种类别的数据样本;

对于每一种联邦学习方法,进行5次独立的实验,取这5次实验结果的平均值作为最终结果。

得到不同客户端具有相同的本地数据库分布时迭代次数与测试精度的关系如图2所示,得到不同客户端具有不同的本地数据库分布时迭代次数与测试精度的关系如图3所示,从图中可以看出,相比于FedAvg与FedBN算法,所提算法可以在不同数据分布下,都能取得好的训练效果。其中,当不同客户端具有相同的本地数据库分布时,FedAvg和FedBN算法在测试数据集的分类精度(即测试精度)分别可以达到89.42%和88.83%,而当不同客户端具有不同的本地数据库分布时,这两种算法的测试精度分别只能达到36.65%和19.24%。对于该专利所提的图像分类方法,其在相同和不同的本地数据库分布下,测试精度分别为89.32%和86.69%,均取得了较好的图像分类模型训练效果。

上述说明示出并描述了本发明的一个优选实施例,但如前所述,应当理解本发明并非局限于本文所披露的形式,不应看作是对其他实施例的排除,而可用于各种其他组合、修改和环境,并能够在本文所述发明构想范围内,通过上述教导或相关领域的技术或知识进行改动。而本领域人员所进行的改动和变化不脱离本发明的精神和范围,则都应在本发明所附权利要求的保护范围内。

相关技术
  • 基于联邦学习的模型参数获取方法、系统及可读存储介质
  • 在安全联邦学习的逻辑回归模型中进行批标准化的方法
  • 一种基于联邦学习的差分隐私图像分类方法及装置
技术分类

06120115918879