掌桥专利:专业的专利平台
掌桥专利
首页

基于注意力指导轻量级网络的人脸关键点检测方法

文献发布时间:2023-06-19 19:28:50


基于注意力指导轻量级网络的人脸关键点检测方法

技术领域

本发明涉及图像识别领域,特别涉及一种基于注意力指导轻量级网络的人脸关键点检测方法。

背景技术

人脸关键点检测在计算机视觉中是非常关键并且重要的任务,尤其在人脸特效、人脸生成,人脸渲染等领域有广泛的应用。其检测任务包括人脸的中心和周围、眉毛、眼睛、鼻子、嘴巴、下巴等关键区域特征点。预测出人脸在图像中的坐标点,可以为人脸识别、人脸情绪识别、人脸姿态估计、美颜等应用提供关键性的依赖。因此,准确快速地检测到人脸关键点受到研究人员的关注。人脸关键点算法应用最广、精度最高的是基于深度学习的方法。早期,研究人员采用CNN的方式来获取关键点,但是效果较差。为改善检测精度和速度,优化提出了级联回归方法,采用逐步获取目标的方式,多次通过不同CNN进行特征提取用于解决局部最优问题,从而获得较精确的关键点检测。同时,不断有新的大型骨干网络在人脸关键点检测任务中应用,如沙漏网络提出同时使用多层特征、残差网络(ResNet152、ResNet101、ResNet50)和Densenet网络采用更深结构和提取方式以此提高CNN提取特征的能力。但是这些方法过于臃肿,在检测实践任务中效率较低。SimplePose网络针对检测效率偏低问题进行优化,这是一个非常轻量级的CNN检测网络,去掉复杂的级联过程和复杂的网络结构,通过ResNet和反卷积结构生成高分辨率特征图。ShuffleNet从网络结构入手,提出逐点组卷积(pointwise group convulution)和通道混洗(channel shuffle)在保障精确率损失不大的同时可以大大减少了计算成本,这些轻量级的模型是在牺牲了部分精度的情况下提高计算速率。因此,如何在保留精度的同时提高效率是一个需要权衡的问题。

发明内容

为了克服目前人脸关键点检测所使用的模型存在精度和轻量化不可兼得的技术问题,本发明提供一种基于注意力指导轻量级网络的人脸关键点检测方法。

为了实现上述技术目的,本发明的技术方案是,

一种基于注意力指导轻量级网络的人脸关键点检测方法,包括以下步骤:

步骤一,构建由教师网络和学生网络组成的训练模型;

其中教师网络包括由多个bottleneck块和CBAM注意力模块组成的编码器,以及由多层反卷积层组成的解码器;

学生网络包括由多个bottleneck块组成的编码器,以及由多层反卷积层组成的解码器;

步骤二,将用于训练的人脸图像输入至教师网络中,并基于教师网络的损失函数对教师网络进行循环迭代训练,直至达到终止训练条件;然后再将用于训练的人脸图像分别输入至学生网络和训练完成的教师网络中,并基于学生网络的损失函数来对学生网络进行循环迭代训练,直至达到终止训练条件;

步骤三,将需要进行人脸关键点检测的图像输入训练完毕后的学生网络中,从而获得人脸关键点检测结果。

所述的方法,所述的步骤一中,教师网络和学生网络中的bottleneck块包括用于将输入分为两个分支的channel split单元,且其中一个分支经第一1×1卷积、深度可分离卷积和第二1×1卷积后,与另一个不经处理的分支输入Channel Shuffle单元处理后输出特征图F∈R

所述的方法,所述的步骤一中,教师网络中的CBAM注意力模块的输入为F∈R

F

F

其中,

所述的方法,所述的步骤一中,教师网络中的反卷积层包括256个2×2的卷积核,并在最后设有一个1×1的卷积核;以通过扩大输入图像的尺寸、旋转卷积核和正向卷积来输出预测点数。

