掌桥专利:专业的专利平台
掌桥专利
首页

一种基于生成对抗网络的量化模型训练的方法及装置

文献发布时间:2023-06-19 09:44:49


一种基于生成对抗网络的量化模型训练的方法及装置

技术领域

本公开涉及计算机技术,特别涉及一种基于生成对抗网络的量化模型训练的方法及装置。

背景技术

在日常生活中,在不同使用场景下,对模型精度有不同的要求,特定场景下,模型的精度可以满足日常的生活需求即可,例如,计算圆的周长的模型,仅仅需要使用3.14代表圆周率即可,无需在计算时使用精确到小数点后几十位,甚至上百位,因此,我们需要对一些模型进行量化。

但是,相关技术下,所述量化技术都会降低模型的精确度,同时也难以保证量化后模型的性能。

因此,需要一种新的量化模型训练的方法及装置,以克服上述缺陷。

发明内容

本公开提供一种基于生成对抗网络的量化模型训练的方法及装置,用以降低模型在低比特量化时的性能损失。

本发明提供的具体技术方案如下:

第一方面,一种基于生成对抗网络的量化模型训练的方法,包括:

确定第一分类模型,第二分类模型以及第三分类模型,其中所述第二分类模型为所述第一分类模型经过量化后得到的模型,所述第三分类模型为与第一模型不同的模型;

基于样本数据集合,采用循环迭代方式执行以下步骤,直到训练完毕为止:

将样本数据分别输入所述第一分类模型和所述第二分类模型,得到相应的第一处理结果和第二处理结果;

将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果;

将第三处理结果与真实结果进行误差比较,获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整;输出训练完毕的第二分类模型。

可选的,确定第一分类模型,第二分类模型以及第三分类模型之前,进一步包括:

将训练完成后的第一分类模型直接进行量化,将量化后的第一分类模型作为第二分类模型;

或者,

将带有模拟量化的噪声数据的样本数据集合,送入所述第一分类模型进行训练,输出训练完成后的第一分类模型,作为第二分类模型。

可选的,将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,包括:

将所述样本数据、所述第一分类模型和所述第二分类模型在分类过程中使用的各个标签、所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,所述第三处理结果表征所述第一处理结果是否由所述第一分类模型输出,以及所述第二处理结果是否由所述第二分类模型输出。

可选的,采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整,包括:

采用所述梯度,对所述第一分类模型以及所述第二分类模型分别进行参数调整,使得参数调整后的第一分类模型再次输出的第一处理结果与参数调整后的第二分类模型再次输出的第二处理结果趋于相同;

采用所述梯度,对所述第三分类模型进行参数调整,使得参数调整后的第三分类模型再次输出的第三处理结果的判断准确率提升。

可选的,输出训练完毕的第二分类模型,包括:

当所述第三分类模型不能正确区分所述第一处理结果和所述第二处理结果时,输出训练完毕的第二分类模型;

或者,

当所述第一处理结果和所述第二处理结果之间的误差小于预设的误差门限值时,输出训练完毕的第二分类模型。

第二方面,一种基于生成对抗网络的量化模型训练的装置,包括:

生成单元,用于确定第一分类模型,第二分类模型以及第三分类模型,其中所述第二分类模型为所述第一分类模型经过量化后得到的模型,所述第三分类模型为与第一模型不同的模型;

训练单元,用于基于样本数据集合,采用循环迭代方式执行以下步骤,直到训练完毕为止:

将样本数据分别输入所述第一分类模型和所述第二分类模型,得到相应的第一处理结果和第二处理结果;

将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果;

将第三处理结果与真实结果进行误差比较,获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整;输出单元,用于输出训练完毕的第二分类模型。

可选的,确定第一分类模型,第二分类模型以及第三分类模型之前,生成单元进一步用于:

将训练完成后的第一分类模型直接进行量化,将量化后的第一分类模型作为第二分类模型;

或者,

将带有模拟量化的噪声数据的样本数据集合,送入所述第一分类模型进行训练,输出训练完成后的第一分类模型,作为第二分类模型。

可选的,将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,训练单元用于:

将所述样本数据、所述第一分类模型和所述第二分类模型在分类过程中使用的各个标签、所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,所述第三处理结果表征所述第一处理结果是否由所述第一分类模型输出,以及所述第二处理结果是否由所述第二分类模型输出。

可选的,采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整,训练单元用于:

采用所述梯度,对所述第一分类模型以及所述第二分类模型分别进行参数调整,使得参数调整后的第一分类模型再次输出的第一处理结果与参数调整后的第二分类模型再次输出的第二处理结果趋于相同;

采用所述梯度,对所述第三分类模型进行参数调整,使得参数调整后的第三分类模型再次输出的第三处理结果的判断准确率提升。

可选的,输出训练完毕的第二分类模型,输出单元用于:

当所述第三分类模型不能正确区分所述第一处理结果和所述第二处理结果时,输出训练完毕的第二分类模型;

或者,

当所述第一处理结果和所述第二处理结果之间的误差小于预设的误差门限值时,输出训练完毕的第二分类模型。

第三方面,一种基于生成对抗网络的量化模型训练的装置,包括:

存储器,用于存储可执行计算机程序;

处理器,用于读取并执行所述存储器中存储的可执行指令,以实现如上述第一方面中任一项所述的方法。

第四方面,一种计算机可读存储介质,当所述存储介质中的指令由处理器执行时,使得所述处理器能够执行如上述第一方面中任一项所述的方法。

本公开实施例中,服务器采用对抗方式对第一分类模型、第一分类模型量化后的第二分类模型,以及第三分类模型进行训练,即将样本数据分别输入第一分类模型和第二分类模型,再将获得的第一处理结果和第二处理结果,以及上述样本数据,输入第三分类模型,得到第三处理结果,再通过对第三处理结果与真实结果进行比对获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整,最终,输出训练完毕的第二分类模型。这样,在与第三分类模型对抗的过程中,使得第一分类模型与第二分类模型输出结果尽可能相同,既降低了第一分类模型在低比特量化时的性能损失,也保证了第二分类模型的精度不受影响,从而有效确保了第二分类模型的性能。

附图说明

图1为本公开实施例中基于生成对抗网络的量化模型训练的流程示意图;

图2为本公开实施例中第一分类模型量化为第二分类模型的一种方法示意图;

图3为本公开实施例中第一分类模型量化为第二分类模型的一种方法示意图;

图4为本公开实施例中第一分类模型输出第一处理结果的示意图;

图5为本公开实施例中第二分类模型输出第二处理结果的示意图;

图6为本公开实施例中第三分类模型输出第三处理结果的示意图;

图7为本公开实施例中参数调整后的第一分类模型输出第一处理结果的示意图;

图8为本公开实施例中参数调整后的第二分类模型输出第二处理结果的示意图;

图9为本公开实施例中参数调整后的第三分类模型输出第三处理结果的示的意图;

图10为本公开实施例中服务器的逻辑架构示意图;

图11为本公开实施例中服务器的实体架构示意图。

具体实施方式

为了降低模型在低比特量化时的性能损失,本公开实施例中,服务器将样本数据,以及第一分类模型基于样本数据输出的第一处理结果和第二分类模型基于样本数据输出的第二处理结果,送入第三分类模型,得到第三处理结果,并基于第三处理结果对第一分类模型,第二分类模型以及第三分类模型进行参数调整,从而实现对抗性训练。

下面结合附图对本公开优选的实施方式作出进一步详细说明。

参阅图1所示,本公开实施例中,基于生成对抗网络的量化模型训练具体流程如下:

步骤100:服务器确定第一分类模型,第二分类模型以及第三分类模型。

本公开实施例中,第二分类模型是第一分类模型量化后的分类模型,第三分类模型是与第一分类模型不同的分类模型。

具体实施例,基于第一分类模型获得第二分类模型的方式,包含但不限于以下两种中任意一种:

A、服务器将训练完成后的第一分类模型直接进行量化,将量化后的第一分类模型作为第二分类模型。

例如,参阅图2所示,以一个训练完成的分类模型G1为例,假设G1记为第一分类模型,服务器可以采取将G1直接量化的方法,得到量化后的分类模型G2,将G2作为第二分类模型;

