掌桥专利:专业的专利平台
掌桥专利
首页

一种基于度量学习的无监督目标检测模型训练方法

文献发布时间:2024-04-18 19:58:30


一种基于度量学习的无监督目标检测模型训练方法

技术领域

本发明涉及图像识别领域,尤其是涉及一种基于度量学习的无监督目标检测模型训练方法。

背景技术

近年来,目标检测已经逐步成为计算机视觉领域的重要研究课题,并且已经取得广泛应用,例如工业互联网,安防,医学辅助诊断,遥感影像分析等。目标检测主要是从视觉图像或视频中识别感兴趣目标的位置及类别。由于具有对图像特征的深度提取能力,基于深度学习在目标检测中取得了巨大的进展。然而,目前目标检测的主流模型训练时仍然需要大量的人工标注数据,标注的巨大成本和目标检测应用的中感兴趣目标的不确定性使得目标检测模型训练成本高,应用推广难。此外,近些年来基于多模态预训练大模型开启了深度学习利用多模态数据图文信息的能力。然而,目前的预训练大模型对于目标的特征表达能力,及对目标的细节信息提取能力仍有不足,限制了其在实际应用的效果。

针对以上两个难点,如何在不采用人工标注数据的情况下,提升目标检测模型对目标特征及细节信息的表达能力,提升目标检测的效果,是目前所需要亟待解决的问题。

发明内容

本发明主要是解决现有技术所存在的难以在不采用人工标注数据的情况下完成目标检测模型的训练、提升目标检测的效果的技术问题,提供一种基于度量学习的无监督目标检测模型训练方法,可以脱离人工标注数据仍然完成对目标检测模型的训练过程。

本发明针对上述技术问题主要是通过下述技术方案得以解决的:一种基于度量学习的无监督目标检测模型训练方法,包括以下步骤:

S1、通过开源预训练模型对训练数据集所包含的图片进行目标提取,获得伪标签,训练数据集为图文对,即包含图片和描述图片的文本标签,伪标签包括目标的坐标信息和目标的描述文本;

S2、待训练目标检测模型包括图像编码骨干模型、特征金字塔网络和检测头;待训练目标检测模型在随机初始化或者加载预训练模型检测点后(加载预训练模型检测点也表示待训练目标检测模型获得了基本参数,可以进行目标检测过程),将训练数据集的图片输入到待训练目标检测模型中,所得到的结果包括整体图片i和各目标的区域图片;目标j为待训练目标检测模型得到的目标之一,将目标j与开源预训练模型得到的各目标进行对比,将重合度大于0.5的目标j’的伪标签作为目标j的伪标签,以同样方式确定待训练目标检测模型得到的所有目标的伪标签,然后提取待训练目标检测模型得到的目标的特征向量;检测头用于从整体图片i中检测目标;一般是先将整体图片i输入到图像编码骨干模型中得到图片i的整体多层特征图,然后用特征金字塔网络从整体多层特征图中提取整体特征向量,最后检测头依据整体特征向量从整体图片i中检测目标;重合度为两者的交集除以两者的并集;如果有多个重合度大于0.5的伪标签则选取重合度最大的作为目标j的伪标签;没有重合度大于0.5的伪标签则将此目标丢弃不加入训练;

S3、基于度量模型对待训练目标检测模型进行训练,对于每一个目标j,其损失函数为:

L=λ

其中,L

本方案适用于训练常规的任意目标检测模型,只需要目标检测模型符合包括图像编码骨干模型、特征金字塔网络和检测头三部分这个特点即可。开源预训练模型可以采用常规的开源并且经过训练的目标检测模型,只需要此开源预训练模型可以输出目标检测结果和描述文本即可。

作为优选,所述步骤S2中,提取目标j的特征向量具体过程为:

S201、将整体图片i输入到图像编码骨干模型,抽取每个输出层输出的特征得到图片i的整体多层特征图;

S202、依据目标j的区域图片在图片i中的位置,从整体多层特征图中截取得到目标多层特征图;

S203、使用特征金字塔网络对目标多层特征图进行提取池化,得到目标j的特征向量。

作为优选,所述三元组相似性度量学习损失为:

L

式中max为取最大值,S为计算括号中两个目标特征向量的余弦距离,具体公式为:

式中,T表示转置,双竖线表示求向量长度,即norm2;S(a,n)的定义与S(a,p)相同,只是将p替换成n;a为基准目标样本的特征向量,p为正目标样本的特征向量,n为负目标样本的特征向量;margin为预设的间隔参数;训练过程中,对每一个待训练目标检测模型生成的匹配到伪标签的目标j均为基准目标样本;对于每一个基准目标样本j,选择与目标j的目标特征向量余弦距离最远且含有相同伪标签描述文本的目标作为正目标样本;选择与目标j的目标特征向量余弦距离最近且含有不同伪标签描述文本的目标作为负目标样本。

通过对目标三元组的度量学习损失,可以使目标检测模型网络模块抽取到不同文本描述的目标的细节特征,提高模型对目标的细粒度识别能力。

作为优选,目标框的位置L1损失具体为:

式中,ti表示待训练目标检测模型得到的目标j的目标框坐标,为一个四点向量{tx、ty、tw、th},分别表示坐标在x轴和y轴上的位置及目标对应的宽w和高h。ti*表示开源预训练模型得到的目标j’的目标框坐标,为一个四点向量{tx*、ty*、tw*、th*}。

作为优选,目标框的位置GIOU损失具体为:

式中,A表示待训练目标检测模型检测到的目标j的目标框,B为开源预训练模型检测到的与目标j对应的目标j’的目标框;C表示A和B两个目标框的最小外接矩形面积,IOU为A和B的重叠度。A、B、C均为面积。

作为优选,图文对比损失具体为:

式中,N为训练所用的图文对样本总数,训练所用的样本包括整体图像和整体图像的文本标签(来源于数据),v_m是第m个图文对样本的图片经过待训练目标检测模型的表征提取后得到的整体多层特征图(步骤S201得到),l_m是第m个图文对样本的文本标签经过文本编码器后得到的文本表征,文本编码器采用Bert或Roberta,这里的文本编码器不需要是S1模型的文本编码器。文本编码器参与训练但不会被更新,p(v_m,l_m)的计算公式如下:

式中,S为计算括号中两个对象的相似度,τ为温度超参数,Nri为同一批训练所用样本中除第m个图文对样本之外的其它图文对样本的文本标签经过文本编码器后得到的文本表征集合,即文本标签k和图片i不匹配,但是和同一个batch(训练样本集)中的其他图像k匹配。

开源预训练模型包括一个图像编码器和一个文本编码器,图像编码器和文本编码器在预训练时在语义空间对齐;图像编码器生成图像i中的目标j的坐标信息并对每个目标进行目标特征提取;文本编码器对图像对应的文本的每个单词/词组进行特征提取;利用开源预训练模型的多模态对齐能力,将图像的目标特征和文本的每个单词/词组特征进行点乘,该点乘结果作为图片目标与文本单词/词组的对齐分数,对于每个目标j,选择对齐分数最高的文本单词/词组作为该目标的描述文本。

本发明带来的实质性效果是,在没有人工标注数据的情况下,利用图文多模态数据,及大规模预训练进行无监督的目标检测训练。首先利用大规模预训练模型进行目标检测模型初始化,并形成目标及目标标签的伪标签信息。然后,利用伪标签信息,对图像目标的区域表征进行度量学习,即采用三元组相似度量模型去区分目标之间的细节信息,从而充分提取图像目标的细粒度信息,从而达到提高目标检测效果的作用。

附图说明

图1是本发明的一种流程图。

具体实施方式

下面通过实施例,并结合附图,对本发明的技术方案作进一步具体的说明。

实施例:一种基于度量学习的无监督目标检测模型训练方法,如图1所示,包括以下步骤:

S1、通过开源预训练模型对训练数据集所包含的图片进行目标提取,获得伪标签,训练数据集为图文对,即包含图片和描述图片的文本标签,伪标签包括目标的坐标信息和目标的描述文本;

S2、待训练目标检测模型包括图像编码骨干模型、特征金字塔网络和检测头;待训练目标检测模型在随机初始化或者加载预训练模型检测点后,将训练数据集的图片输入到待训练目标检测模型中,所得到的结果包括整体图片i和各目标的区域图片;目标j为待训练目标检测模型得到的目标之一,将目标j与开源预训练模型得到的各目标进行对比,将重合度大于0.5的目标j’的伪标签作为目标j的伪标签,以同样方式确定待训练目标检测模型得到的所有目标的伪标签,然后提取待训练目标检测模型得到的目标的特征向量;检测头用于从整体图片i中检测目标;一般是先将整体图片i输入到图像编码骨干模型中得到图片i的整体多层特征图,然后用特征金字塔网络从整体多层特征图中提取整体特征向量,最后检测头依据整体特征向量从整体图片i中检测目标;重合度为两者的交集除以两者的并集;如果有多个重合度大于0.5的伪标签则选取重合度最大的作为目标j的伪标签;没有重合度大于0.5的伪标签则将此目标丢弃不加入训练;

S3、基于度量模型对待训练目标检测模型进行训练,对于每一个目标j,其损失函数为:

L=λ

其中,L

本方案适用于训练常规的任意目标检测模型,只需要目标检测模型符合包括图像编码骨干模型、特征金字塔网络和检测头三部分这个特点即可。图像编码骨干模型为任意的基于图像的神经网络模型,例如resnet、convnext、vision transformer等。目标检测所得到的包含目标j的结果包括图片i和目标j区域图片;图片i一般为3个颜色通道的自然图像;目标j区域图片理解为从图片i中截取的目标j所在区域的图像,具体形式可以是若干个坐标围成的范围。开源预训练模型可以采用常规的开源并且经过训练的目标检测模型,只需要此开源预训练模型可以输出目标检测结果和描述文本即可。

所述步骤S2中,提取目标j的特征向量具体过程为:

S201、将整体图片i输入到图像编码骨干模型,抽取每个输出层输出的特征得到图片i的整体多层特征图;

