掌桥专利:专业的专利平台
掌桥专利
首页

用于训练对象检测模型的方法及对象检测方法

文献发布时间:2023-06-19 18:35:48


用于训练对象检测模型的方法及对象检测方法

技术领域

本公开内容总体上涉及图像处理,更具体的,涉及用于训练对象检测模型的方法和对象检测方法。

背景技术

近年来,随着神经网络技术的发展,基于神经网络的图像处理模型已在多种领域被应用。例如,人脸识别、对象分类、对象检测(object detection)、自动驾驶、行为识别等领域。

通常,基于神经网络的对象检测模型在进行对象检测前,要使用大量已标注过的样本图像进行训练,以优化对象检测模型使得模型具有满意的检测性能。在完成训练后,可以向对象检测模型输入待检测图像,经过对象检测模型对待检测图像的各种处理(例如,特征提取)后,对象检测模型可以输出该待检测图像中包括的各对象实例的位置和类型。

发明内容

在下文中将给出关于本公开内容的简要概述,以便提供关于本公开内容的某些方面的基本理解。应当理解,此概述并不是关于本公开内容的穷举性概述。它并不是意图确定本公开内容的关键或重要部分,也不是意图限定本公开内容的范围。其目的仅仅是以简化的形式给出某些概念,以此作为稍后论述的更详细描述的前序。

根据本公开内容的一个方面,提供了一种用于训练对象检测模型的计算机实现的方法,方法包括以迭代方式训练对象检测模型,并且对象检测模型基于神经网络。在训练期间,当前训练迭代轮包括以下操作:分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;通过对象检测模型对至少一个全面标注的源域图像进行处理来确定针对源域数据子集的检测损失,以及针对至少一个全面标注的源域图像的源域实例分类特征集;通过对象检测模型对至少一个松散标注的目标域图像进行处理确定针对至少一个松散标注的目标域图像的目标域实例分类特征集;基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及基于与检测损失和实例对齐损失有关的总损失通过调整对象检测模型的参数来优化对象检测模型。

根据本公开内容的另一方面,提供了一种对象检测方法。该方法包括:使用上述模型训练方法训练对象检测模型;以及使用训练后的对象检测模型确定待检测图像中的对象的位置及类别。

据本公开内容的一个方面,提供了一种用于训练对象检测模型的装置。该装置包括:存储器,其上存储有指令;以及一个或更多个处理器,一个或更多个处理器能够与存储器通信以执行从存储器获取的指令,并且指令使一个或更多个处理器:分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;通过对象检测模型对至少一个全面标注的源域图像进行处理来确定针对源域数据子集的检测损失,以及针对至少一个全面标注的源域图像的源域实例分类特征集;通过对象检测模型对至少一个松散标注的目标域图像进行处理确定针对至少一个松散标注的目标域图像的目标域实例分类特征集;基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及基于与检测损失和实例对齐损失有关的总损失通过调整对象检测模型的参数来优化对象检测模型。

据本公开内容的一个方面,提供了一种其上存储有程序的计算机可读存储介质。该程序使运行程序的计算机:分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;通过对象检测模型对至少一个全面标注的源域图像进行处理来确定针对源域数据子集的检测损失,以及针对至少一个全面标注的源域图像的源域实例分类特征集;通过对象检测模型对至少一个松散标注的目标域图像进行处理确定针对至少一个松散标注的目标域图像的目标域实例分类特征集;基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及基于与检测损失和实例对齐损失有关的总损失通过调整对象检测模型的参数来优化对象检测模型。

本公开内容的方案的有益效果至少包括以下中的至少一个:对标签噪声鲁棒、克服类别不均衡、改善实例级对齐以及改善检测准确度。

根据本文提供的描述,其他应用领域将变得明显。上述描述仅旨在达到说明的目的,而并非意在限制本公开内容的范围。

附图说明

参照附图下面说明本公开内容的实施例,这将有助于更加容易地理解本公开内容的以上和其他目的、特点和优点。附图只是为了示出本公开内容的原理。在附图中不必依照比例绘制出单元的尺寸和相对位置。相同的附图标记可以表示相同的特征。在附图中:

图1示出了根据本公开内容的一个实施例的用于训练对象检测模型的方法中的一个训练迭代轮所包含的操作的流程图;

图2示出了根据本公开内容的一个实施例的用于训练对象检测模型的方法的示例性流程图;

图3示出了根据本公开内容的一个实施例的用于确定实例级对齐损失的方法的示例性流程图;

图4示出了根据本公开内容的实施例的不同处理阶段实例点在特征空间的示意性分布;

图5示出了根据本公开内容的一个实施例的对象检测方法的示例性流程图;

