掌桥专利:专业的专利平台
掌桥专利
首页

一种分类模型的训练方法及装置

文献发布时间:2023-06-19 10:11:51


一种分类模型的训练方法及装置

技术领域

本申请涉及人工智能技术领域,具体而言,涉及一种分类模型的训练方法及装置。

背景技术

随着PolSAR图像分类方法在图像解释中的重要应用,现已吸引了越来越多国内外研究者的热切关注。而今,通常通过分类模型来实现图像的分类处理,现有的分类模型训练方法,通常先将样本分为标记样本与无标记样本,然后利用有限的标记样本与大量的无标记样本对模型进行训练,进而得到分类模型。然而,在实践中发现,训练样本未考虑标记样本的可信度且未进行样本扩充,所得到的分类模型的分类性能较差,准确性低。

发明内容

本申请实施例的目的在于提供一种分类模型的训练方法及装置,能够考虑样本的可信度,同时对样本进行扩充,训练得到的分类模型分类性能好,准确性高。

本申请实施例第一方面提供了一种分类模型的训练方法,包括:

获取包括标记样本集和未标记样本集的原始样本集;

根据所述标记样本集计算初始聚类中心,根据所述初始聚类中心和所述未标记样本集确定高可信度样本集;

对所述高可信度样本集进行样本扩充处理,得到训练样本集;

通过所述训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。

在上述实现过程中,该方法能够获取标记样本集和未标记样本集,然后根据标记样本集计算初始聚类中心,根据初始聚类中心和未标记样本集确定高可信度样本集;然后再对高可信度样本集进行样本扩充处理,得到训练样本集;通过训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。可见,实施这种实施方式,能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

进一步地,所述根据所述初始聚类中心和所述未标记样本集确定高可信度样本集包括:

计算所述未标记样本集中未标记样本与所述初始聚类中心之间的最小距离;

根据所述最小距离对所述未标记样本集进行伪标签预测,得到伪标签预测结果;

根据所述伪标签预测结果从所述未标记样本集中确定高可信度样本集。

在上述实现过程中,该方法能够根据初始聚类中心和未标记样本集确定出高可信度样本,能够挖掘更加丰富的样本特征,进而有利于提高分类模型的训练效果。

进一步地,所述对所述高可信度样本集进行样本扩充处理,得到训练样本集,包括:

采用预先构建的样本预测模型对所述高可信度样本集进行预测处理,得到标签预测结果;

根据所述标签预测结果和所述高可信度样本集确定待扩充样本集;

将所述标记样本集与所述待扩充样本集进行合并,得到训练样本集。

在上述实现过程中,该方法可以通过预设的样本预测模型对高可信度样本集进行进一步样本提取,得到符合标签预测结果的待扩充样本集,并能够将待扩充样本集和标记样本集合并,得到训练样本集。可见,实施这种实施方式,能够对高可信度样本集进行二次处理,以使其处理结果和标记样本集能够合并成为一个特征更加丰富的训练样本集,从而提高训练样本集的可信度,进而提高分类模型的训练效果。

进一步地,所述通过所述训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型,包括:

对所述训练样本集中的样本图像进行特征图像块提取处理,得到目标特征块集合;

将所述目标特征块集合中每个特征图像块输入到预先构建的原始分类模型中进行预测处理,得到分类预测结果;

通过第二预设分类器对所述分类预测结果进行处理,得到目标预测结果;

根据所述目标预测结果和所述训练样本集确定所述原始分类模型的模型参数;

根据所述模型参数和所述原始分类模型生成训练好的分类模型。

在上述实现过程中,该方法能够通过对训练集进行预测,所得到的预测结果进行模型参数的确定,旨在得到的分类模型能够根据实际情况确定最合适的训练参数,从而使得通过训练得到的分类模型的使用效果更好。

进一步地,所述方法还包括:

获取待分类图片;

通过所述训练好的分类模型对所述待分类图片进行处理,得到所述待分类图片的类别标签预测结果。

