掌桥专利:专业的专利平台
掌桥专利
首页

一种基于通道剪枝的无人机图像实时目标检测方法

文献发布时间:2023-06-19 11:52:33


一种基于通道剪枝的无人机图像实时目标检测方法

技术领域

本发明属于无人机图像目标识别技术领域,具体是涉及一种基于通道剪枝的无人机图像实时目标检测方法。

背景技术

无人机执行安保防护、集会监控、自然勘探等户外飞行任务时,需要地面站配合实时识别目标来进行监测。部署到具体的硬件资源时要考虑到硬件的资源情况,要进一步提高识别速度才能在性能有限的笔记本设备上实现实时识别。无人机地面站的计算量相对手机等嵌入式设备有剩余,因此我们可以将改进的YOLO网络进行模型压缩来降低参数量,提升识别速度。通道剪枝既可以对各种网络模型实现较好的泛化能力,又不依赖特殊的计算资源,可以对卷积层和全连接层直接操作。

发明内容

发明目的:为了克服现有的目标识别模型对硬件资源的要求高、识别速度慢,难以实时识别无人机所处场景的问题,本发明提供了一种基于通道剪枝的无人机图像实时目标检测方法。

技术方案:本发明所述的一种基于通道剪枝的无人机图像实时目标检测方法,包括以下步骤:

(1)使用改进的YOLO网络对无人机数据集进行基础训练,对基础训练完的网络基于批次标准化层的缩放因子重新进行稀疏化训练,产生稀疏化的缩放因子;

(2)为了残差模块的输入输出特征通道匹配,采取保守剪枝策略和全网络剪枝策略,以BN层的尺度缩放因子作为剪枝通道的选择标准进行通道剪枝;

(3)剪枝后采用知识蒸馏策略对剪枝网络和模型进行微调,使得模型的目标识别精度恢复;

(4)从模型压缩效果和目标识别效果两个维度综合分析模型,得到无人机图像实时多目标识别的最优实现模型。

进一步地,所述步骤(1)包括以下步骤:

(11)为每个通道引入一个缩放因子γ,用缩放因子乘以该通道的输出;

(12)共同训练改进的YOLO网络权重和缩放因子,并对缩放因子进行稀疏正则:基于YOLO算法BN层γ系数的通道剪枝方法的损失函数如下:

L

其中,(x,y)表示训练的输入和目标,W表示用于训练的权重,∑

进一步地,所述步骤(2)实现过程如下:

(21)对有直连操作的残差块进行保守剪枝,即不进行通道剪枝,避免直连层维度不一致;

(22)对一般特征图做通道剪枝操作,最后再对残差块关联的特征图进行剪枝,即进行全网络剪枝;

(23)对直连的特征张量进行通道剪枝,需要将同样位置的通道的γ因子相加再排序;

(24)以缩放因子阈值为依据对此特征图的通道做剪枝。

进一步地,步骤(2)所述的通道剪枝利用掩码来标记卷积层通道,需要进行剪枝的通道掩码为1,保留的通道掩码为0;逐层地对网络层进行剪枝,根据掩码判断是否删除该通道相连的输入、输出、卷积核以及批次标准化层的参数,将待剪枝通道操作完成后生成新的模型参数文件。

进一步地,所述步骤(3)通过以下公式实现:

其中,p表示真实标签的概率分布,z和r代表学生网络和教师网络的预测输出,T是温度超参数,以使softmax分类器的输出更加平滑,从教师网络的输出中提炼出标签分布的知识。

有益效果:与现有技术相比,本发明的有益效果:1、以微调后精度不下降为条件,0.35比例的全网络通道剪枝的模型压缩效果最好,剪枝后参数量减少了2928万,计算量减少26.4BFLOPs,目标识别模型的大小减少了111.74M,网络运行内存减少了0.67G,网络前向推理时间减少3ms;2、得到的最好的剪枝模型的参数量、计算量、模型内存大小、前向推理时间都比保守通道剪枝的要少,在台式电脑的训练环境下可达到33FPS的识别速度,在地面站使用的笔记本电脑环境下可以达到25FPS的识别速度。

附图说明

图1为本发明的流程图;

图2为基于BN层缩放因子的通道剪枝效果图;

图3为稀疏训练中不同平衡因子下所有BN层中尺度缩放因子的数量分布的实验结果;

图4为通道剪枝原理图。

具体实施方式

下面结合附图对本发明作进一步详细说明。

本发明提供一种基于通道剪枝的无人机图像实时目标检测方法,如图1所示,具体包括以下步骤:

