掌桥专利:专业的专利平台
掌桥专利
首页

一种基于解耦特征和对抗特征的知识蒸馏方法

文献发布时间:2023-06-19 19:27:02


一种基于解耦特征和对抗特征的知识蒸馏方法

技术领域

本发明属于目标检测技术领域,具体涉及一种基于解耦特征和对抗特征的知识蒸馏方法。

背景技术

目前在电力巡检领域,通过无人机巡检的方式能够减少运检人员登杆检查操作的工作量,能够快速且准确地判断缺陷情况,无人机巡检已成为输电线路常态化的巡检方式。对于无人机拍摄的输电线路巡检图像,传统人工研判缺陷的方式难以适应电网发展及体制变革要求,探求一种可提高缺陷检测效率和准确性的输电线路无人机巡检图像处理方法是破解智能运检发展难题的必由之路。将深度学习技术引入到输电线路航拍图像处理中,并充分挖掘大数据中输电线路关键部件的先验专业知识,具有重要的实用价值,能够从技术层面上保障电网的安全稳定运行,提高输电线路巡检的效率。然而在实际环境中,用无人机采集的电力巡检图像分辨率较高,而传统的深度学习目标检测网络为了保证检测准确率,具有大量的权值参数,模型占用空间也随之变大,这就对现有的硬件设备提出了更高的要求。研究高效率的网络模型可以进一步精简网络,去除冗余的结构和参数,同时也可以进一步提升网络性能,加速深度网络的产业化发展。为减少网络计算量,提高目标检测效率,需要为无人机巡检缺陷智能识别设计网络参数较少的检测网络,同时不损失网络性能。

在基于深度学习的目标检测中,往往采用模型压缩的方法来提高模型速度。知识蒸馏技术是一种出色的模型压缩技术,可以把大型网络压缩成小型网络,在训练阶段训练一个高精度的大型网络,部署阶段利用大型网络蒸馏出的小型网络进行部署,计算代价小,但是网络精度可以媲美大型网络。

公开号为CN112200062B的中国发明专利,公开了一种基于神经网络的目标检测方法、装置、机器可读介质及设备,其提出一种基于神经网络的目标检测方法,包括:构建教师网络;通过样本图像集训练所述教师网络;构建学生网络,其中所述学生网络的参数量小于所述教师网络的参数量;在采用知识蒸馏提取所述教师网络训练获得的知识并迁移到所述学生网络的过程中,通过样本图像集对所述学生网络进行训练;通过训练后的学生网络,对输入的图像进行目标检测。该发明虽然通过知识蒸馏简化了用于目标检测的神经网络,但是其知识蒸馏仅正对目标区域进行蒸馏,这会导致简化后的学生网络对目标特征和背景特征的区分度不够,当目标特征和背景特征中的部分特征信息较接近时,不能做到有效区分;而且对待检测图像的的全局特征把握也不够,检测过程中可能会漏掉有用信息,这些都会降低目标检测的精度。

发明内容

本发明的目的是为了解决背景技术中提及的问题,提供一种基于解耦特征和对抗特征的知识蒸馏方法,在模型中同时并有侧重的使用图片中的目标区域和背景区域的信息,帮助模型蒸馏学到更多有用的信息,同时,使用对抗网络使得简单网络学习复杂网络输出的特征图值的分布尽量趋近一致,以此达到在提高缺陷检测速度的同时,保证检测精度的效果。

为实现上述技术目的,本发明采取的技术方案为:

一种基于解耦特征和对抗特征的知识蒸馏方法,包括以下步骤:

S1、构建教师网络和学生网络;使用样本图像对教师网络训练,直至教师网络达到设定要求;

S2、将样本图像输入步骤S1训练完成的教师网络,以及学生网络,得到两个网络的骨干网络输出的第一特征图;

S3、将步骤S2得到的第一特征图作为解耦特征模块的输入,根据解耦特征模块的结果进行梯度回传更新学生网络参数;所述解耦特征模块用于让学生网络同时分别学习教师网络的目标特征信息和背景特征信息;

S4、将步骤S2得到的第一特征图作为对抗特征模块的输入,训练对抗特征模块,同时根据对抗特征模块结果进行梯度回传更新学生网络参数;所述对抗特征模块用于让学生网络学习教师网络输出的特征图的全局分布特性;

S5、将样本图像输入步骤S1训练完成的教师网络和步骤S4训练后的学生网络,将两个网络的骨干网络输出的第一特征图输入到各自的区域建议网络,得到包含分类和回归的候选框的第二特征图;将上述第二特征图作为候选框特征学习模块的输入,根据候选框特征学习模块的结果进行梯度回传并更新学生网络的参数;所述候选框特征学习模块用于让学生网络学习教师网络候选框的特征信息;

S6、检测头网络根据第二特征图中的候选框截取的信息,做进一步的分类和回归,得到最终的检测结果;并根据结果进行梯度回传更新学生网络的参数;

S7、重复步骤S2-S6,直至学生网络达到设定要求。

