目标数组形状与使用 Tensorflow 的预期输出不同

2024-03-18

我正在尝试制作 CNN(仍然是初学者)。当尝试拟合模型时,我收到此错误:

ValueError:形状为 (10000, 10) 的目标数组被传递用于形状 (None, 6, 6, 10) 的输出,同时用作损失categorical_crossentropy。这种损失期望目标具有与输出相同的形状。

标签的形状 = (10000, 10) 图像数据的形状 = (10000, 32, 32, 3)

Code:

import pickle
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Dense, Dropout, Activation, Flatten, 
                                     Conv2D, MaxPooling2D)
from tensorflow.keras.callbacks import TensorBoard
from keras.utils import to_categorical
import numpy as np
import time

MODEL_NAME = f"_________{int(time.time())}"
BATCH_SIZE = 64

class ConvolutionalNetwork():
    '''
    A convolutional neural network to be used to classify images
    from the CIFAR-10 dataset.
    '''

    def __init__(self):
        '''
        self.training_images -- a 10000x3072 numpy array of uint8s. Each 
                                a row of the array stores a 32x32 colour image. 
                                The first 1024 entries contain the red channel 
                                values, the next 1024 the green, and the final 
                                1024 the blue. The image is stored in row-major 
                                order, so that the first 32 entries of the array are the red channel values of the first row of the image.
        self.training_labels -- a list of 10000 numbers in the range 0-9. 
                                The number at index I indicates the label 
                                of the ith image in the array data.
        '''
        # List of image categories
        self.label_names = (self.unpickle("cifar-10-batches-py/batches.meta",
                            encoding='utf-8')['label_names'])

        self.training_data = self.unpickle("cifar-10-batches-py/data_batch_1")
        self.training_images = self.training_data[b'data']
        self.training_labels = self.training_data[b'labels']

        # Reshaping the images + scaling 
        self.shape_images()  

        # Converts labels to one-hot
        self.training_labels = np.array(to_categorical(self.training_labels))

        self.create_model()

        self.tensorboard = TensorBoard(log_dir=f'logs/{MODEL_NAME}')

    def unpickle(self, file, encoding='bytes'):
        '''
        Unpickles the dataset files.
        '''
        with open(file, 'rb') as fo:
            training_dict = pickle.load(fo, encoding=encoding)
        return training_dict

    def shape_images(self):
        '''
        Reshapes the images and scales by 255.
        '''
        images = list()
        for d in self.training_images:
            image = np.zeros((32,32,3), dtype=np.uint8)
            image[...,0] = np.reshape(d[:1024], (32,32)) # Red channel
            image[...,1] = np.reshape(d[1024:2048], (32,32)) # Green channel
            image[...,2] = np.reshape(d[2048:], (32,32)) # Blue channel
            images.append(image)

        for i in range(len(images)):
            images[i] = images[i]/255

        images = np.array(images)
        self.training_images = images
        print(self.training_images.shape)

    def create_model(self):
        '''
        Creating the ConvNet model.
        '''
        self.model = Sequential()
        self.model.add(Conv2D(64, (3, 3), input_shape=self.training_images.shape[1:]))
        self.model.add(Activation("relu"))
        self.model.add(MaxPooling2D(pool_size=(2,2)))

        self.model.add(Conv2D(64, (3,3)))
        self.model.add(Activation("relu"))
        self.model.add(MaxPooling2D(pool_size=(2,2)))

        # self.model.add(Flatten())
        # self.model.add(Dense(64))
        # self.model.add(Activation('relu'))

        self.model.add(Dense(10))
        self.model.add(Activation(activation='softmax'))

        self.model.compile(loss="categorical_crossentropy", optimizer="adam", 
                           metrics=['accuracy'])

    def train(self):
        '''
        Fits the model.
        '''
        print(self.training_images.shape)
        print(self.training_labels.shape)
        self.model.fit(self.training_images, self.training_labels, batch_size=BATCH_SIZE, 
                       validation_split=0.1, epochs=5, callbacks=[self.tensorboard])


