掌桥专利:专业的专利平台
掌桥专利
首页

技术领域

本发明涉及人工智能技术领域,尤其是涉及一种云边端协同个性化联邦学习方法、系统、设备及介质。

背景技术

近年来,联邦学习在工业界和学术界发展非常迅猛,它是一种分布式机器学习范式,在保护用户隐私安全的前提下通过联合多个参与方实现共同建模。在当前流行的联邦学习框架中,参与训练的客户端可以从移动边缘设备到大型企业组织,它们将原始训练数据存在终端设备中。在中央参数服务器的协调下,各个客户端互相协作进行模型训练。每个训练参与方将本地训练的模型参数上传到中心服务器,中心服务器通过聚合各个用户的模型参数更新,然后将聚合更新后的全局模型下发给每个训练参与方。联邦学习的目标是训练一个在大多数客户端上表现良好的全局模型,实现用户之间的知识共享。

当前主流的联邦学习侧重训练模型的通用性能,然而由于用户数据不平衡和非IID(Independent and identically distributed)分布,导致在特定客户端场景下的个性化性能不佳。现有的联邦学习方法在non-IID(非IID)数据上学习时,一般收敛性较差而且单一的全局模型无法适用差异明显的客户端分布情况。现有的个性化FL(联邦学习)方法一般通过训练一个性能较好的全局模型,然后再在每个客户端上进行本地个性化处理。现阶段,很多联邦优化算法在缓解non-IID数据问题过程中间会增加大量通信成本,同时无法有效提升模型性能。

在联邦学习环境下,不同客户端中局部数据集的高度异构性和不同任务导致的统计异构性会降低整体模型训练效率。针对不同客户端的差异,单一的全局模型无法满足所有客户端的需求,而现有的个性化联邦学习机制一般只侧重解决数据集异构性问题,而忽略了客户端之间的联系,使得训练的模型的准确度不高。

发明内容

本发明旨在至少解决现有技术中存在的技术问题之一。为此,本发明提出一种云边端协同个性化联邦学习方法、系统、设备及介质,通过考虑客户端的特征信息以及客户端之间的关系图结构,能够有效提升模型的推理精确度和训练速度。

第一方面,本发明实施例提供了一种云边端协同个性化联邦学习方法,所述云边端协同个性化联邦学习方法包括:

获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构;

获取所述数据样本的特征矩阵表示和邻接矩阵表示,并根据所述特征矩阵表示和所述邻接矩阵表示,获得所述数据样本的特征张量和邻接张量;

根据所述邻接张量和所述特征张量,对所述多个客户端的终端设备进行聚类,获得多组客户端终端设备;

构建云边端协同个性化联邦学习的目标函数和特征图,并根据所述目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数;

根据所述个性化模型的参数、所述特征图和所述图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型。

与现有技术相比,本发明第一方面具有以下有益效果:

本方法获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构,获取数据样本的特征矩阵表示和邻接矩阵表示,并根据特征矩阵表示和邻接矩阵表示,获得数据样本的特征张量和邻接张量,通过特征张量和邻接张量能够更加完整和准确地描述特征丰富的数据样本,并能够在缓解non-IID数据问题过程中降低通信成本;根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备,通过利用客户端的特征张量和邻接张量可以精确表现终端设备原始数据的特征,无需在云端收集用户原始数据,能够在降低数据隐私泄露风险的同时提升聚类算法效率;构建云边端协同个性化联邦学习的目标函数和特征图,并根据目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数,根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型;其中,特征图包含了特征信息,因此,本方法考虑了客户端的特征信息以及客户端之间的关系图结构,能够有效提升模型的推理精确度和训练速度。

根据本发明的一些实施例,所述根据所述邻接张量和所述特征张量,对所述多个客户端的终端设备进行聚类,获得多组客户端终端设备,包括:

根据注意力机制和图卷积神级网络,构建基于深度注意力机制的自编码器;

采用所述自编码器获取所述客户端的隐藏嵌入表征;

将所述隐藏嵌入表征、所述特征矩阵表示和所述邻接矩阵表示输入至所述图卷积神级网络中,获得重构图;

根据所述图结构和所述重构图,构建重构损失函数

基于所述隐藏嵌入表征,采用K-means方法获取聚类中心,并获取各个客户端的软标签分配,所述软标签分配是各个客户端归于某个类的概率分布;

