掌桥专利:专业的专利平台
掌桥专利
首页

联邦学习模型训练方法、客户端、服务器及存储介质

文献发布时间:2023-06-19 12:19:35


联邦学习模型训练方法、客户端、服务器及存储介质

技术领域

本发明涉及人工智能技术领域,具体涉及一种基于知识蒸馏 的联邦学习模型训练方法、客户端、服务器及计算机可读存储介质。

背景技术

目前多家单位若合作利用人工智能算法在某一业务场景中 进行落地,会遇到一些问题,例如,由于数据安全与数据隐私要求,各 家单位的数据不能在各单位之间进行有效流通和使用,从而造成数据孤 岛问题。传统的算法训练框架强调数据的多样性和完整性,从而进一步 放大数据孤岛问题给算法能力带来的影响。因此传统的算法训练框架和数据孤岛问题会使得人工智能算法能力陷入瓶颈,并进一步限制算法在 实际应用场景中的使用和落地。

因此,本领域仍然需要一种新的方法来解决由于数据孤岛无 法提高算法能力,限制算法落地应用的问题。

发明内容

为了解决现有技术中的上述问题,即,为了解决现有方案由 于数据孤岛无法提高算法能力,限制算法落地应用的问题,一方面,一 种基于知识蒸馏的联邦学习模型训练方法,包括:接收服务器获取的用 于模型训练的控制参数;根据所述控制参数以及本地数据样本对初始的 第一神经网络模型进行训练,得到第一模型参数;将所述第一模型参数 发送至所述服务器;接收所述服务器获取的第二神经网络模型的第二模 型参数;利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神 经网络模型的知识,训练得到更新的第一神经网络模型,其中,所述第 一神经网络模型是学生网络模型。

在上述联邦学习模型训练方法的优选实施方式中,所述控 制参数至少包括训练次数n,n大于等于2,还可以包括:将所述更新的 第一神经网络模型的第一模型参数发送至所述服务器;接收所述服务 器获取的更新后的第二模型参数;根据所述更新后的第二模型参数得 到更新后的第二神经网络模型;利用知识蒸馏方法使所述更新后的第 二神经网络模型学习到所述更新的第一神经网络模型的知识,训练得 到二次训练后的第二神经网络模型,其中,所述更新后的第二神经网 络模型是学生网络模型;将所述二次训练后的第二神经网络模型的第 二模型参数发送至所述服务器;如此循环,直至所述n次训练结束。

在上述联邦学习模型训练方法的优选实施方式中,所述第 一模型参数和所述第二模型参数至少包括神经网络的权重参数。

在上述联邦学习模型训练方法的优选实施方式中,还可以 包括:所述知识蒸馏方法所采用的损失函数包括以下任意一种:均方 误差损失函数、平均绝对误差损失函数。

根据本发明的另一方面,还提供了一种联邦学习模型训练 方法,包括:向第一客户端和第二客户端分别发送用于模型训练的控 制参数;接收来自第一客户端的经过更新的第一神经网络模型的第一 模型参数;接收来自第二客户端的经过更新的第二神经网络模型的第 二模型参数;将所述第一模型参数发送至所述第二客户端,将所述第 二模型参数发送至所述第一客户端;保存所述第一模型参数和所述第 二模型参数。

在上述联邦学习模型训练方法的优选实施方式中,包括: 所述控制参数至少包括训练次数n,n大于等于2;将更新后的第一模型 参数发送至所述第二客户端;将更新后的第二模型参数发送至所述第 一客户端;接收并保存来自所述第二客户端的经过二次训练的第一模 型参数;接收并保存来自所述第一客户端的经过二次训练的第二模型 参数;如此循环,直至所述n次训练结束。

在上述联邦学习模型训练方法的优选实施方式中,在保存 第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个 第一模型参数和第二模型参数进行保存更新;将选取保存的模型参数 发送至对应客户端。

