掌桥专利:专业的专利平台
掌桥专利
首页

一种基于领域自适应的迁移学习方法及系统

文献发布时间:2023-06-19 19:30:30


一种基于领域自适应的迁移学习方法及系统

技术领域

本发明涉及人工智能医学图像处理技术领域,尤其涉及一种基于领域自适应的迁移学习方法及系统。

背景技术

OCT(光学相干断层扫描技术),因其能通过测量组织不同深度反射回来的光信号从而对生物组织的横截面进行快速、无损伤的高分辨率成像,是近30年来发展较为迅速的一种生物医学成像技术,在眼科诊所中,OCT视网膜图像被视为眼科医生筛查糖尿病性黄斑水肿、年龄相关性黄斑变性等常见视网膜疾病的一个重要依据,然而这种人工诊断的方式不仅费时费力且容易受到医生主观经验的影响,近年来,以卷积神经网络代表的深度学习技术(DL),作为机器学习最耀眼的分支,在有关眼科疾病的自动诊断任务中取得了显著的进展。相较于人工诊断的方式,基于DL的计算机辅助诊断系统(CAD)能够根据输入图像自动地输出精确的诊断结果,若能将基于DL的CAD系统真正应用于临床,将会给眼科疾病诊治管理带来巨变,然而,以往基于监督式的DL模型的训练和更新均依赖于大量带有标注的数据。由于在医学领域,对大量数据进行标注需要花费巨大的成本。因此,尽管医院可以提供大量的OCT视网膜图像,但这些图像往往都是很初级的原始形态,即很少有数据被加以正确的人工标注。标注的数据过少,导致DL模型性能不佳,且对于通过传统的预训练加微调的传统迁移学习方法,由于自然图像和OCT图像存在较大差异,因此,即使借助了传统迁移学习的方法,仍需一定数量的OCT标注图像去微调网络,这限制了基于DL的OCT视网膜疾病自动检测系统在真实临床实践中的应用。

发明内容

为了解决上述技术问题,本发明的目的是提供一种基于领域自适应的迁移学习方法及系统,能够加快模型的收敛速度并提高模型的识别分类精度。

本发明所采用的第一技术方案是:一种基于领域自适应的迁移学习方法,包括以下步骤:

对OCT视网膜图像数据集进行图像预处理,得到预处理后的图像数据集;

结合领域自适应损失函数,基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到最终的领域自适应模型;

通过最终的领域自适应模型对OCT视网膜图像数据集进行预测,得到OCT视网膜图像对应的疾病类型。

进一步,所述对OCT视网膜图像数据集进行图像预处理,得到预处理后的图像数据集这一步骤,其具体包括:

通过谱域OCT系统获取OCT视网膜图像数据集;

对OCT视网膜图像数据集进行随机水平翻转处理,得到翻转后的图像数据集;

对翻转后的图像数据集进行随机剪裁处理,得到剪裁后的图像数据集;

对剪裁后的图像数据集进行归一化处理,得到预处理后的图像数据集。

进一步,所述预处理后的图像数据集包括完全标记的图像数据集与部分标记的图像数据集,所述结合领域自适应损失函数,基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到最终的领域自适应模型这一步骤,其具体包括:

基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到训练后的领域自适应模型;

基于部分标记的图像数据集对训练后的领域自适应模型进行微调处理,得到最终的领域自适应模型。

进一步,所述基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到训练后的领域自适应模型这一步骤,其具体包括:

将预处理后的图像数据集输入至初步的领域自适应模型,所述初步的领域自适应模型包括特征提取器、瓶颈层和分类器;

基于特征提取器对预处理后的图像数据集进行特征提取处理,得到预处理后的图像数据集的特征向量;

基于瓶颈层对预处理后的图像数据集的特征向量进行降维处理,得到低维的完全标记图像数据集特征向量与低维的部分标记图像数据集特征向量;

通过领域自适应损失函数对低维的完全标记图像数据集特征向量与低维的部分标记图像数据集特征向量进行差值计算,得到分布差异值;

对完全标记的图像数据集与部分标记的图像数据集进行赋予标签处理,得到具有真实标签的完全标记的图像数据集与具有伪标签的部分标记的图像数据集;

基于分类器,结合分布差异值与交叉熵损失函数对具有真实标签的完全标记的图像数据集与具有伪标签的部分标记的图像数据集进行分类,得到分类结果;

根据分类结果更新初步的领域自适应模型,得到训练后的领域自适应模型。

进一步,所述通过领域自适应损失函数对低维的完全标记图像数据集特征向量与低维的部分标记图像数据集特征向量进行差值计算,得到分布差异值这一步骤,其具体包括:

