掌桥专利:专业的专利平台
掌桥专利
首页

一种模型训练方法及装置

文献发布时间:2024-01-17 01:16:56


一种模型训练方法及装置

技术领域

本公开涉及人工智能技术领域,尤其涉及一种模型训练方法及装置。

背景技术

近年来,强化学习的应用场景逐渐增多,主要的应用场景如下:1.机器人控制:强化学习可以用于控制机器人,以实现机器人自主学习和行动;2.自动驾驶:强化学习可以用于自动驾驶,使车辆能够在复杂的环境中安全驾驶;3.游戏:强化学习可以用于游戏,让游戏角色能够从自身的行为中学习,从而更好地控制游戏。4.无人机:强化学习可以用于控制无人机,让无人机能够自主学习和行动,以实现更好的空中控制。5.金融:强化学习可以用于金融交易系统,以提高金融交易系统的准确性和效率。

神经网络的强化学习场景中常用的是深度强化学习,深度强化学习是一种使用深度神经网络的强化学习方法,对复杂问题的解决较为友好,比如运行游戏,控制机器人,控制自动驾驶等。深度强化学习主要通过不断的学习和实践,通过模拟环境中的行为,来最大化未来的奖励。深度强化学习使用深度神经网络来实现,深度强化学习可以处理高维度和非线性的环境,并能够更好地学习和表现。但是,目前的强化学习方法也面临模型精度难以保证的技术困难。由于目前基于神经网络模型的强化学习方法往往需要建立精确的环境模型,且模型的准确性会随着环境的变化而变化,使得基于现有的强化学习方法训练得到的模型的预测准确性较差。另外,目前的强化学习算法存在一定的问题,包括强化学习训练过程中模型的收敛速度慢,在高维度状态空间的环境收敛效果更差,且存在模型过拟合的问题,这样,导致模型泛化能力差,且进一步导致在模型使用时,模型的预测精度较差。因此,亟需一种新的针对神经网络的强化学习训练方法。

发明内容

有鉴于此,本公开实施例提供了一种模型训练方法、装置、计算机设备及计算机可读存储介质,以解决现有技术中基于现有的强化学习方法训练得到的模型的预测准确性较差的问题。

本公开实施例的第一方面,提供了一种模型训练方法,所述方法包括:

获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的;

利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型;

将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量;

将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量;

利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。

本公开实施例的第二方面,提供了一种模型训练装置,所述装置包括:

集合获取单元,用于获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的;

第一训练单元,用于利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型;

第二训练单元,用于将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量;

第三训练单元,用于将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量;

模型调整单元,用于利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。

本公开实施例的第三方面,提供了一种计算机设备,包括存储器、处理器以及存储在存储器中并且可以在处理器上运行的计算机程序,该处理器执行计算机程序时实现上述方法的步骤。

本公开实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。

本公开实施例与现有技术相比存在的有益效果是:本公开实施例可以先获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的。然后,可以利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型。接着,可以将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量。紧接着,可以将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量。最后,可以利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。可见,在本实施例中,先利用原始训练样本集对决策模型进行强化学习训练,接着,利用原始样本图片和增强样本图片对决策模型和泛化模型进行无监督的对比学习训练,这样,可以通过基于对比学习的强化学习方式对决策模型进行训练,可以让决策模型能够从泛化模型所得到的已有的经验中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。并且,由于增强样本图片为根据原始样本图片所确定的,这样,本实施例可以实现在原始样本图片的基础上,通过原始样本图片的处理,新增得到增强样本图片,并利用基于对比学习的强化学习算法使得决策模型进行原始样本图片及增强样本图片的对比学习,从而可以使决策模型可以更好地适应变化后的图片的预测效果。综上所述,本实施例可以通过基于对比学习的强化学习方式对决策模型进行训练,提升决策模型的预测结果的准确性和精度。

附图说明

为了更清楚地说明本公开实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。

图1是本公开实施例的应用场景的场景示意图;

图2是本公开实施例提供的模型训练方法的流程图;

图3是本公开实施例提供的模型训练流程示意图;

图4是本公开实施例提供的模型训练装置的框图;

图5是本公开实施例提供的计算机设备的示意图。

具体实施方式

以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本公开实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本公开。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本公开的描述。

下面将结合附图详细说明根据本公开实施例的一种模型训练方法和装置。

