基于CIFAR100的VGG网络结构详解

2023-11-01

基于CIFAR100的VGG网络详解

码字不易,点赞收藏


1 数据集概况

1.1 CIFAR100

cifar100包含20个大类,共100类,train集50000张图片,test集10000张图片。
在这里插入图片描述
CIFAR100下载地址:http://www.cs.toronto.edu/~kriz/cifar.html

1.2 showdata.py查看数据

import cv2
import numpy as np
import pickle
import os


# 解压缩,返回解压后的字典
def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict


def cifar100_to_images():
    tar_dir = './data/cifar-100-python/'  # 原始数据库目录
    train_root_dir = './data/cifar100/train/'  # 图片保存目录
    test_root_dir = './data/cifar100/test/'
    if not os.path.exists(train_root_dir):
        os.makedirs(train_root_dir)
    if not os.path.exists(test_root_dir):
        os.makedirs(test_root_dir)

    # 获取label对应的class,分为20个coarse class,共100个 fine class
    meta_Name = tar_dir + "meta"
    Meta_dic = unpickle(meta_Name)
    coarse_label_names = Meta_dic['coarse_label_names']
    fine_label_names = Meta_dic['fine_label_names']
    print(fine_label_names)

    # 生成训练集图片,如果需要png格式,只需要改图片后缀名即可。
    dataName = tar_dir + "train"
    Xtr = unpickle(dataName)
    print(dataName + " is loading...")
    for i in range(0, Xtr['data'].shape[0]):
        image = np.reshape(Xtr['data'][i], (-1,1024))  # Xtr['data']为图片二进制数据
        r = image[0, :].reshape(32, 32)  # 红色分量
        g = image[1, :].reshape(32, 32)  # 绿色分量
        b = image[2, :].reshape(32, 32)  # 蓝色分量
        img = np.zeros((32, 32, 3))
        # RGB还原成彩色图像
        img[:, :, 0] = r
        img[:, :, 1] = g
        img[:, :, 2] = b
        ###img_name:fine_label+coarse_label+fine_class+coarse_class+index
        picName = train_root_dir + str(Xtr['fine_labels'][i]) + '_' + str(Xtr['coarse_labels'][i]) + '_&' + \
                  fine_label_names[Xtr['fine_labels'][i]] + '&_' + coarse_label_names[
                      Xtr['coarse_labels'][i]] + '_' + str(i) + '.jpg'
        cv2.imwrite(picName, img)
    print(dataName + " loaded.")

    print("test_batch is loading...")
    # 生成测试集图片
    testXtr = unpickle(tar_dir + "test")
    for i in range(0, testXtr['data'].shape[0]):
        img = np.reshape(testXtr['data'][i], (3, 32, 32))
        img = img.transpose(1, 2, 0)
        picName = test_root_dir + str(testXtr['fine_labels'][i]) + '_' + str(testXtr['coarse_labels'][i]) + '_&' + \
                  fine_label_names[testXtr['fine_labels'][i]] + '&_' + coarse_label_names[
                      testXtr['coarse_labels'][i]] + '_' + str(i) + '.jpg'
        cv2.imwrite(picName, img)
    print("test_batch loaded.")

if __name__ == '__main__':
    cifar100_to_images()

在这里插入图片描述

2 VGG网络结构

2.1 网络结构总览

在VGG的网络中,卷积核尺寸都是3x3(padding=1),即卷积操作不会使得特征图的尺寸改变
使得特征图尺寸发生改变的只有池化操作特征图尺寸由(H,W)变为(H/2,W/2)

搞清楚以上两点后,VGG网络就十分清晰易懂。例如下图VGG16的网络结构,输入图片的尺寸为224x224x3,VGG16中包含5个池化操作,故特征图进行扁平化之前的尺寸应该是7(224/32);
至于通道的变化就更简单了,只有卷积操作会带来通道的改变,而卷积操作中的通道改变则是通过不同的卷积核组来实现的

在这里插入图片描述

2.2 VGG网络结构源码

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output

基于CIFAR100的VGG网络结构定义得很简洁,可以说是一目了然了。

对于网络结构定义,从forward函数可以看出主要包括两部分:features和classifier
features:卷积层+池化层,只涉及对网络尺寸、网络通道的改变,即VGG中全连接层之前的所有操作
classifier:全连接+分类,把从features中得到的特征图扁平化后,经过3层全连接后将类别映射到100进行分类预测

2.3 VGG中的Features构建过程

cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],# vgg11
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],# vgg13
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],# vgg16
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] # vgg19
}

def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))

2.3.1 字典cfg

