掌桥专利:专业的专利平台
掌桥专利
首页

模型训练方法、目标检测方法及对应装置

文献发布时间:2023-06-19 11:45:49


模型训练方法、目标检测方法及对应装置

技术领域

本发明涉及目标检测技术领域,具体而言,涉及一种模型训练方法、目标检测方法及对应装置。

背景技术

目标检测是计算机视觉相关业务的基础技术之一,其具体任务是找预出给定图像中所有目标物体的位置并判断其类别。在现有技术中,普遍通过特定结构的神经网络模型实现目标检测,例如,单阶段检测器、两阶段检测器等。然而,由于训练样本的质量欠佳,导致这些方法的检测效果仍有待提高。

发明内容

本申请实施例的目的在于提供一种模型训练方法、目标检测方法及对应装置,以改善上述技术问题。

为实现上述目的,本申请提供如下技术方案:

第一方面,本申请实施例提供一种模型训练方法,用于训练第一检测模型,所述第一检测模型包括特征提取模块、密集预测模块、质量分布编码模块以及质量分布采样模块,所述方法包括:利用所述特征提取模块提取训练图像的特征图;利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;利用所述质量分布编码模块根据所述特征图以及所述训练图像中的真实框预测质量分布的分布参数,所述质量分布表征所述密集预测模块预测的检测框和所述真实框的重叠程度在所述特征图所在的平面上满足的分布;利用所述质量分布采样模块根据所述质量分布进行采样,得到多个采样点,确定所述多个采样点中的正样本点,并根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数;计算所述密集预测模块预测的检测框和所述真实框的重叠程度,根据所述质量分布在所述特征图中的每个像素点处的取值与对应的重叠程度计算第一损失,并根据所述第一损失更新所述质量分布编码模块以及所述特征提取模块的参数;根据每个正样本点对应的检测框及其类别分数,以及,所述真实框及其对应的真实类别计算第二损失,并根据所述第二损失更新所述密集预测模块以及所述特征提取模块的参数。

上述方法利用密集预测模块预测的检测框和真实框的重叠程度(重叠程度反映预测质量),作为监督信号来监督质量分布的学习过程,从而使得该质量分布能够有效地描述密集预测模块对检测框的预测质量在空间上的分布状况。根据该质量分布进行采样,将得到更多靠近分布中心(质量分布的峰值位置)的采样点,并且由于质量分布在空间上是连续的,所以采样点未必落在特征图的像素点位置,即具有亚像素精度。根据质量分布表征的含义,在靠近分布中心且具有亚像素精度的采样点处密集预测模块对检测框的预测具有较高的质量,因此它们是高质量的正样本点(即作为正样本的采样点),而利用更多高质量的正样本点进行模型训练,模型的检测效果(包括检测框的预测精度和类别的预测精度)将显著提高。

在第一方面的一种实现方式中,在所述利用所述质量分布采样模块根据所述质量分布进行采样,得到多个采样点之后,所述方法还包括:确定所述多个采样点中的负样本点,并根据所述特征图中位于每个负样本点周围的像素点对应的类别分数,计算得到每个负样本点对应的类别分数;根据每个负样本点对应的类别分数,以及,所述真实框对应的真实类别计算第三损失,并根据所述第三损失更新所述密集预测模块以及所述特征提取模块的参数。

训练目标检测模型需要正样本和负样本,在上述实现方式中,负样本点也可以从采样点中选取。相较于正样本点,负样本点远离质量分布的分布中心,并且也具有亚像素精度,从而在用于训练时有利于提高模型的检测效果。

在第一方面的一种实现方式中,所述确定所述多个采样点中的正样本点,包括:通过比较所述质量分布在所述多个采样点处的取值,将其中取值最大的预设数量的采样点确定为所述正样本点;所述确定所述多个采样点中的负样本点,包括:将所述多个采样点中除所述正样本点之外的采样点确定为所述负样本点。

质量分布可以用概率密度函数等方式来表示,从而质量分布在某个采样点处的取值就是指概率密度函数在该采样点处的函数值。通过一定的算法(例如,topK算法)可以确定对应的函数值靠前的若干个采样点,这些采样点在位置上接近分布中心(即概率密度函数的峰值位置),从而是理想的正样本点,剩余的采样点则可以作为负样本点,这样采样的结果也得到了充分的利用。