B、服务器将带有模拟量化的噪声数据的样本数据集合,送入所述第一分类模型进行训练,输出训练完成后的第一分类模型,作为第二分类模型。

例如,参阅图3所示,仍以一个训练完成的分类模型G1为例,假设G1记为第一分类模型,服务器可以将带有模拟量化噪声数据的样本数据集合,送入G1进行训练,得到重新训练完成后的分类模型G2,将G2作为第二分类模型。其中,所述带有模拟量化噪声数据的样本数据集合,指包含有原始图像信息和将原始图像信息模拟量化后的图像信息的样本数据集合。

可选的,本公开实施例中,以图像信息作为样本数据进行模型训练,而实际应用中样本数据包括但不限于图像信息,音频信息,文字信息等等。本公开实施例中仅以图像信息为例进行说明,后续实施例中将不再赘述。

步骤110:服务器将样本数据分别输入所述第一分类模型和所述第二分类模型,得到相应的第一处理结果和第二处理结果。

具体的,服务器将样本数据输入到第一分类模型中,获得第一分类模型输出的第一处理结果,服务器将样本数据输入到第二分类模型中,获得第二处理结果。

参阅图4所示,以一个松树图片作为样本数据为例,服务器将松树图片输入到第一分类模型G1中,第一分类模型G1输出第一处理结果O1为:松树。

参阅图5所示,服务器将同样的样本数据输入到第二分类模型G2中,第二分类模型G2输出第二分类结果O2为:树。

步骤120:服务器将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果。

具体的,服务器在获得所述第一处理结果和第二处理结果后,将所述样本数据与所述第一处理结果和所述第二处理结果送入第三分类模型,所述第三分类模型输出第三处理结果,其中,所述第三处理结果表征所述第一处理结果是否由所述第一分类模型输出,以及所述第二处理结果是否由所述第二分类模型输出。

参阅图6所示,假设第一处理结果为O1,第二处理结果为O2,服务器将O1、O2和样本数据松树送入第三分类模型,输出的第三处理结果P为:O1由第一分类模型G1输出,O2由第二分类模型G2输出。

步骤130:服务器将第三处理结果与真实结果进行误差比较,获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整;

具体的,服务器将获得的第三处理结果,与真实结果进行比较,获得相应的误差,基于所述误差,获得相应的梯度,采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整;

进一步的,服务器令参数调整后的第一分类模型再次输出的第一处理结果与参数调整后的第二分类模型再次输出的第二处理结果趋于相同,同时,令参数调整后的第三分类模型再次输出的第三处理结果的判断准确率提升。

仍以上述第一分类模型G1,第二分类模型G2,第三分类模型D为例,将第三分类模型输出的结果P与真实结果进行误差比较,基于所述误差,采用梯度下降法计算相应的梯度,采用所述梯度,对G1、G2和D进行参数调整。

仍以上述第一分类模型G1,第二分类模型G2,第三分类模型D为例,假设向G1输入一张图片,且所述图片的特征向量为[3,H,W],则G1输出一个概率分布Q[N_C],其中,针对不同的图像分类任务,所述概率分布具有不同的类别,例如,猫狗二分类中,假设0.9表示猫,0.1表示狗,则概率分布为[0.9,0.1],G2的输入和输出与G1相同,向D中输入一张图片,且所述图片的特征向量为[3,H,W],概率分布Q[N_C]以及一个表示图像原始类别的标签,其中,假设输入的概率分布为Q[0.1],输入的图像原始类别的标签为狗,则D输出概率分布P[A],P[A]表示Q来自于G1还是来自于G2的概率,基于获得的输出结果以及交叉熵损失函数(cross-entropy loss),获得损失,将所述损失返回至D以及G1或G2,对所述损失采用梯度下降法进行计算,获得计算结果,基于所述计算结果,对G1,G2,D进行参数调整。

采用上述方式对第一分类模型G1、第二分类模型G2和第三分类模型D进行参数调整,得到经过参数调整后的第一分类模型G1a,第二分类模型G2a,第三分类模型Da。

步骤140:第三分类模型是否可以正确区分第一处理结果和第二处理结果,若是,则执行步骤110,否则,执行步骤150。

