掌桥专利:专业的专利平台
掌桥专利
首页

一种基于注意力机制的领域自适应方法

文献发布时间:2023-06-19 11:49:09


一种基于注意力机制的领域自适应方法

技术领域

本发明涉及深度神经网络领域,尤其涉及一种基于注意力机制的领域自适应方法。

背景技术

使用深度神经网络训练任务往往需要数据集的人工标注作为监督信息。然而,人工标注数据集费时费力。一种可行的方法是使用相关任务领域已经标注好的数据。但是,不同任务领域的数据分布往往相差较大,在进行模型迁移时,模型的性能会大幅下降。因此,无监督领域自适应方法近些年受到了广泛关注。

经典无监督领域自适应学习场景的定义为当源域有大量有标注的样本,而目标域没有或仅有少量标注好的样本时,我们希望在源域训练出的模型能够在目标域上也有较好的泛化性能,这就形成了一个。这是一种利用信息丰富的源域样本训练模型,并且该模型能够很好的适应目标域的样本分布的迁移学习方法。

目前无监督领域自适应的方法可以简要分为两类:

(1)基于非对抗性学习的方法

这类方法通常使用一个度量准则来衡量并最小化源域和目标域高阶特征的差异。例如使最大均值差异(MMD,MaximumMeanDiscrepancy)作为准则,找到一个核函数后将源域和目标域特征都映射到一个再生核的希尔伯特(Hilbert)空间上,在该空间中拉近分布距离,以抽取到域不变(domaininvariant)的特征。对于简单的分布,这种方法可以很好的对齐两个领域分布,但是对于复杂的分布往往效果不佳。

(2)基于对抗性学习的方法。

这类方法通常在源域和目标域分布差异较大时使用。基于对抗性的学习方法通过学习域分布来最小化源域和目标域的差异。

例如Ganin等人提出域对抗迁移网络(DANN,Domain-AdversarialNeuralNetwork)不仅包含对于标签预测的损失,还包含领域分类的损失,使得网络同时学习到分类目标和领域信息。Saito等人提出的最大分类差异(MCD,MaximumClassifierDiscrepancy)方法建立了两个分类器用于对源域样本的分类。对于目标样本,这两个分类器必须有不同的针对任务的决策边界。Lee等人提出了Drop-to-Adapt(DTA),它利用对抗性Dropout通过实施集群假设来学习强歧视性特征和鲁棒性。

然而,这些方法的性能仍然有待提高,因为他们仅仅直接对齐两个域分布,但是没有利用好两个领域直接的潜在分布信息。并且,这些方法对齐过程中只使用了两个领域全局特征,而全局特征网络包含大量冗余信息且计算低效。从全局的角度出发,传统的基于全局特征的对齐两个域分布的一致性学习方法效率不高。对于一些细粒化分类任务,具有识别性的特征对输出预测往往起到决定的作用,其他部分可能会传达多余的噪声。

发明内容

本发明的目的是引入基于注意力机制的域适应方法,更多关注有识别性的区域,提高领域自适应模型的性能。

为了解决上述技术问题,本发明提出了一种基于注意力机制的领域自适应方法,包括以下步骤:

S1.基于无监督领域自适应模型的第一领域的第一源域的第一样本和第一领域的第一目标域的第二样本,通过无监督领域自适应模型的转化器,获得无监督领域自适应模型的第二领域的第二源域的第三样本和第二领域的第二目标域的第四样本;

S2.基于第一样本、第二样本、第三样本、第四样本,通过神经网络模型和注意力获得机制,获得无监督领域自适应模型的转化预测结果;

S3.基于转化预测结果,通过损失函数模型,获得预测损失模型,用于通过最小化基于注意力的域内一致性函数,实现跨领域基于注意力机制的对齐,提升无监督领域自适应模型的性能。

优选地,损失函数模型包括输入损失函数模型、域内损失函数模型、域间损失函数模型;