在第一方面的一种实现方式中,所述质量分布编码模块包括感兴趣区域池化层以及至少一个全连接层,所述利用所述质量分布编码模块根据所述特征图以及所述训练图像中的真实框预测质量分布的分布参数,包括:利用所述感兴趣区域池化层从所述特征图中扣取出位于真实框内的真实特征;利用所述至少一个全连接层根据所述真实特征预测质量分布的分布参数。

感兴趣区域池化层可以采用RoI Pooling或者RoI Align操作,用于从特征图中扣取出位于真实框内的部分,并将其缩小至固定的尺寸(便于输入后续网络),得到真实特征。至少一个全连接层构成编码器,编码器利用全连接层的信息整合能力将输入特征编码为质量分布的分布参数,编码器的参数可以在训练过程中不断调整。

在第一方面的一种实现方式中,所述质量分布采用高斯混合模型,所述分布参数包括所述高斯混合模型中每个高斯分布的权重、均值以及标准差。

高斯混合模型可以有效地模拟几乎任何分布,并且其构成形式也比较简单,所需的分布参数数量较少,因此质量分布可以采用高斯混合模型,以便有效地模拟检测框预测质量的实际分布。

在第一方面的一种实现方式中,所述根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数,包括:根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,利用插值运算得到每个正样本点对应的检测框及其类别分数。

基于密集预测结果,利用插值运算(例如,最邻近插值、双线性插值、双三次插值等)可以精确地计算出特征图范围内任意位置的预测结果,即提供了一种得到亚像素精度的正样本的方法。

第二方面,本申请实施例提供一种目标检测方法,用于利用第二检测模型进行目标检测,所述第二检测模型包括特征提取模块以及密集预测模块,所述方法包括:利用所述特征提取模块提取待检测图像的特征图;利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;根据所述特征图中的每个像素点对应的检测框及其类别分数,计算所述待检测图像中最终的检测框及其类别分数;其中,所述特征提取模块以及所述密集预测模块利用第一方面或第一方面的任意一种可能的实现方式提供的方法训练得到。

上述方法使用第二检测模型实现目标检测,由于第二检测模型中的特征提取模块和密集预测模块是利用第一方面或其任意一种可能的实现方式训练得到的,根据前文的分析可知,第二检测模型具有较好的检测效果。需要指出,质量分布编码模块和质量分布采样模块仅在模型的训练阶段使用,在模型的测试阶段不使用,即第二检测模型可视为第一检测模型针对测试阶段的简化版本。

第三方面,本申请实施例提供一种模型训练装置,用于训练第一检测模型,所述第一检测模型包括特征提取模块、密集预测模块、质量分布编码模块以及质量分布采样模块,所述装置包括:第一特征提取单元,用于利用所述特征提取模块提取训练图像的特征图;第一密集预测单元,用于利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;质量分布编码单元,用于利用所述质量分布编码模块根据所述特征图以及所述训练图像中的真实框预测质量分布的分布参数,所述质量分布表征所述密集预测模块预测的检测框和所述真实框的重叠程度在所述特征图所在的平面上满足的分布;质量分布采样单元,用于利用所述质量分布采样模块根据所述质量分布进行采样,得到多个采样点,确定所述多个采样点中的正样本点,并根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数;第一训练单元,用于计算所述密集预测模块预测的检测框和所述真实框的重叠程度,根据所述质量分布在所述特征图中的每个像素点处的取值与对应的重叠程度计算第一损失,并根据所述第一损失更新所述质量分布编码模块以及所述特征提取模块的参数;第二训练单元,用于根据每个正样本点对应的检测框及其类别分数,以及,所述真实框及其对应的真实类别计算第二损失,并根据所述第二损失更新所述密集预测模块以及所述特征提取模块的参数。

第四方面,本申请实施例提供一种目标检测装置,用于利用第二检测模型进行目标检测,所述第二检测模型包括特征提取模块以及密集预测模块,所述装置包括:第二特征提取单元,用于利用所述特征提取模块提取待检测图像的特征图;第二密集预测单元,用于利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;最终预测单元,用于根据所述特征图中的每个像素点对应的检测框及其类别分数,计算所述待检测图像中最终的检测框及其类别分数;其中,所述特征提取模块以及所述密集预测模块利用第一方面或第一方面的任意一种可能的实现方式提供的方法训练得到。

第五方面,本申请实施例提供一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序指令,所述计算机程序指令被处理器读取并运行时,执行第一方面、第二方面或以上两方面的任意一种可能的实现方式提供的方法。