network = ConvolutionalNetwork()
network.train()

感谢您的帮助,已经尝试修复一个小时了。


您需要取消注释Flatten创建模型时的图层。本质上,该层的作用是接受 4D 输入(batch_size, height, width, num_filters)并将其展开为 2D 的(batch_size, height * width * num_filters)。这是获得您想要的输出形状所必需的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

目标数组形状与使用 Tensorflow 的预期输出不同 的相关文章

  • 没有名为 crypto.cipher 的模块

    我现在正在尝试加密一段时间 我最近得到了这个基于 python 的密码器 名为PythonCrypter https github com jbertman PythonCrypter 我对 Python 相当陌生 当我尝试通过终端打开 C
  • 通过 Scrapy 抓取 Google Analytics

    我一直在尝试使用 Scrapy 从 Google Analytics 获取一些数据 尽管我是一个完全的 Python 新手 但我已经取得了一些进展 我现在可以通过 Scrapy 登录 Google Analytics 但我需要发出 AJAX
  • Python 的键盘中断不会中止 Rust 函数 (PyO3)

    我有一个使用 PyO3 用 Rust 编写的 Python 库 它涉及一些昂贵的计算 单个函数调用最多需要 10 分钟 从 Python 调用时如何中止执行 Ctrl C 好像只有执行结束后才会处理 所以本质上没什么用 最小可重现示例 Ca
  • Django 管理员在模型编辑时间歇性返回 404

    我们使用 Django Admin 来维护导出到我们的一些站点的一些数据 有时 当单击标准更改列表视图来获取模型编辑表单而不是路由到正确的页面时 我们会得到 Django 404 页面 模板 它是偶尔发生的 我们可以通过重新加载三次来重现它
  • 为 Anaconda Python 安装 psycopg2

    我有 Anaconda Python 3 4 但是每当我运行旧代码时 我都会通过输入 source activate python2 切换到 Anaconda Python 2 7 我的问题是我为 Anaconda Python 3 4 安
  • 使用 on_bad_lines 将 pandas.read_csv 中的无效行写入文件

    我有一个 CSV 文件 我正在使用 Python 来解析该文件 我发现文件中的某些行具有不同的列数 001 Snow Jon 19801201 002 Crom Jake 19920103 003 Wise Frank 19880303 l
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 如何使用 OpencV 从 Firebase 读取图像?

    有没有使用 OpenCV 从 Firebase 读取图像的想法 或者我必须先下载图片 然后从本地文件夹执行 cv imread 功能 有什么办法我可以使用cv imread link of picture from firebase 您可以
  • 如何使用Python创建历史时间线

    So I ve seen a few answers on here that helped a bit but my dataset is larger than the ones that have been answered prev
  • 无法在 Python 3 中导入 cProfile

    我试图将 cProfile 模块导入 Python 3 3 0 但出现以下错误 Traceback most recent call last File
  • 如何在seaborn displot中使用hist_kws

    我想在同一图中用不同的颜色绘制直方图和 kde 线 我想为直方图设置绿色 为 kde 线设置蓝色 我设法弄清楚使用 line kws 来更改 kde 线条颜色 但 hist kws 不适用于显示 我尝试过使用 histplot 但我无法为
  • 对年龄列进行分组/分类

    我有一个数据框说df有一个柱子 Ages gt gt gt df Age 0 22 1 38 2 26 3 35 4 35 5 1 6 54 我想对这个年龄段进行分组并创建一个像这样的新专栏 If age gt 0 age lt 2 the
  • 类型错误:预期单个张量时的张量列表 - 将 const 与 tf.random_normal 一起使用时

    我有以下 TensorFlow 代码 tf constant tf random normal time step batch size 1 1 我正进入 状态TypeError List of Tensors when single Te
  • 有人用过 Dabo 做过中型项目吗? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我们正处于一个新的 ERP 风格的客户端 服务器应用程序的开始阶段 该应用程序是作为 Python 富客户端开发的 我们目前正在评估 Dabo
  • Python:如何将列表列表的元素转换为无向图?

    我有一个程序 可以检索 PubMed 出版物列表 并希望构建一个共同作者图 这意味着对于每篇文章 我想将每个作者 如果尚未存在 添加为顶点 并添加无向边 或增加每个合著者之间的权重 我设法编写了第一个程序 该程序检索每个出版物的作者列表 并
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • 发送用户注册密码,django-allauth

    我在 django 应用程序上使用 django alluth 进行身份验证 注册 我需要创建一个自定义注册表单 其中只有一个字段 电子邮件 密码将在服务器上生成 这是我创建的表格 from django import forms from
  • Rocket UniData/UniVerse:ODBC 无法分配足够的内存

    每当我尝试使用pyodbc连接到 Rocket UniData UniVerse 数据时我不断遇到错误 pyodbc Error 00000 00000 Rocket U2 U2ODBC 0302810 Unable to allocate
  • 导入错误:没有名为 site 的模块 - mac

    我已经有这个问题几个月了 每次我想获取一个新的 python 包并使用它时 我都会在终端中收到此错误 ImportError No module named site 我不知道为什么会出现这个错误 实际上 我无法使用任何新软件包 因为每次我
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2