具体的,当第三分类模型可以正确区分第一处理结果和第二处理结果,则说明第一分类模型与第二分类模型之间的性能损失仍存在,需要继续训练,调整参数;当第三分类不能正确区分第一处理结果和第二处理结果时,则说明第一分类模型与第二分类模型之间的性能损失已经可以视为不存在,可以继续执行下一个步骤。

仍以上述第一分类模型G1,第二分类模型G2,第三分类模型D为例,第三分类模型D输出的第三处理结果P,正确的区分了第一处理结果O1由第一分类模型G1输出,第二处理结果O2由第二分类模型G2输出。因此,需要对第一分类模型G1,第二分类模型G2,第三分类模型D进行参数调整,得到经过参数调整后的第一分类模型G1a,第二分类模型G2a,第三分类模型Da。进而,采用经过参数调整后的第一分类模型G1a,第二分类模型G2a,第三分类模型Da从步骤110开始执行流程。

重新执行步骤110,参阅图7所示,以一个苹果图片作为样本数据为例,服务器将苹果图片输入到第一分类模型G1a中,第一分类模型G1a输出第一处理结果O1a为:苹果。

参阅图8所示,服务器将同样的样本数据输入到第二分类模型G2a中,第二分类模型G2a输出第二分类结果O2a为:苹果。

重新执行步骤120,参阅图9所示,假设第一处理结果为O1a,第二处理结果为O2a,服务器将O1a、O2a和样本数据苹果送入第三分类模型,输出的第三处理结果Pa为:O1a由第一分类模型G1a输出,O2a由第一分类模型G1a输出。

重新执行步骤130,对第一分类模型G1a、第二分类模型G2a和第三分类模型Da进行参数调整,得到经过参数调整后的第一分类模型G1b,第二分类模型G2b,第三分类模型Db。

第三分类模型Da不能正确区分第一处理结果O1a和第二处理结果O2a,因此不进行循环,开始执行下一个步骤。

步骤150:服务器输出训练完毕的第二分类模型。

具体的,当第三分类模型或者第一处理结果和第二处理结果之间满足以下任意一种情况,服务器输出训练完毕的第二分类模型:

1)当第三分类模型不能正确区分第一处理结果和第二处理结果时,则说明第一分类模型与第二分类模型之间的性能损失已经可以视为不存在,进而可以说明此时第二分类模型已经训练完成,可以输出训练完成的第二分类模型。

具体的,服务器可以将经过参数调整的第二分类模型G2b,作为最终的模型输出。

2)当第一处理结果和第二处理结果之间的误差小于预设的误差门限值时,也可以说明第一分类模型与第二分类模型之间的性能损失已经可以视为不存在,进而可以说明此时第二分类模型已经训练完成,可以输出训练完毕的第二分类模型。

在本公开实施例中,预设只要第一处理结果和第二处理结果输出为同一小类时,认定训练完成。

例如,对同一样本数据:苹果图片,若第一处理结果为:苹果,第二处理结果为:苹果。此时,第一处理结果与第二处理结果为同一小类。即使第三分类模型仍可以正确区分第一处理结果和第二处理结果,但是仍可以直接将经过参数调整的第二分类模型G2c,作为最终的模型输出。

上述步骤110—步骤130为一个对抗过程,所述对抗过程指:在循环执行步骤110—步骤130的过程中,服务器令第一分类模型,以及第二分类模型通过参数调整,使第一处理结果和第二处理结果趋于相同,而在循环执行步骤110—步骤130的过程中,令第三模型可以正确的区分第一处理结果是否由第一分类模型输出以及第二处理结果是否有第二分类模型输出,即不断提高第三处理结果的判断准确率。

可以看出,在循环执行步骤110—步骤130的过程中,第一分类模型和第二分类模型参数调整目的是阻止第三分类模型正确区分第一分类模型输出的第一处理结果和第二分类模型输出的第二处理结果。第三分类模型参数调整的目的是尽可能的正确区分第一分类模型输出的第一处理结果和第二分类模型输出的第二处理结果,因此将这个过程称为对抗过程。

基于同一发明构思,参阅图10所示,本公开实施例提供一种基于生成对抗网络的量化模型训练的装置(如,一种服务器),包括:

