掌桥专利:专业的专利平台
掌桥专利
首页

分类预测模型训练方法、分类预测方法、设备及存储介质

文献发布时间:2023-06-19 19:28:50


分类预测模型训练方法、分类预测方法、设备及存储介质

技术领域

本发明属于隐私计算技术领域,尤其涉及一种基于联邦知识蒸馏算法的分类预测模型训练方法、分类预测方法、电子设备及存储介质。

背景技术

由于移动设备(例如手机、手表、电脑等)的高度发展和传感技术的进步,大量的数据(为用户私有数据,例如个人图片)被边缘端的移动设备收集,如今人工智能发展迅速,这种隐私数据通常会被聚合并且存储在云端,配合机器学习或者是深度学习模型,实现各种智能应用。然而,在将敏感的原始数据通过网络上传到云端,在云端集中处理私有数据对于数据捐献者来说会出现严重的数据隐私泄露问题,基于保护数据隐私安全的推动力,联邦学习的概念应运而生。与集中式学习模式不同,联邦学习支持在使用本地数据的分布式计算节点上对全局模型进行协作学习,不将原始数据发送到云端,只将学习好的全局模型更新提交到云端进行聚合;然后,更新云端上的全局模型,并将其发送回分布式计算节点进行下一轮迭代。通过这种迭代方式,可以在不损害用户隐私的情况下学习全局模型,除了改善数据隐私问题,联邦学习还带来了许多其他好处,比如提高了安全性,自主权和效率等。

随着联邦学习的发展,也出现了许多新的挑战。最主要的挑战来自两个方面:

(1)传统的联邦学习算法在每次迭代时共享模型参数,这意味着通信开销会过大。由于现有的深度学习模型可能会有数百万个参数,例如MobileBRET是一种自然语言处理任务的深度学习模型结构,有2500万个参数,对应96MB的内存大小,而边缘端的移动设备经常会受到带宽限制,每轮通信都需要交换96MB的信息,对于移动设备来说是具有挑战性的,这导致许多移动设备无法参与到需要进行大参数交互的联邦学习任务当中。

(2)异构性问题对于想要在现实场景中部署联邦学习系统造成了巨大挑战。一方面是模型异构问题,大部分参与联邦学习任务的移动设备之间的计算资源与带宽资源都不相同,移动设备没有足够的带宽或者计算能力来训练大型的深度学习模型,这意味着不同的参与者可能需要不同架构的模型进行训练,而基于模型参数交互的联邦学习架构满足不了参与者使用不同架构模型的需求;另一方面是数据异构问题,每个参与联邦学习任务的移动设备的本地数据分布在全局上呈现出非独立同分布的特点,单纯聚合移动设备客户端的模型参数可能会阻碍模型的收敛。

基于上述约束,通过联邦学习任务训练出来的全局模型在现实实践中可能不具备很高的精度。

发明内容

本发明的目的在于提供一种基于联邦知识蒸馏算法的分类预测模型训练方法、分类预测方法、设备及存储介质,以解决传统联邦学习算法通信开销过大,传统联邦学习算法无法满足参与者使用不同架构模型的需求以及数据异构导致模型精度无法提升的问题。

本发明是通过如下的技术方案来解决上述技术问题的:一种基于联邦知识蒸馏算法的分类预测模型训练方法,包括以下步骤:

步骤1:构建由中央服务器端和N个客户端C={C

步骤2:每个所述客户端C

步骤3:每个所述客户端C

每个所述客户端C

步骤4:所有客户端C将各自计算出的原型

步骤5:所述中央服务器端将接收到的每个类所有原型和所有软决策分别进行聚合,得到聚合后的各类原型和聚合后的软决策;利用聚合后的各类原型和聚合后的软决策构建优化目标函数,利用公共数据集D

利用训练后的所述全局分类预测模型计算出未带批注的公共数据集D

步骤6:所述中央服务器端将所述软决策

步骤7:每个所述客户端C

步骤8:判断循环轮次t是否等于设定轮次,如果是,则得到训练好的各本地分类预测模型X

进一步地,所述本地分类预测模型和全局分类预测模型均采用深度残差网络模型。

进一步地,对于所述客户端C

其中,D

进一步地,对类k的所有原型进行聚合的聚合公式为:

其中,N

对所有软决策进行聚合的聚合公式为:

其中,

进一步地,利用聚合后的各类原型和聚合后的软决策构建的优化目标函数的具体表达式为:

其中,(x

基于同一发明构思,本发明还提供一种基于分类预测模型的分类预测方法,所述分类预测模型包括全局分类预测模型和N个本地分类预测模型,所述全局分类预测模型和本地分类预测模型是由上述任一项所述的基于联邦知识蒸馏算法的分类预测模型训练方法训练得到,所述分类预测方法包括以下步骤:

获取待分类数据;

利用所述分类预测模型对所述待分类数据进行分类预测,得到所述待分类数据的类别。

基于同一发明构思,本发明还提供一种电子设备,所述设备包括:

存储器,用于存储计算机程序;

处理器,用于执行所述计算机程序时实现上述任一项所述的基于联邦知识蒸馏算法的分类预测模型训练方法的步骤,或实现上述基于分类预测模型的分类预测方法的步骤。

基于同一发明构思,本发明还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现上述任一项所述的基于联邦知识蒸馏算法的分类预测模型训练方法的步骤,或实现上述基于分类预测模型的分类预测方法的步骤。

有益效果

与现有技术相比,本发明的优点在于:

本发明所提供的一种分类预测模型训练方法、分类预测方法、电子设备及存储介质,客户端的私有数据和本地分类预测模型均存储在客户端本地,保证了私有数据的隐私安全;利用知识蒸馏将基于模型参数交互的传统联邦学习改进为基于模型输出软决策交互,大大地减少了服务器与客户端之间的通信开销,同时允许客户端和服务器端根据自身的带宽资源和计算资源选择合适架构的模型,实现了模型架构的个性化。

同时,本发明还通过原型网络缓解了由于客户端私有数据高度异构化所导致的模型精度难以提高的问题,大大提高了模型精度,使用本发明方法的联邦学习框架具有稳定性与高效性。

附图说明

为了更清楚地说明本发明的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一个实施例,对于本领域普通技术人员来说,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。

图1是本发明实施例中知识蒸馏流程图;

图2是本发明实施例中联邦知识蒸馏算法流程图;

图3是本发明实施例中基于联邦知识蒸馏算法的分类预测模型训练方法流程图。

具体实施方式

下面结合本发明实施例中的附图,对本发明中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动的前提下所获得的所有其他实施例,都属于本发明保护的范围。

下面以具体地实施例对本申请的技术方案进行详细说明。下面这几个具体的实施例可以相互结合,对于相同或相似的概念或过程可能在某些实施例不再赘述。

为了解决基于模型参数交互的传统联邦学习算法通信开销大的问题以及传统联邦学习算法无法满足参与者使用不同架构模型的需求以及数据异构导致模型精度无法提升的问题,本发明将知识蒸馏和原型网络运用到联邦学习中。

知识蒸馏是机器学习算法中用于模型压缩的一种方法,知识蒸馏的目的是利用预先训练好的教师模型,将教师模型的知识注入到一个未训练的学生模型。知识蒸馏不同于标准模型训练,标准模型训练试图让模型的预测输出去匹配每个样本的真实标签值(例如[猫,狗]=[0,1]),但知识蒸馏是对于同一个样本,试图让学生模型的预测输出去匹配教师模型的预测输出,也就是logits(软决策),例如[猫,狗]=[0.3,0.7],该logits(软决策)比真实标签值包含了更多的信息,相比标准模型,可以更快的训练学生模型。

考虑一个简单的神经网络,其包括输入层、隐藏层与输出层,设F

其中,(x

基于上述所描述的知识蒸馏,将其运用在联邦学习中,此时联邦学习框架中服务器与客户端交换的不再是模型参数,而是模型的软决策,而此时全局也部署了一个公共数据集来进行模型知识的传递,但是由于联邦学习对数据隐私保护的要求,公共数据集是无标签的数据。联邦知识蒸馏算法的基本流程如下:

(1)各客户端在本地利用私有数据训练模型,在每轮通信时将客户端模型对公共数据集的logtis(软决策)上传到服务器端;

(2)服务器端接收到所有参与联邦学习的客户端上传的logits(软决策)后,将其进行聚合得到最终的logits(软决策),此时服务器端模型为作为学生模型,所有客户端模型作为教师模型来进行知识蒸馏训练;

(3)知识蒸馏训练结束后,服务器端将服务器模型(学生模型)对公共数据集的logits(软决策)发送到各个客户端上;

(4)客户端接收到来自服务器端的logits(软决策)后,此时客户端模型为学生模型,服务器端模型为教师模型,再次进行知识蒸馏训练,训练完后回到步骤(1)开启下一轮联邦学习训练直至全局收敛。

由于公共数据集是无标签的数据,所以服务器模型在进行知识蒸馏训练时的优化目标变为:

其中,x

将知识蒸馏运用到联邦学习中,模型之间交互的不再是模型参数而是模型输出logits(软决策),这大大减少了通信开销,并且在使用知识蒸馏之后,服务器与客户端,包括客户端之间都不需要使用相同架构的模型,这也大大增加了联邦学习框架的个性化能力,各个设备(客户端和服务器端)可以根据自身情况选择不同架构的模型。

但是,在联邦学习框架中使用知识蒸馏时,通常情况下客户端之间的私有数据分布会呈现non-iid(非独立同分布),这会导致模型训练收敛慢,模型精度难以提高。究其原因,是因为客户端的私有数据具有高度异构性时,客户端模型的输出logits(软决策)会有over-confident(过自信)的现象,这就导致服务器端聚合后的logits(软决策)包含了错误的知识,用此logits(软决策)训练服务器模型必然会导致模型训练收敛慢,模型精度较低的问题。

示例性的:有两个客户端在参与联邦学习,客户端A的私有数据中包含了狗、猫、飞机的图像数据;另一个客户端B的私有数据中包含了猫、青蛙、飞机的图像数据。客户端A和客户端B各自使用私有数据进行模型训练后,公共数据集中出现了一张狗的图像数据,客户端A的私有数据集中有狗的图像数据,所以客户端A对此图像数据的输出logits(软决策)的分布是会正确的倾向于狗这个类的;但是由于客户端B的私有数据集中没有狗的图像数据,但是它有猫的私有数据,很有可能客户端B模型就会把这张狗的图像数据错误的预测为猫,那么客户端B模型的输出logits(软决策)的分布就会错误的倾向于猫这个类。将客户端A和客户端B的模型输出logits(软决策)聚合,聚合后得到的logits(软决策)的分布可能会在猫与狗这两个类上比较均匀,这就导致服务器模型训练收敛慢,模型精度无法提升。

基于上述问题,本发明提出使用原型网络来缓解这样的现象。原型网络是指某一类数据在特征空间上的嵌入式向量的平均表示,它是一类数据的抽象特征表示。考虑一个简单的神经网络,其包括输入层、隐藏层与输出层,设F

其中,D

原型网络能够从特征空间层面区别出每个类的不同之处,将其运用在联邦学习当中,服务器端通过聚合客户端基于私有数据的类原型网络,尽管客户端由于数据分布的异构性导致其原型网络略有不同,但是服务器端可以获得更多关于类的抽象特征表示,并通过学习这些抽象特征表示来缓解由于logits(软决策)的over-confident(过自信)导致的模型训练问题。

因此,本发明实施例所提供的一种基于联邦知识蒸馏算法的分类预测模型训练方法,包括以下步骤:

步骤1:构建由中央服务器端和N个客户端C={C

本实施例中,N=100,客户端为手表、手机、电脑或IPAD等。

步骤2:每个所述客户端C

本地训练数据集D

本实施例中,本地分类预测模型X

步骤3:每个所述客户端C

在每个客户端C

其中,D

每个所述客户端C

将公共数据集D

本实施例中,公共数据集D

将公共数据集D

步骤4:所有客户端C将各自计算出的原型

步骤5:所述中央服务器端将接收到的每个类所有原型和所有软决策分别进行聚合,得到聚合后的各类原型和聚合后的软决策;利用聚合后的各类原型和聚合后的软决策构建优化目标函数,利用公共数据集D

聚合包括平均聚合和加权聚合,本实施例采用平均聚合,对类k的所有原型进行聚合的聚合公式为:

其中,N

对所有软决策进行聚合的聚合公式为:

其中,

利用聚合后的各类原型和聚合后的软决策构建的优化目标函数的具体表达式为:

其中,(x

本实施例中,全局分类预测模型采用深度残差网络模型,例如ResNet56模型。

利用训练后的所述全局分类预测模型计算出未带批注的公共数据集D

步骤6:所述中央服务器端将所述软决策

步骤7:每个所述客户端C

步骤8:判断循环轮次t是否等于设定轮次,如果是,则得到训练好的各本地分类预测模型X

基于同一发明构思,本发明还提供一种基于分类预测模型的分类预测方法,所述分类预测模型包括全局分类预测模型和N个本地分类预测模型,所述全局分类预测模型和本地分类预测模型是由上述基于联邦知识蒸馏算法的分类预测模型训练方法训练得到,所述分类预测方法包括以下步骤:

步骤1:获取待分类数据;

步骤2:利用所述分类预测模型对所述待分类数据进行分类预测,得到所述待分类数据的类别。

本发明中,客户端的私有数据和本地分类预测模型均存储在客户端本地,保证了私有数据的隐私安全;利用知识蒸馏将基于模型参数交互的传统联邦学习改进为基于模型输出软决策交互,大大地减少了服务器与客户端之间的通信开销,同时允许客户端和服务器端根据自身的带宽资源和计算资源选择合适架构的模型,实现了模型架构的个性化。

同时,本发明还通过原型网络缓解了由于客户端私有数据高度异构化所导致的模型精度难以提高的问题,大大提高了模型精度,使用本发明方法的联邦学习框架具有稳定性与高效性。

以上所揭露的仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或变型,都应涵盖在本发明的保护范围之内。

相关技术
  • MR图像预测模型的训练方法、装置、设备及存储介质
  • CT图像预测模型的训练方法、装置、设备及存储介质
  • 相似度预测模型训练方法、设备及计算机可读存储介质
  • 分类器训练方法、装置、设备和计算机可读存储介质
  • 文本分类模型的训练方法、装置及可读存储介质
  • 分类预测模型的训练方法、分类预测方法、装置和设备
  • 脉象分类方法、分类模型训练方法、分类设备和存储介质
技术分类

06120115925206