掌桥专利:专业的专利平台
掌桥专利
首页

特征解耦的少量样本预训练模型鲁棒性微调方法及装置

文献发布时间:2024-04-18 19:57:31


特征解耦的少量样本预训练模型鲁棒性微调方法及装置

技术领域

本发明涉及数据处理技术领域,尤其涉及一种特征解耦的少量样本预训练模型鲁棒性微调方法及装置。

背景技术

深度学习是人工智能领域的重要技术之一。近年来,深度学习技术取得了突飞猛进的进展。残差网络和Transformer架构的提出,以及大规模数据集的提出,使得深度学习在各个领域取得了显著的成果,提升了分类、检测、分割等人物的性能。CLIP(ContrastiveLanguage-Image Pre-Training,对比视觉语言预训练模型)等视觉语言跨模态预训练模型的提出,进一步的联通了视觉和语言的模态信息,使得各种跨模态任务成为可能。

然而,相对的,随着模型的参数增多,训练数据量的增加,使得模型的训练难度和训练成本显著提升。另外,在医学图像等场景中,数据获取难度高,不同医院的数据格式不统一,标注成本高,这使得从零开始训练一个大规模模型是一件极其困难,成本极高的事情。因此,如何能够在少量样本的下游任务中提升模型性能是一个很重要的问题。

小样本学习旨在降低深度学习算法对数据的依赖,探索如何能够更高效利用现有数据,利用更少的数据来训练准确率更高的模型。为了能够充分的挖掘少量样本情况下的类别信息,小样本学习有一套特有的特征提取,神经网络训练的方法,从而能够更快速的适配到少量样本数据。

另一方面,随着对比学习等自监督学习的方法的出现,使用大规模无标注数据进行模型训练成为可能。近年来,CLIP、GPT(Generative Pre-Trained Transformer,生成式预训练Transformer模型)等大规模预训练模型在各种任务上都展现出了优越的性能。然而,当将这些预训练模型应用到新的下游任务时,获取足够多的相关数据重新训练大规模预训练模型是极其困难甚至不可能的。因此,最直接的方式就是使用少量数据对现有的预训练模型进行微调。

对于一个小样本分类任务,整体的网络结构往往由预训练模型的特征提取器作为特征提取器提取图像特征,提取的特征将被输入到一个分类器中进行分类。由于预训练模型的参数量极大,使用少量数据直接对特征提取器进行微调往往会出现过拟合的情况。因此,现有的技术往往会更关注于数据量足够情况下的微调。比如,将预训练模型和微调后的模型参数直接加权平均,得到新的模型;或者选择首先微调分类头,随后微调特征提取器;也有方法在微调时引入额外的正则项来约束微调过程。当数据量足够少时,现有的技术大多关注于如何更好的利用现有的特征而避免修改特征提取器部分。但是这样并没有充分挖掘大模型的特征提取能力。然而,对预训练模型进行微调会破坏其在分布外数据的鲁棒性,会导致在分布外数据的性能下降。也有其他的工作发现,对分布外数据的鲁棒性和模型对虚假关联性的识别能力有关,直接微调会导致模型识别虚假关联性的能力下降。

用更少的样本来训练性能更高的模型,是小样本学习的主要目标。MAML(Model-Agnostic Meta-Learning,模型不可知元学习)等方法通过元学习的方式,在基类数据上训练一个基学习器,随后在少量的新类数据上微调得到适用于新类数据的模型;PN(Prototypical Network,原型网络)提出使用度量学习的方式来进行训练;这些方法都关注于从头开始对模型进行训练。也有一些按照传统的预训练-微调范式的小样本学习,例如:Meta-Baseline。随着CLIP等预训练模型的出现,当前可以借助于预训练模型获取更准确的图像特征。

然而,由于小样本学习中的数据量不足,而预训练模型的参数极多,直接使用少量数据对预训练模型进行微调极其容易过拟合,从而导致性能下降同时也会失去对分布外数据的鲁棒性。