在现有技术中,由于传统的的强化学习方法也面临模型精度难以保证的技术困难。由于目前基于神经网络模型的强化学习方法往往需要建立精确的环境模型,且模型的准确性会随着环境的变化而变化,使得基于现有的强化学习方法训练得到的模型的预测准确性较差。另外,目前的强化学习算法存在一定的问题,包括强化学习训练过程中模型的收敛速度慢,在高维度状态空间的环境收敛效果更差,且存在模型过拟合的问题,这样,导致模型泛化能力差,且进一步导致在模型使用时,模型的预测精度较差。因此,亟需一种新的针对神经网络的强化学习训练方法。

为了解决上述问题。本发明提供了一种模型训练方法,在本方法中,由于本实施例可以先利用原始训练样本集对决策模型进行强化学习训练,接着,利用原始样本图片和增强样本图片对决策模型和泛化模型进行无监督的对比学习训练,这样,可以通过基于对比学习的强化学习方式对决策模型进行训练,可以让决策模型能够从泛化模型所得到的已有的经验中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。并且,由于增强样本图片为根据原始样本图片所确定的,这样,本实施例可以实现在原始样本图片的基础上,通过原始样本图片的处理,新增得到增强样本图片,并利用基于对比学习的强化学习算法使得决策模型进行原始样本图片及增强样本图片的对比学习,从而可以使决策模型可以更好地适应变化后的图片的预测效果。综上所述,本实施例可以通过基于对比学习的强化学习方式对决策模型进行训练,提升决策模型的预测结果的准确性和精度。

举例说明,本发明实施例可以应用到如图1所示的应用场景。在该场景中,可以包括终端设备1和服务器2。

终端设备1可以是硬件,也可以是软件。当终端设备1为硬件时,其可以是具有显示屏且支持与服务器2通信的各种电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机和台式计算机等;当终端设备1为软件时,其可以安装在如上该的电子设备中。终端设备1可以实现为多个软件或软件模块,也可以实现为单个软件或软件模块,本公开实施例对此不作限制。进一步地,终端设备1上可以安装有各种应用,例如数据处理应用、即时通信工具、社交平台软件、搜索类应用、购物类应用等。

服务器2可以是提供各种服务的服务器,例如,对与其建立通信连接的终端设备发送的请求进行接收的后台服务器,该后台服务器可以对终端设备发送的请求进行接收和分析等处理,并生成处理结果。服务器2可以是一台服务器,也可以是由若干台服务器组成的服务器集群,或者还可以是一个云计算服务中心,本公开实施例对此不作限制。

需要说明的是,服务器2可以是硬件,也可以是软件。当服务器2为硬件时,其可以是为终端设备1提供各种服务的各种电子设备。当服务器2为软件时,其可以是为终端设备1提供各种服务的多个软件或软件模块,也可以是为终端设备1提供各种服务的单个软件或软件模块,本公开实施例对此不作限制。

终端设备1与服务器2可以通过网络进行通信连接。网络可以是采用同轴电缆、双绞线和光纤连接的有线网络,也可以是无需布线就能实现各种通信设备互联的无线网络,例如,蓝牙(Bluetooth)、近场通信(Near Field Communication,NFC)、红外(Infrared)等,本公开实施例对此不作限制。

具体地,用户可以通过终端设备1输入原始训练样本集和增强训练样本集;终端设备1将原始训练样本集和增强训练样本集向服务器2发送;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的。服务器2存储有待训练的决策模型和泛化模型;服务器2可以先利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型;然后,服务器2可以将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量;接着,服务器2可以将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量;最后,服务器2可以利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。这样,由于本申请可以先利用原始训练样本集对决策模型进行强化学习训练,接着,利用原始样本图片和增强样本图片对决策模型和泛化模型进行无监督的对比学习训练,这样,可以通过基于对比学习的强化学习方式对决策模型进行训练,可以让决策模型能够从泛化模型所得到的已有的经验中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。并且,由于增强样本图片为根据原始样本图片所确定的,这样,本实施例可以实现在原始样本图片的基础上,通过原始样本图片的处理,新增得到增强样本图片,并利用基于对比学习的强化学习算法使得决策模型进行原始样本图片及增强样本图片的对比学习,从而可以使决策模型可以更好地适应变化后的图片的预测效果。综上所述,本实施例可以通过基于对比学习的强化学习方式对决策模型进行训练,提升决策模型的预测结果的准确性和精度。

需要说明的是,终端设备1和服务器2以及网络的具体类型、数量和组合可以根据应用场景的实际需求进行调整,本公开实施例对此不作限制。

