tensorflow中control_flow_ops.while_loop

2023-05-16

self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
self.h0 = tf.stack([self.h0, self.h0])  ## 相当于 h0和C0

 # generator on initial randomness 
 gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length, dynamic_size=False, infer_shape=True)
 gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, dynamic_size=False, infer_shape=True)

 def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x): 
     h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple, h_t其实包含h_t和c_t, tm1即是t减去1
     o_t = self.g_output_unit(h_t)  # batch x vocab , logits
     log_prob = tf.log(tf.nn.softmax(o_t)) 
     next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32) ## next_token的shape是[batch]
     x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
     gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0), tf.nn.softmax(o_t)), 1))
     gen_x = gen_x.write(i, next_token)  # indices, batch_size
     return i + 1, x_tp1, h_t, gen_o, gen_x  ## x_tp1其实就是p就是Plus

 _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
     cond=lambda i, _1, _2, _3, _4: i < self.sequence_length, ## cond的值要么为True或者为False 
     body=_g_recurrence,
     loop_vars=(tf.constant(0, dtype=tf.int32), tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x)  
     )

这段代码出自文章《SeqGAN:Sequence Generative Adversarial Nets with Policy Gradient》源码模块target_lstm.py中,我其实是不太明白control_flow_ops.while_loop的用法,琢磨后为避免忘记特记录在此。

代 码 是 1 、 2 行 : \color{red}{代码是1、2行:} 12:
lstm或gru执行的初始状态

代 码 第 6 行 : \color{red}{代码第6行:} 6
gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, dynamic_size=False, infer_shape=True)
TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
我是不是可以理解成一个list,在代码第15行的时候即是将新生成的next_token写入到gen_x中
代 码 第 15 行 : \color{red}{代码第15行:} 15
gen_x = gen_x.write(i, next_token)
指定index位置写入Tensor, 我觉得write就类似与python中list的append方法,将生成的next_token存储到gen_x中
代 码 第 19 行 : \color{red}{代码第19行:} 19
cond=lambda i, _1, _2, _3, _4: i < self.sequence_length
这行代码是while_loop执行的条件,如果 i < self.sequence_length条件满足, 则cond=True, 执行control_flow_ops.while_loop这个循环,再看lambda表达式,其可以有任意多个形参,在这个表达式里有五个,分别是 i, _1, _2, _3, _4, 为什么是五个参数呢?这里暂且不说(问题1)

代 码 第 20 行 : \color{red}{代码第20行:} 20
body=_g_recurrence
循环主体,_g_recurrence函数已经定义(第8行到第16行),这个函数需要传入5个参数,所以在cand这个条件中需要定义5个形参(问题1的答案),如果cand = True, 就一直执行body,需要注意一点的,每次执行_g_recurrence这个body时参数的参数是不同的,是上一步执行的结果作为本次传入的参数

代 码 第 21 行 : \color{red}{代码第21行:} 21
loop_vars=(tf.constant(0, dtype=tf.int32), tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x)
loop_vars是循环起始参数,这五个是实参,对应与cand中五个形参

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow中control_flow_ops.while_loop 的相关文章

