图像分类(1),数据预处理

2023-10-27

本文介绍如何使用pytorh利用预训练模型进行图像分类,主要参考Transfer Learning Tutorial

具体代码可以参考Image_classification

  1. 下载代码文件:git clone https://github.com/chenmozxh/pytorch_studying
  2. 下载数据集:wget https://download.pytorch.org/tutorial/hymenoptera_data.zip 
    这个数据集是imagenet的一个小子集,包含ants和bees两个分类
  3. 解压数据集:unzip hymenoptera_data.zip

    数据集结构为:文件夹hymenoptera_data下存在训练集路径train和测试集路径test,train和test下都有ants和bees两个文件夹,即相应的图像。

  4. 运行python3 example1.py就开始训练了,可以看出随着epoch的加深,loss越来越小,而准确率acc越来越高
  5. example1.py代码解析:
      数据导入,使用官方写好的torchvision.datasets.ImageFolder接口实现数据导入。这个函数只需要你提供图像所在文件夹data_dir/train和data_dir/test即可。这两个目录下分别为N个子文件夹,N为分类的类别数,每个文件夹下为这个类别的图像。这样,torchvision.datasets.ImageFloder就会返回一个列表,列表中每一个值都是一个tuple,每个tuple包含图像和标签信息
    def Data_loader(Data_Path):
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
             ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                #transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }
    
        data_dir = Data_Path
        image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
        class_names = image_datasets['train'].classes
     
        return dataloaders, image_datasets, class_names
    
    dataloaders, image_datasets, class_names = Data_loader('hymenoptera_data')
    print(image_datasets)
    for e in image_datasets:
        print(e)
        print(image_datasets[e])
        for index, k in enumerate(image_datasets[e]):
            print(type(k), len(k))
            print(index, k[0].size(), k[1])
    
    

    transform对图像进行预处理。torchvision.transform.Compose是用来管理所有的transforms操作的。RandomSizeCrop和RandomHorizontalFlip的输入是PIL Image,也就是用python的PIL Image库读进来图像内容。而Normalize的对象是Tensor,因此需要增加一个ToTensor()用来将图像生成成Tensor。另外,transforms.Scale(256)是resize操作,目前已经被Resize取代。
    ImageFolder只是返回list,list是不能作为模型输入,因此在pytorch中,用另外一个类来封装list,那就是torch.utils.data.DataLoader。这个类将list类型的输入数据,图像和标签分别封装成一个Tensor数据格式,让模型使用。
    另外一个非常重要的类是torch.utils.data.Dataset,这个类是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现,比如torchvision.datasets.ImageFolder和torch.utils.data.DataLoader这两个类。所以,如果数据不是按照上面的格式存储是,需要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类。代码如下:

    def default_loader(path):
        try:
            img = Image.open(path)
            return img.convert('RGB')
        except:
            print("Cannot read image: {}".format(path))
    
    
    class customData(Dataset):
        def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
            with open(txt_path) as input_file:
                lines = input_file.readlines()
                #self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
                #self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
                self.img_name = [os.path.join(img_path, line.strip()[:-2]) for line in lines]
                self.img_label = [int(line.strip()[-1:]) for line in lines]
            self.data_transforms = data_transforms
            self.dataset = dataset
            self.loader = loader
    
        def __len__(self):
            return len(self.img_name)
    
        def __getitem__(self, item):
            img_name = self.img_name[item]
            label = self.img_label[item]
            img = self.loader(img_name)
    
            if self.data_transforms is not None:
                try:
                    img = self.data_transforms[self.dataset](img)
                except:
                    print("Cannot transform image: {}".format(img_name))
            return img, label
    
    def Data_loader():
        batch_size = 4
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }
    
        image_datasets = {x: customData(img_path='hymenoptera_data_cp/',
                                        txt_path=(x + '.txt'),
                                        data_transforms=data_transforms,
                                        dataset=x) for x in ['train', 'val']}
    
        # wrap your data and label into Tensor
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                     batch_size=batch_size,
                                                     shuffle=True) for x in ['train', 'val']}
    
        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    
        return image_datasets, dataloaders
    

     

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图像分类(1),数据预处理 的相关文章

  • 图像分类(1),数据预处理

    本文介绍如何使用pytorh利用预训练模型进行图像分类 主要参考Transfer Learning Tutorial和 具体代码可以参考Image classification 下载代码文件 git clone https github c
  • 图像分类之花卉图像分类(一)数据增强

    网上有很多图像分类的代码 有很多是必须要在GPU上面才能跑的 因为我想在自己的电脑跑 所以很多都是不能用的 而且说实话很多对我这个小白来说 都很难看懂 所以我找了一个就是之间用CNN写的神经卷积模型用来进行花卉识别 其中主要参考了以下的博主
  • Shuffle Net系列【V1—V2】

    1 ShuffleNet V1 1 1 Abstract 我们提出了一个极其效率的CNN架构 ShuffleNet 其专为计算能力非常有限的移动设备设计 这个新的架构利用了两个新的操作 pointwise group conv和channe
  • 图像分类_PyTorch图像数据分类

    图像分类数据集中最常用的是手写数字识别数据集MNIST 但大部分模型在MNIST上的分类精度都超过了95 为了更直观地观察算法之间的差异 我们将使用一个图像内容更加复杂的数据集Fashion MNIST 这个数据集也比较小 只有几十M 没有
  • 保姆级使用PyTorch训练与评估自己的ResNeXt网络教程

    文章目录 前言 0 环境搭建 快速开始 1 数据集制作 1 1 标签文件制作 1 2 数据集划分 1 3 数据集信息文件制作 2 修改参数文件 3 训练 4 评估 5 其他教程 前言 项目地址 https github com Fafa D
  • Yolov5-7.0图像分类算法修改Resnet18/50主干网络流程

    网上大多数都是基于yolov5算法的目标检测网络进行修改主干网络 我最近在尝试图像分类算法 流程如下 以resnet50为例 1 打开models下的common py文件 添加下面的代码 模型 resnet50 class resnet5
  • 综述---图像处理中的注意力机制

    重磅好文 微软亚研 对深度神经网络中空间注意力机制的经验性研究 论文 An Empirical Study of Spatial Attention Mechanisms in Deep Networks 高效Transformer层出不穷
  • 关于Descriptors cannot not be created directly报错

    报错信息为 TypeError Descriptors cannot not be created directly If this call came from a pb2 py file your generated code is o
  • 图像分类之花卉图像分类(二)数据预处理代码

    经过上一节数据增强 我们来说说数据预处理吧 首先我们要知道图片进入网络训练都是要统一大小格式的 所以我们需要对训练集和验证集的图片进行裁剪 让他们大小统一 注意测试集不用裁剪 我选择裁剪成了64 64的 没改源码的裁剪大小 其实图片大些识别
  • 人脸图像数据增强

    为什么要做数据增强 在计算机视觉相关任务中 数据增强 Data Augmentation 是一种常用的技术 用于扩展训练数据集的多样性 它包括对原始图像进行一系列随机或有规律的变换 以生成新的训练样本 数据增强的主要目的是增加模型的泛化能力
  • 模型实战(6)之Alex实现图像分类:模型原理+训练+预测(详细教程!)

    Alex实现图像分类 模型原理 训练 预测 图像分类或者检索任务在浏览器中的搜索操作 爬虫搜图中应用较广 本文主要通过Alex模型实现猫狗分类 并且将可以复用的开源模型在文章中给出 数据集可以由此下载 Data 本文将从以下内容做出讲述 1
  • 计算机视觉系列-2-图像分类

    给定一张输入图像 图像分类的任务是判断该图像属于哪类 如果是多任务分类 可以用于分类该图像包含哪个类别 深度学习作为机器学习中非常重要的分支 在图像领域中应用非常广泛 在图像分类任务中 通常采用卷积层 CNN 提取特征 加上全连接层进行分类
  • 通用图片分类项目

    generalImageClassification 文章目录 generalImageClassification 1 数据准备 1 1 开源数据集 1 2 利用特定网站爬数据 2 分类模型的选择 3 代码结构及使用方法 3 1 代码结构
  • ResNet详解:ResNet到底在解决什么问题?

    原作者开源代码 https github com KaimingHe deep residual networks 论文 https arxiv org pdf 1512 03385 pdf 1 网络退化问题 在ResNet诞生之前 Ale
  • mmclassification

    mmclassification 一 MMCLS项目 0 下载链接 Torch安装方法 CPU pip install torch i https download pytorch org whl torch stable html 指定清
  • 基于keras的图像分类CNN模型的搭建以及可视化(附详细代码)

    基于keras的图像分类CNN模型的搭建以及可视化 本文借助keras实现了热图像的分类模型的搭建 以及可视化的工作 本文主要由以下内容组成 Keras模型介绍 CNN模型搭建 模型可视化 Keras模型介绍 简介 Keras 是 Goog
  • CNN中特征融合的一些策略

    Introduction 特征融合的方法很多 如果数学化地表示 大体可以分为以下几种 X Y textbf X textbf Y X Y X
  • Linux下ImageNet2012数据集下载及其配置

    简明扼要 一 训练集下载 137G http www image net org challenges LSVRC 2012 nnoupb ILSVRC2012 img train tar 验证集下载 http www image net
  • 利用pytorch训练网络---垃圾分类,(resnet18)

    数据集包含6种垃圾 分别为cardboard 纸箱 glass 玻璃 metal 金属 paper 纸 plastic 塑料 其他废品 trash 数据数量较小 仅供学习 数据集标准备工作 包括将数据集分为训练集和测试集 制作标签文件 代码
  • 保姆级使用PyTorch训练与评估自己的ConvNeXt网络教程

    文章目录 前言 0 环境搭建 快速开始 1 数据集制作 1 1 标签文件制作 1 2 数据集划分 1 3 数据集信息文件制作 2 修改参数文件 3 训练 4 评估 5 其他教程 前言 项目地址 https github com Fafa D