根据本发明的再一方面,还提供了一种基于知识蒸馏的联 邦学习模型训练客户端,包括:通讯模块,接收来自服务器的用于模 型训练的控制参数,以及接收来自所述服务器的第二神经网络模型的 第二模型参数;算法训练模块,与所述通讯模块连接,根据所述控制 参数以及本地数据样本对初始的第一神经网络模型进行训练,得到第 一模型参数,根据所述第二模型参数得到所述第二神经网络模型,以 及利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神经网 络模型的知识,训练得到更新的第一神经网络模型,其中,所述第一 神经网络模型是学生网络模型。

在上述客户端的优选实施方式中,所述控制参数至少包括 训练次数n,n大于等于2,所述通讯模块还将得到的第一神经网络模型 的模型参数发送至所述服务器进行更新,以及接收来自所述服务器的 更新后的第二模型参数,以及将所二次训练后的第二神经网络模型的 第二模型参数发送至所述服务器;所述算法训练模块还根据所述更新 后的第二模型参数得到更新后的第二神经网络模型,以及利用知识蒸 馏方法使所述更新后的第二神经网络模型学习到所述更新的第一神经 网络模型的知识,训练得到二次训练后的第二神经网络模型,其中, 所述更新后的第二神经网络模型是学生网络模型,如此循环,直至所述n次训练结束。

根据本发明的又一方面,还提供了一种服务器,包括:训 练控制模块,生成用于模型训练的控制参数;通讯模块,向第一客户 端和第二客户端分别发送所述控制参数,接收来自第一客户端的经过 更新的第一神经网络模型的第一模型参数,以及接收来自第二客户端 的经过更新的第二神经网络模型的第二模型参数;参数更新模块,保 存所述第一模型参数和所述第二模型参数。

在上述服务器的优选实施方式中,包括:所述控制参数至 少包括训练次数n,n大于等于2;所述通讯模块还用于将更新后的第一 模型参数发送至所述第二客户端,将更新后的第二模型参数发送至所 述第一客户端,接收并保存来自所述第二客户端的经过二次训练的第 一模型参数,以及接收来自所述第一客户端的经过二次训练的第二模 型参数,如此循环,直至所述n次训练结束;所述参数更新模块保存所 述经过二次训练的第二模型参数。

在上述服务器的优选实施方式中,还可以包括:模型优选 模块,与所述参数更新模块和通讯模块连接,在保存第一模型参数和 第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和 第二模型参数进行保存更新;所述通讯模块将选取保存的模型参数发 送至对应客户端。

本发明进一步还提供了一种基于知识蒸馏的联邦学习模 型训练系统,包括多个如上述任一技术方案中所述的基于知识蒸馏的 联邦学习模型训练客户端和如上述任一技术方案中所述的服务器。

本发明进一步还提供了一种计算机可读存储介质,所述存 储介质中存储有多条程序代码,所述程序代码适用于由处理器加载并 运行以执行如上述任一技术方案中所述的基于知识蒸馏的联邦学习模 型训练方法和上述任一技术方案中所述的联邦学习模型训练方法。

本发明将模型训练设置在本地,并通过中心服务器完成模 型参数交互,解决了数据孤岛和数据隐私的问题,可在数据不离本地 的情况下完成算法模型的训练及优化。能够支持双模型之间的相互蒸 馏,充分利用全部数据的知识,提高算法模型在联邦框架下的训练效 果,同时一次训练过程可以完成两个神经网络模型的训练,打破数据 孤岛和传统训练框架对算法能力造成的瓶颈。

附图说明

下面结合附图来描述本发明的优选实施方式,附图中:

图1为根据本发明实施例的基于知识蒸馏的联邦学习模型 训练方法的流程图;

图2为根据本发明实施例的联邦学习模型训练方法的流程 图;

图3为根据本发明一个实施例的基于知识蒸馏的联邦学习 模型训练的结构示意图。

具体实施方式

为了便于理解本发明,下文将结合说明书附图和实施例对 本发明作更全面、细致的描述,但本领域技术人员应当理解的是,这 些实施方式仅仅用于解释本发明的技术原理,并非旨在限制本发明的 保护范围。

