掌桥专利:专业的专利平台
掌桥专利
首页

敏感度分析和强化学习的神经网络剪枝方法、系统及装置

文献发布时间:2023-06-19 09:27:35


敏感度分析和强化学习的神经网络剪枝方法、系统及装置

技术领域

本申请涉及深度学习压缩技术领域,具体而言,涉及一种基于敏感度分析和强化学习的神经网络剪枝方法。

背景技术

剪枝(prune)是卷积神经网络(CNN)的一种压缩技术,它主要用来减少卷积神经网络(CNN)计算量。剪枝算法通常情况下是通过裁剪掉神经网络权重(weight)中不重要的张量(tensor)来达到降低整个神经网络的计算量的目的。

神经网络权重(weight)中哪些张量(tensor)不重要是由其稀疏度(sparsity)判定的。稀疏度(sparsity)用来衡量张量(tensor)中0的个数与张量的大小来表示。所以裁剪掉权重(weight)中稀疏度(sparsity)较高的张量(tensor),就可以达到压缩卷积神经网络(CNN)的目的。

卷积神经网络(CNN)压缩的准则是在减少计算量的同时保证网路的精度。文献1中提到了一种敏感度分析(sensitivity analyse)解决了裁剪掉权重(weight)中稀疏度大于多少的张量(tensor)的问题。就是先独立的裁剪掉每个权重的tensor,然后再通过数据验证集来检测网络的精度。通过这种方法可以分析出每个权重的敏感度,用以确认该裁剪掉当前权重的多少张量(tensor)。基于敏感度分析(sensitivity analyse)的剪枝办法,主要针对的是独立的权重,并没有考虑到不同权重之间的相关性,所以往往并不能取得较好的压缩效率。

基于强化学习的神经网络剪枝是一种自动化的剪枝技术,可以自动分析出神经网络权重(weight)的稀疏度(sparsity),然后做出合理的决策对网路进行剪枝,并且在多数情况下被剪枝的网络的网络精度和压缩率都比较好。

这种基于强化学习的神经网络剪枝分为三步:随机对神经网络的多个权重(权重)进行随机裁剪,然后对裁剪后的网络进行微调(fine tuning)记录网络的精度。然后把裁剪办法和裁剪后神经网络的精度一起记录下来放入数据缓冲区。第一步重复n次,在缓冲区积累一定数据后,则进行第二步利用数据缓冲区中的数据训练强化学习代理(agent),并利用第二步训练的代理(agent),预测具体的裁剪动作,利用该方法对网络进行剪枝,对第二步剪枝后的网络,进行微调(fine tuning),记录下微调(fine tuning)后的网络精度,将第二步预测得到的裁剪动作和第三步微调(fine tuning)后的网络精度放入缓冲区。然后跳到第二步。当第三步微调(fine tuning)后的网络精度达到预期,则停止循环。基于强化学习的方法,就是要教会网络采用什么样的办法对网络进行裁剪,才能够得到高的回报(网络精度)。这就需要训练代理(agent)的数据够好,包含的信息够全。但是在上面提到的第一步经过随机的对多个权重(weight)进行裁剪后,有时所获得的网络精度并不理想,利用这些“不好”的数据很难训练出有效的代理(agent),或者使训练代理(agent)的时间增加。基于强化学习的剪枝办法,虽然考虑单个权重(weight)和多个权重(weight)裁剪对网络精度的影响。但有时候因为无法得到有效的数据,无法训练出比较好的代理(agent),这些不好的代理(agent),往往无法产生好的剪枝办法。

发明内容

1、本发明的目的

本发明为了解决强化学习方法中的训练数据无法包含全部信息导致的网络精度不高的问题,而提出了一种基于敏感度分析和强化学习的神经网络剪枝方法。

2、本发明所采用的技术方案

本发明提出了一种敏感度分析和强化学习的神经网络剪枝方法,包括:

设定稀疏度阈值步骤,选择低敏感度的权重进行剪枝;

获取裁剪办法和精度步骤,根据上述的敏感度权重确定需要进行随机剪枝的权重;对被选定的每一个权重进行随机裁剪,将多次随机裁剪的剪枝办法和精度放入缓冲区;