现有的工作更多的关注如何利用CLIP提取出来的特征,而不对特征提取器进行微调,从而更好的提升小样本学习的性能。例如,CoOp(Context Optimization,上下文优化)用可学习向量对提示的上下文单词进行建模,使整个预先训练的参数保持固定。Tip-Adapter(Training-free CLIP-Adapter,无需训练的CLIP适应期)和APE(Adaptive PriorRefinement,自适应先验微调)不需要任何反向传播来训练适配器,而是通过从少量训练样本构建的键值缓存模型来创建权重。VPT(Visual Prompt Tuning,视觉提示微调技术)将额外的可学习参数引入到输入空间中。然而,所有这些方法都是在主干冻结的情况下处理的,并不能更好的挖掘特征提取器的潜力。

发明内容

本发明针对现有技术对预训练模型的潜力挖掘不足、对分布外数据的鲁棒性不足以及依赖特定架构,复用性差的问题,提出了本发明。

为解决上述技术问题,本发明提供如下技术方案:

一方面,本发明提供了一种特征解耦的少量样本预训练模型鲁棒性微调方法,该方法由电子设备实现,该方法包括:

S1、获取待分类的图像数据。

S2、采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型。

S3、将图像数据输入到微调后的预训练模型。

S4、根据图像数据以及微调后的预训练模型,得到图像数据的分类结果。

可选地,S2中的采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型,包括:

S21、获取样本数据以及预训练模型。

S22、根据样本数据,对预训练模型进行初始化。

S23、对初始化后的预训练模型进行微调,得到微调后的预训练模型。

可选地,样本数据,包括:样本图像数据、样本图像数据的类别以及提示模板。

S22中的根据样本数据,对预训练模型进行初始化,包括:

S221、构建预训练模型的特征提取网络,采用预训练模型的参数对特征提取网络进行初始化。

S222、构建分类器,采用类别的文本特征对分类器进行初始化。

S223、构建虚假信息提取器,采用提示模板的文本特征对虚假信息提取器进行初始化。

可选地,S221中的特征提取网络,包括:视觉特征提取网络、文本特征提取网络以及用于微调的视觉特征提取网络。

视觉特征提取网络采用卷积神经网络或者视觉自注意力模型ViT。

文本特征提取网络采用自注意力模型Transformer。

用于微调的视觉特征提取网络采用卷积神经网络或者ViT,并使用视觉特征提取网络的参数进行初始化。

可选地,S222中的采用类别的文本特征对分类器进行初始化,包括:

S2221、对任一类别,将类别名称和提示模板进行组合,得到任一类别的提示词。

S2222、将所有类别的提示词输入到文本特征提取网络,得到所有类别的提示词的文本特征。

其中,任一类别的提示词的文本特征

(1)

其中,

S2223、采用所有类别的提示词的文本特征作为分类器参数,对分类器进行初始化,得到初始化分类器。

可选地,S223中的采用提示模板的文本特征对虚假信息提取器进行初始化,包括:

S2231、对任一提示模板,将提示模板和类别名称进行组合,得到任一提示模板的提示词。

S2232、将提示词的文本输入到文本特征提取网络,得到所有提示词的文本特征,根据所有提示词的文本特征,计算所有提示模板的文本特征。

其中,根据所有提示词的文本特征,计算所有提示模板的文本特征

(2)

其中,

S2233、使用离散森林算法,对所有提示模板的提示词的文本特征进行异常值删除。

S2234、使用聚类算法,对异常值删除后的文本特征进行聚类,得到去除冗余的提示模板的文本特征。

S2235、采用去除冗余的提示模板的文本特征作为虚假信息提取器参数,对虚假信息提取器进行初始化,得到初始化虚假信息提取器。

可选地,S23中的对初始化后的预训练模型进行微调,得到微调后的预训练模型,包括:

