掌桥专利:专业的专利平台
掌桥专利
首页

一种MIMO信道下空中计算联邦学习的梯度上传方法

文献发布时间:2024-04-18 19:52:40


一种MIMO信道下空中计算联邦学习的梯度上传方法

技术领域

本发明属于信息与通信技术领域,涉及一种MIMO信道下空中计算联邦学习的梯度上传方法。

背景技术

第六代(6G)无线通信将支持每平方公里数百万个无线设备的连接密度。这将为实现泛在智能的愿景提供坚实的基础。开发强大的智能模型需要利用在大量边缘设备上的数据多样性。一个简单的范例是要求边缘设备将本地数据上传到中心参数服务器(PS)以进行集中式模型训练。然而,上传原始数据会产生巨大的通信开销,并可能威胁到用户隐私。为了避免这些缺点,联邦学习(FL)是一种很有前景的替代方法,它使边缘设备能够联合训练机器学习(ML)模型,同时保持用户本地数据。与上传原始数据不同,在联邦学习训练中,每个边缘设备将其梯度更新发送到中心服务器,服务器聚合局部梯度,更新全局模型并将全局模型发送回边缘设备。

由于有限的通信资源(例如时间,带宽和空间)难以支持大规模边缘设备的通信需求,梯度上传成为了FL部署在无线网络上的关键瓶颈。近年来,人工智能模型参数数量日益增长,例如,Resnet152具有6000万个参数,而GPT-3具有1750亿个参数。然而,由于带宽和延迟限制,可用的无线通信带宽通常较小,例如1个LTE帧包含5MHz带宽和10ms相干时间只能携带50000个符号。幸运的是,在联邦学习中,相比于每个设备的局部梯度,服务器更关心局部梯度聚合后的梯度。聚合的梯度通常是所有局部梯度的平均值。基于联邦学习的这个特性,空中计算联邦学习(OA-FL)被提出,其中边缘设备通过共享无线资源来传输局部梯度。利用电磁波的模拟叠加,局部梯度在无线传输中完成聚合。与传统的正交多址接入(OMA)方法相比,空中计算联邦学习所需的通信资源并不随着设备的数量而增加,这在很大程度上缓解通信对于联邦学习的瓶颈效应。

由于空中计算联邦学习的广阔前景,许多研究工作已致力于设计更为高效的空中计算联邦学习系统。现有技术提出,局部梯度可以稀疏,压缩和量化后上传以减少通信开销而不会造成明显的学习准确率损失。使用部分正交压缩矩阵和Turbs-CS,可实现低复杂度的梯度稀疏压缩编码的方案。采用上述方案的空中计算联邦学习系统具有较低的通信开销和更快的收敛速率。

然而,现有的梯度压缩编码方案全部是基于单输入单输出(SISO)系统。具有阵列信号处理的多输入多输出(MIMO)已被广泛认为是增强系统容量的强大技术。MIMO多路复用通过天线阵列并行传输多个数据流,可显着减少信道使用的数量。但是,MIMO多路复用会导致数据流间干扰,从而破坏了OA-FL的聚合梯度和测试精度。通过合理设计设备端的预编码矩阵和服务器端的后处理矩阵,可以抑制流间干扰的影响。现有技术使用信道矩阵的伪逆矩阵作为预编码矩阵,并使用微分几何优化技术得出了闭式的后处理矩阵,或根据接收天线选择部分数据流。然而,上述方法均基于信道矩阵求逆,这可能会显着扩大噪声并因此加剧了梯度的聚集误差。尤其是当某些设备处于深度衰落时,上述方案将产生巨大的性能损失。

发明内容

本发明基于MIMO技术提出一种空中计算联邦学习(OA-FL)系统的上行通信设计方案。该方案包括一种新颖的稀疏编码多路复用(SCoM,Sparse-Coded Multiplexing)方案。该方案集成了稀疏压缩编码和MIMO多路复用技术,旨在解决上述提到的空中计算联邦学习通信开销较大的问题和学习性能损失问题。

本发明考虑一个由1个参数服务器(PS,Parameter Server)和M个边缘设备组成的OA-FL系统,其中PS上有N

本发明采用的技术方案包括以下步骤:

S1、如图1所示,OA-FL系统由1个PS和M个边缘设备组成,全局损失函数定义为

式中,

