掌桥专利:专业的专利平台
掌桥专利
首页

一种深度强化学习训练方法及计算机可读存储介质

文献发布时间:2023-06-19 11:02:01


一种深度强化学习训练方法及计算机可读存储介质

技术领域

本发明涉及人工智能技术领域,尤其涉及一种深度强化学习训练方法及计算机可读存储介质。

背景技术

在强化学习领域,深度神经网络强大的学习能力使得智能体直接从高维连续环境中学习有效的控制策略成为可能。理论上,为了实现稳定的训练性能,神经网络一般要求训练数据满足独立同分布(i.i.d.)的特点,这在一般的强化学习范式中几乎是不可能成立的。强化学习边探索边学习的训练模式使得训练数据具有高度时间相关和非平稳的固有属性,由于神经网络在训练过程前后采用的训练数据分布不同,后期训练得到的权重很可能干扰甚至完全覆盖前期已经学习到的好的策略,从而导致模型性能受到干扰甚至是突然崩溃,使得模型训练过程非常不稳定,甚至很难收敛到优策略。对应于实际具体应用,如人工智能围棋系统等各类游戏对战、机器人调优工业设备参数等工业自动化应用、自动驾驶领域车辆运动规划等凡是利用强化学习来自动化寻求最佳序贯决策的真实应用场景,则表现为强化学习智能体在特定环境中学习完成特定任务的策略过程非常不稳定,随着学习的进行,智能体可能会突然忘记已经学习到的稍好的策略以致于面对相应的环境场景做出错误的决策,从而必须重新从头开始再次学习,后期再次遗忘并再次重新学习,如此反复,使得智能体学习优策略的效率大大降低,甚至最终无法学习到完成相应任务的优策略。

以上问题被称为灾难性干扰和遗忘(Catastrophic Interference andForgetting)。现有基于值的深度强化学习训练框架一般采用经验回放和固定目标网络两种策略来缓解灾难性干扰和遗忘问题,其中,经验回放对计算内存有很高的要求,尤其是当处理复杂图像或视频输入问题时,为了能更好地产生近似独立同分布的训练数据,需要设置百万甚至更高级别的经验存储缓冲区大小,这对一般计算机而言是非常困难的;此外,固定目标网络也只能使输出目标相对平稳,单独使用时对灾难性干扰和遗忘问题改善效果非常有限。

现有技术中缺乏解决强化学习领域神经网络模型在训练过程中所遇到的灾难性干扰和遗忘问题的方案。

以上背景技术内容的公开仅用于辅助理解本发明的构思及技术方案,其并不必然属于本专利申请的现有技术,在没有明确的证据表明上述内容在本专利申请的申请日已经公开的情况下,上述背景技术不应当用于评价本申请的新颖性和创造性。

发明内容

本发明为解决现有深度强化学习神经网络模型在训练过程中普遍遭遇的灾难性干扰和遗忘问题,提供一种深度强化学习训练方法及计算机可读存储介质。

为了解决上述问题,本发明采用的技术方案如下所述:

一种深度强化学习训练方法,包括如下步骤:S1:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;S2:依据所述情境数量,采用在线聚类算法实现自适应情境划分,对当前时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心;S3:从所述经验回放缓冲区随机采样小批量样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离依次将各所述样本分配至距离最近的所述情境中;S4:依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;S5:下一时间步,智能体依据所述值函数继续决策,收集样本存于所述经验回放缓冲区,重复进行所述自适应情境划分和所述深度强化学习多头神经网络模型的权重参数更新迭代,直至所述深度强化学习多头神经网络模型完成预先指定的训练次数或达到收敛。

优选地,指定所述情境数量k,其中,k>1;选用一个共享特征提取器和一组线性输出头组成的神经网络结构参数化值函数,每个线性输出头对应于一个特定情境;初始化所述深度强化学习多头神经网络模型的权重参数

