掌桥专利:专业的专利平台
掌桥专利
首页

基于生成对抗网络原型修正的少样本图像分类方法及系统

文献发布时间:2024-04-18 19:58:53


基于生成对抗网络原型修正的少样本图像分类方法及系统

技术领域

本文件涉及计算机技术领域,尤其涉及一种基于生成对抗网络原型修正的少样本图像分类方法及系统。

背景技术

近年来,深度学习技术在计算机视觉领域得到了广泛应用,包括视觉图像处理、目标检测、图像分割等。在这些应用场景中,性能优异的深度学习模型往往需要足够数量的标注图像数据进行模型的训练及更新,但是,在许多情况下,由于样本收集难度大、人工标注成本高等因素,导致无法获取足够多的标注图像用于模型训练,因此,利用少量的样本数据来对深度学习模型进行有效的训练成为了近年来热门的研究问题,这一研究工作也称少样本学习,然而,通过少量标注图像数据来训练深度学习模型往往具有较大的挑战性。为了应对这一挑战,研究人员提出了基于度量学习的方法,该方法旨在学习一个度量空间,在这一空间中属于同类的样本距离较近,不同类的样本距离较远。在基于度量学习的方法中,原型网络是一种常用的少样本图像分类算法,其通过学习每个类别中样本的特征,并计算这些特征的均值来构建类别原型,通过比较未知样本和每个类别原型的欧氏距离,实现对未知样本的分类,但由于每个类别中训练样本的不足,简单求均值得到的类别原型可能会偏离真实的类别中心,导致训练得到的模型泛化性能往往不佳,针对这一问题,研究者们通常采用数据增强的技术生成更多的训练样本,以提高模型的泛化能力。虽然该方法可以增加训练样本的数量,但并不能确保生成样本的判别性,并且,在增强的过程很可能会引入噪声,这些因素都会对模型的训练产生不利的影响。

发明内容

本发明的目的在于提供一种基于生成对抗网络原型修正的少样本图像分类方法及系统,旨在解决现有技术中的上述问题。

本发明实施例提供一种基于生成对抗网络原型修正的少样本图像分类方法,包括:

将图像数据集分为训练集、验证集和测试集,根据所述训练集进行训练得到特征嵌入网络以及生成对抗网络;

利用特征嵌入网络提取训练集的样本特征,对样本特征进行提前修正得到提前修正的类别原型,将噪声和提前修正的类别原型输入生成对抗网络,生成每个类别的伪样本特征,基于每个类别的伪样本特征得到每个类别的伪类别原型;

通过融合伪类别原型和提前修正的类别原型,得到二次修正后的类别原型,利用二次修正后的类别原型对从训练集采样的训练任务中的查询集样本进行相似性度量,得到训练集样本的分类损失,利用分类损失微调特征嵌入网络,基于验证集生成多个少样本验证任务,利用少样本验证任务对微调后的特征嵌入网络进行性能验证,获取效果最优的特征嵌入网络;

利用最优的特征嵌入网络进行少样本任务测试,对少样本任务测试中的每个类别原型进行修正,基于修正后的类别原型实现对测试任务中查询样本的分类。

本发明实施例提供一种基于生成对抗网络原型修正的少样本图像分类系统,包括:

网络模块,用于将图像数据集分为训练集、验证集和测试集,根据所述训练集进行训练得到特征嵌入网络以及生成对抗网络;

修正模块,用于利用特征嵌入网络提取所述训练集的样本特征,对样本特征进行提前修正得到提前修正的类别原型,将噪声和提前修正的类别原型输入生成对抗网络,生成每个类别的伪样本特征,基于每个类别的伪样本特征得到每个类别的伪类别原型;

调整模块,用于融合伪类别原型和提前修正的类别原型,得到二次修正后的类别原型,通过二次修正后的类别原型,对从训练集采样的训练任务中的查询集样本进行相似性度量,得到训练集样本的分类损失,利用分类损失微调特征嵌入网络,基于验证集生成多个少样本验证任务,利用少样本验证任务对微调后的特征嵌入网络进行性能验证,获取效果最优的特征嵌入网络;

分类模块,用于利用最优的特征嵌入网络进行少样本任务测试,对少样本任务测试中的每个类别原型进行修正,基于修正后的类别原型实现对测试任务中查询样本的分类。

本发明实施例还提供一种电子设备,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现上述基于生成对抗网络原型修正的少样本图像分类方法的步骤。

