RNN Pytorch实现——up主:刘二大人《PyTorch深度学习实践》

2023-11-19

b站up主:刘二大人《PyTorch深度学习实践》
教程: https://www.bilibili.com/video/BV1Y7411d7Ys?p=6&vd_source=715b347a0d6cb8aa3822e5a102f366fe
单层 R N N : t o r c h . n n . R N N + E m b e d d i n g + F C 交叉熵损失函数: n n . C r o s s E n t r o p y L o s s 优化器: o p t i m . A d a m 数据集: h e l l o → 期望输出 o h l o l 单层RNN:torch.nn.RNN+Embedding+FC \\交叉熵损失函数:nn.CrossEntropyLoss \\优化器:optim.Adam \\数据集:hello→期望输出ohlol 单层RNN:torch.nn.RNN+Embedding+FC交叉熵损失函数:nn.CrossEntropyLoss优化器:optim.Adam数据集:hello期望输出ohlol

网络结构:
在这里插入图片描述
训练过程:

在这里插入图片描述

源码:

import torch

input_size = 4
hidden_size = 4
batch_size = 1
embedding_size = 10
num_class = 4
num_layers = 1

idx2char = ['e', 'h', 'l', 'o']
x_data = [[1,0,2,2,3]] #(batch, seq_len)
y_data = [3,1,2,3,2] #(batch, seq_len)

inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)

class Model(torch.nn.Module):
  def __init__(self, input_size, hidden_size, batch_size, num_layers = 1):
    super(Model, self).__init__()
    self.batch_size=batch_size #批量大小
    self.input_size=input_size #freature in X, x_t的维度
    self.hidden_size=hidden_size #hidden层向量的维度 h_t的维度
    self.num_layers=num_layers #层数(上下)
    self.emb = torch.nn.Embedding(input_size, embedding_size)

    self.rnn = torch.nn.RNN(input_size=embedding_size,
              hidden_size=self.hidden_size,
              num_layers=num_layers,
              batch_first=True)
    self.fc = torch.nn.Linear(hidden_size, num_class)

  def forward(self, x):
    hidden = torch.zeros(self.num_layers,
              self.batch_size,
              self.hidden_size)
    x = self.emb(x)
    x,_ = self.rnn(x, hidden)
    x = self.fc(x)
    return x.view(-1, num_class)

net=Model(input_size=input_size, 
    hidden_size=hidden_size,
    batch_size=batch_size,
    num_layers=num_layers)

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(), lr=0.05)

for epoch in range(0, 30):
  optimizer.zero_grad()
  outputs = net(inputs)
  loss = criterion(outputs,labels)
  loss.backward()
  optimizer.step()

  _,idx=outputs.max(dim=1)
  idx=idx.data.numpy()
  print('Predicted:',''.join([idx2char[x] for x in idx]),end='')
  print(',Epoch[%d/30]loss=%.3f'%(epoch,loss.item()))


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

