从官网下载/处理 MNIST 数据集,并构造CNN网络训练

2023-05-16

这里写自定义目录标题

  • MNIST 网络 测试用
    • 1. 导入所需要的模块
    • 2. 下载 MNIST 数据集
    • 3. 读取 MNIST 数据集

MNIST 网络 测试用

1. 导入所需要的模块

import sys
#sys.path.append('../../')
#from zfdplearn import fdutils, fdtorch_net, fddata
import os.path as path
import gzip

from typing import Dict, List, Tuple, AnyStr, KeysView, Any

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt


from tqdm import tqdm

2. 下载 MNIST 数据集

2.1 下载地址: http://yann.lecun.com/exdb/mnist/
2.1.1 下载的文件有 4 个,分别是:

train-images-idx3-ubyte.gz ==> 训练集的图片
train-label-idx1-ubyte.gz ==> 训练集的标签
t10k-images-idx3-ubyte.gz ==> 测试集的图片
t10k-label-idx1-ubyte.gz ==> 测试集的标签

下载的数据集格式为 .gz,因此需要使用到 python 的 gzip 包

# 下载地址: http://yann.lecun.com/exdb/mnist/
dataset_folder = '../datasets/mnist'
files_name = {
    'train_img': 'train-images-idx3-ubyte.gz',
    'train_label': 'train-labels-idx1-ubyte.gz',
    'vali_img': 't10k-images-idx3-ubyte.gz',
    'vali_label': 't10k-labels-idx1-ubyte.gz'
}

3. 读取 MNIST 数据集

3.1 下载的数据集格式为 .gz,因此需要使用 gzip 中的 open 函数打开。
3.2 打开模式设置为 mode=‘rb’,以字节流的方式打开。因为下载的数据集的格式为字节方式封装
3.3 由于使用字节流打开,因此需要使用 torch.frombuffer() 或者 np.frombuffer() 函数打开。
3.3 根据 MNIST 数据集官网可知,读取数据集需要 offset,因为,在数据头部的数据存储了数据集的一些信息
3.4.1 training set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8
trainSetLable
3.4.2 training set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数, 因此需要 offset 16
trainSetImg
3.4.2 test set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8
testSetLable
3.4.3 test set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数,因此需要 offset 16
testSetImg

PS: torch/np. frombuffer()

# 加载训练集 图片
def load_mnist_data(files_name) -> Tuple:
    with gzip.open(path.join(dataset_folder, files_name['train_img']), mode='rb') as data:
        train_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载训练集 标签
    with gzip.open(path.join(dataset_folder, files_name['train_label']), mode='rb') as label:
        train_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    # 加载验证集 图片
    with gzip.open(path.join(dataset_folder, files_name['vali_img']), mode='rb') as data:
        vali_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载验证集 label
    with gzip.open(path.join(dataset_folder, files_name['vali_label']), mode='rb') as label:
        vali_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    return (train_img, train_label),(vali_img, vali_label)


# save -- fdutils.Accumulator
class Accumulator():
    """
    collecting metrics in experiment
    """
    def __init__(self, names: List[Any]):
        self.accumulator = {}
        if not isinstance(names, list):
            raise Exception(f'type error, expected list but got {type(names)}')
        for name in names:
            self.accumulator[name] = list()

    def __getitem__(self, item) -> List[Any]:
        if item not in self.accumulator.keys():
            raise Exception(f'key error, {item} is not in accumulator')
        return self.accumulator[item]

    def add(self, name: AnyStr, val: Any):
        self.accumulator[name].append(val)

    def add_name(self, name: AnyStr):
        if name in self.accumulator.keys():
            raise Exception(f'{name} is  already in accumulator.keys')
        self.accumulator[name] = list()

    def gets(self) -> Dict[AnyStr, Any]:
        return self.accumulator

    def get_item(self, name: AnyStr) -> List[Any]:
        if name not in self.accumulator.keys():
            raise Exception(f'key error, {name} is not in accumulator')
        return self.accumulator[name]

    def clear(self):
        self.accumulator.clear()

    def get_names(self) -> KeysView:
        return self.accumulator.keys()