作为优选,所述教师网络和学生网络都采用目标检测网络Faster RCNN,其中教师网络使用骨干网络为ResNet101的Faster RCNN,学生网络使用骨干网络为ResNet18的Faster RCNN。

作为优选,所述样本图像在输入教师网络和学生网络前,短边缩放至设定长度,长边按照原图的长宽比进行相应的缩放;所述梯度回传采用随机梯度下降法。

作为优选,所述教师网络和学生网络对输入的样本图像具有相同的下采样倍数;所述学生网络最后一层设置为1*1卷积层用以调整其输出特征图的通道数等于教师网络输出特征图的通道数。

作为优选,所述样本图像人工标注有若干标注框,根据标注框为教师网络和学生网络输出的第一特征图设置一个二值掩码M,M取值为1或0,当M取值为1表示第一特征图的该区域包含的是目标特征信息;当M取值为0表示第一特征图的该区域包含的是背景特征信息;所述解耦特征模块的输出结果为解耦特征损失函数L

其中,F

作为优选,所述对抗特征模块包括鉴别器模型,鉴别器模型的输出结果是输入特征图鉴别为学生网络输出的概率P

其中,Y

作为优选,所述对抗特征模块在每次训练鉴别器模块之后,计算学生网络的骨干网络生成损失函数L

上式即为对抗特征模块的输出结果。

作为优选,所述学生网络的骨干网络的蒸馏损失函数L

L

其中,超参数λ

作为优选,所述候选框特征学习模块通过下采样,将输入的第二特征图中的候选框对应区域转变为α*α大小的第三特征图,然后根据如下公式计算学生网络的候选框特征的蒸馏损失函数L

上式即为候选框特征学习模块的输出结果;其中,N′表示教师网络或学生网络的区域建议网络输出的第二特征图像包含的候选框总数量,两者的候选框总数量相等;C′表示教师网络对应的第三特征图F

作为优选,所述步骤S7根据以下损失函数L

L

其中,λ

L

其中L

步骤S2-S6的训练过程中交替地优化学生网络中骨干网络的蒸馏损失函数L

本发明的有益效果是:

1、通过解耦特征模型提高了学生网络对目标特征信息和背景特征信息的识别能力,通过对抗特征模型提高了学生网络对全局特征信息的识别能力,这使得训练后的学生网络在简化的同时尽可能的保留了教师网络的检测性能,既能有效区分目标特征信息和背景特征信息,还能提取出尽可能多的全局特征信息,这些都保证的学生网络的检测精度,从而使得学生网络既能做到快速检测,又能保证检测结果的高精度。

2、候选框特征学习模块对学生网络的区域建议网络进行了训练,使得学生网络对候选框的分类和回归结果更加逼近教师网络,也即使得学生网络在保证高精度的识别待检测目标的各个特征时,还能准确的对候选框实现分类和回归,从而使得学生网络能够快速且高精度的实现对目标的检测。

附图说明

图1是本发明整体流程图;

图2是第一阶段学习框图。

具体实施方式

以下结合附图对本发明的实施例作进一步详细描述。

需要注意的是,发明中所引用的如“上”、“下”、“左”、“右”、“前”、“后”等的用语,亦仅为便于叙述的明了,而非用以限定本发明可实施的范围,其相对关系的改变或调整,在无实质变更技术内容下,当亦视为本发明可实施的范畴。

如图1、2所示,本发明提供一种基于解耦特征和对抗特征的知识蒸馏方法,利用知识蒸馏技术对用于复杂的深度学习模型进行压缩,通过捕捉复杂网络的知识,提高轻量化网络的性能。将包含更多知识的、复杂的网络称为教师网络,将需要学习的轻量级网络称为学生网络。主要方法是先用教师网络训练一个包含更多知识的高精度大模型,再用学生网络同时学习大模型中的知识和数据的真实标签,通过这种方法,可以将教师网络的知识转移到学生网络,从而得到一个既有小模型的速度,又有大模型的精度的模型。

本发明具体包括以下步骤:

S1、构建教师网络和学生网络,所述教师网络和学生网络都采用目标检测网络Faster RCNN,具体的,残差网络ResNet101比ResNet18具有更深的深度,能提取更加丰富的特征,所以教师网络使用骨干网络为ResNet101的Faster RCNN,学生网络使用骨干网络为ResNet18的Faster RCNN;在残差网络中,前四层作为特征提取模块,第五层连接两个全连接层用于分类和回归;

使用样本图像对教师网络训练,直至教师网络能精准检测目标信息;需注意样本图像在输入到网络之前(包括教师网络和后续步骤中的学生网络),短边缩放到600,长边按照原图的长宽比进行相应的缩放;网络中采用随机梯度下降法(也即SGD)进行优化,其中动量参数设置为0.9,权重衰减参数设置为0.0005,一次训练所使用的样本数设置为4,学习率设置为0.001。

S2、将样本图像输入步骤S1训练完成的教师网络,以及学生网络,得到两个网络的骨干网络输出的第一特征图,以供步骤S3和S4使用;所述教师网络和学生网络对输入的样本图像具有相同的下采样倍数;所述学生网络最后一层设置为1*1卷积层用以调整其输出特征图的通道数等于教师网络输出特征图的通道数。

