掌桥专利:专业的专利平台
掌桥专利
首页

一种用于农作物病害识别的轻量级注意力机制网络

文献发布时间:2023-06-19 19:28:50


一种用于农作物病害识别的轻量级注意力机制网络

技术领域

本发明涉及深度学习领域,尤其是涉及一种用于农作物病害识别的轻量级注意力机制网络,属于深度学习模型在农作物病害识别领域的应用。

背景技术

农作物的病害影响着农作物的生长,降低农作物的产量,影响品质。由于农业具有重要的地位,农作物病害对农作物又有着严重的危害,因此快速准确的识别农作物的病害变得尤为重要。

要解决农作物病害问题,一个关键的方面是快速精准的识别农作物的病害,然后再对症下药。农作物病害不好判别,对农作物病害准确高效的判断是主要的挑战。近些年,深度学习技术得到较快发展,其应用领域广泛。在图像识别领域,卷积神经网络网络取得十分不错的效果,能够有效提取图像的特征,对图像进行分类。研究人员提出各种各样的卷积神经网,比如VGGNet、GoogLeNet、ResNet等。VGGNet是一个相对较深的模型,网络结构也相对简单,并取得不错的效果。GoogLeNet中采用Inception模块进行搭建,该模块采用多分支的结构,使用多个卷积层提取不同的信息,提升网络的表征能力。在ResNet中,提出残差连接技术,能够构造出比较深的卷积神经网络,取得很好效果。在图像识别领域,之前主要是卷积神经网络占据着主导地位,近年来,Transformer模型也被应用于计算机视觉领域。Transformer模型在自然语言处理(NLP)领域取得出众的效果,Transformer模型主要的核心点是自注意力机制,这不同于卷积神经网络和循环神经网络。Vision Transformer(ViT)模型被应用到计算机视觉领域,Vision Transformer(ViT)模型使用Transformer结构去处理视觉任务,并取得不错的结果。ViT模型将图片划分成一个个不重叠的patch,然后对这些patch进行线性映射,之后再输入到Transformer模型中进行计算。图像的patch就类似NLP任务中token。

随着深度学习的快速发展,深度学习技术的应用领域也越来越广泛,深度学习技术逐渐被用在农作物的病害识别上面。许多深度学习模型有着出众的效果,能够比较准确的识别农作物的病害,但其往往参数量比较大,对计算和存储资源要求比较高,因此很难在一些移动端和嵌入式设备上面部署使用。此外Transformer模型在视觉领域也有着不错的效果,其能够学习全局表征,但其参数量也较大,也很难在移动端和嵌入式设备上使用。因此,设计一种轻量级的并且能够有效识别农作物病害的模型十分有意义。

发明内容

本发明的目的在于针对现有技术存在的上述问题,为能更好的识别农作的病害,同时考虑模型的参数量和复杂度,能够在一些移动端和嵌入式设备部署使用,提供一种用于农作物病害识别的轻量级注意力机制网络,基于一种轻量级的网络模型搭建,加入注意力机制。

所述一种用于农作物病害识别的轻量级注意力机制网络,以MobileViT模型为基础,在MobileViT模型的部分MobileViT块中加入通道注意力机制,在MobileViT模型的最后加入通道注意力机制和空间注意力机制;

所述通道注意力机制基于CBAM注意力机制,还包括一维卷积;通道注意力机制用于分析图像通道间的关系,给每个通道一个权重,来获取关键信息,从而提升网络的性能;

所述空间注意力机制基于CBAM注意力机制,还包括多分支网络结构、空洞卷积;所述多分支网络结构由不同大小的卷积核构建,所述空洞卷积层用于增加感受野;空间注意力机制用于分析图像空间之间的关系,给每个像素点一个权重,从而在空间维度获取重要的信息。

所述MobileViT模型搭能有效地学习局部表征和全局表征,为能够更好的捕获农作物病害信息,在模型中加入的改进的注意力机制,使用PlantVillage公开数据集所有的数据进行训练测试,训练方法具体如下:

采用公开数据集PlantVillage训练模型,PlantVillage是一个公开的农作病害数据集,包含38个类别;将PlantVillage数据集随机按照6∶2∶2的比例划分成训练集、验证集和测试集;其中训练集用于模型的训练,验证集用于在训练过程中检验模型状态,测试集用于最后测试模型的效果;