第六方面,本申请实施例提供一种电子设备,包括:存储器以及处理器,所述存储器中存储有计算机程序指令,所述计算机程序指令被所述处理器读取并运行时,执行第一方面、第二方面或以上两方面的任意一种可能的实现方式提供的方法。

附图说明

为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。

图1示出了本申请实施例提供的模型训练方法的一种可能的流程;

图2示出了本申请实施例提供的第一检测模型的一种可能的结构;

图3示出了本申请实施例提供的目标检测方法的一种可能的流程;

图4示出了本申请实施例提供的模型训练装置的一种可能的结构;

图5示出了本申请实施例提供的目标检测装置的一种可能的结构;

图6示出了本申请实施例提供的电子设备的一种可能的结构。

具体实施方式

下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。

术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。

术语“第一”、“第二”等仅用于将一个实体或者操作与另一个实体或操作区分开来,而不能理解为指示或暗示相对重要性,也不能理解为要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。

图1示出了本申请实施例提供的模型训练方法的一种可能的流程。该方法可以但不限于由图6示出的电子设备执行,该电子设备的具体结构可参考后文关于图6的阐述。该方法所要训练的目标检测模型称为第一检测模型,第一检测模型为卷积神经网络模型,其至少包括特征提取模块、密集预测模块、质量分布编码模块以及质量分布采样模块,每一个模块都可以视为一个子网络,这四个模块在模型中的位置如图2所示,后文还将结合图2对各模块进行具体的介绍,可以理解的,第一检测模型还可能包含更多的模块。

参照图1,模型训练方法包括:

步骤S110:利用特征提取模块提取训练图像的特征图。

训练图像可以指训练集中任一张的图像,训练图像经过了预先标注,标注结果包括图像中每个目标的真实位置以及类别,分别称为真实框和真实类别。

特征提取模块能够对训练图像进行特征提取,并输出训练图像对应的特征图,不妨将特征图的维度记为H×W×C0,其中H、W、C0分别是特征图的高度、宽度以及通道数。特征提取模块可以是一个卷积神经网络,其可以采用多种不同的结构:

例如,特征提取模块可以包括主干网络(Backbone),比如ResNet、VGG等网络,主干网络输出的特征图即为步骤S110中的特征图。

又例如,特征提取模块可以包括主干网络和特征金字塔网络(Feature PyramidNetwork,简称FPN),特征金字塔网络基于主干网络提取到的特征构建特征金字塔,即输出多尺度的特征,用于检测不同尺度下的目标,如图2所示。每种尺度下的特征都可以视为一张S110中的特征图,针对每张特征图都可以分别执行后续的步骤S120至S160,后文仅以其中一张特征图为例进行说明。进一步的,由于步骤S120至S160中使用了一组模块(密集预测模块、质量分布编码模块、质量分布预测模块),因此针对每张特征图都需要单独设置一组这样的模块,比如,对于图2中的P3、P4、P5三张特征图,就要设置对应的三组模块,各组模块结构相同,但具体的网络参数则可能不同。

步骤S120:利用密集预测模块针对特征图中的每个像素点预测对应的检测框及其类别分数。

密集预测模块,顾名思义,其预测结果是密集的,即针对特征图中的每个像素点都会给出一个预测结果,预测结果包括两项:一项是检测框的位置信息(后文简称检测框),一项是类别分数,类别分数描述了检测框中的目标属于各目标类别的概率。

密集预测模块可以是一个卷积神经网络,该网络可以进一步包括分类分支和回归分支,分别用于对类别分数和检测框进行密集预测。例如,在图2中,分类分支和回归分支均包括4个卷积层。分类分支输出的预测结果维度是H×W×C,即针对特征图(维度H×W×C0)中的每个像素点均预测出C个类别分数,其中C为目标类别总数。回归分支输出的预测结果维度H×W×4,即针对特征图中的每个像素点均预测出4个表示检测框位置的数值,根据不同的实现方式,这4个数值可以是坐标值、可以是偏移量值、可以是变换系数等。

应当理解,图2中密集预测模块的结构仅为示例,例如,在其他一些实现方式中,分类分支和回归分支中卷积层的数量可能不是4层,分类分支和回归分支也可能包含公共的部分,等等。

步骤S130:利用质量分布编码模块根据特征图以及训练图像中的真实框预测质量分布的分布参数。

在一些实现方式中,质量分布编码模块可以包括两部分结构,第一部分结构用于从特征图中扣取出位于真实框内的真实特征,第二部分结构用于根据真实特征预测质量分布的分布参数。

