掌桥专利:专业的专利平台
掌桥专利
首页

源域数据缺失下的小样本领域自适应图像分类方法

文献发布时间:2024-04-18 19:58:30


源域数据缺失下的小样本领域自适应图像分类方法

技术领域

本发明涉及计算机视觉和数据处理领域,特指一种针对源域数据缺失下小样本领域自适应图像分类方法。

背景技术

图像分类是根据各自在图像信息中所反映的不同特征,把不同类别的目标区分开来的图像数据处理方法。它利用计算机对图像进行定量分析,把图像或图像中的每个像元或区域划归为若干个类别中的一种,以代替人的视觉判读。图像分类是计算机视觉中最基础的一个任务,也是几乎所有基准模型进行比较的任务。

深度学习在图像分类中扮演着非常重要的角色。传统的图像分类方法通常依赖于特征提取器和分类器的组合进行分类,而深度学习则是通过训练神经网络来自动学习图像特征并进行分类。深度学习在图像分类中的作用是通过神经网络自动学习图像特征,从而实现高效准确的图像分类。例如,ImageNet图像分类挑战赛中,使用深度学习的方法在准确率上远远超过了传统的方法。

领域自适应(Domain Adaptation)是一种解决深度学习模型在不同领域中泛化能力差的问题的一类方法。在领域自适应中,存在源域(source domain)和目标域(targetdomain)。源域表示与待检测样本不同的领域,一般具有丰富的标签作为监督信息;目标域表示待检测样本所在的领域,往往无标签信息或者只有少量标签。源域和目标域往往属于同一类任务,但是分布不同。在图像分类任务中,领域自适应的作用是利用源域和目标域之间的共性和差异,来提高模型在目标域上的分类准确率。通过领域自适应,可以将在源域上训练好的模型迁移到目标域中进行分类。这里的源域模型包括源域特征提取器和源域分类器。一般而言,特征提取器和分类器由神经网络组成。其中,特征提取器的输入是图像,输出是图像特征。分类器的输入是图像特征,输出是预测的分类结果。源域模型的网络权重由源域数据所有者在源域数据基础上使用深度学习方法训练得到,在领域自适应中,用户提供源域模型的神经网络权重。

小样本领域自适应(Few-shot Domain Adaptation)是指在目标域的样本数量较少的情况下实现领域自适应的一类方法。与传统的领域自适应方法相比,小样本领域自适应更加具有挑战性,因为目标域样本数量较少,可能导致模型在目标域上出现过拟合的问题。小样本领域自适应仅需要极少的目标域数据,可用于很多特殊的领域,比如医学图像处理。在医学图像处理场景中,源域数据集是已经收集的医学图像训练集。使用这个训练集能训练出在源域上的医学图像处理系统。目标域数据集指使用这个医学图像系统的少数病人的图像数据集。目标域数据集的样本量十分小。对特定地区的病人的医学图像进行处理的时候,需使用目标域数据集对源域的医学图像处理系统进行领域自适应。小样本领域自适应的图像分类方法是指使用源域的带标注数据和目标域的少量标注数据将源域模型自适应到目标域的图像分类方法。Sun Q等人提出使用元学习方法来解决小样本领域自适应(详见文献“Sun Q,Liu Y,Chua T-S,et al.Meta-transfer learning for few-shotlearning[C].In CVPR.2019”Sun Q,Liu Y,Chua T-S等人的论文:针对小样本学习的元迁移学习)该方法初步解决了小样本领域自适应问题,但是使用了大量的源域数据。

然而,在特殊的敏感领域,由于源域包含大量隐私信息,直接使用源域数据进行源域模型的领域自适应是危险的。举个例子,将前文所述的医学图像处理系统部署到目标域的时候,现有的小样本领域自适应方法(详见文献“Liang J,Hu D,Feng J.Do We ReallyNeed to Access the Source Data?Source Hypothesis Transfer for UnsupervisedDomain Adaptation[C].In ICML.2020.”Liang J,Hu D,Feng J的论文:我们真的需要访问源数据吗?无监督领域自适应的源假设转移)需要大量的有标签源域图像数据,这个训练过程可能会泄露大量病人的隐私信息,造成巨大的损失(详见文献“Jayaraman B,EvansD.Evaluating Differentially Private Machine Learning in Practice[C].InUSENIX.2019”Jayaraman B,Evans D的论文:实践中评估差分隐私机器学习)。所以,在进行小样本领域自适应图像分类的过程中,只有保护好源域数据中的隐私信息,小样本领域自适应图像分类方法才会在图像领域真正做到落地应用。

