掌桥专利:专业的专利平台
掌桥专利
首页

自适应步长梯度下降的方法、装置、设备及可读存储介质

文献发布时间:2023-06-19 12:18:04


自适应步长梯度下降的方法、装置、设备及可读存储介质

技术领域

本发明涉及机器学习领域,特别涉及一种自适应步长梯度下降的方法、装置、设备及可读存储介质。

背景技术

在现有技术中,机器学习本质是求解一个目标函数,并以此目标函数判断未来数据走向,优化算法是求解目标函数中参数的重要工具,广泛使用的优化算法当属梯度下降,因此,提升梯度下降收敛速度、解决梯度下降算法陷入局部最优解问题、减少梯度下降在最小值附近的震荡性,对机器学习、自然语言处理、图像处理等领域都有着有重要的理论意义和实际应用价值。

梯度下降其中一个优化方向是学习率的优化,目前已有的学习率优化算法采用递减方式,即学习率随迭代递减,导致梯度下降后期收敛慢或不收敛情况。

有鉴于此,提出本申请。

发明内容

本发明公开了一种自适应步长梯度下降的方法、装置、设备及可读存储介质,旨在解决学习率随迭代递减,导致梯度下降后期收敛慢或不收敛情况。

本发明第一实施例提供了一种自适应步长梯度下降的方法,包括:

获取优化模型中的假设函数,并生成所述假设函数的损失函数;

运算所述损失函数,并获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果;

根据所述比较结果,调用调节函数对参数的学习率进行自适应调节,以重新运算损失函数。

优选地,所述假设函数的模型为:h

其中,θ

所述损失函数的模型为:

其中,h

优选地,所述根据所述比较结果,调用调节函数对参数的学习率进行自适应调节,以重新运算损失函数,具体为:

在根据所述比较结果判断到当前时刻的损失函数值大于上一时刻的损失函数值时,增加学习率,重新生成模型参数,并通过重新生成的模型参数运算当前时刻的损失函数。

优选地,还包括:在根据所述比较结果判断到当前时刻的损失函数值小于上一时刻的损失函数值时,减小学习率,重新生成模型参数,并通过重新生成的模型参数运算上一时刻的损失函数。

优选地,所述调节函数为tanh函数。