随机推荐

  • 如何在 po gettext 文件中将空翻译 (msgstr) 标记为已翻译?

    我发现字符串 msgid 的翻译为空 所有 gettext 工具都会将该字符串视为未翻译 有解决方法吗 我确实想要一个空字符串作为该项目的翻译 由于这似乎是 gettext 规范中的一个很大的设计缺陷 我决定使用 Unicode Chara
  • Spark Streaming数据放入HBase的问题

    我是这个领域的初学者 所以我无法理解它 HBase 版本 0 98 24 hadoop2 火花版本 2 1 0 以下代码尝试将从 Spark Streming Kafka 生产者接收的数据放入 HBase 中 Kafka输入数据格式是这样的
  • 点“.”的 java keyevent 字段是什么?

    我知道如何使用 keyevent 调用 1 应该像 aaa keyPress KeyEvent VK 1 现在我需要输入 点 但我找不到 KeyEvent VK DOT 或一些类似的命令 请帮忙 Thanks 这个 点 被称为period
  • 如何使用带有条纹元素的引导浮动标签?

    我想知道如何使用浮动标签设置条纹元素的样式 bootstrap 5 我的所有其他字段都采用这种方式设计 因此最好对信用卡输入和 cvv 输入进行设计 以匹配我网站的主题 我尝试过使用以下答案 如何使用 Bootstrap 设置 Stripe
  • 从本地开发环境访问ElastiCache memcache实例

    有没有办法从本地开发环境访问缓存节点 尽管可以从 EC2 实例访问相同的缓存节点 我正在使用带有 C 的 Enyim memcache 客户端库 我发现很少有文章说这是不可能的 那么最好的方法应该是什么 我是否需要在本地设置内存缓存以进行开
  • 最流行的 C 通用集合数据结构库是什么?

    我正在寻找一个提供通用集合数据结构 例如列表 关联数组 集合等 的 C 库 该库应该稳定且经过良好测试 我基本上是在寻找比蹩脚的 C 标准库更好的东西 哪些 C 库符合此描述 编辑 我希望该库是跨平台的 但如果做不到这一点 任何可以在 Ma
  • 将数据存储在自定义字段中或将附件存储在 ics iCal 文件中

    我需要为我手动构建的 iCal 文件 ics 提供一些我实际上不希望日历应用程序用户看到的附加信息 因此 当我在 iOS 应用程序中创建事件并 稍后 从日历事件中读取它们时 我需要能够手动设置它们 我想知道是否可以将自定义字段 属性添加到
  • 使用 dplyr 进行 SQL in-db 操作时的 ifelse 和 grepl 命令

    在R数据帧上运行的dplyr中 很容易运行 df lt df gt mutate income topcoded ifelse income gt topcode income topcode 我现在正在使用一个大型 SQL 数据库 使用
  • SharePoint Designer 动态重新格式化 HTML,是否可以禁用?

    在我彻底放弃之前 我一直在尝试修改 SharePoint Designer 中的一些母版页 每当我更改 HTML 标记时 它都会根据需要重新设置它们的格式 例如 我试图使代码可读 因此我将项目移动到自己的行等 一旦我保存 它就会将所有内容移
  • 将数据从 s3 复制到带有前缀的本地

    我正在尝试使用 aws cli 将数据从 s3 复制到带有前缀的本地 但我在使用不同的正则表达式时遇到错误 aws s3 cp s3 my bucket name RAW TIMESTAMP 0506 profile prod error
  • DirectQuery 模式下的 AAS 表格模型性能优势

    假设您有 10 个相当大的事实表 每个 50 100 GB 应该使用 Power BI 进行查询 它们不适合 Azure Analysis Services RAM 价格合理 因此 为了使用表格模型和 AAS 您必须使用以下模式 1 Pow
  • 如何在 Playframework 中将 Oracle 存储过程与 Scala Anorm 结合使用

    我有许多存储过程 其结果是字符串列表 我如何使用scala访问play 2 0框架中的refcurser 有人可以举一个简单的例子 我如何填写一个列表吗 我试过这个 case class XXXX name String descripti
  • 为什么 UIView 中有一个框架矩形和一个边界矩形?

    好吧 虽然已经是深夜了 但我不明白为什么有两个不同的矩形 frame and bounds 据我了解 一个矩形就足以完成所有操作 相对于另一个坐标系定位视图本身 然后将其内容剪切到指定的大小 你还想用两个矩形做什么 他们如何相互作用 有人有
  • 通过循环在renderUI中创建Value Box

    我想根据我拥有的数据创建一个值框 假设我有 5 个数据变量consumerdata像这样 id data number1 number2 1 k4j A 67 53 2 rls B 30 62 3 yv9 C 45 28 4 l6h D 6
  • 如何在 Eclipse 中使用 SonarLint

    我被分配使用 SonarQube 来提高代码质量 但是当我将它的插件下载到 Eclipse 时 我知道它已被弃用 新的 插件是 SonarLint 但到目前为止 我找不到任何关于如何使用 SonarLint 的好的文档 如何使用它检查jav
  • Delphi 2010远程调试-无法使断点工作

    我最近发布了这个问题 https stackoverflow com questions 4579654 no breakpoints when remote debugging with delphi 2010 so stuck on d
  • 如何从C中的文件中读取最后n行

    这是一道微软面试题 使用 C 读取文件的最后 n 行 精确地 实现这一目标的方法有很多 但其中很少有 gt 最简单的是 在第一遍中 计算文件中的行数 在第二遍中显示最后 n 行 gt 或者可以为每一行维护一个双向链表 并通过向后遍历链表直到
  • 查询 Firestore 文档中的参考字段

    我正在尝试编写一个函数 在文档 Firestore 艺术家 集合中 中的数据发生更改后 Google Cloud Functions 将查找另一个集合 显示 中具有引用字段的所有文档 artist 指向刚刚更改的文档 在 artists 集
  • XMLHttpRequest 已弃用。用什么代替?

    尝试使用纯 JS 方法来检查我是否有有效的 JS 图像 url 我收到警告XMLHttpRequest已弃用 有什么更好的方法来做到这一点 urlExists url const http new XMLHttpRequest http o
  • 目标数组形状与使用 Tensorflow 的预期输出不同

    我正在尝试制作 CNN 仍然是初学者 当尝试拟合模型时 我收到此错误 ValueError 形状为 10000 10 的目标数组被传递用于形状 None 6 6 10 的输出 同时用作损失categorical crossentropy 这