因为源域数据缺失下的小样本领域自适应是一个新问题,目前还没有公开文献涉及完全适用于该问题的方法,大家了解的只有3个基本的领域自适应方法。这三个方法是直接迁移(WA)、微调(Finetune,FT)、SHOT。直接迁移(WA)方法是直接使用源域模型来对目标域数据进行分类。这是小样本假设自适应最基本的解决方法,但是该方法在目标域的分类性能因为没有任何的针对目标域的训练而导致源域数据缺失下的小样本领域自适应的分类精度较差;微调(Finetune,FT)方法是:首先固定源域模型的特征提取器,然后使用现有的少量目标域数据对源域模型的分类器进行训练,该方法因为需要重新训练源域模型的分类器而导致源域数据缺失下的小样本领域自适应的训练速度较慢;SHOT是一个最新的源域假设迁移方法(详见文献“Liang J,Hu D,Feng J.Do We Really Need to Access theSource Data?Source Hypothesis Transfer for Unsupervised Domain Adaptation[C].In ICML.2020.”Liang J,Hu D,Feng J的论文:我们真的需要访问源数据吗?无监督领域自适应的源假设转移),但该方法因为需要大量的目标域而导致源域数据缺失下的小样本领域自适应的隐私泄露严重。

在已有工作中,有研究者利用源域数据和无标签的目标域数据把源域假设迁移为目标域分类器。然而,由于现有领域自适应方法都需要大量目标域数据,因而不适用于无源域数据下的小样本领域自适应问题。因此,如何在仅有少量目标域数据的帮助下,在源域数据缺失的情况下,将源域模型迁移到仅有少量样本的目标域上实现对目标域图像的小样本领域自适应图像分类,提高无法获取源域数据情况下的图像分类的精度,仍是图像分类需要解决的难题。

发明内容

本发明要解决的技术问题是提供一种源域数据缺失下的小样本领域自适应图像分类方法,实现在无法获取源域数据情况下只利用少量目标域数据,基于源域模型迁移到目标域,实现对目标域图像的小样本领域自适应图像分类方法,有效解决隐私泄露问题,并提高图像分类精度。

本发明的技术方案是:构建源域数据缺失下的小样本领域自适应图像分类系统。该系统由输入数据处理模块、特征提取模块、中间域生成模块、领域迁移模块、类别推理模块构成。准备目标域的训练集和源域模型:目标域训练集由用户给定;源域模型的神经网络结构以及模型参数(包括源域特征提取器和源域分类器的神经网络模型以及对应的神经网络模型参数)由用户提供。采用目标域训练集对源域数据缺失下的小样本领域自适应图像分类系统进行训练,得到训练后的源域数据缺失下的小样本领域自适应图像分类系统。最后采用训练后的源域数据缺失下的小样本领域自适应图像分类系统对用户输入的目标域图像进行分类,得到目标域图像的分类结果。

本发明包括以下步骤:

第一步,构建源域数据缺失下的小样本领域自适应图像分类系统。

源域数据缺失下的小样本领域自适应图像分类系统由输入数据处理模块、特征提取模块、中间域生成模块、领域迁移模块、类别推理模块构成。

输入数据处理模块与特征提取模块相连,输入数据处理模块接收用户输入的目标域图像数据,对目标域图像数据进行数据增强,得到增强后的目标域图像数据集,将增强后的目标域图像数据集发送给特征提取模块。

中间域生成模块与特征提取模块相连,采用DCGAN的生成器结构(详见文献“Radford A,Metz L,Chintala S.Unsupervised representation learning withdeepconvolutional generative adversarial networks[C].In ICLR.2015.”Radford A,MetzL,Chintala S的论文:基于深度卷积生成对抗网络的无监督表示学习,内含多个卷积层和批量归一化层),该模块采样高斯噪声和类别序号,根据高斯噪声和类别序号反卷积生成中间域图像数据,将中间域图像数据发送给特征提取模块。

