掌桥专利:专业的专利平台
掌桥专利
首页

基于联邦学习的模型训练方法、装置及联邦学习系统

文献发布时间:2023-06-19 11:08:20


基于联邦学习的模型训练方法、装置及联邦学习系统

技术领域

本发明涉及人工智能技术领域,具体而言,涉及一种基于联邦学习的模型训练方法、装置及一种联邦学习系统。

背景技术

随着居民财富积累的高速增长,从居民财富的结构角度,金融资产占未来居民财富的比重和需求会越来越高。金融资产包括库存现金、银行存款、债权投资、股权投资、基金投资、衍生金融资产等。如何根据个人风险偏好来合理的进行金融资产配置就显得尤为重要。而利用机器算法来进行个人风险偏好预测主要面临着两大重要挑战:其一,数据大都以孤岛的形式存在;其二,数据隐私和安全。例如,由于各种因素的影响,不同卡组织和支付机构所拥有的大量用户画像数据是非共享的。为了保护数据隐私以及解决数据孤岛问题,现有技术提出了采用联邦学习进行模型训练。联邦学习是一种新兴的人工智能技术,相较于传统机器学习方法需要将训练数据集中到一台机器或者一个数据中心里,联邦学习利用分散的成千上万不同卡组织和支付机构的数据集协同训练机器学习模型,而所有的训练数据仍保留在各自的机构手中,从而保护了用户的隐私。

但传统的联邦学习算法在通讯过程中仅进行本地模型参数的传递和聚合,无法实现多元数据的融合,导致训练出的模型的预测准确性不够理想。现有技术缺少一种对现有联邦学习进行改进的方法,以提高模型预测的准确性。

发明内容

本发明为了解决现有的联邦学习无法实现多元数据的融合导致训练出的模型的预测准确性不够理想的技术问题,提出了一种基于联邦学习的模型训练方法及装置。

为了实现上述目的,根据本发明的一个方面,提供了一种基于联邦学习的模型训练方法,该方法包括:

在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到中央服务器,以使所述中央服务器根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

为了实现上述目的,根据本发明的一个方面,提供了另一种基于联邦学习的模型训练方法,该方法包括:

在联邦学习的第i次迭代中,接收各客户端发送的数据特征以及预测精度,其中,每个所述客户端在第i次迭代中根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,以及根据训练出的模型提取本地数据集的数据特征,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数;

在联邦学习的第R次迭代中,将所述全局聚合特征传输到各客户端,以使各客户端根据所述全局聚合特征以及本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到每个客户端各自对应的最终训练出的模型。

为了实现上述目的,根据本发明的另一方面,提供了一种联邦学习系统,该系统包括:中央服务器以及多个客户端;

每个所述客户端,用于在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到所述中央服务器,其中,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

所述中央服务器,用于根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数;

每个所述客户端,还用于在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

为了实现上述目的,根据本发明的另一方面,提供了一种基于联邦学习的模型训练装置,该装置包括:

第一模型训练处理模块,用于在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到中央服务器,以使所述中央服务器根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

第二模型训练处理模块,用于在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

为了实现上述目的,根据本发明的另一方面,提供了另一种基于联邦学习的模型训练装置,该装置包括:

数据接收模块,用于在联邦学习的第i次迭代中,接收各客户端发送的数据特征以及预测精度,其中,每个所述客户端在第i次迭代中根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,以及根据训练出的模型提取本地数据集的数据特征,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

特征聚合模块,用于根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数;

全局聚合特征传输模块,用于在联邦学习的第R次迭代中,将所述全局聚合特征传输到各客户端,以使各客户端根据所述全局聚合特征以及本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到每个客户端各自对应的最终训练出的模型。

为了实现上述目的,根据本发明的另一方面,还提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述基于联邦学习的模型训练方法中的步骤。

为了实现上述目的,根据本发明的另一方面,还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序在计算机处理器中执行时实现上述基于联邦学习的模型训练方法中的步骤。

本发明的有益效果为:本发明实施例将联邦学习前R-1次迭代中每个客户端发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,并在第R次迭代时根据所述全局聚合特征以及本地训练数据集中各训练数据的标签进行模型训练,得到最终训练出的模型,实现了对联邦学习中多元数据的融合,解决了现有的联邦学习由于无法实现多元数据的融合导致训练出的模型的预测准确性不够理想的技术问题。

附图说明

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。在附图中:

图1是本发明实施例基于联邦学习的模型训练方法的第一流程图;