在模型训练过程中,训练集中的每个样本由输入图像和该图像对应的真实类别标签构成;将训练集中的样本数据输入到模型中,得到模型的预测输出,模型的输出为一个向量,假设有C个类别,那么会输出一个总共有C个元素的向量,每个位置代表这个类别的概率;输入图像的真实标签也为一个包含C个元素的向量,该向量的元素只有一个是1,其他都为0,为1的那个位置代表图像的真实类别标签;将模型输出的结果(即预测的标签)与输入图像的真实标签进行比较,通过交叉熵损失函数计算损失,交叉熵损失函数的计算公式为

(1)输入一批图像样本数据到模型中。

(2)通过模型计算出这批数据的预测类别。

(3)将模型输出得到的预测类别和真实类别比较,通过交叉熵损失函数计算损失。

(4)进行反向传播操作,计算模型参数的梯度,采用AdamW优化器更新网络模型参数。

(5)重复上述步骤,当达到训练设定的次数时结束。

以下详细说明本发明网络模型的细节和重点,本发明中的网络模型核心部分包括注意力机制和MobileViT块,其中注意力机制包含通道注意力机制和空间注意力机制。

1)构建通道注意力机制,用于分析图像通道间的关系,给每个通道一个权重,获取关键信息,从而提升网络的性能;

2)构建空间注意力机制,给每个像素点一个权重,对于输入的图像,通过全局平均池化和最大池化将通道信息进行压缩,使用不同大小的卷积核构建多分支网络,以更好的融合信息,提升网络的表征能力,捕获农作物病害信息;使用空洞卷积以增加感受野;将压缩后的图像通过多个卷积核进行处理,结果相加融合,再通过sigmoid计算得到注意力分数;

3)构建用于农作物病害识别的轻量级注意力机制网络,采用MobileViT模型作为基础模型,在MobileViT模型的部分MobileViT块中加入通道注意力机制,在MobileViT模型的最后加入通道注意力和空间注意力机制,实现更好地捕获通道信息和空间信息。

在步骤1)中,所述构建通道注意力机制基于CBAM注意力机制构建,通道注意力机制使用多层感知机(MLP),包含两个全连接层,引入一维卷积以缓解参数量太大的问题:

假设输入的图像为

通过全局平均池化和最大池化操作压缩图像的空间信息,分别得到

其中,

在步骤2)中,所述构建空间注意力机制,不同于通道注意力机制,空间注意力关注的是空间之间的关系,给每个像素点一个权重,从而在空间维度获取重要的信息;在空间注意力机制中,对于输入的图像,通过全局平均池化和最大池化将通道信息进行压缩,分别得到

其中,s

在步骤3)中,所述MobileViT模型包含MV2块和MobileViT块;MV2是MobileNetv2中的倒残差结构;MobileViT块为MobileViT模型的核心模块,用于学习局部表征以及全局表征;所述构建农作物病害识别模型包括局部建模和全局建模:

(1)局部建模:对于输入的图像

其中,d为X

(2)全局建模:包括Unfold、Transformer模块计算和Fold操作;局部建模后得到的输出X

X

其中,f

与现有技术相比,本发明具有以下突出的技术效果和优点:

本发明基于轻量级的Transformer模型即MobileViT网络模型搭建,该模型能够有效地学习局部表征和全局表征,为能够更好的捕获农作物的病害信息,在模型中加入的改进的注意力机制,之后使用PlantVillage公开数据集所有的数据进行训练测试,总共包含38个类别,通过在PlantVillage公开数据集上评估验证,本发明中的模型取得99.60%的识别准确率,说明本发明中的网络模型的有效性。