特征提取模块与输入数据处理模块、中间域生成模块、领域迁移模块、类别推理模块相连,采用与源域模型的特征提取器相同的神经网络结构。在训练阶段,特征提取模块从输入数据处理模块接收增强后的目标域图像数据,从中间域生成模块接收中间域图像数据,提取增强后的目标域图像数据和中间域图像数据对应的特征,将中间域图像特征和目标域图像特征输出到类别推理模块;并将中间域图像特征和目标域图像特征组合成四种类别的样本对,生成标签对总集合和样本对组标签集合;对样本对进行前向推理,得到样本对的浅层特征集合;将样本对的浅层特征集合和样本对组标签集合输出到领域迁移模块;为了进行领域迁移模块、特征提取模块和类别推理模块的对抗训练将标签对总集合发送给类别推理模块。

领域迁移模块与特征提取模块相连,由一个3层全连接网络和一个四分类器组成。四分类器由单层全连接网络和Softmax激活函数组成。3层全连接网络从特征提取模块接收样本对浅层特征集合,对样本对浅层特征集合进行前向推理,得到样本对深层特征集合,将样本对深层特征集合发送给四分类器;四分类器对样本对深层特征集合进行分类,得到样本对的预测标签集合,并根据预测标签集合和样本对组标签集合的差值计算损失函数,得到损失函数值,用于训练。领域迁移模块用于在训练时不断更新特征提取模块、类别推理模块的神经网络权重参数,提高源域数据缺失下的小样本领域自适应图像分类系统对图像分类的精度,在实际对用户输入的图像进行分类时不参与工作。

类别推理模块与特征提取模块相连,采用与源域模型的分类器相同的神经网络结构。类别推理模块从特征提取模块接收中间域图像特征和目标域图像特征,进行类别推理,得到分类结果。在训练过程中,该分类结果用于训练;在对用户输入的目标域图像进行分类时,该分类结果是最终的目标域图像分类结果。

第二步,将用户提供的源域模型加载到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块和类别推理模块,得到初始化后的源域数据缺失下的小样本领域自适应图像分类系统,方法是:

2.1复制源域模型的特征提取器的神经网络结构到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块,复制源域模型的分类器的神经网络结构到源域数据缺失下的小样本领域自适应图像分类系统的类别推理模块。

2.2将源域模型的特征提取器的神经网络结构的权重赋值到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块,将源域模型的分类器的神经网络结构的权重赋值到源域数据缺失下的小样本领域自适应图像分类系统的类别推理模块。

这样,初始化后的源域数据缺失下的小样本领域自适应图像分类系统拥有了与源域模型的特征提取器和分类器相同的神经网络结构。

第三步,初始化后的源域数据缺失下的小样本领域自适应图像分类系统的输入数据处理模块接收用户输入的目标域图像训练集,目标域图像训练集由N类图像数据组成,每一类有support张带标签的图像,图像的格式为jpg或者png,support的取值范围一般为1或者5。N的取值与源域模型的分类数目保持一致。记目标域图像训练集为Train_Set_input。

第四步,采用Train_Set_input,对初始化后的源域数据缺失下的小样本领域自适应图像分类系统进行训练,得到特征提取模块和类别推理模块在目标域的最佳权重参数,方法是:

4.1初始化中间域生成模块、特征提取模块、类别推理模块和领域迁移模块的权重。将中间域生成模块中的所有卷积层权重都初始化为[0,1]之间的随机数,所有批量归一化层中的均值权重初始化为0、标准差权重初始化为1。将领域迁移模块中的3层全连接层网络和四分类器的单层全连接网络的权重矩阵的值都初始化为[0,1]之间符合高斯分布的随机数,将偏置初始化为0。

4.2设置网络训练参数,包括设置学习率γ=0.001,中间域生成模块批处理尺寸batch=32,总训练迭代参数T

4.3输入数据处理模块采用增强方法对Train_Set_input进行增强,得到增强后的目标域图像数据集D

4.3.1输入数据处理模块将Train_Set_input中的每个图像的分辨率改变为32×32大小,得到修改了分辨率的目标域样本集Train_Set_1。

4.3.2输入数据处理模块将Train_Set_input中的每个图像以0.25的概率进行随机旋转,得到随机旋转后样本集Train_Set_2。

4.3.3输入数据处理模块将Train_Set_input中的每个图像以0.25的概率进行随机平移,得到随机平移后样本集Train_Set_3。

