掌桥专利:专业的专利平台
掌桥专利
首页

数据分类识别方法、装置、设备及可读存储介质

文献发布时间:2023-06-19 11:22:42


数据分类识别方法、装置、设备及可读存储介质

技术领域

本申请实施例涉及机器学习领域,特别涉及一种数据分类识别方法、装置、设备及可读存储介质。

背景技术

在基于医学影像的疾病诊断方面通常包括罕见病的诊断和常见病的诊断,也即,将医学影像输入至机器学习模型后,由机器学习模型对医学影像进行分析,从而判断医学影像所对应的身体异常情况。

相关技术中,在针对罕见病进行诊断时,将医学影像输入至罕见病的分类模型中,由分类模型对医学影像进行分析诊断,从而确定医学影像所表达的图像特征是否属于罕见病,以及属于哪一种罕见病。其中,分类模型在训练过程中,需要大量有标注的训练数据集,即标注有罕见病信息的图像数据集进行训练,从而确保模型准确率。

然而,罕见病本身属于出现几率较低的病症,收集罕见病的图像数据以及对罕见病信息进行标注的难度较大,导致分类模型的训练效率较低。

发明内容

本申请实施例提供了一种数据分类识别方法、装置、设备及可读存储介质,能够提高对针对罕见病进行识别分类的识别模型的训练效率。所述技术方案如下。

一方面,提供了一种数据分类识别方法,所述方法包括:

获取第一数据集和第二数据集,所述第一数据集中包括第一数据,所述第二数据集中包括标注有样本标签的第二数据,所述第二数据属于目标分类集;

通过所述第一数据以无监督训练模式,以及所述第二数据以监督训练模式训练得到分类教师模型;

获取分类学生模型,所述分类学生模型为模型参数待训练的模型;

通过所述第一数据以所述分类教师模型为基准模型,对所述分类学生模型的所述模型参数进行蒸馏训练,得到数据分类模型;

通过所述数据分类模型对目标数据进行分类预测,得到所述目标数据在所述目标分类集中所属的分类结果。

另一方面,提供了一种数据分类识别装置,所述装置包括:

获取模块,用于获取第一数据集和第二数据集,所述第一数据集中包括第一数据,所述第二数据集中包括标注有样本标签的第二数据,所述第二数据属于目标分类集;

训练模块,用于通过所述第一数据以无监督训练模式,以及所述第二数据以监督训练模式训练得到分类教师模型;

所述获取模块,还用于获取分类学生模型,所述分类学生模型为模型参数待训练的模型;

所述训练模块,还用于通过所述第一数据以所述分类教师模型为基准模型,对所述分类学生模型的所述模型参数进行蒸馏训练,得到数据分类模型;

预测模块,用于通过所述数据分类模型对目标数据进行分类预测,得到所述目标数据在所述目标分类集中所属的分类结果。

另一方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如上述本申请实施例中任一所述数据分类识别方法。

另一方面,提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如上述本申请实施例中任一所述的数据分类识别方法。

另一方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的数据分类识别方法。

本申请实施例提供的技术方案带来的有益效果至少包括:

在通过无标签的第一数据进行无监督训练以及有标签的第二数据进行监督训练后,得到分类教师模型,从而在分类教师模型的基础上,创建分类学生模型进行知识蒸馏训练,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型,训练主要依赖大量的第一数据,而对有标签的第二数据的数据量要求较小,避免了对样本数据进行大量标注的繁琐过程,提高了数据分类模型的训练效率以及准确率。

附图说明

为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。

图1是本申请一个示例性实施例提供的整体方案实施流程示意图;

图2是本申请一个示例性实施例提供的实施环境示意图;

图3是本申请一个示例性实施例提供的数据分类识别方法的流程图;

图4是本申请另一个示例性实施例提供的数据分类识别方法的流程图;

图5是本申请另一个示例性实施例提供的数据分类识别方法;

图6是本申请一个示例性实施例提供的罕见病分类识别模型的训练过程整体示意图;

图7是本申请一个示例性实施例提供的数据分类识别装置的结构框图;

图8是本申请另一个示例性实施例提供的数据分类识别装置的结构框图;

图9是本申请一个示例性实施例提供的服务器的结构框图。

具体实施方式

为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。

首先,针对本申请实施例中涉及的名词进行简单介绍。

人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。

人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。

机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。

