基于卷积神经网络的手写数字识别(自建模型)

2023-11-03

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

卷积神经网络是一种多层的监督学习神经网络,隐含层的卷积层和池采样层是实现卷积神经网络特征提取功能的核心模块。该网络模型通过采用梯度下降法最小化损失函数对网络中的权重参数逐层反向调节,通过频繁的迭代训练提高网络的精度。


提示:以下是本篇文章正文内容,下面案例可供参考

一、任务要求

数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别,大大缩短了业务处理时间,提升了工作效率和质量。

在处理如 1 所示的手写邮政编码的简单图像分类任务时,可以使用基于MNIST数据集的手写数字识别模型。MNIST是深度学习领域标准、易用的成熟数据集,包含60000条训练样本和10000条测试样本。

  • 任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。
  • 任务输出:经过了大小归一化和居中处理,输出对应的0~9数字标签。

二、模型搭建

1.卷积层

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)

参数:

  • in_channels:输入通道
  • out_channels:输出通道
  • kernel_size:卷积核大小
  • stride:步长
  • padding:填充

 激活层

激活层使用ReLU激活函数。
线性整流函数(Rectified Linear Unit, ReLU),又称修正线性单元,是一种人工神经网络中常用的激活函数(activation function),通常指代以斜坡函数及其变种为代表的非线性函数。

https://img-blog.csdnimg.cn/e3891b78542d420b8aa59b40f170585c.png

torch.nn.ReLU()

 池化层

池化层采用最大池化。
池化(pooling)后,C(C.hannels)不变,W(width)H(Height)变。

https://img-blog.csdnimg.cn/6f188dd455bf40428442cdd7c44cf2ae.png

torch.nn.MaxPool2d(input, kernel_size, stride, padding)

参数:

  • input:输入
  • kernel_size:卷积核大小
  • stride:步长
  • padding:填充

全连接层

之前卷积层要求输入输出是四维张量(B,C,W,H),而全连接层的输入与输出都是二维张量(B,Input_feature),经过卷积、激活、池化后,使用view打平,进入全连接层。

 LeNet5模型

比如输入一个手写数字“5”的图像,它的维度为(batch,1,28,28)即单通道高宽分别为28像素。

首先通过一个卷积核为5×5的卷积层,其通道数从1变为10,高宽分别为24像素;

然后通过一个卷积核为2×2的最大池化层,通道数不变,高宽变为一半,即维度变成(batch,10,12,12);

然后再通过一个卷积核为5×5的卷积层,其通道数从10变为20,高宽分别为8像素;

再通过一个卷积核为2×2的最大池化层,通道数不变,高宽变为一半,即维度变成(batch,20,4,4);

之后将其view展平,使其维度变为320(2044)之后进入全连接层,用线性函数将其输出为10类,即“0-910个数字。

模型代码:

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(320, 100),
            torch.nn.Linear(100, 50),
            torch.nn.Linear(50, 10)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
        x = self.conv2(x)  # 再来一次
        x = x.view(batch_size, -1)  # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320
        x = self.fc(x)
        return x  # 最后输出的是维度为10的,也就是(对应数学符号的0~9)

2.读入数据

1.引入库

import torch
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
 

2.装载数据集

compose的作用是组合打包,这边需要将数据转化为Tensor同时进行标准化

dataset用来下载数据集   下载前把download设置为ture  下载完设为false

dataloader用来装载数据集 shuffle用来打乱数据集,防止数据存放顺序影响特征提取

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data/mnist', train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root='./data/mnist', train=False, download=False, transform=transform)  # train=True训练集,=False测试集
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

3.展示数据集内容

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.tight_layout()
    plt.imshow(train_dataset.train_data[i], cmap='gray', interpolation='none')
    plt.title("Labels: {}".format(train_dataset.train_labels[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

4.加载优化器和参数

model = Net()

criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)  # lr学习率,momentum冲量

.参数设置

batch_size = 64
learning_rate = 0.01
momentum = 0.5
EPOCH = 5

5.训练轮        


def train(epoch):
    running_loss = 0.0  # 这整个epoch的loss清零
    running_total = 0
    running_correct = 0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()

        # forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs, target)

        loss.backward()
        optimizer.step()

        # 把运行中的loss累加起来,为了下面300次一除
        running_loss += loss.item()
        # 把运行中的准确率acc算出来
        _, predicted = torch.max(outputs.data, dim=1)
        running_total += inputs.shape[0]
        running_correct += (predicted == target).sum().item()

        if batch_idx % 300 == 299:  # 不想要每一次都出loss,浪费时间,选择每300次出一个平均损失,和准确率
            print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))
            running_loss = 0.0  # 这小批300的loss清零
            running_total = 0
            running_correct = 0  # 这小批300的acc清零