4.3.4输入数据处理模块将Train_Set_input,Train_Set_1,Train_Set_2,Train_Set_3组成增强后的目标域图像数据集D

4.3.5输入数据处理模块将D

4.4特征提取模块对D

4.5初始化训练迭代参数为epoch=1。

4.6训练初始化后的源域数据缺失下的小样本领域自适应图像分类系统,方法是:首先训练中间域生成模块,直到训练迭代参数epoch=T

4.6.1初始化中间域图像数据集合D

4.6.2训练中间域生成模块,方法是:

4.6.2.1初始化中间域生成模块的生成类别的序号class_n=1。

4.6.2.2中间域生成模块从均值为1、方差为0的高斯分布中随机采样batch个噪声向量noise。

4.6.2.3中间域生成模块对noise和class_n进行反卷积(合适),得到batch个属于第class_n类的中间域图像数据G_noise,将G_noise存储到D

4.6.2.4特征提取模块从中间域生成模块接收G_noise,对G_noise进行神经网络前传,提取出G_noise的中间域图像特征X

4.6.2.5类别推理模块从特征提取模块接收X

4.6.2.6类别推理模块从特征提取模块接收目标域图像特征X

4.6.2.7计算损失值

4.6.2.8使用随机梯度下降(SGD)算法(见文献“Robbins H,Monro S.AStochastic Approximation Method[J].Annals of Mathematical Statistics,1951.”Robbins H,Monro S的论文:一种随机近似法)对

4.6.2.9令class_n=class_n+1,若class_n≤N,转4.6.2.2,继续进行中间域生成模块的训练。若class

4.6.3若此时不满足epoch=T

4.6.3.1初始化领域迁移模块的训练次数train_domain_number=0。

4.6.3.2特征提取模块从中间域图像数据集合D

4.6.3.2.1初始化G

4.6.3.2.2初始化样本对类别序号cid=1。

4.6.3.2.3从D

4.6.3.2.4从D

4.6.3.2.5从D

4.6.3.2.6从D

