如何计算cifar10数据的平均值和标准差

2024-05-13

Pytorch 使用以下值作为 cifar10 数据的平均值和标准差: 变换.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

我需要理解计算背后的概念,因为这些数据是 3 通道图像,我不明白什么是相加的,什么是除什么的等等。 另外,如果有人可以分享计算平均值和标准差的代码,将非常感激。


0.5 值只是近似值cifar10三个通道(r、g、b)的平均值和标准值。 cifar10 训练集的精确值为

  • mean: 0.49139968, 0.48215827 ,0.44653124
  • std: 0.24703233 0.24348505 0.26158768

您可以使用以下脚本计算这些:

import torch
import numpy
import torchvision.datasets as datasets
from torchvision import transforms

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())

imgs = [item[0] for item in cifar_trainset] # item[0] and item[1] are image and its label
imgs = torch.stack(imgs, dim=0).numpy()

# calculate mean over each channel (r,g,b)
mean_r = imgs[:,0,:,:].mean()
mean_g = imgs[:,1,:,:].mean()
mean_b = imgs[:,2,:,:].mean()
print(mean_r,mean_g,mean_b)

# calculate std over each channel (r,g,b)
std_r = imgs[:,0,:,:].std()
std_g = imgs[:,1,:,:].std()
std_b = imgs[:,2,:,:].std()
print(std_r,std_g,std_b)

此外,您可能会发现相同的平均值和标准值here https://github.com/tomgoldstein/loss-landscape/blob/64ef4d57f8dabe79b57a637819c44e48eda98f33/cifar10/dataloader.py#L15 and here https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151#gistcomment-2851662

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何计算cifar10数据的平均值和标准差 的相关文章

随机推荐