需要注意的是,上述应用场景仅是为了便于理解本公开而示出,本公开的实施方式在此方面不受任何限制。相反,本公开的实施方式可以应用于适用的任何场景。

图2是本公开实施例提供的一种模型训练方法的流程图。图2的一种模型训练方法可以由图1的终端设备或服务器执行。如图2所示,该模型训练方法包括:

S201:获取原始训练样本集和增强训练样本集。

在本实施例中,所述原始训练样本集可以包括原始样本图片和所述原始样本图片对应的参考标签。在本实施例中,原始样本图片可以理解为直接采集到的且需要执行预设处理任务的图像,原始样本图片对应的参考标签可以理解为原始样本图片对应的真实的任务处理结果。例如,预设处理任务可以为图片类别识别,则对应的参考标签可以为真实类别标签,举例来说,假设原始样本图片中包括“猫”的图案,则原始样本图片对应的真实类别标签(即参考标签)可以为“动物”;假设原始样本图片中包括“树”的图案,则原始样本图片对应的真实类别标签(即参考标签)可以为“植物”。又例如,预设处理任务可以为识别图中目标对象的关键点,则原始样本图片对应的参考标签可以为真实关键点标签。

需要说明的是,原始训练样本集中的原始样本图片可以为利用图像采集设备(例如照像机、智能手机等)采集得到的图片,或者,设备中预先存储的图片,又或者,可以是从互联网上下载得到的。在本实施例中,不对原始训练样本集中的原始样本图片的获取方式进行限定。

在本实施例中,所述增强训练样本集可以包括增强样本图片,或者,所述增强训练样本集可以包括增强样本图片和增强样本图片对应的参考标签。增强样本图片对应的参考标签可以理解为增强样本图片对应的真实的任务处理结果。其中,所述增强样本图片可以为根据所述原始样本图片所确定的。在一种实现方式中,所述增强样本图片可以为对所述原始样本图片进行预设处理所得到的;其中,所述预设处理的方式包括以下至少一种:裁剪、局部覆盖、增加图像噪声。可以理解的是,增强样本图片对应的参考标签与生成该增强样本图片所基于的原始样本图片对应的参考标签是相同的,例如,预设处理任务为图片类别识别,原始样本图片A中包括“树”的图案,则原始样本图片对应的参考标签可以为“植物”,而增强样本图片B为对原始样本图片A进行裁剪所得到的图片,且增强样本图片B中仍然保留了“树”的图案,因此,增强样本图片B中的参考标签也为“植物”。

S202:利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型。

在本实施例中,决策模型可以理解为需要进行训练的神经网络模型,可以理解的是,决策模型在训练完成后,决策模型可以用于在执行图像处理的任务。在本实施例中,决策模型可以为卷积神经网络、循环神经网络、Transformers模型、多层全联通网络、多层自注意力网络等。在本实施例中,不对决策模型的具体神经网络类型进行限定。

在本实施例中,可以利用原始训练样本集对决策模型进行有监督的强化学习训练。例如,在一种实现方式中可以采用DQN(Deep Q Network)、A3C(AsynchronousAdvantage Actor-Critic)、PPO(Proximal Policy Optimization)、TRPO(Trust RegionPolicy Optimization)、PG(Policy Gradient)等强化学习算法中的一种强化学习算法对决策模型进行有监督的强化学习训练。在决策模型完成强化学习训练后,便可以得到训练后的决策模型。

接下来,以DQN强化学习算法为例,介绍如何利用DQN强化学习算法和原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型。DQN(Deep Q Network)是一种基于神经网络的深度强化学习算法,可以用来解决回合制强化学习问题,可以通过学习一个价值函数来指导决策模型选择动作,从而达到最优化决策模型行为的目的。DQN算法可以用来解决多种类型的强化学习问题,包括但不限于游戏控制、机器人控制、自动驾驶等。DQN算法的核心思想是使用决策模型(即神经网络)来学习一个价值函数,从而指导决策模型选择动作,从而达到最优化决策模型行为的目的。DQN算法的核心思想是将环境的状态和动作映射到一个Q值(价值),从而指导决策模型选择动作。DQN算法使用决策模型来学习这个Q值,通过反向传播算法来更新决策模型的模型参数,从而使得决策模型能够学习最优的行为策略。

DQN算法的主要过程如下:(1)环境描述:首先,定义环境,包括环境状态、可能的动作、奖励等;(2)状态表示:将环境中的状态进行表示,以便决策模型可以接受和处理;3)构建模型:构建DQN模型(即决策模型),包括网络结构、优化器、loss函数等;(4)训练:使用经验回放(Experience Replay)和目标网络(Target Network)来收集经验,并训练决策模型以更新Q值;(5)测试:使用完成训练的决策模型,在测试环境中运行,确定算法的训练效果。

