Pytorch自定义CNN网络实现猫狗分类

2023-05-16

数据集下载地址:https://www.microsoft.com/en-us/download/confirmation.aspx?id=54765

Dogs vs. Cats(猫狗大战)来源Kaggle上的一个竞赛题,任务为给定一个数据集,设计一种算法中的猫狗图片进行判别。

数据集包括25000张带标签的训练集图片,猫和狗各125000张,标签都是以cat  or  dog命名的。图像为RGB格式jpg图片,size不一样。截图如下:

一. 数据预处理

pytorch的数据预处理部分要写成一个类,这个类继承Dataset类,并必须要实现三个函数。

from torch.utils.data import DataLoader,Dataset
from torchvision import transforms as T
import matplotlib.pyplot as plt
import os
from PIL import Image

class DogCat(Dataset):
    def __init__(self, root, transforms=None, train=True):
        imgs = [os.path.join(root,img) for img in os.listdir(root)]
        imgs_num = len(imgs)

        if train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        else:
            self.imgs = imgs[int(0.3 * imgs_num):]

        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

            self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
            ])
        else:
            self.transforms = transforms
             
    def __getitem__(self, index):
        img_path = self.imgs[index]
        # dog label : 1           cat label : 0
        label = 1 if "dog" in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data,label

    def __len__(self):
        return len(self.imgs)

__init__为构造函数,我这里用力定义数据路径,数据集划分,transforms。

__getitem__为迭代函数,用来return单个数据的data和label。

__len__返回数据集的长度。

二. 定义网络

在这个例子中,我们用一个简单的4层卷积,2层全连接,最后跟一个sigmoid输出二分类的概率的CNN网络。

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.conv4 = nn.Conv2d(128, 128, 3)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        # 12*12 for size(224,224)    7*7 for size(150,150)
        self.fc1 = nn.Linear(128*12*12, 512)
        self.fc2 = nn.Linear(512, 1)
        
    def forward(self, x):
        in_size = x.size(0)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.max_pool(x)
        # 展开
        x = x.view(in_size, -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

pytorch定义网络时,必须实现两个函数,构造函数主要定义一些网络块,forward函数实现前向推理过程。且在后续代码中,如果定义对象model: ConvNet和数据image,可以直接通过model(image)来调用froward函数(python真的很神奇,C++出身的我理解这些骚操作好难)

三.训练模型

数据准备好了,模型网络定义好了,下一步当然是训练权重了。


import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from dataset import DogCat
from network import ConvNet
from draw import draw_acc,draw_loss


train_data_root = "/home/elvis/workfile/dataset/dataset_kaggledogvscat/train"
batch_size = 256
# 1. prepare dataset
train_data = DogCat(train_data_root, train=True)
val_data = DogCat(train_data_root, train=False)
train_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(val_data,batch_size=batch_size,shuffle=True)

# 2. load model
model = ConvNet()
if torch.cuda.is_available():
    model.cuda()

# 3. prepare super parameters
criterion = nn.BCELoss()
learning_rate = 1e-3
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 4. train
train_loss_epoch = []
train_acc_epoch = []
val_loss_epoch = []
val_acc_epoch = []
for epoch in range(1, 10):
    model.train()
    train_loss = 0;
    train_acc = 0;
    for batch_idx, (data, target) in enumerate(train_dataloader):
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda().float().unsqueeze(-1)
        else:
            data, target = data, target.float().unsqueeze(-1)
        optimizer.zero_grad()
        output = model(data)
        # print(output)
        loss = criterion(output, target)
        train_loss += loss.item();
        pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in output]).cuda();
        train_acc += pred.eq(target.long()).sum().item();
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%10 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx+1) * len(data), len(train_dataloader.dataset),
                100. * (batch_idx+1) / len(train_dataloader), loss.item()))
    train_loss_epoch.append(train_loss / len(train_dataloader));
    train_acc_epoch.append(train_acc / len(train_dataloader.dataset));
    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(train_loss / len(train_dataloader), train_acc, len(train_dataloader.dataset),
                                                                                    100. * train_acc / len(train_dataloader.dataset)));

    # val
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_dataloader):
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda().float().unsqueeze(-1)
            else:
                data, target = data, target.float().unsqueeze(-1)
            output = model(data)
            # print(output)
            test_loss += criterion(output, target).item(); #每个批次平均,一个epoch里所有批次求和
            pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in output]).cuda()
            correct += pred.eq(target.long()).sum().item()
    print('Valid set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss/len(val_dataloader), correct, len(val_dataloader.dataset),
                                                                                    100. * correct / len(val_dataloader.dataset)));
    val_loss_epoch.append(test_loss / len(val_dataloader));
    val_acc_epoch.append(correct / len(val_dataloader.dataset));

    # Save model
    val_acc_rate = correct / len(val_dataloader.dataset);
    save = True
    best = "best.pt"
    last = "last.pt"
    if save:
        # Save last, best and delete
        torch.save(model.state_dict(), last)
        if val_acc_rate == max(val_acc_epoch):
            torch.save(model.state_dict(), best)
            print("save epoch {} model".format(epoch))

    
