掌桥专利:专业的专利平台
掌桥专利
首页

一种基于对抗迁移学习的预训练漏洞修复方法

文献发布时间:2024-04-18 19:58:21


一种基于对抗迁移学习的预训练漏洞修复方法

技术领域

本发明属于软件调试领域,具体涉及一种基于对抗迁移学习的预训练漏洞修复方法。

背景技术

随着软件漏洞的数量和复杂性的增加,开发人员需要对软件漏洞深入了解,并尽可能的减少对系统功能的影响,大大增加了软件漏洞修复的成本。为了减小软件漏洞修复成本,研究人员提出了自动修复软件漏洞的技术。但是从互联网能够采集到的漏洞修复数据集规模小,给研究人员带来了很大的挑战。

扬州大学在其申请的专利文献“一种基于树的漏洞修复系统及修复方法”(专利申请号:202210027014.7,公开号:CN114547619A)中提出了一种使用语法树表征代码进行漏洞代码自动修复的技术。该方法首先在GitHub上收集漏洞修复数据集,将漏洞修复数据集中的代码转为具有数据流依赖和控制流依赖的语法树AST,将所述语法树AST进行抽象化和规范化得到token序列,然后将所述token序列划分为训练集和测试集,将训练集和测试集输入具有相同编码器和解码器数量的Transformer模型进行训练和测试。该发明利用语法树和Transformer模型,实现了对代码的自动修复,提高了代码修复的效率。但是该方法依然存在不足:

(1)该方法仅依赖漏洞修复数据集进行模型训练,在漏洞修复数据集规模较小的现状下,个别CWE类型的漏洞在数据集里面数量少或者干扰较强,导致模型未完全学习到该漏洞的特征时,会使模型表现不佳,泛化性、鲁棒性减弱;

(2)该方法对代码数据集的抽象化和规范化处理时,会将数据的函数名、变量、值进行替换,不能让模型学习到代码潜在的语意,导致模型的代码理解能力差;

(3)该方法过度依赖Transformer模型生成修复代码,模型会将部分正确代码错误修复,导致模型过拟合。

本发明提出了一种基于对抗迁移学习的预训练漏洞修复方法,其优点在于:

(1)本发明通过在大型代码数据集上进行预训练后得到预训练的代码生成器模型,使得模型具备更好的代码理解能力、代码生成和补全能力;

(2)本发明借助生成对抗网络架构在漏洞修复数据集上对预训练的代码生成器模型进行微调,通过生成对抗网络的对抗训练机制,提升模型的抗干扰能力和修复能力,使得模型具备更高的鲁棒性、泛化性,同时解决模型过拟合的问题。

本发明首先在大型代码数据集上进行预训练后得到预训练的代码生成器模型,然后直接使用漏洞修复数据集进行对抗训练,使得模型减少源领域数据的依赖性,模型能够更好的适应目标领域的数据和特征分布,有助于减小源领域和目标领域的差异,并且提高了模型的训练速度,最终使得模型的漏洞修复准确率提升。

发明内容

发明目的:本发明的目的是设计一种泛化能力强、鲁棒性强和修复准确率高的漏洞修复方法,以适应漏洞修复数据集规模小的现状。

技术方案:为了解决上述技术问题,本发明设计了一种基于对抗迁移学习的预训练漏洞修复方法,包括以下步骤:

S100.构建浅编码器-深解码器架构的代码生成器模型;

S200.基于步骤S100,利用函数级别的大型代码数据集对所述代码生成器模型使用改进的预训练技术进行预训练,得到预训练的代码生成器模型;

S300.基于步骤S200,提取所述代码生成器模型的编码器组构建判别器模型;

S400.基于步骤S200和步骤S300,利用所述预训练代码生成器模型和判别器模型构建生成对抗网络;利用函数级别的漏洞修复数据集对所述生成对抗网络进行再训练,得到适用于修复漏洞代码的最优代码生成器模型;

S500.基于步骤S400,将函数级别的漏洞代码输入所述最优代码生成器模型,得到修复的代码。

进一步的,在步骤S100中,步骤S100具体为:

所述编码器和解码器是基于CodeT5模型中的编码器和解码器,所述浅编码器-深解码器架构表示代码生成器模型中解码器数量多于编码器数量。

进一步的,在步骤S200中,步骤S200包括以下步骤:

S210.利用初始Unigram LM(一元语言模型)分词器将所述函数级别的大型代码数据集转为代码token序列,得到预训练的分词器、代码token序列;

S220.基于步骤S100和步骤S210,利用改进的因果语言建模技术对所述代码生成器模型进行第一步预训练,得到初步预训练的代码生成器模型;

