Pytorch学习笔记(六)之完整的模型训练(以Cifar10为例)

2023-05-16

文章目录

  • 前言:数据集介绍
    • 0.准备工作:首先导入相关包,设置参数等
    • 1.数据预处理之增强(transforms等)
    • 2.数据的读取(Dataset&Dataloader)
    • 3.模型的搭建(nn.model)
    • 4.开始训练(loss函数,优化器,训练epoch)
      • 先定义损失函数,优化器等
      • 训练集上开始训练
      • 测试集上计算loss及准确率
  • 验证测试模型(没有标签的测试图片)

前言:数据集介绍

在学习完深度学习的理论之后,就要开始代码实战,基本上分为数据的读取,数据预处理(增强),模型的搭建及训练

首先介绍一下数据集,此次采用的数据集是CIFAR10,是一个经典的10分类彩色图片数据集。

  • 一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
  • 图片的尺寸为 32×32
  • 数据集中一共有 50000 张训练图片和 10000 张测试图片
  • 官方地址:https://www.cs.toronto.edu/~kriz/cifar.html 可以通过官网直接下载,有三个版本的(python,matlab,c);也可以直接在代码中下载(pytorch已封装,下文就采用这种方法)

从官网下载的是cifar-10-python.tar.gz,解压之后会得到相应文件夹,里面的文件如下图所示,显然并不是直接以图片的形式存放。由于pytorch的torchvision.datasets包中已经写好了CIFAR10的类,就说明可以直接调用再进行处理,所以不需要关心怎么转化成普通的图片类型(当然也可以通过代码将其转换成jpg或png)

0.准备工作:首先导入相关包,设置参数等

0-4: trainer.py

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import argparse
from torch import nn
# 设置命令行参数
parser=argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int , default=10, help="the number of epochs")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
args=parser.parse_args()
print(args)  # 可以把参数输出来看一下

# 选择设备,GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) # 看一下自己的设备(支不支持gpu)
Namespace(batch_size=64, lr=0.001, n_epochs=10)
cuda

torch.device可以选择使用GPU训练,可以选择上述写法(更方便广泛),也可以像下面这样直接指定:

device=torch.device("cpu") # 选择cpu训练
device=torch.device("cuda") # 选择gpu训练
device=torch.device("cuda:0") # gpu多张卡时,选择卡0

设置好device之后,就需要把网络模型,数据,loss函数等放到这个设备上,使用的时.to(device),具体可以看下面步骤的操作。

再说一下argparse包,这是python自带的命令行参数解析包,可以用来方便地读取命令行参数。如果代码需要频繁地修改参数的时候,使用这个工具可以将参数和代码分离开来,让代码更简洁,适用范围更广。
比如我们执行代码时,一般是python train.py,此时程序就会执行默认参数;如果需要改变参数可以这样python train.py --batch_size 32
主要讲一下add_argument()函数的参数

参数描述
dest默认的变量名是–或-后面的字符串,也可以通过dest=xxx来设置参数的变量名。在代码中用args.xxx来获取参数的值
default没有设置值情况下的默认参数
type参数类型(int,float,string……)
required表示这个参数是否一定需要设置
choices参数值只能从几个选项里面选择
help指定参数的说明信息
action相当于把参数设成了一个“开关”,不需要给这个开关传递具体的值(常用在参数为true或false的情况)

更多argparse信息:http://vra.github.io/2017/12/02/argparse-usage/

1.数据预处理之增强(transforms等)

一般来说只需要对训练集进行transforms操作,测试集一般只需转换为tensor

