example1.py代码解析:
数据导入,使用官方写好的torchvision.datasets.ImageFolder接口实现数据导入。这个函数只需要你提供图像所在文件夹data_dir/train和data_dir/test即可。这两个目录下分别为N个子文件夹,N为分类的类别数,每个文件夹下为这个类别的图像。这样,torchvision.datasets.ImageFloder就会返回一个列表,列表中每一个值都是一个tuple,每个tuple包含图像和标签信息 def Data_loader(Data_Path):
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
#transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = Data_Path
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
return dataloaders, image_datasets, class_names
dataloaders, image_datasets, class_names = Data_loader('hymenoptera_data')
print(image_datasets)
for e in image_datasets:
print(e)
print(image_datasets[e])
for index, k in enumerate(image_datasets[e]):
print(type(k), len(k))
print(index, k[0].size(), k[1])
transform对图像进行预处理。torchvision.transform.Compose是用来管理所有的transforms操作的。RandomSizeCrop和RandomHorizontalFlip的输入是PIL Image,也就是用python的PIL Image库读进来图像内容。而Normalize的对象是Tensor,因此需要增加一个ToTensor()用来将图像生成成Tensor。另外,transforms.Scale(256)是resize操作,目前已经被Resize取代。
ImageFolder只是返回list,list是不能作为模型输入,因此在pytorch中,用另外一个类来封装list,那就是torch.utils.data.DataLoader。这个类将list类型的输入数据,图像和标签分别封装成一个Tensor数据格式,让模型使用。
另外一个非常重要的类是torch.utils.data.Dataset,这个类是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现,比如torchvision.datasets.ImageFolder和torch.utils.data.DataLoader这两个类。所以,如果数据不是按照上面的格式存储是,需要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类。代码如下:
def default_loader(path):
try:
img = Image.open(path)
return img.convert('RGB')
except:
print("Cannot read image: {}".format(path))
class customData(Dataset):
def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
with open(txt_path) as input_file:
lines = input_file.readlines()
#self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
#self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
self.img_name = [os.path.join(img_path, line.strip()[:-2]) for line in lines]
self.img_label = [int(line.strip()[-1:]) for line in lines]
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.img_name)
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
img = self.loader(img_name)
if self.data_transforms is not None:
try:
img = self.data_transforms[self.dataset](img)
except:
print("Cannot transform image: {}".format(img_name))
return img, label
def Data_loader():
batch_size = 4
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: customData(img_path='hymenoptera_data_cp/',
txt_path=(x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
# wrap your data and label into Tensor
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
return image_datasets, dataloaders