生成单元1001,用于确定第一分类模型,第二分类模型以及第三分类模型,其中所述第二分类模型为所述第一分类模型经过量化后得到的模型,所述第三分类模型为与第一模型不同的模型;

训练单元1002,用于基于样本数据集合,采用循环迭代方式执行以下步骤,直到训练完毕为止:

将样本数据分别输入所述第一分类模型和所述第二分类模型,得到相应的第一处理结果和第二处理结果;

将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果;

将第三处理结果与真实结果进行误差比较,获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整;输出单元1003,用于输出训练完毕的第二分类模型。

可选的,确定第一分类模型,第二分类模型以及第三分类模型之前,生成单元1001进一步用于:

将训练完成后的第一分类模型直接进行量化,将量化后的第一分类模型作为第二分类模型;

或者,

将带有模拟量化的噪声数据的样本数据集合,送入所述第一分类模型进行训练,输出训练完成后的第一分类模型,作为第二分类模型。

可选的,将所述样本数据,所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,训练单元1002用于:

将所述样本数据、所述第一分类模型和所述第二分类模型在分类过程中使用的各个标签、所述第一处理结果和所述第二处理结果,输入所述第三分类模型,得到第三处理结果,所述第三处理结果表征所述第一处理结果是否由所述第一分类模型输出,以及所述第二处理结果是否由所述第二分类模型输出。

可选的,采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整,训练单元1002用于:

采用所述梯度,对所述第一分类模型以及所述第二分类模型分别进行参数调整,使得参数调整后的第一分类模型再次输出的第一处理结果与参数调整后的第二分类模型再次输出的第二处理结果趋于相同;

采用所述梯度,对所述第三分类模型进行参数调整,使得参数调整后的第三分类模型再次输出的第三处理结果的判断准确率提升。

可选的,输出训练完毕的第二分类模型,输出单元1003用于:

当所述第三分类模型不能正确区分所述第一处理结果和所述第二处理结果时,输出训练完毕的第二分类模型;

或者,

当所述第一处理结果和所述第二处理结果之间的误差小于预设的误差门限值时,输出训练完毕的第二分类模型。

基于同一发明构思,参阅图11所示,本公开实施例提供一种服务器,包括:

存储器1101,用于存储可执行计算机程序;

处理器1102,用于读取并执行所述存储器中存储的可执行指令,以实现上述各个实施例中服务器执行的任意一种方法。

基于同一发明构思,本公开实施例提供一种计算机可读存储介质,当所述存储介质中的指令由处理器执行时,使得所述处理器能够执行上述各个实施例中服务器执行的任意一种方法。

综上所述,本公开实施例中,服务器采用对抗方式对第一分类模型、第一分类模型量化后的第二分类模型,以及第三分类模型进行训练,即将样本数据分别输入第一分类模型和第二分类模型,再将获得的第一处理结果和第二处理结果,以及上述样本数据,输入第三分类模型,得到第三处理结果,再通过对第三处理结果与真实结果进行比对获得相应的梯度,并采用所述梯度对第一分类模型,第二分类模型以及第三分类模型进行参数调整,最终,输出训练完毕的第二分类模型。这样,在与第三分类模型对抗的过程中,使得第一分类模型与第二分类模型输出结果尽可能相同,既降低了第一分类模型在低比特量化时的性能损失,也保证了第二分类模型的精度不受影响,从而有效确保了第二分类模型的性能。

本领域内的技术人员应明白,本公开的实施例可提供为方法、系统、或计算机程序产品。因此,本公开可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本公开可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。

本公开是参照根据本公开实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。

这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。

这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。

尽管已描述了本公开的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本公开范围的所有变更和修改。

显然,本领域的技术人员可以对本公开实施例进行各种改动和变型而不脱离本公开实施例的精神和范围。这样,倘若本公开实施例的这些修改和变型属于本公开权利要求及其等同技术的范围之内,则本公开也意图包含这些改动和变型在内。

相关技术
  • 一种基于生成对抗网络的量化模型训练的方法及装置
  • 一种基于生成对抗网络的模型训练方法及设备
技术分类

06120112280979