掌桥专利:专业的专利平台
掌桥专利
首页

一种深度卷积网络的轻量化压缩方法

文献发布时间:2023-06-19 19:28:50


一种深度卷积网络的轻量化压缩方法

技术领域

本发明涉及计算机视觉技术领域,特别是涉及一种深度卷积网络的轻量化压缩方法。

背景技术

尽管基于深度学习的网络模型相比传统基于图像处理的方法在检测精度上有着显著优势,但在移动端或芯片中进行部署时,却面临着参数量过大、运算速度过慢等问题。因此,对深度卷积网络模型的轻量化压缩具有重要应用价值。

目前,知识蒸馏是较常用的深度卷积网络压缩方法,其利用结构复杂但预测性能更好的网络作为教师网络,去训练结构简单的学生网络,最终使学生网络达到与教师网络相近的精度与鲁棒性。目前主流的知识蒸馏算法主要面向网络输出与输入相同的教师与学生网络,集中在如何设计目标损失函数,达到从教师网络向学生网络进行知识传递的目的。深度卷积神经网络在面向移动端或芯片端进行实际部署时,受计算能力的约束,除了需要对网络主干部分的权重参数进行压缩,往往也需要对模型输出进行精简以及降低模型输入图像的分辨率。因此,直接利用现有的知识蒸馏技术,无法较好的实现知识转移。

发明内容

本发明所要解决的技术问题是提供一种深度卷积网络的轻量化压缩方法,能够在获得低分辨率输入、精简输出、结构轻量的网络模型的同时保证该网络的检测精度与鲁棒性。

本发明解决其技术问题所采用的技术方案是:提供一种深度卷积网络的轻量化压缩方法,包括以下步骤:

(1)对第一网络模型的输出进行精简,并对第一网络模型的主干网络部分进行剪枝,得到第二网络模型;所述第一网络模型为训练好的深度卷积网络模型;

(2)对所述第二网络模型进行基于知识蒸馏和多分辨图像的微调训练;

(3)对所述第二网络模型进行模型剪枝,并生成第三网络模型,并对所述第三网络模型进行基于特征空间对齐和知识蒸馏的微调训练,并将训练好的第三网络模型作为最终模型。

所述步骤(1)具体包括:

(11)保留所述第一网络模型中负责图像特征提取的主干网络部分和必要的检测输出分支部分;

(12)对所述主干网络部分进行剪枝,删除掉冗余的网络模型参数,得到所述第二网络模型。

所述步骤(2)具体包括:

(21)将所述第一网络模型设置为第一教师网络,将所述第二网络模型设置为第一学生网络;

(22)利用所述第一教师网络对所述第一学生网络进行训练,并在每轮训练时,混合使用带真值的训练集图像和没有真值标注的图像;

(23)计算第一目标损失函数,驱动所述第一学生网络进行反向传播,调整所述第一学生网络的各层权值。

所述步骤(22)中利用所述第一教师网络对所述第一学生网络进行训练时,所述第一教师网络的输入为高为H,宽为W的原始图像;所述第一学生网络的输入为按照预设比例分布的与所述第一教师网络的输入相同的原始图像和经过处理的图像;所述经过处理的图像为由输入所述第一教师网络的图像进行N倍下采样,再通过填充操作生成的高为H、宽为W的图像。

所述步骤(23)中第一目标损失函数E的数学表达式为:

所述步骤(3)具体包括:

(31)对所述第二网络模型进行模型剪枝,生成第三网络模型;

(32)将所述第二网络模型作为第二教师模型,将所述第三网络模型作为第二学生模型;

(33)利用所述第二教师网络对所述第二学生网络进行训练,并在每轮训练时,混合使用带真值的训练集图像和没有真值标注的图像;

(34)在所述第二教师网络和所述第二学生网络中选取对应的特征图,分别利用像素重组机制对第二学生网络的特征图进行上采样处理,对所述第二教师网络的特征图进行下采样处理,得到两组尺寸对应的特征图集合;

(35)针对得到的两组特征图集合中每个特征图,分别计算包含空间和通道上下依赖信息的特征图;

(36)计算第二目标损失函数,驱动所述第二学生网络进行反向传播,调整第二学生网络各层权值。

所述步骤(33)中利用所述第二教师网络对所述第二学生网络进行训练时,所述第二教师网络的输入为按照预设比例分布的原始图像和经过处理的图像;所述经过处理的图像是对原始图像进行N倍下采样,再通过填充操作生成的高为H、宽为W的图像;所述第二学生网络的输入为对原始图进行N倍下采样的图像。

所述步骤(34)具体包括:

从所述第二教师网络的主干网络部分选取K个特征图T={T

利用尺寸为M

利用尺寸为M

分别利用卷积对特征图T

根据特征图

所述步骤(35)具体包括:

将特征图

将特征图

基于特征图

所述步骤(36)中的第二目标损失函数E

有益效果

由于采用了上述的技术方案,本发明与现有技术相比,具有以下的优点和积极效果:本发明利用渐进式的模型剪枝与知识蒸馏技术,可同时对模型的结构、输出、以及输入图像分辨率进行压缩,可获得低分辨率输入、精简输出、结构轻量的网络模型,并同时保证该网络的检测精度与鲁棒性。

附图说明

图1是本发明实施方式的深度卷积网络的轻量化压缩方法的流程图;

图2是本发明实施方式的深度卷积网络的模型压缩示意图。

具体实施方式

下面结合具体实施例,进一步阐述本发明。应理解,这些实施例仅用于说明本发明而不用于限制本发明的范围。此外应理解,在阅读了本发明讲授的内容之后,本领域技术人员可以对本发明作各种改动或修改,这些等价形式同样落于本申请所附权利要求书所限定的范围。

本发明的实施方式涉及一种深度卷积网络的轻量化压缩方法,该方法利用渐进式的模型剪枝与知识蒸馏技术,获得低分辨率输入、精简输出、结构轻量的网络模型,并同时保证该网络的检测精度与鲁棒性。整体流程如图1所示,模型压缩过程如图2所示,具体包括以下主要步骤:

步骤1,基于已训练好的深度卷积网络模型(简称为第一网络模型),对其输出进行精简并对其主干网络部分进行剪枝,得到新的更轻量的网络模型(简称为第二网络模型)。具体子步骤如下:

1)根据在移动端或芯片端应用的实际需求及计算能力,保留负责图像特征提取的主干网络和必要的检测输出分支,对非必要的网络检测输出分支进行剔除。如在手势检测应用中,模型首先通过主干网络对输入图像信息进行特征提取,随后通过不同的检测分支分别对手势目标在图中的二维位置、类别、关键点信息以及在三维世界中的位姿信息等进行检测。模型在移动端或芯片端进行实际部署时,可根据应用需求与计算能力,对关键点信息、位姿信息等输出分支进行剔除,在实际应用中虽失去了更细粒度的信息但仍能保证基本的手势识别功能。

2)随后对第一网络模型的主干网络部分进行剪枝,删除掉冗余的网络模型参数,得到第二网络模型。

步骤2,由于模型剪枝后,一般会造成精度下降,且第一网络模型的输入为大分辨率图像,因此由第一网络模型剪枝后得到的第二网络模型对于低分辨率图像存在兼容性较差的问题。为提高第二网络模型精度以及增强网络对低分辨率图像的兼容性,对由步骤1得到的第二网络模型进行基于知识蒸馏和多分辨图像的微调训练。具体子步骤如下:

1)将第一网络模型设置为教师网络,将第二网络模型设置为学生网络,两个网络的输入图像高均为H、宽均为W。

2)为增加第二网络模型的泛化性与鲁棒性,在知识蒸馏中增加更多训练数据集之外的图像数据,即在每轮训练时,混合使用带真值的训练集图像和没有真值标注的图像。

①第一网络模型的输入为高为H、宽为W的原始图像。

②第二网络模型的输入以一定比例在与第一网络模型输入相同图像和经过处理的图像之间进行选择。其中,经过处理的图像为由输入第一网络模型的图像进行N倍下采样,随后通过填充(padding)操作,生成高为H、宽为W的图像。

3)计算目标损失函数E,驱动第二网络模型进行反向传播,调整网络各层权值。其中,根据网络输入图像的不同对目标损失函数E进行分别加权,目标损失函数E的具体数学表达为:

其中,E

步骤3,对第二网络模型进行模型剪枝,生成第三网络模型,并对第三网络模型进行基于特征空间对齐和知识蒸馏的微调训练。具体步骤如下:

1)对第二网络模型进行模型剪枝,生成第三网络模型。

2)设置第二网络模型作为教师网络,第三网络模型作为学生网络。其中,第二网络模型的输入为高H、宽W的图像,第三网络模型的输入为高H/N、宽W/N的图像。

3)为增加第三网络模型的泛化性与鲁棒性,同样在知识蒸馏中增加更多训练数据集之外的图像数据,即在每轮训练时,混合使用带真值的训练集图像和没有真值标注的图像,对第三网络模型进行知识蒸馏训练。优选的,具体操作为:

①第二网络模型的输入为:在原始图像和经过处理的图像之间进行选择。其中,经过处理的图像为对原始图像进行N倍下采样,随后通过填充(padding)操作,生成高为H、宽为W的图像。

②第三网络模型的输入为:对原始图像进行N倍下采样的图像,即分辨率为(1/N)H×(1/N)W的图像。

4)在第二网络模型和第三网络模型中选取对应特征图,分别利用像素重组机制对第三网络模型特征图进行上采样处理,对第二网络模型特征图进行下采样处理,得到两组尺寸对齐的特征图集合,具体操作步骤如下:

①分别选取第二网络模型中主干网络部分的K个特征图T={T

②利用像素重组机制对从第三网络模型选取的各特征图S

③利用像素重组机制对从第二网络模型选取的各特征图T

④分别利用卷积对特征图T

⑤基于上述步骤,根据特征图

5)基于3)中得到的两个特征图集合中每个特征图,分别计算包含空间和通道上下文依赖信息的特征图。具体操作如下:

①将特征图

②将特征图

③基于特征图

H(x)=softmax(x

6)计算目标损失函数E

其中,

E′

E

其中,δ

不难发现,本发明利用渐进式的模型剪枝与知识蒸馏技术,可同时对模型的结构、输出、以及输入图像分辨率进行压缩,可获得低分辨率输入、精简输出、结构轻量的网络模型,并同时保证该网络的检测精度与鲁棒性。

技术分类

06120115922077