S202、依据目标j的区域图片在图片i中的位置,从整体多层特征图中截取得到目标多层特征图;

S203、使用特征金字塔网络对目标多层特征图进行提取池化,得到目标j的特征向量。

提取池化即对图像编码骨干模型中conv2,conv3,conv4和conv5的输出层{C2,C3,C4,C5}作为FPN的特征,对于图像目标i,对目标{C2,C3,C4,C5}所在位置进行池化操作,例如ROIAlign,得到目标j的特征向量。

具体操作如下:

1)假设输入一张800*800的图片,图片上有一个640*640的目标j。图片经过主干网络部分提取特征后,特征图缩放步长(stride)为32,得到25*25的全图特征图。对于目标j,目标特征图的大小为(640/32=)20*20;

2)假设需要框内的特征池化为7*7的大小,将在特征图上映射的20*20的目标特征图划分成49个同等大小的小区域,每个小区域的大小20/7=2.86,即2.86*2.86;

3)假定采样点数为4,即表示,对于每个2.86*2.86的小区域,平分四份,每一份取其中心点位置,而中心点位置的像素,采用双线性插值法进行计算,这样,就会得到四个点的像素值。最后取4个像素值最大值作为这个小区域的像素值,如此类推,得到49个像素值,成为7x7的特征图。

对于目标j对{C2,C3,C4,C5}池化得到的特征图进行拼接,得到目标j的特征图。该特征图经过一个非线性变换(神经网络全连接层)被用来将目标特征表示映射到对比损失的空间;对图像i的目标j,通过图像区域编码器将目标抽取为一个特征向量v_ij。

所述三元组相似性度量学习损失为:

L

式中max为取最大值,S为计算括号中两个目标特征向量的余弦距离,具体公式为:

式中,T表示转置,双竖线表示求向量长度,即norm2;S(a,n)的定义与S(a,p)相同,只是将p替换成n;a为基准目标样本的特征向量,p为正目标样本的特征向量,n为为负目标样本的特征向量;margin为预设的间隔参数;训练过程中,对每一个待训练目标检测模型生成的匹配到伪标签的目标j均为基准目标样本;对于每一个基准目标样本j,选择与目标j的目标特征向量余弦距离最远且含有相同伪标签描述文本的目标作为正目标样本;选择与目标j的目标特征向量余弦距离最近且含有不同伪标签描述文本的目标作为负目标样本。

目标框的位置L1损失具体为:

式中,ti表示待训练目标检测模型得到的目标j的目标框坐标,为一个四点向量{tx、ty、tw、th},分别表示坐标在x轴和y轴上的位置及目标对应的宽w和高h。ti*表示开源预训练模型得到的目标j’的目标框坐标,为一个四点向量{tx*、ty*、tw*、th*}。

作为优选,目标框的位置GIOU损失具体为:

式中,A表示待训练目标检测模型检测到的目标j的目标框,B为开源预训练模型检测到的与目标j对应的目标j’的目标框;C表示A和B两个目标框的最小外接矩形面积,IOU为A和B的重叠度。

图文对比损失具体为:

式中,N为训练所用的图文对样本总数,训练所用的样本包括整体图像和整体图像的文本标签(来源于数据),v_m是第m个图文对样本的图片经过待训练目标检测模型的表征提取后得到的图像表征,l_m是第m个图文对样本的文本标签经过文本编码器后得到的文本表征,文本编码器采用Bert或Roberta,这里的文本编码器不需要是S1模型的文本编码器。文本编码器参与训练但不会被更新,p(v_m,l_m)的计算公式如下:

式中,S为计算括号中两个对象的相似度,τ为温度超参数,Nri为同一批训练所用样本中除第m个图文对样本之外的其它图文对样本的文本标签经过文本编码器后得到的文本表征集合,即文本标签k和图片i不匹配,但是和同一个batch(训练样本集)中的其他图像k匹配。

开源预训练模型包括一个图像编码器和一个文本编码器,图像编码器和文本编码器在预训练时在语义空间对齐;图像编码器生成图像i中的目标j的坐标信息并对每个目标进行目标特征提取;文本编码器对图像对应的文本的每个单词/词组进行特征提取;利用开源预训练模型的多模态对齐能力,将图像的目标特征和文本的每个单词/词组特征进行点乘,该点乘结果作为图片目标与文本单词/词组的对齐分数,对于每个目标j,选择对齐分数最高的文本单词/词组作为该目标的描述文本。

本文中所描述的具体实施例仅仅是对本发明精神作举例说明。本发明所属技术领域的技术人员可以对所描述的具体实施例做各种各样的修改或补充或采用类似的方式替代,但并不会偏离本发明的精神或者超越所附权利要求书所定义的范围。

尽管本文较多地使用了伪标签、图像编码骨干网络、检测头等术语,但并不排除使用其它术语的可能性。使用这些术语仅仅是为了更方便地描述和解释本发明的本质;把它们解释成任何一种附加的限制都是与本发明精神相违背的。

相关技术
  • 用于度量语音数据库覆盖性的无监督模型训练方法及装置
  • 用于度量语音数据库覆盖性的无监督模型训练方法及装置
技术分类

06120116503317