掌桥专利:专业的专利平台
掌桥专利
首页

一种基于概率图和ViT模型的图片分类方法

文献发布时间:2023-06-19 19:27:02


一种基于概率图和ViT模型的图片分类方法

技术领域

本发明涉及一种基于概率图和ViT模型的图片分类方法,属于计算机视觉图片分类技术领域。

背景技术

Transformer在自然语言处理领域取得了巨大的成功,激励了人们尝试将多头注意力机制引入主流框架为卷积神经网络的计算机视觉领域。相较于卷积神经网络,Transformer在捕捉图片全局信息方面有着巨大的优势,同时,Transformer的可并行化计算也促进了其在视觉领域的应用。目前Vision Transformer在计算机视觉的各类任务,如图片分类、目标检测和图片降噪等方面取得了令人瞩目的效果。但是大量研究发现,Transformer的核心多头注意力机制中的不同头之间存在参数冗余,严重影响了模型的整体性能。

发明内容

本发明的目的是提供一种基于概率图和ViT模型的图片分类方法,用来解决上述缺陷。

为了实现上述目的,本发明采用的技术方案是:

一种基于概率图和ViT模型的图片分类方法,包括以下步骤:

S1、将输入模型的图片进行分块,然后将每个图片块展平成一维向量,最后通过线性变换生成patch embedding;

S2、给每个patch embedding加上位置编码,补充位置信息;

S3、增加一个用于分类的Token,学习其他图片patch的整体信息;

S4、基于头部交互的Transformer Block,把attention values看作隐变量,利用概率图模型中的Explaining-away Effects以及Transformer的层级结构,将attentionlogits层层传递,并将相邻层的值进行融合,促进不同头部之间的交互;

S5、使用两层全连接层,将分类Token输入分类层,得到图片的分类结果。

本发明技术方案的进一步改进在于,所述S1的具体步骤为:

S11、将输入模型的图片进行分块、展平,具体操作为:

将图片patch的长宽均设置为P,即将图片数据H*W*C变换为

其中,N为一张图分割的patch数量,C为通道数,H为图片高度,W为图片宽度;

S12、将patch向量线性变换为patch embedding:

patch_embedding=nn.Linear(patch_dim,dim)

其中,patch_dim为patch向量的维度,dim为patch embedding的维度。

本发明技术方案的进一步改进在于,所述S2的具体操作为:

pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))

其中,pos_embedding为patch的位置编码,num_patches为patch的数量。

本发明技术方案的进一步改进在于,所述S3的具体操作为:

添加一个专门用于分类的可学习编码,与输入进行拼接,具体为:

cls_token=nn.Parameter(torch.randn(1,1,dim))

其中,cls_token为分类Token,然后与其他patch token进行拼接。

本发明技术方案的进一步改进在于,所述S4的具体步骤为:

S41、attention head序列建模,将attention value看作隐变量,

p(Y∣X)=∫

其中,Y为图片label,X为输入照片,A为中间层Attention values,p(A∣X)是联合先验分布;

S42、Transformer层次化建模,利用transformer的层级结构,将此过程可表示为:

其中,A

S43、相邻层的attention融合,在Transformer的层级Block的多头注意力计算模块添加MLP,将各层之间的attention vlaue进行融合交互,促进不同头部的去冗余,将此过程可表示为:

A

其中,z

本发明技术方案的进一步改进在于,所述S5的具体操作公式为:

x=self.to_cls_token(x[:,0])

y=self.mlp_head(x)

其中,x为输出的分类Token,mlp_head()为分类层,y为输出的预测。

由于采用了上述技术方案,本发明取得的技术效果有:

本发明设计的一种基于概率图和ViT模型的图片分类方法针对普通VisionTransformer模型中多头注意力机制头部参数的冗余问题,将多头注意力机制建模为概率图模型,将注意力值看作隐变量,促进不同注意力头部之间的交互。

本发明为了促进不同头部之间的交互,将attention logits逐层传递,将相邻层之间的attention logits进行融合,促使不同的头部捕捉不同的特征。

