掌桥专利:专业的专利平台
掌桥专利
首页

一种基于知识蒸馏的模型压缩方法、系统及计算机设备

文献发布时间:2024-04-18 19:58:53


一种基于知识蒸馏的模型压缩方法、系统及计算机设备

技术领域

本发明涉及模型压缩领域,尤其涉及一种基于知识蒸馏的模型压缩方法、系统及计算机设备。

背景技术

深度学习网络因对目标的变化具有很好的鲁棒性而受到广泛关注并在许多领域上得到了迅速发展。然而,深度学习的网络经常需要庞大的计算资源,而这些资源在边缘计算和移动计算等场景中是难以实现的。为了使深度学习网络能够高效地部署在资源受限的设备上以满足各种应用场景的需求,研究人员开始考虑研究获得高效网络的方法。其中,知识蒸馏作为一种模型压缩方法,已经成为深度学习领域的研究热点。

知识蒸馏是利用一个复杂的教师网络来指导一个简单的学生网络学习知识的方法。通过知识蒸馏,学生网络可以在保持较小模型大小和较低计算复杂度的同时,提高其在目标任务上的性能。现有的知识蒸馏方法主要集中在三个关键技术上来提高学生网络的性能。第一个技术是将教师网络中的哪些类型的知识传授给学生网络。在知识蒸馏中,知识是一个抽象的概念。网络的梯度、输出和中间特征都可以归类为知识。不同类型的知识包含了不同层次和不同形式的信息,对于学生网络的提升有不同程度和不同方面的影响。因此,如何从教师网络中提取和传递有效和有用的知识是一个重要且具有挑战性的问题。第二个技术是将教师网络在哪些网络层上的知识传授给学生网络。教师网络的网络层数通常比学生网络的多。因此,有必要从教师网络筛选出最具代表性和最适合学生网络学习的网络层,然后提取它们的知识。不同层次上的知识反映了不同抽象程度和不同维度上的特征信息,对于学生网络在目标任务上表现出不同特性和不同能力有着重要作用。因此,如何选择合适的层次来进行知识蒸馏是一个值得研究和探讨的问题。第三个技术是如何将教师网络的知识传授给学生网络。知识蒸馏是通过最小化特征之间或输出之间或梯度之间等等之间某种距离或相似度或相关性等等之间某种指标来实现知识从教师网络迁移到学生网络中。因此,如何设计合适且有效地损失函数来衡量这些指标是一个关键且困难的问题。此外,如何设计合适且有效的蒸馏策略来实现这些指标的优化也是一个重要且复杂的问题。

然而,这些方法都是以教师为中心的知识蒸馏方法,在人类的教育中也被称为知识“填鸭式”教育,即教师直接向学生灌输知识。换言之,以教师为中心的教育方法中的学生只能充当一个观众,但是,这会导致该学生在学习过程中失去了自主性和主动性。同时,以教师为中心的知识蒸馏方法存在以下几个缺点:(1)忽略了学生网络在不同阶段、不同样本和不同任务上可能存在的知识差距和知识遗忘问题,即学生网络需要补充或巩固哪些知识。例如,学生网络在某个时期可能需要重点学习某些样本或某些任务上的知识,而在另一个时期可能需要重点复习某些样本或某些任务上的知识。如果教师网络不能及时发现和满足学生网络的这些需求,那么学生网络就会出现知识缺失或知识遗忘的问题,从而影响其在目标任务上的性能。(2)它们没有充分利用教师网络的容量空间来提高学生网络的性能,即教师网络可以根据学生网络的反馈来调整自己的学习和传授策略。例如,教师网络可以根据学生网络在验证集上的表现来判断哪些知识是学生网络已经掌握了的,哪些知识是学生网络还没有掌握了的,然后根据这些信息来调整自己的输出或特征或梯度等等,以便更好地传递给学生网络所需要的知识。如果教师网络不能根据学生网络的反馈来自适应地改变自己,那么教师网络就会浪费自己的容量空间来传递一些无用或冗余的知识,从而降低了知识蒸馏的效率。(3)它们没有考虑到学生网络在掌握一定知识后可以主动地去探索更具挑战性和丰富性的知识,即学生网络可以根据自己的进步情况来调整自己的学习难度和速度。例如,当学生网络已经掌握了一些基础和简单的知识后,它可以尝试去学习一些更高级和复杂的知识,从而提高自己在目标任务上的泛化能力和鲁棒性。如果教师网络不能给予学生网络足够地挑战和刺激,那么学生网络就会陷入一个舒适区而停止进步,从而限制了其在目标任务上的潜力。因此,以教师为中心的知识蒸馏方法不能有效地模仿人类教育中以学生为中心的教育理念,即教师无法根据学生的需求和特点来设计和实施教学活动,从而无法提高学生的主动性和自主性,无法充分利用教师网络的容量空间来提高学生网络的性能,降低知识蒸馏的效率和模型压缩的准确性。