输入损失函数模型用于获得第一样本、第二样本分别通过转化器时的输入损失;

域内损失函数模型用于获得通过注意力获得机制处理后的第一样本和第三样本之间的第一域内损失,以及第二样本和第四样本之间的第二域内损失;

域间损失函数模型用于获得第一样本和第三样本之间的第一域间损失,以及第二样本和第四样本之间的第二域间损失。

优选地,第一样本和第三样本的第一标签一致;

第二样本和第四样本的第二标签一致。

优选地,输入损失函数模型包括但不限于交叉熵损失函数。

优选地,域间损失函数模型包括输出损失函数模型;

基于第一目标域和第二目标域的输出预测结果一致性,构建输出损失函数模型,用于限制转换前后语义信息一致性;

基于输出损失函数模型和输入损失函数模型,通过加入正则化项,构建域间损失函数模型。

优选地,输入损失函数模型包括源域输入损失函数模型和目标域输入损失函数模型;

源域输入损失函数模型用于获得第一样本的第一输入损失;

目标域输入损失函数模型用于获得第二样本的第二输入损失。

优选地,域内损失函数模型的构建方法包括以下步骤:

S2.1.基于第一样本、第二样本、第三样本、第四样本,通过神经网络模型,获得第一样本对应的第一特征图、第二样本对应的第二特征图、第三样本对应的第三特征图、第四样本对应的第四特征图;

S2.2.基于注意力获得机制,将第一特征图、第二特征图、第三特征图、第四特征图分别转化为第一特征图对应的第一注意力图、第二特征图对应的第二注意力图、第三特征图对应的第三注意力图、第四特征图对应的第四注意力图;

S2.3.基于第一注意力图和第四注意力图,通过设置第一判别器,构建第一域内损失函数;

S2.4.基于第二注意力图和第三注意力图,通过设置第二判别器,构建第二域内损失函数;

S2.5.基于对抗学习环境,通过第一域内损失函数和第二域内损失函数,构建基于注意力一致性的域内损失函数模型。

优选地,第一域内损失函数为基于源域的第一域内注意力一致性损失函数;

第二域内损失函数为基于目标域的第二域内注意力一致性损失函数。

优选地,基于注意力一致性的域内损失函数模型