S231、对样本图像数据进行预处理,得到预处理后的样本图像数据。

S232、对于任一预处理后的样本图像数据,获取样本图像数据的微调前的图像特征以及微调后的图像特征,根据微调后的图像特征的对应类别计算交叉熵损失,其中,微调前的图像特征根据视觉特征提取网络进行特征提取得到,微调后的图像特征根据用于微调的视觉特征提取网络进行特征提取得到。

S233、根据微调前的图像特征以及虚假信息提取器,得到微调前的虚假信息;根据微调后的图像特征以及虚假信息提取器,得到微调后的虚假信息。

S234、根据微调前的虚假信息以及微调后的虚假信息,计算相对熵损失。

S235、使用交叉熵损失以及相对熵损失的权重参数平衡用于微调的视觉特征提取器网络的准确性和鲁棒性;使用反向传播优化用于微调的视觉特征提取器网络的准确性和鲁棒性,得到微调后的视觉特征提取网络即微调后的预训练模型。

可选地,S232中的交叉熵损失

(3)

其中,

可选地,S234中的相对熵损失

(4)

其中,

另一方面,本发明提供了一种特征解耦的少量样本预训练模型鲁棒性微调装置,该装置应用于实现特征解耦的少量样本预训练模型鲁棒性微调方法,该装置包括:

获取模块,用于获取待分类的图像数据。

微调模块,用于采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型。

输入模块,用于将图像数据输入到微调后的预训练模型。

输出模块,用于根据图像数据以及微调后的预训练模型,得到图像数据的分类结果。

可选地,微调模块,进一步用于:

S21、获取样本数据以及预训练模型。

S22、根据样本数据,对预训练模型进行初始化。

S23、对初始化后的预训练模型进行微调,得到微调后的预训练模型。

可选地,样本数据,包括:样本图像数据、样本图像数据的类别以及提示模板。

微调模块,进一步用于:

S221、构建预训练模型的特征提取网络,采用预训练模型的参数对特征提取网络进行初始化。

S222、构建分类器,采用类别的文本特征对分类器进行初始化。

S223、构建虚假信息提取器,采用提示模板的文本特征对虚假信息提取器进行初始化。

可选地,特征提取网络,包括:视觉特征提取网络、文本特征提取网络以及用于微调的视觉特征提取网络。

视觉特征提取网络采用卷积神经网络或者视觉自注意力模型ViT。

文本特征提取网络采用自注意力模型Transformer。

用于微调的视觉特征提取网络采用卷积神经网络或者ViT,并使用视觉特征提取网络的参数进行初始化。

可选地,微调模块,进一步用于:

S2221、对任一类别,将类别名称和提示模板进行组合,得到任一类别的提示词。

S2222、将所有类别的提示词输入到文本特征提取网络,得到所有类别的提示词的文本特征。

其中,任一类别的提示词的文本特征

(1)

其中,

S2223、采用所有类别的提示词的文本特征作为分类器参数,对分类器进行初始化,得到初始化分类器。

可选地,微调模块,进一步用于:

S2231、对任一提示模板,将提示模板和类别名称进行组合,得到任一提示模板的提示词。

S2232、将提示词的文本输入到文本特征提取网络,得到所有提示词的文本特征,根据所有提示词的文本特征,计算所有提示模板的文本特征。

其中,根据所有提示词的文本特征,计算所有提示模板的文本特征

(2)

其中,

S2233、使用离散森林算法,对所有提示模板的提示词的文本特征进行异常值删除。

S2234、使用聚类算法,对异常值删除后的文本特征进行聚类,得到去除冗余的提示模板的文本特征。

S2235、采用去除冗余的提示模板的文本特征作为虚假信息提取器参数,对虚假信息提取器进行初始化,得到初始化虚假信息提取器。

可选地,微调模块,进一步用于:

S231、对样本图像数据进行预处理,得到预处理后的样本图像数据。