S230.基于步骤S210和步骤S220,利用改进的Span Denoising(跨度去噪)技术对所述初步预训练的代码生成器模型进行第二步预训练,得到预训练的代码生成器模型;

其中,所述改进的Span Denoising技术包括:

在编码器的输入token序列中按50%的概率替换10%的token“[TOKEN 0],,,[TOKEN n]”为预定义token“[LABEL 0],,,[LABEL n]”,并在其之前添加特殊token“[SOM]”;在正确的token序列之前添加特殊token“[EOM]”作为解码器输出的目标token序列;让解码器生成被替换的token序列“[TOKEN 0],,,[ TOKEN n]”,得到预训练的代码生成器模型。

进一步的,在步骤S220中,步骤S220包括以下步骤:

S221.在所述代码token序列中的5%到100%之间按照50%的概率选择一个token;在所选token之前的token序列的后面添加一个特殊token“[GOB]”;将添加特殊token后的token序列作为模型输入,将所选token之后的token序列作为模型输出;

S222.在所述代码token序列中的5%到100%之间按照50%的概率选择一个token;在所选token之后的token序列的前面添加一个特殊token“[GOF]”;将添加特殊token后的token序列作为模型输入,将所选token之前的token序列作为模型输出,得到初步预训练的代码生成器模型。

进一步的,在步骤S300中,步骤S300包括以下步骤:

S310.基于步骤S200,提取所述预训练的代码生成器模型的编码器,得到编码器组;

其中,所述编码器组包含所述预训练的代码生成器模型编码器组的参数;

S320.基于步骤S310,将所述编码器组与线性变化层、输出层组合,得到判别器模型;

进一步的,在步骤S400中,步骤S400包括以下步骤:

S410.基于步骤S200和步骤S300,利用所述预训练代码生成器模型和判别器模型构建生成对抗网络;

S420.基于步骤S210,利用所述预训练的分词器对函数级别的漏洞修复数据集分词得到漏洞代码token序列和修复代码token序列;

S430.基于步骤S410和步骤S420,将所述漏洞代码token序列和修复代码token序列同时输入所述生成对抗网络的代码生成器模型得到生成概率序列;

同时,所述代码生成器模型学习所述生成概率序列与输入的修复代码token序列之间的差异,得到损失值a;

S440.基于步骤S410、步骤S420和步骤S430,利用Nucleus Sampling(又称Top-pSampling,核心采样)算法对所述生成概率序列进行最优排列得到漏洞代码修复token序列;

同时,将所述修复代码token序列和漏洞代码修复token序列输入所述生成对抗网络的判别器模型,判别器模型学习修复代码token序列和漏洞代码修复token序列的差异,得到损失值b;

S450.优化器根据损失值a和损失值b,优化代码生成器模型,得到最优的代码生成器模型。

进一步的,在步骤S500中,步骤S500包括以下步骤:

S510.基于步骤S210,利用所述预训练的分词器对函数级别的漏洞代码进行分词得到待修复漏洞代码token序列;

S520.基于步骤S400和步骤S510,将所述待修复漏洞代码token序列输入最优的代码生成器模型,得到修复代码概率序列;

S530.基于步骤S520,再次利用Nucleus Sampling算法对所述修复代码概率序列进行最优排列得到修复的代码。

本发明与现有技术相比的有益效果是:

(1)本发明使用的浅编码器-深解码器架构进行概率序列生成,相较于使用编码器和解码器数量一致的Transformer模型而言,浅编码器-深解码器架构对于代码生成任务具有更好的表现;

(2)本发明使用函数级别的大型代码数据集进行代码生成器模型预训练,相较于将漏洞修复数据集抽象化和规范化后用于Transformer模型训练而言,代码生成器模型可以学到更广泛的代码结构、语意和特征,以适应在漏洞修复数据集小的现状下修复漏洞的任务;

(3)本发明使用对抗迁移学习来进行模型训练,相较于直接训练Transformer模型训练而言,对抗训练可以通过将生成的错误修复代码用于反向训练代码生成器模型,迁移学习可以将代码生成领域知识迁移至漏洞代码修复领域中,提高了模型的鲁棒性、泛化性。

附图说明

图1是本发明的系统流程图;

图2是本发明中预训练代码生成器模型的一种实施例的流程图;

图3是本发明中构建生成对抗网络并训练得到最优代码生成器模型的一种实施例的流程图;

图4是本发明中构建的生成对抗网络的一种实施例的示意图;

图5是本发明中修复待修复漏洞代码的一种实施例的流程图。

具体实施方式

