如何在 Tensorflow RNN 中构建嵌入层?

2023-12-29

我正在构建一个 RNN LSTM 网络,根据作者的年龄对文本进行分类(二元分类 - 年轻/成人)。

看起来网络没有学习并突然开始过度拟合:

rnn_overfitting
Red: train
Blue: validation

一种可能是数据表示不够好。我只是按频率对独特的单词进行排序并给它们索引。例如。:

unknown -> 0
the     -> 1
a       -> 2
.       -> 3
to      -> 4

所以我试图用词嵌入来代替它。 我看到了几个例子,但我无法在我的代码中实现它。大多数示例如下所示:

embedding = tf.Variable(tf.random_uniform([vocab_size, hidden_size], -1, 1))
inputs = tf.nn.embedding_lookup(embedding, input_data)

这是否意味着我们正在构建一个层learns嵌入?我认为应该下载一些 Word2Vec 或 Glove 并使用它。

无论如何,假设我想构建这个嵌入层......
如果我在代码中使用这两行,我会收到错误:

类型错误:传递给参数“索引”的值的数据类型 float32 不在允许值列表中:int32、int64

所以我想我必须改变input_data键入至int32。所以我这样做了(毕竟都是索引),我得到了这个:

类型错误:输入必须是序列

我尝试包裹inputs(论点tf.contrib.rnn.static_rnn)和一个列表:[inputs]如建议的这个答案 https://stackoverflow.com/a/45217776/900394,但这又产生了另一个错误:

ValueError:输入大小(输入的维度 0)必须可通过 形状推断,但看到值 None。


Update:

我正在拆开张量x在将其传递给之前embedding_lookup。我在嵌入后移动了拆垛。

更新的代码:

MIN_TOKENS = 10
MAX_TOKENS = 30
x = tf.placeholder("int32", [None, MAX_TOKENS, 1])
y = tf.placeholder("float", [None, N_CLASSES]) # 0.0 / 1.0
...
seqlen = tf.placeholder(tf.int32, [None]) #list of each sequence length*
embedding = tf.Variable(tf.random_uniform([VOCAB_SIZE, HIDDEN_SIZE], -1, 1))
inputs = tf.nn.embedding_lookup(embedding, x) #x is the text after converting to indices
inputs = tf.unstack(inputs, MAX_POST_LENGTH, 1)
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, inputs, dtype=tf.float32, sequence_length=seqlen) #---> Produces error

*seqlen:我对序列进行了零填充,因此所有序列都具有相同的列表大小,但由于实际大小不同,我准备了一个描述没有填充的长度的列表。

新错误:

ValueError:层 basic_lstm_cell_1 的输入 0 与 该层:预期 ndim=2,发现 ndim=3。收到的完整形状:[无, 1, 64]

64是每个隐藏层的大小。

很明显,我的尺寸有问题......如何使输入在嵌入后适合网络?


来自tf.nn.static_rnn https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn,我们可以看到inputs参数为:

长度为 T 的输入列表,每个输入都是形状为 [batch_size, input_size] 的张量

所以你的代码应该是这样的:

x = tf.placeholder("int32", [None, MAX_TOKENS])
...
inputs = tf.unstack(inputs, axis=1)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 Tensorflow RNN 中构建嵌入层? 的相关文章