计算机视觉技术(Computer Vision,CV):是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别、跟踪和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、光学字符识别(Optical Character Recognition,OCR)、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。

伪标签:是指通过经过训练的模型对未标注的数据进行预测后得到预测结果,并基于预测结果对数据进行标注的标签。也即,伪标签并非根据数据的实际情况人工标注的标签,而是由训练好的模型标注的存在一定容错率的标签。

相关技术中,针对罕见病的诊断,需要通过用于罕见病诊断的分类模型,而分类模型的训练则需要通过标注有罕见病信息的大量样本图像数据,通过分类模型对样本图像数据进行分类识别后,得到识别结果,通过标注的罕见病信息与识别结果之间的差异对分类模型进行训练。然而,由于罕见病本身的罕见性,导致样本图像数据的获取难度较大,需要大量的人力采集样本图像数据,以及对样本图像数据进行罕见病信息的标注,分类模型的训练效率较低。

本申请实施例中,提供了一种数据分类识别方法,在样本数量较少的情况下提高了数据分类模型的训练效率和准确率。

示意性的,图1是本申请一个示例性实施例提供的整体方案实施流程示意图,以罕见病的分类模型训练过程为例,如图1所示。

首先获取第一图像数据集110和第二图像数据集120,其中,第一图像数据集110中包括常见病的医学影像,且第一图像数据集110中的医学影像不包括标注信息;第二图像数据集120中包括少量罕见病的医学影像,且第二图像数据集120中的医学影像包括标注信息用于标注医学影像对应的罕见病信息。

通过第一图像数据集110对特征提取网络f

其次,对本申请实施例中涉及的实施环境进行说明,示意性的,请参考图2,该实施环境中涉及终端210、服务器220,终端210和服务器220之间通过通信网络230连接。

在一些实施例中,终端210包括第一终端211和第二终端212。

第一终端211用于向服务器220发送医学影像。示意性的,第一终端211为医生应用的终端,医生在通过医学影像对罕见病进行诊断的过程中,通过分类模型进行辅助诊断,从而提高诊断准确率;或者,第一终端211为用户应用的终端,如:患者本人,或者患者的亲属等,用户将医学影像发送至服务器,从而获取参考诊断结果;或者,第一终端211为医学影像扫描设备所连接的终端,医学影像扫描设备在扫描得到医学影像后传输至第一终端211,第一终端211在接收到医学影像后,将医学影像转发至服务器220进行辅助诊断。

服务器220通过上述图1所示的方式进行罕见病分类模型221的训练,得到罕见病分类模型221后,接收第一终端211上传的医学影像,并通过分类模型对医学影像进行分类识别,得到医学影像在罕见病分类集中的分类诊断结果。服务器220将分类诊断结果反馈至第一终端211或者将分类诊断结果发送至第二终端212。

其中,当第一终端211实现为与医学影像扫描设备连接的终端时,服务器220将分类诊断结果发送至第二终端212,第二终端212实现为医生应用的终端或者用户应用的终端。

上述终端可以是手机、平板电脑、台式电脑、便携式笔记本电脑等多种形式的终端设备,本申请实施例对此不加以限定。

值得注意的是,上述服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。

其中,云技术(Cloud technology)是指在广域网或局域网内将硬件、软件、网络等系列资源统一起来,实现数据的计算、储存、处理和共享的一种托管技术。云技术基于云计算商业模式应用的网络技术、信息技术、整合技术、管理平台技术、应用技术等的总称,可以组成资源池,按需所用,灵活便利。云计算技术将变成重要支撑。技术网络系统的后台服务需要大量的计算、存储资源,如视频网站、图片类网站和更多的门户网站。伴随着互联网行业的高度发展和应用,将来每个物品都有可能存在自己的识别标志,都需要传输到后台系统进行逻辑处理,不同程度级别的数据将会分开处理,各类行业数据皆需要强大的系统后盾支撑,只能通过云计算来实现。

在一些实施例中,上述服务器还可以实现为区块链系统中的节点。区块链(Blockchain)是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链,本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层。

结合上述名词简介,对本申请实施例中涉及的应用场景进行举例说明。

第一,医生通过分类模型进行辅助诊断的场景。

也即,医生通过终端将医学影像发送至服务器,服务器通过训练好的分类模型对医学影像进行分类识别,得到与医学影像对应的分类诊断结果,并将分类诊断结果反馈至医生所应用的终端进行展示,从而医生通过分类诊断结果进行辅助诊断,并得出最终的诊断结果。

