第一次这么认真分享,是因为,我找了好久也没找到和自己目标一致的,只好参考别人的,自行修改了一下下。我的目标是对单通道、只包含一个分割目标的数据集进行归一化。如果想要了解多个分割目标的归一化,可以参考下面的链接:
https://hulin.blog.csdn.net/article/details/116600119?spm=1001.2014.3001.5506
1、准备数据集(dataset)

import os
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.img_path = os.listdir(root_dir)#获取训练数据列表
        self.label_path = os.listdir(label_dir)
        self.transformer = torchvision.transforms.Compose([
            transforms.Resize((256,256)),
            transforms.CenterCrop(256),
            transforms.RandomRotation(180),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.4573,std=0.3618)#输入是tensor类型,0.4573和0.3618是计算出来的,后面说如何计算,目前准备阶段可以先忽略这一行语句
        ])
        
    def __getitem__(self, index):
        #read img:
        image_name = self.img_path[index]#获取每一个训练数据名称
        image_path = os.path.join(self.root_dir, image_name)#获取每一个训练数据的路径!!!很重要,不然会打不开文件!!!!
        img_pil = Image.open(image_path)
        #read label:
        label_name = self.label_path[index]
        label_path = os.path.join(self.label_dir,label_name)
        label_pil = Image.open(label_path)
        #data enhance:
        img_tran = self.transformer(img_pil)
        label_tran = self.transformer(label_pil)
        label_tran = torch.squeeze(label_tran)#CEloss要求label:[b,h,w]
        image = img_tran.float()#指定img为floattensor型
        labels = label_tran.long()#指定label为longtensor型
        return image,labels

    def __len__(self):
        return len(self.img_path)
#训练和验证数据的路径:
train_image_dir = 'E:\\pycharm\\UNet\\train_image'
train_label_dir = 'E:\\pycharm\\UNet\\train_label'
valid_image_dir = 'E:\\pycharm\\UNet\\valid_image'
valid_label_dir = 'E:\\pycharm\\UNet\\valid_label'
#调用MyData类,创建dataset
train_dataset = MyData(train_image_dir,train_label_dir)
valid_dataset = MyData(valid_image_dir,valid_label_dir)

以上就准备好了要归一化的数据集,使用pytorch中的transforms.Normalize(mean,std)进行归一化,要求input是tensor类型,所以前面我将该语句放在了transforms.ToTensor()后面。
接下来的问题就是,如何知道我们的这个数据集的mean(均值)和std(标准差)。
2、计算mean和std:
该方法适用于单通道、只包含一个分割目标的数据集,如果是多个目标,则需要对每个目标(一个目标对应一个通道)进行计算mean和std,参照开头链接。

def getstat(dataset):
    print(len(dataset))
    loader = torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,pin_memory = True)
    mean = torch.zeros(1)#因为我的数据集是单通道的,只包含目标(1)和背景(0),所以我只需要计算一个通道的mean和std
    std = torch.zeros(1)
    for x,_ in loader:#计算loader中所有数据的mean和atd的累积
        mean += x.mean()
        std += x.std()
    mean = torch.div(mean,len(dataset))#得到整体数据集mean的平均
    std = torch.div(std,len(dataset))
    return list(mean.numpy()),list(std.numpy())#返回mean和std的list

mean,std = getstat(train_dataset)#调用getstat
mean_,std_ = getstat(valid_dataset)
print(mean,std)
print(mean_,std_)

结果如下:

如有问题,还望指教。(小声说:作者也是刚开始学习)

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