掌桥专利:专业的专利平台
掌桥专利
首页

一种结合领域通用型语言模型的领域泛化方法

文献发布时间:2023-06-19 16:04:54



技术领域

本发明涉及一种结合领域通用型语言模型的领域泛化方法,属于人工智能迁移学习技术领域。

背景技术

在人工智能领域,领域泛化(Domain Generalization,DG)技术,是从若干个具有不同数据分布的数据集(领域)中学习一个泛化能力强的模型,以便在未知的测试集上取得较好的效果。例如,给定一个由餐厅、购物等领域的评论文本组成的训练集,要求训练一个良好的机器学习模型能够在对图书领域的数据集上进行分类时具有最小的预测误差。

现有的领域泛化方法,主要从数据操作、表示学习和学习策略三方面进行研究。其中,数据操作主要利用数据增强和数据生成两种技术来帮助学习一般的表示;表示学习主要通过学习域不变表示学习或利用特征解耦来得到域共享表示,提高模型泛化性能;学习策略侧重于利用通用学习策略促进泛化能力,主要包括集成学习、元学习和梯度操作。

领域泛化任务中的目标域数据不可访问性,这使得领域泛化更具挑战性和实用性。例如,在自然语言处理领域,由于预训练语言模型过度参数化的特性能够带来更小的学习误差,研究人员将其用于领域泛化任务,利用多个源域的数据对BERT等预训练语言模型进行微调,再将其在不同的目标域数据集上进行测试。结果发现,预训练语言模型能够比传统的模型拥有更强的泛化能力。但是,由于预训练语言模型中存在一些对于特定领域有效的参数,使得模型内部的神经元会在某些领域数据上激活而其他领域不激活。这种参数的领域不一致性,使得预训练语言模型的领域泛化能力下降。

发明内容

本发明的目的是为了解决因预训练语言模型存在特定域有效参数导致的泛化性能下降的技术问题,创造性地提出了一种结合领域通用型语言模型的领域泛化方法。本方法综合了预训练语言模型和模型裁剪技术。

对于领域泛化任务而言,泛化模型需要在所有域上都具备较好的泛化性能。为了去除预训练语言模型中的只对特定领域有效的参数来提高模型的泛化能力,本方法定义了域不变分数来识别特定域有效参数,通过去除特定域有效参数,保留域不变参数,得到领域通用型语言模型。然后,将领域通用型语言模型用于领域泛化,显著提高了模型的领域泛化性能。

首先,对预训练语言模型微调,利用多个源域数据对预训练语言模型进行训练。基于微调后得到的模型,计算模型中参数的域不变分数,对域不变分数低的参数进行裁剪。最后,对裁剪后的语言模型进行重训练,将训练得到领域通用型语言模型在不同数据上进行泛化性能测试。

本发明采用的技术方式如下:

一种结合领域通用型语言模型的领域泛化方法,包括以下步骤:

步骤1:预训练语言模型微调。

使用预训练语言模型(例如BERT)在给定的源域数据进行训练,利用使用多层感知器(MLP)微调预训练语言模型。其中,多层感知器包含四层:全连接层、双曲正切函数(Tanh)激活函数层、随机丢弃层(dropout层)和全连接层。

利用训练好的预训练语言模型,在目标域数据上获得经过全连接层的输出表示并送至软最大化标准化层(softmax层),对目标域数据进行相应标签预测。

步骤2:计算参数域不变分数。

本发明中,仅对预训练语言中的多头注意力MHA和前馈神经网络FFN模块进行裁剪。

对于待裁剪的参数,当数据集中只有一种领域的数据时,其对应的参数重要程度分数I如下式所示:

其中,

在基于参数重要程度分数I的基础上,将跨领域的参数重要程度分数的期望与方差纳入考虑范围,提出了参数域不变分数I′。对于待裁剪的参数,其对应的参数域不变分数I′如下式所示:

其中,(x,y)是指领域d中的数据点,D是指领域集合。V表示方差,E表示期望。参数域不变分数对将跨领域的参数重要程度分数的均值与方差进行平衡,参数λ用以权衡二者之间的关系。

对参数而言,该参数的域不变分数越大,说明该参数的在各领域上的泛化能力更优。反之,若域不变分数越小,说明该参数领域泛化性弱,仅在某些领域上有效。

步骤3:参数裁剪。

如步骤2所述,对于每个参与域不变分数计算的参数,都有对应的裁剪变量,用以表示该参数是否被裁剪。

在对参数进行域不变分数计算后,根据域不变分数对参数进行升序排列,并优先对域不变分数低的参数进行裁剪。

具体地,当ξ

步骤4:对裁剪后的模型重训练。

对参数进行裁剪后,将裁剪后的模型进行重训练。其中,重训练需要将裁剪后的模型置为步骤1的初始状态,再让裁剪后的模型在给定的多个源域数据进行训练。然后来对目标域数据进行相应标签预测。

