掌桥专利:专业的专利平台
掌桥专利
首页

基于在线多源迁移学习的快速自适应轨迹预测方法

文献发布时间:2024-04-18 19:58:21


基于在线多源迁移学习的快速自适应轨迹预测方法

技术领域

本发明涉及轨迹预测领域,具体涉及一种基于在线多源迁移学习的快速自适应轨迹预测方法。

背景技术

轨迹预测是智能监控、无人驾驶、社会机器人等领域的研究热点。目前轨迹预测主要通过在固定离线数据上使用批处理方式训练深度模型实现。然而,现实世界中预测模型需要面对实时更新的数据以及场景变化,传统批处理方式很难实时地处理获取到的新数据,限制了轨迹预测模型在现实场景中的应用。

尽管在线学习在图像分类方面有大量研究,但是在轨迹预测领域研究较为缺乏,轨迹预测模型仍以在已知公共数据集上的离线学习方式为主。但是,现实世界中轨迹数据每时每刻都在产生,场景也在不断动态变化,模型的适应能力遭遇挑战。同时,离线学习方式将不同源域的数据联合训练,没有考虑到不同源域之间的差异性。此外,离线学习方式也会导致模型无法实时提高推理能力,需要花费大量时间和成本重新训练和验证。

发明内容

本发明要解决的技术问题是克服现有技术的缺陷,提供一种基于在线多源迁移学习的快速自适应轨迹预测方法,它能够适应实时变化的场景和轨迹数据,进而对未来轨迹进行准确预测。

为了解决上述技术问题,本发明的技术方案是:一种基于在线多源迁移学习的快速自适应轨迹预测方法,包括:

将n个基学习器{f

数据{X

所述更新操作包括:

计算在线学习器f

计算每个基学习器的预测损失L

根据预测损失L

对n+1个权重

利用松弛后的权重对n个基学习器和在线学习器

进一步,利用预测损失L

其中,η表示学习率。

进一步,根据预测损失调整权重的公式为:

其中,i=1,...,n+1,β∈(0,1)为权重衰减因子。

进一步,计算学习器在t时刻的预测损失的步骤包括:

学习器根据观测轨迹X

基于未来真实轨迹Y

所述学习器包括在线学习器f

进一步,对n+1个权重

其中s为平滑系数。

进一步,基学习器引入有记忆提取模块和轨迹预测模块,在对应源域中预先训练基学习器f

在源域

通过预训练过的记忆提取模块提取键值记忆得到键值外部记忆体M′,保存每个键值记忆[H

基于k-means聚类将行人轨迹聚类为多个,选取每类对应的键值记忆得到外部记忆体M;

轨迹预测模块在源域

进一步,记忆提取模块的工作过程包括:

记忆提取模块的工作过程包括:

将行人p的历史轨迹的相对位移

行人p的完整轨迹C

最后将H

进一步,所述轨迹预测模块包括编码器h、解码器d和负责从外部记忆体M迭代查询相关键值存储的多跳注意力机制,编码器h在记忆提取模块中被预先训练,解码器d被用于输出预测的轨迹;其中,外部记忆体M'聚类后成为外部记忆体M。

采用上述技术方案后,本发明结合了在线迁移学习和集成学习来训练在线学习器,在线学习器能够适应实时变化的场景和轨迹数据,实现轨迹预测的快速自适应。

附图说明

图1为本发明的基于在线多源迁移学习的快速自适应轨迹预测方法的流程图;

图2为本发明的基学习器的记忆提取模块的结构图;

图3为本发明的基学习器的轨迹预测模块的结构图。

具体实施方式

为了使本发明的内容更容易被清楚地理解,下面根据具体实施例并结合附图,对本发明作进一步详细的说明。

假设在给定的场景中有m个行人,在时间t处的第p个行人的坐标由

假设有n个源域

下面结合具体的实施例,对上述实施例涉及的技术方案进行详细介绍。

如图1所示,一种基于在线多源迁移学习的快速自适应轨迹预测方法,包括:

将n个基学习器{f

数据{X

所述更新操作包括:

计算在线学习器f

计算每个基学习器的预测损失L

根据预测损失L

由于学习器的结合方式为权重与学习器参数乘积的结合,为了避免权重过大引起更新后的在线学习器的参数变化过大而导致预测效果太差,对n+1个权重

利用松弛后的权重对n个基学习器和在线学习器

基于神经网络的单个基学习器可能会导致欠拟合或过拟合,为了提高泛化性能,本实施例训练多个基学习器学习不同的任务,并通过合理的组合策略形成强学习器。本实施例结合了在线迁移学习和集成学习来训练在线学习器,在线学习器能够适应实时变化的场景和轨迹数据,实现轨迹预测的快速自适应。

在一个实施例中,利用预测损失L

其中,η表示学习率。

在一个实施例中,根据预测损失调整权重的公式为:

其中,i=1,...,n+1,β∈(0,1)为权重衰减因子。

在一个实施例中,计算学习器在t时刻的预测损失的步骤包括:

学习器根据观测轨迹预测行人的未来轨迹,得到未来预测轨迹;

基于未来真实轨迹和未来预测轨迹计算t时刻的预测损失;

具体公式为:

其中||·||

在一个实施例中,对n+1个权重

其中s为平滑系数。

在一个实施例中,基学习器引入有记忆提取模块和轨迹预测模块,在对应源域中预先训练基学习器f

在源域

通过预训练过的记忆提取模块提取键值记忆得到键值外部记忆体M′,保存每个键值记忆[H

基于k-means聚类将行人轨迹聚类为多个,选取每类对应的键值记忆得到外部记忆体M;

轨迹预测模块在源域

需要注意的是,得到外部记忆体M后,在预测时使用得到的历史轨迹与M中的每个键计算相似度,取相似度最高的键对应的值,将这个值送到在线学习器f

具体地,记忆提取模块负责键值记忆的读写操作,保持外部记忆体的低冗余和样本多样性。记忆提取模块的工作过程包括:

记忆提取模块的工作过程包括:

将行人p的历史轨迹的相对位移

为了降低外部记忆体M'的冗余并保持多样性,通过训练控制器(CTR)进行读写操作,CTR通过计算观测轨迹特征编码与记忆键值的余弦相似性得分查找记忆,然后从当前任务中提取典型的行人运动模式存储在外部记忆体M'中。在模型训练开始时,由于外部记忆体M'中缺少可用的观测轨迹编码和未来轨迹编码,以及行人观测轨迹的最后时刻位置,最初随机初始化内存M',并通过CTR进行更新。读写操作如下:

读取操作:在整个过程中,首先计算行人p的观测特征编码H

随后,选取余弦相似度最高的前k个键,并检索它们的对应值G

其中ω

写入操作:在该阶段中,首先计算每个未来预测轨迹

其中l

P

其中ω

L=e·(1-P

其中e的大小决定模型的预测效果,P

集成学习要求基学习器具有已知的结构。然而在外部记忆体提取期间对外部记忆体的更新和修改不能被控制,因此确保外部记忆体的外部记忆体数量的一致,采用k均值聚类根据键值记忆对应行人p的完整轨迹C

M=k-means(C

其中n

在一个实施例中,所述轨迹预测模块包括编码器h、解码器d和负责从外部记忆体M迭代查询相关键值存储的多跳注意力机制,编码器h在记忆提取模块中被预先训练,解码器d被用于输出预测的轨迹;其中,外部记忆体M'聚类后成为外部记忆体M。

观测轨迹和未来轨迹是一对相关的信息,通过多跳注意机制(MHA)预测未来轨迹,MHA用于保证输出轨迹的多样性,即用于预测多条可能的轨迹,该机制从外部记忆体中进行多次迭代推理,并使用多跳的结果作为生成的多模态轨迹。

在给定场景的行人p,观察轨迹X

键寻址:外部记忆体M每个密钥H

p

其中p

值读取:外部记忆体M中的每个值G

轨迹预测:考虑行人的历史轨迹对未来轨迹的影响,将观测轨迹H

其中ω

查询向量更新:生成轨迹后,当前的查询向量q

q

在k次迭代后,生成k个预测轨迹

其中||·||

需要注意的是,记忆提取模块生成的轨迹只用于在训练集中训练网络提取观测轨迹和未来轨迹的特征,也就是训练网络提取更好的记忆体M。而轨迹预测模块则是从输出真正的预测轨迹,相当于在真实世界中预测轨迹。

以上述依据本发明的理想实施例为启示,通过上述的说明内容,相关工作人员完全可以在不偏离本项发明技术思想的范围内,进行多样的变更以及修改。本项发明的技术性范围并不局限于说明书上的内容,必须要根据权利要求范围来确定其技术性范围。

相关技术
  • 一种基于卷积神经网络和迁移学习的无源定位轨迹数据识别方法及系统
  • 一种基于卷积神经网络和迁移学习的无源定位轨迹数据识别方法及系统
技术分类

06120116485219