其中,第一部分结构可以实现为一个感兴趣区域池化层,该层以步骤S110中的特征图以及感兴趣区域的位置(这里就是真实框)为输入,用于从特征图中扣取出位于真实框内的部分,并将其缩小至固定的尺寸后输出,输出的特征称为真实特征,其中缩小至固定尺寸的目的是为了便于后续的第二部分结构能够统一处理。感兴趣区域池化层可以采用RoIPooling或者RoI Align操作,其中前者的计算方式较为简单,后者的计算精度较高。

在介绍步骤S110时提到,真实框在训练图像中是已经标注的,虽然质量分布编码模块的输入是特征图,但特征图和训练图像之间的缩放倍数是已知的,因此真实框在特征图中的位置也可以对应计算得到,如图2中标注有GT的方框所示。

需要指出,由于训练图像中可能有多个目标,因此真实框也可能有多个,此时需将每个真实框依次输入质量分布编码模块,也就是说步骤S130至步骤S160对于每个真实框都要执行一次,在后文中为简单起见,仅以一个真实框为例。

第二部分结构是一个编码器,该编码器可以包括至少一个全连接层(图2示出了2个全连接层),其输入为真实特征,用于利用全连接层的信息整合能力将真实特征编码为质量分布的分布参数输出。

下面说明一下质量分布的含义:

密集预测模块在特征图的每个像素点位置都预测出一个检测框,可以计算该检测框和真实框的重叠程度,显然,计算出的重叠程度代表了检测框的预测质量,即重叠程度越大,检测框的预测质量越高,反之预测质量则越低。两个矩形框的重叠程度的定义方式不限,例如可以定义为二者的交并比(Intersection over Union,简称IoU)或者交并比的变种(例如,GIoU等)。

由于在特征图的每个像素点位置都可以计算出一个重叠程度,从而可以认为特征图所在的平面上形成了一个重叠程度所满足的统计分布,不妨称为重叠程度的真实分布。质量分布可视为上述真实分布的一个模拟,质量分布的分布类型是预先设定好的,其分布参数则根据编码器预测得到。如果质量分布的分布类型选择合适,并且编码器预测的分布参数也足够合适,那么质量分布就可以很好地逼近上述真实分布。质量分布具体的表示方式不限,例如,可以表示为概率密度函数、分布函数等。

在一些实现方式中,质量分布可以采用高斯混合模型(Gaussian Mixed Model,简称GMM)实现(这里的高斯混合模型即一种分布类型),高斯混合模型由多个高斯分布加权形成,其可以有效地模拟几乎任何分布,并且其构成形式也比较简单。高斯混合模型的分布参数包括每个高斯分布的权重π、均值μ以及标准差σ,至于高斯分布的个数N则是预先指定的,例如,可以指定为2、3等。参照图2,若设置N=2,则编码器应当输出2个权重π1和π2、2个均值μ1和μ2以及2个标准差σ1和σ1。其中,均值μ1实际上是两个数值,因为这里的高斯分布是二维的,即μ1包括μ1x和μ1y,对于μ2、σ1和σ1也是同理。可以理解的,质量分布也可以选择其他分布类型。

若希望编码器能够输出合适的分布参数,则需要对其进行监督训练,具体方法在步骤S150中再进行介绍。在训练一段时间后,根据这些分布参数所确定出来的质量分布,就能够较好地模拟密集预测模块预测的检测框和真实框的重叠程度在特征图所在的平面上的真实分布了。

当然,也不排除在某些实现方式中,质量分布编码模块只包括编码器,该编码器以特征图和真实框为输入,输出质量分布的分布参数,并不执行从特征图中扣取真实特征的步骤。

步骤S140:利用质量分布采样模块根据质量分布进行采样,得到多个采样点,确定多个采样点中的正样本点,并根据特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数。

已知分布的形式进行采样存在多种方式,以随机采样为例,在特征图所在的平面上,质量分布的取值(例如,可以指概率密度函数对应的函数值)越大的位置被选择为采样点的概率越高,从而在获得的采样点中,将有较多的点位于质量分布的分布中心(例如,可以指概率密度函数的峰值位置)附近,例如,在图2中,质量分布有2个分布中心,周围共有12个采样点。可以理解的,也可以采用其他采样方式:例如,部分采样点按照指定的方式(非随机)选取,剩余的采样点则随机选取,比如,要获得2×H×W个采样点,其中H×W个采样点就选取特征图中的像素点,另外H×W个采样点则随机选取;又比如,要获得100个采样点,可将质量分布的2个分布中心作为其中的2个采样点,其余的98个采样点则随机选取,等等。

