Python时间序列LSTM预测系列学习笔记(10)-多步预测_persistence model-程序员宅基地

技术标签: LSTM预测从入门到实战  

本文是对:

https://machinelearningmastery.com/multi-step-time-series-forecasting-long-short-term-memory-networks-python/

https://blog.csdn.net/iyangdi/article/details/77881755

博文的学习笔记,博主笔风都很浪,有些细节一笔带过,本人以谦逊的态度进行了学习和整理,笔记内容都在代码的注释中。有不清楚的可以去原博主文中查看。

数据集下载:https://datamarket.com/data/set/22r0/sales-of-shampoo-over-a-three-year-period

后期我会补上我的github
源码地址:https://github.com/yangwohenmai/LSTM/tree/master/LSTM%E7%B3%BB%E5%88%97/Multi-Step%20LSTM%E9%A2%84%E6%B5%8B1

本文主要是告诉我们一种多步预测的方法。

什么是多步预测?可以理解为我取最新的3条数据去预测下一条数据,再用最近的2条数据和刚预测出来的1条数据去重新进行预测。

文章中没有对数据进行正常的预测,只是简单的把最后一次观测值作为预测值,演示了一下多步预测的流程,最后算了一下损失函数的数值。

数据上并没有参考意义,所以最后做出的图形也没有参考意义,我们只需理解其思路即可。

本人对文章的批注都加载代码的注释中,文章整理如下:

 

数据准备与模型评估

1、拆分成训练和测试数据。
训练数据=前两年香皂销售数据
测试数据=剩下一年的香皂销售数据

2、Multi-Step 预测
假设需要预测3个月的销售数据

3、模型评估
用rolling-forcast(walk-forward)方式模型验证
测试数据每个时间步,滑动一个值,预测;之后测试数据的下一个真实观测值加入模型,并预测

用RMSE评估

持久模型(Persistence Model)

他是很好的时间序列预测的基准
是最简单的预测

原理:
用当前值作为之后的预测值
 

 

# coding=utf-8
from pandas import read_csv
from pandas import DataFrame
from pandas import concat
from sklearn.metrics import mean_squared_error
from math import sqrt
from matplotlib import pyplot
from pandas import datetime


def parser(x):
    return datetime.strptime(x, '%Y/%m/%d')

# 把数据拆分,线性数据变成四个一组的监督型数据
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):
    n_vars = 1 if type(data) is list else data.shape[1]
    df = DataFrame(data)  # 数据多了行标、列标
    cols, names = list(), list()
    for i in range(n_in, 0, -1):
        cols.append(df.shift(i))
        names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]
    for i in range(0, n_out, 1):
        cols.append(df.shift(-i))
        if i == 0:
            names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]
        else:
            names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]
    agg = concat(cols, axis=1)
    agg.columns = names
    if dropnan:
        agg.dropna(inplace=True)
    return agg


# 拆分正训练+测试数据
def prepare_data(series, n_test, n_lay, n_seq):
    raw_values = series.values
    raw_values = raw_values.reshape(len(raw_values), 1)
     #转换成四个一组的监督型数据
    supervised = series_to_supervised(raw_values, n_lay, n_seq)
    supervised_values = supervised.values
    # 前3/4作为训练数据,后1/4作为预测 测试数据
    train, test = supervised_values[0:-n_test], supervised_values[-n_test:]
    return train, test


# persistence model预测
# 用上一次观察值作为之后n_seq的预测值
# 其实只是单纯的把上一次的观测值,重复三次写入一个包含三个元素的数组,作为一个包含三个元素的预测结果
def persistence(last_ob, n_seq):
    return [last_ob for i in range(n_seq)]


# 评估persistence model
# 把由
def make_forcast(train, test, n_lay, n_seq):
    forcasts = list()
    for i in range(len(test)):
        x, y = test[i, 0:n_lag], test[i, n_lag:]
        # 这里的预测其实就是抄写上一次的观测值,把观测值变成一个数组列表
        forcast = persistence(x[-1], n_seq)
        forcasts.append(forcast)
    return forcasts