S232、对于任一预处理后的样本图像数据,获取样本图像数据的微调前的图像特征以及微调后的图像特征,根据微调后的图像特征的对应类别计算交叉熵损失,其中,微调前的图像特征根据视觉特征提取网络进行特征提取得到,微调后的图像特征根据用于微调的视觉特征提取网络进行特征提取得到。

S233、根据微调前的图像特征以及虚假信息提取器,得到微调前的虚假信息;根据微调后的图像特征以及虚假信息提取器,得到微调后的虚假信息。

S234、根据微调前的虚假信息以及微调后的虚假信息,计算相对熵损失。

S235、使用交叉熵损失以及相对熵损失的权重参数平衡用于微调的视觉特征提取器网络的准确性和鲁棒性;使用反向传播优化用于微调的视觉特征提取器网络的准确性和鲁棒性,得到微调后的视觉特征提取网络即微调后的预训练模型。

可选地,交叉熵损失

(3)

其中,

可选地,相对熵损失

(4)

其中,

一方面,提供了一种电子设备,所述电子设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述至少一条指令由所述处理器加载并执行以实现上述特征解耦的少量样本预训练模型鲁棒性微调方法。

一方面,提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令,所述至少一条指令由处理器加载并执行以实现上述特征解耦的少量样本预训练模型鲁棒性微调方法。

上述技术方案,与现有技术相比至少具有如下有益效果:

上述方案,提出了一种基于特征解耦对齐的少量样本预训练模型鲁棒性微调方法。通过借助于预训练模型优秀的文本视觉对齐能力,在微调时保持了预训练模型对虚假关联性的识别能力,从而避免在微调时出现过拟合,提升了预训练模型对新的数据的性能以及对分布外数据的鲁棒性。本发明并无额外的结构依赖,可以即插即用到现有的预训练模型的微调方法中并提升性能。

附图说明

为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。

图1是本发明实施例提供的特征解耦的少量样本预训练模型鲁棒性微调方法流程示意图;

图2是本发明实施例提供的特征解耦对齐的少量样本预训练模型鲁棒性微调方法图;

图3是本发明实施例提供的特征解耦的少量样本预训练模型鲁棒性微调装置框图;

图4是本发明实施例提供的一种电子设备的结构示意图。

具体实施方式

为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例的附图,对本发明实施例的技术方案进行清楚、完整地描述。显然,所描述的实施例是本发明的一部分实施例,而不是全部的实施例。基于所描述的本发明的实施例,本领域普通技术人员在无需创造性劳动的前提下所获得的所有其他实施例,都属于本发明保护的范围。

如图1所示,本发明实施例提供了一种特征解耦的少量样本预训练模型鲁棒性微调方法,该方法可以由电子设备实现。如图1所示的特征解耦的少量样本预训练模型鲁棒性微调方法流程图,该方法的处理流程可以包括如下的步骤:

S1、获取待分类的图像数据。

一种可行的实施方式中,本发明不仅可以用于图像分类,还可以用于各种下游任务,例如目标检测、图像分割等等。因此,上述步骤S1可以是采用现有技术获取待分类的图像数据,也可以是获取待分割的数据等。

S2、采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型。

可选地,上述步骤S2可以包括如下步骤S21- S23:

S21、获取样本数据以及预训练模型。

可选地,样本数据可以包括:样本图像数据、样本图像数据的类别以及提示模板。

一种可行的实施方式中,获取样本数据可以是选取数据类别、提示模板和对应的图像数据。对于训练数据主要包含两部分,一部分是用于微调的少量样本数据本身,其中包含了用于微调的图像和对应的类别标签;另一部分包含了当前下游任务所有的类别文本和提示模板,例如:“一张{类别}的照片”,“一张{类别}的画”等。

S22、根据样本数据,对预训练模型的神经网络进行初始化。

可选地,上述步骤S22可以包括如下步骤S221- S223:

S221、构建预训练模型的特征提取网络,采用预训练模型的参数对特征提取网络进行初始化。

其中,特征提取网络,可以包括:视觉特征提取网络、文本特征提取网络以及用于微调的视觉特征提取网络。

一种可行的实施方式中,预训练模型的视觉特征提取网络,在预训练模型中的作用为提取图像的特征,将图像映射到视觉特征空间中。该网络可以采用卷积神经网络或者ViT(Vision Transformer,视觉自注意力模型)构成。在微调时,本特征提取网络无需训练,记作

进一步地,预训练模型的文本特征提取网络,在预训练模型中的作用为提取文本的特征,将文本映射到文本特征空间中。该网络可以采用Transformer构成。在微调时,本特征提取网络无需训练,记作

进一步地,用于微调的视觉特征提取网络,继承于预训练模型的视觉特征提取网络,和预训练模型的视觉特征提取网络具有相同的结构以及初始化参数。该部分也是微调的主要模块,参数可学习,记作

S222、构建分类器,采用类别的文本特征对分类器进行初始化。

其中,分类器主要由一层线性层构成。由于预训练模型的图像特征和文本特征有很好的特征对齐性质。分类头参数是依据文本特征进行初始化。

可选地,S222中的采用类别的文本特征对分类器进行初始化,包括S2221- S2223:

S2221、对任一类别,将类别名称和提示模板进行组合,得到任一类别的提示词。

一种可行的实施方式中,对于第

S2222、将所有类别的提示词输入到文本特征提取网络,得到所有类别的提示词的文本特征。

一种可行的实施方式中,将所有的提示词输入到预训练模型的文本特征提取网络中,提取当前类别的所有提示词的特征。对于每个类别,其文本特征如下式(1)所示:

(1)

其中,

进一步地,按照上述流程,计算出所有的

S2223、采用所有类别的提示词的文本特征作为分类器参数,对分类器进行初始化,得到初始化分类器

(2)

S223、构建虚假信息提取器,采用提示模板的文本特征对虚假信息提取器进行初始化。

一种可行的实施方式中,为了能够保证微调后的模型对虚假相关性的识别能力,需要将虚假特征从图像特征中提取出来。

可选地,S223中的采用提示模板的提示特征对虚假信息提取器进行初始化,包括S2231- S2235:

S2231、对任一提示模板,将提示模板和类别名称进行组合,得到任一提示模板的提示词。

一种可行的实施方式中,借助于预训练模型的视觉特征和文本特征的对齐性质,同样通过文本特征计算各类虚假特征。对于每个提示模板,将其和类别名两两组合,得到提示词,例如:“一张狗的照片”“一张猫的照片”等等。

S2232、将提示词的文本输入到文本特征提取网络,得到所有提示词的文本特征,根据所有提示词的文本特征,计算所有提示模板的文本特征。

将所有的提示词输入到预训练模型的文本特征提取网络中,提取当前类别的所有提示词的特征。通过公式(3)使用提示词的文本特征计算提示模板的文本特征,其文本特征如下式(3)所示:

(3)

进一步地,按照上述流程,计算出所有的

S2233、使用离散森林算法,对所有提示模板的提示词的文本特征进行异常值删除。

一种可行的实施方式中,由于提示模板中存在不合理的提示,因此,使用离散森林算法来删除异常值。

S2234、使用聚类算法,对异常值删除后的文本特征进行聚类,得到去除冗余的提示模板的文本特征。

一种可行的实施方式中,去除异常值之后的文本特征,依旧会存在一些相似或者冗杂的文本特征,可以使用k-means聚类算法对齐进行聚类,最终得到

S2235、采用去除冗余的提示模板的文本特征作为虚假信息提取器参数,对虚假信息提取器进行初始化,得到初始化虚假信息提取器,初始化虚假信息提取器

(4)

