掌桥专利:专业的专利平台
掌桥专利
首页

一种神经网络的训练方法及相关装置

文献发布时间:2024-04-18 19:58:21


一种神经网络的训练方法及相关装置

技术领域

本申请实施例涉及人工智能领域,尤其涉及一种神经网络的训练方法及相关装置。

背景技术

许多的工业场景中,都需要检测来自外界的入侵事件,其中,基于光缆传感的入侵检测是一种通用的检测手段。而由于光缆部署于复杂多样的地质环境和背景震动中,因此,光缆很容易获取到与入侵事件很相似的非入侵事件(属于OOD事件)的信号。例如,农耕机械旋耕机事件(非入侵事件)跟夯机敲击(入侵事件)很相似,导致入侵识别模型把非入侵事件误报为入侵事件。

因此,在光缆感知到外界的入侵信号后,可以采用入侵识别模型来识别入侵事件或非入侵事件的类别。而由于入侵事件信号和非入侵事件信号较为相似,因此,入侵事件信号和非入侵事件信号都会作为入侵识别模型的训练输入,入侵识别模型的输出则为入侵事件或非入侵事件的类别。对此,上述入侵事件信号则属于模型训练过程中的分布内(in-distribution,ID)数据,非入侵事件信号则属于模型训练过程中的分布外(out-of-distribution,OOD)数据。当获取到新的非入侵事件后,一般会采集该非入侵事件信号(属于OOD数据),从而将现有的入侵识别模型重新训练,得到入侵识别模型的输出则包括了该非入侵事件(OOD事件)的新增入侵类别。

非入侵事件的类别随着时间发展可能会不断增多,因此,入侵识别模型所输出的入侵类别也会不断增多,严重降低入侵识别模型的性能。

发明内容

本申请实施例提供了一种神经网络的训练方法及相关装置,用于提高分类网络的性能。

第一方面,本申请实施例提供了一种神经网络的训练方法。本申请中神经网络的训练方法,可以适用于神经网络模型的输入端存在OOD数据的分类网络场景,用于去除OOD数据所带来的无效特征贡献的干扰,提高神经网络模型的性能。例如,上述入侵事件识别的场景中,光缆所采集到的信号包括了与入侵事件信号较为相似的非入侵事件信号(高速行驶或管道正常),又例如,在宠物猫的品种分类的图像识别场景中,输入端存在老虎或豹子等与猫较为相似的图片,还可以是其他的存在OOD数据干扰的分类网络场景,具体此处不做限定。这些OOD数据会对神经网络模型的输出造成极大的干扰。因此,本申请中,训练数据应当包括OOD数据和ID数据。每个训练数据有各自的标签,用于指示该训练数据属于OOD数据或ID数据。

将训练数据(包括OOD数据和ID数据)输入第一特征提取网络,得到每个训练数据对应的特征,即得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征。而由于每个训练数据的标签是已知的(属于OOD数据或ID数据),因此,该训练数据输入第一特征提取网络后所得到的特征的标签,与该训练数据的标签相同。示例性的,若某个训练数据的标签指示其属于OOD数据,则该训练数据输入第一特征提取网络后所得到的特征的标签,也同样指示其属于OOD数据。

在得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征后,其中的第一OOD特征则不需要作为第一分类网络的输入,而仅将第一ID特征输入到第一分类网络中,得到第一分类结果。

根据第一OOD特征、第一ID特征和所述第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络,以及,根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

具体的,可以根据第一OOD特征和第一ID特征对第一特征提取网络进行训练,而第一分类网络所输出的第一分类结果则可以用于训练第一特征提取网络和第一分类网络。

由于在第一特征提取网络生成第一OOD特征和第一ID特征之后,便已经可以计算第一OOD特征和第一ID特征之间的损失函数。对此,“根据第一OOD特征和第一ID特征对第一特征提取网络进行训练”的步骤,可以在第一特征提取网络生成第一OOD特征和第一ID特征之后执行,也可以在第一分类网络生成第一分类结果之后执行,具体本申请对此不做限定。

另一方面,由于训练数据的标签也已经指示了每个训练数据的真实分类结果,因此,可以根据训练数据的真实分类结果与第一分类结果之间的损失函数,来对第一特征提取网络和第一分类网络进行端到端的训练,其训练目的为使得该损失函数的值最小化。

进一步的,为了提高神经网络模型的可靠性,在神经网络模型的训练过程中,一般需要经过多轮次的训练迭代来更新模型参数。因此,本申请的神经网络的训练方法,应用于每一个轮次的训练过程中。每一轮训练都需要从总的训练数据集中抽取一定比例的OOD数据和ID数据,作为神经网络模型本轮次的训练数据,换句话说,本申请的训练数据,可以是总的训练数据集的子集。若完成本轮次的训练后,训练流程未能收敛,则继续从总的训练数据集中抽取另一部分数据作为下一轮次的训练数据,继续训练模型,直至训练流程成功收敛。每一个轮次的训练过程,都执行本申请中的神经网络的训练方法,直至满足预设条件(例如损失函数的值满足预设条件),从而得到第二特征提取网络和第二分类网络。其中,第二特征提取网络为执行过训练操作的第一特征提取网络,第二分类网络为执行过训练操作的第一分类网络。