发明内容

本发明提供了一种基于知识蒸馏的模型压缩方法、系统及计算机设备,实现以学生为中心的知识蒸馏,提高学生网络的性能,提升模型压缩的准确性和压缩效果。

为了解决上述技术问题,本发明实施例提供了一种基于知识蒸馏的模型压缩方法,包括:

基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型;各所述验证样本的知识对当前学生网络的重要性通过P ID参数优化以及模糊策略分析所述学生网络模型在所述验证集上的表现来确定;

基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新所述教师网络模型和所述学生网络模型,获得压缩后的学生网络模型;其中,各所述训练样本的知识对当前学生网络的重要性根据各所述训练样本的教师学习差异而得到。

实施本发明实施例,基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型;各所述验证样本的知识对当前学生网络的重要性通过P ID参数优化以及模糊策略分析所述学生网络模型在所述验证集上的表现来确定;基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新所述教师网络模型和所述学生网络模型,获得压缩后的学生网络模型;其中,各所述训练样本的知识对当前学生网络的重要性根据各所述训练样本的教师学习差异而得到。与以教师为中心的知识蒸馏方法不同,本发明以学生为中心的知识蒸馏,使得教师网络学习和传授给学生网络的都是学生网络所需要的知识,使得学生网络在教师网络的帮助下可以主动地学习需要的知识,从而提高了知识蒸馏的效果和效率,即提升模型压缩的准确性和压缩效果,能够最大程度地利用教师网络的容量空间来提高学生网络的性能,提高压缩后的学生网络模型的性能。

作为优选方案,所述基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型,具体为:

根据第一训练任务需求,确定若干个第一时期,在每个所述第一时期中,基于所述验证集和各所述验证样本的知识对当前学生网络的重要性,将所述验证集的各所述验证样本输入所述教师网络模型进行训练更新,直至满足第一预设训练结束条件,停止训练所述教师网络模型;

其中,在每个所述第一时期中,所述验证集中的每一个样本有且仅有一次被用于所述教师网络模型的训练;每个所述第一时期中包括若干次迭代。

作为优选方案,所述在每个所述第一时期中,基于所述验证集和各所述验证样本的知识对当前学生网络的重要性,将所述验证集的各所述验证样本输入所述教师网络模型进行训练更新,具体为:

在当前的第一时期对应的当前迭代中,从所述验证集中随机筛选出若干个验证样本,确定各当前的验证样本,将各所述当前的验证样本输入所述教师网络模型,通过交叉熵损失法,计算所述教师网络模型在各所述当前的验证样本下对应的第一损失值,具体为:

其中,L

根据各所述当前的验证样本的知识对当前学生网络的重要性和各所述第一损失值,获得在所述当前的第一时期的当前迭代下所述教师网络模型的第二损失值,公式为:

其中,L为所述教师网络模型的第二损失值,w

基于所述教师网络模型的第二损失值,在所述验证集上将所述教师网络模型的独立训练参数进行梯度下降更新,更新所述教师网络模型,公式为:

其中,θ

