使用可变批量大小加载数据?

2023-12-28

我目前正在研究基于补丁的超分辨率。大多数论文将图像分割成更小的补丁,然后使用这些补丁作为模型的输入。我能够使用自定义数据加载器创建补丁。代码如下:

import torch.utils.data as data
from torchvision.transforms import CenterCrop, ToTensor, Compose, ToPILImage, Resize, RandomHorizontalFlip, RandomVerticalFlip
from os import listdir
from os.path import join
from PIL import Image
import random
import os
import numpy as np
import torch

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])

class TrainDatasetFromFolder(data.Dataset):
    def __init__(self, dataset_dir, patch_size, is_gray, stride):
        super(TrainDatasetFromFolder, self).__init__()
        self.imageHrfilenames = []
        self.imageHrfilenames.extend(join(dataset_dir, x)
                                     for x in sorted(listdir(dataset_dir)) if is_image_file(x))
        self.is_gray = is_gray
        self.patchSize = patch_size
        self.stride = stride

    def _load_file(self, index):
        filename = self.imageHrfilenames[index]
        hr = Image.open(self.imageHrfilenames[index])
        downsizes = (1, 0.7, 0.45)
        downsize = 2
        w_ = int(hr.width * downsizes[downsize])
        h_ = int(hr.height * downsizes[downsize])
        aug = Compose([Resize([h_, w_], interpolation=Image.BICUBIC),
                       RandomHorizontalFlip(),
                       RandomVerticalFlip()])

        hr = aug(hr)
        rv = random.randint(0, 4)
        hr = hr.rotate(90*rv, expand=1)
        filename = os.path.splitext(os.path.split(filename)[-1])[0]
        return hr, filename

    def _patching(self, img):

        img = ToTensor()(img)
        LR_ = Compose([ToPILImage(), Resize(self.patchSize//2, interpolation=Image.BICUBIC), ToTensor()])

        HR_p, LR_p = [], []
        for i in range(0, img.shape[1] - self.patchSize, self.stride):
            for j in range(0, img.shape[2] - self.patchSize, self.stride):
                temp = img[:, i:i + self.patchSize, j:j + self.patchSize]
                HR_p += [temp]
                LR_p += [LR_(temp)]

        return torch.stack(LR_p),torch.stack(HR_p)

    def __getitem__(self, index):
        HR_, filename = self._load_file(index)
        LR_p, HR_p = self._patching(HR_)
        return LR_p, HR_p

    def __len__(self):
        return len(self.imageHrfilenames)

假设批量大小为 1,它获取图像并给出 size 的输出[x,3,patchsize,patchsize]。当批量大小为 2 时,我将有两个不同大小的输出[x,3,patchsize,patchsize](例如图像 1 可能给出[50,3,patchsize,patchsize],图像2可能给出[75,3,patchsize,patchsize])。为了处理这个问题,需要一个自定义的整理函数来沿着维度 0 堆叠这两个输出。整理函数如下:

def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim = 0)
    target = torch.cat([item[1] for item in batch],dim = 0)

    return [data, target]

这个整理函数沿着 x 连接(从上面的例子中,我终于得到[125,3,patchsize,pathsize]。出于训练目的,我需要使用 25 的小批量大小来训练模型。是否有任何方法或函数可以用来直接获得大小的输出[25 , 3, patchsize, pathsize]直接从数据加载器使用必要数量的图像作为数据加载器的输入?


以下代码片段可满足您的目的。

首先,我们定义一个 ToyDataset,它接受张量列表(tensors) of variable length in dimension 0。这与数据集返回的样本类似。

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler

class ToyDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(tensors)

其次,我们定义一个自定义数据加载器。创建数据集和数据加载器的常见 Pytorch 二分法大致如下:dataset,您可以向其传递索引,它会从数据集中返回关联的样本。有一个sampler产生一个索引,有不同的策略来绘制索引,从而产生不同的采样器。采样器由batch_sampler一次绘制多个索引(与batch_size指定的数量相同)。有一个dataloader它结合了采样器和数据集,让您可以迭代数据集,重要的是数据加载器还拥有一个函数(collate_fn),它指定如何组合使用来自batch_sampler的索引从数据集中检索的多个样本。对于您的用例,通常的 PyTorch 二分法效果不佳,因为我们需要绘制索引,直到与索引关联的对象超过我们期望的累积大小,而不是绘制固定数量的索引。这意味着我们需要立即检查对象并使用这些知识来决定是否返回批次或保留绘图索引。这就是下面的自定义数据加载器的作用:

class CustomLoader(object):

    def __init__(self, dataset, my_bsz, drop_last=True):
        self.ds = dataset
        self.my_bsz = my_bsz
        self.drop_last = drop_last
        self.sampler = RandomSampler(dataset)

    def __iter__(self):
        batch = torch.Tensor()
        for idx in self.sampler:
            batch = torch.cat([batch, self.ds[idx]])
            while batch.size(0) >= self.my_bsz:
                if batch.size(0) == self.my_bsz:
                    yield batch
                    batch = torch.Tensor()
                else:
                    return_batch, batch = batch.split([self.my_bsz,batch.size(0)-self.my_bsz])
                    yield return_batch
        if batch.size(0) > 0 and not self.drop_last:
            yield batch

在这里,我们迭代数据集,在绘制索引并加载关联对象后,我们将其连接到我们之前绘制的张量(batch)。我们继续这样做,直到达到所需的尺寸,这样我们就可以切割并生产一批。我们保留行batch,我们没有屈服。因为可能会出现单个实例超过所需的batch_size的情况,所以我们使用while loop.

您可以修改这个最小CustomDataloader以 PyTorch 数据加载器的风格添加更多功能。也不需要使用 RandomSampler 来绘制索引,其他的也同样可以工作。如果您的数据很大,通过使用列表并跟踪其张量的累积长度,也可以避免重复的连接。

这是一个示例,演示了它的工作原理:

patch_size = 5
channels = 3
dim0sizes = torch.LongTensor(100).random_(1, 100)
data = torch.randn(size=(dim0sizes.sum(), channels, patch_size, patch_size))
tensors = torch.split(data, list(dim0sizes))

ds = ToyDataset(tensors)
dl = CustomLoader(ds, my_bsz=250, drop_last=False)
for i in dl:
    print(i.size(0))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用可变批量大小加载数据? 的相关文章

  • 初始化 dask 分布式工作线程的状态

    我正在尝试做类似的事情 resource MyResource def fn x something dosemthing x resource return something client Client results client m
  • 图像处理:什么是遮挡?

    我正在开发一个图像处理项目 我遇到了这个词闭塞在许多科学论文中 遮挡在图像处理中意味着什么 字典只是给出了一般的定义 谁能使用图像作为上下文来描述它们 遮挡意味着您想看到某些内容 但由于传感器设置的某些属性或某些事件而无法看到 它到底如何表
  • 使用 Pycharm 在 Windows 下启动应用程序时出现 UnicodeDecodeError

    问题是当我尝试启动应用程序 app py 时 我收到以下错误 UnicodeDecodeError utf 8 编解码器无法解码位置 5 中的字节 0xb3 起始字节无效 整个文件app py coding utf 8 from flask
  • Pytorch LSTM:计算交叉熵损失的目标维度

    我一直在尝试在 Pytorch 中使用 LSTM LSTM 后跟自定义模型中的线性层 但在计算损失时出现以下错误 Assertion cur target gt 0 cur target lt n classes failed 我用以下函数
  • 在f字符串中转义字符[重复]

    这个问题在这里已经有答案了 我遇到了以下问题f string gt gt gt a hello how to print hello gt gt gt f a a gt gt gt f a File
  • 使用强光混合模式时突出显示伪影

    我正在 iPhone 应用程序中使用顶部图像的 HardLight 混合模式混合两个图像 它看起来像这样 UIGraphicsBeginImageContext size sourceImage drawInRect rectangle b
  • 如何使用魔杖扭曲图像

    我正在尝试做同样的事情this https stackoverflow com questions 52090350 how to insert image in a mock up老问题但在python using wand 到目前为止我
  • H2O服务器崩溃

    去年我一直在使用 H2O 我已经厌倦了服务器崩溃 我已经放弃了 夜间发布 因为它们很容易被我的数据集崩溃 请告诉我在哪里可以下载稳定的版本 Charles 我的环境是 Windows 10 企业版 内部版本 1607 具有 64 GB 内存
  • While 在范围内循环用户输入

    我有一些代码 我想要求用户输入 1 100 之间的数字 如果他们在这些数字之间输入一个数字 它将打印 Size input 并打破循环 但是 如果他们在外部输入一个数字1 100 它将打印 大小 输入 并继续向他们重新询问一个数字 但我遇到
  • Flask-migrate:更改模型属性并重命名相应的数据库列

    我对 Flask 有一些经验 但对数据库 Flask migrate alembic SqlAlchemy 不太了解 我正在跟进this https blog miguelgrinberg com post the flask mega t
  • 导入错误:无法导入名称“FFProbe”

    我无法获取ffprobe包 https github com simonh10 ffprobe在 Python 3 6 中工作 我使用 pip 安装它 但是当我输入import ffprobe it says Traceback most
  • Spacy 实体规则不适用于基数(社会安全号码)

    我已使用实体规则为社会保障号添加新标签 即使设置了 overwrite ents true 但它仍然无法识别 我验证了正则表达式是正确的 不知道我还需要做什么 我之前尝试过 ner 但结果相同 text My name is yuyyvb
  • multiprocessing.Queue 中的 ctx 参数

    我正在尝试使用 multiprocessing Queue 模块中的队列 实施 https docs python org 3 4 library multiprocessing html exchang objects Between p
  • 哪种方式最适合Python工厂注册?

    这是一个关于这些方法中哪一种被认为是最有效的问题 Pythonic 我不是在寻找个人意见 而是在寻找惯用的观点 我的背景不是Python 所以这会对我有帮助 我正在开发一个可扩展的 Python 3 项目 这个想法类似于工厂模式 只不过它是
  • Python 3 中 int() 和 Floor() 有什么区别?

    在Python 2中 floor 返回一个浮点值 虽然对我来说并不明显 但我发现了一些解释来澄清为什么它可能有用floor 返回浮点数 对于类似的情况float inf and float nan 然而 在Python 3中 floor 返
  • 从 numpy 数组中删除连续的 RGB 值

    我最初根据灰度图像的初始数组创建了一个子数组 从 numpy 数组中删除连续数字 https stackoverflow com questions 50743769 deleting consecutive numbers from a
  • python 中带有 lambda 函数字典的奇怪行为

    我编写了一个用于生成 lambda 常量函数字典的函数 它是一个更复杂函数的一部分 但我已将其简化为下面的代码 def function a interpolators for key in a keys interpolators key
  • 来自 dll 的 Java 调用函数

    我有这个 python 脚本导入zkemkeeperdll 并连接到考勤设备 ZKTeco 这是我正在使用的脚本 from win32com client import Dispatch zk Dispatch zkemkeeper ZKE
  • 打印包含字符串和其他 2 个变量的变量

    var a 8 var b 3 var c hello my name is var a and var b bye print var c 当我运行程序时 var c 会像这样打印出来 hello my name is 8 and 3 b
  • 如何将Python3设置为Mac上的默认Python版本?

    有没有办法将 Python 3 8 3 设置为 macOS Catalina 版本 10 15 2 上的默认 Python 版本 我已经完成的步骤 看看它安装在哪里 ls l usr local bin python 我得到的输出是这样的

随机推荐

  • 是否可以告诉自动映射器在运行时忽略映射?

    我正在使用 Entity Framework 6 和 Automapper 将实体映射到 dtos 我有这个型号 public class PersonDto public int Id get set public string Name
  • MathJax 方程换行

    嘿 如果包含的元素具有固定大小 有谁知道让 MathJax 自动换行方程的好方法 MathJax v2 0 现在包括针对长显示方程的自动 可选 换行 它是由linebreaks的部分HTML CSS您的配置块 请参阅MathJax 文档 h
  • 在 TypeScript 中解构对象时重命名剩余属性变量

    EDIT 我在github上开了一个与此相关的问题 https github com Microsoft TypeScript issues 21265 https github com Microsoft TypeScript issue
  • PostgreSQL 从 9.1 升级到 9.4 后性能下降

    将 Postgres 9 1 升级到 9 4 后 我的性能变得非常慢 下面是两个查询的示例 它们的运行速度明显慢得多 注意 我意识到这些查询可能可以被重写以更有效地工作 但是我主要担心的是升级到较新版本的 Postgres 后 它们的运行速
  • 差异化包装

    升级应用程序时 Test ServiceFabricApplicationPackage命令会对版本号未更改的每个代码包抛出错误 这表示内容已更改 即使代码未更改 我知道有一个功能可以创建部分包 但我无法使用它 我的问题是 如何检查代码包内
  • 如何在其他工作表的应用程序脚本中请求或获得谷歌电子表格访问权限?

    我正在为我的自定义函数编写 A 电子表格的应用程序脚本 并尝试使用从那里获取 B 电子表格中的值openUrl 然而 我得到了ERROR当我使用自定义函数时在电子表格中 在谷歌文档中 它说 如果您的自定义函数抛出错误消息 You do no
  • 使用powershell在其他域上查找“网络用户”?

    我想做的是 net user user1 DOMAIN 但是 我想为计算机未加入但可以访问的域执行此操作 用户分布在 DOMAIN1 和 DOMAIN2 中 我运行它的计算机已加入 DOMAIN1 但会在 DOMAIN2 上查找用户 这可以
  • 在 mongodb 的嵌套数组中插入数据[重复]

    这个问题在这里已经有答案了 可能的重复 MongoDB 更新嵌套数组中的字段 https stackoverflow com questions 9611833 mongodb updating fields in nested array
  • Safari 中的垂直居中

    我在 Safari 中使用 margin auto 0 时遇到垂直居中问题 在嵌套在带有 display inline flex 的 div 内的 div 上 它在 Firefox Chrome Opera 中工作得很好 但在 Safari
  • Travis CI 失败,因为无法接受许可证约束布局

    在我写这个问题之前 我已经搜索过同样的问题 他们确实有导出许可证 因为仍然使用 alpha 版本的约束布局 但现在android已经发布了约束布局的稳定版本 我尝试了很多设置但仍然失败 我最新的 travis yml language an
  • Django - 显示图像字段

    我刚刚开始使用 Django 还没有找到很多关于如何显示的信息imageField 所以我做了这个 模型 py class Car models Model name models CharField max length 255 pric
  • 如何判断闭合路径是否包含给定点?

    在 Android 中 我有一个 Path 对象 我碰巧知道它定义了一条闭合路径 并且我需要弄清楚给定点是否包含在路径中 我所希望的是类似的东西 路径 contains int x int y 但这似乎不存在 我寻找这个的具体原因是因为我在
  • 如何使用 signalr 将 json 对象发送到 .net 服务器

    我正在开发一个 Angular 应用程序 我必须使用 netcore 服务器和 signalR 将数据从角度形式发送到外部服务器 我可以使用信号集线器在 Angular 客户端和控制器之间建立连接 但我很困惑如何将 json 对象从客户端发
  • 在 bash 中选择不同的可执行文件

    当我想跑步的时候make为了生成一些可执行文件 它总是使用 Sunmake位于 在 usr local bin make而不是 GNU make 可以在以下位置找到 usr sfw bin gmake 我如何告诉操作系统使用 GNU mak
  • TkInter:了解解除绑定功能

    TkInter 是否unbind http effbot org tkinterbook widget htm Tkinter Widget unbind method函数阻止应用它的小部件将更多事件绑定到小部件 澄清 假设我在程序的早期将
  • Python 中以下代码有什么问题?

    我试图对一个字段实施约束 但它不会导致约束验证 而是允许保存记录而不显示任何约束消息 def check contact number self cr uid ids context None for rec in self browse
  • 在 AOSP Android 6.0 上更新 WebView

    我正在开发基于 AOSP Android 6 0 Marshmallow 的设备 我想将标准 Android webview 更新到最新版本以使用最新的 JavaScript 为此我更换了external chromium webview
  • 使 JButton 在 JTable 内可单击

    这是我想做的事情的屏幕截图 发生的情况是 JButton 显示正确 但当我单击它时没有任何反应 经过一番搜索 我发现Object由返回table getValueAt 是一个字符串而不是 JButton 这是代码 tblResult new
  • 是否有“纯粹适用的任一”的标准名称或实现?

    我经常发现我所谓的 纯粹应用性 的用处Either i e Either与Applicative只要我们不实现一个实例就可用Monad实例也是如此 newtype AEither e a AEither unAEither Either e
  • 使用可变批量大小加载数据?

    我目前正在研究基于补丁的超分辨率 大多数论文将图像分割成更小的补丁 然后使用这些补丁作为模型的输入 我能够使用自定义数据加载器创建补丁 代码如下 import torch utils data as data from torchvisio