深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist)

2023-11-01

Mnist数据集是深度学习入门的数据集,昨天发现了Chinese-Mnist数据集,与Mnist数据集类似,只不过是汉字数字,例如‘一’、‘二’、‘三’等,本次实验利用自己搭建的CNN网络实现Chinese版的手写数字识别。

1.导入库

import tensorflow as tf
import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keras

warnings.filterwarnings("ignore")#忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

2.数据加载

原数据中包括15000张图片,如下所示:
在这里插入图片描述
原数据并没有将各类数据分开,而是给出了一个csv文件:
在这里插入图片描述
在进行训练之前将图片分类,首先对数据的标签进行切片

train = pd.read_csv("E:/tmp/.keras/datasets/chinese_mnist/chinese_mnist.csv")
#训练数据的标签
train_image_label = [i for i in train["character"]]
#将标签切片
train_label_ds = tf.data.Dataset.from_tensor_slices(train_image_label)

统计每张图片的具体路径:

#训练数据的具体路径
img_dir = "E:/tmp/.keras/datasets/chinese_mnist/data/data/input"
train_image_paths = []
for row in train.itertuples():
    suite_id = row[1]
    sample_id = row[2]
    code = row[3]
    train_image_paths.append(img_dir+"_"+str(suite_id)+"_"+str(sample_id)+"_"+str(code)+".jpg")
#对图片路径进行切片
train_path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)

train_image_paths结果如下:

E:/tmp/.keras/datasets/chinese_mnist/data/data/input_1_1_10.jpg

读取图片并进行预处理,然后切片

#图片预处理
def preprocess_image(image):
    image = tf.image.decode_jpeg(image,channels = 3)
    image = tf.image.resize(image,[height,width])
    return image / 255.0
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
#根据路径读取图片并进行预处理
train_image_ds = train_path_ds.map(load_and_preprocess_image,num_parallel_calls=tf.data.experimental.AUTOTUNE)

将train_image_ds与train_label_ds组合在一起

image_label_ds = tf.data.Dataset.zip((train_image_ds,train_label_ds))

显示图片:

for i in range(20):
    plt.subplot(4, 5, i + 1)
    num +=1
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)

    # 显示图片
    images = plt.imread(train_image_paths[i])
    plt.imshow(images)
    # 显示标签
    plt.xlabel(train_image_label[i])

plt.show()

在并未对数据进行shuffle之前,如下所示:
在这里插入图片描述
原数据中一共15000张图片,分为15类,每类1000张,并按照顺序排列,因此需要对数据进行打乱。

image_label_ds = image_label_ds.shuffle(15000)

按照8:2的比例划分训练集与测试集

train_ds = image_label_ds.take(12000).shuffle(2000)
test_ds = image_label_ds.skip(12000).shuffle(3000)

超参数的设置

height = 64
width = 64
batch_size = 128
epochs = 50

对训练集与测试集进行batch_size 划分

train_ds = train_ds.batch(batch_size)#设置batch_size
train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_ds = test_ds.batch(batch_size)
test_ds = test_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

再次检查图片,看看是否被打乱顺序:

plt.figure(figsize=(8, 8))

for images, labels in train_ds.take(1):
    # print(images.shape)
    for i in range(12):
        ax = plt.subplot(4, 3, i + 1)
        plt.imshow(images[i])
        plt.title(labels[i].numpy())  # 使用.numpy()将张量转换为 NumPy 数组

        plt.axis("off")
    break
plt.show()

在这里插入图片描述
顺序已被打乱,初始目标完成。

3.网络搭建&&编译

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=32,kernel_size=(3,3),padding="same",activation="relu",input_shape=[64, 64, 3]),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(15, activation="softmax")
])

model.compile(optimizer="adam",
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()
history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs = epochs
)

经过50次epochs,训练结果如下:
在这里插入图片描述
准确率达到了100%

4.混淆矩阵的绘制

模型加载:

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/chinese_mnist/model.h5")

标签列表如下所示:

all_label_names = ['零','一','二','三','四','五','六','七','八','九','十','百','千','万','亿']

绘制混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd
    # 绘制混淆矩阵