优选地,收集样本存于所述经验回放缓冲区,在t时刻收集到的样本表示为{s

优选地,对所述深度强化学习多头神经网络模型训练过程中经历的所有状态进行划分得到有限个簇,每个所述簇称为一个情境ω,Ω={ω

优选地,利用Sequential K-Means算法对当前时刻t智能体所处的环境状态s

优选地,对每个时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心,具体操作包括:当前时刻t下,各所述情境中心为

优选地,从所述经验回放缓冲区

优选地,依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数包括:所述深度强化学习多头神经网络模型值函数估计的原始损失函数为

优选地,所述深度强化学习多头神经网络模型是DQN算法,所述深度强化学习多头神经网络模型值函数估计的原始损失函数

当前时刻智能体所处的环境状态s

其他输出头对应的蒸馏损失为:

联合优化损失函数如下:

其中,λ∈[0,1]为控制深度强化学习多头神经网络模型可塑性和稳定性平衡系数。

本发明还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上任一所述方法的步骤。

本发明的有益效果为:提供一种深度强化学习训练方法及计算机可读存储介质,通过基于情境划分和知识蒸馏的深度强化学习训练框架,联合在线聚类算法实现自适应情境划分,并采用蒸馏损失和多头神经网络架构对各情境下深度强化学习模型值函数进行更有针对性地估计,大幅提升深度强化学习模型训练的稳定性和可塑性。

进一步地,本发明的方法可方便地整合到各类现有的基于值的深度强化学习模型中,大大减少模型训练过程中因训练数据分布漂移而产生的灾难性干扰和遗忘,并显著降低现有模型对计算内存的需求,大幅提升模型的稳定性和可塑性,提升其在各类强化学习任务上的性能,具有很强的通用应用价值。

附图说明

图1是本发明实施例中强化学习模型训练过程中数据分布漂移示例图。

图2是本发明实施例中一种深度强化学习训练方法的示意图。

图3是本发明实施例中一种基于情境划分和知识蒸馏的深度强化学习训练框架示意图。

图4(a)-图4(c)分别是本发明实施例中对现有深度强化学习代表性算法DQN及利用本发明的方法分别在经验回放缓冲区存储容量为1、100和50000下训练OpenAI Gym经典控制游戏CartPole-v1所获得的累计奖励曲线展示。

具体实施方式

为了使本发明实施例所要解决的技术问题、技术方案及有益效果更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。

需要说明的是,当元件被称为“固定于”或“设置于”另一个元件,它可以直接在另一个元件上或者间接在另一个元件上。当一个元件被称为是“连接于”另一个元件,它可以是直接连接到另一个元件或间接连接至另一个元件上。另外,连接既可以是用于固定作用也可以是用于电路连通作用。

需要理解的是,术语“长度”、“宽度”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明实施例和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。

此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多该特征。在本发明实施例的描述中,“多个”的含义是两个或两个以上,除非另有明确具体的限定。

如图1所示,为强化学习模型训练过程中数据分布漂移示例图。实线记录了模型训练过程中遇到的状态的分布情况,虚线展示了相应时刻的模型训练性能。该图显示了模型训练过程中数据分布及模型训练性能的动态变化,揭示了干扰和遗忘产生的内部机理。在T3时刻之前,随着模型训练,数据分布逐步从P1转移至P2再至P3,当神经网络逐步拟合至P3,模型权重被很大程度地更新,使得已学习的P1和P2分布上的信息受到干扰甚至被完全覆盖,因此,当智能体再次遇到P1分布的状态时会突然无法做出正确决策,从而导致模型性能突然下降,此时模型必须在P1分布上重新学习。图1显示了训练数据分布漂移导致的灾难性干扰和遗忘伴随着模型性能的急剧波动。

本发明的目的是为了解决现有基于值的深度强化学习模型训练过程中普遍存在的因训练数据分布漂移引起神经网络出现灾难性干扰和遗忘,从而导致模型训练过程中性能非常不稳定,甚至无法学习到优策略的问题。

如图2和图3所示,为了解决上述技术问题,本发明提供一种深度强化学习训练方法,包括如下步骤:

S1:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;

S2:依据所述情境数量,采用在线聚类算法实现自适应情境划分,对当前时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心;

S3:从所述经验回放缓冲区随机采样小批量样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离依次将各所述样本分配至距离最近的所述情境中;

S4:依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;

S5:下一时间步,智能体依据所述值函数继续决策,收集样本存于所述经验回放缓冲区,重复进行所述自适应情境划分和所述深度强化学习多头神经网络模型的权重参数更新迭代,直至所述深度强化学习多头神经网络模型完成预先指定的训练次数或达到收敛。

本发明的方法通过基于情境划分和知识蒸馏的深度强化学习训练框架,联合在线聚类算法实现自适应情境划分,并采用蒸馏损失和多头神经网络架构对各情境下强化学习模型值函数进行更有针对性地估计,大幅提升强化学习模型训练的稳定性和可塑性。

进一步地,与现有技术相比,本发明具有如下有益效果:

大大减少模型训练过程中因训练数据分布漂移而产生的灾难性干扰和遗忘,大幅提升模型的稳定性和可塑性;

显著降低现有模型对计算内存的需求;

可方便地整合到各类现有的基于值的深度强化学习模型中,提升其在各类强化学习任务上的性能,具有很强的通用应用价值。

在一种具体的实施例中,指定情境数量k,其中,k>1;选用一个共享特征提取器和一组线性输出头组成的神经网络结构参数化值函数,每个线性输出头对应于一个特定情境;初始化深度强化学习多头神经网络模型的权重参数

本发明将针对一个特定任务估计一个值函数的问题转化为针对任务中包含的多个情境分别估计一个单独的值函数问题,从而解耦不同情境包含的状态间的干扰;同时共享特征提取器也能最大限度地促进不同情境在特征提取层的正向泛化,加速训练进程。

可以理解的是,理论上,k值越大则代表情境划分粒度越细,在训练次数足够大的情况下,所获得的模型性能越好,但k值太大会造成神经网络输出头太多从而导致模型非常复杂,增加了训练难度,因此k值也不宜过大。依据经验,对任务情境进行粗略划分,如k=3~5,即可实现明显地模型训练性能改进。

本发明中只需要指定一个超参数来控制情境划分的细粒度,从而确保本发明的方法在实践中的可用性。

收集样本存于经验回放缓冲区,在t时刻收集到的样本表示为{s

对深度强化学习多头神经网络模型训练过程中经历的所有状态进行划分得到有限个簇,每个簇称为一个情境ω,Ω={ω

在本发明中,将单个强化学习环境划分为若干个情境,并针对每个情境中包含的状态采用共享特征提取器和一组特定于单个情境的线性输出头组成的神经网络结构进行值函数估计;多个情境间共享特征提取层,提升特征提取层训练效率,加速训练进程。

情境划分:在本发明的一种实施例中,利用Sequential K-Means算法对当前时刻t智能体所处的环境状态s

进一步地,对每个时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心,具体操作包括:

当前时刻t下,各情境中心为

其中,i∈{1,2,...,k},j=argmin

本发明中采用多头神经网络结构对每个情境中包含的状态对应的值函数分别进行估计,解耦不同情境状态间对神经网络训练的干扰,提升网络训练稳定性和可塑性。

状态分配:从经验回放缓冲区

样本总体表示为:

联合优化:依据样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数包括:

所述深度强化学习多头神经网络模型值函数估计的原始损失函数为

本发明采用知识蒸馏正则化损失对神经网络参数进行优化,最大限度保留网络已学习的知识。知识蒸馏包含两项内容,分别是训练当前输入状态对应输出头的蒸馏损失和同步更新其他输出头的蒸馏损失,其中,前项蒸馏损失表示为

深度强化学习多头神经网络模型可以为任一基于值函数的深度强化学习模型,以深度强化学习代表性算法DQN为例,深度强化学习多头神经网络模型值函数估计的原始损失函数

当前时刻智能体所处的环境状态s

其他输出头对应的蒸馏损失为:

联合优化损失函数如下:

其中,λ∈[0,1]为控制深度强化学习多头神经网络模型可塑性和稳定性平衡系数。

智能体依据所得值函数进行下一步决策,存于所述经验回放缓冲区,重复进行自适应情境划分和深度强化学习多头神经网络模型的权重参数更新迭代,直至深度强化学习多头神经网络模型完成预先指定的训练次数T或达到收敛,最终得到训练好的模型参数θ和情境划分中心

模型部署:指导智能体决策以完成相应任务。对于当前状态s,首先依据其与各情境中心距离判断其所属情境:

计算第j个输出头对应的Q值

如图4所示,是在不同大小的经验回放缓冲区容量设置下,分别采用深度强化学习代表性算法DQN及利用本发明提出的训练框架训练的DQN在OpenAI Gym经典控制游戏CartPole-v1上训练所获得的累计奖励曲线展示。从图中可以看出,在不同大小的缓冲区容量下,原始的DQN方法在训练过程中都出现了非常明显的灾难性遗忘和性能波动,尤其是当缓冲区容量非常小(为1)时,DQN模型根本无法学习到最优策略而实现最大的累计奖励。对比之下,融合了本发明提出的训练框架的DQN方法在不同缓冲区容量设置下的训练性能都稳定得多,且即使是缓冲区容量设置为1时也还是可以学习到完成任务的最优策略,实现最大的累计奖励。

如表1所示,是对图3中所示两种方法训练曲线分别就训练过程中所达到的最大累计奖励和最大的累计奖励下降比例两种指标进行统计所得的结果。

表1统计所得的结果

从表中所示结果可以再次印证对图3分析所得出的结论。不论是在哪种缓冲区容量设置下,融合本发明所提出的训练框架的DQN方法都获得了最大的累计奖励(即,融合本发明提出的训练框架的DQN模型具有很强的可塑性),并且在训练过程中,累计奖励波动的最大值都比原始的DQN小得多(即,融合本发明提出的训练框架的DQN模型具有很好的稳定性)。

本申请实施例还提供一种控制装置,包括处理器和用于存储计算机程序的存储介质;其中,处理器用于执行所述计算机程序时至少执行如上所述的方法。

本申请实施例还提供一种存储介质,用于存储计算机程序,该计算机程序被执行时至少执行如上所述的方法。

本申请实施例还提供一种处理器,所述处理器执行计算机程序,至少执行如上所述的方法。

所述存储介质可以由任何类型的易失性或非易失性存储设备、或者它们的组合来实现。其中,非易失性存储器可以是只读存储器(ROM,Read Only Memory)、可编程只读存储器(PROM,Programmable Read-Only Memory)、可擦除可编程只读存储器(EPROM,ErasableProgrammable Read-Only Memory)、电可擦除可编程只读存储器(EEPROM,ElectricallyErasable Programmable Read-Only Memory)、磁性随机存取存储器(FRAM,FerromagneticRandom Access Memory)、快闪存储器(Flash Memory)、磁表面存储器、光盘、或只读光盘(CD-ROM,Compact Disc Read-Only Memory);磁表面存储器可以是磁盘存储器或磁带存储器。易失性存储器可以是随机存取存储器(RAM,Random Access Memory),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(SRAM,Static Random Access Memory)、同步静态随机存取存储器(SSRAM,SynchronousStatic Random Access Memory)、动态随机存取存储器(DRAM,Dynamic Random AccessMemory)、同步动态随机存取存储器(SDRAM,Synchronous Dynamic Random AccessMemory)、双倍数据速率同步动态随机存取存储器(DDRSDRAM,Double Data RateSynchronous Dynamic Random Access Memory)、增强型同步动态随机存取存储器(ESDRAM,Enhanced Synchronous Dynamic Random Access Memory)、同步连接动态随机存取存储器(SLDRAM,Sync Link Dynamic Random Access Memory)、直接内存总线随机存取存储器(DRRAM,Direct Rambus Random Access Memory)。本发明实施例描述的存储介质旨在包括但不限于这些和任意其它适合类型的存储器。

在本申请所提供的几个实施例中,应该理解到,所揭露的系统和方法,可以通过其它的方式实现。以上所描述的设备实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,如:多个单元或组件可以结合,或可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的各组成部分相互之间的耦合、或直接耦合、或通信连接可以是通过一些接口,设备或单元的间接耦合或通信连接,可以是电性的、机械的或其它形式的。

上述作为分离部件说明的单元可以是、或也可以不是物理上分开的,作为单元显示的部件可以是、或也可以不是物理单元,即可以位于一个地方,也可以分布到多个网络单元上;可以根据实际的需要选择其中的部分或全部单元来实现本实施例方案的目的。

另外,在本发明各实施例中的各功能单元可以全部集成在一个处理单元中,也可以是各单元分别单独作为一个单元,也可以两个或两个以上单元集成在一个单元中;上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。

本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述的程序可以存储于一计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质包括:移动存储设备、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。

或者,本发明上述集成的单元如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实施例的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机、服务器、或者网络设备等)执行本发明各个实施例所述方法的全部或部分。而前述的存储介质包括:移动存储设备、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。

本申请所提供的几个方法实施例中所揭露的方法,在不冲突的情况下可以任意组合,得到新的方法实施例。

本申请所提供的几个产品实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的产品实施例。

本申请所提供的几个方法或设备实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的方法实施例或设备实施例。

以上内容是结合具体的优选实施方式对本发明所做的进一步详细说明,不能认定本发明的具体实施只局限于这些说明。对于本发明所属技术领域的技术人员来说,在不脱离本发明构思的前提下,还可以做出若干等同替代或明显变型,而且性能或用途相同,都应当视为属于本发明的保护范围。

相关技术
  • 一种深度强化学习训练方法及计算机可读存储介质
  • 存储器的数据训练方法、计算机装置及计算机可读存储介质
技术分类

06120112773138