首先,VGG将features部分的网络结构使用字典进行了表示,其中的数字表示通道数的改变字母M表示maxpooling操作

2.3.2 make_layers+vgg16_bn

make_layer部分就是按照cfg中给出的VGG网络结构进行网络的构建,以vgg16_bn为例:
1)确定当前层是卷积还是池化操作:如果是卷积操作,按照cfg中给出的通道数进行网络的连接;如果是池化操作,则以尺寸为2的池化核进行特征图的最大池化,当前的特征图尺寸减半。
2)对每一批batch进行标准化操作。
3)最后进行当前层的激活操作

在CIFAR100中,输入图片尺寸为3*32*32,经过features后特征图尺寸变为512*1*1,扁平化后只剩下通道维度的尺寸512,并且512也是分类器的输入
其实这种make_layer的方式在源码中特别常见,引用量刚破30000的ResNet源码也是用这种方式进行构建的

2.4 分类器

512->4096->4096->100

三层全连接层,输出前经过softmax便构成了最后的分类器

2.5 网络完整源码

import torch.nn as nn

cfg = {
    'A' : [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B' : [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

class VGG(nn.Module):

    def __init__(self, features, num_class=100):
        super().__init__()
        self.features = features

        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_class)
        )

    def forward(self, x):
        output = self.features(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)

        return output

def make_layers(cfg, batch_norm=False):
    layers = []

    input_channel = 3
    for l in cfg:
        if l == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            continue

        layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]

        if batch_norm:
            layers += [nn.BatchNorm2d(l)]

        layers += [nn.ReLU(inplace=True)]
        input_channel = l

    return nn.Sequential(*layers)

def vgg11_bn():
    return VGG(make_layers(cfg['A'], batch_norm=True))

def vgg13_bn():
    return VGG(make_layers(cfg['B'], batch_norm=True))

def vgg16_bn():
    return VGG(make_layers(cfg['D'], batch_norm=True))

def vgg19_bn():
    return VGG(make_layers(cfg['E'], batch_norm=True))

3 贴个最近常在脑海回响的诗

梦游天姥吟留别
李白
海客谈瀛洲,烟涛微茫信难求;
越人语天姥,云霞明灭或可睹。
天姥连天向天横,势拔五岳掩赤城。
天台四万八千丈,对此欲倒东南倾。
我欲因之梦吴越,一夜飞度镜湖月。
湖月照我影,送我至剡溪。
谢公宿处今尚在,渌水荡漾清猿啼。
脚著谢公屐,身登青云梯。
半壁见海日,空中闻天鸡。
千岩万转路不定,迷花倚石忽已暝。
熊咆龙吟殷岩泉,栗深林兮惊层巅。
云青青兮欲雨,水澹澹兮生烟。
列缺霹雳,丘峦崩摧。
洞天石扉,訇然中开。
青冥浩荡不见底,日月照耀金银台。
霓为衣兮风为马,云之君兮纷纷而来下。
虎鼓瑟兮鸾回车,仙之人兮列如麻。
忽魂悸以魄动,恍惊起而长嗟。
惟觉时之枕席,失向来之烟霞。
世间行乐亦如此,古来万事东流水。
别君去兮何时还?且放白鹿青崖间,须行即骑访名山。
安能摧眉折腰事权贵,使我不得开心颜!

OVER
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于CIFAR100的VGG网络结构详解 的相关文章