随机推荐

  • Acwing 895. 最长上升子序列

    f i 表示所有以第i个数结尾的上升子序列中的最大个数 f i max f j 1 j 0 1 2 i 1 include
  • Openwrt开发笔记(1)—— 开发环境

    OpenWrt简介 OpenWrt 是一个嵌入式设备的 Linux 发行版 以 GPL 许可协议发行 其主要特点有如下几个 代码里不含第三方开源包 只包含开源包地址链接 在编译的时候下载 编译时自动下载源代码 打补丁来满足指定平台要求 并编
  • Oracle生成不重复字符串 sys_guid()

    在oracle8i以后提供了一个生成不重复的数据的一个函数sys guid 一共32位 生成的依据主要是时间和机器码 具有世界唯一性 类似于java中的UUID 都是世界唯一的 其优点就是生成的字符串是唯一的 但其和UUID有同样的弊端 生
  • 论文笔记:FILLING THE G AP S: MULTIVARIATE TIME SERIES IMPUTATION BY GRAPH NEURAL NETWORKS

    0 abstract introduction 之前的补全方法并不能很好地捕获 利用 不同sensor之间的非线性时间 空间依赖关系 高效的时间序列补全方法 不仅应该考虑过去 或者未来 的数值 还应该同时考虑空间上邻近的点的测量值 这里的空
  • 抖音超火的网页表白代码大全(浪漫的html表白源代码)

    精彩专栏推荐 作者主页 进入主页 获取更多源码 web前端期末大作业 HTML5网页期末作业 1000套 程序员有趣的告白方式 HTML七夕情人节表白网页制作 125套 七夕来袭 是时候展现专属于程序员的浪漫了 你打算怎么给心爱的人表达爱意
  • 姿态分析开源工具箱MMPose使用示例:人体姿势估计

    MMPose的介绍及安装参考 https blog csdn net fengbingchun article details 126676309 这里给出人体姿势估计的测试代码 论文 Deep high resolution repres
  • UI自动化测试-selenium元素定位

    在使用Selenium和WebDriver进行UI自动化测试时 我们首先需要对元素定位 那么如何来定位元素呢 HTML 在进行元素定位之前 我们要对html代码有所了解 div class s form div class s form w
  • 【CUDA学习】__syncthreads的理解

    syncthreads 是cuda的内建函数 用于块内线程通信 syncthreads is you garden variety thread barrier Any thread reaching the barrier waits u
  • BUUCTF学习笔记-Secret File

    BUUCTF学习笔记 Secret File 时间 2020 09 28 考点 文件包含 打开页面没发现特别的内容 右键查看源代码才发现下面隐藏了一个a标签只是字体颜色改成了和背景色一样 标签跳转到另外一个页面Archive room ph
  • C语言实现图的邻接矩阵存储结构及深度优先遍历和广度优先遍历

    DFS的核心思想在于对访问的邻接节点进行递归调用 BFS的核心思想在于建立了一个邻接节点的队列 在Dev C 中调试运行通过 用下图进行了测试 include
  • Python里面的[::-1]和[::2]是什么意思

    Python里面的 1 和 2 是什么意思 来源 https docs python org 2 3 whatsnew section slices html L range 10 L 2 0 2 4 6 8 L 1 9 8 7 6 5 4
  • 使用Vue+xlsx+xlsx-style实现导出自定义样式的Excel文件

    本文就是上一篇 使用Python openpyxl实现导出自定义样式的Excel文件 文章中提到的 之前项目的导出Excel文件操作都是在前端完成的 这段话中基于前端实现的导出Excel文件方法 文档地址 https docs sheetj
  • python中circle函数的用法,python画圆运用了什么函数

    python画圆运用了matplotlb库的figure 和Circle 函数 其中 figure 函数用于确定画布大小 而Circle 函数用于配置圆的相关信息 进而画圆 本教程操作环境 windows7系统 Python3版 Dell
  • MySQL多表操作练习题

    数据准备 CREATE table dept deptno INT PRIMARY KEY dname VARCHAR 14 loc VARCHAR 13 INSERT INTO dept VALUES 10 accounting new
  • QT开发(十九)——QT内存泄漏问题

    QT开发 十九 QT内存泄漏问题 一 QT对象间的父子关系 QT最基础和核心的类是 QObject QObject内部有一个list 会保存children 还有一个指针保存parent 当自己析构时 会自己从parent列表中删除并且析构
  • Java 反射 与 主要API

    控制你的大脑 控制你的行为 你会得到更多 收获很多 文章目录 一 反射相关的主要API 二 代码例子演示 三 反射测试类 一 反射相关的主要API API 名称 代表含义 Java lang class 代表一个类 java lang re
  • J-Tech Talk | 6.29首播 Python文档漫谈

    J Tech Talk 由 Jina AI 社区为大家带来的技术分享 围绕 Python 人工智能 深度学习等 给大家带来针对具体实战型问题的讲解 分享 Jina AI 在开发过程中所积累的经验 Github 的开源项目数不胜数 想要扩大项
  • 使用OPENLDAP C API修改 win2003 AD域(Active Directory)用户密码

    参考 http blog csdn net wzhwho article details 6209693 参考 http www 121 name LDAP html 参考 http pig made it com pig adusers
  • 单片机学习笔记1:单片机简介

    单片机 1 什么是单片机 单片机 Single Chip Microcomputer 是一种集成电路芯片 是采用超大规模集成电路技术把具有数据处理能力的中央处理器 CPU 存储器 RAM ROM 中断系统 I O接口电路 定时器 计数器等功
  • 图像分类(1),数据预处理

    本文介绍如何使用pytorh利用预训练模型进行图像分类 主要参考Transfer Learning Tutorial和 具体代码可以参考Image classification 下载代码文件 git clone https github c