TF多层 LSTM 以及 State 之间的融合

2023-05-16

第一是实现多层的LSTM的网络;
第二是实现两个LSTM的state的concat操作, 分析 state 的结构.

对于第一个问题,之前一直没有注意过, 看下面两个例子:

在这里插入代码片
import tensorflow as tf

num_units = [20, 20]

#Unit1, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = [tf.nn.rnn_cell.BasicLSTMCell(num_units=units) for units in num_units]
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

#Unit2, OK
# X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
# X = tf.reshape(X, [-1, 5, 6])
# multi_rnn = []
# for i in range(2):
#     multi_rnn.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units[i]))
# lstm_cells = tf.contrib.rnn.MultiRNNCell(multi_rnn)
# output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

# Unit3 *********ERROR***********
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])
# single_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=20) # same as below
lstm_cells = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(num_units=20)] * 2)
output, state = tf.nn.dynamic_rnn(lstm_cells, X, time_major=True, dtype=tf.float32)

print(output)
print(state)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for var in tf.global_variables():
        print(var.op.name)
    output_run, state_run = sess.run([output, state])
   
 之前还真没注意到这个问题, 虽然一般都是多层的维度一致,但是都是写成 Unit2 这种形式.

第二个问题两个 Encoder 的 State 的融合, 并保持 State 类型 (LSTM/GRU)

import tensorflow as tf

def concate_rnn_states(num_layers, encoder_state_local, encoder_state_global):
    '''
    :param num_layers:
    :param encoder_fw_state:
    For LSTM:
    (LSTMStateTuple(c=<tf.Tensor 'encoder1/rnn/while/Exit_3:0' shape=(3, 20) dtype=float32>,
        h=<tf.Tensor 'encoder1/rnn/while/Exit_4:0' shape=(3, 20) dtype=float32>),
    LSTMStateTuple(c=<tf.Tensor 'encoder1/rnn/while/Exit_5:0' shape=(3, 20) dtype=float32>,
        h=<tf.Tensor 'encoder1/rnn/while/Exit_6:0' shape=(3, 20) dtype=float32>))
    For GRU:
    (<tf.Tensor 'encoder1/rnn/while/Exit_3:0' shape=(3, 20) dtype=float32>,
        <tf.Tensor 'encoder1/rnn/while/Exit_4:0' shape=(3, 20) dtype=float32>)
    :param encoder_bw_state: same as fw
    :return: tuple
    '''
    encoder_states = []
    for i in range(num_layers):
        if isinstance(encoder_state_local[i], tf.nn.rnn_cell.LSTMStateTuple):
            # for lstm
            encoder_state_c = tf.concat(values=(encoder_state_local[i].c, encoder_state_global[i].c), axis=1,
                                        name="concat_layer{}_state_c".format(i))
            encoder_state_h = tf.concat(values=(encoder_state_local[i].h, encoder_state_global[i].h), axis=1,
                                        name="concat_layer{}_state_h".format(i))
            encoder_state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
        elif isinstance(encoder_state_local[i], tf.Tensor):
            # for gru and rnn
            encoder_state = tf.concat(values=(encoder_state_local[i], encoder_state_global[i]), axis=1,
                                      name='GruOrRnn_concat')

        encoder_states.append(encoder_state)
    return tuple(encoder_states)

num_units = [20, 20]

#Unit1
X = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)
X = tf.reshape(X, [-1, 5, 6])

with tf.variable_scope("encoder1") as scope:
    local_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    local_lstm_cells = tf.contrib.rnn.MultiRNNCell(local_multi_rnn)
    encoder_output_local, encoder_state_local = tf.nn.dynamic_rnn(local_lstm_cells, X, time_major=False, dtype=tf.float32)

with tf.variable_scope("encoder2") as scope:
    global_multi_rnn = [tf.nn.rnn_cell.GRUCell(num_units=units) for units in num_units]
    global_lstm_cells = tf.contrib.rnn.MultiRNNCell(global_multi_rnn)
    encoder_output_global, encoder_state_global = tf.nn.dynamic_rnn(global_lstm_cells, X, time_major=False, dtype=tf.float32)

print("concat output")
encoder_outputs = tf.concat([encoder_output_local, encoder_output_global], axis=-1)
print(encoder_output_local)
print(encoder_outputs)

print("concat state")
print(encoder_state_local)
print(encoder_state_global)
encoder_states = concate_rnn_states(2, encoder_state_local, encoder_state_global)
print(encoder_states)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TF多层 LSTM 以及 State 之间的融合 的相关文章