# 预测评估
# 计算预测结果的损失值,把抄写的观测值结果带入运算损失值,输出。
def evaluate_forcasts(test, forcasts, n_lag, n_seq):
    for i in range(n_seq):
        actual = test[:, (n_lag + i)]
        predicted = [forcast[i] for forcast in forcasts]
        print('predicted')
        print(predicted)
        rmse = sqrt(mean_squared_error(actual, predicted))
        print('t+%d RMSE:%f' % ((i + 1), rmse))  # 1~n_seq各个长度的预测的rmse


def plot_forcasts(series, forcasts, n_test):
    # 原始数据
    pyplot.plot(series.values)
    # 预测数据
    for i in range(len(forcasts)):
        off_s = len(series) - n_test + i - 1
        off_e = off_s + len(forcasts[i]) + 1
        xaxis = [x for x in range(off_s, off_e)]
        yaxis = [series.values[off_s]] + forcasts[i]
        print('xaxis')
        print(xaxis)
        print('yaxis')
        print(yaxis)
        print('series.values[off_s]')
        print(series.values[off_s])
        pyplot.plot(xaxis, yaxis, color='red')
    pyplot.show()


series = read_csv('data_set/shampoo-sales.csv', header=0, parse_dates=[0], index_col=0, squeeze=True, date_parser=parser)

# 一步数据,预测3步
n_lag = 1
n_seq = 3
n_test = 10  # 给了最后12个月,预测3个月,则能预测的次数是10,即10个3个月,即1,2,3->4 2,3,4->5 3,4,5->6 ...
train, test = prepare_data(series, n_test, n_lag, n_seq)
print('train data')
print(train)
print('test data')
print(test)
forecasts = make_forcast(train, test, n_lag, n_seq)
print('forecasts')
print(forecasts)
# 没有任何意义,只是为了教你如何进行多步的预测,数据全是根据最后观测值编造的
evaluate_forcasts(test, forecasts, n_lag, n_seq)
plot_forcasts(series, forecasts, n_test + 2)

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/yangwohenmai1/article/details/84568633

智能推荐

关于vue中的组件渲染函数render中的scopedSlots属性和this.$scopedSlots和this.$slots的疑惑记录-程序员宅基地

文章浏览阅读1.3w次,点赞14次,收藏25次。最近研究了vue的官方文档,到组件自定义渲染函数时,对第二个属性对象参数中的scopedSlots不太明白作用是什么,官网的案例也是一笔带过,于是连查带试算是明白了他的作用,这里记录一下,希望能帮到遇到相同问题的童鞋.先说一下函数中的$slots吧,这个用起来很简单,直接获取到组件中对应的插槽虚拟节点.this.$slots.插槽名称.废话少说,直接上代码:<div id="a..._this.$scopedslots

如何运用DDD(三):领域服务-程序员宅基地

文章浏览阅读884次。本文将介绍领域驱动设计(DDD)战术模式中另一个非常重要的概念——领域服务。在前面两篇博文中,我们已经学习到了什么是值对象和实体,并且能够比较清晰的定位它们自身的行为。但是在某些时候,你..._领域服务需要用到的服务,放在什么位置

消息 无法为JSP编译类:org.apache.jasper.JasperException-程序员宅基地

文章浏览阅读1.8k次。无法为JSP编译类:记录一个较少见的问题。在maven+ssm搭建中遇到的 org.apache.jasper.JasperException解决办法:(未细究其根本原因)将tomcat有tomcat7变更为tomcat8,启动项目可正产访问!_无法为jsp编译类:

Qdrant向量数据库-程序员宅基地

文章浏览阅读1.7k次,点赞8次,收藏11次。是专为支持而设计的向量和向量,这使得它适用于各种基于的等应用。Qdrant 使用编写,即使在高负载下也能快速、可靠地工作。_qdrant

windows下mysql 备份和恢复_在cmd中执行mysql数据恢复语句-程序员宅基地