S3、将步骤S2得到的第一特征图作为解耦特征模块的输入,根据解耦特征模块的结果进行梯度回传更新学生网络参数;所述解耦特征模块用于让学生网络同时分别学习教师网络的目标特征信息和背景特征信息,也即用于提升学生网络对目标特征信息和背景特征信息的识别能力;

所述样本图像人工标注有若干标注框,标注框标明框内是目前特征信息或背景特征信息;根据标注框为教师网络和学生网络输出的第一特征图设置一个二值掩码M,M取值为1或0,当M取值为1表示第一特征图的该区域包含的是目标特征信息;当M取值为0表示第一特征图的该区域包含的是背景特征信息;所述解耦特征模块的输出结果为解耦特征损失函数L

其中,F

S4、将步骤S2得到的第一特征图作为对抗特征模块的输入,训练对抗特征模块,所述对抗特征模块包括鉴别器模型,训练对抗特征模块就是对鉴别器模型进行训练;所述鉴别器模型的网络,在结构设计上,采用三个步长为2的卷积层进行下采样,卷积核的尺寸分别为3*3*1024、3*3*512、3*3*1,前两个卷积层后面设置一个衰减率为0.2的Leaky ReLU激活函数,为了使最后的输出为0到1的概率值,最后一个卷积层后面设置一个sigmoid激活函数。

鉴别器模型的输出结果是输入特征图(此处即为第一特征图)鉴别为学生网络输出的特征图概率P

其中,Y

同时在每次迭代训练完鉴别模型后,计算学生网络的骨干网络生成损失函数L

根据上式的计算结果(也即对抗特征模块的输出结果)进行梯度回传更新学生网络参数,来使学生网络输出的特征图越来越难以与教师网络输出的特征图区分,从而达到模拟教师网络特征图全局分布特性的目的。

需注意,步骤S3和S4不分先后,可调换顺序或者同时进行。

步骤S2-S4为学生网络的第一阶段学习,主要是对学生网络的骨干网络进行训练,使其尽可能的将教师网络骨干网络的知识迁移过来;

所述学生网络的骨干网络的蒸馏损失函数L

L

其中,超参数λ

S5、本步骤为学生网络的第二阶段学习,学生网络除了学习教师骨干网络的特征,还需要学习教师网络候选框的特征信息,称为候选框特征学习模块,以加强学生网络对于目标特征信息的学习;将样本图像输入到经步骤S4训练后的学生网络,通过其骨干网络得到输出的第一特征图F

候选框特征学习模块得到第三特征图的过程及其对应的输出结果具体如下:

所述候选框特征学习模块将两个网络(学生网络和教师网络)输入的第二特征图中的候选框对应区域,通过下采样得到大小为α*α的第三特征图,然后根据如下公式计算学生网络的候选框特征的蒸馏损失函数L

其中,N′表示教师网络或学生网络第一输出图像的候选框总数量,两者的候选框总数量相等;C′表示特征图F

通过上式的计算结果,进行梯度回传更新学生网络的区域建议网络的参数

S6、检测头网络根据第二特征图中的候选框截取的信息,做进一步的分类和回归,得到最终的检测结果;并根据结果进行梯度回传更新学生网络的检测头网络的参数;

S7、重复步骤S2-S6,直至学生网络达到设定要求:

该步骤根据以下损失函数L

L

其中,λ

L

其中L

在重复步骤S2-S6的训练,实际就是反复交替地优化学生网络中骨干网络的蒸馏损失函数L

经过以上七步训练学习,学生网络即可完成训练。

为证明本发明提出的基于解耦特征和对抗特征的知识蒸馏方法具有可行性和有效性,使用真实拍摄的输电线路无人机巡检缺陷图像进行训练和测试。样本共29124张,检测的类别包含61类,按照7:3的比例划分训练集和测试集,采用mAP值和模型大小作为模型的最终评价。实验结果如下表:

表1基于输电线路无人机巡检缺陷数据集的实验结果

从实验结果可以看出,在没有知识蒸馏的情况下(也即类型一和三),很难兼顾模型的精度和速度,当使用更深的网络ResNet101作为Faster RCNN的骨干网络时,模型的平均精度为53.1%,浮点运算数为7.6G,推理速度为13.2FPS;当我们使用更浅的ResNet18作为Faster RCNN的骨干网络时,模型的平均精度为30.3%,下降了约23%,但是模型的浮点运算数为1.8G,推理速度为76.3FPS,比大模型的推理速度提高近5倍。

当使用本发明提出的知识蒸馏模型时(也即类型二),学生网络模型在保持浮点运算数和推理速度方面的性能不变的情况下,模型的平均精度从30.3%提高到了48.3%,提高了18.4%,接近于教师网络的平均精度,达到了预期的效果,证明了本发明使用的方法的有效性。

以上仅是本发明的优选实施方式,本发明的保护范围并不仅局限于上述实施例,凡属于本发明思路下的技术方案均属于本发明的保护范围。应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理前提下的若干改进和润饰,应视为本发明的保护范围。

技术分类

06120115918595