掌桥专利:专业的专利平台
掌桥专利
首页

训练多分支网络的方法和对象检测方法

文献发布时间:2023-06-19 11:29:13


训练多分支网络的方法和对象检测方法

技术领域

本发明涉及训练神经网络的方法,更具体地,涉及基于深度相互学习来训练多分支网络的方法以及利用经训练的网络来检测对象的方法。

背景技术

目前,利用神经网络的对象检测技术已经得到了极大的发展。作为对象检测的一个具体应用,在航拍图像中检测对象的目的是对地面上的感兴趣的对象(例如车辆,飞机,桥梁等)进行准确定位和分类。这项工作颇具挑战性,因为与一般图像不同,航拍图像通常是鸟瞰图,在这种情况下,图像中的对象通常具有任意的取向,并且具有非常复杂的背景和多变的外观。

对此已经提出了两阶段目标检测方法,例如Faster RCNN,其在一定程度上能够取得较好的检测效果。图1示出了Faster RCNN的典型架构。如图1所示,Faster RCNN包括以下四个部分:

-卷积层,用于针对图像提取特征,并且输出特征图;

-区域建议网络(RPN),用于推荐候选区域,其输入为特征图,输出为多个候选区域;

-RoI池化,用于将不同尺寸的输入转换为固定尺寸的输出;

-分类和回归,用于确定候选区域在图像中的精确位置以及候选区域中的对象的类别

然而,本发明的发明人已经认识到:诸如Faster RCNN这样的单个网络模型的鲁棒性和效率方面都存在不足,另一方面,利用深度相互学习可以进一步改进网络的训练,从而获得性能改善的对象检测网络。

发明内容

基于上述认识,本发明提出了一种具有多个分支的目标检测网络,其中,对多个分支一起进行训练,并且在训练中该多个分支之间互相学习,在检测目标时将各个分支的输出进行融合以获得最终检测结果。

根据本发明的一个方面,提供了一种训练多分支网络的方法。所述多分支网络用于在图像中检测对象,并且包括针对图像提取特征图的骨干网络以及连接至所述骨干网络的输出的至少两个子网络,每个子网络包括区域建议网络(RPN)和检测器。所述方法包括:在每个子网络中,由所述RPN基于从所述骨干网络输出的特征图,确定多个候选区域的位置以及每个候选区域中包含有对象的概率,并且输出确定结果;以及由所述检测器基于所述RPN的输出,在所述图像中确定包含对象的区域以及所述对象的类别,其中,通过深度相互学习来训练各个子网络中的各个RPN,使得各个RPN所输出的确定结果彼此接近。

根据本发明的另一个方面,提供了一种在图像中检测对象的方法,包括:利用根据上述方法训练得到的多分支网络来检测对象,其中,对各个子网络的检测器的输出进行合并,以在所述图像中确定包含对象的区域以及所述对象的类别。

根据本发明的另一个方面,提供了一种多分支网络,其中,所述多分支网络用于在图像中检测对象,并且包括针对图像提取特征图的骨干网络以及连接至所述骨干网络的输出的至少两个子网络。每个子网络包括区域建议网络(RPN)和检测器。在每个子网络中,由所述RPN基于从所述骨干网络输出的特征图来确定多个候选区域的位置以及每个候选区域中包含有对象的概率,并且输出确定结果,并且由所述检测器基于所述RPN的输出来在所述图像中确定包含对象的区域以及所述对象的类别。通过深度相互学习来训练各个子网络中的各个RPN,使得各个RPN所输出的确定结果彼此接近。

根据本发明的另一个方面,提供了一种训练多分支网络的装置。所述装置包括被配置为执行上述训练方法的一个或多个处理器。

根据本发明的另一个方面,提供了一种存储有计算机可执行指令的存储介质,所述指令使计算机执行上述训练方法。

附图说明

图1示出了传统的Faster RCNN的典型架构。

图2示意性地示出了根据本发明的多分支网络的一个示例。

图3示意性地示出了根据本发明的训练多分支网络的方法的流程图。

图4示意性地示出了根据本发明的训练多分支网络的方法的另一流程图。

图5示意性地示出了根据本发明的训练多分支网络的方法的又一流程图。

图6示出了根据本发明的对象检测的伪代码。

图7示出了实现本发明的计算机硬件的示例性配置框图。

具体实施方式