需要说明的是,为了提升决策模型的强化学习的训练效果,除了使用所述原始训练样本集对决策模型进行强化学习训练,还可以使用少量的增强样本图片和增强样本图片对应的参考标签对决策模型进行强化学习训练。

S203:将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量。

在本实施例中,如图3所示,在对决策模型完成强化学习训练,得到训练后的决策模型之后,可以将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,以便通过训练后的决策模型对原始样本图片中的图像信息、语义信息等多维度信息进行提取,得到所述原始样本图片对应的第一特征向量。可以理解的是,原始样本图片对应的第一特征向量可以理解为通过训练后的决策模型所提取出的原始样本图片的特征向量(例如可以是低维表征向量);并且,原始样本图片对应的第一特征向量包括原始样本图片中的语义信息、图像信息等多维度信息,同时基本没有含有信道或噪声等无关信息。

S204:将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量。

在本实施例中,泛化模型可以理解为用于配合决策模型进行训练的神经网络模型;具体地,可以利用泛化模型对决策模型进行无监督的对比学习训练。在本实施例的一种实现方式中,所述决策模型和所述泛化模型的神经网络架构是相同的,且,所述决策模型和所述泛化模型的初始化模型参数是相同的;也就是说,在初始化时,决策模型和泛化模型是两个完全相同的神经网络模型,即神经网络架构和模型参数均是相同的,在后续的模型参数调整过程中,决策模型和泛化模型的神经网络架构是相同的,仅两者的模型参数可能有所不同。

在本实施例中,可以将原始训练样本集中的原始样本图片和增强训练样本集中的增强样本图片分别输入泛化模型。具体地,如图3所示,可以将原始样本图片输入泛化模型,得到原始样本图片对应的第二特征向量;原始样本图片对应的第二特征向量可以理解为通过泛化模型所提取出的原始样本图片的特征向量(例如可以是低维表征向量),并且,原始样本图片对应的第二特征向量包括原始样本图片中的语义信息、图像信息等多维度信息,同时基本没有含有信道或噪声等无关信息。以及,,如图3所示,可以将增强样本图片输入泛化模型,得到增强样本图片对应的第三特征向量;增强样本图片对应的第三特征向量可以理解为通过泛化模型所提取出的增强样本图片的特征向量(例如可以是低维表征向量),并且,增强样本图片对应的第三特征向量包括增强样本图片中的语义信息、图像信息等多维度信息,同时基本没有含有信道或噪声等无关信息。这样,泛化模型便可以得到原始样本图片对应的第二特征向量和增强样本图片对应的第三特征向量,并且,可以利用原始样本图片对应的第二特征向量和增强样本图片对应的第三特征向量,使得决策模型可以进行原始样本图片及增强样本图片的对比学习,即决策模型能够从泛化模型所得到的已有的经验(即第二特征向量和所述第三特征向量)中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。

S205:利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。

在得到第一特征向量、第二特征向量和第三特征向量后,可以利用第一特征向量、第二特征向量和第三特征向量之间的关系,对所述训练后的决策模型的模型参数进行调整。具体地,由于第一特征向量、第二特征向量均为从原始样本图片中所提取到的,而第三特征向量为从增强样本图片中提取到的,因此,第一特征向量和第二特征向量的距离应该更加接近,即第一特征向量和第二特征向量应该更相似,而第一特征向量和第三特征向量的距离应该更加远,即第一特征向量和第三特征向量应该更不相似。故,在得到所述第一特征向量、所述第二特征向量和所述第三特征向量后,可以根据第一特征向量分别与第二特征向量、第三特征向量之间的距离来对训练后的决策模型的模型参数进行调整,得到目标决策模型。这样,目标决策模型能够从泛化模型所得到的已有的经验(即第二特征向量和所述第三特征向量)中提取知识,可以借助之前学习的经验来推断出新环境中的行为,更容易把握新环境。并且,由于目标决策模型是先经过强化学习,再经过对比学习得到的,因此,基于对比方式的强化学习还可以让目标决策模型可以更高效地训练,因为该训练方法使用了对比学习和扩增数据(即增强样本图片)进行训练,能相对减少对真实互动数据的依赖和采集。