本申请中,只将训练数据所生成的ID特征输入到第一分类网络,即第一分类网络只关注对于ID特征的处理。因此,即便增加了新的OOD数据类别,也不会增加第一分类网络的输出类别,从而提高了第一分类网络训练后的性能。

另一方面,第一分类网络的输入只关注ID数据,即只有ID数据作为神经网络模型的输入时,才会体现在第一分类网络的分类结果中,而OOD数据是不会体现在第一分类网络的分类结果中的,从而去除了OOD数据所带来的无效特征贡献的干扰,提高神经网络模型的准确性和稳定性。

基于第一方面,一种可选的实施方式中,在训练完毕得到第二特征提取网络和第二分类网络之后,可以再根据训练数据中的OOD数据来构建OOD锚点库,OOD锚点库包括至少一个OOD特征锚点。OOD锚点库用于确定第二特征属于OOD特征或ID特征,其中,第二特征为第二特征提取网络对待分类数据进行特征提取所得到的特征。具体的,OOD锚点库可以通过训练数据中的OOD数据和第二特征提取网络来构建:先获取用于构建OOD锚点库的训练数据中的OOD数据,然后将OOD数据输入到第二特征提取网络中,得到OOD数据对应的第二OOD特征,将该第二OOD特征作为OOD特征锚点,保存到OOD锚点库中。

需要说明的是,本申请并不限定OOD锚点库中的OOD特征锚点的数量。该OOD锚点库包括至少一个OOD特征锚点,但OOD锚点库中的OOD特征锚点的数量越多,则过滤器区分OOD特征和ID特征的准确性和稳定性则越高。一般来说,在模型训练的过程中,OOD数据的数量会远低于ID数据的数量,因此,过滤器通过结合OOD锚点库的方式来区分OOD特征和ID特征,能够在OOD数据的样本数量较少的场景下,仍然能够具备良好的准确性和稳定性。

基于第一方面,一种可选的实施方式中,在训练完毕得到第二特征提取网络和第二分类网络之后,便可以进行神经网络模型的预测流程。将待分类数据输入到第二特征提取网络,得到该待分类数据对应的第二特征。若第二特征属于ID特征,则将第二特征输入到第二分类网络,从而得到该第二特征对应的分类结果。若第二特征属于OOD特征,则第二特征不会输入到第二分类网络,则第二分类网络也不会输出该第二特征的分类结果。

基于第一方面,一种可选的实施方式中,训练数据输入到第一特征提取网络,生成训练数据的特征后,是根据该训练数据的标签的指示来确定该特征属于OOD特征或ID特征的。然而在,模型训练完毕之后(即得到第二特征提取网络和第二分类网络),该神经网络模型的预测过程中,所输入的待分类数据,是没有标签来指示该待分类数据属于OOD数据或ID数据的,待分类数据输入到第二特征提取网络后,生成待分类数据对应的第二特征,该第二特征同样也没有标签来指示该特征属于OOD特征或ID特征。因此,需要对第二特征进行过滤,确认第二特征属于OOD特征或ID特征。

若第二特征属于ID特征,则将该第二特征输入到第二分类网络中,由第二分类网络对第二特征进行处理,得到该第二特征对应的第二分类结果;若第二特征属于OOD特征,则该第二特征不会作为第二分类网络的输入。

基于第一方面,一种可选的实施方式中,可以根据OOD锚点库来确定第二特征属于OOD特征或ID特征。具体的,待分类数据输入第二特征提取网络,生成该待分类数据对应的第二特征。第二特征和计算OOD锚点库中的OOD特征锚点,都可以用二维的向量来进行表示。根据第二特征的向量表示和OOD特征锚点的向量表示,计算OOD锚点库中的OOD特征锚点与第二特征之间的距离,该距离越小,则说明第二特征与OOD特征锚点之间的相似度越高、关联度越高,该第二特征有更高的概率属于OOD特征;反之,该距离越大,则说明第二特征与OOD特征锚点之间的相似度越低、关联度越低,该第二特征有更高的概率属于ID特征。因此,可以先配置一个预设阈值,当第二特征与OOD锚点库中的OOD特征锚点之间的距离小于预设阈值时,确定第二特征为OOD特征;当第二特征与OOD锚点库中的OOD特征锚点之间的距离大于或等于预设阈值时,确定第二特征为ID特征。

在构建了OOD锚点库之后,即便输入了新增的OOD数据类型,该OOD锚点库也同样适用于识别过滤新增的OOD数据,不需要针对该新增的OOD数据类型来重新构建或优化OOD锚点库中的OOD特征锚点,可以继续基于该OOD锚点库来识别新增的OOD数据,提高了神经网络模型的效率。

基于第一方面,一种可选的实施方式中,以增加第一OOD特征和第一ID特征之间的距离为方向,来构建损失函数。其中,第一OOD特征和第一ID特征之间的距离越大,则该损失函数的值越大,说明第一OOD特征和第一ID特征之间的差异越大,越有利于后续进行区分第一OOD特征和第一ID特征。因此,可以根据第一OOD特征和第一ID特征之间的差异的损失函数,以该损失函数的值越来越大作为优化方向,来对第一特征提取网络进行训练,从而使得第一特征提取网络所提取得到的第一OOD特征和第一ID特征之间的差异越来越大。