all_label_names = ['零','一','二','三','四','五','六','七','八','九','十','百','千','万','亿']
def plot_cm(labels, pre):
    conf_numpy = confusion_matrix(labels, pre)  # 根据实际值和预测值绘制混淆矩阵
    conf_df = pd.DataFrame(conf_numpy, index=all_label_names,
                               columns=all_label_names)  # 将data和all_label_names制成DataFrame
    plt.figure(figsize=(8, 7))

    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")  # 将data绘制为混淆矩阵
    plt.title('混淆矩阵', fontsize=15)
    plt.ylabel('真实值', fontsize=14)
    plt.xlabel('预测值', fontsize=14)
    plt.show()

model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/chinese_mnist/model.h5")

test_pre = []
test_label = []
for images, labels in test_ds:
    for image, label in zip(images, labels):
        img_array = tf.expand_dims(image, 0)  # 增加一个维度
        pre = model.predict(img_array)  # 预测结果
        test_pre.append(all_label_names[np.argmax(pre)])  # 将预测结果传入列表
        test_label.append(all_label_names[label.numpy()])  # 将真实结果传入列表
plot_cm(test_label, test_pre)  # 绘制混淆矩阵#

在这里插入图片描述
总结:本次实验最复杂的就是标签处理那一块,只有处理好这一步骤,才能正确的将图片和标签划分到一起。实验数据只有15000张,而Mnist数据集有70000张,虽然本次的模型准确率达到了100%,但是仍有可能在别的图片预测错误。

努力加油a啊

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist) 的相关文章

  • 如何保存 Tensorflow.js 模型?

    我想制作一个创建 保存和训练 tensorflow js 模型的用户界面 但我无法在创建模型后保存模型 我什至从tensorflow js文档复制了这段代码 但它不起作用 const model tf sequential layers t
  • TensorFlow CUDA_ERROR_OUT_OF_MEMORY

    我正在尝试在 TensorFlow 中构建一个大型 CNN 并打算在多 GPU 系统上运行它 我采用了 塔式 系统 并为两个 GPU 拆分批次 同时将变量和其他计算保留在 CPU 上 我的系统有 32GB 内存 但是当我运行代码时出现错误
  • Keras,如何获取每一层的输出?

    我已经用 CNN 训练了一个二元分类模型 这是我的代码 model Sequential model add Convolution2D nb filters kernel size 0 kernel size 1 border mode
  • 如何正确将 tflite_graph.pb 转换为 detector.tflite

    我正在使用tensorflow对象检测API使用tensorflow中的ssdlite mobilenet v2 coco 2018 05 09来训练自定义模型模型动物园 https github com tensorflow models
  • 从图中删除节点或重置整个默认图

    使用默认全局图时 是否可以在添加节点后将其删除 或者将默认图重置为空 当我在 IPython 中交互地使用 TF 时 我发现自己必须反复重新启动内核 如果可能的话 我希望能够更轻松地尝试图表 更新 11 2 2016 tf reset de
  • 如何安装libcusolver.so.11

    我正在尝试安装 Tensorflow 但它要求 libcusolver so 11 而我只有 libcusolver so 10 有人可以告诉我我做错了什么吗 这是我的 Ubuntu nvidia 和 CUDA 版本 uname a Lin
  • 跨多个 GPU/机器的 TF-Slim 的配置/标志

    我很好奇是否有关于如何使用部署 model deploy py 在多台机器上的多个 GPU 上运行 TF Slim models slim 的示例 该文档非常好 但我缺少一些内容 具体来说 需要为worker device和ps devic
  • 如何使用 Keras 中的 Conv2D 在 5D 张量的最后三个维度上应用卷积?

    通常的输入张量Conv2DKeras 中是一个 4D 张量 其维度为batch size n n channel size 现在我有一个 5D 张量 其尺寸为batch size N n n channel size我想对中的每个 i 应用
  • Keras:加载多个模型并在不同线程中进行预测

    我正在使用带有张量流核心的 Keras 我想在构造函数中加载 2 个不同的模型 然后在不同的线程中进行预测 根据请求 我尝试在张量流图上下文中加载这些模型 但它不起作用 我的代码 from keras models import load
  • 如何在nodejs(tensorflow.js)中训练模型?

    我想做一个图像分类器 但我不会python Tensorflow js 使用我熟悉的 javascript 可以用它来训练模型吗 训练步骤是什么 坦白说 我不知道从哪里开始 我唯一想到的是如何加载 mobilenet 它显然是一组预先训练的
  • 支持 Nvidia CUDA 工具包 9.2

    Tensorflow gpu 绑定到 Nvidia CUDA Toolkit 的特定版本的原因是什么 当前版本似乎专门寻找 9 0 并且不适用于任何更高版本 例如 我安装了最新的 Toolkit 9 2 并将其添加到路径中 但 Tensor
  • Tensorboard 和 Dropout 层

    我有一个非常基本的查询 我制作了 4 个几乎相同 差异在于输入形状 的 CNN 并在连接到全连接层的前馈网络时合并了它们 几乎相同的 CNN 的代码 model3 Sequential model3 add Convolution2D 32
  • Tensorflow中的Tensor和Variable有什么区别

    有什么区别Tensor and Variable在张量流中 我注意到在这个 stackoverflow 答案 https stackoverflow com questions 38556078 in tensorflow what is
  • Tensorflow如何生成不平衡组合数据集

    我对新数据集 API tensorflow 1 4 有疑问 我有两个数据集 我需要创建一个组合的不平衡数据集 即 每个批次应包含第一个数据集中一定数量的元素和第二个数据集中一定数量的元素 例如 dataset1 tf data Datase
  • 如何将两个 keras 模型连接成一个模型?

    假设我有一个 ResNet50 模型 我希望将该模型的输出层连接到 VGG 模型的输入层 这是 ResNet 模型和 ResNet50 的输出张量 img shape 164 164 3 resnet50 model ResNet50 in
  • 安装后 Anaconda 提示损坏

    我刚刚安装张量流GPU创建单独的后环境按照以下指示here https github com antoniosehk keras tensorflow windows installation 但是 安装后当我关闭提示窗口并打开新航站楼弹出
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • 如何解释tf.map_fn的结果?

    看代码 import tensorflow as tf import numpy as np elems tf ones 1 2 3 dtype tf int64 alternates tf map fn lambda x x x x el
  • 从 swift 数组创建张量

    这工作正常 import TensorFlow var t Tensor
  • 我无法使用 scikeras.wrappers.KerasRegressor 执行 cross_val_score

    from tensorflow import keras from sklearn model selection import cross val score from sklearn datasets import make regre