随机推荐

  • 如何生成唯一的订单号?

    我正在寻找一种生成唯一订单 ID 的好方法 你能看出下面的代码有什么问题吗 int customerId 10000000 long ticks DateTime UtcNow Ticks long orderId customerId t
  • 我尝试将 postgresql md5 更改为 scram-sha-256,但出现 FATAL 密码身份验证失败

    我在用着postgresql作为学习的一部分 我尝试更改登录方法以获得更安全的登录方法 例如使用scram sha 256代替md5 我试图改变我的password encryption to scram sha256 in postgre
  • Ansible,字段 args 具有无效值[重复]

    这个问题在这里已经有答案了 我向我的 playbook yml 添加了一个名为 common 的角色 但配置失败并显示以下消息 TASK common Host is present gt cd fatal localhost FAILED
  • HTTP请求头X-Requested-With从哪里来

    众所周知 我们可以使用X Requested Withhttp请求头来判断http请求是否来自Ajax 许多 javascript 框架会自动添加X Requested Withajax请求中的header 比如jQuery Ajax ht
  • Java:如何通过 org.w3c.dom.document 上的 xpath 字符串定位元素

    如何通过给定 org w3c dom document 上的 xpath 字符串快速定位一个或多个元素 似乎没有FindElementsByXpath 方法 例如 html body p div 3 a 我发现当存在大量同名元素时 递归迭代
  • 如何访问白色的消息框?

    我在 WPF 应用程序中有一个简单的消息框 启动方式如下 private void Button Click object sender RoutedEventArgs e MessageBox Show Howdy Howdy 我可以得到
  • 如何检查一个 JavaScript 对象中的值是否存在于另一个 JavaScript 对象中?

    我正在尝试比较 json str1 和 json str2 这里它应该返回 true 因为 json str1 中的所有元素都存在于 json str2 中 现在我正在这样做 json str1 0 a 1 b 2 c json str2
  • jsp 包含中的 HTTP 状态 500 文件未找到错误

    我包含根目录中的文件 它在本地工作正常 但当我托管我的网站时 它给我 错误 HTTP 状态 500 未找到 connection jsp 我的文件在 public html myfolder connection jsp 在共享主机上 我想
  • 将集合绑定到 StackPanel

    我想获取一个对象集合并将其绑定到 StackPanel 所以基本上如果该集合有 4 个元素 那么在堆栈面板内应该生成 4 个按钮 我尝试过这个 但我认为这不是正确的方法 我过去使用 DataTemplated 做过这种类型的想法 如果我错了
  • 尽管集群已启动,Flink localhost 仪表板仍无法工作

    我已经下载了Flink 1 5 0并运行启动集群脚本 集群似乎已成功启动 bin start cluster sh Starting cluster Starting standalonesession daemon on host LAP
  • 为什么 Microsoft.NET.CoreRuntime.1.1.appx 出现依赖性 - UWP APPX

    当我用来创建项目的appx文件 x64发布模式 时 在依赖文件夹下只创建了Microsoft VCLibs x64 14 00 appx文件 但现在 当我尝试创建 appx 时 还在依赖项文件夹下创建了一个附加文件 Microsoft NE
  • 在 C# Blazor 中的分部类中初始化 RenderFragment

    我正在使用第三方包中的组件 该组件接受 RenderFragment 作为参数 并且我想通过索引页的部分类为该 RenderFragment 分配一个值 我意识到当我在中构建 RenderFragment 时code标签 它有效 但一旦你把
  • iPhone 5 (4") 底部工具栏没有响应

    我正在尝试修改一个应用程序以适应新的 iPhone 5 4 屏幕 我添加了新的启动图像 电子邮件受保护 cdn cgi l email protection 之后一切似乎都很好 我的视图的中间部分可以调整大小 但是我注意到 在有底部工具栏的
  • 获取 UIScrollView 内容的可见矩形

    我怎样才能找到屏幕上实际可见的显示视图内容的矩形 CGRect myScrollView bounds 上面的代码在没有缩放时有效 但一旦允许缩放 它就会在 1 以外的缩放比例下中断 为了澄清 我想要一个 CGRect 包含滚动视图内容相对
  • Rust:从标准输入读取和映射行并处理不同的错误类型

    我正在学习 Rust 并尝试用它解决一些基本的算法问题 在许多情况下 我想从标准输入读取行 对每行执行一些转换并返回结果项的向量 我这样做的一种方法是这样的 Fully working Rust code let my values Vec
  • Boost::Spirit 后跟默认值时字符加倍

    我使用 boost spirit 来解析单项式的 一部分 如 x y xy x 2 x 3yz 我想将单项式的变量保存到一个映射中 该映射还存储相应的指数 因此 语法还应该保存 1 的隐式指数 因此 x 存储起来就像写成 x 1 一样 st
  • 单击 R 字符串输出中的 URL

    假设我有 R 的 cat 函数的输出 它是一个 URL 例如 cat https en wikipedia org wiki Statistics Output https en wikipedia org wiki Statistics
  • 使用泛型类型时,“From”的实现如何会发生冲突?

    我正在尝试实现一个错误枚举 它可以包含与我们的特征之一相关的错误 如下所示 trait Storage type Error enum MyError
  • Python/NetworkX:动态计算边权重

    我有一个未加权的创建的图表networkx为此 我想根据边缘出现的计数 频率来计算节点之间的边缘权重 我的图中的一条边可以多次出现 但事先并不知道边出现的频率 目的是根据连接节点之间移动的权重 例如计数 频率 可视化边缘 本质上 我想创建连
  • 如何在 Tensorflow RNN 中构建嵌入层?

    我正在构建一个 RNN LSTM 网络 根据作者的年龄对文本进行分类 二元分类 年轻 成人 看起来网络没有学习并突然开始过度拟合 Red train Blue validation 一种可能是数据表示不够好 我只是按频率对独特的单词进行排序