PyTorch自制数据集

2023-10-31

PyTorch加载数据主要分为两类:只有图片的数据集以及含有csv保存标签的数据集。只有图片的数据集又分为两类:标签在文件夹上和标签在图片名上。

学习地址

1.标签在文件夹上

在这里插入图片描述
此情况下导入数据集,只需要调用PyTorch中的ImageFolder进行载入。(可以直接采用split_data.py划分训练集、测试集、验证集)

导入所需的库

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
#以上语句是由于python与torch版本不匹配才加的与加载数据无关

from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, utils,datasets
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

数据增强函数

data_transform = transforms.Compose([
 transforms.Resize(32), # 缩放图片(Image),保持长宽比不变,最短边为32像素
 transforms.CenterCrop(32), # 从图片中间切出32*32的图片
 transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 transforms.Normalize(mean=[0.492, 0.461, 0.417], std=[0.256, 0.248, 0.251]) # 标准化至[-1, 1],规定均值和标准差
])
  • data_transform作用是对图片进行标准化和归一化 Resize(32)缩放图片(Image),保持长宽比不变,最短边为32像素 CenterCrop(32)从图片中间切出32*32的图片 RandomSizedCrop(32)这一句的作用是对原图进行随机大小和高宽比的裁剪,最后的尺寸为32x32
  • RandomHorizontalFlip()这个则是对原图像根据概率进行随机水平翻转
  • transforms.ToTensor()将图片转化为张量,并使图片的形式表现为通道x高x宽的形式
  • transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])这个则是对数据 进行正则化操作,第一个参数为均值,第二个参数为标准差。

计算RGB图片的均值标准差代码如下:

import numpy as np
import cv2
import os

#以下代码用于求RGB图像的均值和标准差
# img_h, img_w = 32, 32
img_h, img_w = 32, 32  # 经过处理后你的图片的尺寸大小
means, stdevs = [], []
img_list = []

imgs_path = "/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train/cats"  # 数据集的路径采用绝对引用

imgs_path_list = os.listdir(imgs_path)

len_ = len(imgs_path_list)
i = 0
for item in imgs_path_list:
    img = cv2.imread(os.path.join(imgs_path, item))
    img = cv2.resize(img, (img_w, img_h))
    img = img[:, :, :, np.newaxis]
    img_list.append(img)
    i += 1
    print(i, '/', len_)

imgs = np.concatenate(img_list, axis=3)
imgs = imgs.astype(np.float32) / 255.

for i in range(3):
    pixels = imgs[:, :, i, :].ravel()  # 拉成一行
    means.append(np.mean(pixels))
    stdevs.append(np.std(pixels))

# BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
means.reverse()
stdevs.reverse()

print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))

计算结果如下:
在这里插入图片描述

调用ImageFolder进行数据集的加载

hymenoptera_dataset = datasets.ImageFolder(root="/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train",
           transform=data_transform) #导入数据集
  • 第一个参数为数据集路径的参数。
  • 第二个参数为数据增强函数的调用,对加载的数据集进行相关数据操作。

导入数据集后,查看导入情况。

  • 查看图像相关信息
img, label = hymenoptera_dataset[15000] #将启动魔法方法__getitem__(0)
"""这个15000,表示所有文件夹排序后的第15001张图片,0是第一张图片"""
print(label)   #查看标签
"""这里的0表示cat,1表示dog;因为是按文件夹排列的顺序,如果有第三个文件夹pig则2表示pig"""
print(img.size())
print(img)

#处理后的图片信息
for img, label in hymenoptera_dataset:
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

结果:

/home/xyjin/anaconda3/envs/pytorch/bin/python /home/xyjin/PycharmProjects/data_mining/data_test/temp.py
1
torch.Size([3, 32, 32])
tensor([[[-0.1909, -0.2675,  0.0695,  ...,  0.4985,  0.4678,  0.2687],
         [-0.1296, -0.2521,  0.0542,  ...,  0.4525,  0.4066,  0.2381],
         [-0.0530, -0.2368,  0.0389,  ...,  0.4525,  0.4219,  0.2381],
         ...,
         [ 0.0542,  0.1155,  0.2074,  ...,  0.3300,  0.3606,  0.3453],
         [ 0.2534,  0.2840,  0.2840,  ...,  0.2074,  0.3146,  0.3759],
         [ 0.3146,  0.3606,  0.3300,  ...,  0.2534,  0.1615,  0.3453]],

        [[ 0.0070, -0.1353,  0.1810,  ...,  0.8135,  0.7819,  0.3233],
         [ 0.0545, -0.1195,  0.1652,  ...,  0.7661,  0.7502,  0.3075],
         [ 0.1652, -0.1195,  0.1652,  ...,  0.7502,  0.6237,  0.2759],
         ...,
         [ 0.2600,  0.3075,  0.3707,  ...,  0.4498,  0.4972,  0.4972],
         [ 0.4182,  0.4498,  0.4498,  ...,  0.3391,  0.4656,  0.5447],
         [ 0.4972,  0.5289,  0.4972,  ...,  0.3865,  0.2917,  0.4972]],

        [[ 0.0260, -0.0990,  0.1823,  ...,  1.0259,  0.9322,  0.1823],
         [ 0.0885, -0.0834,  0.1510,  ...,  0.9791,  0.8697,  0.1666],
         [ 0.2291, -0.0677,  0.1979,  ...,  0.9009,  0.7291,  0.1198],
         ...,
         [ 0.3541,  0.4010,  0.4635,  ...,  0.4479,  0.4791,  0.4947],
         [ 0.5260,  0.5728,  0.5416,  ...,  0.3229,  0.4635,  0.5260],
         [ 0.5728,  0.6197,  0.6041,  ...,  0.3854,  0.2760,  0.5260]]])