将低维的完全标记图像数据集特征向量与低维的部分标记图像数据集特征向量进行映射处理,得到映射后的完全标记图像数据集特征向量与映射后的部分标记图像数据集特征向量;

对映射后的完全标记图像数据集特征向量与映射后的部分标记图像数据集特征向量进行加权平均计算,得到完全标记图像数据集的平均值与部分标记图像数据集的平均值;

对完全标记图像数据集的平均值与部分标记图像数据集的平均值进行作差计算,得到分布差异值。

进一步,所述领域自适应损失函数的表达式具体如下所示:

/>

上式中,C表示类别总数,H表示再生核希尔伯特空间(RKHS),

进一步,所述交叉熵损失函数的表达式具体如下所示:

上式中,N表示完全标记的图像数据集中的样本总数,C表示类别总数,

本发明所采用的第二技术方案是:一种基于领域自适应的迁移学习系统,包括:

预处理模块,通过对OCT视网膜图像数据集进行图像预处理,得到预处理后的图像数据集;

训练模块,用于结合领域自适应损失函数,基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到最终的领域自适应模型;

预测模块,用于通过最终的领域自适应模型对OCT视网膜图像数据集进行预测,得到OCT视网膜图像对应的疾病类型。

本发明方法及系统的有益效果是:本发明通过获取完全标记的图像数据集与部分标记的图像数据集对本发明的领域自适应模型进行训练,在训练过程中加入领域自适应损失函数减少完全标记的图像数据集与部分标记的图像数据集之间的特征差异,使得能够成功地进行知识迁移,就可以避免昂贵的数据标记工作,从而大大提高模型的学习性能,再通过部分标记的图像数据集对训练后的领域自适应模型进行微调处理,有效的防止模型过拟合问题,能够加快模型的收敛速度并提高模型的泛化能力进而提高模型对没有标记的OCT视网膜图像的分类精度。

附图说明

图1是本发明一种基于领域自适应的迁移学习方法的步骤流程图;

图2是本发明一种基于领域自适应的迁移学习系统的结构框图;

图3是本发明基于领域自适应模型训练的方法流程图;

图4是本发明构建的领域自适应模型的结构框图;

图5是本发明对领域自适应模型进行微调的流程示意图;

图6是本发明基于最终的领域自适应模型进行OCT图像预测示意图。

具体实施方式

下面结合附图和具体实施例对本发明做进一步的详细说明。对于以下实施例中的步骤编号,其仅为了便于阐述说明而设置,对步骤之间的顺序不做任何限定,实施例中的各步骤的执行顺序均可根据本领域技术人员的理解来进行适应性调整。

参照图1,本发明提供了一种基于领域自适应的迁移学习方法,该方法包括以下步骤:

S1、对OCT视网膜图像数据集进行图像预处理,得到预处理后的图像数据集;

具体地,收集两组OCT视网膜图像数据集(一个完全标注的数据集和一个部分标记的数据集),然后对图像进行预处理,包括:随机水平翻转、随机裁剪、归一化等,得到预处理后的图像数据集。

S2、由于面临着不同的用户群、不同厂商的采集设备有各自的图像特点等现实问题,完全标记的图像数据集与部分标记的图像数据集具有不同分布,因此结合领域自适应损失函数,基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到最终的领域自适应模型;

具体地,参照图4,所搭建的领域自适应模型是一个由特征提取器,瓶颈层以及分类器等组成的卷积神经网络,这里的特征提取器使用的是去除全连接层后的ResNet-50网络,目的是提取不随域变化的特征向量;瓶颈层是一个全连接层,目的是把特征向量降维;分类器也是一个全连接层,目的是能够根据降维后的特征向量对视网膜病变进行精准的分类,整个领域自适应模型包含两个损失函数,分别是领域自适应损失以及由真实标签和伪标签组成的分类损失。

S21、计算预处理后的图像数据集的样本权重;

具体地,所述权值的计算为了对齐子领域之间的特征分布,对于完全标注数据集中的样本,权重

上式中,

对于完全标注数据集中的样本,由于它们带有标签,故将真实标签转成one-hot分布去计算权重,而对于未完全标注数据集中的样本,由于没有标签,故使用网络预测的概率分布去计算权重。

S22、基于预处理后的图像数据集对领域自适应模型进行训练;