本发明与一些现有的工作进行相比较,比如在《Using Deep Learning forImage-Based Plant Disease Detection》文献中,作者使用GoogLeNet模型在PlantVillage公开数据集上取得99.35%的准确率。在《Tomato cropdiseaseclassification using pre-trained deep learning algorithm》文献中作者Rangarajan等人选用PlantVillage数据集中的西红柿图像,采用VGG16模型取得96.19%的准确率。在《Grapedisease image classification based on lightweight convolution neuralnetworks and channelwise attention》文献中,在ShuffleNet中加入通道注意力机制,选取在PlantVillage数据集中的葡萄图像进行识别,取得99.14%的准确率。与现有的技术相比较,本发明的优点在于一方面使用MobileViT模型,该网络模型能够有效地学习局部表征和全局表征。另一方面考虑到有些农作物病害比较微小,难以识别,因此加入改进的注意力机制,包括通道注意力机制和空间注意力机制,同时考虑通道维度和空间维度,使得模型能够将重点关注于农作物病害图片中的病害区域,从而有效地识别农作物病害。

本发明对卷积神经网络和Transformer模型进行研究,卷积神经网络和Transformer模型在图像的识别中取得很好的效果,能够用于农作物病害的识别。同时考虑到一些深度学习模型具有较大的参数量,很难在一些嵌入式和移动端设备上面使用,因此选用轻量级模型进行农作物病害的识别。此外对注意力机制进行研究,将注意力机制加入到模型中,用于学习重要的信息,提升模型的表征能力,从而能够有效地识别农作物的病害。

附图说明

图1为输入的农作物病害图片。

图2为加入通道注意力机制的MobileViT块的结构图。

图3为(改进的)MoibleViT块中使用Transformer模块计算时分组的划分方式,其中相同颜色的为一组。

图4为注意力机制结构图,包括通道注意力和空间注意力。

图5为网络模型预测输出的结果图。

具体实施方式

为能更详细的说明本发明,下面结合附图和实施例进行详细的说明。

首先说明本发明中网络模型的训练方式。

在本发明中,采用公开数据集PlantVillage训练模型,PlantVillage是一个公开的农作病害数据集,包含38个类别。将PlantVillage数据集随机按照6∶2∶2的比例划分成训练集、验证集和测试集。其中训练集用于模型的训练,验证集用于在训练过程中检验模型状态,测试集用于最后测试模型的效果。

在模型训练过程中,训练集中的每个样本由输入图像和该图像对应的真实类别标签构成。将训练集中的样本数据输入到模型中,得到模型的预测输出,模型的输出为一个向量,假设有C个类别,那么会输出一个总共有C个元素的向量,每个位置代表这个类别的概率。输入图像的真实标签也为一个包含C个元素的向量,该向量的元素只有一个是1,其他都为0,为1的那个位置代表图像的真实类别标签。将模型输出的结果(即预测的标签)与输入图像的真实标签进行比较,通过交叉熵损失函数计算损失,交叉熵损失函数的计算公式为

(1)输入一批图像样本数据到模型中。

(2)通过模型计算出这批数据的预测类别。

(3)将模型输出得到的预测类别和真实类别比较,通过交叉熵损失函数计算损失。

(4)进行反向传播操作,计算模型参数的梯度,采用AdamW优化器更新网络模型参数。

(5)重复上述步骤,当达到训练设定的次数时结束。

接下来详细说明本发明网络模型的细节和重点。

本发明构建一种有效的注意力机制,加入网络模型中,提升网络的性能,能够有效地捕获农作物病害信息。注意力机制广泛应用于自然语言处理,计算机视觉等领域。注意力机制能够指明哪些是重要信息,从而忽略掉不重要的信息,使模型做出更加准确的判断。在农作物病害识别模型中,引入注意力机制,让模型能够有效地捕获病害信息,提高识别准确率。

本发明主要对CBAM注意力机制进行改进,在保证性能的同时,不会太过于复杂。CBAM包括通道注意力和空间注意力机制。通道注意力机制主要是分析图像通道间的关系,给每个通道一个权重,来获取关键信息,从而提升网络的性能。本发明构建的注意力机制如图4所示,在图4的上面部分为通道注意力机制的计算过程,下面部分为空间注意力机制的计算过程,输入的图像经过通道注意力和空间注意力机制计算之后得到的输出再和输入图像相加,得到最后的输出,通道注意力机制和空间注意力机制的具体计算过程如下所述。

本发明构建的通道注意力主要对CBAM注意力机制进行改进。在CBAM中,通道注意力机制使用多层感知机(MLP),其中包含两个全连接层,导致参数量较大。针对这个问题,本发明根据ECA-Net中的思想,引入一维卷积来缓解参数量太大的问题。假设输入的图像为

其中,