获得采样点后,可以从中选择正样本点,正样本点即“被作为训练用的正样本的采样点”的一种简化说法。而某个采样点作为正样本,具体是指该采样点对应的预测结果(包括检测框和类别分时)对应于图像中的真实目标,从而可以根据该预测结果和真实目标的标签计算分类损失和回归损失。相对于正样本点还有负样本点,负样本点即“被作为训练用的负样本的采样点”的一种简化说法,某个采样点作为负样本,具体是指该采样点对应的预测结果(包括检测框和类别分时)对应于图像中的背景,从而可以根据该预测结果和真实目标的标签计算分类损失(负样本不计算回归损失,因为不需要针对背景预测检测框)。

对于训练第一检测模型而言,正样本点和负样本点都是需要的,下面说明几种选择正样本点和负样本点方式,但应当理解,选择正样本点和负样本点的方式不限于这些:

方式1

通过比较质量分布在多个采样点处的取值,将其中取值最大的K(K为正整数)个采样点确定为正样本点,将多个采样点中除正样本点之外的采样点确定为负样本点。

在质量分布的形式确定以后,每个采样点处质量分布的取值也就确定了,从而通过一定的算法(例如,topK算法)可以确定质量分布的取值靠前的K个采样点,这些采样点在位置上接近分布中心,从而是理想的正样本点,剩余的采样点则可以作为负样本点,以便充分利用得到的每个采样点。方式1中的K可以是一个固定的数值,例如,12;或者,也可以是一个根据采样点总数和预设的比例计算出的数值,例如,全部采样点中的10%。

方式2

将对应的质量分布的取值大于预设阈值的采样点确定为正样本点,将多个采样点中除正样本点之外的采样点确定为负样本点。方式2达到的效果和方式1类似。

方式3

通过比较质量分布在多个采样点处的取值,将其中取值最大的K1(K1为正整数)个采样点确定为正样本点,将其中取值最小的K2(K2为正整数)个采样点确定为负样本点,多个采样点中除正样本点和负样本点之外的采样点则不作为训练样本,不参与损失计算。

相较于方式1和2,方式3只选择那些“明显”的正样本点或负样本点,从而促使模型多学习那些能够比较明确地区分目标和非目标的特征。

由于质量分布在空间上是连续的,所以采样点未必落在特征图的像素点位置,即采样点具有亚像素的精度。对于那些未落在像素点上的采样点,其对应的检测框及其类别分数是未知的,所以尚不能直接作为训练样本参与训练。但由于特征图中位于每个采样点周围的像素点对应的检测框及其类别分数是已知的(通过密集预测模块得到),所以在理论上可以根据特征图中位于每个采样点周围的像素点对应的检测框计算得到每个采样点对应的检测框,以及,根据特征图中位于每个采样点周围的像素点对应的类别分数,计算得到每个采样点对应的类别分数。

这里的“周围”可以是指距离采样点最近的若干个(例如,4个、16个等)像素点,而计算的方式包括但不限于求均值、插值运算(例如,最邻近插值、双线性插值、双三次插值等)。其中,插值运算可以比较精确地估计采样点对应的检测框及其类别分数。参照图2,P0,0、P0,1、P1,0、P1,1表示4个像素点,某采样点落在这4个像素点形成的方格中,则该采样点对应的检测框及其类别分数可以利用P0,0、P0,1、P1,0、P1,1对应的检测框及其类别分数,通过双线性插值运算得到。

具体到每个正样本点,可以根据特征图中位于其周围的像素点对应的检测框及其类别分数,计算得到其对应的检测框及其类别分数;而对于每个负样本点,则可以根据特征图中位于其周围的像素点对应的类别分数,计算得到其对应的类别分数。这样,就得到了亚像素精度正负样本,可以继续执行后续的损失计算步骤。不难看出,针对负样本点无需计算对应的检测框,其原因在于负样本不参与回归损失的计算,而正样本则分类损失和回归损失都要计算。

可以理解的,若部分采样点既未作为正样本点,也未作为负样本点,其对应的检测框及其类别分数无需计算。而对于某些样本点,如果恰好落在特征图的像素点上(比如,在采样时,就将特征图的像素点作为采样点),这些样本点对应的检测框及其类别分数也无需计算,直接采样密集预测的结果即可。