具体地,参照图3,依次输入完全标注数据集和未完全标注数据集中的图片到卷积神经网络(ResNet-50),分别得到完全标注数据集中图片的特征向量和未完全标注数据集中图片的特征向量;接着将这些特征向量通过瓶颈层去实现降维,得到低维的特征向量;再接着使用局部最大均值差异衡量降维后的完全标注数据集和未完全标注数据集的特征差异。换言之,该公式表示的意思是:首先将完全标注数据集中的所有样本映射到高维空间,然后根据每个样本与类别c的关联程度,计算加权平均,表示完全标注数据集中的类别c的特征在高维空间的平均值,同理计算出未完全标注数据集中类别c的特征在高维空间的平均值。最后将它们差值表示的是属于类别c的子领域的分布差异,其计算公式如下所示:

上式中,C表示类别总数,H表示再生核希尔伯特空间(RKHS),

进一步的,由于无法直接计算

上式中,k表示核函数,f

S23、预处理后的图像数据集的分类损失计算;

具体地,由于领域自适应损失的作用,两种不同数据集之间的特征分布差异被减少,此外,当训练数据集和测试数据集分布不一致时,通过在训练数据集上按经验误差最小准则训练得到的模型在测试数据集上性能不佳,此时,减少他们之间的特征分布差异就能成功地进行知识迁移,就可以避免昂贵的数据标记工作,从而大大提高模型的学习性能,因为自适应后的特征会交给分类器进行特征映射,映射到样本的标记空间,也就是说,如果分类器能够根据自适应后的特征正确做出预测,就说明自适应后的特征是有意义的,为了使得经过自适应后的特征是有意义的,还需要在特征提取后面加一个分类器,对于完全标注的数据集中的样本,本发明使用的是交叉熵损失去更新分类器参数,如下所示:

上式中,N表示完全标记的图像数据集中的样本总数,C表示类别总数,

然而,仅使用带有真实标签的代表特征去训练分类器可能会导致DL模型过拟合于完全标注数据集,使得模型在未完全标注数据集中的效果不佳,为了解决这个问题,本发明同时还使用带有伪标签的代表特征去训练分类器,基于伪标签的分类损失函数如下所示,也是基于交叉熵(与上面不同的是,由于在领域自适应中,未完全标注数据集图片是没有预先标记好的类别标签,对此我们使用的是伪标签计算交叉熵),其计算公式如下所示:

上式中,M表示部分标记的图像数据集中的样本总数,

对于伪标签的获取,其具体过程为:用网络去预测一个无标签的样本,如果网络有大于95%的自信度觉得它属于a类,那么可以认为这个样本就是a类,此时就得到这个样本的伪标签,且本发明的标签赋予过程存在不一致的处理手段过程,真实标签指的专业医生根据自身经验给出的标注,而伪标签指的是模型输出、具有高自信度(某一样本的预测值大于95%)的预测值。

S24、对训练后的领域自适应模型进行微调处理;

具体地,参照图5,通过降低前几层的学习率,实现对自适应后的模型的微调,使用较小的学习率来训练模型的全连接层,而全连接层前面部分的学习率设为0,即在微调的过程中,仅对模型的全连接层的参数进行更新,全连接层前面部分的模型参数保持不变。

S3、通过最终的领域自适应模型对OCT视网膜图像数据集进行预测,得到OCT视网膜图像对应的疾病类型。

具体地,参照图6,将输入图像送入最终模型,模型的最后一层全连接层会输出图像所属于每个类别的概率值,最大的概率即对应于图像所属的疾病类别;例如最终模型是预测AMD,DME以及正常三个类别,模型最后一层给出的概率值分别为:92%,6%,2%,则模型预测这张图像属于AMD。

参照图2,一种基于领域自适应的迁移学习系统,包括:

预处理模块,通过对OCT视网膜图像数据集进行图像预处理,得到预处理后的图像数据集;

训练模块,用于结合领域自适应损失函数,基于预处理后的图像数据集对初步的领域自适应模型进行训练,得到最终的领域自适应模型;

预测模块,用于通过最终的领域自适应模型对OCT视网膜图像数据集进行预测,得到OCT视网膜图像对应的疾病类型。

上述方法实施例中的内容均适用于本系统实施例中,本系统实施例所具体实现的功能与上述方法实施例相同,并且达到的有益效果与上述方法实施例所达到的有益效果也相同。

以上是对本发明的较佳实施进行了具体说明,但本发明创造并不限于所述实施例,熟悉本领域的技术人员在不违背本发明精神的前提下还可做作出种种的等同变形或替换,这些等同的变形或替换均包含在本申请权利要求所限定的范围内。

相关技术
  • 一种基于特征迁移和自适应学习的人民调解案例分类系统及方法
  • 一种物联网迁移学习方法和系统
  • 一种基于细粒度领域自适应的图迁移学习方法
  • 一种基于细粒度领域自适应的图迁移学习方法
技术分类

06120115934138