随机推荐

  • PyQt专题结题感言:Python图形用户界面开发

    PyQt专题结题感言 Python图形用户界面开发 在这篇文章中 我将为您介绍PyQt框架 这是一个功能强大的Python图形用户界面 GUI 开发工具 我将详细解释PyQt的基本概念和用法 并提供一些源代码示例来帮助您入门 PyQt是一个
  • 将两个有序数组合并为一个新的有序数组(Java实现)

    不可否认的是Java确实C语言方便许多 这种写法比我在C中的那种写法要好 代码如下 public class Test public static void main String args int arr1 1 3 5 7 9 int a
  • linux 删除的文件太多:bash:/usr/bin/rm: Argument list too long

    背景 删除一个文件夹内指定后缀的文件时 遇到错误 提示 bash usr bin rm Argument list too long 很明显是指定后缀的文件太多 导致无法删除 解决方案 通过命令find来进行删除 比如要删除所有的json文
  • 基于python+flask实现视频数据可视化

    使用爬虫对视频弹幕进行爬取并保存为csv文件 导入数据库中 进而实现前后端交互功能 数据集中包含的数据分别为爬取的热门视频的标题 播放量 弹幕量 收藏量 综合得分以及视频的类别等信息 便于后续我们进行数据分析 我们使用数据库中的数据评出综合
  • 任务分配的穷举法、匈牙利法、分支定界法

    1 必做 任务分配问题 设有 4 项任务 B1 B2 B3 B4 派 4 个人 A1 A2 A3 A4 去完成 每个人都可以承担 4 项任务中的任何一项 但所消耗的资金不同 设 Ai 完成 Bj 所需资金为 问如何分配任务 使总费用最少 假
  • 无缝漫游的过程!

    无缝漫游中无线AP的配置与普通无线AP的配置基本相同 只是应当注意以下几个方面的问题 所有无线AP必须使用同一SSID 所有无线AP必须使用同一网段的IP地址 并且处于同一VLAN中 信号相互覆盖的无线AP不能使用相同的频道 由于多个AP信
  • 图形分析之Nsight的使用

    作者 i dovelemon 日期 2017 06 11 来源 CSDN 主题 Nsight OpenGL 引言 最开始的时候 我进行图形编程使用的是DX 所以那时候进行图形分析的时候 基本都是使用PIX 后来转向了OpenGL 分析的时候
  • JVM-17(垃圾回收器)上

    目录 17 1 GC分类与性能指标 17 1 1 JVM的发展 17 1 2评估GC的性能指标 17 2 不同的垃圾回收器概述 17 3 Serial回收器 串行回收 17 4 ParNew回收器 并行回收 17 5 Parallel回收器
  • html超链接打开共享文件夹,教你如何访问共享文件夹

    现在我们往往要讲究 资源共享 就是有好的东西跟大家一起分享 那么到电脑上呢经常有一些文件夹 有的是加密的 有的是共享的 今天呢小编就要给大家讲讲如何访问这些共享文件夹 要想查看共享文件夹其实也是有步骤可言的 首先 要先打开控制面板 有一个W
  • 电脑开机就显示360服务器,我用360给电脑杀毒,一直到开机启动项会停止,显示“扫面服务意外终止,无法继续扫描,这可能是由于程序...

    希望我的回答可以帮助楼主解决问题哦 这个问题很明显是杀毒软件自身的问题 不太知道诺顿这款杀毒软件 是不是在升级过程中发生什么问题造成的 楼主可以尝试换用腾讯电脑管家 这款杀毒软件在病毒以及木马的查杀方面很权威 很成熟 下面是我总结的电脑容易
  • ZGC收集器介绍

    ZGC收集器 XX UseZGC ZGC是一款JDK 11中新加入的具有实验性质的低延迟垃圾收集器 ZGC可以说源自于是Azul System公司开发的C4 Concurrent Continuously Compacting Collec
  • OK6410矩阵键盘驱动问题解决方案

    在嵌入式系统开发中 矩阵键盘是一种常见的输入设备 OK6410是一款广泛使用的ARM开发板 本文将介绍如何在OK6410开发板上实现矩阵键盘的驱动 硬件连接 首先 我们需要将矩阵键盘与OK6410开发板进行连接 矩阵键盘通常由多个行和列组成
  • ResNet到底在解决一个什么问题呢?

    点击上方 小白学视觉 选择加 星标 或 置顶 重磅干货 第一时间送达 来源 知乎 https www zhihu com question 64494691 文仅交流 侵删 ResNet发布于2015年 目前仍有大量CV任务用其作为back
  • C# 代码转化为Java代码

    http www tangiblesoftwaresolutions com Free Editions html Install Instant C converts VB NET code to C Install Instant VB
  • 史上最全midjourney关键词

    最全midjourney关键词 篇幅太长 文章最后有可编辑版本获取链接 增强图片真实感 清晰度 unreal engine 虚幻引擎 ultra realistic 超真实 photography 摄影图片 detailed 细节 4K 4
  • LaTeX 使用笔记——公式篇

    目录 一 行内公式 二 独立公式 一 行内公式 二 独立公式 一 括号 1 当括号的两边分别位于上下两行公式 且可能出现两个括号大小不一致的情况 例如 使用LaTeX代码 begin aligned dot V k v 1 z k v 1
  • 一次性搞清楚unicode、codepoint、代码点、UTF

    最近在处理字符过滤 重新研究了下字符 unicode和代码点的相关知识 首先要说一下编码的基本知识unicode unicode unicode是计算机科学领域里的一项业界标准 包括字符集 编码方案等 计算机采用八比特一个字节 一个字节最大
  • Python 爬虫获取某贴吧所有成员用户名

    最近想用Python爬虫搞搞百度贴吧的操作 所以我得把原来申请的小号找出来用 有一个小号我忘了具体ID 只记得其中几个字母以及某个加入的贴吧 所以今天就用爬虫来获取C语言贴吧的所有成员 计划很简单 爬百度贴吧的会员页面 把结果存到MySQL
  • FreeMarker模板使用方法讲解

    项目需要 刚接触 正在学习 FreeMarker简介 FreeMarker模板文件主要由如下4个部分组成 1 文本 直接输出的部分 2 注释 lt gt 格式部分 不会输出 3 插值 即 或 格式的部分 将使用数据模型中的部分替代输出 4
  • 深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist)

    Mnist数据集是深度学习入门的数据集 昨天发现了Chinese Mnist数据集 与Mnist数据集类似 只不过是汉字数字 例如 一 二 三 等 本次实验利用自己搭建的CNN网络实现Chinese版的手写数字识别 1 导入库 import