掌桥专利:专业的专利平台
掌桥专利
首页

一种基于互学习知识蒸馏的ResNet模型优化算法

文献发布时间:2024-04-18 20:02:18


一种基于互学习知识蒸馏的ResNet模型优化算法

技术领域

本发明涉及计算机视觉技术领域,具体为一种基于互学习知识蒸馏的ResNet模型优化算法。

背景技术

随着卷积神经网络深度的不断加深,残差网络在图像分类任务中获得了较高的分类准确率,但在相同的网络速度的条件下,ResNet网络模型分类精度仍有一定的提升空间。

1.现有的技术方案

在传统知识蒸馏算法中,遵照“教师-学生”的训练模式。由教师网络T和学生网络S两部分组成。教师网络负责传授知识,学生网络的任务是尽可能多地学习知识。

2.存在的缺点

(1)随着网络深度的加深,ResNet模型的精度提升越来越困难。

(2)传统知识蒸馏中当学生网络太小时,很难使教师网络蒸馏成功。

(3)传统知识蒸馏通常不能满足一些多任务、多领域学习场景的需求,其原因在于多个任务/领域模型都需要参与训练过程,并且从其余任务/领域中学习有利知识,这就要求多个模型必须同时在线参与蒸馏学习。

(4)传统知识蒸馏不能保证教师模型与学生模型的学习过程相匹配,也不能根据学生模型的学习状态实时调整教师模型的知识提炼过程,如果训练完备的教师模型和学生模型的预测性能差距很大,则会影响学生模型在初始阶段的学习。

发明内容

本发明提供了一种可增强ResNet网络的鲁棒性,并通过两个学生网络互相学习,从而提升整体分类精度的基于互学习知识蒸馏的ResNet模型优化算法,来解决上述现有技术中存在的问题。

为实现上述目的,本发明提供如下技术方案:一种基于互学习知识蒸馏的ResNet模型优化算法,包括以下步骤:

步骤1、输入图像训练集以及标签集到原始ResNet网络,并引入ACNet,即利用ACB模块优化原始的ResNet网络,构成AC-ResNet学生网络模型;

步骤2、并在相同的条件下,初始化学生网络S1和学生网络S2,并让学生网络S1和学生网络S2相互学习;

步骤3、从训练集中随机抽取数据,分别计算出学生网络S1和学生网络S2的预测概率;

步骤4、采用KL散度来衡量两个学生网络的损失函数,并通过学生网络S1和学生网络S2的损失函数定义公式对学生网络S1和学生网络S2进行更新:

步骤5、然后对学生网络S1和学生网络S2的预测概率进行更新;

步骤6、重复以上步骤,直到学生网络S1和学生网络S2收敛,最后输出标签。

优选的,所述步骤1中,引入ACNet通过用非对称卷积ACB模块优化原始ResNet网络中的标准3×3卷积核,对新构成的网络参数进行训练直至收敛,具体包括:

1)ACB模块采用3×3,1×3和3×1三个卷积块并行操作替换ResNet网络中的每一个残差块3×3卷积,并训练到网络收敛;

2)训练完成之后,将ACB模块中的非对称卷积核添加进ResNet网络中3×3卷积相应的位置上,从而与ResNet保持相同的网络结构。

优选的,其中ACB等效融合为标准方形核包括BN融合过程和分支融合过程:

1)BN融合:为减少过拟合,加快网络的训练速度,对于非对称卷积块ACB,当其中3×3,1×3和3×1的卷积块进行普通卷积操作之后,再进行BN操作:

其中上式表示,在尺寸U×V,通道数为C的输入特征图M∈R

BN操作后,利用二维卷积核之间的可加性原理,如下式中,将尺寸大小不一的卷积核融合转换,从而产生一个具有相同输出的等效卷积核;

其中,I代表一个被裁剪或者填充的矩阵;K

2)分支融合:非对称卷积核添加到方形核对应的位置之后,将三个BN融合分支合并,形成一个新的标准方形核,达到与原始ResNet网络相同的输出。

优选的,对于第j个卷积核,F'

其中,O

优选的,所述步骤3中,具体步骤包括在第i个样本数据上,两个学生网络S1和S2的预测输出分别表示为

优选的,其中,为了衡量两个学生网络S1和S2预测概率

同样的,S2对S1的KL损失表示为:

其中,

优选的,在每个学生网络在训练过程中不仅模仿真实标签,还拟合另一个学生网络的预测输出P

S1网络的损失函数定义为:

同理,S2网络的损失函数定义为:

其中,λ为调节两个损失函数的超参数,y

与现有技术相比,本发明的有益效果:

1、本发明中,优化了ResNet网络结构,把ACB模块嵌入到ResNet网络中,来优化ResNet中的残差结构;用非对称卷积ACB模块优化原始ResNet网络中的标准3×3卷积核,对新网络参数进行训练直至收敛。从而更好地获取输入图像的具体特征,有效的增强了ResNet网络的鲁棒性,使得ResNet在图像分类中取得更高的分类精度。