第二方面,本申请实施例提供了一种神经网络的训练装置,该装置包括:

处理单元,用于获取训练数据,训练数据包括分布外OOD数据和分布内ID数据;用于将训练数据输入第一特征提取网络,得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征;还用于将第一ID特征输入第一分类网络,得到第一分类结果;

训练单元,用于根据第一OOD特征、第一ID特征和第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络;根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

基于第二方面,一种可选的实施方式中,处理单元,还用于:

将训练数据中的OOD数据输入第二特征网络,得到OOD数据对应的第二OOD特征;将第二OOD特征作为OOD特征锚点,保存到OOD锚点库,其中,OOD锚点库用于确定第二特征属于OOD特征或ID特征,第二特征为第二特征提取网络对待分类数据进行特征提取所得到的特征。

基于第二方面,一种可选的实施方式中,处理单元,还用于:

将待分类数据输入第二特征提取网络,得到第二特征;若第二特征属于ID特征,则将第二特征输入第二分类网络,得到第二分类结果。

基于第二方面,一种可选的实施方式中,处理单元,还用于:对第二特征进行过滤,确认第二特征属于OOD特征或ID特征。

基于第二方面,一种可选的实施方式中,处理单元,具体用于:

获取OOD锚点库,OOD锚点库包括至少一个OOD特征锚点;计算OOD锚点库中的OOD特征锚点与第二特征之间的距离;若距离的值大于预设阈值,则确定第二特征为ID特征。

基于第二方面,一种可选的实施方式中,处理单元,具体用于:

计算第一OOD特征与第一ID特征之间的距离;以增加第一OOD特征与第一ID特征之间的距离作为损失函数,对第一特征提取网络进行训练;根据第一分类结果对第一特征提取网络进行训练。

第三方面,本发明实施例提供了一种计算机设备,包括处理器,用于执行上述任一方面的神经网络的训练方法。

基于第三方面,一种可选的实施方式中,还包括存储器,用于存储代码,与所述处理器耦合;所述处理器具体用于执行所述存储器中的代码,来实现上述任一方面的神经网络的训练方法。

第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,当其在计算机上运行时,使得计算机执行上述任一方面所述的神经网络的训练方法。

第五方面,本申请实施例提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机程序或指令,当其在计算机上运行时,使得计算机执行上述任一方面所述的神经网络的训练方法。

第六方面,本申请实施例提供了一种芯片系统,该芯片系统包括处理器,用于实现上述各个方面中所涉及的功能,例如,发送或处理上述方法中所涉及的数据和/或信息。在一种可能的设计中,所述芯片系统还包括存储器,所述存储器,用于保存服务器或通信设备必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包括芯片和其他分立器件。

从以上技术方案可以看出,本申请实施例具有以下优点:

本申请公开了一种神经网络的训练方法及相关装置。获取训练数据,训练数据包括分布外OOD数据和分布内ID数据;将训练数据输入第一特征提取网络,得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征;将第一ID特征输入第一分类网络,得到第一分类结果;根据第一OOD特征、第一ID特征对和第一分类结果第一特征提取网络进行训练,得到第二特征提取网络,以及,根据第一分类结果对第一分类网络进行训练,得到第二分类网络。本申请中,只将训练数据所生成的ID特征输入到第一分类网络,即第一分类网络只关注对于ID特征的处理。因此,即便增加了新的OOD数据类别,也不会增加第一分类网络的输出类别,从而提高了第一分类网络训练后的性能。

附图说明

图1为人工智能主体框架的一种结构示意图;

图2为入侵识别模型的架构示意图;

图3为入侵识别模型处理新增的OOD事件的示意图;

图4为本申请实施例中神经网络模型的架构示意图;

图5为本申请实施例中神经网络的训练方法的流程示意图;

图6为本申请实施例中神经网络模型的预测流程示意图;

图7为本申请实施例中OOD锚点库的构建流程示意图;

图8为本申请实施例提供的一种神经网络的训练装置的结构示意图;

图9为本申请实施例提供的计算机设备的一种结构示意图。

具体实施方式

本申请实施例提供了一种神经网络的训练方法及相关装置,用于提高分类网络的性能。

下面结合本发明实施例中的附图对本发明实施例进行描述。本发明的实施方式部分使用的术语仅用于对本发明的具体实施例进行解释,而非旨在限定本发明。本领域普通技术人员可知,随着技术的发展和新场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。

本申请中,“至少一个”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B的情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指的这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b,或c中的至少一项(个),可以表示:a,b,c,a-b,a-c,b-c,或a-b-c,其中a,b,c可以是单个,也可以是多个。

本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例例如能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。

