掌桥专利:专业的专利平台
掌桥专利
首页

一种基于注意力机制的生成式对抗轨迹预测方法

文献发布时间:2023-06-19 10:54:12


一种基于注意力机制的生成式对抗轨迹预测方法

技术领域

本发明涉及人工智能领域,特别是涉及到一种基于注意力机制的生成式对抗轨迹预测方法。

背景技术

行人轨迹预测是指根据行人过往一段时间的运动轨迹来预测其未来一段时间的运动轨迹。随着移动服务机器人与自动驾驶等领域的兴起,在动态场景中进行行人轨迹预测成为了一个热门的研究方向。行人轨迹的正确预测有助于智能导航系统规划出更加合理有效的路径。然而行人轨迹预测问题是极其复杂的,行人的运动具有一定的随机性,其在决策的过程中也是比较主观灵活的,因而行人轨迹具有多样性的特征。其次,行人在行走的过程中,行人的轨迹受到周围动态环境的影响,行人通常会根据自己的常识以及社会规范来调整自己的路径。上述特征使得行人轨迹预测问题充满挑战性。

在行人轨迹预测问题中,如何对行人之间的交互进行有效的建模对于行人轨迹预测非常重要。目前主流的方法大都基于深度学习技术来学习行人之间的交互作用,从而对行人轨迹进行预测。其中,基于长短时记忆网络LSTM的方法已经被证明在处理时序问题时非常奏效,但是基于LSTM的方法无法对行人之间的空间关系进行有效建模。为了解决这一问题,Alahi等在LSTM网络模型的基础上提出了社会长短期记忆网络(S-LSTM),该模型通过对空间进行网格化处理,并根据网格对每个行人周围行人的不同特征进行隐藏池化,利用池化的结果预测出多条符合社会规范的轨迹(见“Social lstm:Human trajectoryprediction in crowded spaces,CVPR 2016”)。由于该方法仅仅只能建模目标行人局部区域内的行人交互作用,无法高效模拟场景中所有行人的交互。Gupta等将生成式对抗网络引入到行人轨迹预测问题当中,提出了社会对抗网络模型,通过对生成器与判别器进行反向训练以及一个池化模块提取场景中所有行人的交互信息,生成了多种符合社会规范的轨迹,并且提升了预测精度(见“Social GAN:Socially acceptable trajectories withgenerative adversarial networks,CVPR 2018”)。但是该方法在提取行人之间的交互信息时仅仅考虑了行人之间的空间位置关系,忽略了周围行人运动方向、速度等因素对目标行人未来轨迹的影响,因而无法对行人之间的交互信息进行充分提取。此外,基于生成式对抗网络的方法在网络训练过程中极易出现生成器与判别器强弱不均衡的现象,从而导致梯度消失难以训练的问题。

针对以上问题,广东工业大学申请了专利号为202010110743.X,专利名称为一种基于长短期记忆的行人轨迹预测方法,该发明公开了一种基于长短期记忆的行人轨迹预测方法,主要包括以下步骤:对数据进行预处理,转换为一个[行人数量,4]的矩阵;引入注意力机制选择对当前行人行走时的方向、速度等指标产生影响的信息,并通过全连接层连接所有当前位置信息;将同一场景下全局行人的历史状态隐藏信息输入池化层进行池化,达到“共享”全局隐藏信息的目的;通过长短期记忆单元将当前状态下所有行人的历史状态隐藏信息的池化张量,当前行人的位置信息以及经注意力机制所选择的对行人产生影响的信息,转化为长短期记忆序列信息;将当前的状态信息通过一个多层感知机结构转化到坐标空间,生成预测轨迹序列:

该专利仍然存在如下缺陷:

首先该专利在注意力机制方面,获取注意力权重的方法仅仅考虑了第i个行人相对于第j个行人的相对位置信息,并未将行人j的速度、相对于行人i的运动方向、与行人i的相对距离等因素综合考虑进去来获取注意力机制的,因此申请人改进了注意力权重的方式为:为了刻画行人j对目标行人i运动产生的影响,注意力池化模块将行人j的速度矢量v

其次本申请针对传统GAN网络在训练过程中存在的生成器与判别器强弱不匹配导致训练梯度消失难以训练的问题,通过修改损失函数,训练过程中在判别器端引入随时间减小的噪声,改善了模型的训练效果,提升了轨迹的预测精度。GAN网络的损失函数可以表示为:

L

而我们改进的GAN网络的损失函数表示为:

L

其中,h(·)表示随时间减小的噪声函数。这样改进的好处在于在网络训练初期,训练数据集数据分布和生成器生成数据分布交集很小,因此判别器可以轻易地区分真实数据与生成数据,从而网络缺乏训练梯度。因此,训练初期在判别器端添加一定的噪声使得训练数据与生成数据存在一定的交集。随着训练时间的增加,生成器生成数据的分布逐渐接近真实数据分布,此时逐渐减小噪声仍可以保证网络具备一定的训练梯度,从而改善网络的训练效果。

发明内容