本公开实施例与现有技术相比存在的有益效果是:本公开实施例可以先获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的。然后,可以利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型。接着,可以将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量。紧接着,可以将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量。最后,可以利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。可见,在本实施例中,先利用原始训练样本集对决策模型进行强化学习训练,接着,利用原始样本图片和增强样本图片对决策模型和泛化模型进行无监督的对比学习训练,这样,可以通过基于对比学习的强化学习方式对决策模型进行训练,可以让决策模型能够从泛化模型所得到的已有的经验中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。并且,由于增强样本图片为根据原始样本图片所确定的,这样,本实施例可以实现在原始样本图片的基础上,通过原始样本图片的处理,新增得到增强样本图片,并利用基于对比学习的强化学习算法使得决策模型进行原始样本图片及增强样本图片的对比学习,从而可以使决策模型可以更好地适应变化后的图片的预测效果。综上所述,本实施例可以通过基于对比学习的强化学习方式对决策模型进行训练,提升决策模型的预测结果的准确性和精度。

可以理解的是,本实施例所提供的训练方法,可以基于对比方式的强化学习的优势,可以让机器学习系统(例如决策模型)能够从已有的经验中提取知识,并使得该过程更加鲁棒,能减少模型(例如决策模型)对训练数据的过度拟合。让机器学习系统(例如决策模型)更容易把握新环境,因为机器学习系统(例如决策模型)可以借助之前学习的经验来推断出新环境中的行为。基于对比方式的强化学习还可以让机器学习系统(例如决策模型)更高效地训练,因为该方法使用了对比学习和扩增数据,能相对减少对真实互动数据的依赖和采集。

在一些实施例中,S202“利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型”的步骤,可以包括以下步骤:

S202a:将所述原始训练样本集中的原始样本图片输入所述决策模型,得到所述原始样本图片对应的第一特征向量。

在本实施例中,,如图3所示,可以先将原始训练样本集中的原始样本图片先输入决策模型,决策模型可以输出原始样本图片对应的第一特征向量。

S202b:将所述原始样本图片对应的第一特征向量输入预设的分类器或回归器,得到所述原始样本图片对应的预测标签。

如图3所示,在获取到原始样本图片对应的第一特征向量后,可以将所述原始样本图片对应的第一特征向量输入预设的分类器或回归器,得到所述原始样本图片对应的预测标签。其中,原始样本图片对应的预测标签可以理解为分类器或回归器基于第一特征向量分析所确定预测标签。需要说明的是,在执行分类任务时,可以采用分类器确定预测标签;在执行确定对象位置或图像区域的任务时,可以采用回归器确定预测标签。

S202c:根据所述原始样本图片对应的预测标签和参考标签,确定所述原始样本图片对应的奖励值。

在本实施例中,在得到原始样本图片对应的预测标签和参考标签后,可以根据原始样本图片对应的预测标签和参考标签,确定所述目标推荐策略对应的奖励值。可以理解的是,奖励值可以反映预测标签的情况,若预测标签越接近参考标签,则奖励值越高,反之,若预测标签与参考标签的差距越大,则奖励值越低。例如,如图2所示,若预测标签与参考标签相同,则原始样本图片对应的奖励值为1,若预测标签与参考标签不相同,则原始样本图片对应的奖励值为0。

需要说明的是,由于奖励值可以反映预测标签和参考标签的相似情况,因此,可以利用奖励值不断完善决策模型,使决策模型的训练数据能够自主实现目标的直接经验来源。决策模型可以通过接受奖励值判断决策模型所预测的标签的准确性,从而通过选择收益高更高的行为使决策模型趋于目标状态。

S202d:利用所述原始样本图片、所述原始样本图片对应的参考标签、所述原始样本图片对应的预测标签和所述原始样本图片对应的奖励值,对所述决策模型的模型参数进行调整,得到训练后的决策模型。

在获取到原始样本图片对应的参考标签、预测标签和所述原始样本图片对应的奖励值后,可以利用所述原始样本图片、所述原始样本图片对应的参考标签、预测标签和所述原始样本图片对应的奖励值,对决策模型的模型参数进行调整,以便可以使决策模型可以更好地适应变化后的图片的预测效果。

具体地,在一种实现方式中,可以先根据原始样本图片对应的参考标签、预测标签,计算强化学习损失值。然后,可以利用强化学习损失值和所述原始样本图片对应的奖励值,对决策模型的模型参数进行调整。在本实施例中,对决策模型的模型参数进行调整的方式可以采用预设策略优化算法。在一种实现方式中,所述预设策略优化算法可以包括以下至少一种:策略梯度算法、演员-批评家算法、PPO近端策略优化算法。即可以采用预设策略优化算法,利用所述原始样本图片、所述原始样本图片对应的参考标签、所述原始样本图片对应的预测标签和所述原始样本图片对应的奖励值更新决策模型的模型参数。

