Pytorch的mean和std调查实例

时间:2021-05-22

如下所示:

# coding: utf-8from __future__ import print_functionimport copyimport clickimport cv2import numpy as npimport torchfrom torch.autograd import Variablefrom torchvision import models, transformsimport matplotlib.pyplot as pltimport load_caffemodelimport scipy.io as sio# if model has LSTM# torch.backends.cudnn.enabled = Falseimgpath = 'D:/ck/files_detected_face224/' imgname = 'S055_002_00000025.png' # angerimage_path = imgpath + imgnamemean_file = [0.485, 0.456, 0.406]std_file = [0.229, 0.224, 0.225]raw_image = cv2.imread(image_path)[..., ::-1]print(raw_image.shape)raw_image = cv2.resize(raw_image, (224, ) * 2)image = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=mean_file, std =std_file, #mean = mean_file, #std = std_file, )])(raw_image).unsqueeze(0)print(image.shape)convert_image1 = image.numpy()convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * Wconvert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * Cprint(convert_image1.shape)convert_image1 = convert_image1 * 255diff = raw_image - convert_image1err = np.max(diff)print(err)plt.imshow(np.uint8(convert_image1))plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]std_file = [0.229, 0.224, 0.225]# coding: utf-8import matplotlib.pyplot as pltimport argparseimport osimport numpy as npimport torchvisionimport torchvision.transforms as transformsdataset_names = ('cifar10','cifar100','mnist')parser = argparse.ArgumentParser(description='PyTorchLab')parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names, help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')args = parser.parse_args()data_dir = os.path.join('.', args.dataset)print(args.dataset)args.dataset = 'cifar10'if args.dataset == "cifar10": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(train_set.train_data.mean(axis=(0,1,2))/255) print(train_set.train_data.std(axis=(0,1,2))/255) # imshow image train_data = train_set.train_data ind = 100 img0 = train_data[ind,...] ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe # error produce #b,g,r=cv2.split(img0) #img0=cv2.merge([r,g,b]) print(img0.shape) print(type(img0)) plt.imshow(img0) plt.show() # in ship in sea #img0 = cv2.resize(img0,(224,224)) #cv2.imshow('img0',img0) #cv2.waitKey()elif args.dataset == "cifar100": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(np.mean(train_set.train_data, axis=(0,1,2))/255) print(np.std(train_set.train_data, axis=(0,1,2))/255)elif args.dataset == "mnist": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform) #print(vars(train_set)) print(list(train_set.train_data.size())) print(train_set.train_data.float().mean()/255) print(train_set.train_data.float().std()/255)

结果:

cifar10Files already downloaded and verified(50000, 32, 32, 3)[ 0.49139968 0.48215841 0.44653091][ 0.24703223 0.24348513 0.26158784](32, 32, 3)<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 datasetdata = load('cifar10_train_data.mat');train_data = data.train_data;disp(size(train_data));temp = mean(train_data,1);disp(size(temp));train_data = double(train_data);% compute mean_file mean_val = mean(mean(mean(train_data,1),2),3)/255;% compute std_file temp1 = train_data(:,:,:,1);std_val1 = std(temp1(:))/255;temp2 = train_data(:,:,:,2);std_val2 = std(temp2(:))/255;temp3 = train_data(:,:,:,3);std_val3 = std(temp3(:))/255;mean_val = squeeze(mean_val);std_val = [std_val1, std_val2, std_val3];disp(mean_val);disp(std_val);% result: mean_val: [0.4914, 0.4822, 0.4465]% std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章