为了解决上述存在问题。本发明提供一种基于注意力机制的生成式对抗轨迹预测方法,以充分提取行人之间的交互信息从而提升轨迹预测的精度。若将该方法用于服务机器人的导航规划系统中,可有助于服务机器人在与人共融的动态环境中规划出更加合理有效的路径,从而提高导航的舒适性。

本发明提供一种基于注意力机制的生成式对抗轨迹预测方法,其特征在于,包括以下步骤:

步骤1:将行人轨迹数据预处理并送入编码器进行编码处理;

步骤2:将编码后的向量送入一个基于注意力机制的池化模块进行影响力权重分配并得到池化向量;

步骤3:使用一个基于LSTM网络的解码器输出行人的预测轨迹;

步骤4:利用改进后的损失函数使用Adam算法对生成器与判别器进行对抗训练;

步骤5:将行人的观测轨迹送入训练好的网络模型的生成器中,得到预测的行人轨迹坐标。

进一步的,步骤1所述的对行人轨迹进行编码处理,包括:

网络接收行人的历史轨迹并通过一个单层的全连接网络作为嵌入层,将行人i在t时刻的位置变化信息

其中,f(·)是一个采用ReLU激活函数的嵌入层,W

进一步的,步骤2所述的通过基于注意力机制的池化模块来对同一场景中的行人进行影响力权重分配,并输出表征行人交互信息的池化向量,包括:

为了刻画行人j对目标行人i运动产生的影响,模块首先使用池化的方式获取池化向量h

之后,将场景中所有其他行人相对于目标行人i的池化向量汇聚为最终的池化向量H

q

q

W

H

p

其中,s(·)表示使用softmax激活函数的多层全连接网络,W

进一步的,步骤3所述的利用一个基于LSTM网络的解码器输出行人的预测轨迹,包括:

将注意力池化模块输出的池化向量

其中,j(·)、m(·)和g(·)均为带ReLU激活函数的全连接网络,W

进一步的,步骤4所述的对生成器和判别器利用改进的损失函数进行反向训练,包括:

利用改进的损失函数使用Adam算法对网络进行反向训练,改进的损失函数主要包含两个部分,一部分是GAN网络的对抗损失L

假设真实的训练数据x所代表的分布为p

L

但是传统的GAN网络在训练过程中极易出现判别器判别能力过强从而可以轻易的区分生成器的生成数据与训练集的真实数据,从而造成梯度消失无法训练的情况,为了解决传统GAN网络训练困难的问题,步骤4对GAN网络在训练过程中对判别器端的损失函数施加随时间减小的噪声,使得训练数据与生成数据存在一定的交集,随着训练时间的增加,生成器生成数据的分布逐渐接近真实数据分布,此时逐渐减小噪声仍可以保证网络具备一定的训练梯度;因此,改进后的对抗损失L

L

其中,h(·)表示随时间减小的噪声函数;

为了鼓励网络生成多种符合社会规范的轨迹,网络每次采样k个预测轨迹,并选取位置偏移误差最小的轨迹用于计算位置偏移损失,因此,网络的位置偏移损失L

其中,Y

因此,网络总体的损失函数表示为:

L

其中,l为超参数。

进一步的,步骤5所述的将行人的观测轨迹送入到生成器中,即可得到预测的行人轨迹坐标,包括:

依次执行步骤1、步骤2、步骤3,即将行人的观测轨迹送入编码器中进行编码处理从而获取行人运动的隐藏特征,并通过注意力池化模块提取行人的交互信息,最后通过一个解码器输出行人的预测轨迹坐标。

与现有的技术相比,本发明提供的技术方案具有以下有益效果:

1.针对现有方法无法充分提取行人之间交互信息的缺点,通过引入一个注意力池化模块,将行人的运动方向、速度等要素与他们的未来轨迹关联起来,并以此对同一场景中的行人进行影响力权重分配,从而更加有效地提取行人之间的交互信息,同时提升了模型的可解释性。

2.针对生成式对抗网络在训练过程中存在的生成器与判别器强弱不匹配导致训练梯度消失难以训练的问题,通过修改损失函数,训练过程中在判别器端引入随时间减小的噪声,改善了模型的训练效果,提升了轨迹的预测精度。

附图说明

图1为本发明工作流程示意图;

图2为网络模型的整体结构图;

图3为注意力池化模块示意图;

图4为GAN网络训练过程示意图;

图5为预测轨迹可视化对比图。

具体实施方式

下面结合附图与具体实施方式对本发明作进一步详细描述:

本发明提供一种基于注意力机制的生成式对抗轨迹预测方法,以充分提取行人之间的交互信息从而提升轨迹预测的精度。若将该方法用于服务机器人的导航规划系统中,可有助于服务机器人在与人共融的动态环境中规划出更加合理有效的路径,从而提高导航的舒适性。

如图1和2所示,为本发明网络模型的整体结构图,该网络模型主要包括一个生成器模块和判别器模块。生成器模块基于编码器-解码器架构,由编码器、注意力池化模块、解码器三部分组成,生成器接收行人的历史轨迹,并将行人的轨迹经过编码器编码得到行人运动的隐藏特征,然后经过一个结合注意力机制的池化模块来提取行人的交互信息,最终通过解码器模块输出网络预测的行人位置坐标。判别器模块主要由一个编码器模块组成,其接受轨迹输入并通过一个编码器对轨迹编码,之后通过一个分类网络对轨迹的真实程度进行打分。

