pytorch分割网络数据输入接口

2023-11-20

 pytorch的自定义接口是真的方便, 记录一下自己分割数据输入的脚本:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/31 21:36
# @Author  : Yunyun Xu
# @Contact : 1443563995@qq.com
# @File    : MyDatasetReader.py
# @Software: Pycharm
# @Blog    : https://me.csdn.net/xuyunyunaixuexi

import os
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from mypath import Path
from torchvision import transforms
import custom_transforms as tr

class MyEggSegmentation(data.Dataset):
    #NUM_CLASSES = 19

    def __init__(self, args, root = Path.db_root_dir("MyEggs"), split = "train"):

        self.root = root
        self.split = split
        self.args = args
        self.image_files = {}
        self.label_files = {}
        #files = {train:[]}

        self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
        self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split)
        self.image_files[split] = self.recursive_glob(rootdir=self.images_base, suffix=".png")
        self.label_files[split] = self.recursive_glob(rootdir = self.annotations_base,suffix = ".png")

        if not self.image_files[split]:
            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))

    def __len__(self):
        return len(self.image_files[self.split])

    def __getitem__(self, index):
        img_path = self.image_files[self.split][index].rstrip()
        lbl_path = self.label_files[self.split][index].rstrip()
        #将RGBA转为RGB三通道
        _img = Image.open(img_path).convert("RGB")
        #读取索引图
        _target = Image.open(lbl_path)

        sample = {"images":_img, "label":_target}

        if self.split == "train":
            return self.transform_tr(sample)
        if self.split == "val":
            return self.transform_tr(sample)
        if self.split == "test":
            return self.transform_tr(sample)

    def recursive_glob(self, rootdir = '.', suffix = " "):
        return [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(rootdir)
                for filename in filenames if filename.endswith(suffix)]

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_ts(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

测试了一下,是可以遍历的,证明自定义数据集接口(继承data.Dataset)是正确的:

但是本人也有一个问题, 就是分割网络如果根据自己的数据集大小,去确定crop_size, 还是只能不断的尝试去效果??希望得到大家的解答.........

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch分割网络数据输入接口 的相关文章

随机推荐