train_tfm=transforms.Compose([transforms.Resize((32,32)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor()])

主要是设置transforms来对训练集进行数据增强,常用的还有很多,这边只是进行了resize(由于图片本身就是32x32规则,所以这个操作也可以不用),随机翻转和转化为Tensor。更多transforms操作可参见:常用transforms大集合

2.数据的读取(Dataset&Dataloader)

"""
数据加载,将数据写进Dataset和DataLoader中
"""
# Dataset
train_data=datasets.CIFAR10(root="./dataset",train=True,
                            transform=train_tfm,download=True)
test_data = datasets.CIFAR10(root="./dataset", train=False,
                             transform=transforms.ToTensor(),download=True)
# DataLoader
train_loader= DataLoader(train_data,batch_size = args.batch_size,
                             shuffle = True,num_workers = 0)
test_loader=DataLoader(test_data,batch_size = args.batch_size,
                             shuffle = False,num_workers = 0)
# 看一下训练集和测试集的大小
train_len=len(train_data)
test_len=len(test_data)
print("训练集的大小是:{}".format(train_len))
print("测试集的大小是:{}".format(test_len))
Files already downloaded and verified
Files already downloaded and verified
训练集的大小是:50000
测试集的大小是:10000

采用的是官网提供的datasets类来获取数据集,基本上比较清晰。更多Dataset&DataLoader用法(如想读取自己的数据集)可以参见:click here

3.模型的搭建(nn.model)

参照的网络模型是下图:
在这里插入图片描述
为增加模型的泛化能力和减小误差,在原模型的基础上,加了几个非线性层(激活函数),每一层的输入输出和上图保持一致

# 搭建网络
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.module=nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32,
                      kernel_size=5, stride=1, padding=2),
            nn.RelU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(1024, 10))

    def forward(self,x):
        x = self.module(x)
        return x

卷积网络各个层的详细用法和作用可以参见:click here

注意:torch.nn 只支持小批量输入,而不支持单个样本。例如,nn.Conv2d 接受一个 4 维的张量,每一维分别是 sSamples x nChannels x Height x Width(样本数 x 通道数 x 高 x 宽),直接输入一张图片(三维)是不行的
如果是单个样本:
需使用 input.unsqueeze(0) 或者reshape(input,(1,3,224,224))来添加其它的维数,具体可以参见下文选取一张图片进行验证的代码写法。

4.开始训练(loss函数,优化器,训练epoch)

先定义损失函数,优化器等

if __name__ == '__main__':
    my_module = MyModule()  # 实例化对象
    # 将模型放到device设备上
    my_module = my_module.to(device) 
    # 采用交叉熵损失函数
    criterion = nn.CrossEntropyLoss()
    criterion=criterion.to(device) # 同样把损失函数放到device设备
    # 定义优化器
    optim = torch.optim.SGD(my_module.parameters(), args.lr)

训练集上开始训练

    for i in range(args.n_epochs):
        # 将模型设置为训练模式
        my_module.train()
        print("第{}轮训练开始:".format(i+1))

        for train_step,(imgs,targets) in enumerate(train_loader):
            imgs = imgs.to(device)
            targets = targets.to(device)
            
            output = my_module(imgs)
            
            # 计算模型输出和实际标签的loss
            loss = criterion(output, targets)
            # 梯度手动清零
            optim.zero_grad()
            # 后向传播,计算梯度
            loss.backward()
            # 优化器优化参数
            optim.step()

            if (train_step+1) % 100 == 0:
                print("训练次数:{},loss:{}".format(train_step,loss.item()))           

测试集上计算loss及准确率

        test_loss = 0 # 测试集上的loss
        total_correct = 0  # 正确预测的数目
        #将模型设置为测试模式
        my_module.eval() 
        with torch.no_grad():
            for imgs,targets in test_loader:
                imgs=imgs.to(device)
                targets=targets.to(device)

                output=my_module(imgs)  # 将图片输入到模型中
                
                loss=criterion(output,targets)
                 # 测试集loss,loss.item()可以避免显存爆炸
                test_loss=loss.item()+test_loss

                # 将输出中概率最大的取出,与targets比较,相等就代表预测正确
                correct = (output.argmax(1) == targets).sum()
                total_correct = correct + total_correct
                """
                # 可以保存每一次训练的模型pth
                torch.save(my_module,"my_model{}.pth".format(i+1))
                """
                
            print("测试集上的loss:{}".format(test_loss))
            print("测试集的准确率:{}".format(total_correct/test_len))

    torch.save(my_module,"my_module.pth")
    print("模型已保存")

输出:

完整的训练到这里就结束了,下一步就是优化模型,让test loss更小——炼丹开始

验证测试模型(没有标签的测试图片)