根据所述聚类中心,通过增强节点特征提升所述软标签分配的置信度,获得新的目标分布:

其中,

根据所述软标签分配和所述新的目标分布之间相对熵值构造聚类损失函数

根据所述聚类损失函数

根据本发明的一些实施例,在所述对每组客户端终端设备进行个性化模型训练之前,所述云边端协同个性化联邦学习方法还包括:

采用ISMOTE方法对少样本数量的每组客户端终端设备进行样本分析和构造,获得可预测标签。

根据本发明的一些实施例,所述采用ISMOTE方法对少样本数量的每组客户端终端设备进行样本分析和构造,获得可预测标签,包括:

通过编码器将每个终端设备中的高维数据映射为低维数据;

采用欧式距离计算所述低维数据中的当前样本与其余样本的距离,获得多个邻近样本;

基于整体数据不平衡的比例设置采样倍率,在所述当前样本邻近的所述多个邻近样本中选择预设数量的样本,并在所述预设数量的样本上构建新样本;

对所述新样本采用多层感知机输出可预测标签。

根据本发明的一些实施例,所述个性化联邦学习的目标函数包括:

其中,

根据本发明的一些实施例,在根据所述个性化模型的参数、所述特征图和所述图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型之前,所述云边端协同个性化联邦学习方法还包括:

基于客户端特征信息,构建无监督学习的特征图自动编码器;

通过所述无监督学习的特征图自动编码器将高维稀疏数据表征为特征图形式。

根据本发明的一些实施例,构建云边端协同个性化联邦学习的目标函数,包括:

将所述特征图引入所述个性化联邦学习的目标函数中,获得云边端协同个性化联邦学习的目标函数;其中,所述特征图包括特征矩阵;所述云边端协同个性化联邦学习的目标函数通过如下公式表示:

其中,

第二方面,本发明实施例还提供了一种云边端协同个性化联邦学习系统,所述云边端协同个性化联邦学习系统包括:

数据获取单元,用于获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构;

张量获取单元,用于获取所述数据样本的特征矩阵表示和邻接矩阵表示,并根据所述特征矩阵表示和所述邻接矩阵表示,获得所述数据样本的特征张量和邻接张量;

设备聚类单元,用于根据所述邻接张量和所述特征张量,对所述多个客户端的终端设备进行聚类,获得多组客户端终端设备;

参数更新单元,用于构建云边端协同个性化联邦学习的目标函数和特征图,并根据所述目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数;

联邦学习单元,用于根据所述个性化模型的参数、所述特征图和所述图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型。

第三方面,本发明实施例还提供了一种云边端协同个性化联邦学习设备,包括至少一个控制处理器和用于与所述至少一个控制处理器通信连接的存储器;所述存储器存储有可被所述至少一个控制处理器执行的指令,所述指令被所述至少一个控制处理器执行,以使所述至少一个控制处理器能够执行如上所述的一种云边端协同个性化联邦学习方法。

第四方面,本发明实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行如上所述的一种云边端协同个性化联邦学习方法。

可以理解的是,上述第二方面至第四方面与相关技术相比存在的有益效果与上述第一方面与相关技术相比存在的有益效果相同,可以参见上述第一方面中的相关描述,在此不再赘述。

附图说明

本发明的上述和/或附加的方面和优点从结合下面附图对实施例的描述中将变得明显和容易理解,其中:

图1是本发明一实施例的一种云边端协同个性化联邦学习方法的流程图;

图2是本发明一实施例的构造多维特征张量的示意图;

图3是本发明一实施例的云边端协同个性化联邦学习框架的示意图;

图4是本发明一实施例的客户端的终端设备分组的流程图;

图5是本发明一实施例的ISMOTE方法的流程图;

图6是本发明一实施例的一种云边端协同个性化联邦学习系统的流程图。

具体实施方式

下面详细描述本发明的实施例,所述实施例的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施例是示例性的,仅用于解释本发明,而不能理解为对本发明的限制。

在本发明的描述中,如果有描述到第一、第二等只是用于区分技术特征为目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量或者隐含指明所指示的技术特征的先后关系。

在本发明的描述中,需要理解的是,涉及到方位描述,例如上、下等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。