在一些实施例中,S205“利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型”的步骤,可以包括以下步骤:

S205a:利用所述第一特征向量、所述第二特征向量和所述第三特征向量,分别对所述训练后的决策模型的模型参数和所述泛化模型的模型参数进行调整,得到调整后的决策模型和调整后的泛化模型。

作为一种示例,在本实施例中,可以先根据所述第一特征向量、所述第二特征向量和所述第三特征向量,确定对比学习损失值。例如,可以利用三元组损失(triplet loss)函数或lifted structure损失函数,计算所述第一特征向量、所述第二特征向量和所述第三特征向量的对比学习损失值。

然后,可以利用所述对比学习损失值分别对所述训练后的决策模型的模型参数和所述泛化模型的模型参数进行调整,得到调整后的决策模型和调整后的泛化模型。即,可以利用对比学习损失值对所述训练后的决策模型的模型参数进行调整,得到调整后的决策模型,以及,可以利用对比学习损失值对所述泛化模型的模型参数进行调整,得到调整后的泛化模型。

S205b:若所述调整后的决策模型的模型参数满足预设条件,则将所述调整后的决策模型作为所述目标决策模型。

在一种实现方式中,所述预设条件可以为调整后的决策模型的模型参数为收敛状态,或者,所述决策模型的训练次数满足预设次数阈值。因此,在本实施例中,若所述调整后的决策模型的模型参数满足预设条件,说明调整后的决策模型的模型参数已经达到要求,则将所述调整后的决策模型作为所述目标决策模型。

在一些实施例中,在所述S205a的步骤之后,所述方法还包括:

步骤a:若所述调整后的决策模型的模型参数不满足所述预设条件,根据所述调整后的决策模型的模型参数和所述调整后的泛化模型的模型参数,确定目标模型参数。

步骤b:根据所述目标模型参数,对所述调整后的泛化模型的模型参数进行更新,得到更新后的泛化模型;以及,继续执行所述利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型的步骤,直至所述决策模型的模型参数满足所述预设条件。

在本实施例中,若所述调整后的决策模型的模型参数不满足所述预设条件,说明调整后的决策模型的模型参数未达到要求,因此,需要继续对决策模型进行训练。

作为一种示例,可以先根据所述调整后的决策模型的模型参数和所述调整后的泛化模型的模型参数,确定目标模型参数。在一种实现方式中,可以先将所述调整后的泛化模型的模型参数与预设的第一权重值的乘积作为第一参数值,例如,假设调整后的泛化模型的模型参数为q0,预设的第一权重值为m,则第一参数值为m*q0。然后,可以将所述调整后的决策模型的模型参数与预设的第二权重值的乘积作为第二参数值,其中,所述第二权重值是根据所述第一权重值所确定的;例如,假设调整后的决策模型的模型参数为p1,预设的第一权重值为m,预设的第二权重值为m-1,则第二参数值为(m-1)*p1。接着,可以将所述第一参数值与所述第二参数值之和作为目标模型参数,例如,第一参数值为m*q0,第二参数值为(m-1)*p1,则目标模型参数q1为m*q0+(m-1)*p1,即q1=m*q0+(m-1)*p1。其中,m可以为0到1之间的小数,例如,m可以取值为0.8到0.9之间的任一数字,比如m可以为0.8、0.9、0.85等。

在确定目标模型参数后,可以根据所述目标模型参数,对所述调整后的泛化模型的模型参数进行更新,得到更新后的泛化模型。也就是说,可以将目标模型参数作为更新后的泛化模型的模型参数。以及,继续执行所述利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型的步骤,直至所述决策模型的模型参数满足所述预设条件;即,将调整后的决策模型作为S202、S203中的决策模型,将更新后的泛化模型作为S204中的泛化模型,以及,继续执行S202-S205的步骤,直至所述决策模型的模型参数满足所述预设条件。

上述所有可选技术方案,可以采用任意结合形成本公开的可选实施例,在此不再一一赘述。

下述为本公开装置实施例,可以用于执行本公开方法实施例。对于本公开装置实施例中未披露的细节,请参照本公开方法实施例。

