掌桥专利:专业的专利平台
掌桥专利
首页

基于RNN的语言模型的训练方法及相关装置

文献发布时间:2023-06-19 10:11:51


基于RNN的语言模型的训练方法及相关装置

技术领域

本发明涉及数据处理领域,更具体地说,涉及一种基于RNN的语言模型的训练方法及相关装置。

背景技术

基于RNN(Recurrent Neural Network,循环神经网络)的语言模型实际上是一种预测模型,预测空间为字典空间,得到的结果为预测空间中各个事件的后验概率。目前,考虑到计算效率的问题,在训练基于RNN的语言模型时,通常使用Nce(Noise-ContrastiveEstimation,噪声对比估计)函数。在训练模型时,将所有正样本分为n批,每一批(即batch)中包括若干条数据,针对每一批,准备一组负样本;负样本是从预设的词库中抽取出来的。利用每个batch正样本与对应的负样本,计算Nce损失;计算出Nce损失后,使用BPTT(Backpropagation Through Time,随时间反向传播)算法更新模型的参数。

发明人发现RNN使用的是随时间方向传播,即沿着时间序列一步一步向反向传播梯度,这很容易造成梯度消失。梯度消失会导致模型训练时候的梯度无法影响到参数调节,从而导致收敛不足,造成训练好的语言模型的预测结果有偏差。

发明内容

有鉴于此,本发明提出一种基于RNN的语言模型的训练方法及相关装置,欲提高语言模型的训练速度,以及提高语言模型的训练效果。

为了实现上述目的,现提出的方案如下:

第一方面,提供一种基于RNN的语言模型的训练方法,包括:

将一条长句子训练数据的元素按照时间划分为N部分,N≥2;

对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值;

根据计算得到的Nce损失值,利用BPTT计算得到梯度;

根据所述梯度更新基于RNN的语言模型的参数;

判断所述长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入执行对于所述长句子训练数据中还未计算Nce损失值的时间最早,结合与其对应的一组负样本,计算得到Nce损失值的步骤,若是,则结束。

优选的,在每次执行对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值的步骤之前,还包括:

从预设的元素库中随机抽取若干元素作为与所述长句子训练数据中还未计算Nce损失值的时间最早的部分对应的一组负样本。

优选的,在所述将一条长句子训练数据的元素按照时间划分为N部分的步骤之后,且在对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值的步骤之前,还包括:

从预设的元素库中随机抽取元素生成M组负样本,2≤M≤N;

将所述长句子训练数据的N部分划分为M组,且所述N部分划分的M组中的各组与所述M组负样本中的各组负样本一一对应。

优选的,所述基于RNN的语言模型,具体为:

基于LSTM的语言模型。

第二方面,提供一种基于RNN的语言模型的训练装置,包括:

元素划分单元,用于将一条长句子训练数据的元素按照时间划分为N部分,N≥2;

损失值计算单元,用于对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值;

梯度计算单元,用于根据计算得到的Nce损失值,利用BPTT计算得到梯度;

模型优化单元,用于根据所述梯度更新基于RNN的语言模型的参数;

判断单元,用于判断所述长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入执行损失值计算单元,若是,则结束。

优选的,所述基于RNN的语言模型的训练装置,还包括:

第一负样本生成单元,用于在每次执行所述损失值计算单元之前,从预设的元素库中随机抽取若干元素作为与所述长句子训练数据中还未计算Nce损失值的时间最早的部分对应的一组负样本。

优选的,所述基于RNN的语言模型的训练装置,还包括:

第二负样本生成单元,用于从预设的元素库中随机抽取元素生成M组负样本,2≤M≤N;以及将所述长句子训练数据的N部分划分为M组,且所述N部分划分的M组中的各组与所述M组负样本中的各组负样本一一对应。

优选的,所述基于RNN的语言模型,具体为:

基于LSTM的语言模型。

第三方面,提供一种可读存储介质,其上存储有程序,所述程序被处理器执行时,实现如第一方面中任意一种基于RNN的语言模型的训练方法的各个步骤。

第四方面,提供一种基于RNN的语言模型的训练设备,包括:存储器和处理器;

所述存储器,用于存储程序;

所述处理器,用于执行所述程序,实现如第一方面中任意一种基于RNN的语言模型的训练方法的各个步骤。

与现有技术相比,本发明的技术方案具有以下优点:

上述技术方案提供的一种基于RNN的语言模型的训练方法及相关装置,方法包括:将一条长句子训练数据的元素按照时间划分为至少两部分,针对每部分,计算一次Nce损失值;根据Nce损失值,利用BPTT计算梯度;最后根据计算得到的梯度更新基于RNN的语言模型的参数。将一条长句子训练数据的元素按照时间划分为至少两部分,针对每部分计算一次Nce损失值,多次更新基于RNN的语言模型的参数,相比于一条长句子训练数据更新一次基于RNN的语言模型的参数,提高了语言模型的训练速度。将一条长句子训练数据分为至少两部分,基于拆分后的各部分计算梯度优化模型,由于拆分后的各部分相比于未拆分的一条长句子训练数据,时序上的时间点变少,梯度消失的概率变小,可以提高了语言模型的训练效果,即提高了语言模型的预测准确率。

