Tensorflow:ValueError:形状必须为 2 级,但为 3 级

2023-11-27

我是张量流新手,我正在尝试将双向 LSTM 的一些代码从旧版本的张量流更新到最新版本(1.0),但出现此错误:

形状必须为等级 2,但“MatMul_3”(操作:“MatMul”)的等级为 3,输入形状为:[100,?,400]、[400,2]。

错误发生在 pred_mod 上。

    _weights = {
    # Hidden layer weights => 2*n_hidden because of foward + backward cells
        'w_emb' : tf.Variable(0.2 * tf.random_uniform([max_features,FLAGS.embedding_dim], minval=-1.0, maxval=1.0, dtype=tf.float32),name='w_emb',trainable=False),
        'c_emb' : tf.Variable(0.2 * tf.random_uniform([3,FLAGS.embedding_dim],minval=-1.0, maxval=1.0, dtype=tf.float32),name='c_emb',trainable=True),
        't_emb' : tf.Variable(0.2 * tf.random_uniform([tag_voc_size,FLAGS.embedding_dim], minval=-1.0, maxval=1.0, dtype=tf.float32),name='t_emb',trainable=False),
        'hidden_w': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'hidden_c': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'hidden_t': tf.Variable(tf.random_normal([FLAGS.embedding_dim, 2*FLAGS.num_hidden])),
        'out_w': tf.Variable(tf.random_normal([2*FLAGS.num_hidden, FLAGS.num_classes]))}

    _biases = {
         'hidden_b': tf.Variable(tf.random_normal([2*FLAGS.num_hidden])),
         'out_b': tf.Variable(tf.random_normal([FLAGS.num_classes]))}


    #~ input PlaceHolders
    seq_len = tf.placeholder(tf.int64,name="input_lr")
    _W = tf.placeholder(tf.int32,name="input_w")
    _C = tf.placeholder(tf.int32,name="input_c")
    _T = tf.placeholder(tf.int32,name="input_t")
    mask = tf.placeholder("float",name="input_mask")

    # Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
    istate_fw = tf.placeholder("float", shape=[None, 2*FLAGS.num_hidden])
    istate_bw = tf.placeholder("float", shape=[None, 2*FLAGS.num_hidden])
    _Y = tf.placeholder("float", [None, FLAGS.num_classes])

    #~ transfortm into Embeddings
    emb_x = tf.nn.embedding_lookup(_weights['w_emb'],_W)
    emb_c = tf.nn.embedding_lookup(_weights['c_emb'],_C)
    emb_t = tf.nn.embedding_lookup(_weights['t_emb'],_T)

    _X = tf.matmul(emb_x, _weights['hidden_w']) + tf.matmul(emb_c, _weights['hidden_c']) + tf.matmul(emb_t, _weights['hidden_t']) + _biases['hidden_b']

    inputs = tf.split(_X, FLAGS.max_sent_length, axis=0, num=None, name='split')

    lstmcell = tf.contrib.rnn.BasicLSTMCell(FLAGS.num_hidden, forget_bias=1.0, 
    state_is_tuple=False)

    bilstm = tf.contrib.rnn.static_bidirectional_rnn(lstmcell, lstmcell, inputs, 
    sequence_length=seq_len, initial_state_fw=istate_fw, initial_state_bw=istate_bw)


    pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] for item in bilstm]

任何帮助表示赞赏。


对于将来遇到此问题的任何人,上面的代码片段不应该使用。

From tf.contrib.rnn.static_bidirectional_rnnv1.1 文档:

Returns:

A tuple (outputs, output_state_fw, output_state_bw)其中:outputs 是长度为 T 的输出列表(每个输入一个),它们是深度连接的前向和后向输出。 output_state_fw 是前向 rnn 的最终状态。 output_state_bw 是后向 rnn 的最终状态。

上面的列表理解需要 LSTM 输出,获取这些输出的正确方法是:

outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(lstmcell, lstmcell, ...)
pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] 
            for item in outputs]

这会起作用,因为每个item in outputs有形状[batch_size, 2 * num_hidden]并可以乘以权重tf.matmul().


附加组件:从tensorflow v1.2+开始,推荐使用的函数位于另一个包中:tf.nn.static_bidirectional_rnn。返回的张量是相同的,因此代码没有太大变化:

outputs, _, _ = tf.nn.static_bidirectional_rnn(lstmcell, lstmcell, ...)
pred_mod = [tf.matmul(item, _weights['out_w']) + _biases['out_b'] 
            for item in outputs]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow:ValueError:形状必须为 2 级,但为 3 级 的相关文章

  • Python PAM 模块的安全问题?

    我有兴趣编写一个 PAM 模块 该模块将利用流行的 Unix 登录身份验证机制 我过去的大部分编程经验都是使用 Python 进行的 并且我正在交互的系统已经有一个 Python API 我用谷歌搜索发现pam python http pa
  • DreamPie 不适用于 Python 3.2

    我最喜欢的 Python shell 是DreamPie http dreampie sourceforge net 我想将它与 Python 3 2 一起使用 我使用了 添加解释器 DreamPie 应用程序并添加了 Python 3 2
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • 如何等到 Excel 计算公式后再继续 win32com

    我有一个 win32com Python 脚本 它将多个 Excel 文件合并到电子表格中并将其另存为 PDF 现在的工作原理是输出几乎都是 NAME 因为文件是在计算 Excel 文件内容之前输出的 这可能需要一分钟 如何强制工作簿计算值
  • keras加载模型错误尝试将包含17层的权重文件加载到0层的模型中

    我目前正在使用 keras 开发 vgg16 模型 我用我的一些图层微调 vgg 模型 拟合我的模型 训练 后 我保存我的模型model save name h5 可以毫无问题地保存 但是 当我尝试使用以下命令重新加载模型时load mod
  • 交换keras中的张量轴

    我想将图像批次的张量轴从 batch size row col ch 交换为 批次大小 通道 行 列 在 numpy 中 这可以通过以下方式完成 X batch np moveaxis X batch 3 1 我该如何在 Keras 中做到
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • 循环中断打破tqdm

    下面的简单代码使用tqdm https github com tqdm tqdm在循环迭代时显示进度条 import tqdm for f in tqdm tqdm range 100000000 if f gt 100000000 4 b
  • Python - 按月对日期进行分组

    这是一个简单的问题 起初我认为很简单而忽略了它 一个小时过去了 我不太确定 所以 我有一个Python列表datetime对象 我想用图表来表示它们 x 值是年份和月份 y 值是此列表中本月发生的日期对象的数量 也许一个例子可以更好地证明这
  • 从 pygame 获取 numpy 数组

    我想通过 python 访问我的网络摄像头 不幸的是 由于网络摄像头的原因 openCV 无法工作 Pygame camera 使用以下代码就像魅力一样 from pygame import camera display camera in
  • 设置 torch.gather(...) 调用的结果

    我有一个形状为 n x m 的 2D pytorch 张量 我想使用索引列表来索引第二个维度 可以使用 torch gather 完成 然后然后还设置新值到索引的结果 Example data torch tensor 0 1 2 3 4
  • VSCode:调试配置中的 Python 路径无效

    对 Python 和 VSCode 以及 stackoverflow 非常陌生 直到最近 我已经使用了大约 3 个月 一切都很好 当尝试在调试器中运行任何基本的 Python 程序时 弹出窗口The Python path in your
  • 在 Pandas DataFrame Python 中添加新列[重复]

    这个问题在这里已经有答案了 例如 我在 Pandas 中有数据框 Col1 Col2 A 1 B 2 C 3 现在 如果我想再添加一个名为 Col3 的列 并且该值基于 Col2 式中 如果Col2 gt 1 则Col3为0 否则为1 所以
  • 从 Python 中的类元信息对 __init__ 函数进行类型提示

    我想做的是复制什么SQLAlchemy确实 以其DeclarativeMeta班级 有了这段代码 from sqlalchemy import Column Integer String from sqlalchemy ext declar
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • 您可以在 Python 类型注释中指定方差吗?

    你能发现下面代码中的错误吗 米皮不能 from typing import Dict Any def add items d Dict str Any gt None d foo 5 d Dict str str add items d f
  • 协方差矩阵的对角元素不是 1 pandas/numpy

    我有以下数据框 A B 0 1 5 1 2 6 2 3 7 3 4 8 我想计算协方差 a df iloc 0 values b df iloc 1 values 使用 numpy 作为 cov numpy cov a b I get ar
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • Python - 字典和列表相交

    给定以下数据结构 找出这两种数据结构共有的交集键的最有效方法是什么 dict1 2A 3A 4B list1 2A 4B Expected output 2A 4B 如果这也能产生更快的输出 我可以将列表 不是 dict1 组织到任何其他数
  • PyAudio ErrNo 输入溢出 -9981

    我遇到了与用户相同的错误 Python 使用 Pyaudio 以 16000Hz 录制音频时出错 https stackoverflow com questions 12994981 python error audio recording

