【Mo 人工智能技术博客】时间序列预测——DA-RNN模型_da rnn-程序员宅基地

技术标签: python  rnn  机器学习  深度学习  人工智能  

时间序列预测——DA-RNN模型

作者:梅昊铭

1. 背景介绍

传统的用于时间序列预测的非线性自回归模型(NRAX)很难捕捉到一段较长的时间内的数据间的时间相关性并选择相应的驱动数据来进行预测。本文将介绍一种基于 Seq2Seq 模型(Encoder-Decoder 模型)并结合 Attention 机制的时间序列预测方法。作者提出了一种双阶段的注意力机制循环神经网络模型(DA-RNN),能够很好的解决上述两个问题。

模型的第一部分,我们引入输入注意力机制在每个时间步选择相应的输入特征。模型的第二部分,我们使用时间注意力机制在整个时间步长中选择相应的隐藏层状态。通过这种双阶段注意力机制,我们能够有效地解决一些时序预测方面的问题。我们将这两个注意力机制模型集成在基于 LSTM 的循环神经网络中,使用标准反向传播进行联合训练。

2. DA-RNN 模型

2.1 输入与输出

输入:给定 n 个驱动序列(输入特征), X = ( x 1 , x 2 , . . . , x n ) T = ( x 1 , x 2 , . . . , x T ) ∈ R n × T X = (x^1,x^2,...,x^n)^T = (x_1,x_2,...,x_T) \in R^{n \times T} X=x1,x2,...,xnT=(x1,x2,...,xT)Rn×T T T T 表示时间步长, n n n 表示输入特征的维度。

输出: y ^ T = F ( y 1 , . . . , y T − 1 , x 1 , . . . , x T ) \hat{y}_{T}= F(y_1,...,y_{T-1},x_1,...,x_T) y^T=F(y1,...,yT1,x1,...,xT) ( y 1 , . . . , y T − 1 ) (y_1,...,y_{T-1}) (y1,...,yT1)表示预测目标过去的值,其中 y t ∈ R y_t\in R ytR ( x 1 , . . . , x T ) (x_1,...,x_T) (x1,...,xT) 为时间 T T T n n n 维的外源驱动输入序列, x t ∈ R n x_t \in R^n xtRn F ( ⋅ ) F(\cdot) F() 为模型需要学习的非线性映射函数。

2.2 模型结构

DA-RNN 模型是一种基于注意力机制的 Encoder-Decoder 模型。在编码器部分,我们引入了输入注意力机制来选择相应的驱动序列;在解码器部分,我们使用时间注意力机制来选择整个儿时间步长中相应的隐藏层状态。通过这个两种注意力机制,DA-RNN 模型能够选择最相关的输入特征,并且捕捉到较长时间内的时间序列之间的依赖关系,如图1所示。


图 1:DA-RNN 模型结构

读者福利:知道你对人工智能、Python 感兴趣,小Mo 便精心准备了这门适合零基础小白学习的《人工智能导论》9.9元 浙大教授吴超老师带你进入AI大门!

2.3 编码器

编码器本质上是一个 RNN 模型,它能够将输入序列转换为一种特征表示,我们称之为隐藏层状态。对于时间序列预测问题,给定输入 X = ( x 1 , x 2 , . . . , x T ) ∈ R n × T , x t ∈ R n X = (x_1,x_2,...,x_T) \in R^{n \times T},x_t \in R^n X=(x1,x2,...,xT)Rn×T,xtRn,在时刻 t t t ,编码器将 x t x_t xt 映射为 h t h_t ht h t = f 1 ( h t − 1 , x t ) h_t = f_1(h_{t-1},x_t) ht=f1(ht1,xt) h t ∈ R m h_t \in R^m htRm 表示编码器隐藏层在时刻 t t t 的状态, m m m 表示隐藏层的维度,KaTeX parse error: Expected group after '_' at position 2: f_̲ 为非线性激活函数,本文中我们使用 LSTM。