图2是本发明实施例基于联邦学习的模型训练方法的第二流程图;

图3是本发明实施例提取数据特征的示意图;

图4是本发明实施例联邦学习系统第一示意图;

图5是本发明实施例ESN网络结构示意图;

图6是本发明实施例联邦学习系统第二示意图;

图7是本发明实施例基于联邦学习的模型训练装置的第一结构框图;

图8是本发明实施例基于联邦学习的模型训练装置的第二结构框图;

图9是本发明实施例计算机设备示意图。

具体实施方式

为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。

本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。

需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。

需要说明的是,在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本发明。

本发明对传统联邦学习进行改进,提供一种基于改进联邦学习的模型训练方法,本发明方法旨在解决传统联邦学习算法在通讯过程中仅进行本地模型参数的传递,无法实现多元数据融合且通讯代价过高的问题。本发明在考虑隐私数据保护的前提下,通过关键参数特征提取及融合算法来进行数据特征融合,减小通讯代价,提升预测的精度以及可行性。

本发明一个方面提出了一种联邦学习系统,如图4和图6所示,本发明的联邦学习系统包含中央服务器以及K个客户端,K为大于1的整数。

每个所述客户端,用于在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到所述中央服务器,其中,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集。

在本发明一个实施例中,本发明每个所述客户端训练出的模型均包含特征提取层,本发明将本地数据集输入到训练出的模型中,数据在特征提取层中的输出即为本地数据集对应的数据特征。图3是本发明一个实施例提取数据特征的示意图,如图3所示,第一神经网络采用LSTM网络,将本地训练数据集以及本地测试数据集放入训练好的LSTM网络中,并将数据在特征提取层(LSTM层)的输出看作整个数据集的数据特征。

在本发明一个实施例中,第一神经网络可以采用现有技术任意一种神经网络。每个所述客户端在前R-1次迭代中均采用同一种神经网络进行模型训练

在本发明一个实施例中,每个所述客户端采用局部随机梯度下降法同步进行模型训练。

所述中央服务器,用于根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数。

每个所述客户端,还用于在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

本发明的模型训练方法可以应用于多种场景,对多种数据进行预测。在本发明一个具体实施例中,所述最终训练出的模型用于预测客户的风险偏好分类;所述本地训练数据集和所述本地验证数据集均包含多个客户数据,每个所述客户数据包含:多个客户特征以及客户风险偏好分类标签。

在本发明一个实施例中,每个所述客户端,还用于在联邦学习的第i次迭代中,将训练出的模型的模型参数发送到所述中央服务器,以使所述中央服务器对在本次迭代中每个客户端发送的模型参数进行聚合得到聚合后的模型参数;所述中央服务器,还用于在联邦学习的第i次迭代中,将所述聚合后的模型参数传输到每个所述客户端,以使每个所述客户端将所述聚合后的模型参数作为第i+1次迭代时进行模型训练的初始参数。

传统的联邦学习的流程包含若干个通讯轮次,即若干次迭代,本发明的每个通讯轮次即为一次迭代。由于通讯质量限制,每一通讯轮次中只有处于良好通讯状态的客户端可以参与,参与通讯客户端记为Ck。本发明的客户端为参与联邦学习的节点,例如可以为卡组织节点或者支付机构节点。通讯过程中客户端将本地模型的模型参数共享至中央服务器中进行聚合,得到聚合后的模型参数。中央服务器对各客户端的模型参数进行聚合可以采用以下公式:

其中,

传统的联邦学习框架仅上传本地客户端模型参数至中央服务器内进行聚合,也就是说,在该框架下本地客户端仅上传模型参数至服务器进行聚合得到中央模型,而无法实现数据融合的功能。基于此,本发明进一步提出改进方法,如图4所示。

