掌桥专利:专业的专利平台
掌桥专利
首页

一种基于任务驱动的层次注意力网络的MRI图像分类方法

文献发布时间:2023-06-19 09:49:27


一种基于任务驱动的层次注意力网络的MRI图像分类方法

技术领域

本发明涉及一种基于任务驱动的层次注意力网络的MRI图像分类方法。

背景技术

卷积神经网络(convolutional neural network,简称CNN)由于其出色的特征提取能力被广泛用于图像分类任务中。此类网络可以直接将图像作为输入,并自行提取图像的颜色、纹理等特征,避免了传统识别算法中复杂的特征提取和模型构建过程。

由于磁共振成像(magnetic resonance imaging,简称MRI)图像较大,直接使用用于自然图像分类的CNN结构对其分类往往精度不高。所以有许多学者对自然图像分类的CNN结构做了改进,现有的基于CNN的MRI图像分类方法可以分为三类:基于感兴趣区域(regions-of-interest,简称ROI)的分类方法,基于图像块的分类方法和基于全图像的分类方法。

基于ROI的分类方法首先根据专家的领域知识对原图相关区域进行预分割并提取特征,然后构建MRI图像分类器,但是这种方法通常不能覆盖整个MRI图像所有分类相关区域,同时还需要复杂的预处理步骤。基于图像块的分类方法通常先将整幅MRI图像分割成多个图像块,然后从这些图像块中提取特征,最后简单地融合这些图像块特征用于对样本进行分类。基于图像块的方法可以更好地提取图像的局部特征并且不需要领域知识。但是,这些方法只是简单地使用了CNN的最后几层来融合图像块的特征,可能会导致整个图像中一些潜在的信息丢失。基于全图的分类方法针对整个MRI图像提取特征,它们可以获得全局的特征并且不需要专家知识。然而,由于MRI图像较大而与分类相关区域较小,这类方法无法准确定位到这些区域导致最后分类结果不佳。目前较先进的基于全图的分类方法通常会采取注意力机制,即让网络学习全图各区域的权值,通过加权的方法提升分类准确率。

综上所述,基于ROI的分类方法和基于图像块的分类方法侧重于使用不同的策略来提取具有判别性的局部特征,而基于整幅图像的方法侧重于提取整幅图像的语义特征。然而,前两种方法往往忽略了对整幅图像特征的挖掘,而最后一种方法没有充分进行与分类相关的局部区域特征的提取。总之,上述的前两种方法注重于对图像局部特征的提取,而最后一种方法注重于对图像整体特征的提取。

发明内容

本发明所要解决的技术问题是针对上述现有技术提供一种不仅可以定位与分类相关的区域,而且拥有优越的分类性能的基于任务驱动的层次注意力网络的MRI图像分类方法。

本发明解决上述技术问题所采用的技术方案为:一种基于任务驱动的层次注意力网络的MRI图像分类方法,其特征在于包括如下步骤:

步骤1、获取一定数量、且类别已知的MRI图像,对所有类别已知的MRI图像进行预处理,之后将所有预处理之后的MRI图像归一化为统一的大小,形成样本集;

步骤2、将样本集分成训练集、验证集和测试集;

步骤3、构建信息子网络,并利用训练集中的样本对构建的信息子网络进行训练,得到训练完成后的信息子网络;

构建的信息子网络包括信息图提取网络和图像块分类网络,其中,信息图提取网络为用于分类且以卷积核为1*1*1的卷积层替换掉全局平均池化层和全连接层的3D CNN,通过该信息图提取网络输出通道数为1的信息图;图像块分类网络为用于分类的3D CNN;

所述信息子网络训练的具体过程为:

步骤3-1、初始化信息图提取网络和图像块分类网络中的网络参数;

步骤3-2、在训练集中任意选择R张MRI图像,并将选取的R张MRI图像输入到初始化的信息图提取网络中,得到每张MRI图像所对应的通道数为1的信息图,R为正整数;

步骤3-3、分别提取每张信息图中的K个最高值,并根据3D CNN的映射关系选取出每张信息图中的K个最高值分别在与之对应的MRI图像中的K个L*W*H大小的图像块;其中,K、L、W和H均为正整数;

步骤3-4、将从每张MRI图像中选取出的K个L*W*H大小的图像块输入至初始化的图像块分类网络中,得到从每张MRI图像中选取的K个图像块属于各个类别的分类概率,并根据各个类别的分类概率计算出从每张MRI图像中选取的每个图像块的置信度;

步骤3-5、计算总损失函数L,并根据总损失函数L分别反向更新初始化的信息图提取网络和初始化的图像块分类网络中的网络参数,分别得到一次更新后的信息图提取网络和图像块分类网络,即得到一次训练后的信息子网络;

L=L

其中,L

步骤3-6、重复步骤3-2至步骤3-5,依次从训练集中选择不同的MRI图像,将选中的多张MRI图像输入到一次训练后的信息子网络中,不断更新信息子网络中的参数,得到训练完成后的信息子网络;