6.测试轮

        

def test():
    correct = 0
    total = 0
    with torch.no_grad():  # 测试集不用算梯度
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度,沿着行(第1个维度)去找1.最大值和2.最大值的下标
            total += labels.size(0)  # 张量之间的比较运算
            correct += (predicted == labels).sum().item()
    acc = correct / total
    print('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch+1, EPOCH, 100 * acc))  # 求测试的准确率,正确数/总数
    return acc

7.主代码

if __name__ == '__main__':
    acc_list_test = []
    for epoch in range(EPOCH):
        train(epoch)
        # if epoch % 10 == 9:  #每训练10轮 测试1次
        acc_test = test()
        acc_list_test.append(acc_test)

    plt.plot(acc_list_test)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy On TestSet')
    plt.show()


mymodel自建模型

设计思路参考了vgg16网络和BN网络

采用每两层卷积叠加一层池化的方式丰富参数量,同时引入BN网络的思想,在每层的输入前增加一次归一化,增强非线性,加快模型收敛速度。最后引入drop正则机制,防止系统过拟合。

模型使用为(32*32*1)

class mymodel(nn.Module):
    def __init__(self):
        super(mymodel,self).__init__()

        self.cov1_1 = nn.Conv2d(1, 6,kernel_size=(3,3), stride=(1, 1),padding="same")
        self.bn1_1 = nn.BatchNorm2d(6)
        self.cov1_2 = nn.Conv2d(6, 12,kernel_size=(3,3), stride=(1, 1),padding="same")
        self.bn1_2 = nn.BatchNorm2d(12)
        self.maxpool1 = nn.MaxPool2d((2,2),stride=2)

        self.cov2_1 = nn.Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1), padding="same")
        self.bn2_1 = nn.BatchNorm2d(24)
        self.cov2_2 = nn.Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding="same")
        self.bn2_2 = nn.BatchNorm2d(48)
        self.maxpool2 = nn.MaxPool2d((2, 2), stride=2)

        self.cov3_1 = nn.Conv2d(48,96, kernel_size=(3, 3), stride=(1, 1), padding="same")
        self.bn3_1 = nn.BatchNorm2d(96)
        self.cov3_2 = nn.Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding="same")
        self.bn3_2 = nn.BatchNorm2d(192)
        self.maxpool3 = nn.MaxPool2d((2, 2), stride=2)

        self.cov4_1 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1))
        self.bn4_1 = nn.BatchNorm2d(384)

        self.flatten = nn.Flatten()
        self.do = nn.Dropout(0.5)
        self.fc = nn.Linear(384*2*2, 10)



    def forward(self, x):
        x = self.cov1_1(x)
        x = self.bn1_1(x)
        x = F.relu(x)
        x = self.cov1_2(x)
        x = self.bn1_2(x)
        x = F.relu(x)
        x = self.maxpool1(x)

        x = self.cov2_1(x)
        x = self.bn2_1(x)
        x = F.relu(x)
        x = self.cov2_2(x)
        x = self.bn2_2(x)
        x = F.relu(x)
        x = self.maxpool2(x)

        x = self.cov3_1(x)
        x = self.bn3_1(x)
        x = F.relu(x)
        x = self.cov3_2(x)
        x = self.bn3_2(x)
        x = F.relu(x)
        x = self.maxpool3(x)

        x = self.cov4_1(x)
        x = self.bn4_1(x)
        x = F.relu(x)

        x = self.flatten(x)
        x = self.do(x)
        x = self.fc(x)
        x = F.relu(x)

        return x

 

(11层的模型)

 上述代码是32*32,数据集是28*28,前面需要加一步尺寸变换

可以看出自建模型的收敛速度明显更快,准确度也明显提高。

总结