通过设置不同的裁剪率,得到领域泛化效果最好的裁剪后的模型。至此,即获得了领域通用型语言模型。

步骤5:利用领域通用型语言模型,对训练领域数据以外的其他领域数据进行分类预测。

有益效果

本方法,对比现有技术,具有以下优点:

1.本发明解决了过度参数化的预训练语言模型中学习方差大的问题,基于域不变分数对预训练语言中特定域有效参数进行裁剪,保留下对领域泛化更有用的通用域有效参数。

2.本发明提出的领域通用型语言模型,在亚马逊评论数据集和多类型自然语言推理数据库上的明显优于相应的基线模型,具体表现为领域通用型语言模型的准确率得分比相应基线模型平均提高1.5个百分点。领域通用型语言模型在自然语言处理领域泛化任务上取得了最新的最好效果。

附图说明

图1是本发明的整体流程图。

图2是构建领域通用型语言模型的模型结构图。

具体实施方式

下面结合附图和实施例对本发明进一步详细描述。

实施例

如图1、图2所示,一种结合领域通用型语言模型的领域泛化方法,包括以下步骤:

步骤1:预训练语言模型微调。

具体地,包括以下步骤:

步骤1.1:加载多领域评论语料集,数据集分为训练集、验证集和测试集,并构造成批数据形式。

步骤1.2:加载预训练语言模型M,初始化后保存。其中,预训练语言模型,可以是BERT-base或BERT-large模型。

步骤1.3:模型训练。

批数据再经过BERT结构后获得句子向量表示。本方法使用多层感知器(MLP)来微调预训练语言模型。本方法中的多层感知器MLP包含四层:全连接层、ReLU(线性整流函数)激活函数层、dropout层(随机丢弃层)和全连接层。最后,将经过全连接层的输出表示送至softmax层(软最大化标准化层)以预测相应的标签。其中,模型训练的目标函数为交叉熵函数,具体表示形式如下:

其中,m是标签类别的数量,c表示m的某一类别,N为训练样本个数,y

模型训练为达到最小交叉熵损失,采用随机梯度下降法对其进行优化。在模型训练过程中,每一次训练后,用验证集数据对模型进行效果评价,此处采用的评价指标为各领域的平均准确率。在每轮验证后,保存效果最优的模型M′。

步骤1.4:效果评价。

利用测试集数据对步骤1.3获得的模型M′进行效果评价。首先加载最优模型M′,将测试集数据作为模型的输入,预测步骤与步骤1.3相同,此处使用的评价指标与步骤1.3相同。

步骤2:计算参数域不变分数。

具体地,包括以下步骤:

步骤2.1:加载步骤1.3保存的最优模型M′和训练集数据。

步骤2.2:将训练集数据输入模型M′,对参数计算其在各个领域的重要程度分数,此处参数特指MHA(多头注意力)和FFN(前馈网络),计算公式如下所示:

其中,(x,y)是指数据点,

步骤2.3:基于各个领域的参数重要程度分数,计算参数的域不变分数,计算公式如下所示:

其中,(x,y)是指领域d中的数据点,D是指领域集合。参数域不变分数对将跨领域的参数重要程度分数的均值与方差进行平衡,用参数λ用以权衡二者之间的关系。

步骤2.4:基于参数的域不变分数对参数进行升序排列并保存。

步骤3:参数裁剪。

步骤3.1:设置参数裁剪的比例。

步骤3.2:加载步骤1.2保存的预训练模型M。

步骤3.3:根据裁剪比例计算出待裁剪参数数量n,将步骤2.4保存的参数序列中的前n个参数在模型M中进行裁剪,具体操作为将参数对应的裁剪变量的值设为0。对裁剪后模型M″进行保存。

步骤4:裁剪后模型重训练。

具体地,包括以下步骤:

步骤4.1:加载多领域评论语料集和裁剪后模型M″。

步骤4.2:模型训练。具体方法同步骤1.3。训练后得到领域通用型模型。

步骤4.3:效果评价。对领域通用型模型进行效果评价,具体方法同步骤1.4。

步骤5:利用领域通用型语言模型,对训练领域数据以外的其他领域数据进行分类预测。

领域通用型语言模型由于去除了特定域有效的参数,在其他领域数据上预测的效果会优于原有的预训练模型。例如,基于领域为“电影”,“软件”和“自动汽车”的训练集数据,利用本方法获得的领域通用型语言模型在预测“工业”领域数据的标签时,由于其具有领域通用性,模型的预测效果会远超于原有的预训练模型。

以上所述为本发明的较佳实施例,本发明不应该局限于该实施例和附图所公开的内容。凡是不脱离本发明所公开的精神下完成的等效或修改,都落入本发明保护的范围。

技术分类

06120114696210