文章浏览阅读870次。Windows下备份和恢复1、备份①、用管理员身份启动cmd命令行②、进入mysql安装目录下bin文件夹内cd C:\Program Files\MySQL\MySQL Server 5.6\bin③、执行命令,回车,输入密码,备份文件在mysql bin文件下(xxx是数据库名,aaa是备份后的文件名,可以自己设置)mysqldump -u root -p xxx> aaa...._在cmd中执行mysql数据恢复语句

《数据产品经理修炼手册:从零基础到大数据产品实践》阅读笔记_4. if you can鈥檛 measure it, you can鈥檛 improve it-程序员宅基地

文章浏览阅读330次。数据产品经理修炼手册:从零基础到大数据产品实践IF you can't measure it, you can't improve it。-->意思是:如果你无法衡量,你就无法增长。数据产品:是可以发挥数据价值去辅助用户做更优决策的一种产品形式。它在用户的决策和行动过程中,可以提供更多的分析展现和数据洞察,让数据更直观、功效地驱动业务。从受众用户群体来看,数据产品可分为三类:1.企业内部使用的数据产品。2.企业针对公司推出的商业型数据产品。3.每个用户均可使用的数据产品。._4. if you can鈥檛 measure it, you can鈥檛 improve it

随便推点

.NET WinForm 文本控件加入水印文字_c#winform给图片添加水印-程序员宅基地

文章浏览阅读2.7k次。今天突然来了一个这样的需求,需要在C#的编辑框上加入一个Hint水印效果,类似如下图:以前在手机上(wp)上做过类似的效果。参考silverlight toolkit 的searchTextBox。现在要在winform下制作,开始我还以为应该有啥啥属性可以一键搞定,结果目测了一下,没有什么属性,于是乎百度了一下,网上说用win32API来做,这倒挺神奇的,参考别人做了如下列子。申明一_c#winform给图片添加水印

线性代数|线性方程组的矩阵形式_方程写成矩阵形式-程序员宅基地

文章浏览阅读2.4k次。线性代数学习笔记_方程写成矩阵形式

智能互联时代,打开未来的大门——电信物联卡-程序员宅基地

文章浏览阅读13次。首先,它采用全球通用的SIM卡标准,可以在全球范围内实现设备和网络的即插即用,实现方便快捷的连接。其次,电信物联卡具备稳定可靠的网络连接能力,不论是在城市还是偏远的乡村,不论是低功耗设备还是高频传输设备,都能提供高质量的网络服务。在智慧城市中,电信物联卡可以实现对公共设施的监控和管理,提高城市的运行效率和人们的生活质量。电信物联卡的普及和应用,不仅改变了我们的生活方式,也为未来的智能互联时代带来了无限可能。在当今信息化高速发展的时代,物联网作为技术创新的重要领域,正深刻地改变着我们的生活和工作方式。

2022 年顶级网络安全专家最爱用的10大工具_信息安全专业用到哪些软件-程序员宅基地

文章浏览阅读2.3w次,点赞35次,收藏450次。随着互联网安全威胁的不断加剧,越来越多的企业,尤其是大企业需要雇佣持有CISSP证书的网络安全专家来保护自己的网站、APP、服务、数据不受侵害,不受破坏。_信息安全专业用到哪些软件

PL/SQL 导出oracle数据库表数据的sql文件中文乱码_oracle导出sql中文乱码-程序员宅基地

文章浏览阅读6.4k次,点赞4次,收藏7次。接到了一个需求,把正式数据库的A表(2.9万+)数据,导到测试数据库的A表中。数据量小,可以用“+”这种方式数据量大只能用dmp、sql文件、DBlink由于两个数据库的表空间不一样,dmp的方式先放弃。采用导出sql文件的形式,导出来后查看文件竟然中文乱码。导入数据库后也还是乱码状态。于是去网上查找到了解决办法。 1、首先产生问题的原因是,自己电脑环境变量NLS_LANG的值和数据库字符集不一致就造成了导入之后数据是乱码的问题。 2、解决办..._oracle导出sql中文乱码

javaScript中的继承方式12种_12种继承方式-程序员宅基地

文章浏览阅读186次。javaScript中的对象的继承方式摘自《JavaScript面向对象编程指南(第二版)》_12种继承方式

推荐文章

热门文章

相关标签