掌桥专利:专业的专利平台
掌桥专利
首页

图像分类网络模型的训练方法、图像分类方法及相关设备

文献发布时间:2023-06-19 11:22:42


图像分类网络模型的训练方法、图像分类方法及相关设备

技术领域

本申请涉及图像处理技术领域,特别是涉及一种图像分类网络模型的训练方法、图像分类方法及相关设备。

背景技术

图像分类是图像处理技术领域最为基础的一类问题。现有技术中主要采用深度神经网络的图像分类方法,具体为将待分类图像及待分类图像的类别标签输入深度神经网络模型,以对深度神经网络模型进行训练。但上述方式所得深度神经网络模型输出的待分类图像的预测分类标签存在错误的可能性,且存在不可解释性。

发明内容

本申请提供了一种图像分类网络模型的训练方法、图像分类方法及相关设备。

为解决上述技术问题,本申请提供了一种图像分类网络模型的训练方法,所述方法包括:

获取训练图像和外部知识库,所述外部知识库包括所述训练图像的真实类别标签;

对所述外部知识库进行编码处理,得到类别距离矩阵;

将所述训练图像及其真实类别标签和所述类别距离矩阵输入所述图像分类网络模型,得到所述训练图像的预测类别概率分布,其中,所述预测类别概率分布包括所述图像分类网络模型输出的预测类别标签与所述真实类别标签之间的差距概率;

利用所述类别距离矩阵中所述真实类别标签与所述预测类别标签之间的深度距离以及所述预测类别概率分布计算目标损失函数;

基于所述目标损失函数训练所述图像分类网络模型。

为解决上述技术问题,本申请提供了一种图像分类方法,所述图像分类方法包括:

获取待分类图像;

将所述待分类图像输入到图像分类网络模型,得到所述待分类图像的类别标签,其中,所述图像分类网络模型为利用上述方法训练所得的图像分类网络模型;

对所述待分类图像的类别标签进行评价,得到可解释性评分。

为解决上述技术问题,本申请提供了一种终端设备,所述设备包括存储器以及与所述存储器耦接的处理器;

所述存储器用于存储程序数据,所述处理器用于执行所述程序数据以实现如上述的图像分类网络模型的训练方法和/或上述的图像分类方法。

为解决上述技术问题,本申请还提供了一种计算机存储介质,所述计算机存储介质用于存储程序数据,所述程序数据在被处理器执行时,用以实现如上述的图像分类网络模型的训练方法和/或上述的图像分类方法。

本申请的有益效果是:获取训练图像和外部知识库,外部知识库包括训练图像的真实类别标签;对外部知识库进行编码处理,得到类别距离矩阵;将训练图像及其真实类别标签和类别距离矩阵输入图像分类网络模型,得到训练图像的预测类别概率分布,其中,预测类别概率分布包括图像分类网络模型输出的预测类别标签与真实类别标签之间的差距概率;利用类别距离矩阵中真实类别标签与预测类别标签之间的深度距离以及预测类别概率分布计算目标损失函数;基于目标损失函数训练网络模型。本申请引用外部知识库对图像分类网络模型输出的预测类别概率分布进行约束,兼顾提高了图像分类的准确性及增强了预测结果的可解释性。

附图说明

为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。其中:

图1是本申请提供的图像分类网络模型的训练方法一实施例的流程示意图;

图2是本申请提供的图像分类网络模型的训练方法中外部知识库的简易示意图;

图3是图1所示的图像分类网络模型的训练方法中S102一实施例的流程示意图;

图4是图1所示的图像分类网络模型的训练方法中S104一实施例的流程示意图;

图5是本申请提供的图像分类方法的一实施例的流程示意图;

图6是本申请提供的终端设备一实施例的结构示意图;

图7是本申请提供的计算机存储介质一实施例的结构示意图。

具体实施方式

下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。