在本发明的描述中,“模块”、“处理器”可以包括硬件、软 件或者两者的组合。一个模块可以包括硬件电路、各种合适的感应器、 通信端口、存储器,也可以包括软件部分,比如程序代码,也可以是 软件和硬件的组合。处理器可以是中央处理器、微处理器、图像处理器、数字信号处理器或者其他任何合适的处理器。处理器具有数据和/ 或信号处理功能。处理器可以以软件方式实现、硬件方式实现或者二 者结合的方式实现。非暂时性的计算机可读存储介质包括任何合适的 可存储程序代码的介质,比如磁碟、硬盘、光碟、闪存、只读存储器、 随机存取存储器等等。术语“A和/或B”表示所有可能的A与B的组合, 比如只是A、只是B或者A和B。术语“至少一个A或B”或者“A和B 中的至少一个”含义与“A和/或B”类似,可以包括只是A、只是B或者 A和B。单数形式的术语“一个”、“这个”也可以包含复数形式。

首先参阅图1,在客户端侧,根据本发明实施例的一种基 于知识蒸馏的联邦学习模型训练方法,包括:

S1,接收来自服务器的用于模型训练的控制参数。该控制 参数可以包括学习率、训练次数等模型训练需要的参数,不仅限于此。

S2,根据控制参数以及本地数据样本对初始的第一神经网 络模型进行训练,得到第一模型参数。每个客户端可以部署在各场景 的本地,直接利用本地的数据样本进行训练,这样就无需将本地数据 传输至外部,保护数据隐私。

S3,将第一模型参数发送至服务器。训练完成后,将模型 参数发送给服务器进行保存更新,以便服务器将第一模型参数发送给 其他客户端。

S4,接收来自服务器的第二神经网络模型的第二模型参 数。除了将自己的模型参数通过服务器发送给其他客户端之外,还接 收其他客户端的第二模型参数,便于后面的知识蒸馏学习。

S5,利用知识蒸馏方法使第一神经网络模型学习到第二神 经网络模型的知识,训练得到更新的第一神经网络模型,其中,第一 神经网络模型是学生网络模型。通过双模型之间的相互蒸馏,就可以 充分利用各客户端的全部数据的知识,提高算法模型在联邦框架下的 训练效果。

在上述联邦学习模型训练方法的优选实施方式中,完成了 一次训练,为了进一步提升训练效果,控制参数还可以包括训练次数n, n大于等于2,即按照同样的思路训练多次,直到符合预设效果。将更 新的第一神经网络模型的第一模型参数发送至服务器;接收来自服务 器的更新后的第二模型参数;根据更新后的第二模型参数得到更新后 的第二神经网络模型;利用知识蒸馏方法使更新后的第二神经网络模 型学习到更新的第一神经网络模型的知识,训练得到二次训练后的第 二神经网络模型,其中,更新后的第二神经网络模型是学生网络模型; 将二次训练后的第二神经网络模型的第二模型参数发送至服务器;如 此循环,直至n次训练结束。在第二次训练时,利用第一次训练更新 后的双模型,并且将学习网络和教师网络进行替换,进一步充分学习 了两个客户端数据的知识,同时汇聚了两个模型的优点。以此类推, 进行多次训练之后,可以获得更好的算法模型。

需要说明的是,第一神经网络模型和第二神经网络模型可 以是相同的模型,也可以不相同的模型。第一模型参数和第二模型参 数至少包括神经网络的权重参数。

在上述联邦学习模型训练方法的优选实施方式中,还可以 包括:知识蒸馏方法所采用的损失函数包括以下任意一种:均方误差 损失函数、平均绝对误差损失函数。

传统的联邦学习技术框架会由于参数更新方式、训练策略 以及数据孤立等问题,造成算法模型的训练效果以及性能要差于传统 的训练框架在全部数据上进行训练的效果。而通过上述实施方案,能 够解决数据孤岛和数据隐私的问题,可在数据不离本地的情况下完成 算法模型的训练及优化。并且能够支持双模型之间的相互蒸馏,充分 利用全部数据的知识,提高算法模型在联邦框架下的训练效果,同时 一次训练过程可以完成两个神经网络模型的训练。打破数据孤岛和传 统训练框架对算法能力造成的瓶颈。

下面结合图2和图3,详细说明根据本发明的另一实施例。