其中,所述tanh函数的模型为λ=λ(1+tanh(J

优选地,所述模型参数的表达式为:

θ

其中,θ

本发明第二实施例提供了一种自适应步长梯度下降的装置,包括:

损失函数生成单元,用于获取优化模型中的假设函数,并生成所述假设函数的损失函数;

比较结果生成单元,用于运算所述损失函数,并获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果;

调节函数调用单元,用于根据所述比较结果,调用调节函数对参数的学习率进行自适应调节,以重新运算损失函数。

优选地,所述假设函数的模型为:h

其中,θ

所述损失函数的模型为:

其中,h

本发明第三实施例提供了一种自适应步长梯度下降的设备,包括处理器、存储器以及存储在所述存储器中且被配置由所述处理器执行的计算机程序,所述处理器执行所述计算机程序实现如上任意一项所述的一种自适应步长梯度下降的方法。

本发明第四实施例提供了一种可读存储介质,其特征在于,存储有计算机程序,所述计算机程序能够被所述计算机可读存储介质所在设备的处理器执行,以实现如上任意一项所述的一种自适应步长梯度下降的方法。

基于本发明提供的一种自适应步长梯度下降的方法、装置、设备及可读存储介质,通过获取优化模型中的假设函数,并生成所述假设函数的损失函数,对所述损失函数进行运算,获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果,根据比较结果,利用调节函数自适应的修改学习率,进而修改损失函数的模型参数,重新运算损失函数,解决了学习率随迭代递减,导致梯度下降后期收敛慢或不收敛情况。

附图说明

图1是本发明第一实施例提供的一种自适应步长梯度下降的方法流程示意图;

图2是本发明提供的损失函数示意图;

图3是本发明提供的在过大或过小的学习率下的示意图;

图4是本发明提供的tanh函数示意图;

图5是本发明第二实施例提供的一种自适应步长梯度下降的装置模块示意图;

具体实施方式

下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

为了更好的理解本发明的技术方案,下面结合附图对本发明实施例进行详细描述。

应当明确,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。

在本发明实施例中使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本发明。在本发明实施例和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。

应当理解,本文中使用的术语“和/或”仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。

取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”或“响应于检测”。类似地,取决于语境,短语“如果确定”或“如果检测(陈述的条件或事件)”可以被解释成为“当确定时”或“响应于确定”或“当检测(陈述的条件或事件)时”或“响应于检测(陈述的条件或事件)”。

实施例中提及的“第一\第二”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二”在允许的情况下可以互换特定的顺序或先后次序。应该理解“第一\第二”区分的对象在适当情况下可以互换,以使这里描述的实施例能够以除了在这里图示或描述的那些以外的顺序实施。

以下结合附图对本发明的具体实施例做详细说明。

本发明公开了本发明公开了一种自适应步长梯度下降的方法、装置、设备及可读存储介质,旨在解决学习率随迭代递减,导致梯度下降后期收敛慢或不收敛情况。

请参阅图1,本发明第一实施例提供了一种自适应步长梯度下降的方法,可由自适应步长梯度下降设备(以下简称梯度下降设备)来执行,特别的,由梯度下降设备内的一个或者多个处理器来执行,以实现如下步骤:

S101,获取优化模型中的假设函数,并生成所述假设函数的损失函数;

需要说明的是,在本实施例中,梯度下降的先决条件是确认优化模型的假设函数和损失函数,以线性回归为例,所述假设函数的模型可以为:h

其中,θ

所述损失函数的模型为:

其中,h

需要说明的是,梯度下降的任务是在当前模型的状态下,求得梯度,通过参数减去步长乘以梯度,进行参数矫正,直至找到最优模型,即损失函数最低的模型。

本实施例中,所述模型参数的表达式为:

θ

其中,θ

需要说明的是,梯度下降算法中,学习率是非常重要的参数,学习率过小可能导致收敛训练收敛过慢,学习率过大,可能导致函数不停震荡,无法获取最优解。如图3所示。

在本实施例中,所述调节函数可以为tanh函数,其中,tanh函数的图像如图4所示,当然,在其他实施例中,所述调节函数还可是sigmoid函数,其可以根据实际情况对应选择,这里不做具体限定,但这些方案均在本发明的保护范围内。

其中,所述tanh函数的模型为λ=λ(1+tanh(J

需要说明的是,所述调节函数,即tanh函数范围在-1到1之间变化。在梯度下降前期,损失函数下降较快,不会因为损失函数波动过大而产生巨大变化,tanh可以稳定在1附近,自动调节不断增加,通过不断加大步长而加快收敛。当损失函数开始增大,参数重新回到损失函数最小的状态,通过自动减小学习率,改变新的落脚点,使损失函数始终处于不断减小的状态,有利于减小梯度下降的震荡。最后损失函数变化几乎为0,可使学习率稳定在一个固定学习率状态,有利于提高收敛效果。

S102,运算所述损失函数,并获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果;

S103,根据所述比较结果,调用调节函数对参数的学习率进行自适应调节,以重新运算损失函数。

具体地:在本实施例中,在根据所述比较结果判断到当前时刻的损失函数值大于上一时刻的损失函数值时,增加学习率,重新生成模型参数,并通过重新生成的模型参数运算当前时刻的损失函数。

在本实施例中,还包括:在根据所述比较结果判断到当前时刻的损失函数值小于上一时刻的损失函数值时,减小学习率,重新生成模型参数,并通过重新生成的模型参数运算上一时刻的损失函数。

需要说明的是,所述tanh函数的模型(即λ=λ(1+tanh(J

λ=λ(1+tanh(J

需要说明的是,所述tanh函数的模型(即λ=λ(1+tanh(J

应当理解的是,当损失函数是下降状态是仍然使用模型A,若损失函数开始增加,则使用模型C进行计算,模型C与式模型A的区别是式A是已增大的学习率再进行逐渐减小,式C是直接初始化学习率,进行重新计算。

请参阅图5,本发明第二实施例提供了一种自适应步长梯度下降的装置,包括:

损失函数生成单元201,用于获取优化模型中的假设函数,并生成所述假设函数的损失函数;

比较结果生成单元202,用于运算所述损失函数,并获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果;

调节函数调用单元203,用于根据所述比较结果,调用调节函数对参数的学习率进行自适应调节,以重新运算损失函数。

优选地,所述假设函数的模型为:h

其中,θ

所述损失函数的模型为:

其中,h

本发明第三实施例提供了一种自适应步长梯度下降的设备,包括处理器、存储器以及存储在所述存储器中且被配置由所述处理器执行的计算机程序,所述处理器执行所述计算机程序实现如上任意一项所述的一种自适应步长梯度下降的方法。

本发明第四实施例提供了一种可读存储介质,其特征在于,存储有计算机程序,所述计算机程序能够被所述计算机可读存储介质所在设备的处理器执行,以实现如上任意一项所述的一种自适应步长梯度下降的方法。

基于本发明提供的一种自适应步长梯度下降的方法、装置、设备及可读存储介质,通过获取优化模型中的假设函数,并生成所述假设函数的损失函数,对所述损失函数进行运算,获取当前时刻的损失函数值与上一时刻的损失函数值,并生成两者的比较结果,根据比较结果,利用调节函数自适应的修改学习率,进而修改损失函数的模型参数,重新运算损失函数,解决了学习率随迭代递减,导致梯度下降后期收敛慢或不收敛情况

示例性地,本发明第三实施例和第四实施例中所述的计算机程序可以被分割成一个或多个模块,所述一个或者多个模块被存储在所述存储器中,并由所述处理器执行,以完成本发明。所述一个或多个模块可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述所述计算机程序在所述实现一种自适应步长梯度下降的设备中的执行过程。例如,本发明第二实施例中所述的装置。

所称处理器可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等,所述处理器是所述一种自适应步长梯度下降的方法的控制中心,利用各种接口和线路连接整个所述实现对一种自适应步长梯度下降的方法的各个部分。

所述存储器可用于存储所述计算机程序和/或模块,所述处理器通过运行或执行存储在所述存储器内的计算机程序和/或模块,以及调用存储在存储器内的数据,实现一种自适应步长梯度下降的方法的各种功能。所述存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、文字转换功能等)等;存储数据区可存储根据手机的使用所创建的数据(比如音频数据、文字消息数据等)等。此外,存储器可以包括高速随机存取存储器,还可以包括非易失性存储器,例如硬盘、内存、插接式硬盘、智能存储卡(Smart Media Card,SMC)、安全数字(SecureDigital,SD)卡、闪存卡(Flash Card)、至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。

其中,所述实现的模块如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一个计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。

需说明的是,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本发明提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。

以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求的保护范围为准。

相关技术
  • 自适应步长梯度下降的方法、装置、设备及可读存储介质
  • 用于通过增强现实设备来提供信息的方法和装置、用于提供用于控制增强现实设备的显示的信息的方法和装置、用于控制增强现实设备的显示的方法和装置、具有用于执行方法的指令的计算机可读的存储介质
技术分类

06120113239671