本发明所提出的基于提示模板文本特征设计的虚假信息提取器,是首创利用文本来辅助设计虚假特征提取器,避免了直接使用特征提取网络从图像中提取特征,能够有效的将虚假信息从图像特征中提取出来,为之后保留模型识别虚假关联性提供了基础。

进一步地,本发明提出的二阶段文本特征修正方法,能够删除提示中的不合理的提示,并降低相似提示带来的冗余性,从而能够更准确的识别提取虚假信息。

S23、对初始化后的预训练模型进行微调,得到微调后的预训练模型。

可选地,上述步骤S23可以包括如下步骤S231- S236:

S231、对样本图像数据进行预处理,得到预处理后的样本图像数据。

一种可行的实施方式中,读取数据集中的图像,将图像经过一些随机裁剪翻转等变换,来增加数据的多样性。

S232、对于任一预处理后的样本图像数据,获取样本图像数据的微调前的图像特征以及微调后的图像特征,根据微调后的图像特征的对应类别计算交叉熵损失,其中,微调前的图像特征根据视觉特征提取网络进行特征提取得到,微调后的图像特征根据用于微调的视觉特征提取网络进行特征提取得到。

交叉熵损失

(5)

其中,

S233、根据微调前的图像特征以及虚假信息提取器,得到微调前的虚假信息;根据微调后的图像特征以及虚假信息提取器,得到微调后的虚假信息。

S24、根据微调前的虚假信息以及微调后的虚假信息,计算相对熵损失。使用相对熵来约束微调前后的特征提取器提取虚假特征的能力。

相对熵损失

(6)

其中,

S235、使用交叉熵损失以及相对熵损失的权重参数平衡用于微调的视觉特征提取器网络的准确性和鲁棒性;使用反向传播优化用于微调的视觉特征提取器网络的准确性和鲁棒性,得到微调后的视觉特征提取网络即微调后的预训练模型如下式(7)所示:

(7)

其中,

一种可行的实施方式中,交叉熵损失优化的是模型的准确性,相对熵损失优化的是模型的鲁棒性。然后通过

进一步地,使用反向传播是优化模型更新模型参数的过程,然后重复上述步骤S231-S235微调模型直到满足条件(准确率,微调轮次等等)。

S3、将图像数据输入到微调后的预训练模型。

S4、根据图像数据以及微调后的预训练模型,得到图像数据的分类结果。

一种可行的实施方式中,对于给定的一个图像

本发明首先用特征提取网络提取图像特征,之后将其输入到分类器中进行分类或者微调。在微调过程中,通过约束微调前后的模型识别虚假特征的能力来保证模型的鲁棒性。仅需将神经网络中的用于微的视觉特征提取网络替换到其他方法的视觉提取特征提取网络,微调后的模型可以直接用于分类或者迁移应用到其他方法中。

本发明提出的基于特征解耦对齐的少量样本微调的方法,借助于预训练模型极其强大的文本视觉对齐能力,在得到类别分类的原型用于分类的同时,也获得场景的原型,通过保证微调前后提取的模型对场景上分类的一致性,来保证模型对场景这类分类无关因素的识别一致性,从而保证了微调前后的模型对分布外数据的鲁棒性。本发明于不依赖于任何特殊架构,能够在现有方法上即插即用提升分类准确率,在使用少量数据对预训练模型微调提升分类准确率的同时,充分保留了模型的对分布外数据的鲁棒性。

举例来说,本发明在国际公开数据集ImageNet上,对预训练模型进行微调,每个类别仅使用16个样本的数据,并在ImageNet的验证集和其他由分布差异的数据集上进行性能测试。结果如表1(直接微调和鲁棒性微调方法在不同ImageNet变种数据集上的分类准确率(%))所示:

表1

在ImageNet上,本发明相对于直接微调提升了1.48%。在其他具有分布差异的数据集上,本发明分别在ImageNet-A上提升了1.75%,在ImageNet-R上提升了1.00%,在ImageNet-S上提升了1.26%,在ImageNet-V2上提升了0.10%。