图4是本公开实施例提供的模型训练装置的示意图。如图4所示,该模型训练装置包括:

集合获取单元401,用于获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的;

第一训练单元402,用于利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型;

第二训练单元403,用于将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量;

第三训练单元404,用于将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量;

模型调整单元405,用于利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。

可选的,所述增强样本图片为对所述原始样本图片进行预设处理所得到的;其中,所述预设处理的方式包括以下至少一种:裁剪、局部覆盖、增加图像噪声。

可选的,所述第一训练单元402,具体用于:

将所述原始训练样本集中的原始样本图片输入所述决策模型,得到所述原始样本图片对应的第一特征向量;

将所述原始样本图片对应的第一特征向量输入预设的分类器或回归器,得到所述原始样本图片对应的预测标签;

根据所述原始样本图片对应的预测标签和参考标签,确定所述原始样本图片对应的奖励值;

利用所述原始样本图片、所述原始样本图片对应的参考标签、所述原始样本图片对应的预测标签和所述原始样本图片对应的奖励值,对所述决策模型的模型参数进行调整,得到训练后的决策模型。

可选的,所述模型调整单元405,具体用于:

利用所述第一特征向量、所述第二特征向量和所述第三特征向量,分别对所述训练后的决策模型的模型参数和所述泛化模型的模型参数进行调整,得到调整后的决策模型和调整后的泛化模型;

若所述调整后的决策模型的模型参数满足预设条件,则将所述调整后的决策模型作为所述目标决策模型。

可选的,所述模型调整单元405,具体用于:

根据所述第一特征向量、所述第二特征向量和所述第三特征向量,确定对比学习损失值;

利用所述对比学习损失值分别对所述训练后的决策模型的模型参数和所述泛化模型的模型参数进行调整,得到调整后的决策模型和调整后的泛化模型。

可选的,所述装置还包括第四训练单元,用于:

若所述调整后的决策模型的模型参数不满足所述预设条件,根据所述调整后的决策模型的模型参数和所述调整后的泛化模型的模型参数,确定目标模型参数;

根据所述目标模型参数,对所述调整后的泛化模型的模型参数进行更新,得到更新后的泛化模型;以及,继续执行所述利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型的步骤,直至所述决策模型的模型参数满足所述预设条件。

可选的,所述第四训练单元,具体用于:

将所述调整后的泛化模型的模型参数与预设的第一权重值的乘积作为第一参数值;

将所述调整后的决策模型的模型参数与预设的第二权重值的乘积作为第二参数值;其中,所述第二权重值是根据所述第一权重值所确定的;

将所述第一参数值与所述第二参数值之和作为目标模型参数。

可选的,所述决策模型和所述泛化模型的神经网络架构是相同的,且,所述决策模型和所述泛化模型的初始化模型参数是相同的。

本公开实施例与现有技术相比存在的有益效果是:本公开实施例提供了一种模型训练装置,所述装置包括:集合获取单元,用于获取原始训练样本集和增强训练样本集;其中,所述原始训练样本集包括原始样本图片和所述原始样本图片对应的参考标签,所述增强训练样本集包括增强样本图片;所述增强样本图片为根据所述原始样本图片所确定的;第一训练单元,用于利用所述原始训练样本集对决策模型进行强化学习训练,得到训练后的决策模型;第二训练单元,用于将所述原始训练样本集中的原始样本图片输入所述训练后的决策模型,得到所述原始样本图片对应的第一特征向量;第三训练单元,用于将所述原始训练样本集中的原始样本图片和所述增强训练样本集中的增强样本图片分别输入泛化模型,得到所述原始样本图片对应的第二特征向量和所述增强样本图片对应的第三特征向量;模型调整单元,用于利用所述第一特征向量、所述第二特征向量和所述第三特征向量,对所述训练后的决策模型的模型参数进行调整,得到目标决策模型。可见,在本实施例中,先利用原始训练样本集对决策模型进行强化学习训练,接着,利用原始样本图片和增强样本图片对决策模型和泛化模型进行无监督的对比学习训练,这样,可以通过基于对比学习的强化学习方式对决策模型进行训练,可以让决策模型能够从泛化模型所得到的已有的经验中提取知识,并使得决策模型的训练过程更加鲁棒,能减少决策模型对训练样本集的过度拟合。并且,由于增强样本图片为根据原始样本图片所确定的,这样,本实施例可以实现在原始样本图片的基础上,通过原始样本图片的处理,新增得到增强样本图片,并利用基于对比学习的强化学习算法使得决策模型进行原始样本图片及增强样本图片的对比学习,从而可以使决策模型可以更好地适应变化后的图片的预测效果。综上所述,本实施例可以通过基于对比学习的强化学习方式对决策模型进行训练,提升决策模型的预测结果的准确性和精度。