图6示出了根据本公开内容的一个实施例的用于训练对象检测模型的装置的结构的框图;

图7示出了根据本公开内容的一个实施例的用于训练对象检测模型的装置的结构的框图;以及

图8示出了根据本公开内容的一个实施例的信息处理设备的示例性框图。

具体实施方式

在下文中将结合附图对本公开内容的示例性实施例进行描述。为了清楚和简明起见,在说明书中并未描述实际实施例的所有特征。然而,应该了解,在开发任何这种实际实施例的过程中可以做出很多特定于实施例的决定,以便实现开发人员的具体目标,并且这些决定可能会随着实施例的不同而有所改变。

在此,还需要说明的一点是,为了避免因不必要的细节而模糊了本公开内容,在附图中仅仅示出了与根据本公开内容的方案密切相关的装置结构,而省略了与本公开内容关系不大的其他细节。

应理解的是,本公开内容并不会由于如下参照附图的描述而只限于所描述的实施形式。在本文中,在可行的情况下,实施例可以相互组合、不同实施例之间的特征替换或借用、在一个实施例中省略一个或多个特征。

用于执行本公开内容的实施例的各方面的操作的计算机程序代码可以以一种或多种程序设计语言的任何组合来编写,所述程序设计语言包括面向对象的程序设计语言,诸如Java、Smalltalk、C++之类,还包括常规的过程式程序设计语言,诸如"C"程序设计语言或类似的程序设计语言。

本公开内容的方法可以通过具有相应功能配置的电路来实现。所述电路包括用于处理器的电路。

本公开内容的一个方面提供了一种用于训练对象检测模型M的计算机实现的方法。对象检测模型M基于神经网络。采用迭代方式训练对象检测模型M。在每个训练迭代轮,会输入一批经过了标注的训练样本图像及标注数据。下面将参照图1对一个示例性训练迭代轮所包括的操作进行示例性描述。

图1示出了根据本公开内容的一个实施例的用于训练对象检测模型的方法(简称“模型训练方法”)中的一个训练迭代轮Iter[j]所包括的操作的示例性流程图,其中,j代表训练迭代轮的编号。为了讨论方便,第j训练迭代轮也可以被称为“当前训练迭代轮”。

在步骤S101,分别从具有较大量标签的源域数据集

需要说明的是,对于训练用样本图像,如果图像中的感兴趣类型的实例(即,前景)未被标注,并且对象检测模型所使用的对象类别集包括背景类的话,未被标注的实例可能会被标注为背景类。这就会带来标签噪声。松散标注的目标域图像可能带来标签噪声。另外,对于完全标注的源域图像和松散标注的目标域图像,过大的交并比(IoU,IntersectionoverUnion)可以导致背景类实例的边界框内含部分前景实例,这也可以导致标签噪声。标签噪声可能导致样本点(实例分类特征)无法对齐,对对象检测模型的性能有负面影响。

在步骤S103,通过对象检测模型M对至少一个全面标注的源域图像

在步骤S105,通过对象检测模型M对至少一个松散标注的目标域图像

在步骤S107基于源域实例分类特征集O

在步骤S109,基于与检测损失L

本公开内容的模型训练方法可以包括是否结束训练的判断。下面参照图2对本公开内容的用于训练对象检测模型的计算机实现的方法进行描述,其中示出了判断训练结束条件的步骤。

图2示出了根据本公开内容的一个实施例的用于训练对象检测模型M的方法200的示例性流程图。方法200为计算机实现的用于训练对象检测模型的方法,其包括以迭代方式训练对象检测模型M。方法200包括图1描述过的训练迭代轮Iter[j]包括的步骤S101、S103、S105和S107。

在步骤S209-1,确定是否满足预定训练结束条件。在确定结果为“是”的情况下,结束训练;确定结果为“否”的情况下,执行步骤S209-2。预定训练结束条件可以为以下条件中的一个:总损失小于预定阈值;总损失已收敛。总损失已收敛例如是指当前训练迭代轮的总损失相对于前一训练迭代轮的总损失的变化小于预定阈值。

在步骤S209-2,基于总损失通过调整对象检测模型M的参数来优化对象检测模型M。然后返回到步骤S101,进入下一训练迭代轮。

图1中的步骤S109可以细分为图2中的步骤S209-1和步骤S209-2。

作为步骤S109的另一可选实现方式,可以包括以下子步骤:基于总损失通过调整对象检测模型M的参数来优化对象检测模型M;以及,确定训练迭代轮数已达到预定计数。如果确定结果为“是”,则结束训练;如果确定结果为“否”,则返回到步骤S101,进入下一训练迭代轮。