在上述实现过程中,该方法还可以使用训练好的分类模型对待分类图片进行分类,得到对应的类别标签预测结果,从而使得训练模型能够被直接使用。

进一步地,所述方法还包括:

采用预设算法对所述训练好的分类模型进行优化处理,得到优化分类模型。

在上述实现过程中,该方法还可以进一步对分类模型进行优化,使得优化后的分类模型效果更好。

本申请实施例第二方面提供了一种分类模型的训练装置,所述分类模型的训练装置包括:

获取单元,用于获取包括标记样本集和未标记样本集的原始样本集;

计算单元,用于根据所述标记样本集计算初始聚类中心;

确定单元,用于根据所述初始聚类中心和所述未标记样本集确定高可信度样本集;

扩充单元,用于对所述高可信度样本集进行样本扩充处理,得到训练样本集;

训练单元,用于通过所述训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。

在上述实现过程中,该装置能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

进一步地,所述确定单元包括:

计算子单元,用于根据所述标记样本集计算初始聚类中心,以及计算所述未标记样本集中未标记样本与所述初始聚类中心之间的最小距离;

伪标签预测子单元,用于根据所述最小距离对所述未标记样本集进行伪标签预测,得到伪标签预测结果;

样本集确定子单元,用于根据所述伪标签预测结果从所述未标记样本集中确定高可信度样本集。

在上述实现过程中,该装置能够根据初始聚类中心和未标记样本集确定出高可信度样本,从而避免不好的样本参与训练,进而提高分类模型的训练效果。

本申请实施例第三方面提供了一种电子设备,包括存储器以及处理器,所述存储器用于存储计算机程序,所述处理器运行所述计算机程序以使所述电子设备执行本申请实施例第一方面中任一项所述的分类模型的训练方法。

本申请实施例第四方面提供了一种计算机可读存储介质,其存储有计算机程序指令,所述计算机程序指令被一处理器读取并运行时,执行本申请实施例第一方面中任一项所述的分类模型的训练方法。

附图说明

为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。

图1为本申请实施例提供的一种分类模型的训练方法的流程示意图;

图2为本申请实施例提供的另一种分类模型的训练方法的流程示意图;

图3为本申请实施例提供的一种分类模型的训练装置的结构示意图;

图4为本申请实施例提供的另一种分类模型的训练装置的结构示意图;

图5为本申请实施例提供的一种该分类模型的训练方法的部分举例示意图。

具体实施方式

下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。

应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。同时,在本申请的描述中,术语“第一”、“第二”等仅用于区分描述,而不能理解为指示或暗示相对重要性。

实施例1

请参看图1,图1为本申请实施例提供了一种分类模型的训练方法的流程示意图。其中,该分类模型的训练方法包括:

S101、获取包括标记样本集和未标记样本集的原始样本集。

本实施例中,该方法可以通过Wishart分类器来选择较为优质的样本。

S102、根据标记样本集计算初始聚类中心,根据初始聚类中心和未标记样本集确定高可信度样本集。

S103、对高可信度样本集进行样本扩充处理,得到训练样本集。

本实施例中,该方法可以通过由随机森林(Random forest,RF)、Bagging、SVM三种分类器所构成的协同三体分类器(Tri-training)的模型确定基分类器。以使该模型可以根据该基分类器对待测样本进行类标的预测以得到伪标签样本,再将该样本加入到标记样本中以扩充训练样本集,从而实现样本扩充处理的目的。

S104、通过训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。

本实施例中,该方法可以利用所得到的训练样本集对预设的多深度、多尺度的CV-CNN(Semi-Supervised Complex-Valued Convolution Neural Network,CV-CNN)分类模型(即原始分类模型)进行训练。

本实施例中,该方法可以利用最终训练得到的Tri-training半监督的多深度与多尺度CV-CNN模型对整幅PolSAR图像(即极化合成孔径雷达图像)进行标签的预测。