步骤S150:计算密集预测模块预测的检测框和真实框的重叠程度,根据质量分布在特征图中的每个像素点处的取值与对应的重叠程度计算第一损失,并根据第一损失更新质量分布编码模块以及特征提取模块的参数。

在步骤S130中已经提到,密集预测模块在特征图的每个像素点位置都预测出一个检测框,从而在特征图的每个像素点位置都可以计算出一个检测框和真实框之间的重叠程度,因此特征图所在的平面上形成了一个重叠程度所满足的统计分布,即重叠程度的真实分布。

该真实分布可以视为通过模型预测出的质量分布的标签,即可以作为监督信号用来监督质量分布的学习过程。具体地,特征图中任一像素点处的算出的重叠程度都可以视为质量分布在该像素点处取值的标签,从而,基于质量分布在特征图中的每个像素点处的取值与对应的标签,利用预设的损失函数(例如,交叉熵损失)就可以计算出第一损失,第一损失表征的是质量分布与真实分布之间的差异性,根据第一损失利用反向传播算法更新质量分布编码模块(主要是其中编码器的参数)以及特征提取模块的参数,就可以使得质量分布逐渐接近于重叠程度的真实分布。在图2中,第一损失记为L

需要指出,步骤S150可以步骤S130之后就执行,并不一定要在步骤S140之后执行。

步骤S160:根据每个正样本点对应的检测框及其类别分数,以及,真实框及其对应的真实类别计算第二损失,并根据第二损失更新密集预测模块以及特征提取模块的参数。

第二损失包含两项损失,分别是分类损失和回归损失。其中,根据每个正样本点对应的类别分数以及真实框对应的真实类别计算的是分类损失,其损失函数可采用交叉熵损失,根据分类损失,可以利用反向传播算法更新密集预测模块中的分类分支以及特征提取模块的参数。而根据每个正样本点对应的检测框以及真实框计算的是回归损失,其损失函数可采用交并比损失(IoU Loss),根据回归损失,可以利用反向传播算法更新密集预测模块中的回归分支以及特征提取模块的参数。在图2中,分类损失记为L

进一步的,还可以根据每个负样本点对应的类别分数以及真实框对应的真实类别计算第三损失,并根据第三损失,利用反向传播算法更新密集预测模块以及特征提取模块的参数。第三损失属于分类损失,因此此处更新的是密集预测模块中回归分支的参数。

下面分析一下上述模型训练方法带来的有益效果:

该模型训练方法利用密集预测模块预测的检测框和真实框的重叠程度,作为监督信号来监督质量分布的学习过程,从而使得该质量分布能够有效地描述密集预测模块对检测框的预测质量在空间上的分布状况。根据该质量分布进行采样,按概率将得到更多靠近分布中心的采样点,并且由于质量分布在空间上是连续的,所以采样点未必落在特征图的像素点上,即具有亚像素精度。根据质量分布表征的含义,在靠近分布中心且具有亚像素精度的采样点处密集预测模块对检测框的预测具有较高的质量,因此它们是高质量的正样本点,而利用更多高质量的正样本点进行模型训练,模型的检测效果(包括检测框的预测精度和类别的预测精度)将显著提高。

定性地理解,在图2中,训练图像包含一只羊,2个分布中心位于羊的头部和身体中部,这两个位置最能够体现出羊的特征,从而在这两个位置附近选择较多的正样本点进行学习,可以使模型能够更准确地检测图像中的羊。

类似的,负样本点远离质量分布的分布中心,并且也具有亚像素精度,从而在用于训练时也有利于提高模型的检测效果。

在第一检测模型中,质量分布编码模块和质量分布采样模块是本申请新提出的模型结构,而特征提取模块和密集预测模块则可以直接采用现有目标检测模型中的相关结构,此种应用场景也可以认为是将质量分布编码模块和质量分布采样模块集成到现有的目标检测模型中以改善其性能。这些现有的目标检测模型包括单阶段的检测器FCOS、RetinaNet、ATSS以及两阶段的检测器Faster RCNN(其中的RPN网络作为密集预测模块)等。

上面描述了目标检测模型的训练阶段,下面继续介绍目标检测模型的测试阶段,即如何使用训练好的模型进行目标检测。