本公开的内容的模型训练方法利用大量的源域有标签的数据和少量的目标域的有标签数据进行训练。少量的目标域的松散标注的图像的使用可以降低训练数据标注成本,缩短训练时间。

在一个实施例中,对象检测模型M被配置成基于同一对象类别集Sc对至少一个全面标注的源域图像

在一个实施例中,对象检测模型M包括特征提取器F和基于Faster R-CNN(FasterRegions with CNN features,更快的具有CNN(卷积神经网络)特征的区域)框架的R网络。R网络被配置成确定输入图像的各感兴趣区域特征。R网络还被配置成确定输入图像的各感兴趣区域ROI的带有分类标签的边界框。R网络例如可以包括区域提出网络RPN(RegionProposal Network)。特征提取器F基于输入图像进行卷积处理,输出图像的特征图(特征)。区域提出网络RPN可以基于特征提取器F的输出结果(特征图)输出与感兴趣区域对应的感兴趣区域特征。各感兴趣区域特征表征模型检测出的对象实例的位置。参考标注信息中的对象实例的真实位置信息,使用各感兴趣区域特征可以确定定位损失。关于FasterR-CNN,可以参考以下文献:

Ren S,He K,Girshick R,et al.Faster r-cnn:Towards real-time objectdetection with region proposal networks[J].Advances in neural informationprocessing systems,2015,28:91-99。

进一步的,R网络还可以包括额外分类特征提取层FC。额外分类特征提取层FC在RPN网络之后并与RPN网络连接,以从RPN网络确定的各感兴趣区域特征中提取用于分类的实例分类特征。各实例分类特征可以表征模型检测出的图像中的感兴趣对象实例的分类。参考标注信息中的对象实例的标注分类信息,使用各实例分类特征可以确定分类损失。考虑到在图像的同一位置可能出现不同类型的对象实例,因此,不直接使用感兴趣区域特征确定对象实例的类型,而是设置额外分类特征提取层FC以提取用于分类的实例分类特征,这是优选的,这有利于改善对象检测模型的性能。

在一个实施例中,对象检测模型M的R网络包含SWDA(Strong-weakdistributionalignment,强弱分布对齐)技术。关于SWDA可以参考以下文献:

Saito K,Ushiku Y,Harada T,et al.Strong-weak distribution alignmentfor adaptive object detection[C]//Proceedings of the IEEE/CVF Conference onComputer Vision and Pattern Recognition.2019:6956-6965。

在本实施例中,R网络整合了弱全局对齐和强局部对齐。SWDA是一个基于FasterR-CNN的对象检测UDA(unsupervised data augmentation,无监督数据增强)框架。为此,R网络还包含局部判别器D

其中,L表示对象检测的损失,它由分类损失和边界框的回归损失(即,定位损失)组成。

尽管实例级别的对齐能够有效提高对象检测模型的性能,但是只靠实例级别对齐可能无法保证目标检测领域自适应的模型性能。因此,在本实施例中,模型训练方法整合了SWDA的弱全局对齐和强局部对齐。为此,首先,利用弱全局对齐来学习得到图像级特征。针对当前训练迭代轮,全局判别器D

其中,γ控制比较难分类的样本的权重。

其次,利用强局部对齐来学习得到局部级特征,例如纹理或者颜色。针对当前训练迭代轮,局部判别器D

其中,W和H表示特征提取器F

L

即,对抗损失L

下面对本公开内容的模型训练方法所涉及的实例级对齐作进一步的描述。

这本公开内容的一些实施例中,模型训练方法包括在额外分类特征层提取的特征之上进行实例级别的对齐。不同于传统的只在前景ROI(感兴趣区域)特征之上进行对齐的方法,在一些实施例中,模型训练方法不仅对齐前景ROI的特征,同时也对齐对应背景类实例的背景的参考框的特征。这是因为利用样本点实例对齐,为了计算实例级别对齐损失,需要分别计算得到每个实例的类内距离和类间距离。然而,在一些场景中只有一类前景,例如,检测汽车而忽略其它物体。在这样的场景中,如果只考虑前景,那么将无法计算类间距离,也就无法计算实例对齐损失。当然,如果存在多个前景类别,本公开内容中的基于样本点对齐的实例对齐也可以用于只对齐前景类。

在一个实施例中,可以将例如用特征向量表示的移动平均类中心也当作一个实例特征,将其添加到实例分类特征集,从而参与确定实例级对齐损失。下面参照图3,对确定实例级对齐损失(例如,图1中的步骤S107)进行描述。图3示出了根据本公开内容的一个实施例的用于确定实例级对齐损失的方法300的示例性流程图。