首先对人工智能系统总体工作流程进行描述,请参见图1,图1示出的为人工智能主体框架的一种结构示意图,下面从“智能信息链”(水平轴)和“IT价值链”(垂直轴)两个维度对上述人工智能主题框架进行阐述。其中,“智能信息链”反映从数据的获取到处理的一列过程。举例来说,可以是智能信息感知、智能信息表示与形成、智能推理、智能决策、智能执行与输出的一般过程。在这个过程中,数据经历了“数据—信息—知识—智慧”的凝练过程。“IT价值链”从人智能的底层基础设施、信息(提供和处理技术实现)到系统的产业生态过程,反映人工智能为信息技术产业带来的价值。

(1)基础设施。

基础设施为人工智能系统提供计算能力支持,实现与外部世界的沟通,并通过基础平台实现支撑。通过传感器与外部沟通;计算能力由智能芯片来提供,例如,中央处理器(central processing units,CPU)、嵌入式神经网络处理器(neural-network processingunit,NPU)、图形处理器(graphics processing unit,GPU)、专用集成电路(applicationspecific integrated circuit,ASIC)、现场可编程逻辑门阵列(field programmablegate array,FPGA)等硬件加速芯片;基础平台包括分布式计算框架及网络等相关的平台保障和支持,可以包括云存储和计算、互联互通网络等。举例来说,传感器和外部沟通获取数据,这些数据提供给基础平台提供的分布式计算系统中的智能芯片进行计算。

(2)数据。

基础设施的上一层的数据用于表示人工智能领域的数据来源。数据涉及到图形、图像、语音、文本,还涉及到传统设备的物联网数据,包括已有系统的业务数据以及力、位移、液位、温度、湿度等感知数据。

(3)数据处理。

数据处理通常包括数据训练,机器学习,深度学习,搜索,推理,决策等方式。

其中,机器学习和深度学习可以对数据进行符号化和形式化的智能信息建模、抽取、预处理、训练等。

推理是指在计算机或智能系统中,模拟人类的智能推理方式,依据推理控制策略,利用形式化的信息进行机器思维和求解问题的过程,典型的功能是搜索与匹配。

决策是指智能信息经过推理后进行决策的过程,通常提供分类、排序、预测等功能。

(4)通用能力。

对数据经过上面提到的数据处理后,进一步基于数据处理的结果可以形成一些通用的能力,比如可以是算法或者一个通用系统,例如,翻译,文本的分析,计算机视觉的处理,语音识别,图像的识别等等。

(5)智能产品及行业应用。

智能产品及行业应用指人工智能系统在各领域的产品和应用,是对人工智能整体解决方案的封装,将智能信息决策产品化、实现落地应用,其应用领域主要包括:智能终端、智能交通、智能医疗、自动驾驶、智慧城市等。

本申请中所提供的基于预训练大模型的图像处理方法,所应用的场景,包括但不限于上述示例。具体的,可以应用于数据训练、机器学习、深度学习等数据处理方法,对训练数据进行符号化和形式化的智能信息建模、抽取、预处理、训练等,最终得到训练好的神经网络模型(如本申请实施例中的目标神经网络模型);并且目标神经网络模型可以用于进行模型推理,具体可以将输入数据输入到目标神经网络模型中,得到输出数据。

在机器学习领域,用于训练模型的数据通常被称为分布内(in-distribution,ID)数据,而分布外(out-of-distribution,OOD)数据是指和训练的ID数据分布不一致的数据。在神经网络模型的实际应用中,输入数据中有时存在OOD数据,这会导致模型预测不准确,进而限制神经网络模块的应用。因此,对机器学习模型的输入数据进行OOD数据检测是提高模型预测准确率的一种重要手段。

接下来,结合示例,对涉及OOD数据以及ID数据的一种模型训练场景介绍。

许多工业场景都需要进行外界非法的入侵检测,例如电力、交通、安防、石化或通信等场景。而入侵事件的类型是多种多样的,例如在石化领域,外界非法的入侵事件包括挖掘机施工破坏或人工铲挖掘偷油等。

基于光缆传感的入侵检测是检测入侵事件的一个有效手段。入侵信号通过地质传播被光缆感知,由于光缆的长度较长(一般在50千米左右),并且,光缆部署于复杂多样的地质环境(农田、泥塘、裸漏管道、河流或山区等)和背景震动(公路、铁路、快速路、地铁、轻轨、高架桥或隧道等)中,因此,光缆很容易获取到与入侵事件很相似的非入侵事件(属于OOD事件)的信号。然而,对于工业场景来说,非入侵事件(例如高速行驶或管道正常)并不会影响到该场景下的正常工作,而非入侵事件信号输入到入侵识别模型后,由于该非入侵事件信号与入侵事件信号很相似,入侵识别模型很容易将该非入侵事件信号误报为入侵事件。因此,在工业场景中,会影响正常工作的入侵事件信号,才是入侵识别模型应当关注的数据,则入侵事件信号属于ID数据;而非入侵事件信号会对入侵识别模型的输出造成干扰、引发误报,则该非入侵事件信号属于OOD数据,这些OOD数据会降低入侵识别模型的效率和准确率。例如,农耕机械旋耕机事件(非入侵事件)跟夯机敲击事件(入侵事件)很相似,光缆获取到农耕机械旋耕机事件的信号后,将该信号输入到入侵识别模型,则入侵识别模型很容易将该信号与夯机敲击(入侵事件)的信号混淆,导致入侵识别模型把非入侵事件误报为入侵事件,即发生了农耕机械旋耕机事件后,却误报为夯机敲击事件。