本申请提出了一种图像分类网络模型的训练方法,具体请参阅图1,图1是本申请提供的图像分类网络模型的训练方法一实施例的流程示意图。本实施例中图像分类网络模型的训练方法可以应用于图像分类装置,本申请的图像分类装置可以为服务器,也可以为移动设备,还可以为由服务器和移动设备相互配合的系统。相应地,移动设备包括的各个部分,例如各个单元、子单元、模块、子模块可以全部设置于服务器中,也可以全部设置于移动设备中,还可以分别设置于服务器和移动设备中。

进一步地,上述服务器可以是硬件,也可以是软件。当服务器为硬件时,可以实现成多个服务器组成的分布式服务器集群,也可以实现成单个服务器。当服务器为软件时,可以实现成多个软件或软件模块,例如用来提供分布式服务器的软件或软件模块,也可以实现成单个软件或软件模块,在此不做具体限定。

本实施例的图像分类网络模型的训练方法具体包括以下步骤:

S101:获取训练图像和外部知识库,外部知识库包括训练图像的真实类别标签。

本公开实施例中,考虑到现有技术中单纯利用训练图像和训练图像的真实类别标签对图像分类网络模型进行训练,所得的图像分类网络模型输出的预测分类标签存在错误的可能性,并且存在不可解释性。为避免上述问题,本申请的图像分类装置引用外部知识库对图像分类网络模型的预测结果进行约束。

可参阅图2,图2是本申请提供的图像分类网络模型的训练方法中外部知识库的简易示意图。由图可知,外部知识库为由多个类别标签构成的树状结构,树状结构中的每个节点表示类别标签,且树状结构中位置越接近的节点之间的类别标签越相似。为了利用外部知识库对图像分类网络模型输出的预测类别概率分布进行约束,本实施例的外部知识库应包括图像分类网络模型能够区分的所有类别标签。进一步地,在对图像分类网络模型进行训练的过程中,外部知识库至少包括训练图像的真实类别标签。

进一步地,考虑到训练图像的多类性,单一的外部知识库可能无法包括所有训练图像的真实类别标签。为此,本实施例的图像分类网络模型的训练方法可通过人工从额外的知识库中提取类别标签对单一外部知识库中的缺失类别标签进行补充。

考虑到训练图像的数量对图像分类网络模型输出的预测结果的影响,本实施例所需训练图像的数量应尽可能地多。在具体实施例中,训练图像的数量至少为1000张。

需要说明的是,本实施例的图像分类装置在利用训练图像对图像分类网络模型进行训练前,应统一训练图像的像素大小,例如,统一缩放为256x256,方便利用相同像素大小的训练图像对图像分类网络模型进行训练。

S102:对外部知识库进行编码处理,得到类别距离矩阵。

可继续参阅图2,为了利用外部知识库中类别标签之间的深度距离对网络模型输出的预测类别概率分布进行约束。本实施例需对外部知识库进行编码处理,得到类别距离矩阵。其中,类别距离矩阵包括外部知识库中任意两类别标签之间的深度距离,也即语义距离。

可选地,本实施例可采用图3实施例实现S102,具体包括S201至S203:

S201:获取外部知识库中任意两个类别标签。

为了方便从类别距离矩阵中获取真实类别标签和预测类别标签之间的深度距离,本实施例的图像分类装置可以预先获知包括真实类别标签和预测类别标签之间的深度距离的类别距离矩阵。具体地,对于类别距离矩阵的获取,图像分类装置首先需获取外部知识库中任意两个类别标签。

S202:获取任意两个类别标签之间的公共类别标签。

进一步地,图像分类装置在外部知识库中获取任意两个类别标签之间的公共类别标签,也即最近公共祖先。其中,最近公共祖先为任意两个类别标签中一类别标签和另一类别标签的祖先且该祖先深度尽可能大。

S203:基于公共类别标签计算任意两个类别标签的深度距离,以得到包括任意两个类别标签的深度距离的类别距离矩阵。

其中,本实施例的图像分类装置利用公共类别标签计算任意两个类别标签的深度距离,以得到包括任意两个类别标签的深度距离的类别距离矩阵。