总代码就不展示了  按顺序复制粘贴就可以了,有不清楚的评论区可以提问

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于卷积神经网络的手写数字识别(自建模型) 的相关文章

  • Google 政策更新后不允许仅使用用户名和密码,如何使用 python 发送电子邮件?

    我正在尝试学习如何使用 python 发送电子邮件 我读过的所有网络教程都解释了如何使用 Gmail 进行操作 但是 从 2022 年 5 月 30 日起 尽管每个人都可以自由地使用自己的帐户做任何他想做的事情 Google 制定了一项新政
  • Matplotlib imshow:如何在矩阵上应用蒙版

    我正在尝试以图形方式分析二维数据 matplotlib imshow在这方面非常有用 但我觉得如果我可以从矩阵中排除一些单元格 超出感兴趣范围的值 我可以更多地利用它 我的问题是这些值使我感兴趣的范围内的色彩图 变平 排除这些值后 我可以获
  • 在 HSV 颜色空间内定义组织学图像掩模的颜色范围(Python、OpenCV、图像分析):

    为了根据颜色将组织学切片分成多个层 我修改了 OpenCV 社区提供的一些广泛分布的代码 1 我们的染色程序用不同的颜色标记组织横截面的不同细胞类型 B 细胞为红色 巨噬细胞为棕色 背景细胞核为蓝色 I m interested in se
  • 如何分组显示argparse子命令?

    对于具有许多子命令的程序 我想在 help 输出中显示它们按逻辑分组 Python argparse 有一个add argument group http docs python org library argparse html argp
  • 代码运行时出现内存问题(Python、Networkx)

    我编写了一个代码来生成具有 379613734 条边的图 但由于内存问题 代码无法完成 当经过 6200 万行时 大约会占用服务器内存的 97 所以我杀了它 您有解决这个问题的想法吗 我的代码是这样的 import os sys impor
  • 为什么 Python 中的无分支函数和内置函数速度较慢?

    我发现了 2 个无分支函数 它们可以在 python 中查找两个数字的最大值 并将它们与 if 语句和内置 max 函数进行比较 我认为无分支或内置函数将是最快的 但最快的是 if 语句函数 有人知道这是为什么吗 以下是功能 If 语句 2
  • Python 2.7从非默认目录打开多个文件(对于opencv)

    我在 64 位 win7 上使用 python 2 7 并拥有 opencv 2 4 x 当我写 cv2 imread pic 时 它会在我的默认 python 路径中打开 pic 即C Users Myname 但是我如何设法浏览不同的目
  • 字典键中的通配符

    假设我有一本字典 rank dict V 1 A 2 V 3 A 4 正如您所看到的 我在一个 V 的末尾添加了一个 虽然 3 可能只是 V 的值 但我想要 V1 V2 V2234432 等的另一个密钥 我想检查它 checker V30
  • 如何在 matplotlib 图中禁用 xkcd?

    您可以通过以下方式打开 xkcd 风格 import matplotlib pyplot as plt plt xkcd 但如何禁用它呢 I try self fig clf 但这行不通 简而言之 要么使用 Valentin 提到的上下文管
  • XGBOOST 功能名称错误 - Python

    也许这个问题已经以不同的形式被问过很多次了 但是 我的问题是当我使用XGBClassifier 对于像数据这样的产品 我收到功能名称不匹配错误 我希望有人能告诉我我做错了什么 这是我的代码 顺便说一句 数据完全是编造的 import pan
  • 使用字符串迭代 url - python

    我现在完全被我的代码困住了 首先 我尝试从 volkskrant 的存档页面检索所有网址 这是我被打击的第一步 某一特定日期的 url 如下所示 http www volkskrant nl archief detail 01012016
  • 将多个 csv 文件连接成具有相同标头的单个 csv

    我目前正在使用以下代码导入 6 000 个 csv 文件 带标题 并将它们导出到单个 csv 文件 带单个标题行 import csv files from folder path r data US market merged data
  • Python 结构的 PHP 替代品

    我很高兴在我的 Python 项目中使用 Fabric 进行部署 现在我正在从事一个更大的 PHP 项目 想知道是否有类似 PHP 的 Fabric 之类的东西 唔 为什么这有关系 Fabric 只是 python 脚本 所以它与项目语言无
  • 如何从张量流数据集迭代器返回同一批次两次?

    我正在转换一些旧代码以使用数据集 API 此代码使用feed dict将一批数据送入列车运行 实际上是三次 然后重新计算损失以供显示使用同一批 所以我需要一个迭代器来返回完全相同的批次两次 或多次 不幸的是 我似乎找不到一种使用张量流数据集
  • 从 SQL 数据库导入表并按日期过滤行时,将 Pandas 列解析为日期时间

    我有一个DataFrame列名为date 我们如何将 日期 列转换 解析为DateTime object 我使用 Postgresql 数据库加载日期列sql read frame 的一个例子date列是2013 04 04 我想做的是选择
  • 如何在 django-rest-framework 查询集响应中添加注释数据?

    我正在为查询集中的每个项目生成聚合 def get queryset self from django db models import Count queryset Book objects annotate Count authors
  • 当没有 main 函数时,为什么 sys.settrace 不触发?

    import sys def printer frame event arg print frame event arg return printer sys settrace printer x 1 sys settrace None 上
  • Python 3d 金字塔

    我是 3D 绘图新手 我只想用 5 个点建造一个金字塔并通过它切出一个平面 我的问题是我不知道如何填充两侧 points np array 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 fig plt figure ax fi
  • 捕获 subprocess.run() 的输入

    我在 Windows 上有一个交互式命令行 exe 文件 是由其他人编写的 当程序出现异常时 它会终止 并且我对程序的所有输入都会丢失 所以我正在编写一个 python 程序 它调用一个阻塞子进程subprocess run 并捕获所有输入
  • python nltk从句子中提取关键字

    我们要做的第一件事 就是杀掉所有律师 威廉 莎士比亚 鉴于上面的引用 我想退出 kill and lawyers 作为两个突出的关键词来描述句子的整体含义 我提取了以下名词 动词 POS 标签 First NNP thing NN do V