因此,在光缆感知到外界的入侵信号后,可以采用入侵识别模型来识别入侵事件或非入侵事件的类别。请参阅图2,图2为入侵识别模型的架构示意图。如图2所示,由于入侵事件信号和非入侵事件信号较为相似,因此,入侵事件信号和非入侵事件信号都会作为入侵识别模型的训练输入,入侵识别模型的输出则为入侵事件(挖掘机行驶、挖掘机敲击或挖掘机挖掘)或非入侵事件(高速行驶或管道正常)的类别。

由此可见,上述入侵事件信号属于模型训练过程中的ID数据,非入侵事件信号则属于模型训练过程中的OOD数据。

请参阅图3,图3为入侵识别模型处理新增的OOD事件的示意图。如图3所示,当获取到新的非入侵事件后,一般会采集该非入侵事件信号(属于OOD数据),从而将现有的入侵识别模型重新训练,得到入侵识别模型的输出则包括了该非入侵事件(OOD事件)的新增入侵类别。

非入侵事件的类别随着时间发展可能会不断增多,因此,入侵识别模型所输出的入侵类别也会不断增多,严重降低入侵识别模型的性能。

有鉴于此,本申请提供了一种神经网络的训练方法,用于提高分类网络的性能。请参阅图4,图4为本申请实施例中神经网络模型的架构示意图,该神经网络模型基于本申请的神经网络的训练方法进行模型训练。如图4所示,训练数据(包括OOD数据和ID数据)输入到特征提取网络中,由特征提取网络处理训练数据,得到OOD特征和ID特征。其中,ID特征将作为分类网络的输入,而ODD特征将不会输入到分类网络中。例如,在上述入侵事件的场景中,入侵事件信号的特征(ID特征)将会被输入到分类网络,而非入侵事件信号的特征将不会被输入到分类网络。分类网络将输入的ID特征进行处理,得到ID特征的分类结果。例如,在上述入侵事件的场景中,分类网络将入侵事件信号的特征进行处理,得到该入侵事件的类别。在该神经网络模型的训练过程中,分类网络所输出的分类结果,用于对特征提取网络和分类网络进行训练,并且,OOD特征和ID特征之间的损失函数,也可以用于训练特征提取网络。

为了构建图4所示的神经网络模型,本申请提供了一种神经网络的训练方法。请参阅图5,图5为本申请实施例中神经网络的训练方法的流程示意图,如图5所示,本申请实施例中神经网络的训练方法包括:

101.获取训练数据,训练数据包括OOD数据和ID数据;

传统的神经网络模型一般是在一个封闭环境中进行训练,即整个训练过程都是在基于测试数据和训练数据都来自同样的分布(即分布内)的假设下完成的。然而,在实际使用中,神经网络模型的输入端总是会遇到一些不属于上述封闭环境中的类别的数据(即OOD数据)。

应理解,本申请中神经网络的训练方法,可以适用于神经网络模型的输入端存在OOD数据的分类网络场景,用于去除OOD数据所带来的无效特征贡献的干扰,提高神经网络模型的性能。例如,上述入侵事件识别的场景中,光缆所采集到的信号包括了与入侵事件信号较为相似的非入侵事件信号(高速行驶或管道正常),又例如,在宠物猫的品种分类的图像识别场景中,输入端存在老虎或豹子等与猫较为相似的图片,这些老虎或豹子的图片即为该场景下的OOD数据,猫的图片即为该场景下的ID数据;或者,还可以是其他的存在OOD数据干扰的分类网络场景,具体此处不做限定。这些OOD数据会对神经网络模型的输出造成极大的干扰。

对此,本申请的神经网络的训练方法中,引入了OOD数据对于神经网络模型的影响,因此,训练数据应当包括OOD数据和ID数据。每个训练数据有各自的标签,用于指示该训练数据属于OOD数据或ID数据。

102.将训练数据输入第一特征提取网络,得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征;

将训练数据(包括OOD数据和ID数据)输入第一特征提取网络,得到每个训练数据对应的特征,即得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征。而由于每个训练数据的标签是已知的(属于OOD数据或ID数据),因此,该训练数据输入第一特征提取网络后所得到的特征的标签,与该训练数据的标签相同。示例性的,若某个训练数据的标签指示其属于OOD数据,则该训练数据输入第一特征提取网络后所得到的特征的标签,也同样指示其属于OOD数据。

103.将第一ID特征输入第一分类网络,得到第一分类结果。

在得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征后,其中的第一OOD特征则不需要作为第一分类网络的输入,而仅将第一ID特征输入到第一分类网络中,得到第一分类结果。

104.对第一特征提取网络和第一分类网络进行训练。

根据第一OOD特征、第一ID特征和所述第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络,以及,根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

具体的,可以根据第一OOD特征和第一ID特征对第一特征提取网络进行训练,而第一分类网络所输出的第一分类结果则可以用于训练第一特征提取网络和第一分类网络。