随机推荐

  • 6个常用的React组件库

    Ant Design 项目链接 xff1a Ant Design 包大小 xff08 来自 BundlePhobia xff09 xff1a 缩小后 1 2mB xff0c 缩小 43 gzip 压缩后 349 2kB xff0c 通过摇树
  • 大数据培训课程数据清洗案例实操-简单解析版

    数据清洗 xff08 ETL xff09 在运行核心业务MapReduce程序之前 xff0c 往往要先对数据进行清洗 xff0c 清理掉不符合用户要求的数据 清理的过程往往只需要运行Mapper程序 xff0c 不需要运行Reduce程序
  • 宋红康2023版Java视频发布

    1500万 43 播放量见证经典 xff0c 尚硅谷宋红康老师的Java入门视频堪称神作 xff0c 如今经典再次超级进化 xff0c 新版Java视频教程震撼来袭 xff01 开发环境全新升级 xff1a JDK17 43 IDEA202
  • Java消息队列:消息在什么时候会变成Dead Letter?

    在较为重要的业务队列中 xff0c 确保未被正确消费的消息不被丢弃 xff0c 通过配置死信队列 xff0c 可以让未正确处理的消息暂存到另一个队列中 xff0c 待后续排查清楚问题后 xff0c 编写相应的处理代码来处理死信消息 一 什么
  • Vue2和Vue3数据双向绑定原理的区别及优缺点(下篇)

    上篇我们讲到了Vue2的数据双向绑定原理 xff0c 如果你没有阅读上篇 xff0c 建议先阅读一下上篇中的内容 Vue2和Vue3数据双向绑定原理的区别及优缺点 xff08 上篇 xff09 在上篇中我们抛出了一个问题 xff1a 是不是
  • FlinkTable时间属性

    像窗口 xff08 在 Table API 和 SQL xff09 这种基于时间的操作 xff0c 需要有时间信息 因此 xff0c Table API 中的表就需要提供逻辑时间属性来表示时间 xff0c 以及支持时间相关的操作 一 处理时
  • kafka学习(1)

    目录 kafka是什么 xff1f 为什么要用kafka kafka的特点 kafka结构 Kafka Producer的Ack机制 kafka是什么 xff1f 收集nginx日志 xff0c 将nginx日志的关键字段进行分析 xff0
  • spss 因子分析

    是通过研究变量间的相关系数矩阵 xff0c 把这些变量间错综复杂的关系归结成少数几个综合因子 xff0c 并据此对变量进行分类的一种统计方法 xff0c 归结出的因子个数少于原始变量的个数 xff0c 但是他们又包含原始变量的信息 xff0
  • Hive 报错 Invalid column reference 列名

    两张表 当我执行 select m movieid m moviename substr m moviename 5 4 as years avg r rate as avgScore FROM t movie as m join t ra
  • 20数学建模C-中小微企业的信贷决策

    前言 源码文末获取 小编在 9 月份参加了今年的数学建模 xff0c 成绩怎么样不知道 xff0c 能有个成功参与奖就不错了哈哈 最近整理了一下 xff0c 写下这篇文章分享小编的思路 能力知识水平有限 xff0c 欢迎各位大佬前来指教 o
  • playwright 爬虫使用

    官方文档 xff1a Getting started Playwright Python 参考链接 xff1a 强大易用 xff01 新一代爬虫利器 Playwright 的介绍 目录 安装 基本使用 代码生成 AJAX 动态加载数据获取
  • kmeans聚类选择最优K值python实现

    来源 xff1a https www omegaxyz com 2018 09 03 k means find k 下面利用python中sklearn模块进行数据聚类的K值选择 数据集自制数据集 xff0c 格式如下 xff1a 维度为3
  • mysql构造页损坏

    构造页损坏 及修复方式可参考 gg gMysql页面crash问题复现 amp 恢复方法 阿里云开发者社区 也可通过 dd 命令进行构造 dd xff0c 命令参考 xff1a Linux dd 命令 菜鸟教程
  • mysql审计日志过滤sql功能

    审计日志功能是一个插件 xff0c 需要先安装插件才可以使用 过滤 sql 语句 xff0c 可以通过插件内核参数 audit log include commands 与 audit log exclude commands 参数设置 x
  • setDaemon python守护进程,队列通信子线程

    使用setDaemon 和守护线程这方面知识有关 xff0c 比如在启动线程前设置thread setDaemon True xff0c 就是设置该线程为守护线程 xff0c 表示该线程是不重要的 进程退出时不需要等待这个线程执行完成 这样
  • 中文与 \u5927\u732a\u8e44\u5b50 这一类编码互转

    了解更多关注微信公众号 木下学Python 吧 a 61 39 大猪蹄子 39 a 61 a encode 39 unicode escape 39 print a 运行结果 xff1a b 39 u5927 u732a u8e44 u5b
  • python字典删除键值对

    https blog csdn net uuihoo article details 79496440
  • 计算机网络(4)传输层

    目录 小知识点 xff1a 三次握手 xff1a 状态 xff1a tcpdump xff1a 一 xff1a 命令介绍 xff1a 二 xff1a 命令选项 xff1a tcpdump的表达式 xff1a 使用python扫描LAN工具
  • MSE 治理中心重磅升级-流量治理、数据库治理、同 AZ 优先

    作者 xff1a 流士 本次 MSE 治理中心在限流降级 数据库治理及同 AZ 优先方面进行了重磅升级 xff0c 对微服务治理的弹性 依赖中间件的稳定性及流量调度的性能进行全面增强 xff0c 致力于打造云原生时代的微服务治理平台 前情回
  • TF多层 LSTM 以及 State 之间的融合

    第一是实现多层的LSTM的网络 第二是实现两个LSTM的state的concat操作 分析 state 的结构 对于第一个问题 之前一直没有注意过 看下面两个例子 在这里插入代码片 import tensorflow as tf num u