雷郭出品
函数的用途:
将NCHW的tensor以网格图的形式存储到硬盘中,该图也叫做雪碧图sprite image
如下图所示:
将多张图以网格的形式拼凑起来,每张图的大小是28*28,单通道
那宽高如何确定?
我们可以来看看该函数的源码
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
format: Optional[str] = None,
) -> None:
"""Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp (string or file object): A filename or a file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
可以看到nrow默认为8
padding默认为2
当我们的tensor形状为96* 1* 28 * 28的时候
网格的行和列对应的格子数分别为(N/nrow,nrow)
即(12,8)
对应的就是第一张图
但是实际当我去查看图片的像素大小时,由于padding的存在
像素大小并不是(12 * 28,8 * 28)
而是(12 * 28+13 * 2,8 * 28+9 * 2)
还有一点要注意,当你存储图片的时候由于总的图片数可能不能被batchsize整除
所以当雪碧图的格子数跟batchsize不对应的时候
不要犯愁
这是正常
我也是看了好几个小时才突然从下面的打印中得到的灵感
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([128, 1, 28, 28])
real_images的shape: torch.Size([96, 1, 28, 28])
real_img的shape: torch.Size([96, 784])
可以看到一开始的形状都是128
到了最后一个就变成了96
然后再次使用还是96
我就立刻想到了余数
然后我再验证6000=128 * 468+96
完美符合验证