附图说明

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。

图1为基于RNN的语言模型进行训练的现有方案的原理图;

图2为本发明实施例提供的一种基于RNN的语言模型的训练方法的原理图;

图3为本发明实施例提供的一种基于RNN的语言模型的训练方法的流程图;

图4为本发明实施例提供的另一种基于RNN的语言模型的训练方法的流程图;

图5为本发明实施例提供的又一种基于RNN的语言模型的训练方法的流程图;

图6为本发明实施例提供的一种基于RNN的语言模型的训练装置的示意图;

图7为本发明实施例提供的一种基于RNN的语言模型的训练设备的示意图。

具体实施方式

下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

参见图1,为对基于RNN的语言模型进行训练的现有方案。X1、X2、……、Xn是一个长句子训练数据包括的所有元素。在将X1输入到基于RNN的语言模型计算Y1,将X2输入到基于RNN的语言模型计算Y2,……,以及将Xn输入到基于RNN的语言模型计算Yn时,均是使用的同一组负样本。这是因为需要把Y1、Y2、……、Yn都计算出来后,统一计算Nce loss(即Nce损失值),然后通过Nce损失值计算出一个向量(即梯度),更新基于RNN的语言模型中计算函数的参数,使其可以更好的拟合当前的数据分布。但是,通过Nce损失计算梯度的方法是BPTT,这个方法是跟时间相关的,有一个特性就是每条长句子训练数据的元素越多,发生梯度消失的几率越大。

为了解决上述问题,本发明的核心思路是将长句子训练数据的元素按照时间划分为至少两部分,针对每部分分别计算一次Nce损失,进而对语言模型进行一次更新。将一条长句子训练数据的元素按照时间划分为至少两部分,多次更新基于RNN的语言模型的参数,相比于一条长句子训练数据更新一次基于RNN的语言模型的参数,提高了语言模型的训练速度。将一条长句子训练数据分为至少两部分,基于拆分后的各部分分别计算梯度优化模型,由于拆分后的各部分相比于未拆分的一条长句子训练数据,时序上的时间点变少,梯度消失的概率变小,可以提高了语言模型的训练效果,即提高了语言模型的预测准确率。

参见图2,为本发明的实施例提供的基于RNN的语言模型的训练方法的示意图。该实施例中两个元素划分为一部分;时序最早的一部分结合与其对应的一组负样本,计算Nce损失值,更新语言模型的参数;之后再将剩下的时序最早的一部分结合与其对应的一组负样本,计算Nce损失值,更新语言模型的参数,直到对每一部分都计算了Nce损失值以及更新了语言模型的参数。在一个具体实施例中,基于RNN的语言模型,具体为基于LSTM的语言模型。元素具体可以是字,也可以是词,本发明中对此不做限定,都属于本发明的保护范围。

参见图3,为本发明实施例提供的一种基于RNN的语言模型的训练方法的流程图,该方法可以包括以下步骤:

S31:将一条长句子训练数据的元素按照时间划分为N部分,N≥2。

输入到循环神经网络的数据是时序的序列。一条长句子训练数据包括的每个元素表示时序上的一个时间点,每个时间点都会计算一个输出,这个输出的计算依靠一组负样本。例如,长句子训练数据为“今天天气怎么样”,今天、天气、怎么样分别为一个元素,即对应时序上的一个时间点。每输入到循环神经网络一个词,循环神经网络就输出截止目前为止,下一个最可能的词。

S32:对于长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值。

未计算Nce损失值的部分,指的是还没有利用该部分的元素计算输出,以及利用该部分的各个元素计算得到的所有输出计算得到Nce损失值。时间最早的部分指的是该部分中元素表示的时间点,比其它部分的元素表示的是时间点都早。

现有技术方案如图1所示,是在得到每个时间点的输出(即Y1、Y2、……、Yn-1和Yn)后,计算Nce损失值;而本发明是在得到长句子训练数据中某一部分的各个时间点的输出后,计算Nce损失值,示例性的,在得到图2中的输出Y1和Y2后,就计算Nce损失值,并进行一次后续的模型参数更新。具体的,对于长句子训练数据中的某一部分,结合与其对应的一组负样本,计算得到Nce损失值的过程为:在确定某一部分的正样本和多个负样本以后,根据Nce二分类准则,结合采样分布概率以及RNN输出的对应样本的代价值,就可以分别计算得到正样本被分类为正的概率、负样本被分类为负的概率,综合这些概率值就可以计算近似的期望损失函数值,即NCE损失值。