随机推荐

  • Java基础知识查阅表(四)[线程、网络编程、注解、java8新特性]

    文章目录 Java中的线程 线程的分类 线程调度规则 获取线程的优先级 其他几个方法 线程的通信 守护线程 线程的生命周期 线程安全问题 线程安全的类 ReentrantLock加锁 关于锁的面试题 定时器Timer Java网络编程 两个
  • 数据结构—顺序表基本操作(c语言代码)

    顺序表 计算机内部存储一张线性表是用一组连续地址内存单元 这种存储结构即为顺序存储结构 这种结构下的线性表叫顺序表 顺序表有两种定义方法 1 静态定义 2 动态生成 顺序表是最简单的一种线性存储结构 优点 构造简单 操作方便 通过顺序表的首
  • python装饰器原理

    装饰器作用 装饰器在实际开发中应用广发 如 1 引入日志 2 函数执行时间统计 3 执行函数前预备处理 4 执行函数后清理功能 5 权限校验等场景 6 缓存 装饰器可以实现在不修改之前已经写好并且封装好的代码的前提下对之前的代码进行功能上的
  • LASlib/LAStools:Win10 + VS2017 编译LASlib/LAStools

    一 下载解压 下载地址 http lastools github io download LAStools zip 解压地址 G LAStools 二 编译 2 1 打开 用VS2017打开lastools dsw 历史原因 一直点确定就可
  • Linux shell 从文件中随机选择内容

    如果需要从文件中随机选择一定行的内容 可以借助sort 命令 如下 使用sort 命令将文件随机排序 选择前100行 sort random sort file head n 100
  • 《自然语言处理》第二次作业:语言模型和文本分类

    文章目录 作业要求 代码 读取数据集 建立二元语法模型 朴素贝叶斯分类 分类和评估 计算困惑度 完整代码 运行结果 作业要求 题目 语言模型和文本分类 数据集 text classification data用户评论 包括训练集 开发测试集
  • 三位数除以两位数怎么算竖式_四年级数学上册三位数除以两位数竖式笔算专项练习(10套)...

    四年级数学上册三位数除以两位数竖式笔算专项练习 一 三位数除以两位数的除法 包括以下两部分 一 三位数除以整十数 如 二 三位数除以两位数 二 除数是两位数的除法法则 从被除数左边的高位起 先用除数试除被除数的前两位数 如果它比除数小 再试
  • 深入浅出讲解 NAT 和 UDP/TCP 点对点通讯

    深入浅出讲解 NAT 和 UDP TCP 点对点通讯 转自 http blog csdn net g brightboy article details 12704933 一 什么是NAT 为什么要使用NAT NAT是将私有地址转换为合法I
  • Java面向对象(基础总结)

    Java面向对象 基础总结 面向对象是一种编程思想 面向对象的三大基本特征 封装 继承 多态 面向对象的编程思想就是把事物看作一个整体 从事物的特征 属性 和行为 方法 两个方面进行描述 面向对象的过程就是找对象 建立对象 使用对象 维护对
  • angular中涉及rxjs请求beego接口跨域问题解决

    今天遇到一个调用服务端接口跨域问题 我用本地的angular运行项目 访问本地的beego接口 发现请求接口状态404 并且接口方法还是OPTIONS 一查知道是跨域了 在网上搜索一些跨域访问的方法 发现跨域时访问可以了 但正常post接口
  • Java中的代理(二)--JDK动态代理

    JDK动态代理借助接口实现 目标类需是接口形式 代理类继承InvocationHandler类 通过反射方式动态创建目标类 1 目标对象 public interface ByShoot void byShoot String size p
  • LWN 翻译:Atomic Mode Setting 设计简介(下)

    译者注 紧接上篇文章 本篇翻译起来有难度 同时对读者的技术背景有一定要求 适合深入研究 DRM 驱动的开发人员阅读 通过阅读本文 你将了解如下内容 DRM MODE ATOMIC ALLOW MODESET 标志位的由来及其作用 驱动中随处
  • 常见CMS系统总结

    一 CMS系统是什么 CMS系统指的是内容管理系统 CMS可以理解为CMS帮你把一个网站的程序部分的事全做完了 你要做的只是一个网站里面美工的部份 只要搞几个静态网页模板 一个门户级的网站就出来了 二 CMS系统的分类 企业建站系统 Met
  • MYSQL:ER_NOT_SUPPORTED_AUTH_MODE:Client does not support authentication protocol

    今天新建一个koa项目 启动调用mysql驱动的时候报该错误 solution 在系统mysql终端输入下面命令 重启koa进程即可 yourpassword 是你的数据库账户密码 root和host也是 ALTER USER root l
  • 印刷MES管理系统等数字化系统,应用发展如此迅速

    作为印刷企业最基本的数字化管理系统 印刷MES管理系统与印刷ERP管理系统在最近两年普遍受到印刷企业的关注并得到迅速发展 市场需求旺盛 1 ERP逐渐普及到中小企业 ERP管理系统延续了前两年的发展趋势 市场正在从普及阶段转入升级阶段 一方
  • 构建微服务开源生态,TARS项目将成立基金会

    导语 在20世纪60至70年代 软件开发人员通常在大型机和小型机上使用单体架构进行软件开发 没有一个应用程序能够满足大多数最终用户的需求 垂直行业使用的软件代码量更小 与其他应用程序的接口更简单 而可伸缩性在当时并不是优先考虑的 随着互联网
  • -day11--函数进阶

    day11 函数进阶 目标 掌握函数相关易错点 项目开发必备技能 概要 参数的补充 函数名 函数名到底是什么 区分返回值和print 函数的作用域 1 参数的补充 函数进阶 在特定情况下可以让代码更加简洁 提升开发效率 1 1 参数内存地址
  • 一、Nginx源码安装与yum安装

    目录标题 源码安装 yum安装 源码安装 wget http nginx org download nginx 1 15 8 tar gz tar zxvf nginx 1 15 8 tar gz cd nginx 1 15 8 confi
  • 关于Vue中element按需引入

    在项目中使用elementui确实是很方便的一件事 但是如果我只需要用到其中的某一些元素来简化代码的话 全局引入就显得有点臃肿了 这就有了按需引入的概念 需要什么就引入什么 方便 一 安装element ui npm i element u
  • 基于卷积神经网络的手写数字识别(自建模型)

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 前言 卷积神经网络是一种多层的监督学习神经网络 隐含层的卷积层和池采样层是实现卷积神经网络特征提取功能的核心模块 该网络模型通过采用梯度下降法最小化损失函数对网络中