class MNIST_dataset(Dataset):
    def __init__(self, data: List, label: List):
        self.__data = data
        self.__label = label

    def __getitem__(self, item):
        if not item < self.__len__():
            return f'Error, index {item} is out of range'
        return self.__data[item], self.__label[item]

    def __len__(self):
        return len(self.__data)

# 读取数据
train_data, vali_data = load_mnist_data(files_name)
# 将数据封装为 MNIST 类
train_dataset = MNIST_dataset(*train_data)
vali_dataset = MNIST_dataset(*vali_data)
len(train_dataset), len(vali_dataset)
(60000, 10000)
class YLMnistNet(nn.Module):
    def __init__(self):
        super(YLMnistNet, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
        self.conv1 = nn.Conv2d(6, 16, kernel_size=(5, 5))
        self.pool0 = nn.AvgPool2d(kernel_size=(2, 2))
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2))
        self.linear0 = nn.Linear(16*4*4, 120)
        self.linear1 = nn.Linear(120, 84)
        self.linear2 = nn.Linear(84, 10)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.layers = [self.conv0, self.pool0, self.conv1, self.pool1, self.flatten, self.linear0, self.relu, self.linear1, self.relu, self.linear2, self.relu]

    def forward(self, x):
        output = self.conv0(x)
        output = self.pool0(output)
        output = self.conv1(output)
        output = self.pool1(output)
        output = self.flatten(output)
        output = self.linear0(output)
        output = self.relu(output)
        output = self.linear1(output)
        output = self.relu(output)
        output = self.linear2(output)
        output = self.relu(output)
        return output

    # get depth of MNIST Net
    def __len__(self):
        return len(self.layers)

    # get specified layer
    def __getitem__(self, item):
        return self.layers[item]

    def __name__(self):
        return 'YNMNISTNET'

net = YLMnistNet()

def train(net, loss, train_iter, vali_iter, optimizer, epochs, device) -> Accumulator:
    net = net.to(device)
    one_hot_f = nn.functional.one_hot
    accumulator = Accumulator(['train_loss', 'vali_loss', 'train_acc', 'vali_acc'])
    epoch_loss = []
    for epoch in range(epochs):
        len_train =  0
        len_vali = 0

        net.train()
        epoch_loss.clear()
        correct_num = 0
        for img, label in train_iter:
            img, label = img.to(device, dtype=torch.float), label.to(device)
            oh_label = one_hot_f(label.long(), num_classes=10)
            optimizer.zero_grad()
            y_hat = net(img)
            l = loss(y_hat, oh_label.to(dtype=float))
            l.backward()
            optimizer.step()
            epoch_loss.append(l.item())
            correct_num += (y_hat.argmax(dim=1, keepdim=True) == label.reshape(-1, 1)).sum().item()
            len_train += len(label)
        accumulator['train_loss'].append(sum(epoch_loss)/len(epoch_loss))
        accumulator['train_acc'].append(correct_num/len_train)
        print(f'-----------epoch: {epoch+1} start --------------')
        print(f'epoch: {epoch+1} train loss: {accumulator["train_loss"][-1]}')
        print(f'epoch: {epoch+1} train acc: {accumulator["train_acc"][-1]}')

        # validation
        epoch_loss.clear()
        correct_num = 0
        with torch.no_grad():
            net.eval()
            for img, label in vali_iter:
                img, label = img.to(device, dtype=torch.float), label.to(device)
                # print(img.dtype)
                oh_label = one_hot_f(label.long(), num_classes=10)
                vali_y_hat = net(img)
                l = loss(vali_y_hat, oh_label.to(dtype=float))
                epoch_loss.append(l.item())
                correct_num += (vali_y_hat.argmax(dim=1, keepdim=True) == label.reshape(-1, 1)).sum().item()
                len_vali += len(label)
            accumulator['vali_loss'].append(sum(epoch_loss)/len(epoch_loss))
            accumulator['vali_acc'].append(correct_num / len_vali)
            print(f'epoch: {epoch+1} vali loss: {accumulator["vali_loss"][-1]}')
            print(f'epoch: {epoch+1} vali acc: {accumulator["vali_acc"][-1]}')
            print(f'-----------epoch: {epoch+1} end --------------')
    return accumulator