考虑完通道维度的注意力之后,开始构建空间注意力机制,不同于通道注意力机制,空间注意力关注的是空间之间的关系,给每个像素点一个权重,从而在空间维度获取重要的信息。在空间注意力机制中,对于输入的图像,首先通过全局平均池化和最大池化将通道信息进行压缩,分别得到

/>

其中,

采用MobileViT模型作为基础的模型。对于Transformer模型来说,它能够基于自注意力机制学习全局表征,这是它的优势。而对于卷积神经网络来说,其具有空间归纳偏置,能够通过较少的参数学习局部表征。考虑到卷积神经网络和Transformer模型有各自的优势,本发明中选择MobileViT模型构建农作物病害识别模型,该模型结合卷积神经网络和Transformer的优势。在MobileViT中包含的主要模块是MV2块和MobileViT块。MV2块是MobileNetv2块,是MobileNetv2中的倒残差结构。在MobileViT中,核心模块是MobileViT块,该模块结合CNN和Transformer的优势,用于学习局部表征以及全局表征。对于输入的图像

其中d为X

X

其中,f

在MobileViT的基础上,为了更好的捕获信息,本发明在MobileViT的部分MobileViT块中加入通道注意力机制。

以下为输入一张图像,通过网络模型预测得到识别输出的过程。

(1)输入一张图像,如图1所示,其大小为224×224×3,表示该图像的高和宽为224,有3个通道,类别为“Apple_Apple_scab”。首先通过一个3×3的卷积层进行计算,然后再经过BatchNorm层和SiLU激活函数。

(2)在第一阶段,通过一个MobileNetv2模块进行计算,在该模块中,首先通过一个1×1的卷积模块计算,通过1×1卷积层进行升维。然后通过一个3×3的深度卷积进行计算,深度卷积的输入通道数和输出通道数相等,每个通道只使用一个卷积核,减少参数量。最后再通过1×1的卷积进行降维。

(3)在第二阶段,通过三个MobileNetv2模块进行计算,每个MobileNetv2模块计算过程与步骤(2)类似,区别在于当前阶段输出图像的通道数为48,尺寸大小为56×56;在步骤(2)中输出图像的通道数为32,尺寸大小为112×112。

(4)在第三阶段,首先通过一个MobileNetv2模块进行计算,具体过程与步骤(2)类似,区别在于当前阶段输出图像的通道数为64,尺寸大小为28×28。通过改进的MobileViT块进行计算,其结构如图2所示。在改进的MobileViT块中,先通过通道注意力机制进行计算,假设输入为

(5)在第四阶段,首先通过一个MobileNetv2模块进行计算,再通过改进的MobileViT块进行计算,具体过程与步骤(4)类似,区别在于当前阶段输出图像通道数为80,尺寸大小为14×14;步骤(4)中输出图像通道数为64,尺寸大小为28×28。

(6)在第五阶段,首先通过一个MobileNetv2模块进行计算,再通过MobileViT块进行计算,具体过程与步骤(4)类似,区别在于没有加入通道注意力机制以及当前阶段输出图像通道数为96,尺寸大小为7×7。

(7)将前面步骤得到的输出先使用一个1×1的卷积进行计算,然后输入到注意力机制中。假设之前的输出为

然后再将注意力分数和输入图像相乘得到输出

然后再将注意力分数和通道注意力输出相乘得到输出

(8)将前面步骤得到的输出,输入到最后的分类器中,在分类器中首先会通过全局平均池化将之前得到的输出在通道维度进行池化运算,之后再通过全连接层进行计算,最后输出的元素个数为类别的个数,每个元素为输入图像属于该类别的概率,这些元素中的最大值就是网络模型预测的类别,最后将该类别输出,具体如图5所示。

本发明结合MobileViT模型和注意力机制构建一种用于农作物病害识别的轻量级注意力机制网络。对于基本的MobileViT模型,本发明在MobileViT的部分MobileViT块中加入通道注意力机制,此外在MobileViT模型的最后,加入通道注意力和空间注意力机制,从而能更好的捕获通道信息和空间信息。

相关技术
  • 基于注意力特征融合的轻量级农作物病害图像识别方法
  • 一种利用轻量级注意力网络识别水稻病害的系统
技术分类

06120115925595