本发明在国际公开数据集miniImageNet上对预训练模型进行微调,并测试了其在不同的分布外数据集的性能。如表2(直接微调和鲁棒性微调方法在不同分布外数据集上的分类准确率(%))所示:

表2

相对于CLIP和WiSE-FT,本发明能获得更高的分布外分类准确率。

同样,本发明能够提升分布内的性能,如表3(不同方法在不同数据集不同微调样本数上的平均分类准确率(%))所示:

表3

本发明在CoOp共11个数据集上,分别测试了不同微调样本数的情况下的分类准确率,可以看出,在不同的微调样本数,本发明都能达到最优分类准确率。

另外,本发明微调后的模型能够直接应用到现有的方法上提升分类准确率。如表4(不同方法使用本发明在ImageNet不同变种数据集的平均分类准确率(%))所示,本发明微调后的模型提升了现有方法在分布外数据的性能。如表5(不同方法使用本发明在分布内数据的平均分类准确率(%))所示,本发明能提升现有方法在分布内数据的性能。

表4

表5

针对现有技术对预训练模型的潜力挖掘不足,本发明注重于对预训练模型本身进行微调,充分挖掘预训练模型在下游少量数据上的潜力。

针对现有技术对分布外数据的鲁棒性不足。本发明提出的基于特征解耦对齐模块能够充分保留预训练模型对虚假相关性的识别能力,进而保证微调后预训练模型对分布外数据的鲁棒性。

针对现有的技术依赖特定架构,复用性差。本发明提出的特征解耦对齐模块不依赖于特定的结构,能够直接将本发明微调后的模型应用到现有的方法中,提升分布内和分布外数据的分类准确率。

本发明实施例中,提出了一种基于特征解耦对齐的少量样本预训练模型鲁棒性微调方法。通过借助于预训练模型优秀的文本视觉对齐能力,在微调时保持了预训练模型对虚假关联性的识别能力,从而避免在微调时出现过拟合,提升了预训练模型对新的数据的性能以及对分布外数据的鲁棒性。本发明并无额外的结构依赖,可以即插即用到现有的预训练模型的微调方法中并提升性能。

如图3所示,本发明实施例提供了一种特征解耦的少量样本预训练模型鲁棒性微调装置300,该装置300应用于实现特征解耦的少量样本预训练模型鲁棒性微调方法,该装置300包括:

获取模块310,用于获取待分类的图像数据。

微调模块320,用于采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型。

输入模块330,用于将图像数据输入到微调后的预训练模型。

输出模块340,用于根据图像数据以及微调后的预训练模型,得到图像数据的分类结果。

可选地,微调模块320,进一步用于:

S21、获取样本数据以及预训练模型。

S22、根据样本数据,对预训练模型进行初始化。

S23、对初始化后的预训练模型进行微调,得到微调后的预训练模型。

可选地,样本数据,包括:样本图像数据、样本图像数据的类别以及提示模板。

微调模块320,进一步用于:

S221、构建预训练模型的特征提取网络,采用预训练模型的参数对特征提取网络进行初始化。

S222、构建分类器,采用类别的文本特征对分类器进行初始化。

S223、构建虚假信息提取器,采用提示模板的文本特征对虚假信息提取器进行初始化。

可选地,特征提取网络,包括:视觉特征提取网络、文本特征提取网络以及用于微调的视觉特征提取网络。

视觉特征提取网络采用卷积神经网络或者视觉自注意力模型ViT。

文本特征提取网络采用自注意力模型Transformer。

用于微调的视觉特征提取网络采用卷积神经网络或者ViT,并使用视觉特征提取网络的参数进行初始化。

可选地,微调模块320,进一步用于:

S2221、对任一类别,将类别名称和提示模板进行组合,得到任一类别的提示词。

S2222、将所有类别的提示词输入到文本特征提取网络,得到所有类别的提示词的文本特征。

