CNN、RNN用于时间序列预测的代码接口和数据格式详解(pytorch)

2023-10-27

网上对时序问题的代码详解很少,这里自己整理对CNN和RNN用于时序问题的代码部分记录,便于深入理解代码每步的操作。
本文中涉及的代码:https://github.com/EavanLi/CNN-RNN-TSF-a-toy

一、1D-CNN

1. Conv1d的接口

class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

in_channels (int):输入通道数,在时间序列背景下即为输入序列的元数,或称为特征数。
out_channels (int):输出通道数,时序预测背景下单元预测单元/多元预测多元时out_channels和in_channels保持一致。
kernel_size (int or tuple):卷积核的尺寸;卷积核的第二个维度由in_channels决定,所以实际上卷积核的大小为kernel_size * in_channels
stride (int or tuple, optional) – 卷积操作的步长。 默认:1
padding (int or tuple, optional) – 输入数据各维度各边上要补齐0的层数。 默认: 0
dilation (int or tuple, optional) – 卷积核各元素之间的距离。 默认: 1
groups (int, optional) – 输入通道与输出通道之间相互隔离的连接的个数。 默认:1
bias (bool, optional) – 如果被置为True,向输出增加一个偏差量,此偏差是可学习参数。 默认:True

2. 输入数据shape介绍及应用卷积

(1)单元时序

  • 输入数据shape介绍

举例:任意生成batch_size为5,长为50的单元时序数据。其中univariate_data.shape为torch.Size([5, 1, 50]),分别表示batch_size, 输入通道数/元数/特征数,时序长度。

univariate_data = torch.rand(5, 1, 50)
  • 卷积构建并传入数据
    举例:对单元预测单元来说,Conv1d的输入通道数和输出通道数均为1,以kernel size为3来说,构造卷积层并将数据传入。
univariate_conv = nn.Conv1d(in_channels=1, out_channels = 1, kernel_size = 3)
univariate_out = univariate_conv(univariate_data)
  • 输出数据shape介绍
    这里univariate_out的shape是torch.Size([5, 1, 48])。
  • 以前几个数据的卷积操作为例,看下图解释其计算过程:
    在这里插入图片描述

(2)多元预测多元

  • 输入数据shape介绍
    举例:任意生成batch_size为5,特征数为2且长为50的多元时序数据。其中univariate_data.shape为torch.Size([5, 2, 50]),分别表示batch_size, 输入通道数/元数/特征数,时序长度。
multivariate_data = torch.rand(5, 2, 50)
  • 卷积构建并传入数据
    举例:对多元预测多元来说,Conv1d的输入通道数和输出通道数均为特征数,以kernel size为3来说,构造卷积层并将数据传入。
multivariate_conv1 = nn.Conv1d(in_channels=2, out_channels = 2, kernel_size = 3)
multivariate_out = multivariate_conv1(multivariate_data)
  • 输出数据shape介绍
    这里multivariate_out的shape是torch.Size([5, 2, 48])。
  • 以前几个数据的卷积操作为例,看下图解释其计算过程:
    在这里插入图片描述

(3)多元预测单元

  • 输入数据shape介绍
    与多元预测多元相同
multivariate_data = torch.rand(5, 2, 50)
  • 卷积构建并传入数据
    举例:对多元预测单元来说,Conv1d的输入通道数为特征数,输出特征数为1,以kernel size为3来说,构造卷积层并将数据传入。
multivariate_conv = nn.Conv1d(in_channels=2, out_channels = 1, kernel_size = 3)
univariate_out = multivariate_conv(multivariate_data)
  • 输出数据shape介绍
    这里univariate_out的shape是torch.Size([5, 1, 48])。
  • 以前几个数据的卷积操作为例,看下图解释其计算过程:
    在这里插入图片描述

3. 1D-CNN的前馈过程

多元预测单元,长为50的历史数据,其特征数为2,预测未来30个时刻的数值,其代码如下:

  input = torch.rand(5, 2, 50)

  conv = nn.Conv1d(in_channels=2, out_channels = 1, kernel_size = 3)
  pool = nn.MaxPool1d(2, 2)
  linear = nn.Linear(24,30)

  output = conv(input)# 结束后为torch.Size([5, 1, 48])
  output = torch.relu(output) # 结束后为torch.Size([5, 1, 48])
  output = pool(output)# 结束后为torch.Size([5, 1, 24])
  output = linear(output) # 结束后为torch.Size([5, 1, 30])

预测结果查看(以上代码仅为卷积操作的一次前馈过程,卷积参数未经过训练,这个结果代码就是一个little toy):

plt.plot(input[0][0])
plt.plot(range(len(input[0][0]),len(input[0][0]) + len(output[0][0])),output[0][0].detach().numpy())

4. 1D-CNN的完整训练过程

以sinewave数据集为例,写出1维卷积的完整预测过程,并给出中间特征结果图。