图3示出了本申请实施例提供的目标检测方法的一种可能的流程。该方法可以但不限于由图6示出的电子设备执行,该电子设备的具体结构可参考后文关于图6的阐述。该方法检测目标所用的目标检测模型称为第二检测模型,第二检测模型至少包括特征提取模块以及密集预测模块,其中特征提取模块和密集预测模块也就是第一检测模型中的同名模块,二者是利用本申请实施例提供的模型训练方法训练得到的。第二检测模型可视为第一检测模型针对测试阶段的简化版本,第一检测模型中的质量分布编码模块和质量分布采样模块仅在模型的训练阶段使用,在模型的测试阶段不使用,在图2中也用文字进行了相关说明。

参照图3,目标检测方法包括:

步骤S210:利用特征提取模块提取待检测图像的特征图。

步骤S220:利用密集预测模块针对特征图中的每个像素点预测对应的检测框及其类别分数。

以上两个步骤类似步骤S110和步骤S120,只是其中的训练图像换成了待检测图像,不再重复阐述。

步骤S230:根据特征图中的每个像素点对应的检测框及其类别分数,计算待检测图像中最终的检测框及其类别分数。

在一些简单的实现方式中,对于步骤S220得到的密集预测结果,可以设置分数阈值,以筛选出其中得分较高的检测框,再进一步进行检测框去重等操作,得到最终的目标检测结果(即最终的检测框及其类别分数)。在一些复杂一点的实现方式中,步骤S220也可以通过最终预测模块完成,最终预测模块是第二检测模型的子网络,若最终预测模块包含可学习的参数,其也可以包含在第一检测模型中进行训练。

在上述目标检测方法中,由于第二检测模型包含的特征提取模块和密集预测模块是利用本申请实施例提供的模型训练方法训练得到的,因此根据前文的分析可知,第二检测模型具有较好的检测效果。

图4示出了本申请实施例提供的模型训练装置300的功能模块图。模型训练装置300用于训练第一检测模型,所述第一检测模型包括特征提取模块、密集预测模块、质量分布编码模块以及质量分布采样模块,参照图4,模型训练装置300包括:

第一特征提取单元310,用于利用所述特征提取模块提取训练图像的特征图;

第一密集预测单元320,用于利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;

质量分布编码单元330,用于利用所述质量分布编码模块根据所述特征图以及所述训练图像中的真实框预测质量分布的分布参数,所述质量分布表征所述密集预测模块预测的检测框和所述真实框的重叠程度在所述特征图所在的平面上满足的分布;

质量分布采样单元340,用于利用所述质量分布采样模块根据所述质量分布进行采样,得到多个采样点,确定所述多个采样点中的正样本点,并根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数;

第一训练单元350,用于计算所述密集预测模块预测的检测框和所述真实框的重叠程度,根据所述质量分布在所述特征图中的每个像素点处的取值与对应的重叠程度计算第一损失,并根据所述第一损失更新所述质量分布编码模块以及所述特征提取模块的参数;

第二训练单元360,用于根据每个正样本点对应的检测框及其类别分数,以及,所述真实框及其对应的真实类别计算第二损失,并根据所述第二损失更新所述密集预测模块以及所述特征提取模块的参数。

在模型训练装置300的一种实现方式中,质量分布采样单元340还用于:在所述利用所述质量分布采样模块根据所述质量分布进行采样,得到多个采样点之后,确定所述多个采样点中的负样本点,并根据所述特征图中位于每个负样本点周围的像素点对应的类别分数,计算得到每个负样本点对应的类别分数;模型训练装置300还包括:第三训练单元,用于根据每个负样本点对应的类别分数,以及,所述真实框对应的真实类别计算第三损失,并根据所述第三损失更新所述密集预测模块以及所述特征提取模块的参数。

在模型训练装置300的一种实现方式中,质量分布采样单元340确定所述多个采样点中的正样本点,包括:通过比较所述质量分布在所述多个采样点处的取值,将其中取值最大的预设数量的采样点确定为所述正样本点;质量分布采样单元340确定所述多个采样点中的负样本点,包括:将所述多个采样点中除所述正样本点之外的采样点确定为所述负样本点。

在模型训练装置300的一种实现方式中,所述质量分布编码模块包括感兴趣区域池化层以及至少一个全连接层,质量分布编码单元330利用所述质量分布编码模块从所述特征图中扣取出位于真实框内的真实特征,并根据所述真实特征预测质量分布的分布参数,包括:利用所述感兴趣区域池化层从所述特征图中扣取出位于真实框内的真实特征;利用所述至少一个全连接层根据所述真实特征预测质量分布的分布参数。

在模型训练装置300的一种实现方式中,所述质量分布采用高斯混合模型,所述分布参数包括所述高斯混合模型中每个高斯分布的权重、均值以及标准差。