RNN Pytorch实现——up主:刘二大人《PyTorch深度学习实践》 的相关文章

  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • 使 CUDA 内存不足

    我正在尝试训练网络 但我明白了 我将批量大小设置为 300 并收到此错误 但即使我将其减少到 100 我仍然收到此错误 更令人沮丧的是 在 1200 个图像上运行 10 epoch 大约需要 40 分钟 有什么建议吗 错了 我怎样才能加快这
  • 为什么 pytorch matmul 在 cpu 和 gpu 上执行时得到不同的结果?

    我试图找出 numpy pytorch gpu cpu float16 float32 数字之间的舍入差异 而我发现的内容让我感到困惑 基本版本是 a torch rand 3 4 dtype torch float32 b torch r
  • Pytorch Tensor 如何获取元素索引? [复制]

    这个问题在这里已经有答案了 我有 2 个名为x and list它们的定义如下 x torch tensor 3 list torch tensor 1 2 3 4 5 现在我想获取元素的索引x from list 预期输出是一个整数 2
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • PyTorch 中的交叉熵

    交叉熵公式 但为什么下面给出loss 0 7437代替loss 0 since 1 log 1 0 import torch import torch nn as nn from torch autograd import Variable
  • 在Pytorch中计算欧几里得范数..理解和实现上的麻烦

    我见过另一个 StackOverflow 线程讨论计算欧几里德范数的各种实现 但我很难理解特定实现的原因 如何工作 该代码可以在 MMD 指标的实现中找到 https github com josipd torch two sample b
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee
  • ValueError:使用火炬张量时需要解压的值太多

    对于神经网络项目 我使用 Pytorch 并使用 EMNIST 数据集 已经给出的代码加载到数据集中 train dataset dsets MNIST root data train True transform transforms T
  • PyTorch LSTM 中的“隐藏”和“输出”有什么区别?

    我无法理解 PyTorch 的 LSTM 模块 以及类似的 RNN 和 GRU 的文档 关于输出 它说 输出 输出 h n c n 输出 seq len batch hidden size num directions 包含RNN最后一层的
  • 无法在 Windows 10 上构建 Detectron2

    尽管 Windows 上的 Detectron2 没有官方支持 但有很多可用的说明 我尝试按照这些说明进行操作 但最终出现了相同的错误 这是我的设置 OS Windows 10 专业版 19043 1466 微软视觉工作室 2019 CUD
  • softmax_cross_entropy_with_logits 的 PyTorch 等效项

    我想知道 TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy with logits TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy
  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • 没有名为“torch”或“torch.C”的模块

    希望得到像我 5 这样的解释 因为我已经检查了所有相关答案 但没有一个有帮助 我已经安装了Python 我已经安装了Pycharm 我已经安装了Anaconda 我已经安装了 Microsoft Visual Studio 我有not安装了
  • 使用 Huggingface 变压器仅保存最佳权重

    目前 我正在使用 Huggingface transformers 构建一个新的基于 Transformer 的模型 其中注意力层与原始模型不同 我用了run glue py检查我的模型在 GLUE 基准测试上的性能 但是 我发现huggi
  • 如何让火车装载机使用特定数量的图像?

    假设我正在使用以下调用 trainset torchvision datasets ImageFolder root imgs transform transform trainloader torch utils data DataLoa
  • PyTorch 中的数据增强

    我对 PyTorch 中执行的数据增强有点困惑 现在 据我所知 当我们执行数据增强时 我们保留原始数据集 然后添加它的其他版本 翻转 裁剪 等 但 PyTorch 中似乎并没有发生这种情况 据我从参考文献中了解到 当我们使用data tra
  • 如何在 Google Colab 上安装 PyTorch v1.0.0+?

    PyTorch v1 0 0 稳定版是发布于 2018 年 12 月 8 日 https github com pytorch pytorch releases tag v1 0 0成为之后7个月前宣布 https code fb com
  • CUDA 与 DataParallel:为什么有区别?

    我有一个简单的神经网络模型 我应用cuda or DataParallel 在模型上如下所示 model torch nn DataParallel model cuda OR model model cuda 当我不使用 DataPara
  • 从 torch.autograd.gradcheck 导入 zero_gradients

    我想复制代码here https github com LTS4 DeepFool blob master Python deepfool py 并且我在 Google Colab 中运行时收到以下错误 ImportError 无法导入名称