所述的方法,所述的步骤二中,教师网络的损失函数为

L

学生网络的损失函数为:

其中L

所述的方法,孤立点损失函数L

其中w是将非线性部分的范围限制在[-w,w]区间内;∈是约束非线性区域的曲率;C=w-wln(1+x/∈),为一个常数,以平滑连接分段的线性和非线性部分。

所述的方法,检测点中心点损失函数L

所述的方法,蒸馏损失函数是基于逐步像素损失函数

其中

本发明的技术效果在于,本文采用优化的深度残差结构作为教师主干网络,使用注意力机制、逐点组卷积(pointwise group convulution)和通道混洗(channel shuffle)在保障精确率损失不多的同时可以大大减少了计算成本,这些轻量级的模型是在牺牲了部分精度的情况下提高计算速率,然后进一步通过复杂但准确性高的教师网络来指导精简的学生网络进行训练,从而得到计算量少,参数少的较精确的网络模型。

附图说明

图1为本发明训练模型的总体结构图;

图2为bottleneck块结构图,其中(a)为现有bottleneck块结构示意图,(b)为本发明中bottleneck块结构示意图。

具体实施方式

本实施例采用主流的300W和WFLW数据集进行实验。其中300W数据集由HELEN、LFPW、AFWIBUG数据集组成,广泛应用于人脸关键点检测任务,其中HELEN、LFPW和AFW数据集的图像均在自然环境中采集,存在姿态变化、表情变化和部分遮挡的情况,更适用于多变的自然环境。在300W数据集中,每张人脸图像有68个标注的人脸关键点。

模型训练使用PyTorch框架,Adam优化器。学习率设置为0.002,体重衰减为0.1。在Nvidia 3090GPU上训练需要约10小时。

参见图1,本实施例中的教师网络由编码器与解码器构成。其中编码器以ResNet101结构为主体,去掉原ResNet101后面的全局平均池化和全连接层,仅保留用于特征提取的卷积结构。针对ResNet101模型在实际检测点的效果和效率问题,引入优化后的CBAM模块,加强空间和通道的信息交互,增加主体特征的注意力,使重要特征获得更高的权重表达。同时,对ResNet101和CBAM(Convolutional Block Attention Module)结构采用shuffle操作和分组、深度卷积等优化操作实现跨通道的交互和信息整合,增加网络的非线性,进一步提高网络的特征拟合能力,提高网络的表达能力并且减少整体的计算量,并使用通道数随机截取提高模型泛化能力,避免过拟合。编码器主要用于提取人脸特征。解码器包括三层反卷积层。反卷积主要应用在深度学习的计算机视觉领域,广泛用于特征图的上采样阶段,由于输入图像在通过卷积神经网络(CNN)提取特征后出于计算量考虑,输出特征图尺寸通常会变小,当任务需要将图像恢复到原来的尺寸进行计算时,实现图像由小分辨率图到大分辨率图的映射的过程。反卷积操作(Transposed Convolution)就是上采样中常见方法之一。

编码器主要由33个bottleneck块加CBAM注意力模块构成,其中图2(a)的bottleneck块是现有ResNet的核心思想残差块,设计为恒等映射的结构(1):

H(x)=F(x)+x (1)

残差结构从原来的找到输入至输出的映射变成找到输出减输入的映射,从而至少确保模型不会因为加大深度导致的退化问题。ResNet进行反向传播时,只求解链式法则前面的部分,残差支路的梯度始终为1,从而解决梯度消失的问题,进一步避免了模型训练后期由于反向传播梯度不稳定造成的精度下降。图2(b)是本发明使用的bottleneck块。相较于传统bottleneck块,加入channel split和shuffle操作,以及替换常规卷积层,替换为深度可分离卷积。以此达到随机跨通道信息交互,增强非线性的同时减少参数量。其中图2(a)参数计算量计算公式如(2)所示

M=33×D

图2(b)参数计算量计算公式如(3)所示,其中D

M=33×D