式中,f(θ;ξ

S2、服务器(PS)生成压缩矩阵

在通信轮次t内,执行如下步骤:

S3、服务器与设备进行信道信息(CSI)的估计,假设在每一通信轮次中梯度上传时信道保持不变,并且服务器拥有全局的信道信息矩阵

S4、服务器通过交替优化方法设计发端预编码矩阵

其中,,m和m′分别指示设备m和设备m′,q

S5、收端后处理矩阵F

其中,I为单位阵。

S6、每个发端预编码矩阵

P

V

其中,

S7、通过迭代S5和S6,直至S4中优化问题的目标函数值收敛。此时得到最优的

S8、各个设备在本地进行梯度下降,计算局部梯度

S9、如图2所示,为SCoM在一个通信轮次内本地梯度上传到PS的信号流图。设备使用优化后的预编码矩阵将本地梯度上传。轮次t内,边缘设备将梯度映射为复数版本,如以下公式给出

式中,

式中,

其中,λ∈[0,1]表示稀疏度。sp(·)保留累计梯度

随后设备基于稀疏梯度

其中,⊙为逐元素乘积,

其中,C为压缩后梯度的长度,

S10、为传输多流数据,设备将压缩梯度

其中,N

设备将

S11、在PS端,采用后处理矩阵F

PS对处理后的矩阵

其中,vec(·)为向量化操作,

S12、如图3所示,服务器采用TurboCS算法求解S11中压缩感知问题。TurboCS算法迭代求解聚合后的梯度

其中,

其中,

然后,在模块B,根据先验信息

其中,

其中,

S13、PS依照如下公式得到估计的聚合梯度

其中,

式中,η为学习率。

S14、服务器将更新后的全局模型通过无差错广播信道回传给每个设备。

S15、若达到通信轮次t>T则结束,否则转S3。

本发明的改进可作如下总结:首先,本发明提出了一种新颖的空中计算联邦学习本地梯度上传方案SCoM,利用了MIMO多路复用技术和压缩编码技术组成。所提出的方案在达到相同的学习准确率时显著减少了上传梯度的通信开销。其次,本发明开发了一种基于交替优化(AO)和交替方向乘数法(ADMM)的低复杂度算法来优化预编码和后处理矩阵,从而避免了现有方案中信道反转而导致梯度聚合误差显著增大的问题。最后,本发明给出了最小化梯度聚合误差的最优多路复用数据流数,即发端和收端天线数的最小值。

附图说明

图1:系统模型

图2:SCoM方案中的Turbo-CS算法示意图

图3:SCoM方案中设备上传梯度的流程图

图4:仿真设备分布示意图

图5:SCoM方案中使用学习准确率随不同复用数据流数变化的曲线

图6:使用不同传输方案达到相同学习准确率时信道使用次数的曲线

具体实施方式

下面结合附图和实施例,对本发明的具体实施方式作进一步详细描述。

具体方法的参数设置如下:

考虑一个由20个设备和一个中心服务器组成的FL系统。设备在以基站为中心、半径为100m的圆内均匀分布,如图4所示。基站高度为10米。每个任务数据集大小为60000个样本,每个设备上有3000个样本。本发明的实验训练了两个FL任务,分别基于MNIST和FMNIST两个数据集。每个数据集有两种数据分布,分别为1)独立同分布(i.i.d.),其中所有的数据都被打乱,然后平均分配给20个设备;2)非独立同分布(non-i.i.d.),其中每个设备随机选择4个类别,然后从每个所选类别中随机抽取750个样本。FL任务的模型由一个2层卷积层(每层包含5x5卷积核、2x2最大池化、ReLU激活函数和batchnorm层),1层全连接层和1层softmax输出层组成。学习率设置为0.001。训练通信轮次设置为T=500.

根据以上参数设置,该仿真的具体步骤如下:

S1、如图1所示,OA-FL系统由1个PS和M个边缘设备组成,全局损失函数定义为

式中,

式中,f(θ;ξ

S2、服务器(PS)生成压缩矩阵

S3、服务器与设备进行信道信息(CSI)的估计,假设在每一通信轮次中梯度上传时信道保持不变,并且服务器拥有全局的信道信息矩阵

S4、服务器通过交替优化方法设计发端预编码矩阵

其中,,m和m′分别指示设备m和设备m′,q

S5、收端后处理矩阵F

其中,I为单位阵。

S6、每个发端预编码矩阵

P

V

其中,

S7、通过迭代S5和S6,直至S4中优化问题的目标函数值收敛。此时得到最优的

S8、各个设备在本地进行梯度下降,计算局部梯度

S9、如图2所示,给SCoM在一个通信轮次内本地梯度上传到PS的信号流图。设备使用优化后的预编码矩阵将本地梯度上传。轮次t内,边缘设备将梯度映射为复数版本,如以下公式给出

式中,

式中,

其中,γ∈[0,1]表示稀疏度。sp(·)保留累计梯度

随后设备基于稀疏梯度

其中,⊙为逐元素乘积,

其中,C为压缩后梯度的长度,

S10、为传输多流数据,设备将压缩梯度

其中,N

设备将

S11、在PS端,采用后处理矩阵F

PS对处理后的矩阵

其中,vec(·)为向量化操作,

S12、如图3所示,服务器采用TurboCS算法求解S11中压缩感知问题。TurboCS算法迭代求解聚合后的梯度

其中,

其中,

然后,在模块B,根据先验信息

其中,

/>

其中,

S13、PS依照如下公式得到估计的聚合梯度

其中,

式中,η为学习率。

S14、服务器将更新后的全局模型通过无差错广播信道回传给每个设备。

S15、若达到通信轮次t>T则结束,否则转S3。

在图5中,研究了多路复用数据流的数量N

在图6中,本发明展示了多种传输方案所需的信道使用总数与相对学习准确率的关系。如图6所示,在四种数据分布中,本发明所提出的算法在达到相同的学习准确率时消耗了最少的通信开销,并且明显优于所有基线,这清楚地证明了本发明提出的方案的优越性。

相关技术
  • MIMO干扰信道下空中计算多任务联邦学习方法
  • MIMO干扰信道下空中计算多任务联邦学习方法
技术分类

06120116331490