如图4所示,与传统的联邦学习相似,本发明的改进的联邦学习同样主要包含两个部分,即客户端模型训练和中央服务器的模型参数聚合及其模型参数的上传/下载。但是,本发明的改进的联邦学习与传统联邦学习相比又存在两个明显不同点:(1)在每个通讯轮次中,所有参与更新的客户端都将基于本地数据集P利用局部随机梯度下降法进行同步训练。对于参与聚合的第Ck个客户端,首先提取本地数据集的数据特征(表示为fea

其中,

在本发明实施例中,本发明的通讯过程共进行R轮次,即本发明的联邦学习共进行R次迭代,R为预设值,每轮次中模型参数都将在参与客户端与中央服务器之间相互传递,客户端将根据下载的聚合后的模型参数进一步训练本地模型,并利用训练出的模型继续提取本地数据集的数据特征;而上传的数据特征一直在中央服务器内进行聚合更新,直至循环迭代结束为止,即在各通讯轮次中各客户端需要上传客户端模型参数、数据特征以及训练出的模型的预测精度,而从中央服务器下载的内容仅为聚合后的模型参数,直至第R轮通讯轮次,所有客户端才会下载聚合后的全局数据特征作为全局数据表示。

在本发明一个实施例中,模型参数具体为模型参数矩阵。每个所述客户端,还用于采用奇异值分解算法分解模型参数矩阵得到第一奇异值,并将所述第一奇异值发送到所述中央服务器;所述中央服务器,还用于通过对所述第一奇异值进行奇异值分解算法的逆运算得到模型参数矩阵。

由于在联邦学习中,需将参与聚合的模型参数上传至中央服务器进行聚合更新,当网络结构较复杂时,上传和下载的数据量极大,对网络带宽要求较高,通讯代价较大。本发明采用奇异值分解(Singular Value Decomposition,SVD)算法分解模型参数矩阵并上传矩阵分解后的奇异值代替原本模型参数矩阵进行聚合,减少上传数据的同时,提高融合操作的性能。

在本发明一个实施例中,本发明的模型参数矩阵包含:连接权矩阵W

例如,假设第i个客户端的模型参数矩阵包括特征提取阶段的连接权矩阵W

在本发明实施例中,每个所述客户端根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。其中,所述第二神经网络可以采用回声状态网络(Echo State Network,ESN)。ESN网络作为递归神经网络的一种变体,其结构可以如图5所示。ESN结构中主要包含输入层、动态储备池和输出层三部分,网络具有K个输入单元,M个内部单元以及L个输出单元,各层之间分别通过输入连接权值与输出连接权值进行连接,动态储备池位于输入层与输出层之间,由大量稀疏连接的神经元组成。网络的输入为全局聚合特征Fea及对应标签所形成的集合,输出为预测的标签值。

ESN网络的训练过程只需求解一个线性回归问题,大大简化了网络的训练过程。如图6所示,各个客户端下载最终通讯轮次中聚合形成的全局聚合特征Fea,并将聚合特征以及本地数据集内的真实数据值分别看作训练数据和对应标签,按照上节所述ESN网络的训练方法训练ESN网络,得到每个客户端各自对应的最终训练出的模型。在本发明一个实施例中,最终训练出的模型用于对客户的风险偏好分类进行预测,从而实现对应的金融产品推荐。

在本发明一个实施例中,所述中央服务器,还用于采用奇异值分解算法分解聚合后的模型参数矩阵得到的第二奇异值,并将所述第二奇异值发送到每个所述客户端;每个所述客户端,还用于对所述第二奇异值进行奇异值分解算法的逆运算得到聚合后的模型参数矩阵。

在本发明一个实施例中,数据特征具体为数据特征矩阵。每个所述客户端,还用于采用奇异值分解算法分解数据特征矩阵得到第三奇异值,并将所述第三奇异值发送到所述中央服务器;所述中央服务器,还用于通过对所述第三奇异值进行奇异值分解算法的逆运算得到数据特征矩阵。

在本发明一个实施例中,全局聚合特征具体为全局聚合特征矩阵。所述中央服务器,还用于采用奇异值分解算法分解全局聚合特征矩阵得到的第四奇异值,并将所述第四奇异值传输到每个所述客户端;每个所述客户端,还用于对所述第四奇异值进行奇异值分解算法的逆运算得到全局聚合特征矩阵。

本发明另一个方面提出了一种基于联邦学习的模型训练方法,由于基于联邦学习的模型训练方法解决问题的原理与上述联邦学习系统相似,因此基于联邦学习的模型训练方法的实施例可以参见上述联邦学习系统的实施例,重复之处不再赘述。图1是本发明实施例基于联邦学习的模型训练方法的第一流程图,应用于上述联邦学习系统中的客户端,如图1所示,本实施例的基于联邦学习的模型训练方法包括步骤S101和步骤S102。

步骤S101,在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到中央服务器,以使所述中央服务器根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集。

在本发明一个实施例中,所述根据本地训练数据集采用第一神经网络进行模型训练,具体包括:采用局部随机梯度下降法进行模型训练。

步骤S102,在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

在本发明一个实施例中,本发明的基于联邦学习的模型训练方法,在联邦学习的第i次迭代中,还包括:

将训练出的模型的模型参数发送到所述中央服务器,以使所述中央服务器对在本次迭代中每个客户端发送的模型参数进行聚合得到聚合后的模型参数;

从所述中央服务器获取所述聚合后的模型参数,将所述聚合后的模型参数作为第i+1次迭代时进行模型训练的初始参数。

在本发明一个实施例中,本发明的模型参数具体可以为模型参数矩阵。在本发明一个实施例中,上述步骤中的将训练出的模型的模型参数发送到所述中央服务器,具体包括:

采用奇异值分解算法分解模型参数矩阵得到第一奇异值,将所述第一奇异值发送到所述中央服务器,以使所述中央服务器通过对所述第一奇异值进行奇异值分解算法的逆运算得到模型参数矩阵。

在本发明一个实施例中,上述步骤中的从所述中央服务器获取所述聚合后的模型参数,具体包括:

获取所述中央服务器采用奇异值分解算法分解聚合后的模型参数矩阵得到的第二奇异值;

对所述第二奇异值进行奇异值分解算法的逆运算得到聚合后的模型参数矩阵。

在本发明一个实施例中,本发明的数据特征具体可以为数据特征矩阵。在本发明一个实施例中,上述步骤S101中的所述将所述数据特征以及所述预测精度发送到中央服务器,具体包括:

采用奇异值分解算法分解数据特征矩阵得到第三奇异值,将所述第三奇异值发送到所述中央服务器,以使所述中央服务器通过对所述第三奇异值进行奇异值分解算法的逆运算得到数据特征矩阵。

在本发明一个实施例中,本发明的全局聚合特征具体可以为全局聚合特征矩阵。在本发明一个实施例中,上述步骤S102中的从所述中央服务器获取所述全局聚合特征,具体包括:

获取所述中央服务器采用奇异值分解算法分解全局聚合特征矩阵得到的第四奇异值;

对所述第四奇异值进行奇异值分解算法的逆运算得到全局聚合特征矩阵。

图2是本发明实施例基于联邦学习的模型训练方法的第二流程图,应用于联邦学习系统中的中央服务器,如图2所示,本实施例的基于联邦学习的模型训练方法包括步骤S201至步骤S203。

步骤S201,在联邦学习的第i次迭代中,接收各客户端发送的数据特征以及预测精度,其中,每个所述客户端在第i次迭代中根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,以及根据训练出的模型提取本地数据集的数据特征,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集。

步骤S202,根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数。

步骤S203,在联邦学习的第R次迭代中,将所述全局聚合特征传输到各客户端,以使各客户端根据所述全局聚合特征以及本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到每个客户端各自对应的最终训练出的模型。

在本发明一个实施例中,本发明的基于联邦学习的模型训练方法,在联邦学习的第i次迭代中,还包括:

获取每个客户端在本次迭代中发送的模型参数,并对所述模型参数进行聚合得到聚合后的模型参数;

将所述聚合后的模型参数传输到各客户端,以使各客户端将所述聚合后的模型参数作为第i+1次迭代时进行模型训练的初始参数。

在本发明一个实施例中,本发明的模型参数具体可以为模型参数矩阵。在本发明一个实施例中,所述获取每个客户端在本次迭代中发送的模型参数,具体包括:

获取每个客户端采用奇异值分解算法分解模型参数矩阵得到第一奇异值;

通过对所述第一奇异值进行奇异值分解算法的逆运算得到模型参数矩阵。

在本发明一个实施例中,所述将所述聚合后的模型参数传输到各客户端,具体包括:

采用奇异值分解算法分解聚合后的模型参数矩阵得到的第二奇异值;

将所述第二奇异值传输到各客户端,以使各客户端对所述第二奇异值进行奇异值分解算法的逆运算得到聚合后的模型参数矩阵。

在本发明一个实施例中,本发明的数据特征具体可以为数据特征矩阵。在本发明一个实施例中,上述步骤S201中的接收各客户端发送的数据特征以及预测精度,具体包括:

接收各客户端采用奇异值分解算法分解数据特征矩阵得到第三奇异值;

通过对所述第三奇异值进行奇异值分解算法的逆运算得到数据特征矩阵。

在本发明一个实施例中,本发明的全局聚合特征具体可以为全局聚合特征矩阵。在本发明一个实施例中,上述步骤S203中的将所述全局聚合特征传输到各客户端,具体包括:

采用奇异值分解算法分解全局聚合特征矩阵得到的第四奇异值;

将所述第四奇异值传输到各客户端,以使各客户端对所述第四奇异值进行奇异值分解算法的逆运算得到全局聚合特征矩阵。

由以上实施例可以看出,本发明实现了至少以下有益效果:

1、本发明在考虑隐私数据保护的前提下实现了数据特征融合。

2、在进行参数融合之前,本发明提出基于奇异值分解算法(Singular ValueDecomposition,SVD)的关键参数特征提取及融合算法以减小单通讯轮次所需通讯代价,提升预测的精度以及可行性。

需要说明的是,在附图的流程图示出的步骤可以在诸如一组计算机可执行指令的计算机系统中执行,并且,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。

基于同一发明构思,本发明实施例还提供了基于联邦学习的模型训练装置,可以用于实现上述实施例所描述的基于联邦学习的模型训练方法,如下面的实施例所述。由于基于联邦学习的模型训练装置解决问题的原理与基于联邦学习的模型训练方法相似,因此基于联邦学习的模型训练装置的实施例可以参见基于联邦学习的模型训练方法的实施例,重复之处不再赘述。以下所使用的,术语“模块”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。

图7是本发明实施例基于联邦学习的模型训练装置的第一结构框图,如图7所示,在本发明一个实施例中,本发明的基于联邦学习的模型训练装置包括:

第一模型训练处理模块1,用于在联邦学习的第i次迭代中,根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,根据训练出的模型提取本地数据集的数据特征,将所述数据特征以及所述预测精度发送到中央服务器,以使所述中央服务器根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

第二模型训练处理模块2,用于在联邦学习的第R次迭代中,从所述中央服务器获取所述全局聚合特征,并根据所述全局聚合特征以及所述本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到最终训练出的模型。

图8是本发明实施例基于联邦学习的模型训练装置的第二结构框图,如图8所示,在本发明另一个实施例中,本发明的基于联邦学习的模型训练装置包括:

数据接收模块3,用于在联邦学习的第i次迭代中,接收各客户端发送的数据特征以及预测精度,其中,每个所述客户端在第i次迭代中根据本地训练数据集采用第一神经网络进行模型训练,根据本地验证数据集确定训练出的模型的预测精度,以及根据训练出的模型提取本地数据集的数据特征,所述本地数据集包括:所述本地训练数据集和/或所述本地验证数据集;

特征聚合模块4,用于根据在前R-1次迭代中每个客户端在每次迭代时发送的数据特征以及预测精度进行特征聚合得到全局聚合特征,其中,i大于等于1且小于等于R-1,R为大于1的整数;

全局聚合特征传输模块5,用于在联邦学习的第R次迭代中,将所述全局聚合特征传输到各客户端,以使各客户端根据所述全局聚合特征以及本地训练数据集中各训练数据的标签采用第二神经网络进行模型训练,得到每个客户端各自对应的最终训练出的模型。

为了实现上述目的,根据本申请的另一方面,还提供了一种计算机设备。如图9所示,该计算机设备包括存储器、处理器、通信接口以及通信总线,在存储器上存储有可在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述实施例方法中的步骤。

处理器可以为中央处理器(Central Processing Unit,CPU)。处理器还可以为其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等芯片,或者上述各类芯片的组合。

存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序、非暂态计算机可执行程序以及单元,如本发明上述方法实施例中对应的程序单元。处理器通过运行存储在存储器中的非暂态软件程序、指令以及模块,从而执行处理器的各种功能应用以及作品数据处理,即实现上述方法实施例中的方法。

存储器可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储处理器所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施例中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。

所述一个或者多个单元存储在所述存储器中,当被所述处理器执行时,执行上述实施例中的方法。

上述计算机设备具体细节可以对应参阅上述实施例中对应的相关描述和效果进行理解,此处不再赘述。

为了实现上述目的,根据本申请的另一方面,还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序在计算机处理器中执行时实现上述基于联邦学习的模型训练方法中的步骤。本领域技术人员可以理解,实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)、随机存储记忆体(RandomAccessMemory,RAM)、快闪存储器(Flash Memory)、硬盘(Hard DiskDrive,缩写:HDD)或固态硬盘(Solid-State Drive,SSD)等;所述存储介质还可以包括上述种类的存储器的组合。

显然,本领域的技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个的计算装置上,或者分布在多个计算装置所组成的网络上,可选地,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件结合。

以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

相关技术
  • 基于联邦学习的模型训练方法、装置及联邦学习系统
  • 一种联邦学习模型训练方法、装置及联邦学习系统
技术分类

06120112809627