具体地,图像分类装置分别获取公共类别标签、两个类别标签中一类别标签的深度和另一类别标签的深度;计算一类别标签的深度与另一类别标签的深度之和;利用公共类别标签的深度与上述一类别标签的深度与另一类别标签的深度之和的比值计算这任意两个类别标签的深度距离。

其中,采用Wup(Wu-Palmer)语义相似度计算外部知识库中任意两个类别标签之间的深度距离。深度距离的具体计算公式如下:

其中,c

进一步地,本实施例的图像分类装置通过定位公共类别标签在外部知识库中的标签位置,利用外部知识库中标签位置获取公共类别标签的层数,并以此确定公共类别标签的深度。在具体实施例中,类别标签c

S103:将训练图像及其真实类别标签和类别距离矩阵输入图像分类网络模型,得到训练图像的预测类别概率分布。

本实施例的图像分类装置将训练图像及其真实类别标签和类别距离矩阵输入图像分类网络模型,得到训练图像的预测类别概率分布。其中,预测类别概率分布包括图像分类网络模型输出的预测类别标签与真实类别标签之间的差距概率。

S104:利用类别距离矩阵中真实类别标签与预测类别标签之间的深度距离以及预测类别概率分布计算目标损失函数。

由于现有图像分类网络模型训练方法中所用损失函数在训练图像的预测类别标签与真实类别标签一致时,存在损失函数值。在训练图像的预测类别标签与真实类别标签不一致时,损失函数值为0。所以,现有损失函数忽略了训练图像的预测类别标签与真实类别标签不一致的情况下对图像分类网络模型训练的影响,导致图像分类网络模型输出的预测类别概率分布与常识不符。为解决上述问题,本实施例的图像分类网络模型训练方法通过拓展损失函数,兼顾考虑训练图像的预测类别标签与真实类别标签不一致情况对图像分类网络模型的影响。具体地,本实施例的图像分类装置利用类别距离矩阵中真实类别标签与预测类别标签之间的深度距离以及预测类别概率分布计算目标损失函数。

可选地,本实施例可采用图4实施例实现S104,具体包括S301至S304:

S301:获取类别距离矩阵中真实类别标签与预测类别标签之间的深度距离。

由于本实施例拓展了现有图像分类网络模型训练方法中的损失函数,因此,本实施例图像分类网络模型中的目标损失函数包括第一损失函数和第二损失函数。第一损失函数和第二损失函数分别表征网络模型的不同方面特征。具体地,第一损失函数表征在训练图像的预测类别与真实类别之间一致性时,图像分类网络模型输出的预测类别概率分布与预设类别概率分布之间的损失。第二损失函数表征在训练图像的预测类别与真实类别不一致性时,利用训练图像的预测类别与真实类别之间的深度距离,即语义距离,获取图像分类网络模型输出的预测类别概率分布与深度距离之间的损失。

S302:利用预测类别概率分布与深度距离计算第一损失函数。

图像分类装置利用预测类别概率分布与深度距离计算第一损失函数。

具体地,第一损失函数满足下式:

其中,L

需要说明的是,在具体实施例中,第一损失函数可以为交叉熵损失函数。

S303:利用预测类别概率分布与真实类别标签计算第二损失函数。

基于S302的第一损失函数可知,当l与k的真实类别标签不一致时,指示函数为0,导致第一损失函数为0,图像分类网络模型输出的预测结果忽视了训练图像的真实类别标签与预测类别标签不一致的情况。为解决上述问题,本实施例的图像分类装置拓展第一损失函数,对除真实类别标签的其他类别标签的预测结果进行约束。具体地,图像分类装置利用预测类别概率分布与真实类别标签计算第二损失函数。

具体地,第二损失函数满足下式:

其中,L

S304:基于第一损失函数和第二损失函数计算目标损失函数。

其中,图像分类装置利用第一损失函数和第二损失函数计算目标损失函数。

具体地,目标损失函数满足下式:

其中,L(k)为目标损失函数,α为权重系数,用于平衡第一损失函数和第二损失函数以最优化图像分类网络模型的训练。

在具体实施例中,图像分类装置可采用网格搜索法确定权重系数α。

S105:基于目标损失函数训练图像分类网络模型。

本实施例的图像分类装置以目标损失函数训练图像分类网络模型。具体地,本实施例的图像分类装置可利用梯度下降技术对目标损失函数进行训练。

上述方案中,图像分类装置引用外部知识库对图像分类网络模型输出的预测类别概率分布进行约束,兼顾提高了图像分类的准确性及增强了预测结果的可解释性;利用预测类别概率分布与深度距离计算目标损失函数,扩展了现有损失函数,避免因现有损失函数忽视训练图像的预测类别标签与真实类别标签不一致的情况,导致图像分类网络模型输出的预测类别概率分布与常识不符。

可参阅图5,图5是本申请提供的图像分类方法的一实施例的流程示意图。本实施例图像分类方法可应用于上述图像分类网络模型的训练方法中训练所得的图像分类网络模型,从而兼顾提高图像分类的准确性及预测结果的可解释性。下面以用于图像分类方法的服务器为例,介绍本申请提供的图像分类方法,本实施例图像分类方法具体包括以下步骤:

S401:获取待分类图像。

本实施例获取待分类图像与上述实施例S101中训练图像获取相似,在此不再赘述。

S402:将待分类图像输入到图像分类网络模型,得到待分类图像的类别标签。

其中,本实施例的图像分类装置将待分类图像输入到图像分类网络模型中,得到待分类图像的类别标签。

S403:对待分类图像的类别标签进行评价,得到可解释性评分。

为了兼顾提高图像分类准确性及增强预测结果可解释性,本实施例需对图像分类网络模型输出的待分类图像的类别标签进行评价,得到可解释性评分。具体地,本实施例将待分类图像输入图像分类网络模型,得到类别标签排序值及类别概率分布,利用包括第一类别标签排序值和第二类别标签排序值的类别标签排序值计算可解释性评分。

其中,第一类别标签排序值为类别概率分布中待分类图像的类别标签与待分类图像的真实类别标签之间的差距概率排序值。第二类别标签排序值为类别距离矩阵中待分类图像的类别标签与待分类图像的真实类别标签之间的深度距离排序值。

其中,可解释性评价满足下式:

其中,r

本实施例,获取待分类图像,将待分类图像输入到图像分类网络模型,得到待分类图像的类别标签,对待分类图像的类别标签进行评价,得到可解释性评分,实现了兼顾提高图像分类准确性及增强预测结果可解释性。

为实现上述实施例的图像分类网络模型训练方法和/或图像分类方法,本申请提出了一种终端设备,具体请参阅图6,图6是本申请提供的终端设备一实施例的结构示意图。

终端设备600包括存储器61和处理器62,其中,存储器61和处理器62耦接。

存储器61用于存储程序数据,处理器62用于执行程序数据以实现上述实施例的图像分类网络模型训练方法和/或图像分类方法。

在本实施例中,处理器62还可以称为CPU(Central Processing Unit,中央处理单元)。处理器62可能是一种集成电路芯片,具有信号的处理能力。处理器62还可以是通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器62也可以是任何常规的处理器等。

本申请还提供一种计算机存储介质700,如图7所示,计算机存储介质700用于存储程序数据71,程序数据71在被处理器执行时,用以实现如本申请方法实施例中所述的图像分类网络模型训练方法和/或图像分类方法。

本申请图像分类网络模型训练方法和/或图像分类方法实施例中所涉及到的方法,在实现时以软件功能单元的形式存在并作为独立的产品销售或使用时,可以存储在装置中,例如一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施方式所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。

以上所述仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

相关技术
  • 图像分类网络模型的训练方法、图像分类方法及相关设备
  • 图像分类方法及装置、神经网络模型的训练方法及装置
技术分类

06120112899890