介绍
猫狗分类来源于Kaggle上的一个入门竞赛。
https://www.kaggle.com/competitions/dogs-vs-cats-redux-kernels-edition/overview
代码及解释
首先,导入一系列的库。
import numpy as np
from PIL import Image
from pathlib import Path
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
这段代码主要是导入了一些Python库,包括:
- numpy:Python中常用的科学计算库,用于处理数组、矩阵等数值数据。
- PIL(Python Imaging Library):Python图像处理库,用于处理各种图像格式。
- pathlib:Python处理文件和目录路径的标准库,支持多平台。
- torch:PyTorch深度学习框架的核心库。
- nn:PyTorch中用于构建神经网络的模块。
- F(functional):PyTorch中用于创建自定义卷积层、激活函数等的函数。
- DataLoader:PyTorch中用于加载和批量处理数据的工具。
- transforms:PyTorch中对图像和数据进行预处理的工具。
- matplotlib:Python绘图库,用于绘制数据和图像。
这些库的导入是PyTorch实践项目中经常用到的基础操作,其中PIL、numpy和matplotlib主要用于读取和展示图像、transforms用于对图像进行数据增强,torch和nn则是构建和训练深度神经网络的核心。
而后,启用GPU加速。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: ", device)
get_label = lambda x: x.name.split('.')[0]
class get_dataset(Dataset):
def __init__(self, root, transform=None):
self.images = list(Path(root).glob('*.jpg'))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = self.images[idx]
label = get_label(img)
label = 1 if label == 'dog' else 0
if self.transform:
img = self.transform(Image.open(img))
return img, torch.tensor(label, dtype=torch.int64)
这段代码定义了一个类get_dataset,用于加载和预处理数据集。
在类的初始化函数中,root为数据集路径,transform为数据预处理函数。通过list和glob函数获取符合条件的文件名,即所有后缀为jpg的图片文件名,并将其转为列表self.images。同时记录transform函数,即数据预处理函数。
__len__函数返回数据集中的图片数量,__getitem__函数根据索引idx获取对应图片和标签。首先获取索引对应的图片img,并通过get_label函数获取该图片对应的标签。该函数将图片文件名以’.‘分割,并将第一个分割出来的字符串作为标签。如果标签等于’dog’,则将其转为数字1,否则转为数字0。
接着如果有定义transform函数,就将img通过transform函数进行数据预处理。最后返回处理后的图片和标签,其中标签用torch.tensor转为整型。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)