训练强化学习步骤,利用缓冲区中的数据训练强化学习代理,训练后生成的裁剪办法和精度放入缓冲区;重复进行,直到网络精度达到预设值。

优选的,设定稀疏度阈值步骤,选择低敏感度的权重进行剪枝,即:设定各权重的稀疏度阈值,采用当前稀疏度阈值裁剪后,网络下降的精度保持在预设范围内。

优选的,获取裁剪办法和精度步骤,对被选定的每一个权重进行随机裁剪,保证其稀疏度小于稀疏度阈值。

优选的,训练强化学习步骤,利用缓冲区中的数据训练强化学习代理,利用训练后生成的代理,当前选定的权重进行预测,确定对应的裁剪办法,然后利用生成的裁剪办法对各网络权重进行裁剪,最后对裁剪后的网络进行多次微调,记录最终的网络精度,并将训练后的裁剪办法和精度放入缓冲区。

优选的,网络下降的精度保持在20%以内。

本发明提出了一种敏感度分析和强化学习的神经网络剪枝系统,包括:

设定稀疏度阈值模块,用于选择低敏感度的权重进行剪枝;

获取裁剪办法和精度模块,用于根据上述的敏感度权重确定需要进行随机剪枝的权重;对被选定的每一个权重进行随机裁剪,将多次随机裁剪的剪枝办法和精度放入缓冲区;

训练强化学习模块,用于利用缓冲区中的数据训练强化学习代理,训练后生成的裁剪办法和精度放入缓冲区;重复进行,直到网络精度达到预设值。

优选的,设定稀疏度阈值模块,用于选择低敏感度的权重进行剪枝,即:设定各权重的稀疏度阈值,采用当前稀疏度阈值裁剪后,网络下降的精度保持在预设范围内。

优选的,获取裁剪办法和精度模块,用于对被选定的每一个权重进行随机裁剪,保证其稀疏度小于稀疏度阈值。

优选的,训练强化学习模块,用于利用缓冲区中的数据训练强化学习代理,利用训练后生成的代理,当前选定的权重进行预测,确定对应的裁剪办法,然后利用生成的裁剪办法对各网络权重进行裁剪,最后对裁剪后的网络进行多次微调,记录最终的网络精度,并将训练后的裁剪办法和精度放入缓冲区。

优选的,网络下降的精度保持在20%以内。

本发明提出了一种敏感度分析和强化学习的神经网络剪枝装置,包括存储器和处理器,存储器存储有计算机程序,所述处理器执行所述计算机程序时实现所述的方法步骤。

本发明提出了一种计算机可度存储介质,其上存储有计算机程序,所述的计算机程序被处理器执行时实现所述的方法步骤。

3、本发明所采用的有益效果

(1)本发明选择低敏感度的权重进行剪枝,设定各权重的稀疏度阈值保证被裁剪的权重采用当前稀疏度进程裁剪后,网络下降的精度保持在预设范围以内。在保证网络高精度,

(2)本发明中提到的是一种结构剪枝办法而非元素剪枝,在保证高精度的情况下,最大化的提升了神经网络的压缩率,具体为:

压缩率有两种衡量方式,一种是基于模型参数多少的比值

(3)现有技术中针对模型大小进行剪枝,比较容易控制神经网络模型的大小,但是很难比较好的控制神经网络的计算量MACs(乘加次数),并且剪枝后的网络需要部署在特定的硬件上,不具有普遍性。

综上,本发明采用的方法可以同时针对模型计算量和存储空间进行剪枝。

附图说明

为了更清楚地说明本申请实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。

图1为本发明流程图;

图2为本发明获取裁剪办法和精度步骤流程图;

图3为本发明训练强化学习步骤。

具体实施方式

下面结合本发明实例中的附图,对本发明实例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域技术人员在没有做创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。

下面将结合附图对本发明实例作进一步地详细描述。

实施例1

利用敏感度分析(sensitivity analyse)对神经网络进行分析,使强化学习的数据缓冲区中的初始数据是在一定范围内随机进行。

通过敏感度(sensitivity analyse)可以确定出哪些权重敏感度较高,不能进行剪枝,那些权重敏感度过低,可以进行剪枝。