图像img的形状torch.Size([3, 32, 32]),标签label的值0
图像数据预处理后:
 tensor([[[ 1.8159,  1.8618,  1.8925,  ...,  1.9384,  1.9231,  1.9078],
         [ 1.8006,  1.8465,  1.8771,  ...,  1.9384,  1.9231,  1.9231],
         [ 1.7546,  1.8006,  1.8618,  ...,  1.9384,  1.9384,  1.9384],
         ...,
         [-0.8036, -0.7270, -0.7270,  ..., -1.4930, -1.5389, -1.6002],
         [-0.8496, -0.7883, -0.7883,  ..., -0.6351, -1.2172, -1.6155],
         [-0.8343, -0.8496, -0.8343,  ..., -1.1253, -1.3857, -1.6768]],

        [[ 1.3195,  1.3669,  1.4618,  ...,  1.7148,  1.7464,  1.7781],
         [ 1.3037,  1.3511,  1.4144,  ...,  1.7148,  1.7306,  1.7781],
         [ 1.2562,  1.3037,  1.3669,  ...,  1.7148,  1.7464,  1.7939],
         ...,
         [-1.0366, -1.0208, -1.0050,  ..., -1.4952, -1.5268, -1.5901],
         [-1.0524, -1.0524, -1.0524,  ..., -0.8469, -1.3054, -1.5901],
         [-1.0208, -1.0840, -1.0999,  ..., -1.2264, -1.4161, -1.6375]],

        [[ 0.2447,  0.2916,  0.3697,  ...,  0.7135,  0.7291,  0.7760],
         [ 0.2291,  0.2760,  0.3229,  ...,  0.7135,  0.7291,  0.8072],
         [ 0.1823,  0.2291,  0.2760,  ...,  0.6978,  0.7447,  0.8541],
         ...,
         [-1.5520, -1.5051, -1.5051,  ..., -1.5989, -1.5989, -1.6145],
         [-1.5676, -1.5364, -1.5207,  ..., -1.4114, -1.5520, -1.6145],
         [-1.5676, -1.5676, -1.5520,  ..., -1.5051, -1.5832, -1.6145]]])

Process finished with exit code 0

对加载的数据进行batch size处理:表示加载的数据切分为4个为一组的 shuffle=True,送入训练的的数据是打乱后的,而不是顺序输入。

dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4,shuffle=True)
  • 使用显示图片与对应标签的方式进行查验
import torchvision
import matplotlib.pyplot as plt
import numpy as np
# %matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataset_loader)#此处填写加载的数据集
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

图片结果:
在这里插入图片描述
[ ’ 小 狗 ’ , ’ 小 狗 ’ , ’ 小 猫 ’ , ’ 小 狗 ’ ]

总结:标签在文件夹上的数据载入,代码如下:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, utils,datasets
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

data_transform = transforms.Compose([
 transforms.Resize(32), # 缩放图片(Image),保持长宽比不变,最短边为32像素
 transforms.CenterCrop(32), # 从图片中间切出32*32的图片
 transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 transforms.Normalize(mean=[0.492, 0.461, 0.417], std=[0.256, 0.248, 0.251]) # 标准化至[-1, 1],规定均值和标准差
])

hymenoptera_dataset = hymenoptera_dataset = datasets.ImageFolder(root="/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train",
           transform=data_transform) #导入数据集
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4,shuffle=True) #

2. 标签在图片名上

3. 将数据集分为训练集,验证集和测试集

4.标签存储在csv文件中

数据集介绍

  • 本数据集加载使用的数据集是给狗进行种类的识别的数据集
    在这里插入图片描述
    该文件夹中有:测试集,训练集,还有存有训练集标签和对应标签的文件名的csv文件
    在这里插入图片描述