方法300的处理对象是源域实例分类特征集O

在步骤S301,基于源域实例分类特征集O

其中,x

在步骤S303,基于目标域实例分类特征集O

其中,x

在步骤S305,针对源域,基于当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对当前训练迭代轮的源域的各类的移动平均类中心。针对源域,第k类别的第j训练迭代轮(当前训练迭代轮)的移动平均类中心

其中,

在步骤S307针对目标域,基于当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对当前训练迭代轮的目标域的各类的移动平均类中心。针对目标域,第k类别的第j训练迭代轮(当前训练迭代轮)的移动平均类中心

其中,

在步骤S309,通过将针对当前训练迭代轮的源域的各类的移动平均类中心添加到源域分类特征集来更新源域实例分类特征集。图4(b_s)示意性示出了添加了源域的各类的移动平均类中心的源域实例点分布,其中,各实心几何图形对应在源域中的表示各类的移动平均类中心的示例特征点。

在步骤S311;通过将针对当前训练迭代轮的目标域的各类的移动平均类中心添加到目标域实例分类特征集来更新目标域实例分类特征集。图4(b_t)示意性示出了添加了目标域的各类的移动平均类中心的目标域实例点分布,其中,各实心几何图形对应在目标域中的表示各类的移动平均类中心的示例特征点。添加移动平均类中心有利于对所有类别的所有实例计算跨域的类内和类间距离。

在步骤S313,确定更新的源域实例分类特征集和更新的目标域实例分类特征集之间的实例级对齐损失。需要说明的是,如果在某训练迭代轮(例如,在第一训练迭代轮),对对象类别集Sc的某个类别,其源域或目标域的移动平均类中心为零,则不进行对该实例类型的实例点的对齐,不计算针对该实例类型的实例级对齐损失,即,实例级对齐损失中不计入有关该实例类型的对齐损失。

在一个实施例中,更新实例分类特征集还可以包括删除背景类实例。由于背景的参考框所代表的感兴趣区域ROI具有非常多的标签噪声,所以在本实施例中,删除源域实例分类特征集中对应背景类实例的分类特征,同时保留源域实例分类特征集中背景类的移动平均类中心;删除目标域实例分类特征集中对应背景类实例的分类特征,同时保留目标域实例分类特征集中背景类的移动平均类中心。删除背景操作可按照公式(13)和公式(14)进行。

其中,

删除背景类实例特征点(分类特征),同时保留背景类移动平均中心,有利于抑制标签噪声,改善对象检测模型的性能。在本公开内容中,“删除背景类实例特征点(分类特征),同时保留背景类移动平均中心”的操作也简称为“删除背景类实例”。

图4(c_s)示意性示出了删除背景类实例后源域的各类实例点分布,图4(c_t)示意性示出了删除背景类实例后目标域的各类实例点分布。在图4(c_s)和图4(c_t)中,可以看到,以空心三角形表示的真实背景类实例点已被移除,同时保留了以实心三角形表示的背景类移动平均类中心实例点。

在一个实施例中,更新实例分类特征集还可以包括欠采样。众所周知,训练样本集中类别不均衡问题会导致机器学习的性能的下降。同样,实例分布的不均衡也会对实例级别对齐带来负面影响。例如,参见文献1,在Cityscapes数据集上,实例的分布非常不均衡,其中“轿车”(car)和“人员”(person)这两类的实例占了绝大多数:

文献1:Cordts M,Omran M,Ramos S,et al.The cityscapes dataset forsemantic urban scene understanding[C]//Proceedings of the IEEE conference oncomputer vision and pattern recognition.2016:3213-3223。

因此,为减轻这一问题对性能的影响,在本实施例中,更新实例分类特征集还包括对源域实例分类特征集O

其中,undersampling()是一个预定义的用于通过随机丢弃实例来限制相应类的实例最大个数不超过给定阈值

欠采样有利于实例分布均衡,有利于改善对象检测模型的性能。

在一个实施例中,针对源域实例分类特征集和目标域实例分类特征集,更新实例分类特征集包括添加各类的移动平均中心、删除背景类实例以及欠采样。

得到更新的源域实例分类特征集、目标域实例分类特征集后,就可以基于这两个特征集中的特征点的对齐确定实例级对齐损失L

文献2:Xu X,Zhou X,Venkatesan R,et al.d-SNE:Domain adaptation usingstochastic neighborhood embedding(d-SNE)(CVPR 2019)。

d-SNE是目前性能较好的基于样本点的对齐方法。d-SNE损失如公式(17)所示。