本发明提出的一种敏感度分析和强化学习的神经网络剪枝方法,如图1所示,包括:

S100、选择低敏感度的权重进行剪枝,设这些被选定的权重为W(w0,w1,w2,...wn),然后设定各权重的稀疏度阈值T(t0,t1,t2...tn)。这些阈值的选定必须保证被裁剪的权重采用当前稀疏度进行裁剪后,网络下降的精度保持在20%以内。

S200、根据步骤S100得到的W确定需要进行随机剪枝的权重。如图2所示,包括:

S201、对被选定的每一个权重(weight)wi进行随机裁剪,保证其稀疏度小于ti。进行m轮实验,对每一轮剪枝的神经网络进行p次微调(fine tuning),并记录下微调(finetuning)后的网络精度;将这m轮实验的剪枝办法和精度放入缓冲区B。

S300、利用缓冲区B中的数据训练强化学习代理(agent),如图3所示,包括:

S301、利用训练后生成的代理(agent),当前选定的权重进行预测,确定对应的裁剪办法;

S302、利用生成的裁剪办法对各网络权重进行裁剪;

S303、对裁剪后的网络进行p次微调(fine tuning),记录最终的网络精度,并将裁剪办法和精度放入缓冲区B。重复进行步骤300,直到网络精度达到要求。

本发明提出的一种敏感度分析和强化学习的神经网络剪枝系统,包括:设定稀疏度阈值模块、获取裁剪办法和精度模块、训练强化学习模块;

设定稀疏度阈值模块,用于选择低敏感度的权重进行剪枝;

获取裁剪办法和精度模块,用于根据上述的敏感度权重确定需要进行随机剪枝的权重;对被选定的每一个权重进行随机裁剪,将多次随机裁剪的剪枝办法和精度放入缓冲区;

训练强化学习模块,用于利用缓冲区中的数据训练强化学习代理,训练后生成的裁剪办法和精度放入缓冲区;重复进行,直到网络精度达到预设值。

其中,设定稀疏度阈值模块,用于选择低敏感度的权重进行剪枝,即:设定各权重的稀疏度阈值,采用当前稀疏度阈值裁剪后,网络下降的精度保持在预设范围内。

其中,获取裁剪办法和精度模块,用于对被选定的每一个权重进行随机裁剪,保证其稀疏度小于稀疏度阈值。

其中,训练强化学习模块,用于利用缓冲区中的数据训练强化学习代理,利用训练后生成的代理,当前选定的权重进行预测,确定对应的裁剪办法,然后利用生成的裁剪办法对各网络权重进行裁剪,最后对裁剪后的网络进行多次微调,记录最终的网络精度,并将训练后的裁剪办法和精度放入缓冲区。

机器可读存储介质作为一种计算机可读存储介质,可用于存储软件程序、计算机可执行程序以及模块,如本申请实施例中的虚拟现实对象控制方法对应的程序指令/模块(所示的获取模块、第一确定模块、第二确定模块以及对象控制模块)。处理器通过检测存储在机器可读存储介质中的软件程序、指令以及模块,从而执行终端设备的各种功能应用以及数据处理,即实现上述的虚拟现实对象控制方法,在此不再赘述。

机器可读存储介质可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序;存储数据区可存储根据终端的使用所创建的数据等。此外,机器可读存储介质可以是易失性存储器或非易失性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read-OnlyMemory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(ErasablePROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(SynchronousDRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double DataRateSDRAM,DDRSDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(Synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(DirectRambus RAM,DR RAM)。应注意,本文描述的系统和方法的存储器旨在包括但不限于这些和任意其它适合发布节点的存储器。在一些实例中,机器可读存储介质可进一步包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至虚拟现实设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。

处理器可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法实施例的各步骤可以通过处理器中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器可以是通用处理器、数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(FieldProgrammable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。结合本申请实施例所公开的方法的步骤可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件模块组合执行完成。

在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、虚拟现实设备或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、虚拟现实设备或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的虚拟现实设备、数据中心等数据存储设备。所述可用介质可以是磁性介质(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘(solid state disk,SSD))等。

本申请实施例是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。

这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。

这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。

以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。

技术分类

06120112171924