随机推荐

  • VMWARE虚拟机更新Ubuntu卡在登陆界面的问题解决

    昨天在搭建开发环境的时候 需要安装一些图形包和升级系统的组件 升级重启后 发现系统进不去了 如下图所示 我的是VMWARE虚拟机 不存在独显驱动问题 所以排除这个问题 将lightdm组件重新装一次 问题可以解决 步骤如下 1 重启 看到如
  • Cuda Streams的概述(四)-- 同步

    同步 同步的APIs 同步所有的事情 阻塞host端 直到所有的CUDA调用完成 cudaDeviceSynchronize 同步主机端特定的流 阻塞host端 直到流里的CUDA调用完成 cudaStreamSynchronize str
  • PyQt开发样例: 利用QToolBox开发的桌面工具箱Demo

    老猿Python博文目录 专栏 使用PyQt开发图形界面Python应用 老猿Python博客地址 一 引言 toolBox工具箱是一个容器部件 对应类为QToolBox 在其内有一列从上到下顺序排列的标签部件项 tabbed widget
  • (转)AI技术能给金融带来什么

    AI技术能给金融带来什么 2017 04 13 今日投资官微 来源 维基百科 文因互联分析 人工智能的热潮被AlphaGo带到顶点 然而在人工智能的学科发展史上是有繁荣期和稳定期的 一个技术突破会带来一定时期内难以想象的繁荣 之后的科学发展
  • [Hadoop] 实际应用场景之 - 阿里

    http blog csdn net u010415792 article details 9151475 Hadoop在淘宝和支付宝的应用从09年开始 用于对海量数据的离线处理 例如对日志的分析 也涉及内容部分 结构化数据等 使用Hado
  • 阿里老哥独家珍藏的Java面试突击宝典,轻松应对95%秋招面试题

    临近秋招 想必有不少老哥已经在为面试做准备了 大家想必也知道现在面试就是看项目经验 基本技术 个人潜力 也就是值不值得培养 总之就是每一次面试都是对我们能力的检验 无论是软实力还是硬实力 软实力其实就是简历包装 自我介绍 与面试官交谈技巧等
  • UVA 1347 Tour

    描述 Click Here quad 给定平面上n n lt 1000 个点的坐标 按照x递增的顺序给出 各点x坐标不同 且均为整数 你的任务是设计一条路线 从最左边的点出发走到最右边的点再返回 要求除了最左边和最右边之外 每个点恰好经过一
  • java中什么是并发,如何解决?

    多个进程或线程同时 或着说在同一段时间内 访问同一资源会产生并发问题 银行两操作员同时操作同一账户就是典型的例子 比如A B操作员同时读取一余额为1000元的账户 A操作员为该账户增加100元 B操作员同时为该账户减去 50元 A先提交 B
  • awk读取ini配置文件

    awk读取ini配置文件 一 awk基础 二 读取ini 1 net ini文件 2 打印 三 读取特定Section的Key的值 1 设置特定值 2 查找匹配项 四 总结 一 awk基础 F 指定分割符 print 打印 0 表示整个当前
  • 软件项目管理与开发流程管理 课程

    软件项目管理与开发流程管理 课程背景 以 IT领域典型的软件开发项目管理为主线 结合业界公认最成功的Rational软件开发统一流程架构 Rational Unified Process 和奉为项目管理圣经的美国项目管理协会 PMI 项目管
  • mnist数据集及其读写格式

    mnist数据集及其读写格式 1 mnist 数据集 2 idx 数据格式 参考文献 1 mnist 数据集 mnist数据是手写的数字0 9的数据集 共包含训练集60000个样本和测试集10000个样本 mnist是NIST的子集 同时m
  • Sa-Token获取当前所有可用Token

    记录一下写的小工具 里面的逻辑只能让各位大佬自己看了 各位用的时候自己改改 TokenInfo这个类是自定义的 我这里是获取了一下当前token对应的用户最大的生命周期 各位大佬们自行享用了 获取所有有效token集合 return pub
  • Python -- The eric Python IDE

    Python The eric Python IDE ataru 21 Jul 2008 14 26 最近在看Qt4 結果就看到eric這個為了Python與Ruby開發的IDE工具 它本身是用Python Qt寫出來的 因此這正可以當作希
  • 我的世界1.12 Java崩溃,1.12.2崩溃报告求助

    官方认证版本 mod都是在百科或者CurseForge上下载的 每次崩溃都没得前兆 突然崩溃 之后打开存档就会直接崩溃 crash reports如下 Minecraft Crash Report WARNING coremods are
  • STM32使用HAL库实现按键的单击、双击、长按

    STM32使用HAL库实现按键的单击 双击 长按 目录 STM32使用HAL库实现按键的单击 双击 长按 前言 具体思路 工程配置 代码实现 实验效果 扫描以下二维码 关注公众号雍正不秃头获取更多STM32资源及干货 前言 编程开发环境 S
  • 【完结版】jmeter+ant+python自动化框架,且支持jenkins持续集成

    前言 本文是实现jmeter ant python脚本的自动化测试框架 并且把整套部署在jenkins 通过jenkins的构建来出发脚本的运行 而且还会在jenkins上展示html报告 本文记录搭建框架的整个步骤 以及遇到的问题和记录解
  • css 排除具有某个class的项

    idname li not classA display inline block 增加样式时将li中有class的去掉
  • 使用决策树进行特征选择

    使用决策树进行特征选择 决策树也是常用的特征选取方法 使用决策树集合 如随机森林等 也可以计算每个特征的相对重要性 这些重要性能够辅助进行特征选择 该方法主要使用信息增益率来进行特征选择 from sklearn import datase
  • Keil的软件仿真和硬件仿真

    一 软件仿真 Keil有很强大的软件仿真功能 通过软件仿真可以发现很多将要出现的问题 Keil的仿真可以查看很多硬件相关的寄存器 通过观察这些寄存器值的变化可以知道代码有没有正常运行 这样可以避免频繁下载程序 延长单片机Flash寿命 开始
  • RNN Pytorch实现——up主:刘二大人《PyTorch深度学习实践》

    b站up主 刘二大人 PyTorch深度学习实践 教程 https www bilibili com video BV1Y7411d7Ys p 6 vd source 715b347a0d6cb8aa3822e5a102f366fe 单层