数据集加载的方式和第一种相同,需要将每一个种类的数据照片放到对应种类命名的文件夹中

首先进行数据集的拆分,导入对应的库;并设置用到的一些变量

import math
import os
import shutil
from collections import Counter

data_dir = "/home/xyjin/PycharmProjects/data_mining/data_test/study/dog-breed-identification" #数据集的根目录
label_file = 'labels.csv'#根目录中csv的文件名加后缀
train_dir = 'train'#根目录中的训练集文件夹的名字
test_dir = 'test'#根目录中的测试集文件夹的名字
input_dir = 'train_valid_test'#用于存放拆分数据集的文件夹的名字,可以不用先创建,会自动创建
batch_size = 4#送往训练的一批次中的数据集的个数
valid_ratio = 0.1#将训练集拆分为90%为训练集10%为验证集
  • 训练集拆分为训练集(test)和验证集(valid)然后分别放到对应的文件夹中
    在这里插入图片描述

  • train_valid文件夹中包含了所有训练集、验证集的数据
    在这里插入图片描述

  • 训练集、验证集和train_valid的文件夹中都是一个个种类的小文件夹,其中存放着对应的数据集图像

  • test中没有标签所以所有的数据照片都存放在unknown中的

程序如下:

def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,
                   valid_ratio):
    # 读取训练数据标签,label.csv文件读取标签以及对应的文件名
    with open(os.path.join(data_dir, label_file), 'r') as f:
        # 跳过文件头行(栏名称)
        lines = f.readlines()[1:]
        tokens = [l.rstrip().split(',') for l in lines]
        idx_label = dict(((idx, label) for idx, label in tokens))
    labels = set(idx_label.values())

    num_train = len(os.listdir(os.path.join(data_dir, train_dir)))#获取训练集的数量便于数据集的分割
    # 训练集中数量最少一类的狗的数量
    min_num_train_per_label = (
        Counter(idx_label.values()).most_common()[:-2:-1][0][1])
    # 验证集中每类狗的数量
    num_valid_per_label = math.floor(min_num_train_per_label * valid_ratio)
    label_count = dict()

    def mkdir_if_not_exist(path):#判断是否有存放拆分后数据集的文件夹,没有就创建一个
        if not os.path.exists(os.path.join(*path)):
            os.makedirs(os.path.join(*path))

    # 整理训练和验证集,将数据集进行拆分复制到预先设置好的存放文件夹中。
    for train_file in os.listdir(os.path.join(data_dir, train_dir)):
        idx = train_file.split('.')[0]
        label = idx_label[idx]
        mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
        shutil.copy(os.path.join(data_dir, train_dir, train_file),
                    os.path.join(data_dir, input_dir, 'train_valid', label))
        if label not in label_count or label_count[label] < num_valid_per_label:
            mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            mkdir_if_not_exist([data_dir, input_dir, 'train', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'train', label))

    # 整理测试集,将测试集复制存放在新建路径下的unknown文件夹中
    mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
    for test_file in os.listdir(os.path.join(data_dir, test_dir)):
        shutil.copy(os.path.join(data_dir, test_dir, test_file),
                    os.path.join(data_dir, input_dir, 'test', 'unknown'))

reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,valid_ratio)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch自制数据集 的相关文章