本发明实施例还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有信息传递的实现程序,所述程序被处理器执行时实现上述基于生成对抗网络原型修正的少样本图像分类方法的步骤。

采用本发明实施例可以包括以下有益效果:本发明实施例能够解决少样本学习中训练样本不足、初始类别原型偏离真实类别中心以及生成的伪样本特征中存在噪声干扰的问题,能够让训练得到的模型拥有更好的泛化能力,有利于改善少样本图像分类任务的准确性。

附图说明

为了更清楚地说明本说明书一个或多个实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书中记载的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。

图1是本发明实施例的基于生成对抗网络原型修正的少样本图像分类方法流程图;

图2是本发明实施例的WGAN模块训练框图;

图3是本发明实施例的基于生成对抗网络原型修正的少样本图像分类方法的模型训练框图;

图4是本发明实施例的基于生成对抗网络原型修正的少样本图像分类系统示意图。

具体实施方式

为了使本技术领域的人员更好地理解本说明书一个或多个实施例中的技术方案,下面将结合本说明书一个或多个实施例中的附图,对本说明书一个或多个实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本说明书的一部分实施例,而不是全部的实施例。基于本说明书一个或多个实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都应当属于本文件的保护范围。

方法实施例

根据本发明实施例,提供了一种基于生成对抗网络原型修正的少样本图像分类方法,图1是本发明实施例的基于生成对抗网络原型修正的少样本图像分类方法流程图,如图1所示,根据本发明实施例的基于生成对抗网络原型修正的少样本图像分类方法具体包括:

步骤S101:将图像数据集分为训练集、验证集和测试集,根据所述训练集进行训练得到特征嵌入网络以及生成对抗网络,具体包括:

将图像数据集分为训练集、验证集和测试集,根据所述训练集通过小批量数据采样方式并基于公式1所示的训练损失函数训练一个特征嵌入网络,再通过Episodictraining训练策略,根据如公式2所示的对抗训练损失函数和如公式3所示的判别性正则化项,得到如公式4所示的Wasserstein生成对抗网络(Wasserstein GenerativeAdversarial Network,WGAN)的最终训练目标:

其中,p(y=j|x

其中,L

其中,L

其中,G为WGAN中的生成器,D为判别器,L

步骤S102:利用特征嵌入网络提取训练集的样本特征,对样本特征进行提前修正得到提前修正的类别原型,将噪声和提前修正的类别原型输入生成对抗网络,生成每个类别的伪样本特征,基于每个类别的伪样本特征得到每个类别的伪类别原型,具体包括:

利用特征嵌入网络提取训练集的样本特征,对提取的样本特征进行提前修正,基于修正后的样本特征得到提前修正的类别原型,将噪声和提前修正的类别原型输入WGAN,生成每个类别的伪样本特征,计算每个类别的伪样本特征和提前修正的类别原型之间的余弦相似度,基于余弦相似度,通过softmax函数获取每个类别的伪样本特征的权重系数,将权重系数与每个类别的伪样本特征相乘,得到加权后的每个类别的伪样本特征,并对加权后的每个类别的伪样本特征求和得到每个类别的伪类别原型;

其中,利用特征嵌入网络根据公式5和公式6提取训练集的样本特征:

f

f

其中,F为特征嵌入网络,x

使用分别带有可学习参数φ、δ和ε的1×1卷积层对样本特征f

A={a

其中,exp表示自然指数函数,

将值向量

其中,f′

基于修正后的样本特征f′

其中,p

将噪声和提前修正的类别原型输入WGAN,生成如公式11所示的每个类别的伪样本特征:

其中,Z~N(0,1)表示随机采样的噪声,p

计算每个类别的伪样本特征和提前修正的类别原型之间的余弦相似度,基于余弦相似度,通过softmax函数获取如公式12所示的每个类别的伪样本特征的权重系数:

其中,cos为余弦相似度,

将权重系数与每个类别的伪样本特征相乘得到如公式13所示的加权后的每个类别的伪样本特征:

其中,

对加权后的每个类别的伪样本特征求和得到如公式14所示的每个类别的伪类别原型:

其中,J表示生成器为类别n生成伪样本特征的数量,

步骤S103:通过融合伪类别原型和提前修正的类别原型,得到二次修正后的类别原型,利用二次修正后的类别原型对从训练集采样的训练任务中的查询集样本进行相似性度量,得到训练集样本的分类损失,利用分类损失微调特征嵌入网络,基于验证集生成多个少样本验证任务,利用少样本验证任务对微调后的特征嵌入网络进行性能验证,获取效果最优的特征嵌入网络;

步骤S104:利用最优的特征嵌入网络进行少样本任务测试,对少样本任务测试中的每个类别原型进行修正,基于修正后的类别原型实现对测试任务中查询样本的分类。

以下结合本发明实施例的少样本图像分类方法的训练框图,如图2-3所示,对本发明实施例的上述技术方案进行详细说明。本发明实施例的基于生成对抗网络原型修正的少样本图像分类方法,具体包括以下步骤:

步骤1:准备当前任务的图像数据集,将该数据集划分为训练集、验证集和测试集,从训练集中随机采样N个类别,再从这N个类别中每个类别都随机采样K个样本来构成支持集,另外,从这N个类别的剩余样本中每个类别随机采样Q个样本来构成查询集,支持集和查询集构成一个少样本学习任务,这个任务通常称为N-way,K-shot任务;

步骤2:利用训练集,采用小批量(Mini-batch)的数据采样方式预训练一个特征嵌入网络,用作后续任务样本的特征提取器,再以跨任务的训练策略(Episodic training)对Wasserstein生成对抗网络(Wasserstein Generative Adversarial Network,WGAN)进行训练,得到训练好的WGAN;

步骤3:利用预训练好的特征嵌入网络提取少样本学习任务中的样本图像特征;

步骤4:使用自注意力模块对提取到的样本特征进行提前修正,获得更具判别性的样本特征,基于该样本特征,得到提前修正的类别原型;

步骤5:将随机采样的噪声和提前修正的类别原型输入进WGAN,为每个类别生成伪样本特征;

步骤6:计算每个类别的伪样本特征与提前修正的类别原型之间的余弦相似度,然后经过softmax函数获取伪样本特征的权重系数,将该权重系数与伪样本特征相乘,从而得到加权后的伪样本特征;

步骤7:对每个类别中加权后的伪样本特征进行求和,得到每个类别的伪类别原型;

步骤8:将伪类别原型和提前修正的类别原型进行融合,对提前修正的原型进行二次修正;

步骤9:利用修正后的原型,对查询集样本进行相似性度量,求得查询集样本的分类损失,微调特征嵌入网络;

步骤10:在测试阶段,对少样本学习任务中支持集里的每个类别原型都进行修正,基于修正后的原型,实现对查询集样本的分类。

本发明实施例中步骤2所述特征嵌入网络的训练损失函数为:

其中,p(y=j|x

WGAN是一种通过引入Wasserstein距离和优化训练过程的改进GAN模型,旨在解决传统GAN中的训练不稳定和模式崩溃问题,从而提高生成器的生成能力,本发明实施例中WGAN的训练损失函数分为两部分,第一部分为对抗训练损失L

其中,D

此外,为保证生成的伪样本特征的判别性,本发明实施例添加了一个判别性正则化项,以明确鼓励生成的伪样本特征与来自同一类的查询集样本特征有着高度的相关性,并将其作为WGAN训练损失函数的第二部分,其表达式如下:

其中,

从上述公式可知,生成的伪样本特征必须具有与查询样本相同类别的信息,从而确保伪样本特征的判别性,WGAN最后的训练目标如下:

本发明实施例中步骤3所述利用特征嵌入网络提取出少样本任务中的样本特征,其表达式如下:

f

f

其中,F为特征提取网络,x

本发明实施例中步骤4所述使用自注意力模块对提取到的样本特征进行提前修正,基于提前修正的样本特征,得到提前修正的类别原型,其步骤为:

首先,使用三个分别带有可学习参数φ,δ和ε的1×1卷积层对样本特征f

其中,

将修正后的样本特征f′

其中,p

本发明实施例中步骤5所述为每个类别生成伪样本特征,其表达式为:

其中,Z~N(0,1)表示随机采样的噪声,p

本发明实施例中步骤6所述伪样本特征的权重系数表达式为:

其中,cos为余弦相似度,N为类别数量,

其中,

本发明实施例中步骤7所述每个类别的伪类别原型表达式为:

其中,J表示生成器为类别n生成伪样本特征的数量,

本发明实施例中步骤8所述对类别原型进行二次修正表达式为:

其中,α、β为超参数,用于控制提前修正的原型和伪类别原型的占比权重,

本发明实施例中步骤9所述查询集样本的分类损失函数及预测概率表达式分别为:

其中,M代表度量函数,这里选用欧氏距离,N为类别数量;P(y=n|x

本发明实施例中步骤10所述对测试集中的查询样本实现分类,其表达式为:

其中,f

综上所述,本发明实施例首先利用训练集,采用Mini-batch的数据采样方式预训练一个特征嵌入网络,用作后续任务样本的特征提取器,接下来采用Episodic training的训练策略对WGAN进行训练,得到训练好的WGAN,接着,使用自注意力模块对提取到的图像样本特征进行提前修正,获得更具判别性的样本特征及其对应的类别原型,在此基础上生成器基于类别原型,为对应的类别生成伪样本特征,然后通过计算每个类别的伪样本特征与提前修正的类别原型之间的余弦相似度,并经过softmax函数获取伪样本特征的权重系数,再将该权重系数与伪样本特征相乘,得到加权后的伪样本特征,随后对每个类别中加权后的伪样本特征进行求和,得到每个类别的伪类别原型,将伪类别原型与对应的提前修正的类别原型进行融合,得到二次修正后的类别原型,最后采用度量的方式,在少样本任务中利用修正后的类别原型完成对查询样本的分类。

本发明实施例提出的基于生成对抗网络原型修正的少样本图像分类方法,其采用的基于生成对抗网络原型修正的训练算法,能够解决少样本学习中训练样本不足、初始类别原型偏离真实类别中心以及生成的伪样本特征中存在噪声样本的问题,该方法能够让训练得到的模型拥有更好的泛化能力,有利于改善少样本图像分类任务的准确性。

系统实施例

根据本发明实施例,提供了一种基于生成对抗网络原型修正的少样本图像分类系统,图4是本发明实施例的基于生成对抗网络原型修正的少样本图像分类系统示意图,如图4所示,根据本发明实施例的基于生成对抗网络原型修正的少样本图像分类系统具体包括:

网络模块40,用于将图像数据集分为训练集、验证集和测试集,根据所述训练集进行训练得到特征嵌入网络以及生成对抗网络,具体用于:

将图像数据集分为训练集、验证集和测试集,根据所述训练集通过小批量数据采样方式并基于公式1所示的训练损失函数训练一个特征嵌入网络,再通过Episodictraining训练策略,根据如公式2所示的对抗训练损失函数和如公式3所示的判别性正则化项,得到如公式4所示的Wasserstein生成对抗网络(Wasserstein GenerativeAdversarial Network,WGAN)的最终训练目标:

其中,p(y=j|x

其中,L

其中,L

其中,G为WGAN中的生成器,D为判别器,L

修正模块42,用于利用特征嵌入网络提取所述训练集的样本特征,对样本特征进行提前修正得到提前修正的类别原型,将噪声和提前修正的类别原型输入生成对抗网络,生成每个类别的伪样本特征,基于每个类别的伪样本特征得到每个类别的伪类别原型,具体用于:

利用特征嵌入网络提取训练集的样本特征,对提取的样本特征进行提前修正,基于修正后的样本特征得到提前修正的类别原型,将噪声和提前修正的类别原型输入WGAN,生成每个类别的伪样本特征,计算每个类别的伪样本特征和提前修正的类别原型之间的余弦相似度,基于余弦相似度,通过softmax函数获取每个类别的伪样本特征的权重系数,将权重系数与每个类别的伪样本特征相乘,得到加权后的每个类别的伪样本特征,并对加权后的每个类别的伪样本特征求和得到每个类别的伪类别原型;

其中,利用特征嵌入网络根据公式5和公式6提取训练集的样本特征:

f

f

其中,F为特征嵌入网络,x

使用分别带有可学习参数φ、δ和ε的1×1卷积层对样本特征f

其中,exp表示自然指数函数,

将值向量

其中,f′

基于修正后的样本特征f′

其中,p

将噪声和提前修正的类别原型输入WGAN,生成如公式11所示的每个类别的伪样本特征:

其中,Z~N(0,1)表示随机采样的噪声,p

计算每个类别的伪样本特征和提前修正的类别原型之间的余弦相似度,基于余弦相似度,通过softmax函数获取如公式12所示的每个类别的伪样本特征的权重系数:

其中,cos为余弦相似度,

将权重系数与每个类别的伪样本特征相乘得到如公式13所示的加权后的每个类别的伪样本特征:

其中,

对加权后的每个类别的伪样本特征求和得到如公式14所示的每个类别的伪类别原型:

其中,J表示生成器为类别n生成伪样本特征的数量,

调整模块44,用于融合伪类别原型和提前修正的类别原型,得到二次修正后的类别原型,通过二次修正后的类别原型,对从训练集采样的训练任务中的查询集样本进行相似性度量,得到训练集样本的分类损失,利用分类损失微调特征嵌入网络,基于验证集生成多个少样本验证任务,利用少样本验证任务对微调后的特征嵌入网络进行性能验证,获取效果最优的特征嵌入网络;

分类模块46,用于利用最优的特征嵌入网络进行少样本任务测试,对少样本任务测试中的每个类别原型进行修正,基于修正后的类别原型实现对测试任务中查询样本的分类。

本发明实施例是与上述方法实施例对应的系统实施例,各个模块的具体操作可以参照方法实施例的描述进行理解,在此不再赘述。

本发明实施例除了上述模块划分,作为优选实施例,还可以采用如下模块划分的方式:

1、数据预处理模块,用于读取图像数据及标签,把RGB的图像转换成特征张量的形式;

2、预训练网络模块,用于预训练特征嵌入网络和WGAN,通过利用训练集,采用Mini-batch的数据采样方式预训练特征嵌入网络,用作后续任务样本的特征提取器,再采用Episodic training的训练策略对WGAN进行训练,用作后续任务中类别的伪样本特征生成器;

3、自注意力模块,用于将提取到的图像样本特征进行提前修正,得到更具判别性的样本特征和类别原型,后续基于该类别原型生成更具判别性的伪样本特征;

4、伪样本特征生成模块,用于为每个类别生成伪样本特征;

5、权重生成模块,用于给每个生成的伪样本特征赋予不同的权重系数,减小生成的伪样本特征中噪声样本特征的干扰;

6、修正原型模块,用于原型的二次修正,修正后的原型相比于少样本情形下的原始原型更具判别性,更能代表其所属的类别;

7、度量模块,用于度量查询样本和修正后原型的相似度,实现对少样本任务中查询样本的分类。

综上所述,本发明实施例包括以下有益效果:

1、本发明实施例提出的基于生成对抗网络原型修正的少样本图像分类方法能够保证生成的伪样本特征的判别性,在本发明实施例中利用自注意力来对原始的图像特征进行提前修正,与直接基于原始的图像样本特征生成伪样本特征的方法相比,利用自注意力进行特征提前修正的方法可以得到更具判别性的原型,生成更具判别性的伪样本特征;

2、本发明实施例通过设计一项生成器判别性正则化项,以明确鼓励生成器生成的伪样本特征与来自同一类的查询集样本特征有着高度的相关性,确保生成的伪样本特征的判别性;

3、本发明实施例通过考虑生成的伪样本特征和类别原型的相似度关系,给每个生成的伪样本特征给予不同的权重系数,减少生成的伪样本特征中噪声样本特征的干扰;

4、本发明实施例通过将伪类别原型和提前修正的原型进行融合,实现二次修正原型,相比于少样本情况下的原始原型,修正后的原型更具判别性,更能代表其所属的类别;

5、本发明实施例提出的基于生成对抗网络原型修正的少样本图像分类方法能够在现有的技术上进一步解决由于训练标记样本不足而引起的少样本学习问题,并且能够让训练得到的模型拥有更好的泛化能力,有利于改善少样本图像分类任务的准确性。

装置实施例一

本发明实施例提供一种电子设备,包括:存储器、处理器及存储在所述存储器上并可在所述处理上运行的计算机程序,所述计算机程序被所述处理器执行时实现如方法实施例中所述的步骤。

装置实施例二

本发明实施例提供一种计算机可读存储介质,所述计算机可读存储介质上存储有信息传输的实现程序,所述程序被处理器执行时实现如方法实施例中所述的步骤。

本实施例所述计算机可读存储介质包括但不限于为:ROM、RAM、磁盘或光盘等。

最后应说明的是:以上各实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述各实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的范围。

相关技术
  • 一种基于原型网络少样本学习的图像分类器构建、图像识别方法及系统
  • 一种基于原型网络少样本学习的图像分类器构建、图像识别方法及系统
技术分类

06120116507395