第二,用户通过分类模型进行预诊断。

用户(患者或者患者的亲友)将医学影像发送至服务器,服务器通过训练好的分类模型对医学影像进行分类识别,得到与医学影像对应的分类诊断结果,并将分类诊断结果反馈至用户应用的终端进行展示,用户根据分类诊断结果首先对异常生命状态进行初步了解,继而通过医生诊断得到详细诊断结果。

第三,分类模型还可以应用于其他分类场景。

示意性的,该分类模型还可以应用于物体识别场景、语音识别场景、人脸识别场景等,本申请实施例对此不加以限定。

结合上述名词简介和应用场景,对本申请提供的数据分类识别方法进行说明,以该方法应用于服务器中为例,如图3所示,该方法包括如下步骤。

步骤301,获取第一数据集和第二数据集。

其中,第一数据集中包括第一数据,第二数据集中包括标注有样本标签的第二数据,第二数据属于目标分类集。

在一些实施例中,第一数据集中的第一数据为未标注有标签的数据,而第二数据集中的第二数据为标注有样本标签的数据。

可选地,第一数据属于第一分类集,第二数据属于目标分类集,也即第一数据和第二数据属于不同分类集对应的数据。示意性的,第一分类集对应为常见病分类集,以眼部疾病为例,如:第一分类集中包括近视、远视、结膜炎等常见眼疾类型;目标分类集对应为罕见病分类集,以眼部疾病为例,如:目标分类集中包括干眼症、视雪症、遗传性视神经病变等罕见眼疾类型。

本申请实施例中,常见病和罕见病是针对同一器官或者同一身体部分对应的病症,或者,常见病和罕见病属于同一病症类型。

在一些实施例中,第一数据集中包括的第一数据为与常见病对应的医学影像,如:电子计算机断层扫描(Computed Tomography,CT)图像、X光图像、超声波图像等形式的影像;第二数据集中包括的第二数据为与和罕见病对应的医学影像,如:CT图像、X光图像、超声波图像等形式的影像。

值得注意的是,上述医学影像仅为示意性的举例,本申请实施例中的第一数据和第二数据还可以实现为其他类型的数据,本申请实施例对此不加以限定。

第一数据集中第一数据的数据量(即医学影像的数量)大于第二数据集中第二数据的数据量(即医学影像的数量),可选地,第二数据集中第二数据的数量在要求数量范围内,如:小于预设数量。

可选地,第一数据集中的第一数据是从基础数据集中随机采样的数据,基础数据集中包括常见病数据;第二数据集中的第二数据是从罕见病数据集中随机采样的数据,罕见病数据集中包括罕见病数据,第二数据标注有罕见病信息,也即每个医学影像所对应的罕见病类型。

步骤302,通过第一数据以无监督训练模式,以及第二数据以监督训练模式训练得到分类教师模型。

在一些实施例中,基于第一数据集中的第一数据对特征提取网络进行无监督训练,将分类回归网络与经过无监督训练的特征提取网络结合,得到分类模型,其中,分类回归网络用于在目标分类集中进行数据分类,通过第二数据集中的第二数据和样本标签对分类模型进行监督训练,得到分类教师模型。

由于第一数据集中的第一数据为不存在对应标注的标签的数据,故,第一数据仅能够用于对特征提取网络进行无监督训练。而第二数据集中的第二数据存在对应标注的样本标签,故,第二数据能够用于对分类模型进行监督训练。

步骤303,获取分类学生模型。

分类学生模型为模型参数待训练的模型。

可选的,分类学生模型为随机初始化的分类模型,分类学生模型中包括模型参数,分类学生模型用于根据分类教师模型输出的知识进行蒸馏训练。其中,知识蒸馏是指将教师模型输出的监督信息作为知识,由学生模型学习迁移自教师模型的监督信息作为蒸馏过程,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型。

步骤304,通过第一数据以分类教师模型为基准模型,对分类学生模型的模型参数进行蒸馏训练,得到数据分类模型。

可选地,通过分类教师模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应目标分类集中类别的伪标签;通过分类学生模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应的预测结果,基于伪标签与预测结果之间的差异对分类学生模型的模型参数进行调整,得到数据分类模型。

也即,将分类教师模型对第一数据进行分类预测后输出的伪标签作为知识,由分类学生模型迁移该伪标签进行蒸馏,从而实现分类学生模型的蒸馏训练。