本发明所提出的方法,具体包括以下步骤:

步骤1:将行人轨迹数据预处理并送入编码器进行编码;

网络接收行人的历史轨迹并通过一个单层的全连接网络作为嵌入层,将行人i在t时刻的位置变化信息

其中,f(·)是一个采用ReLU激活函数的嵌入层,W

步骤2:将编码后的向量送入一个基于注意力机制的池化模块进行影响力权重分配并得到池化向量;

行人的未来轨迹总是受到前面行人的影响,并且与这些行人的速度、运动方向、相对距离等要素有关。如图3所示,目标行人1的未来轨迹主要受视线前方的行人2和3影响,其几乎不受行人4的影响。并且行人2的速度越大,与行人1的相对距离越小,其对行人1的轨迹影响就越大。

为了刻画行人j对目标行人i运动产生的影响,模块首先使用池化的方式获取池化向量h

之后,将场景中所有其他行人相对于目标行人i的池化向量汇聚为最终的池化向量H

q

q

W

H

p

其中,s(·)表示使用softmax激活函数的多层全连接网络,W

步骤3:使用一个基于LSTM网络的解码器输出行人的预测轨迹;

将注意力池化模块输出的池化向量

其中,j(·)、m(·)和g(·)均为带ReLU激活函数的全连接网络,W

步骤4:利用改进后的损失函数使用Adam算法对生成器与判别器进行对抗训练;

改进的损失函数主要包含两个部分,一部分是GAN网络的对抗损失L

假设真实的训练数据x所代表的分布为p

L

但是传统的GAN网络在训练过程中极易出现判别器判别能力过强从而可以轻易的区分生成器的生成数据与训练集的真实数据,从而造成梯度消失无法训练的情况。

为了解决传统GAN网络训练困难的问题,步骤4对GAN网络在训练过程中对判别器端的损失函数施加随时间减小的噪声,如图4所示,图中深色实线代表训练集数据分布p

L

其中,h(·)表示随时间减小的噪声函数。

为了鼓励网络生成多种符合社会规范的轨迹,网络每次采样k个预测轨迹,并选取位置偏移误差最小的轨迹用于计算位置偏移损失,因此,网络的位置偏移损失L

其中,Y

因此,网络总体的损失函数可以表示为:

L

其中,l为超参数。

步骤5:将行人的观测轨迹送入训练好的网络模型的生成器中,得到预测的行人轨迹坐标;

只需依次执行步骤1、步骤2、步骤3,即将行人的观测轨迹送入编码器中进行编码处理从而获取行人运动的隐藏特征,并通过注意力池化模块提取行人的交互信息,最后通过一个解码器输出行人的预测轨迹坐标。

图5展示了三幅具有代表性的行人轨迹预测场景。在每一幅场景中,左子图表示真实的行人运动轨迹,右子图表示行人的观测轨迹与预测轨迹,其中实心圆圈和星形分别代表观测轨迹与预测轨迹。可以看出,本发明提出的方法可以捕捉到结伴而行、相互礼让等复杂的行人间交互,其预测出的轨迹比较符合实际的运动场景,并且网络预测的轨迹没有与其他轨迹发生冲突。因此整体上来说,本发明提出的网络模型输出的预测轨迹既符合社会规范,也是满足物理约束的。

表1.不同模型的ADE和FDE对比(t

本发明使用如下两个指标来刻画预测轨迹的准确性。

1)平均偏移误差(Average Displacement Error,ADE)。表示预测轨迹与真实轨迹序列在每个时间步长的欧氏距离的平均值。

2)最终偏移误差(Final Displacement Error,FDE)。表示预测轨迹与真实轨迹序列在最终时刻的欧式距离。

本发明选取最具代表性的Linear、LSTM、S-LSTM和SGAN网络模型作为比较基准,各种轨迹预测模型的对比结果如表1所示。其中,表中数据单位为米,加粗数据表示最佳结果,Atten-GAN为本发明对应的网络模型,+DN表示Atten-GAN在训练过程引入了随时间减小的噪声,-DN反之。

综合表中数据可以看出,本发明由于引入了注意力池化机制,其可以有选择地融合对目标行人未来轨迹有影响的信息,因此模型具有更强的表现力,其可以对行人的交互进行更准确的刻画。同时,训练过程中在判别器中添加随时间减小的噪声可以在一定程度上改善生成器与判别器强弱不均衡从而造成梯度消失的问题,从而进一步提升网络的预测精度。

以上所述,仅是本发明的较佳实施例而已,并非是对本发明作任何其他形式的限制,而依据本发明的技术实质所作的任何修改或等同变化,仍属于本发明所要求保护的范围。

相关技术
  • 一种基于注意力机制的生成式对抗轨迹预测方法
  • 一种基于生成式对抗网络的交通流预测方法
技术分类

06120112721350