# from torch.utils.data import DataLoader
net = YLMnistNet()
batch_size = 32
train_iter = DataLoader(train_dataset, batch_size=batch_size)
vali_iter = DataLoader(vali_dataset, batch_size=batch_size)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epoch = 1
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
accumulator = train(net, loss, train_iter, vali_iter, optimizer, num_epoch, device)

--------------epoch: 1 start ----------------------
epoch: 1 train loss: 0.4480924960145155
epoch: 1 train acc: 0.85015
epoch: 1 vali loss: 0.14723741338332405
epoch: 1 vali acc: 0.9559
--------------epoch: 1 end ----------------------
		    	…
		     	…
--------------epoch: 20 start ----------------------
epoch: 20 train loss: 0.01722543535107635
epoch: 20 train acc: 0.9943166666666666
epoch: 20 vali loss: 0.12874228838498014
epoch: 20 vali acc: 0.9754
--------------epoch: 20 end ----------------------
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从官网下载/处理 MNIST 数据集,并构造CNN网络训练 的相关文章

  • HRZ的序列

    问题描述 xff1a 相较于咕咕东 xff0c 瑞神是个起早贪黑的好孩子 xff0c 今天早上瑞神起得很早 xff0c 刷B站时看到了一个序列a xff0c 他对这个序列产生了浓厚的兴趣 xff0c 他好奇是否存在一个数K xff0c 使得
  • 东东学打牌

    问题描述 xff1a 最近 xff0c 东东沉迷于打牌 所以他找到 HRZ ZJM 等人和他一起打牌 由于人数众多 xff0c 东东稍微修改了亿下游戏规则 xff1a 所有扑克牌只按数字来算大小 xff0c 忽略花色 每张扑克牌的大小由一个
  • 咕咕东的目录管理器

    文章目录 问题描述样例输入样例输出 解题思路代码 问题描述 咕咕东的雪梨电脑的操作系统在上个月受到宇宙射线的影响 xff0c 时不时发生故障 xff0c 他受不了了 xff0c 想要写一个高效易用零bug的操作系统 这工程量太大了 xff0
  • 针对CSP-T1,T2的练习

    文章目录 题目1问题描述样例输入样例输出 解题思路代码 题目2问题描述样例输入样例输出 解题思路代码 题目1 问题描述 给出n个数 xff0c zjm想找出出现至少 n 43 1 2次的数 xff0c 现在需要你帮忙找出这个数是多少 xff
  • Rust的控制流:条件、循环以及模式匹配

    文章目录 条件控制循环控制forwhileloopbreak continue 模式匹配 条件控制 Rust的条件控制也是使用if else xff0c 和其他语言相比没有多大区别 xff0c 直接看例子 xff1a fn main let
  • 在Windows上搭建Rust开发环境——Clion篇

    文章目录 在Windows上搭建Rust开发环境 Clion篇安装mingw64安装Rusthello world安装Clion使用Clion创建并调试项目 在Windows上搭建Rust开发环境 Clion篇 刚开始学习Rust的时候 x
  • 洛谷P3366最小生成树模板

    kruskal span class token macro property span class token directive keyword include span span class token string lt cstdi
  • 在家远程控制 少了它俩简直太遗憾了

    互联网公司的值班 xff0c 本意在于出现问题时有人及时处理 xff0c 毕竟上线运行的产品 xff0c 出问题可能会影响到公司的整体收益 虽然工作是965 xff0c 但值班日程表却明明白白写着谁负责保障今天的产品运行正常 涉及到技术 运
  • Openstack Kolla-Ansible安装部署

    Openstack Kolla Ansible安装部署 部署节点制作 环境准备 CentOS环境安装 配置国内pypi源 xff1a mkdir p config pip vim config pip pip conf global ind
  • Windows 远程桌面登录蓝屏、不显示桌面问题解决方法

    远程桌面登录蓝屏 不显示桌面问题解决方法 有时候的不当操作 xff0c 可以使Windows服务器或vps远程桌面出现蓝屏或者黑屏 xff01 遇到此问题 xff0c 不要急急忙忙的让机房值班给你重启机器 xff0c 因为此时除了远程连接不
  • 【5G核心网】5GC核心网之网元UPF

    UPF xff08 User Plane Function xff0c 用户面功能 xff09 xff1a ts 29 244 23 501 5 8 1 UPF User Plane Function 用户平面功能 用于RAT内 RAT间移
  • 玩转ADB命令(ADB命令使用大全)

    此文章内容整合自网络 xff0c 欢迎转载 我相信做Android开发的朋友都用过ADB命令 xff0c 但是也只是限于安装应用push文件和设备重启相关 xff0c 更深的就不知道了 xff0c 其实我们完全可以了解多一点 xff0c 有
  • Ubuntu12.04操作系统安装时,出现的问题及解决方案

    问题一 Windows 下用 putty 连接不上虚拟机上的 Ubuntu12 04 解决方案 预探索 问题可能的原因 A 先确定你能不能ping通远程的ubuntu或者虚拟机 B 如果还不能登录 xff0c 分析原因是大多数没有真正开启s
  • 获取镜像源来搭建本地Ubuntu14.04源

    针对公司的网络限制 xff0c 可以在局域网内搭建一台本地的ubuntu源 1 修改源配置 换成搜狐源 默认的ubuntu源不如某些国内的源速度快 vi etc apt source list deb http mirrors sohu c
  • Ubuntu Desktop 16 配置ssh远程登录

    文章目录 环境介绍1 安装openssh server2 允许用户登录 xff1b 编辑配置文件3 重启sshd服务并检查状态4 查看Ubuntu主机的IP5 远程登录Ubuntu6 退出远程登录参考文献英语好的同学请忽略 环境介绍 主机系
  • 关闭Linux防火墙

    文章目录 查看防火墙状态临时关闭防火墙禁止开机启动防火墙开启防火墙允许开机启动防火墙关闭防火墙的步骤 查看防火墙状态 CentOS 6 service iptables status CentOS 7 firewall cmd state
  • ubuntu挂载sd卡到分区目录+修改docker镜像存储位置

    ubuntu挂载sd卡到分区目录 43 修改docker镜像存储位置 一 挂载SD卡到 data 1 查看Linux硬盘信息 lsblk 或 fdisk l lsblk 新的硬盘 xff0c 最好删除之前的分区 xff0c 再新建分区 de
  • xRDP "Password failed, error - problem connecting"

    Add this in sesman ini under Xvnc solved my problem param8 61 SecurityTypes param9 61 None This solved my problum sudo n
  • 如何远程公司 上班族必选大集合

    老张是我们销售部的经理 xff0c 为人随和 xff0c 一点架子也没有 xff0c 和我们关系搞的都很好 xff0c 也很袒护我们 xff0c 由于疫情的原因 xff0c 不得已要居家办公了 xff0c 这让同事们都很不适应 xff0c
  • C语言排序算法之简单交换法排序,直接选择排序,冒泡排序

    C语言排序算法之简单交换法排序 xff0c 直接选择排序 xff0c 冒泡排序 xff0c 最近考试要用到 xff0c 网上也有很多例子 xff0c 我觉得还是自己写的看得懂一些 简单交换法排序 1 简单交换法 排序 2 根据序列中两个记录

随机推荐