其中,任一类别的提示词的文本特征

(1)

其中,

S2223、采用所有类别的提示词的文本特征作为分类器参数,对分类器进行初始化,得到初始化分类器。

可选地,微调模块320,进一步用于:

S2231、对任一提示模板,将提示模板和类别名称进行组合,得到任一提示模板的提示词。

S2232、将提示词的文本输入到文本特征提取网络,得到所有提示词的文本特征,根据所有提示词的文本特征,计算所有提示模板的文本特征。

其中,根据所有提示词的文本特征,计算所有提示模板的文本特征

(2)

其中,

S2233、使用离散森林算法,对所有提示模板的提示词的文本特征进行异常值删除。

S2234、使用聚类算法,对异常值删除后的文本特征进行聚类,得到去除冗余的提示模板的文本特征。

S2235、采用去除冗余的提示模板的文本特征作为虚假信息提取器参数,对虚假信息提取器进行初始化,得到初始化虚假信息提取器。

可选地,微调模块320,进一步用于:

S231、对样本图像数据进行预处理,得到预处理后的样本图像数据。

S232、对于任一预处理后的样本图像数据,获取样本图像数据的微调前的图像特征以及微调后的图像特征,根据微调后的图像特征的对应类别计算交叉熵损失,其中,微调前的图像特征根据视觉特征提取网络进行特征提取得到,微调后的图像特征根据用于微调的视觉特征提取网络进行特征提取得到。

S233、根据微调前的图像特征以及虚假信息提取器,得到微调前的虚假信息;根据微调后的图像特征以及虚假信息提取器,得到微调后的虚假信息。

S234、根据微调前的虚假信息以及微调后的虚假信息,计算相对熵损失。

S235、使用交叉熵损失以及相对熵损失的权重参数平衡用于微调的视觉特征提取器网络的准确性和鲁棒性;使用反向传播优化用于微调的视觉特征提取器网络的准确性和鲁棒性,得到微调后的视觉特征提取网络即微调后的预训练模型。

可选地,交叉熵损失

(3)

其中,

可选地,相对熵损失

(4)

其中,

本发明实施例中,提出了一种基于特征解耦对齐的少量样本预训练模型鲁棒性微调方法。通过借助于预训练模型优秀的文本视觉对齐能力,在微调时保持了预训练模型对虚假关联性的识别能力,从而避免在微调时出现过拟合,提升了预训练模型对新的数据的性能以及对分布外数据的鲁棒性。本发明并无额外的结构依赖,可以即插即用到现有的预训练模型的微调方法中并提升性能。

图4是本发明实施例提供的一种电子设备400的结构示意图,该电子设备400可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上处理器(centralprocessing units,CPU)401和一个或一个以上的存储器402,其中,存储器402中存储有至少一条指令,至少一条指令由处理器401加载并执行以实现下述特征解耦的少量样本预训练模型鲁棒性微调方法:

S1、获取待分类的图像数据。

S2、采用基于特征解耦对齐的少量样本微调的方法,对预训练模型进行微调,得到微调后的预训练模型。

S3、将图像数据输入到微调后的预训练模型。

S4、根据图像数据以及微调后的预训练模型,得到图像数据的分类结果。

在示例性实施例中,还提供了一种计算机可读存储介质,例如包括指令的存储器,上述指令可由终端中的处理器执行以完成上述特征解耦的少量样本预训练模型鲁棒性微调方法。例如,计算机可读存储介质可以是ROM、随机存取存储器(RAM)、CD-ROM、磁带、软盘和光数据存储设备等。

本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。

以上所述仅为本发明的较佳实施例,并不用以限制本发明,凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

相关技术
  • 负例训练样本采集方法、装置及模型训练方法、装置
  • 基于对抗训练鲁棒的内容-风格解耦模型训练方法及系统
  • 基于预训练模型指导微调的半监督少样本时间序列异常检测与分类的系统和方法
技术分类

06120116459195