首先,对“根据第一OOD特征和第一ID特征对第一特征提取网络进行训练”的步骤进行介绍:本申请中,以增加第一OOD特征和第一ID特征之间的距离为方向,来构建损失函数。其中,第一OOD特征和第一ID特征之间的距离越大,则该损失函数的值越大,说明第一OOD特征和第一ID特征之间的差异越大,越有利于后续进行区分第一OOD特征和第一ID特征。因此,可以根据第一OOD特征和第一ID特征之间的差异的损失函数,以该损失函数的值越来越大作为优化方向,来对第一特征提取网络进行训练,从而使得第一特征提取网络所提取得到的第一OOD特征和第一ID特征之间的差异越来越大。

由于在第一特征提取网络生成第一OOD特征和第一ID特征之后,便已经可以计算第一OOD特征和第一ID特征之间的损失函数。对此,“根据第一OOD特征和第一ID特征对第一特征提取网络进行训练”的步骤,可以在第一特征提取网络生成第一OOD特征和第一ID特征(即步骤102)之后执行,也可以在第一分类网络生成第一分类结果(即步骤103)之后执行,具体此处不做限定。

接下来,对“根据第一分类结果对第一特征提取网络和第一分类网络进行训练”的步骤进行介绍。由于训练数据的标签也已经指示了每个训练数据的真实分类结果,因此,可以根据训练数据的真实分类结果与第一分类结果之间的损失函数,来对第一特征提取网络和第一分类网络进行端到端的训练,其训练目的为使得该损失函数的值最小化。

进一步的,为了提高神经网络模型的可靠性,在神经网络模型的训练过程中,一般需要经过多轮次的训练迭代来更新模型参数。因此,本申请的神经网络的训练方法,应用于每一个轮次的训练过程中。每一轮训练都需要从总的训练数据集中抽取一定比例的OOD数据和ID数据,作为神经网络模型本轮次的训练数据,换句话说,本申请的训练数据,可以是总的训练数据集的子集。若完成本轮次的训练后,训练流程未能收敛,则继续从总的训练数据集中抽取另一部分数据作为下一轮次的训练数据,继续训练模型,直至训练流程成功收敛。每一个轮次的训练过程,都执行本申请中的神经网络的训练方法(即步骤101至步骤104),直至满足预设条件(例如损失函数的值满足预设条件),从而得到第二特征提取网络和第二分类网络。其中,第二特征提取网络为执行过训练操作的第一特征提取网络,第二分类网络为执行过训练操作的第一分类网络。

本申请中,只将训练数据所生成的ID特征输入到第一分类网络,即第一分类网络只关注对于ID特征的处理。因此,即便增加了新的OOD数据类别,也不会增加第一分类网络的输出类别,从而提高了第一分类网络训练后的性能。

另一方面,第一分类网络的输入只关注ID数据,即只有ID数据作为神经网络模型的输入时,才会体现在第一分类网络的分类结果中,而OOD数据是不会体现在第一分类网络的分类结果中的,从而去除了OOD数据所带来的无效特征贡献的干扰,提高神经网络模型的准确性和稳定性。

在训练完毕得到第二特征提取网络和第二分类网络之后,便可以进行神经网络模型的预测流程。将待分类数据输入到第二特征提取网络,得到该待分类数据对应的第二特征。若第二特征属于ID特征,则将第二特征输入到第二分类网络,从而得到该第二特征对应的分类结果。若第二特征属于OOD特征,则第二特征不会输入到第二分类网络,则第二分类网络也不会输出该第二特征的分类结果。

本申请中的神经网络的训练方法中,训练数据输入到第一特征提取网络,生成训练数据的特征后,是根据该训练数据的标签的指示来确定该特征属于OOD特征或ID特征的。然而在,模型训练完毕之后(即得到第二特征提取网络和第二分类网络),该神经网络模型的预测过程中,所输入的待分类数据,是没有标签来指示该待分类数据属于OOD数据或ID数据的。待分类数据输入到第二特征提取网络后,生成待分类数据对应的第二特征,该第二特征同样也没有标签来指示该特征属于OOD特征或ID特征。因此,需要对第二特征进行过滤,确认第二特征属于OOD特征或ID特征。

请参阅图6,图6为本申请实施例中神经网络模型的预测流程示意图。如图6所示,待分类数据输入第二特征提取网络,得到该待分类数据对应的第二特征。接下来,确定第二特征属于OOD特征或ID特征。若第二特征属于ID特征,则将该第二特征输入到第二分类网络中,由第二分类网络对第二特征进行处理,得到该第二特征对应的第二分类结果;若第二特征属于OOD特征,则该第二特征不会作为第二分类网络的输入。

进一步的,确定第二特征属于OOD特征或ID特征的方式有多种,下面分别进行介绍。

在一些可能的实施方式中,可以根据OOD锚点库来确定第二特征属于OOD特征或ID特征。具体的,在训练完毕得到第二特征提取网络和第二分类网络之后,可以再根据训练数据中的OOD数据来构建OOD锚点库,OOD锚点库包括至少一个OOD特征锚点。请参阅图7,图7为本申请实施例中OOD锚点库的构建流程示意图。由于训练数据中包括了OOD数据和ID数据,因此,OOD锚点库可以通过训练数据中的OOD数据和第二特征提取网络来构建。如图7所示,先获取用于构建OOD锚点库的训练数据中的OOD数据,然后将OOD数据输入到第二特征提取网络中,得到OOD数据对应的第二OOD特征,将该第二OOD特征作为OOD特征锚点,保存到OOD锚点库中。其中,第二特征提取网络为第一特征提取网络训练完毕后所得到的。