tester.py:选取几张数据集中没有的照片,如下:

img_path="./pic/dog.jpg"

img_path2="./pic/airplane.jpg"

from PIL import Image
from torch import nn
from torchvision import transforms
import torch
# 导入这一句就可以不用写下面的模型
from trainer import MyModule

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 读入图片
img_path="./pic/airplane.jpg" # 需要验证的图片
image = Image.open(img_path)

tfm = transforms.Compose([transforms.Resize((32,32)),
                          transforms.ToTensor()])
img_tfm = tfm(image)

"""
就是上文的模型,如果不导入from trainer import MyModule,就需要把模型再写一遍
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = nn.Sequential(
             nn.Conv2d(in_channels=3, out_channels=32,
                      kernel_size=5, stride=1, padding=2),
            nn.RelU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(1024, 10))

    def forward(self, x):
        x = self.module(x)
        return x
"""
# 加载模型
model = torch.load("my_module.pth")
print(model) # 输出模型看一下

# 注意模型输入的尺寸,上文已说明。需要(N,C,W,H),N表示batch_size,C是通道数
img = torch.reshape(img_tfm,(1,3,32,32))

# 注意模型和数据要在同一设备
img=img.to(device)  

output = model(img)
print(output)
print(output.argmax(1)) # 输出概率最大的

输出如下图所示,可以看出模型判断其为0类(查看最开始数据集介绍,第一个就是0类,airplane)
在这里插入图片描述

用官方提供的预训练模型——VGG16
VGG16是用ImageNet训练的,他的输入一般是224x224,这里我们也resize以下。一共有1000个类别,我们修改以下模型让他用于十分类

from PIL import Image
from torchvision import transforms,models

img_path="./pic/dog.jpeg"
image = Image.open(img_path)

tfm = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
img_tfm = tfm(image)

vgg16_true = models.vgg16(pretrained=True)
vgg16_true.add_module("linear",nn.Linear(1000,10))
print(vgg16_true)

# vgg16 的网络架构也是继承于torch.nn,所以需要改变维度,才能输入
img = torch.reshape(img_tfm,(1,3,224,224))
output = vgg16_true(img)

# print(output)   # 1000 维的 tensor ,就不输出演示了
print("输出的标签是:",output.argmax(1))
输出的标签是: tensor([208])

查一下ImageNet千分类对应的标签:
在这里插入图片描述
只能说,千分类太细了,我也不太懂狗的品种……


参考链接:https://www.bilibili.com/video/BV1hE411t7RN?
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch学习笔记(六)之完整的模型训练(以Cifar10为例) 的相关文章