随机推荐

  • 打印包含“word”的行 python

    我只想打印以下输出中包含 Server 的行 Date Sun 16 Dec 2012 20 07 44 GMT Expires 1 Cache Control private max age 0 Content Type text htm
  • Laravel 表单不会 PATCH,只会 POST - 嵌套 RESTfull 控制器、MethodNotAllowedHttpException

    我正在尝试允许users编辑他们的playlist 但是 每当我尝试执行 PATCH 请求时 我都会得到MethodNotAllowedHttpException错误 它正在等待一个帖子 我已经设置了 RESTful 资源控制器 路线 ph
  • 如何在 bash 的 CURL 请求中使用变量?

    Goal 我正在使用 bash CURL 脚本连接到 Cloudflare APIv4 目标是更新 A 记录 我的脚本 Get current public IP current ip curl silent ipecho net plai
  • 如何在android中动态提供地图api密钥

    我的 Android 应用程序中有一个用例 我的应用程序的用户必须提供 API 密钥 以便他们可以使用地图 但我发现我必须在清单文件中提供 API 密钥 我无法在运行时编辑它 有没有其他方法可以动态地将地图 API 密钥提供给谷歌地图 我正
  • iframe 中 url 的基本身份验证

    我需要验证从页面上的 iframe 通过 javascript 创建 发送的请求 身份验证将通过基本的 http 身份验证完成 我试过做 http user password server 但显然由于安全异常 这在 IE 中不可用 http
  • 如何在 IIS 上设置反向代理,以允许 host1.mydomain.com 和 host2.mydomain.com 之间进行跨主机通信?

    我在 host1 mydomain com page from host1 jsp 上有一个页面 在 host2 mydomain com page from host2 html 上有一个 HTML 页面 host1 是 IIS7 Tom
  • 在 Android 4.4 中启用 TLS 1.2

    我使用 Retrofit 和 OkHttp3 来发出请求 我知道在 Android 4 4 中 默认情况下未启用 TLS 1 1 和 TLS 1 2 所以我正在尝试启用它们 但到目前为止我还没有成功 我读到这可能是 android stud
  • 如何移动google地图的中心位置

    我正在使用以下代码在脚本中创建谷歌地图 var mapElement parent mapOptions map marker latLong openMarker parent document getElementsByClassNam
  • Gitlab 端口 8080

    我目前正在尝试在我的私人 Debian 服务器上安装 Gitlab 综合总线 它在端口 80 上运行得很好 问题是我还有一个 Apache 服务器在监听端口 80 所以我正在尝试让 Nginx监听端口 8080 但由于某种原因我得到了 50
  • 为什么多态性在没有指针/引用的情况下不起作用?

    我确实在 StackOverflow 上发现了一些具有类似标题的问题 但是当我阅读答案时 他们关注的是问题的不同部分 这些部分非常具体 例如 STL 容器 有人可以告诉我 为什么必须使用指针 引用来实现多态性吗 我可以理解指针可能会有所帮助
  • 检测用户所在国家/地区的最快方法

    我需要检测用户的国家 地区并按他 她的国家 地区显示网站的语言 土耳其人用土耳其语 其他人用英语 我怎样才能以最快的方式做到这一点 表现对我来说很重要 我在看IPInfoDB 的 API 还有更好的选择吗 我使用的是PHP 对于可能在 20
  • 消息 8114,级别 16,状态 5,第 1 行将数据类型 varchar 转换为数字时出错

    Select CAST de ornum AS numeric 1 as ornum2 from Cpaym as de left outer join Cpaym as de1 on CAST de ornum AS numeric de
  • 毕加索实际上是如何缓存图像的

    我想知道毕加索图书馆到底是如何缓存应用程序内的图像的 我知道它使用 HttpHeaders 来检查天气以从网络获取图像 但是 它缓存图像有时间范围吗 比如一天后使缓存无效之类的 问题是我的项目正在从网络加载大量小图像 有时 新图像会反映在下
  • 预测精度:没有以两个向量作为参数的 MASE

    我正在使用accuracy函数从forecast包 计算精度测量 我使用它来计算拟合时间序列模型的度量 例如 ARIMA 或指数平滑 当我在不同维度和聚合级别上测试不同模型类型时 我使用 Hyndman 等人引入的 MASE 平均绝对比例误
  • ggplot2 的图像文件压缩选项

    是否可以使用压缩图形的文件大小ggsave 我尝试过使用compression lzw 参数 但文件大小保持不变 使用 R studio 98 501 OS X Yosemite My code ggsave Figure1 tiff wi
  • Selenium Phantomjs 浏览器在启动时挂起。我该如何调试它?

    我正在尝试帮助在其他人的设置上运行我的 selenium Python 绑定版本 2 测试 它可以与 Firefox esr 两台机器上 配合使用 也可以与我的机器上最新的 phantomjs 配合使用 它挂在他的机器上 唯一明显的区别是他
  • 如何根据用户输入动态构建并返回 linq 谓词

    在这件事上有点卡住了 基本上我有一个方法 我想返回一个谓词表达式 我可以将其用作Where 条件 我认为我需要做的与此类似 http msdn microsoft com en us library bb882637 aspx但我对我需要做
  • 如何加速嵌套循环?

    我正在 python 中执行一个嵌套循环 如下所示 这是搜索现有金融时间序列并在时间序列中寻找符合某些特征的周期的基本方法 在这种情况下 有两个独立的 大小相等的数组 分别代表 收盘价 即资产的价格 和 交易量 即一段时间内交换的资产数量
  • 如何通过 SendKeys 发送特殊字符?

    我正在尝试在 Selenium2 中填写表格 One input has an autocomplete that I want to close preferably by sending esc after the search ter
  • Tensorflow:ValueError:形状必须为 2 级,但为 3 级

    我是张量流新手 我正在尝试将双向 LSTM 的一些代码从旧版本的张量流更新到最新版本 1 0 但出现此错误 形状必须为等级 2 但 MatMul 3 操作 MatMul 的等级为 3 输入形状为 100 400 400 2 错误发生在 pr