本发明通过促进头部交互,提升了参数效率,进而提升了图形分类正确率以及迁移学习效果,同时提高了特征的可解释性。

附图说明

图1是本发明的算法流程图;

图2是本发明的模型架构图;

具体实施方式

下面结合附图及具体实施例对本发明做进一步详细说明:

一种基于概率图和ViT模型的图片分类方法,如图1所示,包括如下步骤:

S1、图片分块;

将输入模型的图片进行分块,然后将每个图片块展平成一维向量,最后通过线性变换生成patch embedding;

S2、位置编码;

给每个patch embedding加上位置编码,补充位置信息;

S3、分类Token;

增加一个用于分类的Token,学习其他图片patch的整体信息;

S4、基于头部交互的Transformer Block;

把attention values看作隐变量,利用概率图模型中的Explaining-awayEffects以及Transformer的层级结构,将attention logits层层传递,并将相邻层的值进行融合,促进不同头部之间的交互;

S5、使用两层全连接层,将分类Token输入分类层,得到图片的分类结果。

具体实施过程如下:

S1、图片分块

S11、将输入模型的图片进行分块、展平,此过程可描述为:

将图片patch的长宽均设置为P,即将图片数据H*W*C变换为

其中,N为一张图分割的patch数量,C为通道数,H为图片高度,W为图片宽度。

S12、将patch向量线性变换为patch embedding,此过程可描述为:

patch_embedding=nn.Linear(patch_dim,dim)

其中,patch_dim为patch向量的维度,dim为patch embedding的维度。

S2、位置编码,为保留图片的相对位置信息,给各个patch加上可学习的位置编码

pos_embedding=nn.Parameter(torch.randn(1,num_patches+1,dim))

其中,pos_embedding为patch的位置编码,num_patches为patch的数量。

S3、分类Token,添加一个专门用于分类的可学习编码,与输入进行拼接,具体为:

cls_token=nn.Parameter(torch.randn(1,1,dim))

其中,cls_token为分类Token,然后与其他patch token进行拼接。

S4、基于头部交互的Transformer Block

S41、attention head序列建模,将attention value看作隐变量,

p(Y∣X)=∫

其中,Y为图片label,X为输入照片,A为中间层Attention values,p(A∣X)是联合先验分布。

S42、Transformer层次化建模,利用transformer的层级结构,将此过程可表示为:

其中,A

S43、相邻层的attention融合,在Transformer的层级Block的多头注意力计算模块添加MLP,将各层之间的attention vlaue进行融合交互,促进不同头部的去冗余,将此过程可表示为:

A

其中,z

S5、用两层全连接层,将分类Token输入分类层,得到图片的分类结果:

x=self.to_cls_token(x[:,0])

y=self.mlp_head(x)

其中,x为输出的分类Token,mlp_head()为分类层,y为输出的预测。

本发明提出将多头注意力机制从概率论的角度进行建模。将多头注意力中的attention value看作隐变量,利用概率图模型的Explaining-away Effects以及Transformer的层级结构,将attention logits层层传递,并将相邻层的值进行融合,促进不同头部之间的交互。在实验中,首先将本发明方法设计的模型在ImageNet1K数据集上预训练,图片分类正确率相较于普通ViT模型提高了1.56%;然后将训练得到的模型参数应用到下游小型数据集CIFAR100数据集上进行迁移学习,分类正确率提高了3.47%。实验证明,本发明方法在一定程度上促进了参数的有效性以及不同头部之间的独立性,提高了图片分类的正确率和迁移学习性能。另一方面,将图片的attention可视化发现,本发明方法实现了不同的头部注意到图片不同的部分,证明了模型学习在特征的可解释性方面具有一定的研究意义和应用价值。

相关技术
  • 一种基于概率图模型的工业过程报警根源识别方法
  • 基于概率图模型以及场景分类的光伏特性曲线预测方法
  • 基于概率图模型的频繁模式关联分类方法
技术分类

06120115917924