Tensorflow 无效参数:断言失败 [标签 ID 必须 < n_classes]

2024-02-26

我在使用 Python 2.7 在 Tensorflow 1.3.0 中实现 DNNClassifier 时遇到错误。我从 Tensorflow 获取了示例代码tf.estimator Quickstart教程,我想用我自己的数据集运行它:3D 坐标和 10 个不同的类(int 标签)。这是我的实现:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

def ReadLabels(file):
    #load the labels from test file here
    labelFile = open(file, "r")
    Label = labelFile.readlines();
    returnL = [[Label[i][j+1] for j in range(len(Label[0])-3)] for i in range(len(Label))]
    returnLint = list();
    for i in range(len(returnL)):
        tmp = ''
        for j in range(len(returnL[0])):
            tmp += str(returnL[i][j])
        returnLint.append(int(tmp))
    return returnL, returnLint

def NumpyReadBin(file,numcols,type):
    #load the data from binary file here
    import numpy as np
    trainData = np.fromfile(file,dtype=type)
    numrows = len(trainData)/numcols
    #print trainData[0:100]
    result = [[trainData[i+j*numcols] for i in range(numcols)] for j in range(numrows)]
    return result

def TensorflowDNN():
    #load sample dataset
    trainData = NumpyReadBin('data/TrainingData.dat',3,'float32')
    valData = NumpyReadBin('data/ValidationData.dat',3,'float32')
    testData = NumpyReadBin('data/TestingData.dat',3,'float32')
    #load sample labels
    trainL, trainLint = ReadLabels('data/TrainingLabels.txt')
    validateL, validateLint = ReadLabels('data/ValidationLabels.txt')
    testL, testLint = ReadLabels('data/TestingLabels.txt')

    import tensorflow as tf
    import numpy as np

    #get unique labels
    uniqueTrain = set()
    for l in trainLint:
        uniqueTrain.add(l)
    uniqueTrain = list(uniqueTrain)
    numClasses = len(uniqueTrain)
    numDims = len(trainData[0])

    #All features have real-value data
    feature_columns = [tf.feature_column.numeric_column("x", shape=[3])]

    # Build 3 layer DNN with 10, 20, 10 units respectively.
    classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                              hidden_units=[10, 20, 10],
                                              n_classes=numClasses,
                                              model_dir="../Classification/tmp")

    # Define training inputs
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
                                                x={"x": np.array(trainData)},y=np.array(trainLint),
                                                num_epochs = None, shuffle = True)

    #Train the model
    classifier.train(input_fn = train_input_fn, steps = 2000)

    #Define Validation inputs
    val_input_fn = tf.estimator.inputs.numpy_input_fn(
                                                x={"x": np.array(valData)},y=np.array(validateLint),
                                                num_epochs = 1, shuffle = False)

    # Evaluate accuracy.
    accuracy_score = classifier.evaluate(input_fn=val_input_fn)["accuracy"]
    print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

if __name__ == '__main__':
    TensorflowDNN()

功能RedLabels(...) and NumpyReadBin(...)正在将我保存的数据集加载到张量中。由于标签是我从文本文件中读取的整数,因此该函数有点奇怪,但最终我得到的是一个包含来自这些标签的整数的数组:[11, 12, 21, 22, 23, 31, 32 、33、41、42]。

然而我无法对任何东西进行分类,因为在调用时classifier.train(input_fn = train_input_fn, steps = 2000),我收到以下错误:

...Traceback and stuff like that...
InvalidArgumentError (see above for traceback): assertion failed: [Label IDs must < n_classes] [Condition x < y did not hold element-wise:x (dnn/head/labels:0) = ] [[21][32][42]...] [y (dnn/head/assert_range/Const:0) = ] [10]
[[Node: dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_INT64], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/Switch/_117, dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/data_0, dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/data_1, dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/Switch_1/_119, dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/data_3, dnn/head/assert_range/assert_less/Assert/AssertGuard/Assert/Switch_2/_121)]]

有人以前遇到过这个错误或者知道如何解决它吗?我猜它在某种程度上抱怨我的数据集中的类数/标签格式,但我知道 trainLint 包含 10 个不同的类标签,这就是numClasses。难道是我的格式trainLint array?


所以解决方案为伊桑特·姆里纳尔 https://stackoverflow.com/users/1896918/ishant-mrinal指出:

Tensorflow 期望从 0 到类数的整数作为类标签 (range(0, num_classes)),而不是像我这样的“任意”数字。谢谢!:)

...我刚刚遇到的另一个选择是添加label_vocabulary到分类器定义:

classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[10, 20, 10],
                                          n_classes=numClasses,
                                          model_dir=saveAt,
                                          label_vocabulary=uniqueTrain)