其中,A表示注意力图,s表示源域,t表示目标域,DAS表示第一判别器,DAt表示第二判别器,At→s为第三注意力图,At表示第二注意力图,Y`是维度与注意力图A相同维度的表示样本来源领域的标签。

优选地,神经网络模型至少包括Resnet50神经网络模型;

注意力计算机制至少包括CAM注意力计算机制。

本发明公开了以下技术效果:

与现有方法相比,本发明所述的方法添加了基于注意力的域适应,提高了模型的效率和跨领域适应的准确性,可以定位到有区分性的区域,对于这些区域的微小变化有很高的敏感性,而对于无关区域给予较少的关注,因此准确度更高。

鉴别器使得网络对于两个领域之间不一致区域给予更多注意力以消除微小的差异。

基于以上两点,该方法尤其适合细粒度高的任务,大幅提高准确性。

该发明可以利用现有数据集实现对于无标签数据集的自动批量标注,不依赖特定人类先验知识,因此具有很好的泛化性。

附图说明

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还根据这些附图获得其他的附图。

图1为本发明实施例所述的方法流程图。

具体实施方式

下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

如图1所示,本发明公开了一种基于注意力机制的领域自适应方法,包括以下步骤:

S1.基于无监督领域自适应模型的第一领域的第一源域的第一样本和第一领域的第一目标域的第二样本,通过无监督领域自适应模型的转化器,获得无监督领域自适应模型的第二领域的第二源域的第三样本和第二领域的第二目标域的第四样本;

S2.基于第一样本、第二样本、第三样本、第四样本,通过神经网络模型和注意力获得机制,获得无监督领域自适应模型的转化预测结果;

S3.基于转化预测结果,通过损失函数模型,获得预测损失模型,用于通过最小化基于注意力的域内一致性函数,实现跨领域基于注意力机制的对齐,提升无监督领域自适应模型的性能。

损失函数模型包括输入损失函数模型、域内损失函数模型、域间损失函数模型;输入损失函数模型用于获得第一样本、第二样本分别通过转化器时的输入损失;域内损失函数模型用于获得通过注意力获得机制处理后的第一样本和第三样本之间的第一域内损失,以及第二样本和第四样本之间的第二域内损失;域间损失函数模型用于获得第一样本和第三样本之间的第一域间损失,以及第二样本和第四样本之间的第二域间损失。

第一样本和第三样本的第一标签一致;第二样本和第四样本的第二标签一致。

输入损失函数模型包括但不限于交叉熵损失函数。

域间损失函数模型包括输出损失函数模型;基于第一目标域和第二目标域的输出预测结果一致性,构建输出损失函数模型,用于限制转换前后语义信息一致性;基于输出损失函数模型和输入损失函数模型,通过加入正则化项,构建域间损失函数模型。

输入损失函数模型包括源域输入损失函数模型和目标域输入损失函数模型;源域输入损失函数模型用于获得第一样本的第一输入损失;目标域输入损失函数模型用于获得第二样本的第二输入损失。

域内损失函数模型的构建方法包括以下步骤:

S2.1.基于第一样本、第二样本、第三样本、第四样本,通过神经网络模型,获得第一样本对应的第一特征图、第二样本对应的第二特征图、第三样本对应的第三特征图、第四样本对应的第四特征图;

S2.2.基于注意力获得机制,将第一特征图、第二特征图、第三特征图、第四特征图分别转化为第一特征图对应的第一注意力图、第二特征图对应的第二注意力图、第三特征图对应的第三注意力图、第四特征图对应的第四注意力图;

S2.3.基于第一注意力图和第四注意力图,通过设置第一判别器,构建第一域内损失函数;

S2.4.基于第二注意力图和第三注意力图,通过设置第二判别器,构建第二域内损失函数;

S2.5.基于对抗学习环境,通过第一域内损失函数和第二域内损失函数,构建基于注意力一致性的域内损失函数模型。

第一域内损失函数为基于源域的第一域内注意力一致性损失函数;第二域内损失函数为基于目标域的第二域内注意力一致性损失函数。

基于注意力一致性的域内损失函数模型

其中,A表示注意力图,s表示源域,t表示目标域,DAS表示第一判别器,DAt表示第二判别器,At→s为第三注意力图,At表示第二注意力图,Y`是维度与注意力图A相同维度的表示样本来源领域的标签。

神经网络模型至少包括Resnet50神经网络模型;

注意力计算机制至少包括CAM注意力计算机制。

实施例1:基于注意力、特征和输出的域适应

1、令源域为S,目标域为T,以单组样本输入为例,从领域S,T中各随机抽取一个样本X

被转换样本与其原始样本标签一致,因为转化器只改变样本风格等知识,但基本保持语义的一致;即X

2、将样本经过处理输入Resnet50等神经网络,得到预测结果。

3、使用损失函数计算输入X

其中C为神经网络中的分类器,S′和T′为转化生成的图像的特征分布。

4、基于3.1,X

5、域间一致性主要包含跨领域样本的标签预测一致性,加入λ为正则化项,因此域间一致性损失函数为:

6、在训练过程中将域间一致性损失函数用于参数更新和模型优化,通常情况下源域和目标域在此过程中会共享神经网络的参数。

7、训练过程中试图最小化域间损失函数以缩小源域和目标域特征分布差异,达到限制源域和目标域样本语义信息和风格一致的目的,实现基于输出的域适应。

8、对于细粒化的任务,仅基于输出的域适应对于细小区域欠缺敏感度。

9、仅用输出一致的方法缺乏对于域内样本的约束,即X