#以下为数据生成、完整的训练和预测代码
# ------------------------------------ 完整的卷积预测,以单元预测单元为例,历史输入长为50,做三步预测 ------------------------------------
def series_to_supervised(data, input_length, output_length, drop=True):
  supervised_x, supervised_y = [], []
  for i in range(len(data)-input_length-output_length):  # 多余的数据抛弃
      supervised_x.append(data[i:i+input_length])
      supervised_y.append(data[i + input_length: i + input_length + output_length])
  return supervised_x, supervised_y

def sinewave(N, period, amplitude):
  x1 = np.arange(0, N, 1)
  frequency = 1/period
  theta = 0
  y = amplitude * np.sin(2 * np.pi * frequency * x1 + theta)
  return y

np.random.seed(0)

# 生成sinewave数据
N = 1500
y1 = sinewave(N, 24, 1)  # plt.plot(range(len(y1)), y1) 24指的是周期长度,1是震动幅度
y2 = sinewave(N, 168, 1.5)  # plt.plot(range(len(y2)), y2)
y3 = sinewave(N, 672, 2)  # plt.plot(range(len(y3)), y3)
y = y1 + y2 + y3 + np.random.normal(0, 0.2, N)#y = y1 + y2 + y3 + np.random.normal(0, 0.2, N)
y[672:] += 10  # 模拟从样本中间开始的突然变化

# 划分训练数据和测试数据
train_data = y[:int(len(y)*0.6)]
test_data = y[int(len(y)*0.6):]

# 转化为监督学习格式
input_length, output_length = 50, 3
train_x, train_y = series_to_supervised(train_data, input_length, output_length)
train_x, train_y = torch.Tensor(train_x), torch.Tensor(train_y)
train_x, train_y = train_x.resize(train_x.shape[0],1,train_x.shape[1]), train_y.resize(train_y.shape[0],1,train_y.shape[1])
test_x, test_y = series_to_supervised(test_data, input_length, output_length)
test_x, test_y = torch.Tensor(test_x), torch.Tensor(test_y)
test_x, test_y = test_x.resize(test_x.shape[0],1,test_x.shape[1]), test_y.resize(test_y.shape[0],1,test_y.shape[1])

# 搭建网络
class Model(nn.Module):
  def __init__(self):
      super().__init__()
      self.conv = nn.Conv1d(1, 1, 3)
      self.pool = nn.MaxPool1d(2, 2)
      self.linear = nn.Linear(24, 3) # 根据当前tensor形状和预测步长确定
  def forward(self,x):
      x.requires_grad_()
      output = self.conv(x)
      output = torch.relu(output)
      output= self.pool(output)
      output = self.linear(output)
      return output

