掌桥专利:专业的专利平台
掌桥专利
首页

一种基于全局-局部知识蒸馏的跨域小样本图像分类方法

文献发布时间:2023-06-19 19:28:50



技术领域

本发明属图像处理技术领域,具体涉及一种基于全局-局部知识蒸馏的跨域小样本图像分类方法。

背景技术

图像处理是机器视觉走向工业应用的关键技术,而图像分类是图像处理技术的基础。在医学、遥感等多种场景下,图像数据往往难以获取,呈现典型的小样本特性。为了缓解小样本问题,一种有效的方式是利用源域数据学习可迁移的知识,并将学习到的知识泛化到目标域的小样本任务中。然而,由于源域与目标域之间存在域差异,导致源域上训练的模型难以有效地泛化到目标域中。为此,研究适用于跨域场景下的小样本图像分类技术具有重要的应用价值。文献“Snell J,Swersky K,Zemel R.Prototypical networks for few

发明内容

为了克服现有技术的不足,本发明提供一种基于全局-局部知识蒸馏的跨域小样本图像分类方法。构建了由全局分支和局部分支构成的分类模型,其中,全局分支以原始图像为输入,用于提取图像的全局特征,局部分支以原始图像的局部块为输入,用于提取该图像的局部特征;在两分支之间,通过构建全局-局部知识蒸馏损失促进全局特征关注到图像的局部区域,使得全局特征捕获丰富的语义信息,进而提升全局特征在跨域小样本任务上的泛化性能。

一种基于全局-局部知识蒸馏的跨域小样本图像分类方法,其特征在于步骤如下:

步骤1:基于现有的图像数据集构建小样本任务训练数据集,包括支持集

步骤2:构建模型的全局分支,其处理过程如下:

首先,按照下式获得支持集

其中,

然后,基于原型表示对查询集

其中,

接着,根据预测得分中的最大相似度对应的类别作为该查询样本的预测标签

其中,H(·)表示交叉熵损失函数,

步骤3:构建模型的局部分支,其处理过程如下:

对于查询样本

然后,使用局部分支中的特征提取网络

接着,使用步骤2计算的原型对局部特征

其中,

步骤4:按照下式计算模型的总损失

其中,I表示小样本任务中查询样本的总个数,

所述的查询样本

所述的跨图像的局部-全局蒸馏损失

其中,

步骤5:根据步骤4计算的模型总损失,使用随机梯度下降法,端到端的训练全局分支的网络参数,并按下式进行局部分支的网络参数的更新:

θ

其中,θ

步骤6:将待处理图像数据集输入到步骤5训练后得到的全局分支,预测得到其中每幅图像的隶属类别,完成图像分类。

本发明的有益效果是:通过训练阶段构建的全局-局部知识蒸馏框架促进全局特征关注到图像的局部信息,从而使模型能够学习到泛化性强的语义表征,提升在跨域小样本任务上的泛化性能;采用端到端的框架设计方式,一旦模型在源域(训练数据集)上训练完成之后,即可在任意目标域(待处理图像数据集)的小样本任务上进行测试,无需微调特征提取模型;本发明能够在跨域小样本图像分类中获得较好的分类效果。

具体实施方式

下面结合实施例对本发明进一步说明,本发明包括但不仅限于下述实施例。

本发明提供了一种基于全局-局部知识蒸馏的跨域小样本图像分类方法,其具体实现过程如下:

1、构建小样本任务训练数据集

跨域小样本图像分类任务要求模型在源域

2、全局分支计算

构建模型的全局分支,其处理过程如下:

首先,按照下式获得支持集

其中,

然后,基于原型表示对查询集

其中,

接着,根据预测得分中的最大相似度对应的类别作为该查询样本的预测标签

其中,H(·)表示交叉熵损失函数,

3、局部分支计算

构建模型的局部分支,其处理过程如下:

对于查询样本

然后,与步骤2类似,先使用局部分支中的特征提取网络

其中,

4、计算总损失

按照下式计算模型的总损失

其中,I表示小样本任务中查询样本的总个数,

所述的查询样本

所述的跨图像的局部-全局蒸馏损失

其中,

5、训练模型

根据步骤4计算的模型总损失,使用随机梯度下降法,端到端的训练全局分支的网络参数。对于局部分支的网络参数,使用全局分支的指数移动平均进行参数更新,即:

θ

其中,θ

6、图像分类

模型训练完成之后,丢弃局部分支,只保留全局分支对目标域中的小样本图像进行分类处理,即将待处理图像数据集输入到步骤5训练后得到的全局分支,按照步骤2的计算过程,预测得到其中每幅图像的隶属类别,完成图像分类任务。

本发明能够在跨域小样本图像分类任务中取得较好的分类性能。例如,本实施例使用mini-ImageNet数据集作为源域的训练数据集进行模型训练,然后对遥感场景分类数据集EuroSAT以及医学图像数据集ISIC作为目标域进行分类处理,本发明方法在5-way 1-shot任务上(支持集中包含5个类别,每个类别中有1个样本)分别取得了63.70%和33.51%的分类精度,相比于现有基于原型的小样本图像分类方法分别提高了4.59%和1.78%。

相关技术
  • 一种基于多教师知识蒸馏的跨域小样本识别方法
  • 基于自监督和小样本学习的跨域高光谱图像分类方法
技术分类

06120115921494