需要说明的是,本申请并不限定OOD锚点库中的OOD特征锚点的数量。该OOD锚点库包括至少一个OOD特征锚点,但OOD锚点库中的OOD特征锚点的数量越多,则区分OOD特征和ID特征的准确性和稳定性则越高。一般来说,在模型训练的过程中,OOD数据的数量会远低于ID数据的数量,因此,通过结合OOD锚点库的方式来区分OOD特征和ID特征,能够在OOD数据的样本数量较少的场景下,仍然能够具备良好的准确性和稳定性。

在完成OOD锚点库的构建之后,便可以根据OOD锚点库来确定第二特征属于OOD特征或ID特征。具体的,待分类数据输入第二特征提取网络,生成该待分类数据对应的第二特征。第二特征和计算OOD锚点库中的OOD特征锚点,都可以用二维的向量来进行表示。根据第二特征的向量表示和OOD特征锚点的向量表示,计算OOD锚点库中的OOD特征锚点与第二特征之间的距离,该距离越小,则说明第二特征与OOD特征锚点之间的相似度越高、关联度越高,该第二特征有更高的概率属于OOD特征,即该第二特征属于OOD特征的置信度越高;反之,该距离越大,则说明第二特征与OOD特征锚点之间的相似度越低、关联度越低,该第二特征有更高的概率属于ID特征,即该第二特征属于OOD特征的置信度越低。因此,可以先配置一个预设阈值,当第二特征与OOD锚点库中的OOD特征锚点之间的距离小于预设阈值时,确定第二特征为OOD特征;当第二特征与OOD锚点库中的OOD特征锚点之间的距离大于或等于预设阈值时,确定第二特征为ID特征。

在构建了OOD锚点库之后,即便输入了新增的OOD数据类型,该OOD锚点库也同样适用于识别过滤新增的OOD数据,不需要针对该新增的OOD数据类型来重新构建或优化OOD锚点库中的OOD特征锚点,可以继续基于该OOD锚点库来识别新增的OOD数据,提高了神经网络模型的效率。

示例性的,可以将OOD锚点库中的OOD特征锚点通过k均值聚类算法(k-meansclustering algorithm),将OOD锚点库中的所有OOD特征锚点进行聚类,得到若干个簇,每个簇由一个OOD特征锚点作为这个簇的质心。计算第二特征与每个簇的质心之间的二维向量上的距离,当该第二特征与某个簇的质心的距离小于预设阈值时,则可以确定该第二特征属于这个簇,即该第二特征为OOD特征;当第二特征与所有簇的质心的距离都大于或等于预设阈值时,则表明该第二特征不属于任何簇,即该第二特征为ID特征。例如,第二特征与每个簇的质心之间的距离的预设阈值为0.5的欧式距离,那么假设当第二特征与某个簇的质心的距离为0.3时,则可以确定该第二特征属于这个簇,即该第二特征为OOD特征;假设当第二特征与某个簇的质心的距离为0.7时,则可以确定该第二特征不属于这个簇;如果第二特征与所有簇的质心的距离均大于0.5,则该第二特征为ID特征。

在一些可能的实施方式中,可以通过二分类神经网络来确定第二特征属于OOD特征或ID特征。在该二分类神经网络的训练过程中,将训练数据中的OOD数据和ID数据均作为训练的输入。训练数据由第二特征提取网络进行特征提取,从而作为二分类神经网络的输入,该二分类神经网络的输出端则指示训练数据属于OOD类型或ID类型,由于每个训练数据的标签是已知的(OOD数据或ID数据),因此,根据该二分类神经网络的输出和训练数据的标签之间的损失函数来对该二分类神经网络进行训练。待二分类神经网络训练完毕之后,待分类数据输入到第二特征提取网络,得到第二特征。第二特征输入到二分类神经网络,二分类神经网络的输出端指示出该第二特征属于OOD特征或ID特征。

应理解,在实际应用中,除了上述OOD锚点库和二分类神经网络的示例之外,还有其他方式可以用于实现OOD特征和ID特征的区分,本申请对此不做限定。

在图5所对应的实施例的基础上,为了更好的实施本申请实施例的上述方案,下面还提供用于实施上述方案的相关设备。具体的,请参阅图8,图8为本申请实施例提供的一种神经网络的训练装置的结构示意图。如图8所示,神经网络的训练装置包括:

处理单元201,用于获取训练数据,训练数据包括分布外OOD数据和分布内ID数据;用于将训练数据输入第一特征提取网络,得到OOD数据对应的第一OOD特征和ID数据对应的第一ID特征;还用于将第一ID特征输入第一分类网络,得到第一分类结果;

训练单元202,用于根据第一OOD特征、第一ID特征和第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络;根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

在一种可能的设计中,处理单元201,还用于:

将训练数据中的OOD数据输入第二特征网络,得到OOD数据对应的第二OOD特征;将第二OOD特征作为OOD特征锚点,保存到OOD锚点库,其中,OOD锚点库用于确定第二特征属于OOD特征或ID特征,第二特征为第二特征提取网络对待分类数据进行特征提取所得到的特征。

在一种可能的设计中,处理单元201,还用于:

将待分类数据输入第二特征提取网络,得到第二特征;若第二特征属于ID特征,则将第二特征输入第二分类网络,得到第二分类结果。

在一种可能的设计中,处理单元201,还用于:对第二特征进行过滤,确认第二特征属于OOD特征或ID特征。

在一种可能的设计中,处理单元201,具体用于:

获取OOD锚点库,OOD锚点库包括至少一个OOD特征锚点;计算OOD锚点库中的OOD特征锚点与第二特征之间的距离;若距离的值大于预设阈值,则确定第二特征为ID特征。

在一种可能的设计中,处理单元201,具体用于:

计算第一OOD特征与第一ID特征之间的距离;以增加第一OOD特征与第一ID特征之间的距离作为损失函数,对第一特征提取网络进行训练;根据第一分类结果对第一特征提取网络进行训练。

神经网络的训练装置中各模块/单元之间的信息交互、执行过程等内容,与本申请中图5对应的方法实施例基于同一构思,具体内容可参见本申请前述所示的方法实施例中的叙述,此处不再赘述。

本申请实施例还提供了一种计算机设备,请参阅图9,图9为本申请实施例提供的计算机设备的一种结构示意图,计算机设备300上可以部署有图8对应实施例中所描述的神经网络的训练装置,用于实现图5对应实施例中第一网络设备设备的功能,具体的,计算机设备300由一个或多个服务器实现,计算机设备300可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上中央处理器(central processing units,CPU)322(例如,一个或一个以上处理器)和存储器332,一个或一个以上存储应用程序342或数据344的存储介质330(例如一个或一个以上海量存储设备)。其中,存储器332和存储介质330可以是短暂存储或持久存储。存储在存储介质330的程序可以包括一个或一个以上模块(图示没标出),每个模块可以包括对计算机设备中的一系列指令操作。更进一步地,中央处理器322可以设置为与存储介质330通信,在计算机设备300上执行存储介质330中的一系列指令操作。

计算机设备300还可以包括一个或一个以上电源326,一个或一个以上有线或无线网络接口350,一个或一个以上输入输出接口358,和/或,一个或一个以上操作系统341,例如Windows Server

需要说明的是,计算机设备中各模块/单元之间的信息交互、执行过程等内容,与本申请中图5对应的方法实施例基于同一构思,具体内容可参见本申请前述所示的方法实施例中的叙述,此处不再赘述。

本申请实施例中还提供一种包括计算机程序产品,当其在计算机上运行时,使得计算机执行如前述图5所示实施例描述的方法。

本申请实施例中还提供一种计算机可读存储介质,该计算机可读存储介质中存储有用于进行信号处理的程序,当其在计算机上运行时,使得计算机执行如前述图5所示实施例描述的方法。

本申请实施例提供的神经网络的训练装置具体可以为芯片,芯片包括:处理单元和通信单元,所述处理单元例如可以是处理器,所述通信单元例如可以是输入/输出接口、管脚或电路等。该处理单元可执行存储单元存储的计算机执行指令,以使芯片执行上述图5所示实施例描述的方法。可选地,所述存储单元为所述芯片内的存储单元,如寄存器、缓存等,所述存储单元还可以是所述无线接入设备端内的位于所述芯片外部的存储单元,如只读存储器(read-only memory,ROM)或可存储静态信息和指令的其他类型的静态存储设备,随机存取存储器(random access memory,RAM)等。

所另外需说明的是,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本申请提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。

通过以上的实施方式的描述,所属领域的技术人员可以清楚地了解到本申请可借助软件加必需的通用硬件的方式来实现,当然也可以通过专用硬件包括专用集成电路、专用CPU、专用存储器、专用元器件等来实现。一般情况下,凡由计算机程序完成的功能都可以很容易地用相应的硬件来实现,而且,用来实现同一功能的具体硬件结构也可以是多种多样的,例如模拟电路、数字电路或专用电路等。但是,对本申请而言更多情况下软件程序实现是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在可读取的存储介质中,如计算机的软盘、U盘、移动硬盘、ROM、RAM、磁碟或者光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,训练设备,或者网络设备等)执行本申请各个实施例所述的方法。

在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。

所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、训练设备或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、训练设备或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存储的任何可用介质或者是包含一个或多个可用介质集成的训练设备、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘(Solid State Disk,SSD))等。

相关技术
  • 一种神经网络模型训练方法及装置、文本标签确定方法及装置
  • 一种用于皮肤病理图像处理的神经网络训练方法及装置
  • 一种神经网络训练方法及装置、设备、介质
  • 一种神经网络训练方法、装置、计算机设备和存储介质
  • 一种模型训练方法、合成说话表情的方法和相关装置
  • 一种神经网络训练方法及相关装置
  • 一种基于物理信息神经网络的模型训练方法及相关装置
技术分类

06120116482548