步骤4、将验证集中的MRI图像送入步骤3训练完成后的信息子网络中,筛选并保存具有最优网络参数的信息子网络;

步骤5、构建层次注意力子网络,并利用训练集中的样本对构建的层次注意力子网络进行训练,得到训练完成后的层次注意力子网络;

构建的层次注意力子网络是以3D CNN为主干的网络结构,另外,该3D CNN的卷积模块后面还设有注意力模块,注意力模块以信息图为输入,注意力模块的输出F′的计算公式为:

F′=M′⊙F,其中M′=TI(M)

其中,F为3D CNN中某个卷积模块输出的特征图,M为信息图,TI()为三线性插值函数,使M′与F的矩阵空间大小相同;⊙为点积运算;

通过上述计算公式得到的F′即作为3D CNN中输出特征图F的卷积模块下一层的输入;

所述层次注意力子网络训练的具体过程为:

步骤5-1、初始化层次注意力子网络中的网络参数;

步骤5-2、将训练集中的Q张MRI图像经过步骤4中所保存的信息子网络得到Q张信息图,并将Q张信息图作为注意力模块的输入,同时将Q张MRI图像输入到设置有注意力模块的层次注意力子网络中,Q为正整数;

步骤5-3、计算层次注意力子网络的损失函数L

其中,x

步骤5-4、依次从训练集中选择不同的MRI图像,重复步骤5-2至步骤5-3,不断更新层次注意力子网络中的参数,得到训练完成后的层次注意力子网络;

步骤6、将验证集中的MRI图像送入步骤5训练完成后的层次注意力子网络中,筛选并保存具有最优网络参数的层次注意力子网络;

步骤7、获取待分类MRI图像的类别:具体过程为:在测试集中任意选择一张MRI图像,记为待分类MRI图像I’,将该待分类MRI图像I’输入到步骤4得到的信息子网络中得到信息图M’,之后将该信息图M’和待分类MRI图像I’输入到步骤6中得到的层次注意力子网络中得到该待分类MRI图像I’属于每一类别的概率,将概率最高的值对应的类别作为该待分类MRI图像的类别。

优选的,所述步骤3中构建的信息图提取网络为一个去掉了全局平均池化和全连接层的3D ResNet18,并增加了三层卷积核为1*1*1的卷积层,构建的图像块分类网络为3DResNet10。

进一步的,所述步骤3-4中每个图像块的置信度的计算方法为:对每个图像块所属的各个类别分别进行one-hot编码,并将编码为1的类别所对应的分类概率作为该图像块的置信度。

优选的,所述步骤5中构建的层次注意力子网络是一个以3D ResNet34为主干的网络。

与现有技术相比,本发明的优点在于:本发明首先用信息子网络获取一张信息图,该信息图蕴含了原图中的各个区域对分类的重要程度,再将信息图用于分层注意子网络加强网络对重要区域特征的提取,通过此种先定位后加强局部特征的方式,显著提高了CNN网络对分类相关区域的关注并有效地结合了全图特征,提高了分类的准确率。

附图说明

图1为本发明实施例中信息子网络的原理框图;

图2为本发明实施例中层次注意力子网络的原理框图。

具体实施方式

以下结合附图实施例对本发明作进一步详细描述。

一种基于任务驱动的层次注意力网络的MRI图像分类方法,包括如下步骤:

步骤1、获取一定数量、且类别已知的MRI图像,对所有类别已知的MRI图像进行预处理,之后将所有预处理之后的MRI图像归一化为统一的大小,形成样本集;

步骤2、将样本集分成训练集、验证集和测试集;

步骤3、构建信息子网络,并利用训练集中的样本对构建的信息子网络进行训练,得到训练完成后的信息子网络;

构建的信息子网络包括信息图提取网络和图像块分类网络,其中,信息图提取网络为用于分类且以卷积核为1*1*1的卷积层替换掉全局平均池化层和全连接层的3D CNN,通过该信息图提取网络输出通道数为1的信息图;图像块分类网络为用于分类的3D CNN;

信息子网络训练的具体过程为:

步骤3-1、初始化信息图提取网络和图像块分类网络中的网络参数;

步骤3-2、在训练集中任意选择R张MRI图像,并将选取的R张MRI图像输入到初始化的信息图提取网络中,得到每张MRI图像所对应的通道数为1的信息图,R为正整数;

步骤3-3、分别提取每张信息图中的K个最高值,并根据3D CNN的映射关系选取出每张信息图中的K个最高值分别在与之对应的MRI图像中的K个L*W*H大小的图像块;其中,K、L、W和H均为正整数;

步骤3-4、将从每张MRI图像中选取出的K个L*W*H大小的图像块输入至初始化的图像块分类网络中,得到从每张MRI图像中选取的K个图像块属于各个类别的分类概率,并根据各个类别的分类概率计算出从每张MRI图像中选取的每个图像块的置信度;

其中,每个图像块的置信度的计算方法为:对每个图像块所属的各个类别分别进行one-hot编码,并将编码为1的类别所对应的分类概率作为该图像块的置信度;