步骤21,server端(即服务器,可以是任意的节点)进行 训练初始化。在server配置神经网络结构模型、启动参数及训练参数, 进行训练初始化。将训练相关参数发送至各client端(即客户端)。在 实际应用时,客户端可以是银行系统的客户端,社保系统的客户端。 由于银行和社保系统都是需要高度安全和隐私,因此两端的数据无法 实现互通,通过本发明可以解决该问题。

本领域人员应该理解,这里的神经网络模型包括但不限于 YOLOv3,YOLOv4。

步骤22,client端启动并开始训练。Client端接收server 相关训练控制参数并进行训练启动,开始模型的训练。控制参数中可 以包括训练次数,例如epoch值。一个epoch就是使用训练集中的全部 样本训练一次。通俗的讲,Epoch的值就是整个训练数据集被反复使用 几次。Epoch数是一个超参数,它定义了学习算法在整个训练数据集中 的工作次数。

步骤S23,每个clinet端完成1个epoch训练后,将训练 完成的模型参数返回至server端。如图3所示,在客户端1中,利用 本地样本数据data1对model1进行训练,得到model1的模型参数,把 该模型参数发送至服务器进行更新。同理,在客户端2中,利用本地样本数据data2对model2进行训练,得到model2的模型参数,把该模 型参数发送至服务器进行更新。

步骤24,server端更新从client端获取的模型参数,并将 更新后的模型参数交换分发至对应的client端。服务器将model2的模 型参数发送给客客户端1,将model1的模型参数发送给客户端2。

步骤25,client端双模型相互蒸馏训练并返回模型参数至 server端。

在epoch 2训练阶段,每个client端都会存在一个学生网 络模型,并通过另一个模型(教师网络模型)进行知识蒸馏。每个客 户端在完成1个epoch的训练后将学生网络模型返回至server端。如图 3中所示,在epoch 2训练阶段,在客户端1中model1是学生网络模型, model2教师网络模型。将训练更新的model1的模型参数反馈至服务器。 同理,在客户端2中,model1是教师网络模型,model2是学生网络模 型。经过知识蒸馏训练后,将model2反馈至服务器。

在知识蒸馏过程中,学生网络通过优化的损失函数进行训 练。优化的损失函数为:

loss=loss

这里loss

loss

其中M

步骤26,server端更新参数,并进行模型选优保存。

server端在保存参数时,利用指标评估方法选取一个或多 个第一模型参数和第二模型参数进行保存更新。在评估时可以采用 mAP(mean Average Precision,不同召回率上的正确率的平均值),loss 等评估指标。

步骤27,server端发送模型至client端。

将server端将model2发送至客户端1,将model1发送给 客户端2。在每个客户端中,交换学生网络模型和教师网络模型,进行 知识蒸馏训练,例如在客户端1中,以model2为学生网络,model1为 教师网络,在客户端2中,以model1为学生网络,model2为教师网络。然后重复25,26步骤直至达到设计的训练epoch数据。通过交换的方 式,可以同时学习两个模型的精华知识。

在服务器侧,根据本发明的实施例的联邦学习模型训练方 法,包括:向第一客户端和第二客户端分别发送用于模型训练的控制 参数;接收来自第一客户端的经过更新的第一神经网络模型的第一模 型参数;接收来自第二客户端的经过更新的第二神经网络模型的第二 模型参数;将第一模型参数发送至第二客户端,将第二模型参数发送 至第一客户端;保存第一模型参数和第二模型参数。

服务器可以用于管理和交换多个客户端之间的模型参数, 在不传输数据的情况下,完成具备多端知识的模型训练,打破数据孤 岛的问题。

在上述联邦学习模型训练方法的优选实施方式中,控制参 数至少包括训练次数n,n大于等于2;将更新后的第一模型参数发送 至第二客户端;将更新后的第二模型参数发送至第一客户端;接收并 保存来自第二客户端的经过二次训练的第一模型参数;接收并保存来 自第一客户端的经过二次训练的第二模型参数;如此循环,直至n次 训练结束。