本文中,我们提出了一种输入注意力机制编码器。它能够适当地选择相应的驱动序列,这对时间序列预测是至关重要的。我们通过确定性注意力模型来构建一个输入注意力层。它需要将之前的隐藏层状态 h t − 1 h_{t-1} ht1 和** LSTM** 单元的** cell **状态 s t − 1 s_{t-1} st1 作为该层的输入得到:
e t k = v e T t a n h ( W e [ h t − 1 ; s t − 1 ] + U e x k ) e^k_t = v^T_etanh(W_e[h_{t-1};s_{t-1}]+U_ex^k) etk=veTtanh(We[ht1;st1]+Uexk),其中 v e ∈ R T , W e ∈ R T × 2 m , U e ∈ R T × T v_e \in R^T,W_e \in R^{T \times 2m},U_e \in R^{T \times T} veRT,WeRT×2m,UeRT×T是需要学习的参数。
输入注意力层的输出 ( e t 1 , e t 2 , . . . , e t n ) (e^1_t,e^2_t,...,e^n_t) (et1,et2,...,etn) 输入到 softmax 层得到 α t k \alpha_t^k αtk 以确保所有的注意力权重的和为1, α t k \alpha_t^k αtk 表示在时刻 t t t k k k 个输入特征的重要性。

得到注意权重后,我们可以自适应的提取驱动序列 x ~ t = ( α t 1 x t 1 , α t 2 x t 2 , . . . , α t n x t n ) \tilde x_t = (\alpha^1_tx^1_t,\alpha^2_tx^2_t,...,\alpha^n_tx^n_t) x~t=(αt1xt1,αt2xt2,...,αtnxtn),此时我们更新隐藏层的状态为 h t = f 1 ( h t − 1 , x ~ t ) h_t = f_1(h_{t-1},\tilde x_t) ht=f1(ht1,x~t)

2.4 解码器

为了预测输出 y ^ T \hat y_T y^T,我们使用另外一个 LSTM 网络层来解码编码器的信息,即 隐藏层状态 KaTeX parse error: Expected group after '_' at position 2: h_̲。当输入序列过长时,传统的Encoder-Decoder 模型效果会急速恶化。因此,在解码器部分,我们引入了时间注意力机制来选择相应的隐藏层状态。