步骤305,通过数据分类模型对目标数据进行分类预测,得到目标数据在目标分类集中所属的分类结果。

在分类学生模型经过训练后,得到数据分类模型,通过数据分类模型对目标数据进行分类,即可得到目标数据在目标分类集中的分类结果。其中,目标数据可以是实际应用时的数据,如:实际应用时的医学影像;或者,目标数据也可以是测试集中用于对数据分类模型进行测试的数据。

综上所述,本实施例提供的数据分类识别方法,在通过无标签的第一数据进行无监督训练以及有标签的第二数据进行监督训练后,得到分类教师模型,从而在分类教师模型的基础上,创建分类学生模型进行知识蒸馏训练,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型,训练主要依赖大量的第一数据,而对有标签的第二数据的数据量要求较小,避免了对样本数据进行大量标注的繁琐过程,提高了数据分类模型的训练效率以及准确率。

在一些实施例中,通过分类教师模型对分类学生模型进行蒸馏训练的过程中,需要通过分类教师模型识别得到的伪标签作为知识,图4是本申请另一个示例性实施例提供的数据分类识别方法的流程图,以该方法应用于服务器中为例进行说明,如图4所示,该方法包括如下步骤。

步骤401,获取第一数据集和第二数据集。

其中,第一数据集中包括第一数据,第二数据集中包括标注有样本标签的第二数据,第二数据属于目标分类集。

在一些实施例中,第一数据集中的第一数据为未标注有标签的数据,而第二数据集中的第二数据为标注有样本标签的数据。

可选地,第一数据属于第一分类集,第二数据属于目标分类集,也即第一数据和第二数据属于不同分类集对应的数据。示意性的,第一分类集对应为常见病分类集;目标分类集对应为罕见病分类集。

步骤402,通过第一数据以无监督训练模式,以及第二数据以监督训练模式训练得到分类教师模型。

在一些实施例中,基于第一数据集中的第一数据对特征提取网络进行无监督训练,将分类回归网络与经过无监督训练的特征提取网络结合,得到分类模型,其中,分类回归网络用于在目标分类集中进行数据分类,通过第二数据集中的第二数据和样本标签对分类模型进行监督训练,得到分类教师模型。

分类教师模型具有较好的分类性能,但在表征学习的过程中,忽略了与目标分类集相关的知识,故,本申请实施例中,将分类教师模型作为基准模型,通过分类教师模型输出的知识对分类学生模型进行蒸馏训练。其中,分类学生模型为模型参数待调整的用于在目标分类集进行分类的模型。

步骤403,通过分类教师模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应目标分类集中类别的伪标签。

由于即使第一数据集和第二数据经济所包含的数据类别不同,但数据具有相似的特征,示意性的,第一数据集为常见病的医学影像,第二数据集为罕见病的医学影像,则第一数据集和第二数据集的数据在颜色、纹理或者形状上具有相似的特征。因此,采用分类教师模型作为基准模型预测第一数据集中的图像属于目标分类集的概率。

在一些实施例中,通过分类教师模型对第一数据集中的第一数据进行分类预测,得到第一数据对应目标分类集中分类的概率值,基于概率值从目标分类集中确定第一数据对应的伪标签。