在上述联邦学习模型训练方法的优选实施方式中,在保存 第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个 第一模型参数和第二模型参数进行保存更新;将选取保存的模型参数 发送至对应客户端。

继续参考图3,根据本发明的实施例的基于知识蒸馏的联 邦学习模型训练客户端31或者32,包括:通讯模块33,接收来自服 务器的用于模型训练的控制参数,以及接收来自服务器的第二神经网 络模型的第二模型参数;算法训练模块32,与通讯模块33连接,根据控制参数以及本地数据样本对初始的第一神经网络模型进行训练,得 到第一模型参数,根据第二模型参数得到第二神经网络模型,以及利 用知识蒸馏方法使第一神经网络模型学习到第二神经网络模型的知 识,训练得到更新的第一神经网络模型,其中,第一神经网络模型是 学生网络模型。

在上述优选实施方式中,控制参数至少包括训练次数n,n 大于等于2,通讯模块33还将得到的第一神经网络模型的模型参数发 送至服务器进行更新,以及接收来自服务器的更新后的第二模型参数, 以及将所二次训练后的第二神经网络模型的第二模型参数发送至服务 器;算法训练模块32还根据更新后的第二模型参数得到更新后的第二 神经网络模型,以及利用知识蒸馏方法使更新后的第二神经网络模型 学习到更新的第一神经网络模型的知识,训练得到二次训练后的第二 神经网络模型,其中,更新后的第二神经网络模型是学生网络模型, 如此循环,直至n次训练结束。图3中的数据加载模块可用于在训练 模型时加载样本数据data1。

继续参考图3,根据本发明的实施例的服务器(server端) 300,可以包括:训练控制模块36,生成用于模型训练的控制参数;通 讯模块39,向第一客户端和第二客户端分别发送控制参数,接收来自 第一客户端的经过更新的第一神经网络模型的第一模型参数,以及接 收来自第二客户端的经过更新的第二神经网络模型的第二模型参数; 参数更新模块37,保存所述第一模型参数和所述第二模型参数。日志 管理模块用于保存运行日志。

在上述优选实施方式中,包括:所述控制参数至少包括训 练次数n,n大于等于2;所述通讯模块还用于将更新后的第一模型参 数发送至所述第二客户端,将更新后的第二模型参数发送至所述第一 客户端,接收并保存来自所述第二客户端的经过二次训练的第一模型 参数,以及接收来自所述第一客户端的经过二次训练的第二模型参数, 如此循环,直至所述n次训练结束;所述参数更新模块保存所述经过 二次训练的第二模型参数。

在上述优选实施方式中,还可以包括:模型优选模块38, 与所述参数更新模块37和通讯模块39连接,在保存第一模型参数和 第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和 第二模型参数进行保存更新;所述通讯模块将选取保存的模型参数发 送至对应客户端。

如图3所示,是根据本发明的实施例的基于知识蒸馏的联 邦学习模型训练系统,包括多个如上述任一技术方案中所述的基于知 识蒸馏的联邦学习模型训练客户端(客户端31、客户端32)和如上述 任一技术方案中所述的服务器300。

本发明进一步还提供了一种计算机可读存储介质,所述存 储介质中存储有多条程序代码,所述程序代码适用于由处理器加载并 运行以执行基于知识蒸馏的联邦学习模型训练方法和联邦学习模型训 练方法。

本发明将模型训练设置在本地,并通过中心服务器完成模 型参数交互,解决了数据孤岛和数据隐私的问题,可在数据不离本地 的情况下完成算法模型的训练及优化。能够支持双模型之间的相互蒸 馏,充分利用全部数据的知识,提高算法模型在联邦框架下的训练效 果,同时一次训练过程可以完成两个神经网络模型的训练,打破数据 孤岛和传统训练框架对算法能力造成的瓶颈。

至此,已经结合附图所示的一个实施方式描述了本发明的 技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围 显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下, 本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更 改或替换之后的技术方案都将落入本发明的保护范围之内。

相关技术
  • 联邦学习模型训练方法、客户端、服务器及存储介质
  • 纵向联邦学习模型的训练方法、装置、设备及存储介质
技术分类

06120113254770