使用此选项,我可以像以前一样定义标签,并将其转换为字符串。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 无效参数:断言失败 [标签 ID 必须 < n_classes] 的相关文章

  • 使用“iloc”时出现“尝试在 DataFrame 切片的副本上设置值”错误

    Jupyter 笔记本返回此警告 C anaconda lib site packages pandas core indexing py 337 SettingWithCopyWarning A value is trying to be
  • 在 Python 中搜索文本文件并打印相关行?

    如何在文本文件中搜索关键短语或关键字 然后打印关键短语或关键字所在的行 searchfile open file txt r for line in searchfile if searchphrase in line print line
  • 如何将字典转换为字符串

    我正在尝试使用提供的解决方案here https stackoverflow com questions 5192753 how to get the number of occurrences of each character usin
  • 如何向未知用户目录读取/写入文件?

    我正在尝试从用户目录 C Users USERNAME Test Source 读取和写入文件 但我未能成功找到任何有关如何自动检测用户名的资源 其中的 USERNAME上面的例子 或者无论如何 我可以让它读取和写入目录 而不需要知道用户名
  • pandas 系列值之间的过滤

    If s is a pandas Series http pandas pydata org pandas docs stable dsintro html series 我知道我可以这样做 b s lt 4 or b s gt 0 但我做
  • Python 中的自然日/相对日

    我想要一种在 Python 中显示日期项目的自然时间的方法 类似于 Twitter 将显示 刚才 几分钟前 两小时前 三天前 等消息 Django 1 0 在 django contrib 中有一个 人性化 方法 我没有使用 Django
  • Visual Studio Code 调试控制台中的 pydevd 警告

    我已经搜索了一段时间但找不到任何相关问题 当使用 Visual Studio Code 和 Python 扩展来调试大型元素时 计算表示或获取属性可能需要一些时间 在这些情况下 会出现如下警告 pydevd 警告 计算 DataFrame
  • Python 中字典的合并层次结构

    我有两本词典 而我想做的事情有点奇怪 基本上 我想合并它们 这很简单 但它们是字典的层次结构 我想以这样的方式合并它们 如果字典中的项目本身就是字典并且存在于两者中 我也想合并这些字典 如果它不是字典 我希望第二个字典中的值覆盖第一个字典中
  • Python 2.7从非默认目录打开多个文件(对于opencv)

    我在 64 位 win7 上使用 python 2 7 并拥有 opencv 2 4 x 当我写 cv2 imread pic 时 它会在我的默认 python 路径中打开 pic 即C Users Myname 但是我如何设法浏览不同的目
  • 使用 Python 将 Json 转换为换行 Json 标准

    我有一个获取嵌套对象并删除所有嵌套的代码 使对象平坦 def flatten json y param y Unflated Json return Flated Json out def flatten x name if type x
  • 使用字符串迭代 url - python

    我现在完全被我的代码困住了 首先 我尝试从 volkskrant 的存档页面检索所有网址 这是我被打击的第一步 某一特定日期的 url 如下所示 http www volkskrant nl archief detail 01012016
  • 如何从张量流数据集迭代器返回同一批次两次?

    我正在转换一些旧代码以使用数据集 API 此代码使用feed dict将一批数据送入列车运行 实际上是三次 然后重新计算损失以供显示使用同一批 所以我需要一个迭代器来返回完全相同的批次两次 或多次 不幸的是 我似乎找不到一种使用张量流数据集
  • Python - 根据条件调用函数

    我想知道是否有一种简洁的方法来根据条件调用函数 我有这个 if list 1 some dataframe df myfunction 我想知道这是否有可能三元运算符 http book pythontips com en latest t
  • Python 中的否定

    如果路径不存在 我尝试创建一个目录 但是 不 运算符不起作用 我不知道如何在 Python 中进行否定 正确的方法是什么 if os path exists usr share sounds blues proc subprocess Po
  • 检查一个数是否是完全平方数

    如何检查一个数是否是完全平方数 速度并不重要 目前 只是工作 See also Integer square root in python https stackoverflow com questions 15390807 依赖任何浮点计
  • Python 3d 金字塔

    我是 3D 绘图新手 我只想用 5 个点建造一个金字塔并通过它切出一个平面 我的问题是我不知道如何填充两侧 points np array 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 fig plt figure ax fi
  • 捕获 subprocess.run() 的输入

    我在 Windows 上有一个交互式命令行 exe 文件 是由其他人编写的 当程序出现异常时 它会终止 并且我对程序的所有输入都会丢失 所以我正在编写一个 python 程序 它调用一个阻塞子进程subprocess run 并捕获所有输入
  • python nltk从句子中提取关键字

    我们要做的第一件事 就是杀掉所有律师 威廉 莎士比亚 鉴于上面的引用 我想退出 kill and lawyers 作为两个突出的关键词来描述句子的整体含义 我提取了以下名词 动词 POS 标签 First NNP thing NN do V
  • Python DNS服务器IP地址查询

    我正在尝试使用 python 获取 DNS 服务器 IP 地址 要在 Windows 命令提示符下执行此操作 我将使用 ipconfig 全部 如下所示 我想使用 python 脚本做同样的事情 有什么方法可以提取这些值吗 我成功提取了设备
  • 如何将 pygame Surface 转换为 PIL 图像?

    我正在使用 PIL 来透视地变换屏幕的一部分 原始图像数据是一个 pygame Surface 需要转换为 PIL 图像 因此我发现了 pygame 的 tostring 函数就是为了这个目的而存在的 然而结果看起来很奇怪 见附图 这段代码

随机推荐