S33:根据计算得到的Nce损失值,利用BPTT计算得到梯度。

S34:根据计算得到的梯度更新基于RNN的语言模型的参数。

现有技术中Nce损失值,利用BPTT计算得到梯度的方法,本实施例都可以采用;现有技术中基于梯度更新基于RNN的语言模型的参数的方法,本实施例也都可以采用,本发明对此不做限定。

S35:判断长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入步骤S32,若是,则结束。

上述实施例提供的基于RNN的语言模型的训练方法,将一条长句子训练数据的元素按照时间划分为至少两部分,多次更新基于RNN的语言模型的参数,相比于一条长句子训练数据更新一次基于RNN的语言模型的参数,提高了语言模型的训练速度。将一条长句子训练数据分为至少两部分,基于拆分后的各部分分别计算梯度优化模型,由于拆分后的各部分相比于未拆分的一条长句子训练数据,时序上的时间点变少,梯度消失的概率变小,可以提高了语言模型的训练效果,即提高了语言模型的预测准确率。

进一步的,对于图1所示的方案,发明人还发现正样本的batch包括很多条数据,且batch中数据使用同一组负样本时,受到该组负样本的随机分布的影响,容易产生偏向性,进而导致训练好的语言模型的预测结果有偏差。针对该技术问题,本发明中对于一个batch,采用多组负样本,不同组负样本有不同的元素分布,随机性更好,降低了训练得到的语言模型的预测偏差。

参见图4,为本实施例提供的另一种基于RNN的语言模型的训练方法,该方法相对于图3所示的方法,在每次执行对于长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值的步骤之前,还包括:

步骤S42:从预设的元素库中随机抽取若干元素作为与长句子训练数据中还未计算Nce损失值的时间最早的部分对应的一组负样本。

预先设置一个元素库,该元素库包括的元素不是正确的元素。从该元素库中获取多个元素组成一组负样本。一组负样本中的每个元素都是一个负样本。一组负样本中元素的个数与实际应用要求有关,本发明对此不做限定。

步骤S41、S43、S44、S45、S46分别与S31、S32、S33、S34、S35一致,本实施例不再赘述。

参见图5,为本实施例提供的又一种基于RNN的语言模型的训练方法,该方法相对于图3所示的方法,在将一条长句子训练数据的元素按照时间划分为N部分的步骤之后,且在对于长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值的步骤之前,还包括:

步骤S52:从预设的元素库中随机抽取元素生成M组负样本,2≤M≤N。

执行步骤S52生成M组负样本,M的具体值可以根据实际应用来设定,本发明对此不做限定,只要2≤M≤N都属于本发明的保护范围。

步骤S53:将长句子训练数据的N部分划分为M组,且N部分划分的M组中的各组与M组负样本中的各组负样本一一对应。

当M=N时,长句子训练数据的N部分,每一部分均为一组,这样每一部分均对应一组负样本。

步骤S51、S54、S55、S56、S57分别与S31、S32、S33、S34、S35一致,本实施例不再赘述。

对于前述的各方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明并不受所描述的动作顺序的限制,因为依据本发明,某些步骤可以采用其他顺序或者同时进行。

参见图6,为本发明的实施例提供的一种基于RNN的语言模型的训练装置,包括元素划分单元61、损失值计算单元62、梯度计算单元63、模型优化单元64和判断单元65。

元素划分单元61,用于将一条长句子训练数据的元素按照时间划分为N部分,N≥2。

损失值计算单元62,用于对于长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值。

梯度计算单元63,用于根据计算得到的Nce损失值,利用BPTT计算得到梯度。

模型优化单元64,用于根据梯度更新基于RNN的语言模型的参数。

判断单元65,用于判断长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入执行损失值计算单元62,若是,则结束。

在一些具体实施例中,基于RNN的语言模型为基于LSTM的语言模型。

在一些具体实施例中,基于RNN的语言模型的训练装置,还包括:第一负样本生成单元,用于在每次执行损失值计算单元之前,从预设的元素库中随机抽取若干元素作为与长句子训练数据中还未计算Nce损失值的时间最早的部分对应的一组负样本。

在一些具体实施例中,基于RNN的语言模型的训练装置,还包括:第二负样本生成单元,用于从预设的元素库中随机抽取元素生成M组负样本,2≤M≤N;以及将长句子训练数据的N部分划分为M组,且N部分划分的M组中的各组与M组负样本中的各组负样本一一对应。