基于8和9,本发明在此基础上结合基于特征和注意力的域适应以提高分类任务准确性。具体方法如下:

10、将被转换样本与其原始样本均通过特征提取器(可以通过Resnet50等神经网络模型实现)后得到特征图得到F

11、使用对抗生成学习策略使得源域和目标域之间的特征分布对齐。

12、为了实现基于特征的领域自适应方法,提出两个判别器D

具体实施方法如下:

13、在对抗学习环境下,损失函数为:

其中Y是维度与注意力图A相同维度的表示样本来源领域的标签。

14、由于两对跨领域特征被共同优化,因此最终的一致性损失是两个特征引起误差的累积。

15、在判别过程中,源域的域内特征一致性试图达到F

同理,目标域的特征对齐可被定义为:

16、由于特征图F会传递样本的全局信息,存在一定的冗余。而对有识别度的特征给予更多关注可以让分类预测等任务更简单且准确率更高。

17、尤其对于细粒化的任务,源域和目标域之间即便是微小的变化也会对模型预测结果造成很大的影响。

18、基于16和17,本发明加入对于源域样本和目标域转换为源域样本的注意力对齐,即对齐X

19、将被转换样本与其原始样本均通过特征提取器(可以通过Resnet50等神经网络模型实现)后得到的特征图F

以源域原始样本为例,以CAM注意力计算机制为例,其特征映射为

20、为了提高无监督领域自适应方法的性能,在以上基于输出和特征的领域自适应方法基础上,本发明实施结合注意力机制,提出了一种新的基于输出、特征和注意力机制的领域自适应方法以及相关的损失函数。

21、为了实现发明中基于注意力机制的领域自适应方法部分,提出两个判别器D

D

22、注意力一致性损失函数被定义为:

其中Y'是维度与注意力图A相同维度的表示样本来源领域的标签。

23、对于源域的基于注意力的域内一致性损失函数为:

24、对于目标域的基于注意力的域内一致性损失函数为:

基于以上,D

25、基于以上,域内一致性总损失函数为:

26、基于以上,总体的损失函数为:

本发明实施例中,可以使用Resnet50等神经网络作为特征提取的主干网络,将样本通过神经网络得到输出,并通过最小化对于输出的损失函数保持跨领域样本的语义信息一致性。将样本通过特征提取器得到特征图提供全局信息,再将特征图通过CAM等注意力计算机制,得到注意力图提供局部信息,经过不断的迭代和参数优化,特征抽取网络将抽取到注意力图分布在源域和目标域上具有一致性的特征,使得利用源域有监督信息训练的模型,在目标域上也能够获得很好的性能。并且通过注意力机制,使得模型更加注重具有识别性的局部区域,提升模型的预测性能。随着模型性能的提升,注意力机制的准确性也将不断提高,提取出更具有识别意义的部位。通过最小化基于注意力的域内一致性函数,实现跨领域基于注意力机制的对齐。通过这些方法,本发明有效提升无监督领域自适应模型的性能,即基于输出、特征和注意力机制的领域自适应。

本发明上述方案可以应用到多个具体的工作中,下面仅示例性的给出一些具体应用方向:对于无标签数据集的标注以用于机器学习模型训练。例如对于漫画人脸特征的分类,目前很少有带有标注的漫画人脸特征,但已存在大量带有标注的真实人脸样本数据集。该方法可将真实人脸样本数据集作为源域,批量预测漫画的人脸特征。

应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释,此外,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。

最后应说明的是:以上所述实施例,仅为本发明的具体实施方式,用以说明本发明的技术方案,而非对其限制,本发明的保护范围并不局限于此,尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特征进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的精神和范围。都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应所述以权利要求的保护范围为准。

相关技术
  • 一种基于注意力机制的领域自适应方法
  • 一种基于多流形嵌入式分布对齐的领域自适应方法
技术分类

06120113065900