随机推荐

  • Redis实现点赞功能模块

    功能点设计 统计文章点赞的总数 用户所有文章的点赞数 因此设计的点赞功能模块具有以下功能点 某篇文章的点赞数 用户所有文章的点赞数 用户点赞的文章 持久化到MySQL数据库 数据库设计 1 Redis数据库设计Redis是K V数据库 没有
  • Java if判断语句的用法(一)

    If语句概述和使用格式 1 if语句用于判断不同的条件 根据判断的结果执行不同的代码 2 if语句判断的条件可以是关系运算 逻辑运算 if语句根据逻辑值true false来决定执行不同的代码 3 if语句在开发中使用极为广泛 if语句格式
  • 【华为OD机试真题 JAVA】火锅

    JS版 华为OD机试真题 JS 火锅 标题 火锅 时间限制 1秒 内存限制 262144K 语言限制 不限 入职后 导师会请你吃饭 你选择了火锅 火锅里会在不同时间下很多菜 不同食材要煮不同的时间 才能变得刚好合适 你希望吃到最多的刚好合适
  • 华为OD机试-求满足要求的最长子串

    题目描述 给定一个字符串 只包含字母和数字 按要求找出字符串中的最长 连续 子串的长度 字符串本身是其最长的子串 子串要求 1 只包含1个字母 a z A Z 其余必须是数字 2 字母可以在子串中的任意位置 如果找不到满足要求的子串 如全是
  • @Bean 的用法

    Bean是一个方法级别上的注解 主要用在 Configuration注解的类里 也可以用在 Component注解的类里 添加的bean的id为方法名 定义bean 下面是 Configuration里的一个例子 Configuration
  • 【Vue学习笔记7】Vue3中如何开发组件

    重点学习 vue3 0之组件通信机制defineProps 组件接收外部传来的参数 defineEmits 向组件外部传递参数 1 评级组件第一版 简单的评级需求 只需要一行代码就可以实现 slice 5 rate 10 rate 只需要传
  • 动态规划(最大子序和 && 乘积最大子序列)

    一 最大子序列和 给定一个整数数组 nums 找到一个具有最大和的连续子数组 子数组最少包含一个元素 返回其最大和 示例 输入 2 1 3 4 1 2 1 5 4 输出 6 解释 连续子数组 4 1 2 1 的和最大 为 6 https l
  • 使用三种方法获取远程连接服务器上的文件

    文章目录 概要 alt p sz 使用kettle软件 概要 第一种方法 alt p 第二种方法 sz 第三种方法 使用kettle软件 alt p 在crt连接页面使用快捷键 alt p 打开sftp页面 使用例如 get a txt 获
  • LeetCode(力扣)62. 不同路径Python

    LeetCode62 不同路径 题目链接 代码 题目链接 https leetcode cn problems unique paths 代码 递归 class Solution def uniquePaths self m int n i
  • C++忘记返回值导致异常bug

    问题 在C 函数实现时 定义一个函数如下 bool MCUSerialImpl InitDevInfo devInfo std make shared
  • 【MySQL】不就是多表查询

    前言 嗨 小伙伴们大家好呀 忙碌的一周就要开始 在此之前我们学习的MySQL数据库的各种操作都是在一张表之中 今天我们学习要对多张表进行相关操作 相比较于单一的表来说 多张表操作相对复杂一些 我相信只要认真学习多表查询也不再话下 目录 目录
  • 2023年Python面试题(爬虫)

    爬取数据后使用哪个数据库存储数据的 为什么 MongoDB 是使用比较多的数据库 这里以 MongoDB 为例 大家需要结合自己真实开发环境回答 原因 1 与关系型数据库相比 MongoDB 的优点如下 1 弱一致性 最终一致 更能保证用户
  • Linux DRM框架详解

    Linux DRM框架详解
  • c++ 给定n个十六进制正整数,输出它们对应的八进制数。

    问题描述 给定n个十六进制正整数 输出它们对应的八进制数 输入格式 输入的第一行为一个正整数n 1 lt n lt 10 接下来n行 每行一个由09 大写字母AF组成的字符串 表示要转换的十六进制正整数 每个十六进制数长度不超过100000
  • 【Node.js】模块化:

    文章目录 1 模块化的基本概念 2 Node js 中模块化 1 Node js 中模块的分类 2 加载模块 3 模块作用域 4 向外共享模块作用域中的成员 5 模块化规范 3 npm与包 包 依赖 插件 1 包的基本知识 2 开发属于自己
  • 进程池、线程池、协程

    什么是池 保证计算机硬件安全的情况下最大限度利用计算机 降低了程序的运行效率 但保证了硬件的安全 受限于硬件的物理极限 硬件的发展跟不上软件的速度 迫不得已提出了池的概念 进程池 线程池 提交任务的方式 同步 提交任务之后 原地等待任务的返
  • 多线程 UDP传输速率 实验

    现阶段问题 丢包问题 丢包率达到50 但是ping的时候反应良好 1 分析UDP丢包的原因 1 现象是每隔一个包丢失一个 所以考虑是否是缓冲区的问题 答 用不同数据包大小10 100 500 1500发现都是收一个丢一个 说明缓冲区大小并不
  • android 手机内存64实际不到,为什么你的手机内存总是达不到64G?丢失的内存去哪了?详细解读...

    近些年手机各项参数快速发展 除了屏幕 处理器 相机等主要零部件性能提升的同时 我们手机的内存也是越来越大 从最刚开始的2GB 到4G 8G 16与32G 再到现在标配64G起步 手机软件生态越来越完善 现在64G的手机也变得捉襟见肘 但是当
  • Linux man 命令详解

    man 命令 Linux man 命令用于显示 Linux 操作系统中的手册页 manual page 它提供了对 Linux 操作系统中各种命令 函数 库等的详细说明 man 命令有许多参数 参数介绍 下面简要介绍一下主要参数的功能 f
  • PyTorch自制数据集

    PyTorch加载数据主要分为两类 只有图片的数据集以及含有csv保存标签的数据集 只有图片的数据集又分为两类 标签在文件夹上和标签在图片名上 学习地址 1 标签在文件夹上 此情况下导入数据集 只需要调用PyTorch中的ImageFold