本发明的描述中,需要说明的是,除非另有明确的限定,设置、安装、连接等词语应做广义理解,所属技术领域技术人员可以结合技术方案的具体内容合理确定上述词语在本发明中的具体含义。

当前主流的联邦学习侧重训练模型的通用性能,然而由于用户数据不平衡和非IID(Independent and identically distributed)分布,导致在特定客户端场景下的个性化性能不佳。现有的联邦学习方法在non-IID(非IID)数据上学习时,一般收敛性较差而且单一的全局模型无法适用差异明显的客户端分布情况。现有的个性化FL(联邦学习)方法一般通过训练一个性能较好的全局模型,然后再在每个客户端上进行本地个性化处理。现阶段,很多联邦优化算法在缓解non-IID数据问题过程中间会增加大量通信成本,同时无法有效提升模型性能。

在联邦学习环境下,不同客户端中局部数据集的高度异构性和不同任务导致的统计异构性会降低整体模型训练效率。针对不同客户端的差异,单一的全局模型无法满足所有客户端的需求,而现有的个性化联邦学习机制一般只侧重解决数据集异构性问题,而忽略了客户端之间的联系,使得训练的模型的准确度不高。

为解决上述问题,本发明通过获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构,获取数据样本的特征矩阵表示和邻接矩阵表示,并根据特征矩阵表示和邻接矩阵表示,获得数据样本的特征张量和邻接张量,通过特征张量和邻接张量能够更加完整和准确地描述特征丰富的数据样本,并能够在缓解non-IID数据问题过程中降低通信成本;根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备,通过利用客户端的特征张量和邻接张量可以精确表现终端设备原始数据的特征,无需在云端收集用户原始数据,能够在降低数据隐私泄露风险的同时提升聚类算法效率;构建云边端协同个性化联邦学习的目标函数和特征图,并根据目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数,根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型;其中,特征图包含了特征信息,因此,本方法考虑了客户端的特征信息以及客户端之间的关系图结构,能够有效提升模型的推理精确度和训练速度。

参照图1,本发明实施例提供了一种云边端协同个性化联邦学习方法,本云边端协同个性化联邦学习方法包括但不限于步骤S100至步骤S500:

步骤S100、获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构;

步骤S200、获取数据样本的特征矩阵表示和邻接矩阵表示,并根据特征矩阵表示和邻接矩阵表示,获得数据样本的特征张量和邻接张量;

步骤S300、根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备;

步骤S400、构建云边端协同个性化联邦学习的目标函数和特征图,并根据目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数;

步骤S500、根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型。

在一些实施例的步骤S100至步骤S500中,为了能够更加完整和准确地描述特征丰富的数据样本,并能够在缓解non-IID数据问题过程中降低通信成本,通过获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构,获取数据样本的特征矩阵表示和邻接矩阵表示,并根据特征矩阵表示和邻接矩阵表示,获得数据样本的特征张量和邻接张量;为了精确表现终端设备原始数据的特征,无需在云端收集用户原始数据,并在降低数据隐私泄露风险的同时提升聚类算法效率,通过根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备;为了提升模型的推理精确度和训练速度,通过构建云边端协同个性化联邦学习的目标函数和特征图,并根据目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数,根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型。

在一些实施例中,根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备,包括:

根据注意力机制和图卷积神级网络,构建基于深度注意力机制的自编码器;

采用自编码器获取客户端的隐藏嵌入表征;

将隐藏嵌入表征、特征矩阵表示和邻接矩阵表示输入至图卷积神级网络中,获得重构图;

根据图结构和重构图,构建重构损失函数

基于隐藏嵌入表征,采用K-means方法获取聚类中心,并获取各个客户端的软标签分配,软标签分配是各个客户端归于某个类的概率分布;

根据聚类中心,通过增强节点特征提升软标签分配的置信度,获得新的目标分布:

其中,

根据软标签分配和新的目标分布之间相对熵值构造聚类损失函数

根据聚类损失函数

在本实施例中,通过利用客户端节点的特征张量和邻接张量可以精确表现终端设备原始数据的特征,无需在云端收集用户原始数据,在降低数据隐私泄露风险的同时提升聚类算法效率。

在一些实施例中,在对每组客户端终端设备进行个性化模型训练之前,云边端协同个性化联邦学习方法还包括:

采用ISMOTE方法对少样本数量的每组客户端终端设备进行样本分析和构造,获得可预测标签。

在本实施例中,ISMOTE方法基于传统SMOTE算法进行的改进,通过ISMOTE方法能够解决易产生过拟合的问题,改进后产生的数据会减少异常数据出现频次、模型泛化性能较低和数据边缘化的问题。

在一些实施例中,采用ISMOTE方法对少样本数量的每组客户端终端设备进行样本分析和构造,获得可预测标签,包括:

通过编码器将每个终端设备中的高维数据映射为低维数据;

采用欧式距离计算低维数据中的当前样本与其余样本的距离,获得多个邻近样本;

基于整体数据不平衡的比例设置采样倍率,在当前样本邻近的多个邻近样本中选择预设数量的样本,并在预设数量的样本上构建新样本;

对新样本采用多层感知机输出可预测标签。

在本实施例中,通过多层感知机输出可预测标签是生成少数类样本流程中的一部分,利用多层感知机生成的可预测标签,可以用于再次进行训练从而微调本地模型参数,最终可以提升个性化模型的准确性和鲁棒性。

在一些实施例中,个性化联邦学习的目标函数包括:

其中,

在一些实施例中,在根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型之前,云边端协同个性化联邦学习方法还包括:

基于客户端特征信息,构建无监督学习的特征图自动编码器;

通过无监督学习的特征图自动编码器将高维稀疏数据表征为特征图形式。

在本实施例中,构建的无监督学习的特征图自动编码器能够减少特征图信息传输到云端时造成隐私暴露以及高额通信成本的问题。

在一些实施例中,构建云边端协同个性化联邦学习的目标函数,包括:

将特征图引入个性化联邦学习的目标函数中,获得云边端协同个性化联邦学习的目标函数;其中,特征图包括特征矩阵;云边端协同个性化联邦学习的目标函数通过如下公式表示:

其中,

为方便本领域人员理解,以下提供一组最佳实施例:

随着人工智能、物联网和大数据等现代科学的发展,终端设备的数据海量式增长,数据的维度也在快速增加。对这些高维度多源异构数据使用矩阵和矢量进行标识无法完整呈现样本的完整特征,基于张量的高维数据特征提取可以高效地进行特征表示。

张量作为标量、向量和矩阵的一种高阶泛化形式,其表示的物理意义是从多个维度表示数据特征。假设在常见的三维空间中,零阶张量叫做标量,可以简单理解是一个数值;一阶张量叫做向量,当在n维空间中,一个一阶张量可以理解为

由于客户端节点之间的拓扑结构和数据特征之间都可能存在联系,本实施例提出使用邻接张量来表示节点的联系。每种关系对应一个邻接矩阵,邻接矩阵的集合可以构成邻接张量,同时邻接矩阵也是邻接张量的一个切片。本实施例提出的邻接张量可以充分利用客户端节点之间存在的多种联系。定义节点特征

面向联邦训练系统中多源异构数据的场景,本实施例在算法中引入高维特征张量,将传统的特征矩阵

参照图2,在本实施例中,定义

本实施例提出一种新颖的云边端协同个性化联邦学习框架,如图3所示。本实施例的云边端协同个性化联邦学习框架面向云边端系统,先在边缘联邦模式下,对终端设备中的无标签数据进行整合训练,再通过中心云服务器的统一协调下,从分散的边缘服务器训练出一个全局通用的模型。在云边端协同框架下,终端用户通过本地数据对通用模型进行个性化训练,保障了用户隐私和模型个性化的实现。

1、客户端的终端设备选择阶段。

参照图4,本实施例将总体流程分成四个步骤:数据收集、提取特征、生成张量和设备分组。在本地数据收集步骤中,终端设备会产生大量无标签数据,数据样本呈现non-IID分布。为了保护用户隐私安全和满足本地模型训练需求,终端设备产生的数据样本都保存在设备内存中。由于终端设备存在海量的数据样本,若采用基于所有数据的聚类算法会消耗大量计算资源和时间成本。本实施例的设备选择方法利用客户端本地原始数据特征和节点图结构特征,基于邻接张量和特征张量进行客户端相似度聚类。通过利用客户端节点的特征张量和邻接张量可以精确表现终端设备原始数据的特征,无需在云端收集用户原始数据,在降低数据隐私泄露风险的同时提升聚类算法效率。