下面将结合附图和实施例,对本发明的技术方案进行清楚、完整地描述。以下实施例或者附图用于说明本发明,但不用来限制本发明的范围。

参阅图1-图5,本实施提供了一种基于对抗迁移学习的预训练漏洞修复方法,包括:

在一个实施例中,如图1所示,所述图1为本发明的系统流程。

S100.构建浅编码器-深解码器架构的代码生成器模型。

具体的,所述编码器和解码器基于CodeT5模型中的编码器和解码器,所述架构为T5架构,所述浅编码器-深解码器架构表示代码生成器模型中解码器数量多于编码器数量,所述代码生成器模型中的编码器和解码器之间通过交叉注意力层连接,所述代码生成器模型中的最后一个解码器后连接线性变化层和输出层。

例如,所述代码生成器模型的浅编码器-深解码器架构可以设置编码器数量为12个,解码器数量为18个;

其中,所述线性变化层可以采用适用于回归任务的神经网络,所述输出层可以采用适用于文本生成任务的神经网络。

例如,线性变化层采用全连接层神经网络,其激活函数可以采用ReLU激活函数;输出层可以采用Softmax概率输出层。

S200.基于步骤S100,利用函数级别的大型代码数据集对所述代码生成器模型使用改进的预训练技术进行预训练,得到预训练的代码生成器模型。

在一个实施例中,如图2所示,所述图2为如何预训练代码生成器模型,所述步骤S200包括以下步骤:

S210.利用初始Unigram LM(一元语言模型)分词器将所述函数级别的大型代码数据集转为代码token序列,得到预训练的分词器、代码token序列;

例如,所述函数级别的大型代码数据集有软件工程领域公开的CodeSearchNet数据集、Github-Code数据集等,或者是将CodeSearchNet数据集和Github-Code数据集等不同的大型代码数据集进行组合再去重的数据集;

例如,将语句“This is a test.”输入分词器,分词后得到的代码token序列为:“'_Thi', 's', '_is', '_a', '_t', 'est', '.'”。

S220.基于步骤S100和步骤S210,利用改进的因果语言建模技术对所述代码生成器模型进行第一步预训练,得到初步预训练的代码生成器模型;

其中,所述改进的因果语言建模技术分为两步;

S221.在所述代码token序列中的5%到100%之间按照50%的概率选择一个token;在所选token之前的token序列的后面添加一个特殊token“[GOB]”;将添加特殊token后的token序列作为模型输入,将所选token之后的token序列作为模型输出;

例如,对于代码token序列“'_Thi', 's', '_is', '_a', '_t', 'est', '.'”,选择的token为“_is”,则模型的输入token序列为:“'_Thi', 's', '[GOB]'”,模型的输出token序列为:“'_a', '_t', 'est', '.'”。

S222.在所述代码token序列中的5%到100%之间按照50%的概率选择一个token;在所选token之后的token序列的前面添加一个特殊token“[GOF]”;将添加特殊token后的token序列作为模型输入,将所选token之前的token序列作为模型输出,得到初步预训练的代码生成器模型;

例如,对于代码token序列“'_Thi', 's', '_is', '_a', '_t', 'est', '.'”,选择的token为“_a”,则模型的输入token序列为:“'[GOF]', '_t', 'est', '.'”,模型的输出token序列为:“'_Thi', 's', '_is'”。

S230.基于步骤S210和步骤S220,利用改进的Span Denoising(跨度去噪)技术对所述初步预训练的代码生成器模型进行第二步预训练,得到预训练的代码生成器模型;

其中,所述改进的Span Denoising技术包括:

在编码器的输入token序列中按50%的概率替换10%的token“[TOKEN 0],,,[TOKEN n]”为预定义token“[LABEL 0],,,[LABEL n]”,并在其之前添加特殊token“[SOM]”;在正确的token序列之前添加特殊token“[EOM]”作为解码器输出的目标token序列;让解码器生成被替换的token序列“[TOKEN 0],,,[ TOKEN n]”,得到预训练的代码生成器模型。

例如,编码器的输入token序列为:“'_Thi', 's', '_is', '_a', '_t', 'est','.'”,被替换的token为:“'_is'”,则替换后的token序列为:“'_Thi', 's', '[SOM]', '[LABEL 0]', '_a', '_t', 'est', '.'”,解码器输出的目标token序列为:“'_Thi', 's','[EOM]', '_is', '_a', '_t', 'est', '.'”。

S300.基于步骤S200,提取所述代码生成器模型的编码器组构建判别器模型。

在一个实施例中,所述步骤S300包括以下步骤:

S310.基于步骤S200,提取所述预训练的代码生成器模型的编码器,得到编码器组;