# 5. drawing
draw_loss(train_loss_epoch, val_loss_epoch)
draw_acc(train_acc_epoch,val_acc_epoch)
    

第一步,准备数据。先用我们之前定义的DogCat类来加载数据,但这个类继承自dataset,是加载一条数据的。如果要批量加载数据,还要用pytorch内部的另一个类DataLoader,然后在构造函数里传入batchsize就可以批量加载数据了。注意这里的类对象实际是一个生成器,后续通过循环就可以一直批量的去取数据了。

第二步,定义模型对象,有用显卡就把模型放在显卡上,没有的话就用cpu跑。

第三步,定义一些超参数。因为是二分类,网络最后一层为sigmoid输出类别的概率值,所以选用二分类交叉熵损失函数。再设置一下学习率和优化器。

第四步,训练n个epoch。在每一个epoch里计算训练集准去率,验证集准确率,并保存模型。

最后结果像这样

有条件的可以多训练几个epoch试试。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch自定义CNN网络实现猫狗分类 的相关文章

  • Wireshark的常见提示

    概述 本文主要介绍Wireshark中出现的一些常见提示 详细信息 Wireshark简介 Gerald Combs是堪萨斯城密苏里大学计算机科学专业的毕业生 1998年发布了第一版Ethereal工具 xff0c Ethereal工具使用
  • shell报错bad substitution 解决办法

    bin bash a 61 34 hello 34 b 61 34 hi is a 34 echo b echo a echo a echo a 1 2 执行脚本方式不同出现的结果不同 xff1a 方式1 xff1a sh shell sh
  • centos8软件安装dnf命令

    DNF是新一代的rpm软件包管理器 它首先出现在 Fedora 18 这个发行版中 而目前 xff0c 它取代了yum xff0c 正式成为从 Fedora 22 起 Fedora 版本的包管理器 DNF包管理器克服了YUM包管理器的一些瓶
  • 多目标规则在 Makefile 中的应用与示例

    在 Makefile 中 xff0c 如果一个规则有多个目标 xff0c 而且它们之间用空格分隔 xff0c 我们称之为 34 多目标规则 34 这意味着这个规则适用于列出的所有目标 在目标下面的命令是 C 64 xff0c 它通常与 ma
  • 计算机中内存、cache和寄存器之间的关系及区别

    1 寄存器是中央处理器内的组成部份 寄存器是有限存贮容量的高速存贮部件 xff0c 它们可用来暂存指令 数据和位址 在中央处理器的控制部件中 xff0c 包含的寄存器有指令寄存器 IR 和程序计数器 PC 在中央处理器的算术及逻辑部件中 x
  • dell 台式电脑设置每天定时开机和关机

    每天定时开机设置 xff1a 戴尔电脑通过CMOS设置实现自动开机的设置过程如下 xff1a 1 首先进入 CMOS SETUP 程序 大多数主板是在计算机启动时按DEL或F2键进入 xff1b 2 然后将光条移到 Power Manage
  • windows批处理自动获取电脑配置信息

    39 2 gt nul 3 gt nul amp cls amp 64 echo off 39 amp rem 获取本机系统及硬件配置信息 39 amp set 61 Any question amp set 64 61 WX amp se
  • Centos7搭建cisco ocserv

    一 安装的部分直接yum安装即可 yum y install ocserv 二 配置文件根据实际情况调整 auth方式有两种 1 系统账号认证 配置的话就是 xff1a auth 61 34 pam 34 2 本地文件认证 配置的话就是 x
  • 私有harbor部署(docker方式)

    环境准备 docker compose v Docker Compose version v2 14 2 wget https github com docker compose releases download v2 14 2 dock
  • ORACLE扩展表空间

    一 查询表空间使用情况 SELECT UPPER F TABLESPACE NAME 34 表空间名 34 D TOT GROOTTE MB 34 表空间大小 M 34 D TOT GROOTTE MB F TOTAL BYTES 34 已
  • Oracle 常用性能监控SQL语句

    1 查看表锁 SELECT FROM SYS V SQLAREA WHERE DISK READS gt 100 2 监控事例的等待 SELECT EVENT SUM DECODE WAIT TIME 0 0 1 34 Prev 34 SU
  • Nginx出现“ 413 (499 502 404) Request Entity Too Large”错误解决方法

    1 Nginx413错误的排查 修改上传文件大小限制 在使用上传POST一段数据时 xff0c 被提示413 Request Entity Too Large xff0c 应该是nginx限制了上传数据的大小 解决方法就是 打开nginx主
  • 查看弹出广告来自哪个软件

    打开VS的Spy 43 43 将指针移到广告处 xff0c 然后点OK xff0c 在Process标签页可以看到进程id和线程id将获得的16进制进程id xff08 例如 xff1a 000025F8 xff09 通过计算器转成10进制
  • C++多态虚函数实现原理,对象和虚函数表的内存布局

    基本概念 我们知道C 43 43 动态多态是用虚函数实现的 xff0c 而虚函数的实现方式虽说C 43 43 标准没有要求 xff0c 但是基本都是用虚函数表实现的 xff08 编译器决定 xff09 所以我们有必要了解一下虚函数表的实现原
  • C++ STL中递归锁与普通锁的区别

    在多线程编程中 xff0c 保护共享资源的访问很重要 xff0c 为了实现这个目标 xff0c C 43 43 标准库 xff08 STL xff09 中提供了多种锁 xff0c 如std mutex和std recursive mutex
  • VS+Qt开发环境

    VS Qt下载 VS下载 xff1a https visualstudio microsoft com zh hans vs Qt下载安装 xff1a https www bilibili com video BV1gx4y1M7cM VS
  • windows下使用ShiftMediaProject编译调试FFmpeg

    为什么要编译FFmpeg xff1f 定制模块调试源码 windows下编译 推荐项目ShiftMediaProject FFmpeg 平时总是看到一些人说windows下编译FFmpeg很麻烦 xff0c 这时候我就都是微微一笑 xff0
  • RTSP分析

    RTSP使用TCP来发送控制命令 xff08 OPTIONS DESCRIBE SETUP PLAY xff09 xff0c 因为TCP提供可靠有序的数据传输 xff0c 而且TCP还提供错误检测和纠正 RTSP的报文格式可以参考HTTP的
  • RTP分析

    参考 RTP xff08 A Transport Protocol for Real Time Applications 实时传输协议 xff0c rfc3550 xff09 RTP Payload Format for H 264 Vid
  • VS链接器工具错误 LNK2019:无法解析的外部符号

    常见的问题 以下是一些导致 LNK2019 的常见问题 xff1a 未链接的对象文件或包含符号定义的库 在 Visual Studio 中 xff0c 验证包含定义源代码文件是生成 xff0c 分别链接为项目的一部分 在命令行中 xff0c