其中,d(x

其中,m是一个预定义的边距(margin)值,max()表示取最大值。m可以根据经验来定,一个实例性取值是1。在一个示例中,本公开内容的实例级对齐损失可以根据公式(18)来确定。然而公式(18)所示的d-SNE损失的实现只加大了最大类内距离和最小类间距离之间的相对差异,而没有最大化最小绝对类间距离。为解决这一问题,在一个示例中,采用改进的实例级别对齐的损失,即,扩展d-SNE损失,其用公式(19)来确定。

其中,m

调整模型参数会用到总目标函数。下面对总目标函数进行进一步的描述。

在一个实施例中,总损失可以为检测损失L

L

λ

可以利用mini-max损失函数定义总目标函数(参见公式(21))。利用总目标函数实现通过调整对象检测模型的参数来优化对象检测模型。

其中,

文献3:Ganin Y,Ustinova E,Ajakan H,et al.Domain-adversarial trainingof neural networks[J].The journal of machine learning research,2016,17(1):2096-2030。

图4(e)示意性示出了调整对象检测模型的参数的对特征点对齐的影响的效果。在图4(e)中,为了清楚地示出调整参数对对齐的影响效果,已将由调整参数后的对象检测模型确定的源域实例点和目标域实例点合并在同一空间布置。如图4(e)所示,调整对象检测模型的参数后,同类特征点会倾向于更加聚集、对齐程度变高,类内距离减小,不同类特征点会倾向于间隔变大,类间距离增大。

本公开内容的一个方面提供一种对象检测方法。下面参照图5对该方法进行示例性描述。

图5示出了根据本公开内容的一个实施例的对象检测方法500的示例性流程图。

在步骤S501,训练对象检测模型M。具体的,使用本公开内容的模型训练方法(例如,图2中示出的方法200)训练对象检测模型M。

在步骤S503,对待检测图像进行检测。具体地,使用训练后的对象检测模型确定待检测图像中的对象的位置及类别。

本公开内容的一个方面提供用于训练对象检测模型的装置。下面参照图6对该装置进行描述。图6示出了根据本公开内容的一个实施例的用于训练对象检测模型的装置600的结构的框图。装置600用于以迭代方式训练对象检测模型。对象检测模型基于神经网络。该装置600包括:检测损失确定单元601、分类特征集确定单元603、对齐损失确定单元605、总损失确定单元607和优化单元609。检测损失确定单元601被配置成:基于与用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集确定针对源域数据子集的检测损失。分类特征集确定单元603被配置成:确定针对至少一个全面标注的源域图像的源域实例分类特征集,以及确定针对至少一个松散标注的目标域图像的目标域实例分类特征集。对齐损失确定单元605被配置成:基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失。总损失确定失单元607被配置成:基于检测损失和实例对齐损失确定总损失。优化单元609被配置成:基于总损失,通过调整对象检测模型的参数来优化对象检测模型。源域数据子集和目标域数据子集分别来自具有较大量标签的源域数据集和具有较小量标签的目标域数据集。装置600与方法200存在对应关系,装置600的进一步的细节可参考本文对方法200的描述。例如,分类特征集确定单元603还被配置成执行以下操作中的至少一个:确定源域和目标域的各类的移动平均类中心,将各移动平均类中心添加到相应的实例分类特征集,删除实例分类特征集中的背景类实例,以及对实例分类特征集进行欠采样。可选的,装置600还可以包括对抗损失确定单元。对抗损失确定单元用于确定针对源域数据集和目标域数据集的对抗损失。对抗损失确定单元与总损失确定单元607耦接,以使总损失中还包括对抗损失。

据本公开内容的一个方面,提供了一种用于训练对象检测模型的装置。下面参照图7对该装置进行描述。图7示出了根据本公开内容的一个实施例的用于训练对象检测模型的装置700。该装置包括:存储器701,其上存储有指令;以及一个或更多个处理器703,一个或更多个处理器能够与存储器通信以执行从存储器获取的指令,并且指令使一个或更多个处理器:分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;通过对象检测模型对至少一个全面标注的源域图像进行处理来确定针对源域数据子集的检测损失,以及针对至少一个全面标注的源域图像的源域实例分类特征集;通过对象检测模型对至少一个松散标注的目标域图像进行处理确定针对至少一个松散标注的目标域图像的目标域实例分类特征集;基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及基于与检测损失和实例对齐损失有关的总损失通过调整对象检测模型的参数来优化对象检测模型。装置700与方法200存在对应关系,装置700的进一步的细节可参考本文对方法200的描述。

本公开内容的一个方面提供一种其上存储有程序的计算机可读存储介质。该程序使运行程序的计算机:分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;通过对象检测模型对至少一个全面标注的源域图像进行处理来确定针对源域数据子集的检测损失,以及针对至少一个全面标注的源域图像的源域实例分类特征集;通过对象检测模型对至少一个松散标注的目标域图像进行处理确定针对至少一个松散标注的目标域图像的目标域实例分类特征集;基于源域实例分类特征集和目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及基于与检测损失和实例对齐损失有关的总损失通过调整对象检测模型的参数来优化对象检测模型。该程序与方法200存在对应关系,该程序的进一步的细节可参考本文对方法200的描述。

本公开内容的一个方面提供一种其上存储有程序的计算机可读存储介质。该程序使运行程序的计算机实现方法500。

根据本公开内容一个方面,还提供一种信息处理设备。

图8是根据本公开内容的一个实施例的信息处理设备800的示例性框图。在图8中,中央处理单元(CPU)801根据存储在只读存储器(ROM)802中的程序或从存储部分808加载到随机存取存储器(RAM)803的程序来进行各种处理。在RAM 803中,也根据需要来存储在CPU801执行各种处理时所需的数据等。

CPU 801、ROM 802以及RAM 803经由总线804彼此连接。输入/输出接口805也连接至总线804。

下述部件连接至输入/输出接口805:包括软键盘等的输入部分806;包括诸如液晶显示器(LCD)等的显示器以及扬声器等的输出部分807;诸如硬盘的存储部分808;以及包括网络接口卡如LAN卡、调制解调器等的通信部分809。通信部分809经由诸如英特网、局域网、移动网络的网络或其组合执行通信处理。

驱动器810根据需要也连接至输入/输出接口805。可拆卸介质811如半导体存储器等根据需要安装在驱动器810上,使得从其中读取的程序根据需要被安装到存储部分808。

CPU 801可以运行用于实现根据本公开内容的用于训练对象识别模型的方法的程序或用于实现本公开内容的对象检测方法的程序。

下面对本公开内容的方案的效果进行描述。

构建了以下三个场景进行实验,比较本公开内容的方案与现有的方法在准确率性能方面的差异:(1)从Cityscapes到Foggy Cityscapes的迁移(C->F);(2)从SIM10K到Cityscapes的迁移(S->C;即,用SIM10K的有标注样本和Cityscapes的少量有标注样本一起训练);(3)从Udacity到Cityscapes的迁移(U->C)。实验结果如表1和表2所示。第一个场景C->F是为了模拟天气变化造成的领域偏移(domain shift)导致的数据偏差。第二个场景S->C是为了模拟虚拟世界和真实世界之间的数据偏差。第三个场景U->C是为了仿真两个不同真实世界之间由于光照、摄像机角度等原因造成的数据偏差。

表1 C->F的实验结果

表2 S->C和U->C的实验结果

引用数据来源如下。

[1]Ren S,He K,Girshick R,et al.Faster r-cnn:Towards real-time objectdetection with region proposal networks[J].Advances in neural informationprocessing systems,2015,28:91-99.

[2]Saito K,Ushiku Y,Harada T,et al.Strong-weak distribution alignmentfor adaptive object detection[C]//Proceedings of the IEEE/CVF Conference onComputer Vision and Pattern Recognition.2019:6956-6965.

[3]Zhuang C,Han X,Huang W,et al.ifan:Image-instance full alignmentnetworks for adaptive object detection[C]//Proceedings of the AAAI Conferenceon Artificial Intelligence.2020,34(07):13122-13129.

[4]Wu,A.,Han,Y.,Zhu,L.,&Yang,Y.(2021).Instance-Invariant DomainAdaptive Object Detection via Progressive Disentanglement.IEEE Transactionson Pattern Analysis and Machine Intelligence,1–1.

[5]Wang T,Zhang X,Yuan L,et al.Few-shot adaptive faster r-cnn[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and PatternRecognition.2019:7173-7182.

其中,Source-only表示仅使用全面标注的源域数据进行训练;Target-only表示仅使用松散标注的目标域数据进行训练;UDA表示无监督领域自适应方法,它使用了所有未标注的目标域数据来进行领域自适应;FUDA表示少样本无监督领域自适应方法;FDA表示少样本领域自适应方法;PICA+SWDA表示本公开内容所采用的方法,PICA表示“点型实例及中心对齐”(point-wise instance and centroid alignment);mAP(0.5)表示平均准率(Meanaverage precision),0.5是阈值;表格内的含小数点数据表示检测准确率mAP。

在S->C和U->C场景中,使用了8张目标域图像,每张图像只标注3辆汽车;在C->F场景中,使用了8张目标域图像,每张图像对应一个类,并且每张图像只标注出对应类的一个实例。FUDA方法使用了和FDA一样的8张图像,但是不使用对应的标注。

表1和表2的实验结果表明本公开内容的方法(PICA+SWDA)在C->F、S->C和U->C上都优于现有方法FAFRCNN以及SWDA。

本公开内容的方案涉及额外分类特征提取层、对抗损失、少量松散标注的目标域图像的使用、移动平均类中心的对齐、删除背景类实例、欠采样、改进的实例级对齐损失。本公开内容的有益效果至少包括以下中的至少一个:对标签噪声鲁棒、克服类别不均衡、改善实例级对齐以及改善检测准确度。

尽管上面已经通过对本发明的具体实施例的描述对本发明进行了披露,但是,应该理解,本领域的技术人员可在所附权利要求的精神和范围内设计对本发明的各种修改(包括在行的情况下,各实施例之间特征的组合或替换)、改进或者等同物。这些修改、改进或者等同物也应当被认为包括在本发明的保护范围内。

应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。

此外,本发明的各实施例的方法不限于按照说明书中描述的或者附图中示出的时间顺序来执行,在技术上可行的情况下,也可以按照其他的时间顺序、并行地或独立地执行。因此,本说明书中描述的方法的执行顺序不对本发明的技术范围构成限制。

本公开内包括但不限于以下技术方案。

1.一种用于训练对象检测模型的计算机实现的方法,所述方法包括以迭代方式训练所述对象检测模型,并且所述对象检测模型基于神经网络,其特征在于,当前训练迭代轮包括以下操作:

分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于所述当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;

通过所述对象检测模型对所述至少一个全面标注的源域图像进行处理来确定针对所述源域数据子集的检测损失,以及针对所述至少一个全面标注的源域图像的源域实例分类特征集;

通过所述对象检测模型对所述至少一个松散标注的目标域图像进行处理确定针对所述至少一个松散标注的目标域图像的目标域实例分类特征集;

基于所述源域实例分类特征集和所述目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及

基于与所述检测损失和所述实例对齐损失有关的总损失通过调整所述对象检测模型的参数来优化所述对象检测模型。

2.根据附记1所述的方法,其中,基于同一对象类别集使用所述至少一个全面标注的源域图像和所述至少一个松散标注的目标域图像对所述对象检测模型进行训练,并且所述同一对象类别集包括背景类。

3.根据附记1所述的方法,其中,所述对象检测模型包括R网络;

所述R网络基于Faster RCNN框架;

所述R网络被配置成确定输入图像的各感兴趣区域特征;并且

所述R网络还被配置成确定所述输入图像的各感兴趣区域的带有分类标签的边界框。

4.根据附记3所述的方法,其中,所述R网络包括额外分类特征提取层;并且

所述额外分类特征提取层被配置成从各感兴趣区域特征中提取用于分类的实例分类特征。

5.根据附记1所述的方法,其中,所述总损失还与针对所述源域数据子集和所述目标域数据子集的对抗损失有关。

6.根据附记5所述的方法,其中,所述R网络包括全局判别器和局部判别器,所述对抗损失包括由所述全局判别器基于图像级特征确定的弱全局对齐损失和由所述局部判别器基于局部级特征确定的强局部对齐损失。

7.根据附记2所述的方法,其中,基于所述源域实例分类特征集和所述目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失包括:

基于所述源域实例分类特征集确定针对所述当前训练迭代轮源域的各类的平均类中心;

基于所述目标域实例分类特征集确定所述当前训练迭代轮目标域的各类的平均类中心;

针对所述源域,基于所述当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对所述当前训练迭代轮的所述源域的各类的移动平均类中心;

针对所述目标域,基于所述当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对所述当前训练迭代轮的所述目标域的各类的移动平均类中心;

通过将针对所述当前训练迭代轮的所述源域的各类的移动平均类中心添加到所述源域分类特征集来更新所述源域实例分类特征集;

通过将针对所述当前训练迭代轮的所述目标域的各类的移动平均类中心添加到所述目标域实例分类特征集来更新所述目标域实例分类特征集;以及

确定所述更新的源域实例分类特征集和所述更新的目标域实例分类特征集之间的所述实例级对齐损失。

8.根据附记7所述的方法,其中,更新所述源域实例分类特征集还包括:对所述源域实例分类特征集进行欠采样;并且

更新所述目标域实例分类特征集还包括:对所述目标域实例分类特征集进行欠采样。

9.根据附记7所述的方法,其中,更新所述源域实例分类特征集还包括:删除所述源域实例分类特征集中对应背景类实例的分类特征,同时保留所述源域实例分类特征集中所述背景类的移动平均类中心;并且

更新所述目标域实例分类特征集还包括:删除所述目标域实例分类特征集中对应背景类实例的分类特征,同时保留所述目标域实例分类特征集中所述背景类的移动平均类中心。

10.根据附记1所述的方法,其中,所述实例级对齐损失为考虑了最大化最小绝对类间距离的扩展d-SNE损失。

11.一种对象检测方法,其特征在于,包括:

使用附记1至10中的任一项所述的方法训练所述对象检测模型;以及

使用训练后的对象检测模型确定待检测图像中的对象的位置及类别。

12.一种其上存储有程序的计算机可读存储介质,其特征在于,所述程序使运行程序的计算机:

分别从具有较大量标签的源域数据集和具有较小量标签的目标域数据集读取用于当前训练迭代轮的与至少一个全面标注的源域图像对应的具有较大量标签的源域数据子集和与至少一个松散标注的目标域图像对应的具有较小量标签的目标域数据子集;

通过对象检测模型对所述至少一个全面标注的源域图像进行处理来确定针对所述源域数据子集的检测损失,以及针对所述至少一个全面标注的源域图像的源域实例分类特征集;

通过所述对象检测模型对所述至少一个松散标注的目标域图像进行处理确定针对所述至少一个松散标注的目标域图像的目标域实例分类特征集;

基于所述源域实例分类特征集和所述目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失;以及

基于与所述检测损失和所述实例对齐损失有关的总损失通过调整所述对象检测模型的参数来优化所述对象检测模型。

13.根据附记12所述的计算机可读存储介质,其中,基于同一对象类别集使用所述至少一个全面标注的源域图像和所述至少一个松散标注的目标域图像对所述对象检测模型进行训练,并且所述同一对象类别集包括背景类。

14.根据附记12所述的计算机可读存储介质,其中,所述对象检测模型包括R网络;

所述R网络基于Faster RCNN框架;

所述R网络被配置成确定输入图像的各感兴趣区域特征;并且

所述R网络还被配置成确定所述输入图像的各感兴趣区域的带有分类标签的边界框。

15.根据附记14所述的计算机可读存储介质,其中,所述R网络包括额外分类特征提取层;并且

所述额外分类特征提取层被配置成从各感兴趣区域特征中提取用于分类的实例分类特征。

16.根据附记12所述的计算机可读存储介质,其中,所述总损失还与针对所述源域数据子集和所述目标域数据子集的对抗损失有关。

17.根据附记16所述的计算机可读存储介质,其中,所述R网络包括全局判别器和局部判别器,所述对抗损失包括由所述全局判别器基于图像级特征确定的弱全局对齐损失和由所述局部判别器基于局部级特征确定的强局部对齐损失。

18.根据附记13所述的计算机可读存储介质,其中,基于所述源域实例分类特征集和所述目标域实例分类特征集确定与实例特征对齐有关的实例级对齐损失包括:

基于所述源域实例分类特征集确定针对所述当前训练迭代轮源域的各类的平均类中心;

基于所述目标域实例分类特征集确定所述当前训练迭代轮目标域的各类的平均类中心;

针对所述源域,基于所述当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对所述当前训练迭代轮的所述源域的各类的移动平均类中心;

针对所述目标域,基于所述当前训练迭代轮的各类的平均类中心和前一训练迭代轮的各类的移动平均类中心确定针对所述当前训练迭代轮的所述目标域的各类的移动平均类中心;

通过将针对所述当前训练迭代轮的所述源域的各类的移动平均类中心添加到所述源域分类特征集来更新所述源域实例分类特征集;

通过将针对所述当前训练迭代轮的所述目标域的各类的移动平均类中心添加到所述目标域实例分类特征集来更新所述目标域实例分类特征集;以及

确定所述更新的源域实例分类特征集和所述更新的目标域实例分类特征集之间的所述实例级对齐损失。

19.根据附记18所述的计算机可读存储介质,其中,更新所述源域实例分类特征集还包括:对所述源域实例分类特征集进行欠采样;并且

更新所述目标域实例分类特征集还包括:对所述目标域实例分类特征集进行欠采样。

20.根据附记18所述的计算机可读存储介质,其中,更新所述源域实例分类特征集还包括:删除所述源域实例分类特征集中对应背景类实例的分类特征,同时保留所述源域实例分类特征集中所述背景类的移动平均类中心;并且

更新所述目标域实例分类特征集还包括:删除所述目标域实例分类特征集中对应背景类实例的分类特征,同时保留所述目标域实例分类特征集中所述背景类的移动平均类中心。

技术分类

06120115627920