作为优选方案,所述各所述验证样本的知识对当前学生网络的重要性通过PID参数优化以及模糊策略分析所述学生网络模型在所述验证集上的表现来确定,具体为:

在所述当前的第一时期中,将各所述验证样本输入所述学生网络模型,通过所述交叉熵损失法,计算在所述当前第一时期下所述学生网络模型的第三损失值;

根据在当前第一时期的前一时期下所述学生网络模型的第三损失值,通过基于CL的模糊策略动态调整PID超参数,得到PID参数;其中,所述PID参数包括比例参数、积分参数和微分参数;

根据在所述当前第一时期下所述学生网络模型的第三损失值和所述优化的PID参数,通过PID控制算法,计算在所述当前第一时期下所述学生网络模型的反馈值,公式为:

其中,u(t)为在所述当前第一时期下所述学生网络模型的反馈值,e(t)为在所述当前第一时期下每次迭代中每个所述验证样本的第三损失值,σ(·)为放缩函数,N

其中,所述放缩函数根据当前的验证样本的累积误差的平均值获得,具体为:

式中,σ(x)为所述放缩函数,x为所述当前的验证样本的累积误差的平均值,a为所述平均值的归一化形式;

根据所述当前第一时期下所述学生网络模型的反馈值,计算所述当前的验证样本的知识对当前学生网络的重要性,具体为:

其中,w

作为优选方案,所述根据在当前第一时期的前一时期下所述学生网络模型的第三损失值,通过基于CL的模糊策略动态调整PID超参数,得到PID参数,具体为:

根据在当前第一时期的前一时期下所述学生网络模型的第三损失值,量化所述学生网络模型在所述当前第一时期的前一时期中的平均表现,公式为:

其中,

根据所述学生网络模型在所述验证集上的最大损失值,划分所述平均表现的隶属函数,具体为:

m=[m

式中,m为所述隶属函数,m

将所述平均表现输入到所述隶属函数中,计算当前的隶属函数值;

根据第一PID超参数,设置所述比例参数的变量范围,将所述比例参数的变量范围的中值作为所述比例参数的初始值,获得初始比例参数,具体为:

kp=[0,Δp,2Δp,3Δp,4Δp,5Δp,6Δp]∈R

K

其中,kp为所述比例参数的变量范围,Δp为所述第一PID超参数,K

根据所述比例参数的变量范围和所述当前的隶属函数值,进行去模糊化计算,获得所述比例参数,具体为:

其中,K

根据所述比例参数和所述初始比例参数,将第二PID超参数和第三PID超参数进行调整,获得所述积分参数和所述微分参数,具体为:

K

K

其中,K

作为优选方案,所述基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新所述教师网络模型和所述学生网络模型,获得压缩后的学生网络模型,具体为:

根据所述教师网络模型在所述验证集上最近一次更新前和更新后的模型参数,确定更新前的教师网络模型和所述更新后的教师网络模型;

根据第二训练任务需求,确定若干个第二时期;在每个所述第二时期中,将所述训练集的各所述训练样本输入所述更新前的教师网络模型和所述更新后的教师网络模型,并根据各所述当前的训练样本的教师学习差异,确定各所述训练样本的知识对当前学生网络的重要性,并基于所述训练集和各所述训练样本的知识对当前学生网络的重要性,将所述训练集的各所述训练样本输入所述教师网络模型和所述学生网络模型进行训练更新,直至满足第二预设训练结束条件,停止训练所述教师网络模型和所述学生网络模型,得到训练好的学生网络模型,将所述训练好的学生网络模型作为所述压缩后的学生网络模型。

作为优选方案,所述在每个所述第二时期中,将所述训练集的各所述训练样本输入所述更新前的教师网络模型和所述更新后的教师网络模型,并根据各所述当前的训练样本的教师学习差异,确定各所述训练样本的知识对当前学生网络的重要性,具体为:

在当前的第二时期对应的当前迭代中,从所述训练集中随机筛选出若干个训练样本,确定各当前的训练样本;

将所述训练集的各所述当前的训练样本分别输入更新前的教师网络模型和所述更新后的教师网络模型,通过交叉熵损失法,分别计算所述更新前的教师网络模型在各所述当前的训练样本下对应的第四损失值和所述更新后的教师网络模型在各所述当前的训练样本下对应的第五损失值;

将在各所述当前的训练样本下对应的所述第四损失值与所述第五损失值进行差值计算,获得各所述当前的训练样本的教师学习差异,具体为:

其中,Δl

根据各所述当前的训练样本的教师学习差异,确定各所述当前的训练样本的知识对当前学生网络的重要性,公式为:

式中,

作为优选方案,所述基于所述训练集和各所述训练样本的知识对当前学生网络的重要性,将所述训练集的各所述训练样本输入所述教师网络模型和所述学生网络模型进行训练更新,具体为:

将各所述当前的训练样本输入所述教师网络模型和所述学生网络模型,通过所述交叉熵损失法,得到所述教师网络模型在各所述当前的训练样本下对应的第六损失值,并根据各所述当前的训练样本的知识对当前学生网络的重要性和各所述第六损失值,获得在所述当前的第二时期的当前迭代下所述教师网络模型的第七损失值,公式为:

式中,

基于所述教师网络模型的第七损失值,在所述训练集上将所述教师网络模型的独立训练参数进行梯度下降更新,更新所述教师网络模型,公式为:

其中,θ

将所述训练好的教师网络模型向前传播至所述学生网络模型,并基于KL散度,最小化教师网络和学生网络之间的输出分布,获得各所述训练样本在所述学生网络模型的知识蒸馏的第八损失值;

其中,

根据各所述当前的训练样本的知识对当前学生网络的重要性和各所述训练样本在所述学生网络模型的知识蒸馏的第八损失值,获得在所述当前的第二时期的当前迭代下所述学生网络模型的第九损失值,公式为:

其中,

基于所述学生网络模型的第九损失值,在所述训练集上将所述学生网络模型的独立训练参数进行梯度下降更新,更新所述学生网络模型,公式为:

其中,θ

为了解决相同的技术问题,本发明实施例还提供了一种基于知识蒸馏的模型压缩系统,包括:教师网络训练更新模块和学生网络训练更新模块;

其中,所述教师网络训练更新模块用于基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型;各所述验证样本的知识对当前学生网络的重要性通过PID参数优化以及模糊策略分析所述学生网络模型在所述验证集上的表现来确定;

所述学生网络训练更新模块用于基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新所述教师网络模型和所述学生网络模型,获得压缩后的学生网络模型;其中,各所述训练样本的知识对当前学生网络的重要性根据各所述训练样本的教师学习差异而得到。

为了解决相同的技术问题,本发明实施例还提供一种计算机设备,包括处理器和存储器,存储器用于存储计算机程序,计算机程序被处理器执行时实现基于知识蒸馏的模型压缩方法。

附图说明

图1:为本发明提供的一种基于知识蒸馏的模型压缩方法的一种实施例的流程示意图;

图2:为本发明提供的一种基于知识蒸馏的模型压缩方法的一种实施例的知识蒸馏的完整训练和测试的流程图;

图3:为本发明提供的一种基于知识蒸馏的模型压缩方法的一种实施例的基于验证集教师网络模型训练更新示意图;

图4:为本发明提供的一种基于知识蒸馏的模型压缩方法的一种实施例的教师网络通过训练集传授学生网络需要的知识示意图;

图5:为本发明提供的一种基于知识蒸馏的模型压缩系统的一种实施例的结构示意图。

具体实施方式

下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

实施例一

请参照图1,为本发明实施例提供的一种基于知识蒸馏的模型压缩方法的流程示意图。本实施例的模型压缩方法适用于教师网络模型的压缩,本实施例通过样本的知识对当前学生网络的重要性,进行以学生为中心的知识蒸馏,提高学生网络的性能,提升模型压缩的准确性和压缩效果。该模型压缩方法包括步骤101至步骤102,各步骤具体如下:

步骤101:基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型;各验证样本的知识对当前学生网络的重要性通过PID参数优化以及模糊策略分析学生网络模型在验证集上的表现来确定。

在本实施例中,教师网络模型是未压缩前较复杂的深度学习网络,学生网络模型是教师网络模型对应的压缩后轻量级深度学习网络,学生网络模型可以在保持较小模型大小和较低计算复杂度的同时,提高其在目标任务上的性能。生成轻量级的学生网络模型,可以部署于资源有限的终端设备上,如手机,个人电脑,树莓派等计算资源和储存资源比较有限的设备。以学生为中心的知识蒸馏的模型压缩,通过验证集和训练集进行模型训练,实现模型压缩,再通过测试集测试模型压缩效果,知识蒸馏的完整训练和测试的流程,如图2所示。教师网络通过分析学生网络在验证集上的表现来确定样本

需要说明的是,对所使用符号的概要进行说明,主要符号和符号列表,如下表1所示。

表1主要符号和符号列表

可选的,基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型,具体为:

根据第一训练任务需求,确定若干个第一时期,在每个第一时期中,基于验证集和各验证样本的知识对当前学生网络的重要性,将验证集的各验证样本输入教师网络模型进行训练更新,直至满足第一预设训练结束条件,停止训练教师网络模型;

其中,在每个第一时期中,验证集中的每一个样本有且仅有一次被用于教师网络模型的训练;每个第一时期中包括若干次迭代。

在本实施例中,教师网络通过验证集学习学生网络需要的知识,基于验证集教师网络模型训练更新示意图,如图3所示。优化一个深度学习网络的参数需要经过多次迭代/时期(Iterations/Epochs)。实际训练过程中执行多次迭代,一个时期包含多次迭代。在每一个时期中,每一个样本有且仅有一次被用于训练;当每一个样本都被用于训练一次时,则这个时期的训练结束了。训练的时期次数是根据任务来确定,通常是240个时期,或者300个时期。预设的训练结束条件根据实际训练情况而确定,一般情况下,损失函数的损失值越小越好,如果长时间的损失值不再变化时,那就不需要再继续训练了,如30个迭代内的损失函数都不变,而且学习率足够小,则可以就停止训练。

可选的,在每个第一时期中,基于验证集和各验证样本的知识对当前学生网络的重要性,将验证集的各验证样本输入教师网络模型进行训练更新,具体包括步骤S11-S11,具体如下:

S11:在当前的第一时期对应的当前迭代中,从验证集中随机筛选出若干个验证样本,确定各当前的验证样本,将各当前的验证样本输入教师网络模型,通过交叉熵损失法,计算教师网络模型在各当前的验证样本下对应的第一损失值,具体为:

其中,L

在本实施例中,在每次迭代中,从验证集

/>

其中,y

S12:根据各当前的验证样本的知识对当前学生网络的重要性和各第一损失值,获得在当前的第一时期的当前迭代下教师网络模型的第二损失值,公式为:

其中,L为教师网络模型的第二损失值,w

在本实施例中,假设验证样本

S13:基于教师网络模型的第二损失值,在验证集上将教师网络模型的独立训练参数进行梯度下降更新,更新教师网络模型,公式为:

其中,θ

在本实施例中,通过损失函数使用一个梯度下降步骤,在验证集上更新教师网络。

需要说明的是,为了让教师网络准确地学习学生网络需要的知识,需要有效地评估验证样本的知识对当前学生网络的重要性w

其中N是在一个迭代(Iteration)中的样本数量,u(t)和e(t)分别是N个样本的反馈值和损失大小。σ(·)是一个使小的值变得更小,而大的值变得更大的放缩函数,n是正整数和

可选的,各验证样本的知识对当前学生网络的重要性通过PID参数优化以及模糊策略分析学生网络模型在验证集上的表现来确定,具体包括步骤S21-S24,具体如下:

S21:在当前的第一时期中,将各验证样本输入学生网络模型,通过交叉熵损失法,计算在当前第一时期下学生网络模型的第三损失值;

在本实施例中,PID参数优化以及模糊策略分析主要是对PID控制算法的比例、积分和微分单元对应的PID参数进行改进,使得PID参数能够实现以学生为中心的知识蒸馏方法。对于来自

实施本发明实施例,基于PID控制算法的知识蒸馏过程,能够有效地解决学生网络在当前薄弱的知识、困难样本学习和知识遗忘等问题,从而提高了学生网络在目标任务上的性能和鲁棒性。

S22:根据在当前第一时期的前一时期下学生网络模型的第三损失值,通过基于CL的模糊策略动态调整PID超参数,得到PID参数;其中,PID参数包括比例参数、积分参数和微分参数;

在本实施例中,基于CL的模糊策略是根据学生网络的学习状态(即第三损失值

实施本发明实施例,基于CL的模糊策略,能够自适应地调整比例、积分和微分单元的比例,从而允许学生网络在掌握一定知识后能够主动地去探索更具挑战性和丰富性的知识,从而提高学生网络的主动性和自主性。

可选的,步骤S22包括步骤S221-SS226,具体如下:

S221:根据在当前第一时期的前一时期下学生网络模型的第三损失值,量化学生网络模型在当前第一时期的前一时期中的平均表现,公式为:

其中,

在本实施例中,量化学生网络在最近一次测试中的平均表现,即e(t-1)中所有验证样本的平均损失,然后再确定当前状态的K

需要说明的是,学生网络是现在测试集上进行测试,只计算它们在测试集上的性能,不进行网络更新。在训练阶段,那么最近一次的测试就是训练中的上一个环节,也就是上一次训练结果阶段。

S222:根据学生网络模型在验证集上的最大损失值,划分平均表现的隶属函数,具体为:

/>

m=[m

式中,m为隶属函数,m

在本实施例中,根据实践经验,将

需要说明的是,在自动化学科中,模糊集通常由七个语言变量(NB、NM、NS、NS、ZO、PS、PS、PM、PB)的隶属函数(Membership Function)组成,分别代表负大(negative big)、负中(negative medium)、负小(negative small)、零(zero)、正小(positive small)、正中(positive medium)、正大(positive big)。

S223:将平均表现输入到隶属函数中,计算当前的隶属函数值;

在本实施例中,将

S224:根据第一PID超参数,设置比例参数的变量范围,将比例参数的变量范围的中值作为比例参数的初始值,获得初始比例参数,具体为:

kp=[0,Δp,2Δp,3Δp,4Δp,5Δp,6Δp]∈R

K

其中,kp为比例参数的变量范围,Δp为第一PID超参数,K

在本实施例中,根据相同的区间Δp设置K

S225:根据比例参数的变量范围和当前的隶属函数值,进行去模糊化计算,获得比例参数,具体为:

其中,K

在本实施例中,根据比例参数的变量范围和隶属函数值,执行去模糊化(Defuzzification)。去模糊化是使用中心法导出精确的K

S226:根据比例参数和初始比例参数,将第二PID超参数和第三PID超参数进行调整,获得积分参数和微分参数,具体为:

K

K

其中,K

需要说明的是,K

S23:根据在当前第一时期下学生网络模型的第三损失值和优化的PID参数,通过PID控制算法,计算在当前第一时期下学生网络模型的反馈值,公式为:

其中,u(t)为在当前第一时期下学生网络模型的反馈值,e(t)为在当前第一时期下每次迭代中每个验证样本的第三损失值,σ(·)为放缩函数,N

其中,放缩函数根据当前的验证样本的累积误差的平均值获得,具体为:

式中,σ(x)为放缩函数,x为当前的验证样本的累积误差的平均值,a为平均值的归一化形式;

在本实施例中,样本

其中,e(t)包含一个迭代中每个样本的损失大小(第三损失值)。σ(·)是一个让小的值变小,大的值变大的放缩函数,可以清晰地区分困难和容易样本的反馈值。具体地,σ(·)定义如下:

其中,x表示样本

需要说明的是,

u(t)=K

/>

S24:根据当前第一时期下学生网络模型的反馈值,计算当前的验证样本的知识对当前学生网络的重要性,具体为:

其中,w

在本实施例中,得到学生网络的反馈u(t)后,利用softmax函数计算

其中,u

步骤102:基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新教师网络模型和学生网络模型,获得压缩后的学生网络模型;其中,各训练样本的知识对当前学生网络的重要性根据各训练样本的教师学习差异而得到。

可选的,步骤102具体为:根据教师网络模型在验证集上最近一次更新前和更新后的模型参数,确定更新前的教师网络模型和更新后的教师网络模型;根据第二训练任务需求,确定若干个第二时期;在每个第二时期中,将训练集的各训练样本输入更新前的教师网络模型和更新后的教师网络模型,并根据各当前的训练样本的教师学习差异,确定各训练样本的知识对当前学生网络的重要性,并基于训练集和各训练样本的知识对当前学生网络的重要性,将训练集的各训练样本输入教师网络模型和学生网络模型进行训练更新,直至满足第二预设训练结束条件,停止训练教师网络模型和学生网络模型,得到训练好的学生网络模型,将训练好的学生网络模型作为压缩后的学生网络模型。

在本实施例中,教师网络通过训练集传授学生网络需要的知识示意,如图4所示,

可选的,在每个第二时期中,将训练集的各训练样本输入更新前的教师网络模型和更新后的教师网络模型,并根据各当前的训练样本的教师学习差异,确定各训练样本的知识对当前学生网络的重要性,具体包括步骤S31-S34:

S31:在当前的第二时期对应的当前迭代中,从训练集中随机筛选出若干个训练样本,确定各当前的训练样本;

在本实施例中,在每次迭代中,从训练集

S32:将训练集的各当前的训练样本分别输入更新前的教师网络模型和更新后的教师网络模型,通过交叉熵损失法,分别计算更新前的教师网络模型在各当前的训练样本下对应的第四损失值和更新后的教师网络模型在各当前的训练样本下对应的第五损失值;

S33:将在各当前的训练样本下对应的第四损失值与第五损失值进行差值计算,获得各当前的训练样本的教师学习差异,具体为:

其中,Δl

在本实施例中,

S34:根据各当前的训练样本的教师学习差异,确定各当前的训练样本的知识对当前学生网络的重要性,公式为:

式中,

在本实施例中,样本

可选的,基于训练集和各训练样本的知识对当前学生网络的重要性,将训练集的各训练样本输入教师网络模型和学生网络模型进行训练更新,具体包括步骤S41-S45:

S41:将各当前的训练样本输入教师网络模型和学生网络模型,通过交叉熵损失法,得到教师网络模型在各当前的训练样本下对应的第六损失值,并根据各当前的训练样本的知识对当前学生网络的重要性和各第六损失值,获得在当前的第二时期的当前迭代下教师网络模型的第七损失值,公式为:

式中,

在本实施例中,使用交叉熵来确定训练样本上教师网络的损失。设ΔLi为样本

S42:基于教师网络模型的第七损失值,在训练集上将教师网络模型的独立训练参数进行梯度下降更新,更新教师网络模型,公式为:

其中,θ

在本实施例中,使用S41中的损失函数采取一个梯度下降步骤,在训练集上更新教师网络。

S43:将训练好的教师网络模型向前传播至学生网络模型,并基于KL散度,最小化教师网络和学生网络之间的输出分布,获得各训练样本在学生网络模型的知识蒸馏的第八损失值;

其中,

在本实施例中,教师网络在样本

S44:根据各当前的训练样本的知识对当前学生网络的重要性和各训练样本在学生网络模型的知识蒸馏的第八损失值,获得在当前的第二时期的当前迭代下学生网络模型的第九损失值,公式为:

其中,

S45:基于学生网络模型的第九损失值,在训练集上将学生网络模型的独立训练参数进行梯度下降更新,更新学生网络模型,公式为:

其中,θ

在本实施例中,通过步骤S44的损失函数使用一个梯度下降步骤,在训练集上更新学生网络,β是学生网络的学习率。公式(27)和公式(32)中的学习率相同,以确保教师网络和学生网络之间的学习速度是一致的。执行多次迭代,在一个时期的训练集中的每一个训练都有且仅有一次被用于训练,以便使一个时期中的每个训练样本都可用于更新教师网络和学生网络。

本发明所解决的技术问题包括(1)根据学生网络在不同阶段、不同样本和不同任务上的表现,及时发现和满足学生网络的知识需求,从而避免或减少学生网络的知识差距和知识遗忘问题。(2)根据学生网络的反馈,自适应地调整教师网络的学习和传授策略,从而充分利用教师网络的容量空间来提高学生网络的性能。(3)根据学生网络的进步情况,适当地给予学生网络更具挑战性和丰富性的知识,从而激发学生网络的主动性和自主性。

实施本发明实施例,与以教师为中心的知识蒸馏方法不同,本发明以学生为中心的知识蒸馏,使得教师网络学习和传授给学生网络的都是学生网络所需要的知识,使得学生网络在教师网络的帮助下可以主动地学习需要的知识,从而提高了知识蒸馏的效果和效率,即提升模型压缩的准确性和压缩效果,能够最大程度地利用教师网络的容量空间来提高学生网络的性能,提高压缩后的学生网络模型的性能。

实施例二

相应地,参见图5,图5是本发明提供的基于知识蒸馏的模型压缩系统的实施例二的结构示意图。如图5所示,基于知识蒸馏的模型压缩系统包括教师网络训练更新模块501和学生网络训练更新模块502;

其中,教师网络训练更新模块501用于基于验证集和各验证样本的知识对当前学生网络的重要性,训练更新教师网络模型;各验证样本的知识对当前学生网络的重要性通过PID参数优化以及模糊策略分析学生网络模型在验证集上的表现来确定;

学生网络训练更新模块502用于基于训练集、更新后的教师网络模型和各训练样本的知识对当前学生网络的重要性,训练更新教师网络模型和学生网络模型,获得压缩后的学生网络模型;其中,各训练样本的知识对当前学生网络的重要性根据各训练样本的教师学习差异而得到。

实施本发明实施例,提出以学生为中心的知识蒸馏方法,使得学生网络在教师网络的帮助下可以主动地学习需要的知识。能够最大程度地利用教师网络的容量空间来提高学生网络的性能。基于PID控制算法的知识蒸馏,使得教师网络学习和传授给学生网络的都是学生网络所需要的知识,能够有效地解决学生网络在当前薄弱的知识、困难样本学习和知识遗忘等问题。基于CL的模糊策略,能够自适应地调整比例、积分和微分单元的比例,从而允许学生网络在掌握一定知识后能够主动地去探索更具挑战性和丰富性的知识。

另外,本申请实施例还提供一种计算机设备,计算机设备包括处理器和存储器,存储器用于存储计算机程序,计算机程序被处理器执行时实现上述任意方法实施例中的步骤。

上述的一种基于知识蒸馏的模型压缩系统可实施上述方法实施例的一种基于知识蒸馏的模型压缩方法。上述方法实施例中的可选项也适用于本实施例,这里不再详述。本申请实施例的其余内容可参照上述方法实施例的内容,在本实施例中,不再进行赘述。

以上的具体实施例,对本发明的目的、技术方案和有益效果进行了进一步的详细说明,应当理解,以上仅为本发明的具体实施例而已,并不用于限定本发明的保护范围。特别指出,对于本领域技术人员来说,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

相关技术
  • 基于知识蒸馏的检测模型的压缩方法、系统和计算设备
  • 一种基于知识蒸馏实现的无数据细粒度分类模型压缩系统及方法
技术分类

06120116513349