与编码器中注意力层类似,解码器的注意力层也需要将之前的隐藏层状态 d t − 1 d_{t-1} dt1LSTM 单元的cell状态 s t − 1 ′ s'_{t-1} st1 作为该层的输入得到该层的输出:
l t i = v d T t a n h ( W d [ d t − 1 ; s t − 1 ′ ] + U d h i ) l^i_t = v^T_dtanh(W_d[d_{t-1};s'_{t-1}]+U_dh_i) lti=vdTtanh(Wd[dt1;st1]+Udhi),其中 v d ∈ R m , W d ∈ R m × 2 p , U e ∈ R m × m v_d \in R^m,W_d \in R^{m \times 2p},U_e \in R^{m \times m} vdRm,WdRm×2p,UeRm×m是需要学习的参数。通过 softmax 层,我们可以得到第 i i i 个编码器隐藏状态 h i h_i hi 对于最终预测的重要性 β t i \beta^i_t βti。解码器将所有的编码器隐藏状态按照权重求和得到文本向量 c t = ∑ i = 1 T β t i h i c_t = \sum_{i=1}^T \beta_t^ih_i ct=i=1Tβtihi,注意 c t c_t ct 在不同的时间步是不同的。

在得到文本向量之后,我们将其和目标序列结合起来得到 y ~ t − 1 = w ~ T [ y t − 1 ; c t − 1 ] + b ~ \tilde y_{t-1} = \tilde w^T[y_{t-1};c_{t-1}]+\tilde b y~t1=w~T[yt1;ct1]+b~。利用新计算得到的 y ~ t − 1 \tilde y_{t-1} y~t1,我们来更新解码器隐藏状态 d t = f 2 ( d t − 1 , y ~ t − 1 ) d_t=f_2(d_{t-1},\tilde y_{t-1}) dt=f2(dt1,y~t1),我们使用 LSTM 来作为激活函数 f 2 f_2 f2
通过 DA-RNN 模型,我们预测 y ^ T = F ( y 1 , . . . , y T − 1 , x 1 , . . . , x T ) = v y T ( W y [ d T ; c T ] + b w ) + b v \hat y_T = F(y_1,...,y_{T-1},x_1,...,x_T) = v_y^T(W_y[d_T;c_T]+b_w)+b_v y^T=F(y1,...,yT1,x1,...,xT)=vyT(Wy[dT;cT]+bw)+bv

2.5 训练过程

在该模型中,作者使用平均方差作为目标函数,利用 Adam 优化器,min-batch 为128来进行参数优化。
目标函数:
O ( y T , y ~ T ) = 1 N ∑ i = 1 N ( y ^ T i − y T i ) 2 O(y_T,\tilde y_T)=\frac{1}{N}\sum_{i=1}^N(\hat y^i_T-y_T^i)^2 OyT,y~T=N1i=1N(y^TiyTi)2

3. 实验

3.1 数据集

本文的作者采用了,两种不同的数据集来测试验证 DA-RNN 模型的效果。这里我们仅对 NASDAQ 100 Stock 数据集进行介绍。作者根据 NASDAQ 100 Stock 收集了 81 家主要公司的股票价格作为驱动时间序列,NASDAQ 100 的股票指数做目标序列。数据收集的频率为一分钟一次。该数据集包含了从2016年7月26日至2016年12月22日总共105天的数据。在本实验中,作者使用 35100 条数据作为训练集,2730条数据作为验证集,以及最后2730条数据作为测试集。

3.2 参数设置和评价指标

时间窗口的大小 T ∈ { 3 , 5 , 10 , 15 , 25 } T \in \{3,5,10,15,25\} T{ 3,5,10,15,25}。实验表明 :T=10 时,模型在验证集上的效果最好。编码器和解码器隐藏层的大小 m , p ∈ { 16 , 32 , 64 , 128 , 256 } m ,p\in\{16,32,64,128,256\} m,p{ 16,32,64,128,256}。当 m = p = 64 , 128 m=p=64,128 m=p=64,128 时,实验效果最好。

为评估模型的效果,我们考虑了三种不同的评价指标:RSME,MAE,MAPE。

3.3 模型预测

为展示 DA-RNN 模型的效果,作者将该模型和其他的模型在两个不同的数据集上的预测效果进行了对比,如表1所示。由表1可以看出,DA-RNN模型相对于其他模型,误差更小一些。DA-RNN模型在时间序列预测方面具有良好的表现。

表 1:SML 2010数据集和纳斯达克100股票数据集的时间序列预测结果

为了更好的视觉比较,我们将Encoder-Decoder 模型,Attention RNN 和 DA-RNN 模型的在纳斯达克100股票数据集上的预测结果在图2中展示出来。我们不难看出DA-RNN模型能更好地反映真实情况。

图 3:三种模型在纳斯达克100股票数据集上的预测结果

4. 总结

在本文中,我们介绍了一种基于注意力机制的双阶段循环神经网络模型。该模型由两部分组成:Encoder 和 Decoder。在编码器部分,我们引入了输入注意力机制来对输入特征进行特征提取,为相关性较高的特征变量赋予更高的权重;在解码器部分,我们通过时间注意力机制为不同时间 t t t 的隐藏状态赋予不同的权重,不断地更新文本向量,来找出时间相关性最大的隐藏层状态。Encoder 和 Decode 中的注意力层分别从空间和时间上来寻找特征表示和目标序列之间的相关性,为不同的特征变量赋予不同的权重,以此来更准确地预测目标序列。
项目源码地址:https://momodel.cn/workspace/5da8cc2ccfbef78329c117ed?type=app

5. 参考资料

  1. 论文:A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
  2. 注意力机制详解:https://blog.csdn.net/BVL10101111/article/details/78470716
  3. 项目源码:https://github.com/chensvm/A-Dual-Stage-Attention-Based-Recurrent-Neural-Network-for-Time-Series-Prediction
  4. 数据集:https://cseweb.ucsd.edu/~yaq007/NASDAQ100_stock_data.html

欢迎关注我们的微信公众号:MomodelAI

同时,欢迎使用 「Mo AI编程」 微信小程序

以及登录官网,了解更多信息:Mo 平台

Mo,发现意外,创造可能

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44015907/article/details/102860962

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签