步骤3-5、计算总损失函数L,并根据总损失函数L分别反向更新初始化的信息图提取网络和初始化的图像块分类网络中的网络参数,分别得到一次更新后的信息图提取网络和图像块分类网络,即得到一次训练后的信息子网络;

L=L

其中,L

步骤3-6、重复步骤3-2至步骤3-5,依次从训练集中选择不同的MRI图像,将选中的多张MRI图像输入到一次训练后的信息子网络中,不断更新信息子网络中的参数,得到训练完成后的信息子网络;

步骤4、将验证集中的MRI图像送入步骤3训练完成后的信息子网络中,筛选并保存具有最优网络参数的信息子网络;

步骤5、构建层次注意力子网络,并利用训练集中的样本对构建的层次注意力子网络进行训练,得到训练完成后的层次注意力子网络;

构建的层次注意力子网络是以3D CNN为主干的网络结构,另外,该3D CNN的卷积模块后面还设有注意力模块,注意力模块以信息图为输入,注意力模块的输出F′的计算公式为:

F′=M′⊙F,其中M′=TI(M)

其中,F为3D CNN中某个卷积模块输出的特征图,M为信息图,TI()为三线性插值函数,使M′与F的矩阵空间大小相同;⊙为点积运算;

通过上述计算公式得到的F′即作为3D CNN中输出特征图F的卷积模块相邻的下一层网络的输入;

卷积模块包括有多个卷积层,每个卷积层后还有激活函数;

层次注意力子网络训练的具体过程为:

步骤5-1、初始化层次注意力子网络中的网络参数;

步骤5-2、将训练集中的Q张MRI图像经过步骤4中所保存的信息子网络得到Q张信息图,并将Q张信息图作为注意力模块的输入,同时将Q张MRI图像输入到设置有注意力模块的层次注意力子网络中,Q为正整数;

步骤5-3、计算层次注意力子网络的损失函数L

其中,x

步骤5-4、依次从训练集中选择不同的MRI图像,重复步骤5-2至步骤5-3,不断更新层次注意力子网络中的参数,得到训练完成后的层次注意力子网络;

步骤6、将验证集中的MRI图像送入步骤5训练完成后的层次注意力子网络中,筛选并保存具有最优网络参数的层次注意力子网络;

步骤7、获取待分类MRI图像的类别:具体过程为:在测试集中任意选择一张MRI图像,记为待分类MRI图像I’,将该待分类MRI图像I’输入到步骤4得到的信息子网络中得到信息图M’,之后将该信息图M’和待分类MRI图像I’输入到步骤6中得到的层次注意力子网络中得到该待分类MRI图像I’属于每一类别的概率,将概率最高的值对应的类别作为该待分类MRI图像的类别。

为了能更好的说明本发明提出的基于任务驱动的层次注意力网络的MRI图像分类方法的作用,本实施例中将该方法应用于判断某一张MRI图像中是否包含患阿尔茨海默病的图像特征,其中,如图1所示,本实施例中的具体使用参数及具体方法为:步骤1中对原有的脑部MRI图像数据进行仿射配准等预处理,上述预处理步骤通过FMRIB SoftwareLibrary 5.0完成,并通过三线性插值和填充0值的方式,将所有图像统一至相同的大小(128*128*128),通道数设为1;步骤2中训练集、验证集和测试集的优选比例为7:2:1,并要求训练集和验证集都包含MRI图像总类别数;步骤3中构建的信息图提取网络为一个去掉了全局平均池化和全连接层的3D ResNet18,并增加了三层卷积核为1*1*1的卷积层,构建的图像块分类网络为3D ResNet10,具体的网络结构如表1所示:

另外,步骤3-4中MRI图像的分类中使用的K值为4,L*W*H大小为48*48*48;根据上述步骤3-4中得到的4个48*48*48的图像块输入到图像块分类网络中,即可得到每个图像块的分类概率,如图1中所示,本实施例中,图像块所属的类别共有两类,即包含有患有阿尔兹海默症特征图像的第一类MRI图像和不包含患有阿尔兹海默症特征图像的第二类MRI图像,对每个MRI图像中选取的每个图像块所属的类别分别进行one-hot编码,图1中c

表1各网络具体结构

本发明中的方法与基于ROI和图像块的方法相比,优点为:不仅考虑局部病变区域的特征提取,同时还结合了全图的整体结构特征,进一步提高了分类能力;另外,本发明中的方法与基于全图的CNN方法相比,优点为:本发明采用了注意力机制,让网络更注重于对局部特征的提取,并且不丢失全局信息;另外,本发明中注意力块从浅层到深层使用,并且总是与分类区域有关,而一般的空间注意力网络通常是从网络本身生成注意图,此类方法由于梯度消失,浅层生成的注意图不能直接与疾病相关区域重合,这会阻碍一般注意力网络的工作。因此本发明中的方法不仅可以定位与分类相关的区域,而且拥有优越的分类性能。

相关技术
  • 一种基于任务驱动的层次注意力网络的MRI图像分类方法
  • 一种基于层次注意力网络模型的恶意软件分类方法
技术分类

06120112316396