随机推荐

  • [svn]status命令

    wangyetao 64 linux u1604 LinuxRoom SVN FILE 个人空间 xx wangyetao 64 linux u1604 LinuxRoom SVN FILE 个人空间 xx svn help status
  • 维护型项目的管理

    最近 xff0c 一直在维护一个项目 项目很大 xff0c 有很多个系统相互配合 xff0c 且使用的语言也不一样 有JAVA写的系统 xff0c 有PHP写的 xff0c 各系统用的数据库也不一样 xff0c 还有一些我说不出来的技术 项
  • cas5.2.6 搭建cas服务端

    1 打包cas服务器端war包 下载cas overlay template 5 2 zip 1 1配置pom xml lt dependencies gt lt dependency gt lt groupId gt org apereo
  • PHP516 用phpize增加扩展PDO_OCI和OCI8

    环境 xff1a centos5 5 PHP5 1 6 oracle10 2 0 5 客户端 1 从oracle官网下载oracle客户端包 oracle instantclient basic 10 2 0 5 1 i386 rpm or
  • npm ERR! enoent This is related to npm not being able to find a file.解决

    一 问题描述 运行sudo npm install color name出现如下错误 xff1a npm ERR path root blog node modules color namenpm ERR code ENOENT npm E
  • ROS中最重要的变量$ROS_PACKAGE_PATH

    昨天刚成功安装了ardrone autonomy 和 tum ardrone xff0c 运行也是通过了 今天又尝试了一下昨天的命令 xff0c 结果发现tum ardrone居然又运行不了了 xff0c 郁闷 xff01 说是没有在环境变
  • 用TIKZ在LaTex中画图

    我之前是用Edraw max画图的 xff0c 但是有一个致命的问题就是在图上写字母的时候与图解释中不一致 xff0c 所以尝试了一下LaTex画图 xff0c 哎呀 xff0c 耗费我一下午的时间呀 首先导入包 xff1a usepack
  • NLP中三种特征抽取器的优与劣

    RNN LSTM GRU xff1a 缺点 xff08 1 xff09 xff1a 无法并行 xff0c 因此速度较慢 xff08 2 xff09 xff1a RNN无法很好地学习到全局的结构信息 xff0c 尤其对于序列结构很长的 CNN
  • python List中元素两两组合

    aa span class token operator 61 span span class token punctuation span span class token string 39 a 39 span span class t
  • JRE not compatible with project .class file compatibility: 1.7

    电脑上刚装了jdk1 7 xff0c 运行一般程序的时候没有出现什么问题 xff0c 由于内存不够用 xff0c 在设置虚拟内存时却出现问题 xff0c 如下 xff1a 还好找到了解决办法 xff0c 错误的原因是JRE库配置与Java
  • BufferedWriter 的 flush() 方法

    package com corpus import java io import java util List import edu stanford nlp ling HasWord import edu stanford nlp lin
  • 正则表达式匹配连续多个空格或tab空格

    Pattern p 61 Pattern compile 34 s 2 t 34 Matcher m 61 p matcher str String strNoBlank 61 m replaceAll 34 34 System out p
  • LaTex中插入花体字母

    特别要注意的是 xff1a 在LaTeX中 xff0c 别把希腊字母和英文的花体字母搞混哦 xff0c 哈哈 举个例子 xff1a 后面显示的 X 不是希腊字母 西 即 也就是说不能通过 Chi 的方式插入这个特殊符号 xff0c 正确的花
  • 气哭了的C++调试,cmake 找不到 eigen

    这才刚刚开头 xff0c 可是就是不知道错误在哪里 xff1f 百度了问题后 xff0c 打开了很多很多相关的解答 xff0c 从昨天上午遇到这个问题 xff0c 历经昨天下午和晚上 xff0c 还是错误 xff0c 终于在今天上午圆满解决
  • 对ORACLE SCN的理解

    1 SCN数值实际来源于系统的timestamp xff0c 这个实际可以证明 select current scn from v database select timestamp to scn sysdate from dual 这两个
  • Ubuntu 下 终端界面转图形界面

    在运行程序的时候 xff0c 错误的使用了快捷键 ctrl 43 alt 43 F10 然后 unbuntu就黑屏了 xff0c 整个界面只剩下左上角有一个白色的字符在闪 xff0c 然后 Ctrl 43 alt 43 F2时跳出终端的登录
  • python错误:TypeError: 'module' object is not callable

    TrainCorpusStructure py 文件中的代码如下 xff1a class TrainCorpusStructure inputs 61 Demo py中的代码如下 xff1a from corpusProcess impor
  • python 除法保留两位小数点

    span class hljs operator a span 61 span class hljs number 1 span b 61 span class hljs number 3 span print span class hlj
  • pytorch中contiguous()

    contiguous xff1a view只能用在contiguous的variable上 如果在view之前用了transpose permute等 xff0c 需要用contiguous 来返回一个contiguous copy 一种可
  • tensorflow中control_flow_ops.while_loop

    self h0 61 tf zeros self batch size self hidden dim self h0 61 tf stack self h0 self h0 相当于 h0和C0 generator on initial r