随机推荐

  • 把渲染当作核心任务

    和leader聊了一段时间 他的思路是完全底层自行开发 他写代码20多年了 实力是很强悍的 关于是否使用UE4 的shader 他的意思是可以学习后吃透了再写出来 但是不直接使用 搞pbr主要还是要看数学公式 遇到过一些问题 他往往能够解决
  • Kubernetes1.14 学习笔记二: 安装K8S

    一 安装rpm 包 rpm 制作过程参考上一节 https blog csdn net yulei qq article details 89205022 运行如下命令 root k8s x86 64 yum localinstall rp
  • EasyClangComplete CMake环境修复

    Sublime使用EasyClangComplete插件写代码时 如果你的文档目录或它的上层目录下有一个CMakeLists txt文件 那么插件会去调用cmake命令编译这个文件 如果系统没有装cmake或者cmake编译出错 就会导致编
  • yml中对特殊字符的处理

    一 双引号包住可以解决 二 单引加中括号包住可以解决 简单实用 over
  • sqlserver远程链接设置

    需要别人远程你的数据库 首先需要的是在一个局域网内 或者连接的是同一个路由器 接下来就是具体步骤 1 首先是要检查SQLServer数据库服务器中是否允许远程链接 其具体操作为 1 打开数据库 用本地帐户登录 右击第一个选项 选择属性 2
  • 矩阵键盘(stm32f103)

    最近需要用到矩阵键盘 在网上搜了很久看见的大多数都是根据判断寄存器的值来进行矩阵键盘取值 反正我找了一天 免费的文章 大都是这样的 付费的我也不知道 因为本人是初学者 对寄存器的操作不懂 刚开始也照着写了 逻辑上没有问题 但最后返回不了值
  • vant组件时间选择器修改时间格式以及默认展示当天时间

    vant的时间控件默认展示当天时间
  • 源码安装以太坊/wtc

    1 安装go 先更新一下 sudo apt get update sudo apt get y upgrade 下载源码https www golangtc com download 并解压 sudo tar xvf go1 9 2 lin
  • SQL盲注及python脚本编写

    1 什么是盲注 盲注就是在 sql 注入过程中 sql 语句执行的选择后 选择的数据不能回显 到前端页面 此时 我们需要利用一些方法进行判断或者尝试 这个过程称之为盲注 从 background 1 中 我们可以知道盲注分为三类 基于布尔
  • 基于 SpringMvc + OpenCV 实现的答题卡识别系统(附源码)

    java opencv 项目介绍 OpenCV是一个基于BSD许可 开源 发行的跨平台计算机视觉库 它提供了一系列图像处理和计算机视觉方面很多通用算法 是研究图像处理技术的一个很不错的工具 最初开始接触是2016年因为公司项目需要 但是当时
  • AlertDialog全屏显示的问题

    有时候 我们需要直接显示全屏的dialog 平常的时候会有一圈边框 不好看 第一步 编写style 第二步 在使用的时候带入 最简单的全屏就这么完成了 简单不 咩哈哈哈哈哈哈哈
  • Python入门实战题目

    1 有1 2 3 4个数字 能组成多少个互不相同且无重复数字的三位数 都是多少 2 两个乒乓球队进行比赛 各出三人 甲队为a b c三人 乙队为x y z三人 已抽签决定比赛名单 有人向队员打听比赛的名单 a说他不和x比 c说他不和x z比
  • Python3 [爬虫实战] Redis+Flask 动态维护cookies池(上)

    Redis 使用 1 首先去官网下载Reidszip文件 http www redis cn topics config html 2 Reids的安装 直接解压缩zip文件 然后放在一个文件夹中 在文件夹路径下用dos窗口启动服务器端 r
  • 入门算法题002

    题目 给你一个正整数n 假设有两个质数加起来等于n 问一共有多少组这样的质数 思路 1 我们得要先有一个函数去判断是否是质数 2 循环拆解为两个数 暴力拆解 试下10 15分钟内做出来 public class Leecode002 pub
  • selenium爬虫运行慢如何解决?

    Selenium作为一个强大的自动化工具 可用于编写爬虫程序 尽管Selenium在处理动态网页上非常强大 但对于静态网页爬简单数据提取 使用轻量级库或工具可能更加上所述 Selenium作为一个灵活可定动化工具 在需要模拟用户行为 处理动
  • VS2005中分页和多列排序

    最近在使用ASP net 2 0的GridView 控件时 发现排序与分页功能Microsoft实现的都很简单 比如排序 在点击列名的时候来触发整页的PostBack 然后排序 但是在列头上没有一个显示升序降序的图标 这会让最终用户使用时很
  • OJ在线编程常见输入输出练习(11题)

    1 输入包括两个正整数a b 1 lt a b lt 10 9 输入数据包括多组 输出描述 输出a b的结果 输入例子1 1 5 10 20 输出例子1 6 30 import java io BufferedReader import j
  • java web 项目配置日志框架log4j

    第一步 log4j 框架所关联的第三方jar 文件 commons logging xxx jar log4j xxx jar slf4j api xxx jar slf4j log4j12 xxx jar 以下是我搭建web框架集成log
  • 【C++】“没有可用成员”问题的原因之一

    今天碰到一个定义类成员函数的时候一直提示没有可用成员的问题 琢磨半天终于解决 记录一下 以免再犯 问题描述 在头文件中声明了名称空间SALES 并在名称空间中声明了类Sales 在类中声明了一系列类成员后 切换到另一个cpp文件中定义相关的
  • 基于CIFAR100的VGG网络结构详解

    基于CIFAR100的VGG网络详解 码字不易 点赞收藏 1 数据集概况 1 1 CIFAR100 cifar100包含20个大类 共100类 train集50000张图片 test集10000张图片 CIFAR100下载地址 http w