CNN_model = Model()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(CNN_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
loss_record = [] #记录训练损失变化

# 训练网络
for epoch in range(700):
  predict_y = CNN_model(train_x)
  loss = criterion(predict_y, train_y)
  print('loss:', loss.item())
  loss_record.append(loss.item())
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

# 进行测试
with torch.no_grad():
  predict_test = CNN_model(test_x)
  predict_test = predict_test.detach().numpy()

# 绘制测试结果
plt.plot(y[:int(len(y)*0.6)], label = 'training data') # 训练数据
plt.plot(range(len(y[:int(len(y)*0.6)]), len(y)), y[int(len(y)*0.6):], label = 'True value of the testing data') # 测试数据的真实值
for sample_No in range(len(predict_test)):
  plt_range_min = sample_No+len(y[:int(len(y)*0.6)]) + input_length
  plt_range_max = plt_range_min + output_length
  plt.plot(range(plt_range_min,plt_range_max),predict_test[sample_No][0],'--') # 绘制预测结果
plt.legend()
plt.show()

下图结果图。相对来说,数据构成越复杂,有较大的跳跃/含噪声,预测结果越差。
在这里插入图片描述在这里插入图片描述在这里插入图片描述

二、RNN

1. RNN的接口

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CNN、RNN用于时间序列预测的代码接口和数据格式详解(pytorch) 的相关文章

随机推荐

  • Markdown高级

    Markdown高级 警告 Markdown 未正式支持的解决方法 概述 大多数使用 Markdown 的人会发现基本和扩展的语法元素可以满足他们的需求 但很有可能 如果你使用 Markdown 的时间足够长 你会不可避免地发现它不支持你需
  • 微信小程序报错 Error: errCode: -1

    如果你是因为请求云数据库内的数据 那就是权限问题 解决方法如下 勾选 所有用户可读 仅创建者可读写 如果你需要让所有用户都可读写那要怎么办呢 答案是创建云函数 调用云函数写入数据库 因为云函数就是创建者权限
  • RISC-V IDE MRS使用笔记(三):提升浮点计算效率

    RISC V IDE MRS使用笔记 三 提升浮点计算效率 MRS内置CH32V30X系列芯片 此系列芯片支持FPU 浮点计算单元 想要打开时需要开启相应的扩展 如下图所示 此时如果编译单精度类型的浮点变量 就会启用FPU进行浮点计算 提高
  • windows 64位 安装mvn提示 不是内部或外部命令

    在安装mvn的过程中当在mvn的目录下去执行mvn命令的时候是可以正常执行的 当设置好环境变量后执行后发现提示mvn不是内部命令 原因是设置的MAVEN HOME变量未被Path解析 解决办法是 直接把path中的 MAVEN HOME b
  • vue子组件弹窗 el-dialog

    弹窗 在vue中结合elementui 在父组件中通过点击事件 控制子组件弹窗显示 将父组件的值通过props的方式传递给子组件弹窗 子组件通过computed的方式来获取值和设置新值 在set中调用 emit update porp 方法
  • 51nod 1165 整边直角三角形的数量(两种解法)

    链接 http www 51nod com Challenge Problem html problemId 1165 直角三角形 三条边的长度都是整数 给出周长N 求符合条件的三角形数量 例如 N 120 共有3种不同的满足条件的直角3角
  • Understanding and Detecting Software Upgrade Failures in Distributed Systems

    分布式系统中的升级故障 Tips 摘要 介绍 方法论 升级故障的严重程度 升级故障的根本原因 升级故障的触发因素 测试和检测升级故障 未来研究方向 相关工作和总结 后记 Tips 作者主页 论文下载地址 摘要 升级是破坏分布式系统可用性的不
  • iOS内购(IAP,In App Purchases-在APP内部支付),设置及使用

    项目中使用到了中间货币 金币 的形式来进行功能使用 模式是使用RMB换成 金币比如 1RMB 10金币 所以会集成第三方的支付平台 使用了微信和支付宝的第三方平台过后 发现审核失败 被苹果拒绝 查了一查原因 才是因为苹果对app内的中间币的
  • 分支语句简单讲

    分支语句之if语句 if 表达式 语句1 else 语句2 if语句中if else后默认只有一条语句 若要跟多条语句 要用 把语句括起来 例如下面 if age lt 18 printf 未成年 n printf 不能喝酒 n else
  • 如何开发kanzi插件,越详细越好

    开发Kanzi插件需要使用Kanzi SDK 它提供了一系列工具和技术 可以帮助开发者实现自己的设想 并将其应用在Kanzi上 首先 我们需要从Kanzi官网上下载Kanzi SDK 然后使用IDE 如Visual Studio或Eclip
  • Python代码规范:企业级代码静态扫描-代码规范、逻辑、语法、安全检查,以及代码规范自动编排(1)

    适用于企业实际使用Python或Python框架 Tornado Django Flask等 开发的项目作为扫描目标 进行代码规范 逻辑 语法 安全检查 代码风格规范主要有几个方面 命名规范 语言规范 格式规范 其中大部分命名规范和语言规范
  • rabbitmq命令小记录

    rabbitmq学习的一些链接 http blog csdn net anzhsoft article details 19563091 检查是否有内存泄露 sudo rabbitmqctl list queues name message
  • JAVA毕业设计课设源码分享50+例

    1 基于Springboot员工薪资管理系统 2 基于server jsp智能化停车场管理系统 3 基于SSM网上点餐系统 4 基于springboot商城购物系统 5 基于springboot中小学教务管理系统 6 基于springboo
  • 如何搭建测试环境

    1 首先检查环境和本地网络是否正确 环境就是检查系统版本是否符合开发要求 系统与本地是否能连接 2 找开发要软件包 安装数据库和服务器 把压缩包拖入 或rz 一键安装或单个yum install 3 上传项目包 确认上传的路径 文档 开发
  • vs更换本地git账号

    有人认为vs中用的git账号是哪个无所谓 其实不然 git账号不同 访问的权限就不一样 那么如果想跟换git账号该怎么做呢 win7 控制面板 gt 用户帐户和家庭安全 gt 凭据管理器 编辑普通凭据中的git账号或者直接删除 然后重启vs
  • rdesktop架构解析(RDP协议分析)

    转载自 http blog csdn net songbohr article details 5309650 本文立足于rdesktop的架构层次进行解析 算是抛砖引玉 国内对RDP协议深入解析的资料到本文发布时为空白 ps 昨天在nok
  • UVM基础-sequence library

    一 sequence library的用法 1 1 sequence library在环境中的使用 uvm sequence library定义为是一堆sequence的集合 本质上其实就是uvm sequence 只不过在普通的uvm s
  • Python实现一个情人节必备表白神器——跳动的爱心,基于tkinter实现

    前言 包子们 晚上好 一般能够看到这篇文章的小伙伴 不是单身狗 那也得是一个贵族 如果你有心仪的对象啦 如果你想表白一个女生啦 如果你还在想着怎么表白女神 这不是就给大家安排好了 跳动的爱心 怎么说呢 用这个表白也可以的 万一就成了呢 哈哈
  • GOF设计模式(04)桥接模式

    简介 一 定义 1 概念 桥接 Bridge 模式 将抽象部分与其实现部分分离 使得他们都可以独立地变化 它是一种对象结构型模式 又称为接口模式 桥接符合开闭原则和单一职责原则 2 理解 在使用桥接模式时 我们首先应该识别出一个类所具有的两
  • CNN、RNN用于时间序列预测的代码接口和数据格式详解(pytorch)

    网上对时序问题的代码详解很少 这里自己整理对CNN和RNN用于时序问题的代码部分记录 便于深入理解代码每步的操作 本文中涉及的代码 https github com EavanLi CNN RNN TSF a toy 一 1D CNN 1