在模型训练装置300的一种实现方式中,质量分布采样单元340根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,计算得到每个正样本点对应的检测框及其类别分数,包括:根据所述特征图中位于每个正样本点周围的像素点对应的检测框及其类别分数,利用插值运算得到每个正样本点对应的检测框及其类别分数。

本申请实施例提供的模型训练装置300,其实现原理及产生的技术效果在前述方法实施例中已经介绍,为简要描述,装置实施例部分未提及之处,可参考方法实施例中相应内容。

图5示出了本申请实施例提供的目标检测装置400的功能模块图。目标检测装置400用于利用第二检测模型进行目标检测,所述第二检测模型包括特征提取模块以及密集预测模块,参照图5,目标检测装置400包括:

第二特征提取单元410,用于利用所述特征提取模块提取待检测图像的特征图;

第二密集预测单元420,用于利用所述密集预测模块针对所述特征图中的每个像素点预测对应的检测框及其类别分数;

最终预测单元430,用于根据所述特征图中的每个像素点对应的检测框及其类别分数,计算所述待检测图像中最终的检测框及其类别分数;

其中,所述特征提取模块以及所述密集预测模块利用本申请实施例提供的模型训练方法训练得到。

本申请实施例提供的目标检测装置400,其实现原理及产生的技术效果在前述方法实施例中已经介绍,为简要描述,装置实施例部分未提及之处,可参考方法实施例中相应内容。

图6示出了本申请实施例提供的电子设备500的一种可能的结构。参照图6,电子设备500包括:处理器510、存储器520以及通信接口530,这些组件通过通信总线540和/或其他形式的连接机构(未示出)互连并相互通讯。

其中,处理器510包括一个或多个(图中仅示出一个),其可以是一种集成电路芯片,具有信号的处理能力。上述的处理器510可以是通用处理器,包括中央处理器(CentralProcessing Unit,简称CPU)、微控制单元(Micro Controller Unit,简称MCU)、网络处理器(Network Processor,简称NP)或者其他常规处理器;还可以是专用处理器,包括图形处理器(Graphics Processing Unit,GPU)、神经网络处理器(Neural-network ProcessingUnit,简称NPU)、数字信号处理器(Digital Signal Processor,简称DSP)、专用集成电路(Application Specific Integrated Circuits,简称ASIC)、现场可编程门阵列(FieldProgrammable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。并且,在处理器510为多个时,其中的一部分可以是通用处理器,另一部分可以是专用处理器。

存储器520包括一个或多个(图中仅示出一个),其可以是,但不限于,随机存取存储器(Random Access Memory,简称RAM),只读存储器(Read Only Memory,简称ROM),可编程只读存储器(Programmable Read-Only Memory,简称PROM),可擦除可编程只读存储器(Erasable Programmable Read-Only Memory,简称EPROM),电可擦除可编程只读存储器(Electric Erasable Programmable Read-Only Memory,简称EEPROM)等。处理器510以及其他可能的组件可对存储器520进行访问,读和/或写其中的数据。

特别地,在存储器520中可以存储一个或多个计算机程序指令,处理器510可以读取并运行这些计算机程序指令,以实现本申请实施例提供的模型训练方法和/或目标检测方法。

通信接口530包括一个或多个(图中仅示出一个),可以用于和其他设备进行直接或间接地通信,以便进行数据的交互。通信接口530可以包括进行有线和/或无线通信的接口。

可以理解,图6所示的结构仅为示意,电子设备500还可以包括比图6中所示更多或者更少的组件,或者具有与图6所示不同的配置。图6中所示的各组件可以采用硬件、软件或其组合实现。电子设备500可能是实体设备,例如PC机、笔记本电脑、平板电脑、手机、服务器、嵌入式设备等,也可能是虚拟设备,例如虚拟机、虚拟化容器等。并且,电子设备500也不限于单台设备,也可以是多台设备的组合或者大量设备构成的集群。

本申请实施例还提供一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序指令,所述计算机程序指令被计算机的处理器读取并运行时,执行本申请实施例提供的模型训练和/或目标检测方法。例如,计算机可读存储介质可以实现为图6中电子设备500中的存储器520。

以上所述仅为本申请的实施例而已,并不用于限制本申请的保护范围,对于本领域的技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

相关技术
  • 模型训练方法、目标检测方法及对应装置
  • 目标检测方法和目标检测模型的训练方法、装置
技术分类

06120113046890