可选地,首先通过分类教师模型确定第一数据对应目标分类集中类别的软标签,也即对应目标分类集中类别的概率:p=F(x)=[p

步骤404,获取分类学生模型,并通过分类学生模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应的预测结果。

分类学生模型为模型参数待调整的模型,分类学生模型用于对应目标分类集对数据进行分类。

步骤405,基于伪标签与预测结果之间的差异对分类学生模型的模型参数进行调整,得到数据分类模型。

可选地,分类学生模型中包括第一查询编码器和第一键值编码器,则通过第一查询编码器对第一数据进行编码,得到第一编码结果,通过第一键值编码器对第一数据和第一预设动态字典中的数据进行编码的第二编码结果,基于第一编码结果和第二编码结果的差异对分类学生模型进行蒸馏训练,得到数据分类模型。

在一些实施例中,结合伪标签监督方法与对比判别方法进行混合蒸馏损失的确定,其中伪标签监督方法即为基于伪标签与预测结果之间的差异对分类学生模型的模型参数进行调整,对比判别方法即为通过第一查询编码器与第一键值编码器对分类学生模型进行训练。可选地,采用随机初始化学生模型的策略,其中分类学生模型F’=f’

公式一:L

其中,x为第一数据集中的第一数据,θ’

在一些实施例中,与基准模型不同的是,f’

在实际训练中,由于罕见病对应的第二数据集中数据量较少以及其产生的噪声和偏差,分类教师模型生成的伪标签不是完全可用的并且可能对分类学生模型的训练造成不利影响。故,本申请实施例中,伪标签还对应有置信度参数,获取伪标签的置信度参数,确定伪标签在置信度参数下与预测结果之间的差异,并基于差异对分类学生模型的模型参数进行调整,得到数据分类模型。

示意性的,本实施例中,将分类学生模型的预测值p’与伪标签y结合作为训练目标,请参考如下公式二。

公式二:y

其中,α为置信度参数,控制分类教师模型生成的伪标签y所占训练目标的比例。通常α为一个固定值,然而,在训练的初始阶段,学生模型所产生的预测值的可信度较低。因此本申请采用线性增长方法,在第t个训练回合的α为:α

步骤406,通过数据分类模型对目标数据进行分类预测,得到目标数据在目标分类集中所属的分类结果。

在分类学生模型经过训练后,得到数据分类模型,通过数据分类模型对目标数据进行分类,即可得到目标数据在目标分类集中的分类结果。其中,目标数据可以是实际应用时的数据,如:实际应用时的医学影像;或者,目标数据也可以是测试集中用于对数据分类模型进行测试的数据。

在一些实施例中,获取测试数据集,测试数据集中的测试数据用于对数据分类模型的训练效果进行测试,从测试数据集中获取目标数据,目标数据标注有参考分类信息,通过数据分类模型对目标数据进行分类预测得到分类结果后,基于参考分类信息和分类结果获取数据分类模型的训练效果数据。示意性的,获取测试数据集中的多个目标数据,分别进行分类预测,并与参考分类信息进行比对,根据比对结果正确的目标数据占被测试的目标数据总数的比例,确定训练效果,也即确定数据分类模型的预测准确率。

综上所述,本实施例提供的数据分类识别方法,在通过无标签的第一数据进行无监督训练以及有标签的第二数据进行监督训练后,得到分类教师模型,从而在分类教师模型的基础上,创建分类学生模型进行知识蒸馏训练,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型,训练主要依赖大量的第一数据,而对有标签的第二数据的数据量要求较小,避免了对样本数据进行大量标注的繁琐过程,提高了数据分类模型的训练效率以及准确率。

本实施例提供的方法,结合伪标签监督方法与对比判别方法进行混合蒸馏损失的确定,在通过分类教师模型对分类学生模型进行蒸馏训练的同时,避免分类学生模型对数据的特征提取被蒸馏训练过程影响,提高了分类学生模型的训练效率和准确率。

在一些实施例中,分类教师模型是通过第一数据的无监督训练和第二数据的监督训练得到的。图5是本申请另一个示例性实施例提供的数据分类识别方法的流程图,如图5所示,以该方法应用于服务器中为例,该方法包括如下步骤。

步骤501,获取第一数据集和第二数据集。

其中,第一数据集中包括第一数据,第二数据集中包括标注有样本标签的第二数据,第二数据属于目标分类集。

在一些实施例中,第一数据集中的第一数据为未标注有标签的数据,而第二数据集中的第二数据为标注有样本标签的数据。

可选地,第一数据属于第一分类集,第二数据属于目标分类集,也即第一数据和第二数据属于不同分类集对应的数据。示意性的,第一分类集对应为常见病分类集;目标分类集对应为罕见病分类集。

步骤502,基于第一数据集中的第一数据对特征提取网络进行无监督训练。

在一些实施例中,特征提取网络中包括第二查询编码器和第二键值编码器则通过第二查询编码器对第一数据进行编码,得到第三编码结果,获取第二键值编码器对第二预设动态字典中的数据进行编码的第四编码结果,基于第三编码结果和第四编码结果的差异对特征提取网络进行无监督训练。

无监督表征学习能够在无标注数据的情况下训练一个较好的特征提取模型,故,本申请实施例中,采用对比损失作为特征提取网络的优化函数。

可选地,在通过特征提取网络对第一数据进行特征提取时,将第一数据进行数据增强,以第一数据为医学影像为例,则对第一数据集中的医学影像进行图像增强,其中,图像增强的次数两次,从而分别输入第二查询编码器和第二键值编码器。示意性的,对第一数据集中的每张图像进行两次图像增强,得到

公式三:

其中,x

通过第一数据集对特征提取网络完成无监督训练后,冻结参数θ

步骤503,将分类回归网络与经过无监督训练的特征提取网络结合,得到分类模型。

在一些实施例中,分类回归网络用于在目标分类集中进行数据分类。

可选地,由于上述特征提取网络对应有第二查询编码器和第二键值编码器,在将分类回归网络与特征提取网络结合时,本申请实施例中,将分类回归网络与经过无监督训练的第二查询编码器连接,得到分类模型。

步骤504,通过第二数据集中的第二数据和样本标签对分类模型进行监督训练,得到分类教师模型。

在一些实施例中,通过第二数据对第二模型进行监督训练时,将第二数据输入分类模型进行分类预测,得到预测结果,而第二数据本身标注有样本标签,用于指示第二数据的实际分类,从而根据样本标签与预测结果之间的差异反向对分类模型的模型参数进行调整。可选地,根据样本标签与预测结果计算该预测结果的损失值,从而根据损失值反向对分类模型的模型参数进行调整,直至预测结果对应的损失值收敛。

步骤505,获取分类学生模型。

分类学生模型为模型参数待训练的模型。

可选的,分类学生模型为随机初始化的分类模型,分类学生模型中包括模型参数,分类学生模型用于根据分类教师模型输出的知识进行蒸馏训练。其中,知识蒸馏是指将教师模型输出的监督信息作为知识,由学生模型学习迁移自教师模型的监督信息作为蒸馏过程,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型。

步骤506,通过第一数据以分类教师模型为基准模型,对分类学生模型的模型参数进行蒸馏训练,得到数据分类模型。

可选地,通过分类教师模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应目标分类集中类别的伪标签;通过分类学生模型对第一数据集中的第一数据进行分类预测,得到与第一数据对应的预测结果,基于伪标签与预测结果之间的差异对分类学生模型的模型参数进行调整,得到数据分类模型。

也即,将分类教师模型对第一数据进行分类预测后输出的伪标签作为知识,由分类学生模型迁移该伪标签进行蒸馏,从而实现分类学生模型的蒸馏训练。

步骤507,通过数据分类模型对目标数据进行分类预测,得到目标数据在目标分类集中所属的分类结果。

在分类学生模型经过训练后,得到数据分类模型,通过数据分类模型对目标数据进行分类,即可得到目标数据在目标分类集中的分类结果。其中,目标数据可以是实际应用时的数据,如:实际应用时的医学影像;或者,目标数据也可以是测试集中用于对数据分类模型进行测试的数据。

综上所述,本实施例提供的数据分类识别方法,在通过无标签的第一数据进行无监督训练以及有标签的第二数据进行监督训练后,得到分类教师模型,从而在分类教师模型的基础上,创建分类学生模型进行知识蒸馏训练,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型,训练主要依赖大量的第一数据,而对有标签的第二数据的数据量要求较小,避免了对样本数据进行大量标注的繁琐过程,提高了数据分类模型的训练效率以及准确率。

本实施例提供的方法,通过第一数据集中无标签的第一数据对特征提取网络进行无监督训练,从而通过第二数据集中有标签的第二数据对分类模型进行监督训练,从而在第二数据的采集过程较为繁琐,或者第二数据的收集难度较大时,仅需要少量采集第二数据,即可实现对分类教师模型的有效训练,提高了模型的训练效率。

结合上述内容,以上述第一数据集中的第一数据为常见病的医学影像,第二数据集中的第二数据为罕见病的医学影像为例,进行示意性的说明,图6是本申请一个示例性实施例提供的罕见病分类识别模型的训练过程整体示意图。

如图6所示,该过程中包括无监督训练阶段610、监督训练阶段620、伪标签生成阶段630以及分类学生模型的训练阶段640。

其中,在无监督训练阶段610中,将无标签标注的常见病医学影像611进行两次图像增强得到

在监督训练阶段620中,当查询编码器612与分类回归模型621连接后,得到待训练的分类教师模型622,通过罕见病的医学影像623对分类教师模型622进行监督训练时,根据罕见病的医学影像623对应标注的标签以及分类教师模型622的分类结果确定损失值,并实现对分类教师模型622的监督训练。

在分类教师模型622训练完毕后,在伪标签生成阶段630,通过分类教师模型622对常见病的医学影像611进行分类识别,得到常见病的医学影像611对应的伪标签。

在分类学生模型的训练阶段640,根据常见病的医学影像611对应的伪标签,以及分类学生模型641的预测结果得到第一损失值,根据分类学生模型641中查询编码器642和键值编码器643的编码结果得到第二损失值,从而根据第一损失值和第二损失值确定总的损失值对分类学生模型641进行训练,得到罕见病分类识别模型。

表一给出了本申请的技术方案在皮肤病变分类数据集上的结果对比。此数据集包含7个类别,将病例数量最多的四个类别的数据集作为第一数据集,剩余三个类别的数据集作为第二数据集。评价指标选择了准确率(Accuracy)、统计学中用来衡量二分类模型精确度的指标F1 score。

表一

表一中,N代表测试类别数,K代表每个测试类别提供的有标签的图片数量,本技术方案分别对比了K为1,3,5的结果。将罕见病数据集中剩余的图像组成Q作为测试集用于性能评估。

由表一可见,本技术方案的分类指标优于全部现有技术。本技术方案在基准模型的基础上加入自蒸馏,提升了准确率约1-2%,F1 score 约3-5%。从表一中可以观察到在K=5时,本技术方案无需任何常见病数据集的标注,准确率即可达到81.16%。此结果验证了本方法的假设:通过将伪标签监督信息注入到表征学习过程中并充分利用大量无标注数据集学习能够更好地学习罕见疾病数据的表征及其分类器。

图7是本申请一个示例性实施例提供的数据分类识别装置的结构示意图,如图7所示,该装置包括如下部分:

获取模块710,用于获取第一数据集和第二数据集,所述第一数据集中包括第一数据,所述第二数据集中包括标注有样本标签的第二数据,所述第二数据属于目标分类集;

训练模块720,用于通过所述第一数据以无监督训练模式,以及所述第二数据以监督训练模式训练得到分类教师模型;

所述获取模块710,还用于获取分类学生模型,所述分类学生模型为模型参数待训练的模型;

所述训练模块720,还用于通过所述第一数据以所述分类教师模型为基准模型,对所述分类学生模型的所述模型参数进行蒸馏训练,得到数据分类模型;

预测模块730,用于通过所述数据分类模型对目标数据进行分类预测,得到所述目标数据在所述目标分类集中所属的分类结果。

在一个可选的实施例中,所述预测模块730,还用于通过所述分类教师模型对所述第一数据集中的所述第一数据进行分类预测,得到与所述第一数据对应所述目标分类集中类别的伪标签;

所述预测模块730,还用于通过所述分类学生模型对所述第一数据集中的第一数据进行分类预测,得到与所述第一数据对应的预测结果;

如图8所示,训练模块720,还包括:

调整单元721,用于基于所述伪标签与所述预测结果之间的差异对所述分类学生模型的所述模型参数进行调整,得到所述数据分类模型。

在一个可选的实施例中,所述获取模块710,还用于获取所述伪标签的置信度参数;

所述调整单元721,还用于确定所述伪标签在所述置信度参数下与所述预测结果之间的差异,并基于所述差异对所述分类学生模型的所述模型参数进行调整,得到所述数据分类模型。

在一个可选的实施例中,所述预测模块730,还用于通过所述分类教师模型对所述第一数据集中的所述第一数据进行分类预测,得到所述第一数据对应所述目标分类集中分类的概率值;基于所述概率值从所述目标分类集中确定所述第一数据对应的伪标签。

在一个可选的实施例中,所述分类学生模型中包括第一查询编码器和第一键值编码器;

所述装置还包括:

编码模块740,用于通过所述第一查询编码器对所述第一数据进行编码,得到第一编码结果;

所述获取模块710,还用于获取所述第一键值编码器对所述第一数据和第一预设动态字典中的数据进行编码的第二编码结果;

所述训练模块720,还用于基于所述第一编码结果与所述第二编码结果的差异对所述分类学生模型进行蒸馏训练,得到所述数据分类模型。

在一个可选的实施例中,所述训练模块720,还用于基于所述第一数据集中的第一数据对特征提取网络进行无监督训练;将分类回归网络与经过无监督训练的所述特征提取网络结合,得到分类模型,所述分类回归网络用于在所述目标分类集中进行数据分类;

所述训练模块720,还用于通过所述第二数据集中的所述第二数据和所述样本标签对所述分类模型进行监督训练,得到所述分类教师模型。

在一个可选的实施例中,所述特征提取网络中包括第二查询编码器和第二键值编码器;

所述装置还包括:

编码模块740,用于通过所述第二查询编码器对所述第一数据进行编码,得到第三编码结果;

所述获取模块710,还用于获取所述第二键值编码器对所述第一数据和第二预设动态字典中的数据进行编码的第四编码结果;

所述训练模块720,还用于基于所述第三编码结果与所述第四编码结果的差异对所述特征提取网络进行无监督训练。

在一个可选的实施例中,所述训练模块720,还用于将所述分类回归网络与经过无监督训练的所述第二查询编码器连接,得到所述分类模型。

在一个可选的实施例中,所述获取模块710,还用于获取测试数据集,所述测试数据集中的测试数据用于对所述数据分类模型的训练效果进行测试;从所述测试数据集中获取所述目标数据,所述目标数据标注有参考分类信息;

所述预测模块730,还用于通过所述数据分类模型对目标数据进行分类预测,得到所述分类结果;

所述获取模块710,还用于基于所述参考分类信息和所述分类结果获取所述数据分类模型的训练效果数据。

综上所述,本实施例提供的数据分类识别装置,在通过无标签的第一数据进行无监督训练以及有标签的第二数据进行监督训练后,得到分类教师模型,从而在分类教师模型的基础上,创建分类学生模型进行知识蒸馏训练,利用教师模型进行监督训练来达到蒸馏的目的,最终得到更高性能和精度的学生模型,训练主要依赖大量的第一数据,而对有标签的第二数据的数据量要求较小,避免了对样本数据进行大量标注的繁琐过程,提高了数据分类模型的训练效率以及准确率。

需要说明的是:上述实施例提供的数据分类识别装置,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的数据分类识别装置与数据分类识别方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。

图9示出了本申请一个示例性实施例提供的服务器的结构示意图。

具体来讲:服务器900包括中央处理单元(Central Processing Unit,CPU)901、包括随机存取存储器(Random Access Memory,RAM)902和只读存储器(Read Only Memory,ROM)903的系统存储器904,以及连接系统存储器904和中央处理单元901的系统总线905。服务器900还包括用于存储操作系统913、应用程序914和其他程序模块915的大容量存储设备906。

大容量存储设备906通过连接到系统总线905的大容量存储控制器(未示出)连接到中央处理单元901。大容量存储设备906及其相关联的计算机可读介质为服务器900提供非易失性存储。也就是说,大容量存储设备906可以包括诸如硬盘或者紧凑型光盘只读存储器(Compact Disc Read Only Memory,CD-ROM)驱动器之类的计算机可读介质(未示出)。

不失一般性,计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、可擦除可编程只读存储器(Erasable Programmable Read Only Memory,EPROM)、带电可擦可编程只读存储器(Electrically Erasable Programmable Read Only Memory,EEPROM)、闪存或其他固态存储其技术,CD-ROM、数字通用光盘(Digital Versatile Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知计算机存储介质不局限于上述几种。上述的系统存储器904和大容量存储设备906可以统称为存储器。

根据本申请的各种实施例,服务器900还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即服务器900可以通过连接在系统总线905上的网络接口单元911连接到网络912,或者说,也可以使用网络接口单元911来连接到其他类型的网络或远程计算机系统(未示出)。

上述存储器还包括一个或者一个以上的程序,一个或者一个以上程序存储于存储器中,被配置由CPU执行。

本申请的实施例还提供了一种计算机设备,该计算机设备包括处理器和存储器,该存储器中存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行以实现上述各方法实施例提供的数据分类识别方法。

本申请的实施例还提供了一种计算机可读存储介质,该计算机可读存储介质上存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行,以实现上述各方法实施例提供的数据分类识别方法。

本申请的实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的数据分类识别方法。

可选地,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、固态硬盘(SSD,Solid State Drives)或光盘等。其中,随机存取记忆体可以包括电阻式随机存取记忆体(ReRAM,Resistance RandomAccess Memory)和动态随机存取存储器(DRAM,Dynamic Random Access Memory)。上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。

本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。

以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

相关技术
  • 数据分类识别方法、装置、设备及可读存储介质
  • 数据分类识别方法、装置、设备及可读存储介质
技术分类

06120112899828