Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型

2023-10-30

# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author   : Meng Li
# @contact: 925762221@qq.com
# @FILE     : torch_seq2seq.py
# @Time     : 2022/6/8 11:11
# @Software : PyCharm
# @site:
# @Description : 将Seq2Seq网络采用编码器和解码器两个类进行融合
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchsummary
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class my_dataset(Dataset):
    def __init__(self, enc_input, dec_input, dec_output):
        super().__init__()
        self.enc_input = enc_input
        self.dec_input = dec_input
        self.dec_output = dec_output

    def __getitem__(self, index):
        return self.enc_input[index], self.dec_input[index], self.dec_output[index]

    def __len__(self):
        return self.enc_input.size(0)


class Encoder(nn.Module):
    def __init__(self, in_features, hidden_size):
        super().__init__()
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.encoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1)  # encoder

    def forward(self, enc_input):
        seq_len, batch_size, embedding_size = enc_input.size()
        h_0 = torch.rand(1, batch_size, self.hidden_size)
        c_0 = torch.rand(1, batch_size, self.hidden_size)
        # en_ht:[num_layers * num_directions,Batch_size,hidden_size]
        encode_output, (encode_ht, decode_ht) = self.encoder(enc_input, (h_0, c_0))
        return encode_output, (encode_ht, decode_ht)


class Decoder(nn.Module):
    def __init__(self, in_features, hidden_size):
        super().__init__()
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.crition = nn.CrossEntropyLoss()
        self.fc = nn.Linear(hidden_size, in_features)
        self.decoder = nn.LSTM(input_size=in_features, hidden_size=hidden_size, dropout=0.5, num_layers=1)  # encoder

    def forward(self, enc_output, dec_input):
        (h0, c0) = enc_output
        # en_ht:[num_layers * num_directions,Batch_size,hidden_size]
        de_output, (_, _) = self.decoder(dec_input, (h0, c0))
        return de_output


class Seq2seq(nn.Module):
    def __init__(self, encoder, decoder, in_features, hidden_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.fc = nn.Linear(hidden_size, in_features)
        self.crition = nn.CrossEntropyLoss()

    def forward(self, enc_input, dec_input, dec_output):
        enc_input = enc_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]
        dec_input = dec_input.permute(1, 0, 2)  # [seq_len,Batch_size,embedding_size]
        # output:[seq_len,Batch_size,hidden_size]
        _, (ht, ct) = self.encoder(enc_input)  # en_ht:[num_layers * num_directions,Batch_size,hidden_size]
        de_output = self.decoder((ht, ct), dec_input)  # de_output:[seq_len,Batch_size,in_features]
        output = self.fc(de_output)
        output = output.permute(1, 0, 2)
        loss = 0
        for i in range(len(output)):  # 对seq的每一个输出进行二分类损失计算
            loss += self.crition(output[i], dec_output[i])
        return output, loss


def make_data(seq_data):
    enc_input_all, dec_input_all, dec_output_all = [], [], []
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    word2idx = {j: i for i, j in enumerate(vocab)}
    V = np.max([len(j) for i in seq_data for j in i])  # 求最长元素的长度
    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + '?' * (V - len(seq[i]))  # 'man??', 'women'

        enc_input = [word2idx[n] for n in (seq[0] + 'E')]
        dec_input = [word2idx[i] for i in [i for i in len(enc_input) * '?']]
        dec_output = [word2idx[n] for n in (seq[1] + 'E')]

        enc_input_all.append(np.eye(len(vocab))[enc_input])
        dec_input_all.append(np.eye(len(vocab))[dec_input])
        dec_output_all.append(dec_output)  # not one-hot

    # make tensor
    return torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)


def translate(word):
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    idx2word = {i: j for i, j in enumerate(vocab)}
    V = 5
    x, y, z = make_data([[word, "?" * V]])
    if not os.path.exists("translate.pt"):
        train()
    net = torch.load("translate.pt")
    pre, loss = net(x, y, z)
    pre = torch.argmax(pre, 2)[0]
    pre_word = [idx2word[i] for i in pre.numpy()]
    pre_word = "".join([i.replace("?", "") for i in pre_word])
    print(word, "->  ", pre_word[:pre_word.index('E')])