步骤1:使用改进的YOLO网络对无人机数据集进行基础训练,对基础训练完的网络基于批次标准化层的缩放因子重新进行稀疏化训练,产生稀疏化的缩放因子。

如图2所示,首先要为每个通道引入一个缩放因子γ,用缩放因子乘以该通道的输出。然后共同训练网络权重和缩放因子,并对缩放因子进行稀疏正则。具体来说,基于BN层γ系数的通道剪枝方法的损失函数可用下式(1)表示:

L

其中,(x,y)表示训练的输入和目标,W表示用于训练的权重,前面的求和项对应卷积神经网络正常训练的损失值;g函数是缩放因子的稀疏性惩罚项,λ是平衡这两项的系数。可以选择L1正则项或者L2正则项作为缩放因子的惩罚项,这两种正则化方法广泛用于实现稀疏性,其中L1正则项为g(γ)=|γ|,L2正则项为g(γ)=γ

步骤2:为了残差模块的输入输出特征通道匹配,采取保守剪枝策略和全网络剪枝策略,以BN层的尺度缩放因子作为剪枝通道的选择标准进行通道剪枝。

如图4所示,YOLO网络结构可以分为主干网络和检测部分,在对主干部分进行通道剪枝时,残差模块是需要特别注意的结构。残差模块中为了解决梯度发散问题引入了直连操作,将两相同维度的特征张量对应逐位相加。如果直接进行通道剪枝,会导致残差模块的输入输出特征图无法匹配通道数,所以对这些特征张量不能直接删除通道,要留下相同位置的通道。第一种策略是对有直连操作的残差块直接不进行通道剪枝,从而避免直连层维度不一致问题,也就是保守剪枝。第二种策略是先对一般特征图做通道剪枝操作,最后再对残差块关联的特征图进行剪枝,称为全网络剪枝。对直连的特征张量进行通道剪枝,需要将同样位置的通道的γ因子相加再排序,最后以缩放因子阈值为依据对此特征图的通道做剪枝。

通道剪枝过程中利用掩码来标记卷积层通道,需要进行剪枝的通道掩码为1,保留的通道掩码为0;逐层地对网络层进行剪枝,根据掩码判断是否删除该通道相连的输入、输出、卷积核以及批次标准化层的参数,将待剪枝通道操作完成后生成新的模型参数文件。

步骤3:剪枝后采用知识蒸馏策略对剪枝网络和模型进行微调,使得模型的目标识别精度恢复。

使用知识蒸馏策略,知识蒸馏适用于模型大小结构相似的网络,在剪枝后的微调使用效果显著。知识蒸馏策略是利用教师网络的特征图等来训练一个紧凑的学生网络。剪枝微调阶段,教师网络是剪枝操作前的预训练模型,剪枝后的网络通过模仿预训练模型不断提高目标识别准确度,同时保持剪枝后模型的复杂度不变。具体实施方法是在训练时增加蒸馏损失,以此惩罚两个网络的softmax分类器输出的不一致。原本使用负交叉熵损失l(p,softmax(z))来度量网络的预测输出和真实标签之间的差异,现在添加蒸馏部分的损失函数,因此知识蒸馏的损失函数变为下式(2)所示:

其中p表示真实标签的概率分布,z和r代表学生网络和教师网络的预测输出,T是温度超参数,以使softmax分类器的输出更加平滑,从教师网络的输出中提炼出标签分布的知识。

步骤4:从模型压缩效果和目标识别效果两个维度综合分析模型,得到无人机图像实时多目标识别的最优实现模型。使用的全网络剪枝策略是以全局阈值为指标找出卷积层的掩码信息,对每组直连操作,将相连的各个卷积层的剪枝掩码求并集,用融合后的掩码来决定剪枝。这种方法对每一层保留通道都进行制约。在执行时对激活偏移值添加处理可以降低剪枝带来的精度损失。分别以0.3、0.35、0.4、0.45、0.5、0.55的比例来进行剪枝,剪枝后的网络和模型的模型压缩指标和目标识别性能指标分别如表1和表2所示。

表1全网络剪枝策略的模型压缩指标实验结果

表2全网络剪枝策略的目标识别指标实验结果

从表1、表2看出以微调后精度不下降为条件,0.35比例的全网络通道剪枝的模型压缩效果最好,剪枝后参数量减少了2928万,计算量减少26.4BFLOPs,目标识别模型的大小减少了111.74M,网络运行内存减少了0.67G,网络前向推理时间减少3ms。

相关技术
  • 一种基于通道剪枝的无人机图像实时目标检测方法
  • 一种基于无人机平台的图像动目标实时检测方法
技术分类

06120113083146