随机推荐

  • FFmpeg合并视频流与音频流

    mux h ifndef MUX H define MUX H ifdef cplusplus extern 34 C 34 endif include 34 common h 34 include 34 encode h 34 typed
  • 解决电脑同时使用有线网上内网,无线网上外网的冲突

    由于内网有网络限制 xff08 限制娱乐等 xff09 xff0c 所以肯定要用外网 xff08 无线网卡 xff09 但是有的网站只能用内网访问 xff0c 比如gitlab xff0c oa等 我电脑刚开始连接了wifi后上不了gitl
  • Python斗鱼直播间自动发弹幕脚本

    工具 xff1a Python xff0c Chrome浏览器 因为不会用短信验证码登录 xff0c 所以使用QQ帐号登录 xff0c 必须要斗鱼帐号绑定QQ号 难点主要是帧的切换 查找元素可以通过chrome浏览器鼠标指向该元素 xff0
  • Qt+FFmpeg录屏录音

    欢迎加QQ群309798848交流C C 43 43 linux Qt 音视频 OpenCV 源码 xff1a Qt 43 FFmpeg录屏录音 NanaRecorder 之前的录屏项目ScreenCapture存在音视频同步问题 xff0
  • Qt源码分析(一)

    欢迎加QQ群309798848交流C C 43 43 linux Qt 音视频 OpenCV 源码面前 xff0c 了无秘密 阅读源码能帮助我们理解实现原理 xff0c 然后更灵活的运用 接下来我用VS2015调试Qt5 9源码 首先提一下
  • python实现批量提取图片中文字的小工具

    要实现批量提取图片中的文字 xff0c 我们可以使用Python的pytesseract和Pillow库 pytesseract是一个OCR xff08 Optical Character Recognition xff0c 光学字符识别
  • 在虚拟机与wsl中编译内核,并启用kasan

    linux内核编译 虚拟机编译自己的内核WSL编译自己的内核测试kasan 想学习一下kasan的相关配置 xff0c 但kasan是内核中的相关配置 xff0c 开启得编译内核 xff0c 本文使用虚拟机与wsl两种方式来编译内核启动ka
  • mac时间机器删除旧备份

    查 span class token function sudo span tmutil listlocalsnapshots 删除 span class token function sudo span tmutil deleteloca
  • Linux的ssh服务

    Linux中真机与虚拟机的连接 1 先修改虚拟机的IP xff0c 在虚拟机中输入命令 xff1a nm connection editor 2 选中已存在的System eth0 xff0c 选择Delete 然后点击Add xff0c
  • 51单片机——控制步进电机加速、减速及反转

    加速 xff1a include lt reg52 h gt define uchar unsigned char define uint unsigned int define MotorData P1 uchar phasecw 4 6
  • 【安装教程】——Linux安装opencv

    安装教程 Linux安装opencv 一 安装相关软件包二 获取Source三 安装OpenCV 未完待续 一 安装相关软件包 安装相关软件包 打开终端 xff0c 安装以下软件包 sudo apt install build span c
  • 请教:linux下的/opt目录是做什么用的?

    请教 linux下的 opt目录是做什么用的 蛋疼YMG 浏览 4934 次 发布于2014 10 24 13 35 荒漠探险 答题闯关 好礼连连 最佳答案 opt 主机额外安装软件所摆放的目录 默认是空的 一般安装软件的时候 xff0c
  • kali里装java

    1 下载最新的JAVA JDK jdk 8u91 linux x64 2 解压缩文件并移动至 opt tar xzvf jdk 8u91 linux x64 tar gz mv jdk1 8 0 91 opt cd opt jdk1 8 0
  • ffmpeg--tcp

    TCP代码分析 xff1a include 34 avformat h 34 include 34 libavutil avassert h 34 include 34 libavutil parseutils h 34 include 3
  • GO调用ffmpeg动态库

    package main cgo CFLAGS I usr local ffmpeg include cgo LDFLAGS L usr local ffmpeg lib lavformat include 34 libavformat a
  • ftp上传文件时出现 550 Permission denied,不是用户权限问题

    查了半天 xff0c 发现是因为服务器已经有同名的文件了 xff0c 所以无法上传 xff0c 上传文件名改成不重复的就可以了
  • Ubuntu16.04解决登录闪退问题

    Ubuntu16 04解决登录闪退问题 夏哈哈 64 64 2020 07 21 08 42 46 637 收藏 1 分类专栏 xff1a 程序开发 Ubuntu装机 文章标签 xff1a 深度学习 版权 问题 xff1a Ubuntu16
  • syntax error at or near “.“

    SQL select t id t name t sex t age t address t phone from user where id 61 1 Cause org postgresql util PSQLException ERR
  • ubuntu编译安装opencv4.5.1(C++)

    1 在https github com opencv opencv上下载opencv源码 包括opencv4 5 1和opencv contrib4 5 1 也可以不用 xff0c 其中opencv是主体 xff0c opencv cont
  • Pytorch自定义CNN网络实现猫狗分类

    数据集下载地址 xff1a https www microsoft com en us download confirmation aspx id 61 54765 Dogs vs Cats 猫狗大战 来源Kaggle上的一个竞赛题 xff