def train():
    vocab = [i for i in "SE?abcdefghijklmnopqrstuvwxyz上下人低国女孩王男白色高黑"]
    word2idx = {j: i for i, j in enumerate(vocab)}
    idx2word = {i: j for i, j in enumerate(vocab)}
    seq_data = [['man', '男人'], ['black', '黑色'], ['king', '国王'], ['girl', '女孩'], ['up', '上'],
                ['high', '高'], ['women', '女人'], ['white', '白色'], ['boy', '男孩'], ['down', '下'], ['low', '低'],
                ['queen', '女王']]
    enc_input, dec_input, dec_output = make_data(seq_data)
    batch_size = 3
    in_features = len(vocab)
    hidden_size = 128

    train_data = my_dataset(enc_input, dec_input, dec_output)
    train_iter = DataLoader(train_data, batch_size, shuffle=True)

    encoder = Encoder(in_features, hidden_size)
    decoder = Decoder(in_features, hidden_size)
    net = Seq2seq(encoder, decoder, in_features, hidden_size)
    learning_rate = 0.001
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    loss = 0

    for i in range(1000):
        for en_input, de_input, de_output in train_iter:
            output, loss = net(en_input, de_input, de_output)
            pre = torch.argmax(output, 2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if i % 100 == 0:
            print("step {0} loss {1}".format(i, loss))
    torch.save(net, "translate.pt")


if __name__ == '__main__':
    before_test = ['man', 'black', 'king', 'girl', 'up', 'high', 'women', 'white', 'boy', 'down', 'low', 'queen',
                   'mman', 'woman']
    [translate(i) for i in before_test]
    # train()

仍然先上代码,接上一篇文章,这里将Seq2Seq模型个构建采用Encoder类和Decoder类融合起来

主要是为了后面的Attention作铺垫

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型 的相关文章

  • Lucene 标准分析器与 Snowball

    刚刚开始使用 Lucene Net 我使用标准分析器索引了 100 000 行 运行了一些测试查询 并注意到如果原始术语是单数 则复数查询不会返回结果 我知道雪球分析器增加了词干支持 这听起来不错 不过 我想知道 超过标准的雪球锣是否有任何
  • Google Colab 使用 Transformers 和 PyTorch 微调 BERT Base Case 时出现间歇性“RuntimeError: CUDA out of memory”错误

    我正在运行以下代码来微调 Google Colab 中的 BERT Base Cased 模型 有时代码第一次运行良好 没有错误 其他时候 相同的代码使用相同的数据 会导致 CUDA 内存不足 错误 以前 重新启动运行时或退出笔记本 返回笔
  • Pytorch - 推断线性层 in_features

    我正在构建一个玩具模型来获取一些图像并进行分类 我的模型看起来像 conv2d gt pool gt conv2d gt linear gt linear 我的问题是 当我们创建模型时 我们必须计算第一个线性层的大小in features基
  • 计算机AI算法写句子?

    我正在寻找有关处理文本句子或在创建在正常人类语言 例如英语 中有效的句子时遵循结构的算法的信息 我想知道这个领域是否有我可以学习或开始使用的项目 例如 如果我给一个程序一个名词 为其提供同义词库 相关单词 和词性 以便它理解每个单词在句子中
  • 如何将 35 类城市景观数据集转换为 19 类?

    以下是我的代码的一小段 使用它 我可以在城市景观数据集上训练名为 lolnet 的模型 但数据集包含 35 个类别 标签 0 34 imports trainloader torch utils data DataLoader datase
  • FastText - 由于 C++ 扩展未能分配内存,无法加载 model.bin

    我正在尝试使用 FastText Python APIhttps pypi python org pypi fasttext https pypi python org pypi fasttext虽然 据我所知 此 API 无法加载较新的
  • 对产品列表进行分类的算法? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一个代表或多或少相同的产品的列表 例如 在下面的列表中 它们都是希捷硬盘 希捷硬盘 500Go 适用于笔记本电脑的希捷硬盘 120
  • 如何使用 lstm 执行多类多输出分类

    I have multiclass multioutput classification see https scikit learn org stable modules multiclass html https scikit lear
  • NLTK 可用的停用词语言

    我想知道在哪里可以找到 NLTK 停用词支持的语言 及其键 的完整列表 我找到一个列表https pypi org project stop words https pypi org project stop words 但它不包含每个国家
  • R 中带有变音符号的字符列表

    我试图将字符串中的电话 字符 出现次数制成表格 但变音符号单独作为字符制成表格 理想情况下 我有一个国际音标的单词列表 其中包含大量变音符号以及它们与基本字符的几种组合 我在这里给出了仅包含一个单词的 MWE 但对于单词列表和更多类型的组合
  • 使用 NLP 进行句子压缩 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 使用机器翻译 我可以获得一个句子的非常压缩的版本 例如 我真的很想喝一杯美味可口的咖啡将被翻译为我想喝咖
  • 如何改进 NLTK 中的荷兰语 NER 词块划分器

    感谢这个伟大的答案 我使用 NLTK 和 Conll2002 语料库训练自己的荷兰语 NE 词块划分器 有了一个良好的开端 NLTK 荷兰语命名实体识别 https stackoverflow com questions 11293149
  • 更换色谱柱时出现稀疏效率警告

    def tdm modify feature names tdm non useful words kill stampede trigger cause death hospital minister said told say inju
  • 在Python中表示语料库句子的一种热门编码

    我是 Python 和 Scikit learn 库的初学者 我目前需要从事一个 NLP 项目 该项目首先需要通过 One Hot Encoding 来表示一个大型语料库 我已经阅读了 Scikit learn 关于 preprocessi
  • PyTorch 中的数据增强

    我对 PyTorch 中执行的数据增强有点困惑 现在 据我所知 当我们执行数据增强时 我们保留原始数据集 然后添加它的其他版本 翻转 裁剪 等 但 PyTorch 中似乎并没有发生这种情况 据我从参考文献中了解到 当我们使用data tra
  • 如何在 scikit-learn 的 SVM 中使用非整数字符串标签? Python

    Scikit learn 具有相当用户友好的用于机器学习的 python 模块 我正在尝试训练用于自然语言处理 NLP 的 SVM 标记器 其中我的标签和输入数据是单词和注释 例如 词性标记 而不是使用双精度 整数数据作为输入元组 1 2
  • 带有填充掩码的 TransformerEncoder

    我正在尝试使用 src key padding mask 不等于 none 来实现 torch nn TransformerEncoder 想象输入的形状src 20 95 二进制填充掩码的形状为src mask 20 95 填充标记的位置
  • Pytorch LSTM:计算交叉熵损失的目标维度

    我一直在尝试在 Pytorch 中使用 LSTM LSTM 后跟自定义模型中的线性层 但在计算损失时出现以下错误 Assertion cur target gt 0 cur target lt n classes failed 我用以下函数
  • 如何使用 PyTorch 沿特定维度进行热编码?

    我有一个大小的张量 3 15 136 where 3 is batch size 15 sequence length and 136 is tokens 我想使用中的概率来单热我的张量tokens维度 136 为此 我想提取序列长度中每个
  • NLTK 2.0分类器批量分类器方法

    当我运行此代码时 它会抛出一个错误 我认为这是由于 NLTK 3 0 中不存在batch classify 方法 我很好奇如何解决旧版本中的某些内容在新版本中消失的此类问题 def accuracy classifier gold resu

随机推荐

  • M1 Mac创建虚拟环境遇到的问题

    报错信息 PackagesNotFoundError The following packages are not available from current channels python 3 7 Current channels ht
  • 数据结构:线性表(栈的实现)

    文章目录 1 栈 Stack 1 1 栈的概念 1 2 栈的结构 链表栈 数组栈 2 栈的定义 3 栈的实现 3 1 初始化栈 StackInit 3 2 入栈 StackPush 3 3 出栈 StackPop 3 4 检测栈是否为空 S
  • Micro SD Card参数基本介绍

    Micro SD Card原名Trans Flash Card或T Flash Card 由SanDisk 闪迪 公司发明 目前主要用于可移动设备储存 比如数码相机 手机 MP4等可移动设备 一 品牌标识 Micro SD Card虽是Sa
  • c++定义变量存放rsa密匙_解惑RSA

    RSA的算法 最好的文章是 somenzz 一文搞懂 RSA 算法 文章直接举例 让人很容易理解 RSA 中的 p q n m e d 的关系 一个关键点就是 n e 组成公钥 n d 组成私钥 公私钥并不是一个独立的大数 而是一个二元组
  • 低代码、中台化,中国ERP迎来产业变局

    对中国的ERP而言 在思维和观念背后 需要的是底层技术产品的重新解构和对生态的重新理解和构建 这对应的不单纯是面向客户层面的需求服务 更是对自身企业模型的优化调整 道阻且长 但值得肯定的是 它正在迈出更好的下一步 作者 斗斗 编辑 皮爷 出
  • docker从安装到部署项目,一篇文章教会你

    1 什么是Docker 首先看下 Docker 图标 一条小鲸鱼上面有些集装箱 比较形象的说明了 Docker 的特点 以后见到这个图标等同见到了 Docker 1 1 容器技术 1 Docker 是一个开源的应用容器引擎 它基于 Go 语
  • Android ViewModel,Lifecycles和LiveData组件讲解

    文章目录 一 ViewModel ViewModel基本用法 向ViewModel传递参数 二 Lifecycles 三 LiveData LiveData的基本用法 map和switchMap JetPack是一个开发组件工具集 他的主要
  • js empty() vs remove()

    转自 jQuery empty vs remove empty will remove all the contents of the selection remove will remove the selection and its c
  • 如何将MySQL中指定的表结构同步到人大金仓数据库

    场景 刚开始做数据库适配的时候 这是一个棘手的问题 因为MySQL的库里 表结构 字段都是最新的 但是金仓的库 全是旧版本的表结构 需要把我们模块的表结构 同步到金仓中 虽然金仓有数据库同步工具 但是直接把所有表都给同步过来 难免会影响到其
  • java使用flatbuffer基础篇

    先放上flatbuffer的github链接flatbuffer 里面可以直接下载针对模板文件生成代码exe程序和所有支持语言的库代码 之前写的一套系统是http的 里面也用到了websocket 但是服务器都是放在国外的 国内的电信运营商
  • 动态规划-国王与金矿

    动态规划 国王与金矿 图文有点长 慢慢看 题目 有一座高度是10级台阶的楼梯 从下往上走 每跨一步只能向上1级或者2级台阶 要求用程序来求出一共有多少种走法 比如 每次走1级台阶 一共走10步 这是其中一种走法 我们可以简写成 1 1 1
  • python+django学习资料在线分享系统vue

    本站是一个B S模式系统 采用vue框架作为开发技术 MYSQL数据库设计开发 充分保证系统的稳定性 系统具有界面清晰 操作简单 功能齐全的特点 使得校园资料分享平台管理工作系统化 规范化 技术栈 后端 python django 前端 v
  • kafka整合lua消费不到数据解决方案

    用lua脚本将前端页面获取到的数据塞给kafka kafka不报错 nginx不报错 lua脚本也没有问题 topic生成了但就是消费不到数据 自己写一个生产者测试过证明消费者也没问题 折腾了很久 最后在kafka配置文件中加了host n
  • springboot中jsp配置tiles

    tiles是jsp的前端框架 像fream标签一样可以把多个页面组合起来 完成后的目录结构 1 pom xml中添加依赖
  • 深度Linux怎样关闭休眠,deepin如何休眠,

    deepin如何休眠 deepin官网 休眠这个功能还是很酷很实用的 对于Linux系统 休眠一般就是把内存中的数据写入硬盘 swap文件 然后关机 在下一次开机的时候将数据重新载入内存 让你快速回到上一次的工作状态 这在你开启了大量的程序
  • NATAPP使用详细教程(免费隧道内网映射)

    NATAPP https natapp cn tunnel lists NATAPP 在开发时可能会有将自己开发的机器上的应用提供到公网上进行访问 但是并不想通过注册域名 搭建服务器 由此可以使用natapp 内网穿透 购买免费隧道 修改隧
  • 动手学强化学习Day1-基本概念

    文章目录 1 1 什么是强化学习 1 2 强化学习的环境 1 3 强化学习的目标 1 4 强化学习的数据 1 5 强化学习的特征 1 1 什么是强化学习 在机器学习领域 有一类重要的任务和人生选择很相似 即序贯决策 sequential 任
  • C++ win32编程 02 常见消息

    02 常见消息 1 打印消息相关信息 1 1 将消息内容转化为字符串 第一步 定义字符串变量 用来保存转化后的消息 wchar t szInfo 300 定义消息内容变量 第二步 用宽字符格式化函数转化消息内容 wsprintf szInf
  • Ethereum开发

    Ethereum开发 1 简介 1 下载源码 使用Git Bath git clone https github com ethereum go ethereum git 或者使用浏览器下载 2 下载安装包 根据您的系统选择下载 2 官方网
  • Pytorch+LSTM+Encoder+Decoder实现Seq2Seq模型

    usr bin env Python3 coding utf 8 version v1 0 Author Meng Li contact 925762221 qq com FILE torch seq2seq py Time 2022 6