同时,加入卷积块注意力模块CBAM是一种结合空间和通道的注意力机制。在分类和检测模型上得到广泛使用,其优势在于能够序列化地在通道和空间两个维度上产生注意力特征图权重信息,然后两种特征图权重信息在与输入特征图进行相乘进行自适应特征修正,产生最后的特征图,但是缺乏通道与通道之间的交互,导致通道之间的非线性关系缺失,限制了模型的学习能力,过于关注特征表达从而导致过拟合。为解决CBAM存在的这些不足,加入channel split和shuffle操作以及替换传统卷积块为深度可分离卷积、多层1×1卷积对特征降维处理,实现跨通道的交互和信息整合,增加网络的非线性能力,提高网络的表达能力并且减少计算量,进行通道数随机截取以此提高泛化能力,最后输出的是特征注意力图A。将从ResNet101网络中提取图像I的特征图F∈R

F

F

其中,A表示特征注意力图;C、H、W分别表示特征图的通道数、高度、宽度;

解码器主要包括三层反卷积层,反卷积是一种特殊的正向卷积,先按照一定的比例通过补充来扩大输入图像的尺寸,接着旋转卷积核,再进行正向卷积。反卷积可以将低分辨率特征恢复到高分辨率特征,每层反卷积层有256个卷积核,每个卷积核的大小为2×2,步长为2。因此单次反卷积计算可以将特征注意力图A放大4倍。最后加入1×1卷积核,输出预测点数。

本实施例中的Student网络的解码器与Teacher网络一致,编码器在设计上主要考虑了效果和效率问题,主体采用Resnet50结构,包含16个bottleneck块,去掉CBAM。大大的减少了参数计算量,提高检测效率。

下面对训练中使用的损失函数作出说明。

孤立点损失函数

在人脸关键点检测任务中,不同位置的关键点回归难度不同,在训练开始阶段,所有的点误差都非常大,当训练到中后期时,大部分检测点都已经基本准确,但仍不足以满足需要,想让它回归结果更加准确,现有损失已经无法继续优化的情况下,就必须放大它的损失,而wing loss中采用对数损失能够满足孤立检测点问题。Wing loss采用分段函数方式,保证前中期大部分检测点损失的减少,同时满足后期部分孤立检测点的训练,使其不会影响到其他检测点的训练。计算公式如(10)所示

其中w是将非线性部分的范围限制在[-w,w]区间内;∈是约束非线性区域的曲率,且C=w-wln(1+x/∈)是一个常数,可与平滑的来连接分段的线性和非线性部分。∈的取值应为较小的数值,因为它可能导致网络训练变得不稳定,即可能会因为很小的误差导致梯度爆炸问题。在本实施例中,将wing loss的参数设置为w=10和∈=2。

检测点中心点损失函数

在检测任务过程中,通过损失函数计算预测结果与实际数据的差异程度来约束模型收敛,因此,损失函数至关重要。在检测任务中的检测点集合必然存在一个中心点,通过中心点的运用可以增加模型的稳定性。因此加入中心损失促进模型学习到的特征具有更好的泛化性和检测能力,通过惩罚每个预测检测点和实际检测中心的偏移,使得同一检测点的数据尽量聚合在一起。为了解决局部优化过拟合这一问题,增强模型的泛化能力,对属于同一检测点中心的特征的方差进行惩罚,即检测点中心特征P

蒸馏损失函数

教师网络学习到的知识迁移到学生网络中,将关键点检测问题看作一系列像素分类问题的集合,加入逐步像素损失函数,将教师网络的输出

/>

其中

故综合上述损失函数,本实施例中的教师网络的损失函数如公式(13)所示,学生网络的损失函数公式如(14)所示

L

。/>

相关技术
  • 用于人脸关键点网络检测模型的训练方法、人脸关键点检测方法、装置
  • 基于卷积网络的轻量级人脸关键点检测方法、系统及存储介质
技术分类

06120115924662