2、本发明中,在借鉴传统的知识蒸馏的基础上,转变了教师—学生训练架构,引入了基于互学习知识蒸馏的算法对ResNet模型进行精度上的优化,选择交叉熵损失函数和KL损失作为ResNet图像分类中的损失函数,让两个学生网络在没有教师网络的条件下自己学习,利用其结构之间的差异性互相指导,优势互补,提升了模型的分类精度。

附图说明

附图用来提供对本发明的进一步理解,并且构成说明书的一部分,与本发明的实施例一起用于解释本发明,并不构成对本发明的限制。

在附图中:

图1是传统知识蒸馏流程示意图;

图2是本发明互学习框架的示意图;

图3是本发明互学习知识蒸馏算法的示意图。

具体实施方式

以下结合附图对本发明的优选实施例进行说明,应当理解,此处所描述的优选实施例仅用于说明和解释本发明,并不用于限定本发明。

如图1所示,在传统知识蒸馏算法中,遵照“教师-学生”的训练模式,由教师网络T和学生网络S两部分组成。教师网络负责传授知识,学生网络的任务是尽可能多地学习知识,知识蒸馏过程主要分为两步:

1)训练原始网络,即教师网络T。对于给定的数据集Dataset,首先对教师网络T进行训练,经过softmax层之后得到每一个样本数据X

2)训练小网络,即学生网络S。目的是让学生网络S学习到教师网络T强大的泛化能力,尽可能的将学生网络的结果接近教师网络。学生网络S的对应软标签为:

那么传统知识蒸馏方法中,学生网络的损失函数定义如下:

L

其中,L

而本发明中则在基干网络中引入了ACNet;如图2-图3所示,一种基于互学习知识蒸馏的ResNet模型优化算法,包括以下步骤:

步骤1、输入图像训练集train以及标签集y到原始ResNet网络,并引入ACNet,即利用ACB模块优化原始的ResNet网络,构成AC-ResNet学生网络模型;

引入ACNet通过用非对称卷积ACB模块优化原始ResNet网络中的标准3×3卷积核,对新构成的网络参数进行训练直至收敛,具体包括:

1)ACB模块采用3×3,1×3和3×1三个卷积块并行操作替换ResNet网络中的每一个残差块3×3卷积,并训练到网络收敛;

2)训练完成之后,将ACB模块中的非对称卷积核添加进ResNet网络中3×3卷积相应的位置上,从而与ResNet保持相同的网络结构。

其中ACB等效融合为标准方形核包括BN融合过程和分支融合过程:

1)BN融合:为减少过拟合,加快网络的训练速度,对于非对称卷积块ACB,当其中3×3,1×3和3×1的卷积块进行普通卷积操作之后,再进行BN操作:

其中上式表示,在尺寸U×V,通道数为C的输入特征图M∈R

BN操作后,利用二维卷积核之间的可加性原理,如下式中,将尺寸大小不一的卷积核融合转换,从而产生一个具有相同输出的等效卷积核;

其中,I代表一个被裁剪或者填充的矩阵;K

2)分支融合:非对称卷积核添加到方形核对应的位置之后,将三个BN融合分支合并,形成一个新的标准方形核,达到与原始ResNet网络相同的输出;

对于第j个卷积核,F'

其中,O

步骤2、并在相同的条件下,初始化学生网络S1和学生网络S2,并让学生网络S1和学生网络S2相互学习;

步骤3、从训练集train中随机抽取数据x,分别计算出学生网络S1和学生网络S2的预测概率;

具体步骤包括在第i个样本数据上,两个学生网络S1和S2的预测输出分别表示为

其中,为了衡量两个学生网络S1和S2预测概率

同样的,S2对S1的KL损失表示为:

其中,

步骤4、采用KL散度来衡量两个学生网络的损失函数,并通过学生网络S1和学生网络S2的损失函数定义公式对学生网络S1和学生网络S2进行更新:

在每个学生网络在训练过程中不仅模仿真实标签,还拟合另一个学生网络的预测输出P

S1网络的损失函数定义为:

同理,S2网络的损失函数定义为:

其中,λ为调节两个损失函数的超参数,y

步骤5、然后对学生网络S1和学生网络S2的预测概率进行更新;

步骤6、重复以上步骤,直到学生网络S1和学生网络S2收敛,最后输出标签。

其中,本发明还可以通过基于注意力机制的知识蒸馏方法,或者利用GhostNet模块增强卷积结构的特征提取能力,去除模型的冗余信息;利用激活注意力蒸馏的方法强化学生网络的特征提取能力,改善学生网络的分类精度,从而得到最终的分类效果。

最后应说明的是:以上所述仅为本发明的优选实例而已,并不用于限制本发明,尽管参照前述实施例对本发明进行了详细的说明,对于本领域的技术人员来说,其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

技术分类

06120116576034