4.6.3.2.7令cid=cid+1。若cid≤N,转4.6.3.2.3。若cid>N,令样本对总集合G={G

4.6.3.3特征提取模块对G进行前向推理(参考文献“Rumelhart D E,Hinton G E,Williams R J.Learning representations by back-propagating errors[J].1988.”Rumelhart D E,Hinton G E,Williams R J等人的论文:通过反向传播误差学习特征),得到G的浅层特征集Φ,Φ={Φ

4.6.3.4领域迁移模块的3层全连接网络对Φ进行前向推理,得到样本对的深层特征集G

4.6.3.5四分类器对G

得到的预测组标签,D(Φ

,D(Φ

4.6.3.6四分类器根据预测组标签集合Predict和组标签总集合GY计算领域迁移模块的损失函数

其中one_hot(GY)表示求GY的独热编码,log(Predict)为对Predict求对数。

4.6.3.7使用随机梯度下降算法(SGD)对

4.6.3.8令train_domain_number=train_domain_number+1。

4.6.3.9若train_domain_number≤T

4.6.4若epoch<(T

4.6.4.1特征提取模块从中间域图像数据集合D

4.6.4.2特征提取模块对采用4.6.3.3所述的前向推理G进行特征提取,再次得到G的浅层特征集Φ,Φ={Φ

4.6.4.3领域迁移模块的3层全连接网络对Φ进行前向推理,得到样本对的深层特征集G

4.6.4.4四分类器对G

4.6.4.5四分类器根据Predict和GY按照公式(1)重新计算领域迁移模块的损失函数

4.6.4.6使用随机梯度下降算法(SGD)对

4.6.4.7为了增加数据的丰富性,输入数据处理模块使用用户提供的Train_Set_input,采用4.3步所述增强方法对Train_Set_input进行再次增强,得到新的增强后的目标域样本集合D

4.6.4.8特征提取模块对D

4.6.4.9类别推理模块对X

4.6.4.10计算特征提取模块和类别推理模块的对抗损失

其中,β为领域迁移超参数,one_hot(GY

4.6.4.11使用随机梯度下降算法(SGD)对

4.6.5令epoch=epoch+1,若epoch≤T

第五步,加载第四步中训练得到的特征提取模块、类别推理模块的权重参数,得到训练后的源域数据缺失下的小样本领域自适应图像分类系统。

第六步,训练后的源域数据缺失下的小样本领域自适应图像分类系统对用户输入的目标域图像进行分类,得到目标域图像的分类结果。方法是:

6.1训练后的源域数据缺失下的小样本领域自适应图像分类系统的输入数据处理模块接收用户输入的目标域图像D

6.2特征提取模块对D

6.3类别推理模块对从特征提取模块收到的Φ

采用本发明可以达到以下技术效果:

1.相比于背景技术所述的基于小样本领域自适应的图像分类方法,本发明无需使用任何源域数据即可以对图像进行分类,有效解决了棘手的隐私泄露问题,实现了无源域数据下的图像分类。

2.本发明第4.6.2步训练中间域生成模块时将源域数据和目标域数据结合起来,构造出一个新的中间域数据集,提高了源域模型在目标域上的分类精度。

3.本发明第4.6.3和4.6.4步仅使用少量的目标域数据对领域迁移模块、特征提取模块和类别推理模块进行训练,使得训练速度大大加快,因此本发明在实际应用中更加实用和可行,特别是在数据规模较大且难以获得源域数据的情况下。

4.本发明在STL-10数据集和CIFAR-10数据集进行了实验,证实了本发明可以有效地把源域模型信息迁移到第一步构建的源域数据缺失下的小样本领域自适应图像分类系统,并且能成功训练源域数据缺失下的小样本领域自适应图像分类系统,采用训练后的源域数据缺失下的小样本领域自适应图像分类系统按第六步所述方法对用户的图像进行分类,精度超出了背景技术所述基于小样本领域自适应的图像分类方法。

因此,本发明成功解决了图像分类精度和源域数据隐私泄露的矛盾难题,既避免源域数据隐私泄露,又提高了图像分类精度,应用范围广泛。

附图说明

图1是本发明第一步构建的源域数据缺失下的小样本领域自适应图像分类系统逻辑结构图。

图2是本发明总体流程图。

具体实施方式

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。

为了验证本发明的图像分类效果,采用CIFAR-10数据集(

如图2所示,本发明在上述测试环境下的一个实施例包括以下步骤:

第一步,构建源域数据缺失下的小样本领域自适应图像分类系统。

如图1所示,源域数据缺失下的小样本领域自适应图像分类系统由输入数据处理模块、特征提取模块、中间域生成模块、领域迁移模块、类别推理模块构成。

输入数据处理模块与特征提取模块相连,输入数据处理模块接收用户输入的目标域图像数据,对目标域图像数据进行数据增强,得到增强后的目标域图像数据集,将增强后的目标域图像数据集发送给特征提取模块。

中间域生成模块与特征提取模块相连,采用DCGAN的生成器结构(详见文献“Radford A,Metz L,Chintala S.Unsupervised representation learning withdeepconvolutional generative adversarial networks[C].In ICLR.2015.”Radford A,MetzL,Chintala S的论文:基于深度卷积生成对抗网络的无监督表示学习),该模块从均值为0,方差为1的高斯分布中采样得到高斯噪声和类别序号,根据高斯噪声和类别序号反卷积,生成中间域图像数据,将中间域图像数据发送给特征提取模块。

特征提取模块与输入数据处理模块、中间域生成模块、领域迁移模块、类别推理模块相连,采用与源域模型的特征提取器相同的神经网络结构(如DenseNet-169(详见文献“[50]Huang G,Liu Z,Van Der Maaten L,et al.Densely connectedconvolutionalnetworks[C].In CVPR.2017.”Huang G,Liu Z,Van Der Maaten L等人的论文密集连接卷积神经网络)。在训练阶段,特征提取模块从输入数据处理模块接收增强后的目标域图像数据,从中间域生成模块接收中间域图像数据,提取增强后的目标域图像数据和中间域图像数据对应的特征,将中间域图像特征和目标域图像特征输出到类别推理模块;并将中间域图像特征和目标域图像特征组合成四种类别的样本对,并生成标签对总集合和样本对组标签集合;对样本对进行前向推理,得到样本对的浅层特征集合;将样本对的浅层特征集合和样本对组标签集合输出到领域迁移模块;为了进行类别推理模块的对抗训练将标签对总集合发送给类别推理模块。

领域迁移模块与特征提取模块相连,由一个3层全连接网络和一个四分类器组成。四分类器由单层全连接网络和Softmax激活函数组成。3层全连接网络从特征提取模块接收样本对浅层特征集合,对样本对的浅层特征集合进行前向推理,得到样本对的深层特征集合,将样本对深层特征集合发送给四分类器;四分类器对样本对深层特征集合进行分类,得到样本对的预测标签集合,并根据预测标签集合和样本对组标签集合的差值计算损失函数,得到损失函数值,用于训练。领域迁移模块用于在训练时不断更新特征提取模块、类别推理模块的神经网络权重参数,提高源域数据缺失下的小样本领域自适应图像分类系统对图像分类的精度,在实际对用户输入的图像进行分类时不参与工作。

类别推理模块与特征提取模块相连,采用与源域模型的分类器相同的神经网络结构(如选定一层全连接网络,并且最后伴随Softmax激活函数作为类别推理模块)。类别推理模块从特征提取模块接收中间域图像特征和目标域图像特征,进行类别推理,得到分类结果。在训练过程中,该分类结果用于训练;在对用户输入的目标域图像进行分类时,该分类结果是最终的目标域图像分类结果。

第二步,将用户提供的源域模型加载到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块和类别推理模块,得到初始化后的源域数据缺失下的小样本领域自适应图像分类系统,方法是:

2.1复制源域模型的特征提取器的神经网络结构到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块,复制源域模型的分类器的神经网络结构到源域数据缺失下的小样本领域自适应图像分类系统的类别推理模块。

2.2将源域模型的特征提取器的神经网络结构的权重赋值到源域数据缺失下的小样本领域自适应图像分类系统的特征提取模块,将源域模型的分类器的神经网络结构的权重赋值到源域数据缺失下的小样本领域自适应图像分类系统的类别推理模块。

这样,源域数据缺失下的小样本领域自适应图像分类系统拥有了与源域模型的特征提取器和分类器相同的神经网络结构。

第三步,初始化后的源域数据缺失下的小样本领域自适应图像分类系统的输入数据处理模块接收用户输入的目标域图像训练集,目标域图像训练集由N类图像数据组成,每一类有support张带标签的图像,图片的格式为jpg或者png,support的取值范围一般为1或者5。N的取值与源域模型的分类数目保持一致。记目标域图像训练集为Train_Set_input。如本实施例构建的目标域图像训练集选取STL-10的9(N=9)个重叠类别(airplane(飞机)、automobile(汽车)、bird(鸟)、cat(猫)、deer(鹿)、dog(狗)、horse(马)、ship(船)、truck(卡车)),每个类别随机选取5(support=5)张图片,因此一共有45张图片得到目标域训练集。

第四步,采用Train_Set_input,对初始化后的源域数据缺失下的小样本领域自适应图像分类系统进行训练,得到特征提取模块和类别推理模块在目标域的最佳权重参数,方法是:

4.1初始化中间域生成模块、特征提取模块、类别推理模块和领域迁移模块的权重。将中间域生成模块中的所有卷积层权重都初始化为[0,1]之间的随机数,所有批量归一化层中的均值权重初始化为0、标准差权重初始化为1。将领域迁移模块中的3层全连接层网络和四分类器的单层全连接网络的权重矩阵的值都初始化为[0,1]之间符合高斯分布的随机数,将偏置初始化为0。

4.2设置网络训练参数,包括设置学习率γ=0.001,中间域生成模块批处理尺寸batch=32,总训练迭代参数T

4.3输入数据处理模块采用增强方法对Train_Set_input进行增强,得到增强后的目标域图像数据集D

4.3.1输入数据处理模块将Train_Set_input中的每个图像的分辨率改变为32×32大小,得到修改了分辨率的目标域样本集Train_Set_1。

4.3.2输入数据处理模块将Train_Set_input中的每个图像以0.25的概率进行随机旋转,得到随机旋转后样本集Train_Set_2。

4.3.3输入数据处理模块将Train_Set_input中的每个图像以0.25的概率进行随机平移,得到随机平移后样本集Train_Set_3。

4.3.4输入数据处理模块将Train_Set_input,Train_Set_1,Train_Set_2,Train_Set_3组成增强后的目标域图像数据集D

4.3.5输入数据处理模块将D

4.4特征提取模块对D

4.5初始化训练迭代参数为epoch=1。

4.6训练初始化后的源域数据缺失下的小样本领域自适应图像分类系统,方法是:首先训练中间域生成模块,直到训练迭代参数epoch=T

4.6.1初始化中间域图像数据集合D

4.6.2训练中间域生成模块,方法是:

4.6.2.1初始化中间域生成模块的生成类别的序号class_n=1。

4.6.2.2中间域生成模块从均值为1、方差为0的高斯分布中随机采样batch个噪声向量noise。

4.6.2.3中间域生成模块的DCGAN的生成器对noise和class_n进行前向推理,得到batch个属于第class_n类的中间域图像数据G_noise,将G_noise存储到D

4.6.2.4特征提取模块从中间域生成模块接收G_noise,对G_noise进行神经网络前传,提取出G_noise的中间域图像特征X

4.6.2.5类别推理模块从特征提取模块接收X

4.6.2.6类别推理模块从特征提取模块接收目标域图像特征X

4.6.2.7计算损失值

4.6.2.8使用随机梯度下降(SGD)算法(见文献“Robbins H,Monro S.AStochastic Approximation Method[J].Annals of Mathematical Statistics,1951.”Robbins H,Monro S的论文:一种随机近似法)对

4.6.2.9令class_n=class_n+1,若class_n≤N,转4.6.2.2,继续进行中间域生成模块的训练。若class

4.6.3若此时不满足epoch=T

4.6.3.1初始化领域迁移模块的训练次数train_domain_number=0。

4.6.3.2特征提取模块从中间域图像数据集合D

4.6.3.2.1初始化G

4.6.3.2.2初始化样本对类别序号cid=1。

4.6.3.2.3从D

4.6.3.2.4从D

4.6.3.2.5从D

4.6.3.2.6从D

4.6.3.2.7令cid=cid+1。若cid≤N,转4.6.3.2.3。若cid>N,令样本对总集合G={G

4.6.3.3特征提取模块对G进行前向推理(参考文献“Rumelhart D E,Hinton G E,Williams R J.Learning representations by back-propagating errors[J].1988.”Rumelhart D E,Hinton G E,Williams R J等人的论文:通过反向传播误差学习特征),得到G的浅层特征集Ф,Φ={Φ

4.6.3.4领域迁移模块的3层全连接网络对Ф进行前向推理,得到样本对的深层特征集G

4.6.3.5四分类器对G

4.6.3.6四分类器根据预测组标签集合Predict和组标签总集合GY计算领域迁移模块的损失函数

其中one_hot(GY)表示求GY的独热编码,log(Predict)为对Predict求对数。

4.6.3.7使用随机梯度下降算法(SGD)对

4.6.3.7令train_domain_number=train_domain_number+1。

4.6.3.8若train_domain_number≤T

4.6.4若epoch<(T

4.6.4.1特征提取模块从中间域图像数据集合D

4.6.4.2特征提取模块对采用4.6.3.3所述的前向推理G进行特征提取,再次得到G的浅层特征集Ф,将Ф发送给领域迁移模块。

4.6.4.3领域迁移模块采用4.6.3.4所述的前向推理对Ф进行前向推理,再次得到样本对的预测组标签集合Predict。

4.6.4.4四分类器根据Predict和GY按照公式(1)重新计算领域迁移模块的损失函数

4.6.4.5使用随机梯度下降算法(SGD)对

4.6.4.6为了增加数据的丰富性,输入数据处理模块使用用户提供的Train_Set_input,采用4.3步所述增强方法对Train_Set_input进行再次增强,得到新的增强后的目标域样本集合D

4.6.4.7特征提取模块对D

4.6.4.8类别推理模块对X

4.6.4.9领域迁移模块的四分类器根据GY、Predict计算特征提取模块和类别推理模块的对抗损失

其中,β为领域迁移超参数,one_hot(GY

4.6.4.10使用随机梯度下降算法(SGD)对

4.6.5令epoch=epoch+1,若epoch≤T

第五步,加载第四步中训练得到的特征提取模块、类别推理模块的权重参数,得到训练后的源域数据缺失下的小样本领域自适应图像分类系统。

第六步,训练后的源域数据缺失下的小样本领域自适应图像分类系统对用户输入的目标域图像(来自于STL-10数据集的9个重叠类别(airplane(飞机)、automobile(汽车)、bird(鸟)、cat(猫)、deer(鹿)、dog(狗)、horse(马)、ship(船)、truck(卡车)的4500张图片中的任意一张图片;每个重叠类别有500张图片,因此共4500张图片)进行分类,得到目标域图像的分类结果。方法是:

6.1训练后的源域数据缺失下的小样本领域自适应图像分类系统的输入数据处理模块接收用户输入的目标域图像D

6.2特征提取模块对D

6.3类别推理模块对从特征提取模块收到的Ф

测试效果如下:

1.相比于对比方法,本发明无需使用任何源域数据即可以对图像进行分类,有效解决了棘手的隐私泄露问题,实现了无源域数据下的图像分类。具体验证如下:

本领域通过计算每个中间域样本和所有源域样本的PSNR值(详见文献“HoréA,Ziou D.Image Quality Metrics:PSNR vs.SSIM[C].In ICPR.2010.”图像质量指标:PSNR与SSIM)来验证中间域样本是否包含源域数据特征。PSNR值表示在给定一个标准图像g的情况下衡量图像f的生成质量,并且定义如下:

通过上述定义可以看出,PSNR值越大意味着两张图片越相似。为了横向比较,通过计算本发明在(源域数据到源域数据),(目标域数据到目标域数据),(源域数据到目标域数据),(中间域数据到目标域数据)上的PSNR值,并在下表中展示了这四类情形的均值。

可以看到,中间域数据相比源域数据更接近于目标域数据,但中间域数据和源域数据几乎没任何相似之处。源域数据到目标域数据的平均PSNR值仅为9.32,显然小于中间域数据和目标域数据的PSNR值(50.1722)。通过这个结果,可以证明中间域数据更接近目标域同时和源域数据非常不同,从而说明生成的中间域数据不包含源域数据特征,因此不会造成源域数据隐私泄露。

2.本发明第4.6.2步将源域数据和目标域数据结合起来,构造出一个新的中间域数据集,提高了源域模型在目标域上的分类性能。具体验证如下:

经过测试,本发明在CIFAR-10到STL-10的小样本领域自适应任务上实现了73.2%的分类精度,较背景技术所述三种比较方法(WA、FT、SHOT)有明显提高,这个提高幅度在该领域有较大意义。具体数据如下表。

3.本发明第4.6.3和4.6.4步仅使用少量的目标域数据对领域迁移模块、特征提取模块和类别推理模块进行训练,使得训练速度大大加快,因此本发明在实际应用中更加实用和可行,特别是在数据规模较大且难以获得源域数据的情况下。具体验证如下:

经过测试,本发明在CIFAR-10到STL-10的小样本领域自适应任务上实现了1.74个小时的训练时间,较背景技术所述三种比较方法(WA、FT、SHOT)有明显提高,这个提高幅度在该领域有较大意义。具体数据如下表。

本发明针对的场景是隐私保护下或者源域数据缺失下的基于小样本领域自适应的图像分类场景。由于源域数据存在大量隐私信息的情况,数据所有者无法直接将其提供给技术人员,但是可以提供用这些源域数据训练好的模型(即源域假设)数据通常包含很多隐私信息,比如个人手机上的数据。因此,如果像通常的小样本领域自适应方法一样直接使用源域数据来训练目标域分类器,其中的隐私信息很可能会被泄露。本发明为了完全阻止源域隐私泄露,提出了一个基于源域数据缺失下的小样本领域自适应的的图像分类方法,进一步的提高了小样本领域自适应在图像分类领域的应用范围,解决了隐私泄露问题。

实验表明,本发明有效的解决了小样本领域自适应本身的隐私泄露问题,同时在目标域的分类性能的显著超过了现有最好的方法SHOT,精度提高1.3%。

因此本发明实现了在源域数据缺失下的基于小样本领域自适应的图像分类。

以上对本发明所提供的一种源域数据缺失下的基于小样本领域自适应的图像分类方法进行了详细介绍。本文对本发明的原理及实施方式进行了阐述,以上说明用于帮助理解本发明的核心思想。应当指出,对于本技术领域的普通研究人员来说,在不脱离本发明原理的前提下,还可以对本发明进行若干改进和修饰,这些改进和修饰也落入本发明权利要求的保护范围内。

相关技术
  • 基于多源域自注意力的小样本遥感图像分类方法及系统
  • 一种将视频模型从源域迁移到目标域的领域自适应方法
技术分类

06120116498266