可以理解的是,本实施例所提供的训练方法,可以基于对比方式的强化学习的优势,可以让机器学习系统(例如决策模型)能够从已有的经验中提取知识,并使得该过程更加鲁棒,能减少模型(例如决策模型)对训练数据的过度拟合。让机器学习系统(例如决策模型)更容易把握新环境,因为机器学习系统(例如决策模型)可以借助之前学习的经验来推断出新环境中的行为。基于对比方式的强化学习还可以让机器学习系统(例如决策模型)更高效地训练,因为该方法使用了对比学习和扩增数据,能相对减少对真实互动数据的依赖和采集。

应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本公开实施例的实施过程构成任何限定。

图5是本公开实施例提供的计算机设备5的示意图。如图5所示,该实施例的计算机设备5包括:处理器501、存储器502以及存储在该存储器502中并且可以在处理器501上运行的计算机程序503。处理器501执行计算机程序503时实现上述各个方法实施例中的步骤。或者,处理器501执行计算机程序503时实现上述各装置实施例中各模块/模块的功能。

示例性地,计算机程序503可以被分割成一个或多个模块/模块,一个或多个模块/模块被存储在存储器502中,并由处理器501执行,以完成本公开。一个或多个模块/模块可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述计算机程序503在计算机设备5中的执行过程。

计算机设备5可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算机设备。计算机设备5可以包括但不仅限于处理器501和存储器502。本领域技术人员可以理解,图5仅仅是计算机设备5的示例,并不构成对计算机设备5的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如,计算机设备还可以包括输入输出设备、网络接入设备、总线等。

处理器501可以是中央处理模块(Central Processing Unit,CPU),也可以是其它通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。

存储器502可以是计算机设备5的内部存储模块,例如,计算机设备5的硬盘或内存。存储器502也可以是计算机设备5的外部存储设备,例如,计算机设备5上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。进一步地,存储器502还可以既包括计算机设备5的内部存储模块也包括外部存储设备。存储器502用于存储计算机程序以及计算机设备所需的其它程序和数据。存储器502还可以用于暂时地存储已经输出或者将要输出的数据。

所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能模块、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块、模块完成,即将装置的内部结构划分成不同的功能模块或模块,以完成以上描述的全部或者部分功能。实施例中的各功能模块、模块可以集成在一个处理模块中,也可以是各个模块单独物理存在,也可以两个或两个以上模块集成在一个模块中,上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。另外,各功能模块、模块的具体名称也只是为了便于相互区分,并不用于限制本公开的保护范围。上述系统中模块、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。

在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。

本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的模块及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本公开的范围。

在本公开所提供的实施例中,应该理解到,所揭露的装置/计算机设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/计算机设备实施例仅仅是示意性的,例如,模块或模块的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,多个模块或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或模块的间接耦合或通讯连接,可以是电性,机械或其它的形式。

作为分离部件说明的模块可以是或者也可以不是物理上分开的,作为模块显示的部件可以是或者也可以不是物理模块,即可以位于一个地方,或者也可以分布到多个网络模块上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。

另外,在本公开各个实施例中的各功能模块可以集成在一个处理模块中,也可以是各个模块单独物理存在,也可以两个或两个以上模块集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用软件功能模块的形式实现。

集成的模块/模块如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本公开实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,计算机程序可以存储在计算机可读存储介质中,该计算机程序在被处理器执行时,可以实现上述各个方法实施例的步骤。计算机程序可以包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质等。需要说明的是,计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如,在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。

以上实施例仅用以说明本公开的技术方案,而非对其限制;尽管参照前述实施例对本公开进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本公开各实施例技术方案的精神和范围,均应包含在本公开的保护范围之内。

相关技术
  • 一种欺诈检测模型训练方法和装置及欺诈检测方法和装置
  • 一种神经网络模型训练方法及装置、文本标签确定方法及装置
  • 一种绿光估计模型训练方法及装置、影像合成方法及装置
  • 一种翻译模型的训练方法、翻译方法和装置
  • 一种检测模型的训练方法、装置及终端设备
  • 模型预训练方法、模型训练方法、数据处理方法及其装置
  • 预训练语言模型的训练方法、语言模型的训练方法及装置
技术分类

06120116106650