本申请实施例中,该方法的执行主体可以为计算机、服务器等计算装置,对此本实施例中不作任何限定。

在本申请实施例中,该方法的执行主体还可以为智能手机、平板电脑等智能设备,对此本实施例中不作任何限定。

可见,实施本实施例所描述的分类模型的训练方法,能够获取标记样本集和未标记样本集,然后根据标记样本集计算初始聚类中心,根据初始聚类中心和未标记样本集确定高可信度样本集;然后再对高可信度样本集进行样本扩充处理,得到训练样本集;通过训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。可见,实施这种实施方式,能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

实施例2

请参看图2,图2为本申请实施例提供的一种分类模型的训练方法的流程示意图。如图2所示,其中,该分类模型的训练方法包括:

S201、获取包括标记样本集和未标记样本集的原始样本集。

S202、根据标记样本集计算初始聚类中心。

本实施例中,该方法能够随机选取少量的标记样本X,并计算它们的复数域相干矩阵T的平均值,以确定初始聚类中心V

S203、计算未标记样本集中未标记样本与初始聚类中心之间的最小距离。

本实施例中,该方法能够计算所有未标记样本与每个聚类中心之间的距离d(T,V

S204、根据最小距离对未标记样本集进行伪标签预测,得到伪标签预测结果。

本实施例中,该方法可以利用最小距离进行未标记数据的伪标签的预测。

S205、根据伪标签预测结果从未标记样本集中确定高可信度样本集。

本实施例中,该方法可以选择带有伪标签的数据作为下一步分类步骤的样本。

S206、采用预先构建的样本预测模型对高可信度样本集进行预测处理,得到标签预测结果。

本实施例中,该方法可以将带有伪标签的数据作为Tri-training算法中参与预测的样本。

在本实施例中,该方法可以选取准确性与多样性较好的RF算法、Bagging算法和稀疏性与稳健性较好的SVM算法对应的三种分类器构成一种新颖的Tri-training模型,以挖掘更加丰富的图像信息。

在本实施例中,该方法可以利用上述改进的Tri-training模型对该待测样本进行类标预测,从而进一步提高其样本标签的可信度。

S207、根据标签预测结果和高可信度样本集确定待扩充样本集。

本实施例中,该方法能够将得到的伪标记样本与原有的标记样本合并,得到扩充后的训练样本集U’。

S208、将标记样本集与待扩充样本集进行合并,得到训练样本集。

S209、对训练样本集中的样本图像进行特征图像块提取处理,得到目标特征块集合。

S210、将目标特征块集合中每个特征图像块输入到预先构建的原始分类模型中进行预测处理,得到分类预测结果。

S211、通过第二预设分类器对分类预测结果进行处理,得到目标预测结果。

S212、根据目标预测结果和训练样本集确定原始分类模型的模型参数。

S213、根据模型参数和原始分类模型生成训练好的分类模型。

本实施例中,该方法可以优先对扩充样本集中的所有图像块提取不同尺度的特征,以挖掘图像中更加丰富的特征;然后,将每个特征分别输入到不同深度的CV-CNN模型中进行单独预测;接着,将每个分类结果作为Softmax分类器p(y

S214、采用预设算法对训练好的分类模型进行优化处理,得到优化分类模型。

在本实施例中,该方法可以应用BP算法对上述模型进行优化,进而利用训练好的半监督多深度与多尺度的Tri-CV-CNN模型对整幅PolSAR图像进行分类。

作为一种可选的实施方式,该方法还可以包括:

获取待分类图片;

通过训练好的分类模型对待分类图片进行处理,得到待分类图片的类别标签预测结果。

实施这种实施方式,能够实时对待分类图片进行标签分类,从而实现物以致用的效果。

本实施例中,该方法可以利用最终训练得到的Tri-training半监督的多深度与多尺度CV-CNN模型对整幅PolSAR图像进行标签的预测。

本实施例中,利用Tri-training半监督的多深度与多尺度CV-CNN模型(Semi-Supervised Complex-Valued Muti-depth and Muti-Scale Convolution NeuralNetwork with Tri-training Algorithm,Tri-CV-CNN)进行图像分类处理,能够有效地提升模型的泛化能力与分类性能,最终提高了模型的分类准确率与空间一致性。

请参阅图5,图5描述了一种该分类模型的训练方法的部分举例示意图。具体解释说明可以参照实施例中所描述的内容。

举例来说,该种方法可以理解为一种基于半监督多深度与多尺度的复值卷积神经网络的雷达图像分类方法。其中,该方法优先进行数据预选,在该过程中,该方法可以优先输入待测Flevoland数据集(750×1024)的相关矩阵T,提取其上三角{T

其中,Wishart分类器选取样本的步骤如下:

1、计算标记样本X的平均值,确定聚类中心V

2、将c

3、选择上述伪标签的样本作为样本扩充过程中的输入样本。

进一步举例来说,对于样本扩充的过程而言,该方法可以使用一种新颖的Tri-training模型(即样本预测模型),其选取了RF,Bagging与SVM算法作为模型的基学习器。其中,RF算法通过生成很多的决策树,将多个低效模型整合为一个高效模型;Bagging算法通过“有放回”的采样策略,降低学习器的方差,进而提高算法稳定性;SVM算通过利用少量标记样本构建多个超平面对样本进行划分,节省了内存空间。然后,利用改进的Tri-training模型对该未标记样本进行标签预测,利用半监督和集成学习机制,大大提高其样本标签的可信度与模型的泛化能力。其中,样本的扩充准则如下:

1、若三个分类器的预测结果相同,则将它连同预测标签加入到总训练集当中;

2、若有两个分类器的预测结果相同,则将该标签结果加入到另一个分类器所对应的训练样本中,再让这个分类器对更新的训练样本进行进一步的训练;

3、若三者预测结果均不同,则将其放回到未标记样本中。

最后,将预测得到的伪标记样本与原有的标记样本合并,得到扩充后的训练样本集U’。

进一步举例来说,多深度与多尺度CV-CNN模型的搭建过程可以使用上述所得到的扩充后的训练样本集对多深度与多尺度CV-CNN模型进行训练。其中,该过程可以包括:

1、根据训练样本集U’的中心像素位置,分别在原图上提取尺度大小为6*6*6,12*12*6,24*24*6的特征图像块,并对其进行归一化。

2、将每个特征分别输入到不同深度的CV-CNN模型中进行单独预测。

其中,尺度大小为6*6*6输入特征的图像块的深度为2,包括一个3*3*3的卷积层与一个2*2的池化层;尺度大小为12*12*6输入特征的图像块的深度为3,包括两个大小为3*3*12与3*3*11的卷积层与一个2*2的池化层;尺度大小为24*24*6输入特征的图像块的深度为4,包括两个大小为5*5*12与3*3*11的卷积层与两个2*2的池化层。

3、将每个单独预测所得到的分类结果作为Softmax分类器p(y

其中,C是类别数。y表示预测标签。输出结果表示像素p属于第q类的概率,a

4、根据扩充后的训练集和初始化后的多深度与多尺度的CV-CNN模型,通过最小化最小二乘损失函数E来学习整个网络的参数:

其中,T[n]与O[n]表示第n个输入标签与输出标签。k表示神经元数目。

第四步,半监督多深度与多尺度的Tri-CV-CNN模型的优化与分类。

首先,应用BP算法对上述模型进行优化:

其中,l表示全连接层数。η为学习率。

最后,利用训练好的半监督多深度与多尺度的Tri-CV-CNN模型对整幅Flevoland图像进行分类,得到所有像素点的类别标签预测结果。

可见,实施本实施例所描述的分类模型的训练方法,能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

实施例3

请参看图3,图3为本申请实施例提供的一种分类模型的训练装置的结构示意图。如图3所示,该分类模型的训练装置包括:

获取单元310,用于获取包括标记样本集和未标记样本集的原始样本集;

计算单元320,用于根据标记样本集计算初始聚类中心;

确定单元330,用于根据初始聚类中心和未标记样本集确定高可信度样本集;

扩充单元340,用于对高可信度样本集进行样本扩充处理,得到训练样本集;

训练单元350,用于通过训练样本集对预先构建的原始分类模型进行训练,得到训练好的分类模型。

本申请实施例中,对于分类模型的训练装置的解释说明可以参照实施例1或实施例2中的描述,对此本实施例中不再多加赘述。

可见,实施本实施例所描述的分类模型的训练装置,能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

实施例4

请一并参阅图4,图4是本申请实施例提供的一种分类模型的训练装置的结构示意图。其中,图4所示的分类模型的训练装置是由图3所示的分类模型的训练装置进行优化得到的。如图4所示,确定单元330包括:

计算子单元331,用于根据标记样本集计算初始聚类中心,以及计算未标记样本集中未标记样本与初始聚类中心之间的最小距离;

伪标签预测子单元332,用于根据最小距离对未标记样本集进行伪标签预测,得到伪标签预测结果;

样本集确定子单元333,用于根据伪标签预测结果从未标记样本集中确定高可信度样本集。

作为一种可选的实施方式,扩充单元340包括:

预测子单元341,用于采用预先构建的样本预测模型对高可信度样本集进行预测处理,得到标签预测结果;

确定子单元342,用于根据标签预测结果和高可信度样本集确定待扩充样本集;

合并子单元343,用于将标记样本集与待扩充样本集进行合并,得到训练样本集。

作为一种可选的实施方式,训练单元350包括:

提取子单元351,用于对训练样本集中的样本图像进行特征图像块提取处理,得到目标特征块集合;

处理子单元352,用于将目标特征块集合中每个特征图像块输入到预先构建的原始分类模型中进行预测处理,得到分类预测结果;

处理子单元352,还用于通过第二预设分类器对分类预测结果进行处理,得到目标预测结果;

参数确定子单元353,用于根据目标预测结果和训练样本集确定原始分类模型的模型参数;

训练子单元354,用于根据模型参数和原始分类模型生成训练好的分类模型。

作为一种可选的实施方式,分类模型的训练装置还包括:

获取单元310,还用于获取待分类图片;

预测单元360,用于通过训练好的分类模型对待分类图片进行处理,得到待分类图片的类别标签预测结果。

作为一种可选的实施方式,分类模型的训练装置还包括:

优化单元370,用于采用预设算法对训练好的分类模型进行优化处理,得到优化分类模型。

本申请实施例中,对于分类模型的训练装置的解释说明可以参照实施例1或实施例2中的描述,对此本实施例中不再多加赘述。

可见,实施本实施例所描述的分类模型的训练装置,能够根据原始样本集自动处理计算,从而获取到更高置信度的样本,进而能够通过提高样本的可信度提高分类模型的训练效果,使得训练好的分类模型的分类性能更好,准确性更高。

本申请实施例提供了一种电子设备,包括存储器以及处理器,所述存储器用于存储计算机程序,所述处理器运行所述计算机程序以使所述电子设备执行本申请实施例1或实施例2中任一项分类模型的训练方法。

本申请实施例提供了一种计算机可读存储介质,其存储有计算机程序指令,所述计算机程序指令被一处理器读取并运行时,执行本申请实施例1或实施例2中任一项分类模型的训练方法。

在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本申请的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。

另外,在本申请各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。

所述功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。

以上所述仅为本申请的实施例而已,并不用于限制本申请的保护范围,对于本领域的技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。

以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应所述以权利要求的保护范围为准。

需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。

相关技术
  • 一种分类模型训练方法和一种分类模型训练装置
  • 基于分类模型的文本分类方法及装置,以及模型训练方法
技术分类

06120112456012