图2示意性地示出了根据本发明的多分支网络的一个示例。为了描述的简明以及便于理解,图2仅示出了两个分支,但本发明并不限于两个分支,而是可以包括更多个分支子网络。

如图2所示,本发明提出了基于Faster RCNN架构的、具有多个分支的R2CNN网络(旋转区域卷积神经网络)。该多分支R2CNN网络包括骨干网络100以及共享骨干网络100的两个子网络210,220。骨干网络100对应于图1所示的卷积层。两个子网络210,220具有相同的结构,子网络210包括区域建议网络RPN 1和检测器C1,子网络220包括RPN 2以及检测器C2。检测器C1,C2执行与图1中的分类与回归单元类似的功能。

图3示意性地示出了根据本发明的训练多分支网络的方法的流程图。首先在步骤S310,骨干网络100针对输入的图像(例如航拍图像)提取特征图。

RPN 1,RPN 2基于从骨干网络100输出的特征图,各自确定多个候选区域的位置以及每个候选区域中包含有对象的概率,如步骤S320所示。具体来说,在每一个子网络210,220中,由RPN提出轴对齐的多个候选区域(下文也称为“边界框”),各个候选区域中可以包含具有不同取向(orientation)的对象。

对于RPN提出的每个边界框,提取其池化特征以便同时预测水平边界框(HBB)、倾斜边界框(OBB)以及分类概率。水平边界框(HBB)与边界框的中心坐标、宽度以及高度有关,倾斜边界框(OBB)与边界框的倾斜角度有关。特别地,这里的分类是相对粗略的分类,RPN主要分辨在相应的边界框中是否存在对象,换言之,分辨是前景还是背景。因此,这里的分类概率是指示在相应边界框中包含有对象的概率。

然后,在步骤S330,检测器C1和C2分别基于相对应的RPN 1和RPN 2的输出,确定图像中包含对象的区域以及该对象的分类类别。特别地,这里的分类是精细的分类,用于确定所包含对象的具体类别。

在步骤S340,通过深度相互学习(DML)来一起训练RPN1和RPN2,使得RPN1和RPN2的输出结果彼此接近。DML是指多个网络模型同时进行训练,并且在训练中互相学习。每个网络在训练中不仅接受来自真值的监督,而且还要参考其它网络处理相同问题时的输出结果。两个网络之间不断分享学习经验,互相学习借鉴。

如上所述,RPN1和RPN2能够输出候选区域的位置(即,回归输出)以及每个候选区域中包含有对象的概率(即,分类概率)。因此,在基于深度相互学习(DML)的训练中,使用DML回归损失和DML分类损失。作为DML回归损失的一个示例,可以采用平滑L1损失,作为DML分类损失的一个示例,可以采用KL散度。因此,对于子网络i,可以通过以下等式(1)-(4)来计算DML损失:

L

其中,p

图4示出了根据本发明的训练多分支网络的方法的另一流程图。图4中的步骤S410-S440与图3中的步骤S310-S340相同,故不再赘述。图4所示的方法与图3的方法的不同之处在于:除了基于DML对各个RPN进行训练之外,还包括对于构成同一子网络的RPN和检测器的训练,如步骤S450所示。

在步骤S450的训练中使用分类损失和回归损失。作为分类损失的一个示例,可以采用交叉熵损失(cross entropy loss),作为回归损失的一个示例,可以采用平滑L1损失。此外,对于RPN而言,考虑两种损失,即分类损失L

可以通过以下等式(5)-(7)来计算用于该子网络的损失函数:

L({p

其中,p

此外,可以进一步将等式(4)表示的DML损失与等式(7)表示的损失相加,从而得到用于训练该子网络的最终损失,如以下等式(8)所示,其中λ是加权值。

L

图5示出了根据本发明的训练多分支网络的方法的又一流程图。图5中的步骤S510-S550与图4中的步骤S410-S450相同,故不再赘述。图5所示的方法与图4的方法的不同之处在于:除了对子网络进行训练之外,还包括对于骨干网络的训练,如步骤S560所示。这将在下文进行描述。

以下将描述根据本发明的训练方法的其它策略。

在训练图2所示的多分支网络时,以不同的方式来训练各个子网络(RPN),从而最大程度地发挥DML的效果。具体来说,各个子网络被随机地初始化,并且对其应用不同的训练样本。作为一个示例,可以利用第一图像序列作为训练样本来训练子网络210,并且利用第二图像序列来训练子网络220,该第一图像序列包含与第二图像序列不同的图像。或者,第一图像序列与第二图像序列可以包含相同的图像,但图像的顺序彼此不同。

此外,交替地训练各个子网络。作为一个示例,以第一图像序列中的第一图像来训练子网络210,然后以第二图像序列中的第一图像来训练子网络220,然后以第一图像序列中的第二图像来训练子网络210,然后以第二图像序列中的第二图像来训练子网络220……以此类推。

此外,在利用某一训练图像来训练某个子网络时,改变该子网络的参数,而同时保持其它子网络的参数不变。也就是说,在针对一个子网络的单独训练中固定其它子网络的配置。

此外,在训练每个子网络时,对该子网络与骨干网络100联合地进行训练。由于在本发明中各个子网络共享同一骨干网络100,以此方式,骨干网络100可以得到多次训练。

以上描述了对多分支网络的训练方法。在训练完成后,可以获得训练好的多个子网络模型,由于每个子网络都是在不同的模式下训练而获得的,因此这些子网络在某种程度上是互补的。这样,与单个网络模型相比,将各个子网络模型的输出进行融合将会进一步改善检测效果。因此,在应用本发明的多分支网络进行对象检测时,采用专家混合模型(mixture of experts),即,对每个子网络的输出结果进行融合,然后采用倾斜非极大值抑制(inclined Non-Maximum Suppression)来删除重复,从而获得最终检测结果。图6示出了相应的伪代码。

表1示出了不同网络模型的mAP得分,其中以R2CNN网络以及具有多分支的R2CNN网络作为参照,说明了根据本发明的基于DML的多分支R2CNN具有更好的检测性能。

[表1]

从表1中可以看出,深度相互学习(DML)将HBB和OBB检测性能分别提高了约0.5%和0.9%。此外,与所有分支当中具有最佳性能的单个分支相比,专家混合将检测性能进一步提高了0.4%和0.5%。

以上量化的结果证明了根据本发明的网络模型具有优良性能,这可归因于本发明的以下方面:

-在训练每个子网络时一同训练骨干网络,因而骨干网络得到更多的训练,提高了鲁棒性;

-基于DML来训练各个RPN,从而改善每个子网络的检测性能;

-应用专家混合模型,进一步改善检测性能。

以上已经结合具体实施例描述了根据本发明的多分支网络及其训练和应用。本发明可适用于对象检测,特别地,适用于在航拍图像中检测对象,以及适用于检测具有一定取向的对象。

在上文中描述的方法可以由软件、硬件或者软件和硬件的组合来实现。包括在软件中的程序可以事先存储在设备的内部或外部所设置的存储介质中。作为一个示例,在执行期间,这些程序被写入随机存取存储器(RAM)并且由处理器(例如CPU)来执行,从而实现在本文中描述的各种处理。

图7示出了根据程序来执行本发明的方法的计算机硬件的示例性框图,该计算机硬件是根据本发明的用于训练多分支网络的装置的一个示例。

如图7所示,在计算机700中,中央处理单元(CPU)701、只读存储器(ROM)702以及随机存取存储器(RAM)703通过总线704彼此连接。

输入/输出接口705进一步与总线704连接。输入/输出接口705连接有以下组件:以键盘、鼠标、麦克风等形成的输入单元706;以显示器、扬声器等形成的输出单元707;以硬盘、非易失性存储器等形成的存储单元708;以网络接口卡(诸如局域网(LAN)卡、调制解调器等)形成的通信单元709;以及驱动移动介质711的驱动器710,该移动介质711例如是磁盘、光盘、磁光盘或半导体存储器。

在具有上述结构的计算机中,CPU 701将存储在存储单元708中的程序经由输入/输出接口705和总线704加载到RAM 703中,并且执行该程序,以便执行上文中描述的方法。

要由计算机(CPU 701)执行的程序可以被记录在作为封装介质的移动介质711上,该封装介质以例如磁盘(包括软盘)、光盘(包括压缩光盘-只读存储器(CD-ROM))、数字多功能光盘(DVD)等)、磁光盘、或半导体存储器来形成。此外,要由计算机(CPU 701)执行的程序也可以经由诸如局域网、因特网、或数字卫星广播的有线或无线传输介质来提供。

当移动介质711安装在驱动器710中时,可以将程序经由输入/输出接口705安装在存储单元708中。另外,可以经由有线或无线传输介质由通信单元709来接收程序,并且将程序安装在存储单元708中。可替选地,可以将程序预先安装在ROM 702或存储单元708中。

由计算机执行的程序可以是根据本说明书中描述的顺序来执行处理的程序,或者可以是并行地执行处理或当需要时(诸如,当调用时)执行处理的程序。

本文中所描述的单元或装置仅是逻辑意义上的,并不严格对应于物理设备或实体。例如,本文所描述的每个单元的功能可能由多个物理实体来实现,或者,本文所描述的多个单元的功能可能由单个物理实体来实现。此外,在一个实施例中描述的特征、部件、元素、步骤等并不局限于该实施例,而是也可以应用于其它实施例,例如替代其它实施例中的特定特征、部件、元素、步骤等,或者与其相结合。

本发明的范围不限于在本文中描述的具体实施例。本领域普通技术人员应该理解的是,取决于设计要求和其他因素,在不偏离本发明的原理和精神的情况下,可以对本文中的实施例进行各种修改或变化。本发明的范围由所附权利要求及其等同方案来限定。

此外,本发明还可被配置如下。

(1).一种训练多分支网络的方法,其中,所述多分支网络用于在图像中检测对象,并且包括针对图像提取特征图的骨干网络以及连接至所述骨干网络的输出的至少两个子网络,每个所述子网络包括区域建议网络RPN和检测器,所述方法包括:在每个所述子网络中,由所述RPN基于从所述骨干网络输出的特征图来确定多个候选区域的位置以及每个候选区域中包含有对象的概率,并且输出确定结果;以及由所述检测器基于所述RPN的输出来在所述图像中确定包含对象的区域以及所述对象的类别,其中,通过深度相互学习来训练各个子网络中的各个RPN,使得各个RPN所输出的确定结果彼此接近。

(2).根据(1)所述的方法,还包括:在针对所述各个RPN中的特定RPN的训练中,所述特定RPN的参数被改变,并且其他RPN的参数保持不变。

(3).根据(2)所述的方法,其中,利用不同的图像序列来训练不同的RPN。

(4).根据(2)所述的方法,其中,所述各个RPN被随机地初始化。

(5).根据(2)所述的方法,其中,在基于深度相互学习的训练中使用第一损失函数和第二损失函数,其中,所述第一损失函数与所述特定RPN和其他RPN各自确定的、候选区域中包含有对象的概率有关,所述第二损失函数与所述特定RPN和其他RPN各自确定的候选区域的位置有关。

(6).根据(5)所述的方法,其中,候选区域的位置包括所述候选区域的中心坐标,宽度以及高度。

(7).根据(2)所述的方法,还包括:基于第三损失函数和第四损失函数来训练所述特定RPN以及与其属于同一子网络的检测器,其中,所述第三损失函数与所述特定RPN确定的候选区域的位置和候选区域中包含有对象的概率有关,所述第四损失函数与所述检测器确定的包含对象的区域的位置以及对象的类别有关。

(8).根据(7)所述的方法,其中,所述检测器确定的包含对象的区域的位置包括所述区域的中心坐标,宽度、高度以及旋转角度。

(9).根据(7)所述的方法,还包括:对所述骨干网络进行训练。

(10).一种在图像中检测对象的方法,包括:利用根据(1)-(9)所述的方法训练得到的多分支网络来检测对象,其中,对各个子网络的检测器的输出进行合并,以在所述图像中确定包含对象的区域以及所述对象的类别。

(11).一种多分支网络,其中,所述多分支网络用于在图像中检测对象,并且包括针对图像提取特征图的骨干网络以及连接至所述骨干网络的输出的至少两个子网络,每个子网络包括区域建议网络(RPN)和检测器。在每个子网络中,由所述RPN基于从所述骨干网络输出的特征图来确定多个候选区域的位置以及每个候选区域中包含有对象的概率,并且输出确定结果,并且由所述检测器基于所述RPN的输出来在所述图像中确定包含对象的区域以及所述对象的类别。其中,通过深度相互学习来训练各个子网络中的各个RPN,使得各个RPN所输出的确定结果彼此接近。

(12).一种训练多分支网络的装置,包括:被配置为执行(1)-(9)所述的方法的一个或多个处理器。

(13).一种存储有计算机可执行指令的存储介质,所述指令在被计算机执行时使得所述计算机执行(1)-(10)所述的方法。

相关技术
  • 训练多分支网络的方法和对象检测方法
  • 利用对象感知多分支关系网络完成视频中指定对象定位任务的方法和定位系统
技术分类

06120112940652