本实施例还提供一种基于RNN的语言模型的训练设备,如PC终端、云平台、服务器及服务器集群等。服务器可以是机架式服务器、刀片式服务器、塔式服务器以及机柜式服务器中的一种或几种。参见图7,为本实施例提供的一种基于RNN的语言模型的训练设备。该基于RNN的语言模型的训练设备的硬件结构可以包括:至少一个处理器71,至少一个通信接口72,至少一个存储器73和至少一个通信总线74;且处理器71、通信接口72、存储器73通过通信总线74完成相互间的通信。

处理器71在一些实施例中可以是一个CPU(Central Processing Unit,中央处理器),或者是ASIC(Application Specific Integrated Circuit,特定集成电路),或者是被配置成实施本发明实施例的一个或多个集成电路等。

通信接口72可以包括标准的有线接口、无线接口。通常用于在基于RNN的语言模型的训练设备与其他电子设备或系统之间建立通信连接。

存储器73包括至少一种类型的可读存储介质。可读存储介质可以为如闪存、硬盘、多媒体卡、卡型存储器等NVM(non-volatile memory,非易失性存储器)。可读存储介质还可以是高速RAM(random access memory,随机存取存储器)存储器。可读存储介质在一些实施例中可以是基于RNN的语言模型的训练设备的内部存储单元,例如该基于RNN的语言模型的训练设备的硬盘。在另一些实施例中,可读存储介质还可以是基于RNN的语言模型的训练设备的外部存储设备,例如该基于RNN的语言模型的训练设备上配备的插接式硬盘、SMC(Smart Media Card,智能存储卡)、SD(Secure Digital,安全数字)卡,闪存卡(FlashCard)等。

其中,存储器73存储有计算机程序,处理器71可调用存储器73存储的计算机程序,所述计算机程序用于:

将一条长句子训练数据的元素按照时间划分为N部分,N≥2;

对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值;

根据计算得到的Nce损失值,利用BPTT计算得到梯度;

根据所述梯度更新基于RNN的语言模型的参数;

判断所述长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入执行对于所述长句子训练数据中还未计算Nce损失值的时间最早,结合与其对应的一组负样本,计算得到Nce损失值的步骤,若是,则结束。

所述程序的细化功能和扩展功能可参照上文描述。

图7仅示出了具有组件71~74的基于RNN的语言模型的训练设备,但是应理解的是,并不要求实施所有示出的组件,可以替代的实施更多或者更少的组件。

可选地,该基于RNN的语言模型的训练设备还可以包括用户接口,用户接口可以包括输入单元(比如键盘)、语音输入装置(比如包含麦克风的具有语音识别功能的设备)和/或语音输出装置(比如音响、耳机等)。可选地,用户接口还可以包括标准的有线接口和/或无线接口。

可选地,该基于RNN的语言模型的训练设备还可以包括显示器,显示器也可以称为显示屏或显示单元。在一些实施例中可以是LED显示器、液晶显示器、触控式液晶显示器以及OLED(Organic Light-Emitting Diode,有机发光二极管)显示器等。

可选地,该基于RNN的语言模型的训练设备还包括触摸传感器。触摸传感器所提供的供用户进行触摸操作的区域称为触控区域。此外,触摸传感器可以为电阻式触摸传感器、电容式触摸传感器等。而且,触摸传感器不仅包括接触式的触摸传感器,也可包括接近式的触摸传感器等。此外,触摸传感器可以为单个传感器,也可以为例如阵列布置的多个传感器。用户可以通过触摸触控区域输入身份识别信息。

此外,该基于RNN的语言模型的训练设备的显示器的面积可以与触摸传感器的面积相同,也可以不同。可选地,将显示器与触摸传感器层叠设置,以形成触摸显示屏。该装置基于触摸显示屏侦测用户触发的触控操作。

该基于RNN的语言模型的训练设备还可以包括RF(Radio Frequency,射频)电路、传感器和音频电路等等,在此不再赘。

本发明实施例还提供一种可读存储介质,该可读存储介质可存储有适于处理器执行的程序,所述程序用于:

将一条长句子训练数据的元素按照时间划分为N部分,N≥2;

对于所述长句子训练数据中还未计算Nce损失值的时间最早的部分,结合与其对应的一组负样本,计算得到Nce损失值;

根据计算得到的Nce损失值,利用BPTT计算得到梯度;

根据所述梯度更新基于RNN的语言模型的参数;

判断所述长句子训练数据划分的所有部分是否均计算了Nce损失值,若否,则转入执行对于所述长句子训练数据中还未计算Nce损失值的时间最早,结合与其对应的一组负样本,计算得到Nce损失值的步骤,若是,则结束。

所述程序的细化功能和扩展功能可参照上文描述。

以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。

在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。

本说明书中各个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可,且本说明书中各实施例中记载的特征可以相互替换或者组合。

对本发明所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。

相关技术
  • 基于RNN的语言模型的训练方法及相关装置
  • 基于RNN的反洗钱模型的训练方法、装置、设备及介质
技术分类

06120112456464