其中,所述编码器组包含所述预训练的代码生成器模型编码器组的参数;

S320.基于步骤S310,将所述编码器组与线性变化层、输出层组合,得到判别器模型。

例如,线性变化层和输出层可以采用和步骤S100中一样的全连接层和Softmax概率输出层。

S400.基于步骤S200和步骤S300,利用所述预训练代码生成器模型和判别器模型构建生成对抗网络;利用函数级别的漏洞修复数据集对所述生成对抗网络进行再训练,得到适用于修复漏洞代码的最优代码生成器模型。

在一个实施例中,如图3、图4所示,所述图3为如何构建生成对抗网络并训练最优代码生成器模型,所述图4为本步骤构建的生成对抗网络的一种实施例的示意图,所述步骤S400包括以下步骤:

S410.基于步骤S200和步骤S300,利用所述预训练代码生成器模型和判别器模型构建生成对抗网络;

S420.基于步骤S210,利用所述预训练的分词器对函数级别的漏洞修复数据集分词得到漏洞代码token序列和修复代码token序列;

例如,所述函数级别的漏洞修复数据集有软件安全领域公开的CVEfixes数据集、Big-Vul数据集等,或者是将CVEfixes数据集和Big-Vul数据集等不同的漏洞修复数据集进行组合去重的漏洞修复数据集。

S430.基于步骤S410和步骤S420,将所述漏洞代码token序列和修复代码token序列同时输入所述生成对抗网络的代码生成器模型得到生成概率序列;

同时,所述代码生成器模型学习所述生成概率序列与输入的修复代码token序列之间的差异,得到损失值a。

例如,损失值a可以使用交叉熵损失函数学习。

S440.基于步骤S410、步骤S420和步骤S430,利用Nucleus Sampling(又称Top-pSampling,核心采样)算法对所述生成概率序列进行最优排列得到漏洞代码修复token序列;

同时,将所述修复代码token序列和漏洞代码修复token序列输入所述生成对抗网络的判别器模型,判别器模型学习修复代码token序列和漏洞代码修复token序列的差异,得到损失值b。

例如,利用Nucleus Sampling算法对该概率序列进行最优排列得到漏洞代码修复token序列时,设置Nucleus Sampling算法的top_p=0.9(此参数表示累计概率值)、max_length=50(此参数表示生成序列的最大长度)、temperature=0.8(此参数表示采样过程中概率分布的平滑程度的参数)、num_return_sequences=50(此参数表示生成序列的数量),经过所述参数设置后,会生成50个漏洞代码修复序列,每个序列的最大长度为50个token;损失值b可以使用交叉熵损失函数学习。

S450.优化器根据损失值a和损失值b,优化代码生成器模型,得到最优的代码生成器模型。

例如,优化器可以使用AdamW优化器,以8为批尺寸对上述生成对抗神经网络进行100期训练,学习率设置为2e-5,权重衰减设置为1e-4,预热步数设置为200,梯度积累步数设置为8,得到最优的代码生成器模型。

S500.基于步骤S400,将函数级别的漏洞代码输入所述最优代码生成器模型,得到修复的代码。

在一个实施例中,如图5所示,所述图5为如何修复待修复漏洞代码,所述步骤S500包括以下步骤:

S510.基于步骤S210,利用所述预训练的分词器对函数级别的漏洞代码进行分词得到待修复漏洞代码token序列;

S520.基于步骤S400和步骤S510,将所述待修复漏洞代码token序列输入最优的代码生成器模型,得到修复代码概率序列;

S530.基于步骤S520,再次利用Nucleus Sampling算法对所述修复代码概率序列进行最优排列得到修复的代码。

例如,再次利用Nucleus Sampling算法得到修复的代码时,设置NucleusSampling算法的top_p=0.9(此参数表示累计概率值)、max_length=50(此参数表示生成序列的最大长度)、temperature=0.8(此参数表示采样过程中概率分布的平滑程度的参数)、num_return_sequences=5(此参数表示生成序列的数量),经过所述参数设置后,会生成5个漏洞代码修复序列,每个序列的最大长度为50个token。

最后应说明的是:以上所述仅为本发明的一种实施例而已,并不用于限制本发明,尽管参照前述实施例对本发明进行了详细的说明,对于本领域的技术人员来说,其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换。

凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进,均应包含在本发明的保护范围之内。

相关技术
  • 基于掩码生成对抗网络迁移学习的无监督图像修复方法
  • 一种基于生成对抗网络模型的水下图像修复方法和生成对抗网络模型训练方法
  • 基于掩码生成对抗网络迁移学习的无监督图像修复方法
技术分类

06120116482397