将所有参与训练的客户端构建的图结构定义为A,客户端聚类目标是将A中的所有节点划分为k个互不相交的小组

(1)基于注意力机制和图卷积神级网络(GCN),设计一个可以融合特征信息和结构信息的基于深度注意力机制的自编码器,通过基于深度注意力机制的自编码器实现总体数据的重建;

(2)在所有客户端构成的图A中,特征矩阵X表示节点的特征信息,邻接矩阵S表示节点的结构信息。使用上述自编码器获得节点的隐藏嵌入表征H,将隐藏嵌入表征H、邻接矩阵S和特征矩阵X作为图卷积神经网络(GCN)的输入;

(3)将隐藏嵌入表征H和邻接矩阵S相乘得到重构图G,通过原始图A和重构图G构建损失函数

(4)基于节点的隐藏嵌入表征H,通过K-means算法获取最初的聚类中心m,基于节点的隐藏嵌入表征,将各个节点具体归于某个类的概率分布定义为软标签分配,用Q表示,

(5)为了提升聚类效率,实现内类节点距离最小化,通过增强节点特征提升Q的置信度,定义新的目标分布P:

(6)通过P和Q之间相对熵值构造聚类损失函数

(7)基于聚类损失函数

2、本地模型训练阶段。

在每组客户端终端设备进行个性化模型训练之前,采用ISMOTE方法对少样本数量的每组客户端终端设备进行样本分析和构造,获得可预测标签,具体为:

将智能终端中的样本数据定义为

针对终端设备中少数类样本数量不足、分布不平衡问题。本实施例引入ISMOTE方法,通过使用过采样方法实现生成类平衡的数据。经典的SMOTE算法通过随机采取简单的复制样本用于增加少数类样本,用于解决数据不平衡问题。ISMOTE算法基于传统SMOTE算法进行改进,解决易产生过拟合的问题,改进后的数据会减少异常数据出现频次、模型泛化性能较低和数据边缘化的问题。传统KNN算法处理异常数据和不平衡分布的情况下算法效率低,本文使用VKNN进行数据分类操作进行优化。ISMOTE算法没有采用简单复制样本的方法增加少数类样本,而是对少数类样本分析并构造,ISMOTE算法的总体流程分为以下几个步骤:1)通过编码器网络将终端设备中的高维数据映射为低维数据。2)在低维数据中,对少数类中每一个样本d,基于欧式距离计算其余样本和d的距离,获得k个邻近样本。3)基于整体数据不平衡的比例设置采样倍率r,在每一个少数类样本d邻近的k个样本中选择部分样本,假设选中的样本为c。4)在原来样本的基础上构建新样本y。5)最后通过多层感知机输出可预测标签。ISMOTE算法流程如图5所示。

3、模型聚合阶段。

本实施例通过利用节点特征信息和图结构进行个性化联邦模型聚合来探索损失函数的最优解。本实施例引入邻接矩阵E(邻接矩阵包含图结构信息),将个性化联邦学习的目标函数定义为以下双层优化问题:

其中,每个客户端都有一个相应的本地模型

当参与训练的客户端之间没有明显的图结构信息时,用于表示邻接矩阵的E就可能不存在,无法利用结构化的方法对本地模型进行表征和推理。在结构信息较少的情况下,可以通过增强模型聚合过程中客户端节点特征来提升模型性能。本实施例提出基于客户端节点特征信息构建特征图GF(特征图里面包含特征节点F),用来约束节点特征相似度的接近关系。在构建客户端设备的特征图的流程中,为了减少特征图信息传输到云端时造成隐私暴露以及高额通信成本的问题,本实施例设计了基于无监督学习的特征图自动编码器(FMAE)。编码器网络通过将原始输入映射为一个表示向量,然后解码器网络将潜在空间还原成原始空间。FMAE基于无监督学习方法将输入的高维稀疏数据表征为紧凑有效的特征图形式,本地用户数据通过编码器网络压缩成低维的特征表示,这些数据样本和特征表征无须上传到云端,由本地客户端进行维护。由于CNN在目标检测、图像分类和自然语言处理等领域取得了很大成功,本实施例使用CNN作为编码器和解码器网络的主要架构。F表示联邦训练参与者的特征信息,T(F)表示添加特征矩阵的正则化项,引入特征矩阵F后的云边端协同个性化联邦学习的目标函数:

在模型聚合过程中,本实施例通过计算

本实施例的模型聚合策略整体流程包括如下步骤:

(1)该算法的输入包括:客户端与服务端通信次数T;分布式客户端总数N;学习率r;本地迭代次数L;客户端集群CG;聚类数目K;客户端节点特征信息构建特征图GF;客户端拓扑图结构A;特征矩阵X。

(2)在客户端选择阶段,在每一轮训练轮次开始之前,基于节点张量的图聚类算法TNGC(A,K,X)获得客户端集群CG。

(3)在本地训练阶段,对于各个分组中的客户端

(4)在云端聚合阶段,利用客户端拓扑图结构A和特征矩阵X,基于GCN进行模型细粒度梯度聚合,获得全局模型

(5)最终输出全局模型

为了更好的说明,本实施例进行了如下实验:

在本实施例中,为了验证本实施例的技术方案的有效性,实验的硬件环境设置为Intel i7-9700k(3.6GHz,8cores)、32GB DDR4内存、NVIDIA GeForce RTX2080Ti(32 GB)。软件环境设置为Ubuntu 18.04.1和CUDA11.4。本实施例采用联邦学习架构的通用设置,基于Pytorch框架搭建一个可以在中心服务器协调下支持若干个客户端联合训练的联邦学习框架。在联邦训练过程中,可以通过设置学习率、本地更新步骤等多个超参数进行细粒度优化。本实施例将本实施例的技术方案和最先进的方法进行比较,包括FedAvg、Per-FedAvg、FedProx、FedEnsemble、FedDistill、FED-ROD。为了本实施例的技术方案进行评估,本实施例选用CIFAR-10、CIFAR-100、MNIST、FashionMNIST和CELEBA等五个公开通用的数据集。在实验总结部分,本实施例通过联邦训练过程中可视化呈现、实验结果以及分析。最终,通过消融实验验证本实施例的技术方案的有效性和鲁棒性。

需要说明的是,本实施例的FedAvg、Per-FedAvg、FedProx、FedEnsemble、FedDistill、FED-ROD都为现有技术,本实施例不做具体描述。

1、实验准备。

数据集:本实施例选用五个真实世界的图像数据集用于全面准确评估模型性能:CIFAR-10(50000条训练数据和10000条测试数据)、CIFAR-100(包含100个类的彩色图片,每个类有600条数据,500张作为训练集,100条作为训练集)、MNIST(60000条训练数据和10000条测试数据)、EMNIST(60000条训练数据,10000条测试数据)和CELEBA(202,599张人脸图片,5个人脸特征点和40个属性标记)。CIFAR-10、CIFAR-100、MNIST和EMNIST用于图像和数字分类任务,CELEBA数据集用于预测名人是否处于微笑状态的二分类任务。

基线:将本实施例的技术方案和其他基线进行比较,用于验证算法的有效性。FedAvg算法通过随机选择若干客户端进行采样,对选择客户端的梯度更新求平均值后形成全局更新,最终使用全局更新模型代替剩余未被选择的客户端。Per-FedAvg在FedAvg基础上引入元学习思想,设计了一个基于所有客户端元函数平均的优化函数,客户端元函数是本地损失函数通过梯度下降后得到的模型。FedEnsemble对FedAvg进行扩展研究,引入模型集成的方法,先利用随机排列更新模型,再通过平均模型集成获得模型更新。FedProx改进了本地损失函数,通过修正项约束本地模型不会偏离全局模型,同时通过动态调整本地更新次数,提升在异构系统中的性能。FedDistill的思想是交换本地模型输出,无需交换网络参数,在聚合模型参数方面具有一定健壮性。FED-ROD引入对不同类分布具有鲁棒性的损失族和设计一个自适应个性化预测器,使得模型可以同时实现最先进的通用和个性化性能。

实验设置:在模型训练和测试过程中,所有实验模型均通过相同的参数设置方法和使用基于Pytorch进行设置。在模型共享参数配置中,联邦训练一般设置200轮次全局通信,指定20个训练参与方,活跃用户比例占50%。在本地训练中,本地更新步长设置T为10,每个步长的批量大小为32,采用随机梯度下降作为优化器。CIFAR-10和CIFAR-100学习率设置为0.01,MNIST和FashionMNIST的学习率设置为0.005,CELEBA的学习率设置为0.0001。