随机推荐

  • Flume实战

    前言 在一个完整的大数据处理系统中 xff0c 除了hdfs 43 mapreduce 43 hive组成分析系统的核心之外 xff0c 还需要数据采集 结果数据导出 任务调度等不可或缺的辅助系统 xff0c 而这些辅助工具在hadoop生
  • sqoop安装

    sqoop安装 作为大数据协作框架之一 xff0c Sqoop是一款用于Hadoop和关系型数据库之间进行相互的数据导入和导出的工具 安装sqoop的前提是已经具备java和hadoop的环境 1 下载并解压 最新版下载地址http ftp
  • 利用视图进行多表关联

    疑问 在Maxcompute中我们关联的码表大于8个 xff0c 然后数据存储量大于500W xff0c 那么在进行sql清洗的时候极有可能会被卡死 可是我们就是要在一张表上关联10多个表 xff0c 比如一张表中的很多字段都要关联码表 x
  • CM&CDH安装

    笔者当时自己装CM amp CDH看了不下10篇博客 xff0c 重装集群不下3次 xff0c 后来快照这个功能深深的刻在了我的心里 这篇博客笔者呕心沥血啊 不过还是会有同学会挂掉 xff0c 所以希望大家做到那里一步记得快照 发一下牢骚
  • jvm复习:主动产生fullGC

    一 jdk8参数 Xms100m Xmx100m Xmn30m XX 43 PrintGCDetails 二 代码 xff1a package cn edu tju test public class GcTest01 public sta
  • Zookeeper机制和应用场景

    Zookeeper简介 Zookeeper 分布式服务框架是 Apache Hadoop 的一个子项目 xff0c 它主要是用来解决分布式应用中经常遇到的一些数据管理问题 xff0c 如 xff1a 统一命名服务 状态同步服务 集群管理 分
  • crontab定时器

    crontab定时器 linux下的定时任务 1 编辑使用crontab e 一共6列 xff0c 分别是 xff1a 分 时 日 月 周 命令 2 查看使用crontab l 3 删除任务crontab r 4 查看crontab执行日志
  • Linux后台运行程序

    在我们平常的时候运行程序的时候会产生很多的信息 xff0c 这些信息有时候有用 xff0c 有时候没用 xff0c 不过这些数据都会在该程序的log中保存 xff0c 所以把这些信息放在前台就不是很好 我们可以将脚本放在后台运行 xff0c
  • vnc的两种配置方法及解决vnc连不上的情况

    1 vnc连不上的现象 xff1a Timed out waiting for a response from the computer 解决方法 xff1a sudo sbin iptables I INPUT 1 p TCP dport
  • 【随写笔记】TouchGFX

    https www cnblogs com firege p 5805823 html https blog csdn net u013766436 article details 50805808 LTDC STM32F429系列芯片内部
  • BGP路由协议

    特点 BGP是一种外部网关协议 xff08 EGP xff09 xff0c 不擅长路由计算 xff0c 擅长路由控制 OSPF ISIS等内部网关协议 xff08 IGP xff09 xff0c 擅长路由计算 xff0c 不擅长路由控制 B
  • sed命令的使用(合并行)

    1 把所有不以句号结尾的行 xff0c 和下一行合并 span class token function sed span i span class token string 39 N s n 39 span abc txt 2 把两行合并
  • Mybatis常见面试题及答案

    文章目录 1 什么是Mybatis xff1f 2 Mybaits的优缺点 xff1a 3 和 的区别是什么 xff1f 4 通常一个mapper xml文件 xff0c 都会对应一个Dao接口 xff0c 这个Dao接口的工作原理是什么
  • HDFS排查路径

    遇到HDFS的问题 xff0c 首先需要排除可用类问题 可用类问题按影响 紧急程度不同 xff0c 可继续分为HDFS功能性受损 lt 61 HDFS高可靠性 高可用性受损 按照以下步骤进行排查 xff0c 以下任意一项有异常 xff0c
  • Docker复习: jar包打成docker

    FROM openjdk 8 ARG JAR FILE COPY springbootmybatis 1 0 SNAPSHOT jar app jar EXPOSE 9012 ENTRYPOINT 34 sh 34 34 c 34 34 j
  • iscsi磁盘挂载并设置为开机自动挂载

    前提准备 xff1a 安装iscsi客户端软件 yum y install iscsi initiator utils 第一步 xff1a 发现ISCSI设备 root 64 sdw4 iscsiadm m discovery t st p
  • 关于华为AC6507S能ping通web和ssh却登录不上排障记录(管理面隔离)

    一 客户描述PC和服务器能ping通AC但是web却登录不上 测试 xff1a 设置服务器地址为192 168 0 100 24 AC地址192 168 0 2 24 用0 100去ping0 2可以ping通 xff0c web登录连接失
  • tensor 和 numpy 的互相转换

    为什么要相互转换 xff1a 简单一句话 numpy操作多样 简单 但网络前向只能是tensor类型 各有优势 所以需要相互转换补充 convert Tensor x of torch to array y of numpy y 61 x
  • 图像畸变矫正算法实现 matlab版

    真正的相机镜头不理想 xff0c 并在图像中引入一些失真 为了解释这些非理想性 xff0c 有必要在透视投影的方程中添加失真模型 一 原图如下 xff1a 二 实现的效果图 三 算法具体实现 function undistorted img
  • Pytorch学习笔记(六)之完整的模型训练(以Cifar10为例)

    文章目录 前言 xff1a 数据集介绍0 准备工作 xff1a 首先导入相关包 xff0c 设置参数等1 数据预处理之增强 transforms等 2 数据的读取 Dataset amp Dataloader 3 模型的搭建 nn mode