2、性能评估。

表1

从表1中可知本实施例的技术方案相较于其他基线方法有较大优势。在面对不同用户数据异质性情况下,本实施例的技术方案通过基于张量的多视图聚类方法、引入ISMOTE算法实现少数类样本生成和基于样本特征和图结构实现模型细粒度聚合过程等方法降低数据异构性带来的影响。

参照图6,本发明实施例还提供了一种云边端协同个性化联邦学习系统,本云边端协同个性化联邦学习系统包括数据获取单元100、张量获取单元200、设备聚类单元300、参数更新单元400和联邦学习单元500,其中:

数据获取单元100,用于获取多个客户端的终端设备产生的数据样本和所有客户端构成的图结构;

张量获取单元200,用于获取数据样本的特征矩阵表示和邻接矩阵表示,并根据特征矩阵表示和邻接矩阵表示,获得数据样本的特征张量和邻接张量;

设备聚类单元300,用于根据邻接张量和特征张量,对多个客户端的终端设备进行聚类,获得多组客户端终端设备;

参数更新单元400,用于构建云边端协同个性化联邦学习的目标函数和特征图,并根据目标函数和特征图,对每组客户端终端设备进行个性化模型训练,更新每组客户端终端设备的个性化模型的参数;

联邦学习单元500,用于根据个性化模型的参数、特征图和图结构,通过图卷积神经网络进行聚合,获得每个客户端的全局模型和个性化模型。

需要说明的是,由于本实施例中的一种云边端协同个性化联邦学习系统与上述的一种云边端协同个性化联邦学习方法基于相同的发明构思,因此,方法实施例中的相应内容同样适用于本系统实施例,此处不再详述。

本发明实施例还提供了一种云边端协同个性化联邦学习设备,包括:至少一个控制处理器和用于与至少一个控制处理器通信连接的存储器。

存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。

实现上述实施例的一种云边端协同个性化联邦学习方法所需的非暂态软件程序以及指令存储在存储器中,当被处理器执行时,执行上述实施例中的一种云边端协同个性化联邦学习方法,例如,执行以上描述的图1中的方法步骤S100至步骤S500。

以上所描述的系统实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。

本发明实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机可执行指令,该计算机可执行指令被一个或多个控制处理器执行,可使得上述一个或多个控制处理器执行上述方法实施例中的一种云边端协同个性化联邦学习方法,例如,执行以上描述的图1中的方法步骤S100至步骤S500的功能。

本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统可以被实施为软件、固件、硬件及其适当的组合。某些物理组件或所有物理组件可以被实施为由处理器,如中央处理器、数字信号处理器或微处理器执行的软件,或者被实施为硬件,或者被实施为集成电路,如专用集成电路。这样的软件可以分布在计算机可读介质上,计算机可读介质可以包括计算机存储介质(或非暂时性介质)和通信介质(或暂时性介质)。如本领域普通技术人员公知的,术语计算机存储介质包括在用于存储信息(诸如计算机可读指令、数据结构、程序模块或其他数据)的任何方法或技术中实施的易失性和非易失性、可移除和不可移除介质。计算机存储介质包括但不限于RAM、ROM、EEPROM、闪存或其他存储器技术、CD-ROM、数字多功能盘(DVD)或其他光盘存储、磁盒、磁带、磁盘存储或其他磁存储装置、或者可以用于存储期望的信息并且可以被计算机访问的任何其他的介质。此外,本领域普通技术人员公知的是,通信介质通常包含计算机可读指令、数据结构、程序模块或者诸如载波或其他传输机制之类的调制数据信号中的其他数据,并且可包括任何信息递送介质。

上面结合附图对本发明实施例作了详细说明,但本发明不限于上述实施例,在所属技术领域普通技术人员所具备的知识范围内,还可以在不脱离本发明宗旨的前提下作出各种变化。

相关技术
  • 联邦学习方法、系统及可读存储介质
  • 基于端、边及云协同的机器学习方法、系统及介质
  • 基于联邦学习的云边端协同方法、控制装置、及协同系统
技术分类

06120115601103