1. 问题描述

代码运行至下列语句时:

for i, data in enumerate(train_loader):

有时会遇到以下三种报错:

TypeError: img should be PIL Image. Got <class 'dict'>
TypeError: img should be PIL Image. Got <class 'Torch.Tensor'>
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

2. 分析

这类问题往往出错在 datasetgetitem 函数处,因为这里涉及到对数据做 transformtransform 中又涉及到一些图像转换的问题。

2.1 首先,我们先看一下官方的 transform 是怎么写的,以transform.Resize() 类为例
class Resize(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        return F.resize(img, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

我们来看 call() 方法,这里有几行注释,大致意思是, call 方法接受一个 PIL Image 格式的输入,经过 resize 方法后 返回一个 PIL Image 格式的输出,也就是说, pytorch 官方 中的 transform 默认是需要一个 PIL Image 格式输入的 。而有很多朋友会采用不同的方式读取 image ,比如使用 cv2.imread() 函数,我们可以来测试一下:

import cv2

img = cv2.imread('data/davis2016/JPEGImages/480p/train/00000.jpg')
type(img)
# Out[7]: numpy.ndarray

可以看见,使用 cv2.imread 函数读取的 imagenumpy.ndarray 格式的,这时候如果直接对这个 imagetransform 就会出现类型不匹配问题,这时候,需要在你写的 train_transform 中加一个 transform.ToPILImage() 函数,例如:

train_transforms = t.Compose([t.ToPILImage(), # Here
                              t.RandomHorizontalFlip(),
                              t.Resize((480, 852)),
                              t.ToTensor()])

依次可以类推,transform.Compose() 方法其实就是把一系列我们要对 image 做的操作(数据预处理,数据增强等)排列到一起,因此,我们要保证其从第一个函数到最后一个函数的输入都要是 PIL Image 格式。那些遇见错误的同学,要么是没有将 PIL Image 格式的图像做为 transform.Compose() 方法输入,要么是虽然输入了 PIL Image 格式图像,但是在一些列操作未结束之前就将其转为了 Tensor,见下列代码:

train_transforms = t.Compose([t.ToPILImage(),
                              t.RandomHorizontalFlip(),
                              t.ToTensor(), # Here
                              t.Resize((480, 852))])

这时,不用运行,我们就知道,这里肯定出错了,因为刚刚我们验证过 transform.Resize()call 方法需要接受一个 PIL Image 格式的图像,而你提前使用了 transform.ToTensor() 方法将其转为了torch.Tensor 格式,这就肯定错了。

3. 总结

3.1 对于错误:
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

你应该检查你的图像是否为 PIL Image 格式,如果不是,可以使用 transform.ToPILImage() 方法。

3.2 对于错误:
TypeError: img should be PIL Image. Got <class 'Torch.Tensor'>

你应该检查你的 transform.ToTensor() 方法是否写在了你要做的操作之前,如果是,调换一下它们的位置。

3.3 对于错误:
TypeError: img should be PIL Image. Got <class 'dict'>

这个应该很少有人会遇见,这是我需要将一个 img 和它的 gt 做为一个字典一起返回的时候遇见的一个错误:

    def __getitem__(self, idx):
        img = readImage(self.img_list[idx], channel=3)
        gt = readImage(self.mask_list[idx], channel=1)
        sample = {'images': img, 'gts': gt}

        if self.transform is not None:
            sample = self.transform(sample)

        return sample

解决方法是,先分别对 imggttransform 再把它们组合到一个字典里:

    def __getitem__(self, idx):
        img = readImage(self.img_list[idx], channel=3)
        gt = readImage(self.mask_list[idx], channel=1)

        if self.transform is not None:
            img = self.transform(img)
            gt = self.transform(gt)
            
        sample = {'images': img, 'gts': gt}

        return sample

4. 推荐

4.1 通过上面的分析,想必大家已经感觉到,写一个 train_transform 比较好的流程是:
  1. PIL.Image 读取图像文件
  2. 写要对图像做的操作
  3. 最后写 transform.ToTensor() 方法
4.2 我提供一些代码来完整的过一遍4.1:
  1. 这是一个用 PIL 模块读取图像的函数,channel 是通道数目
    代码参考自:https://github.com/michuanhaohao/AlignedReID/blob/master/util/dataset_loader.py
import PIL.Image
def readImage(img_path, channel=3):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not os.path.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        if channel == 3:
            try:
                img = Image.open(img_path).convert('RGB')
                got_img = True
            except IOError:
                print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
                pass
        elif channel == 1:
            try:
                img = Image.open(img_path).convert('1')
                got_img = True
            except IOError:
                print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
                pass
    return img
  1. 有了 PIL.Image 格式的 img 后,我们就可以送入 transform.Compose() 了:
train_transforms = t.Compose([t.RandomHorizontalFlip(),
                              t.Resize((480, 852)),
                              t.ToTensor()])

注意:transform.ToTensor() 最好写在最后。

ps:其实…也有一些方法(例如,随即擦除,标准化等)不需要输入格式为 PIL.Image,下面这样写是可以的:

    transform_train = T.Compose([
        T.Random2DTranslation((256, 128)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.RandomErasing(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

具体情况还是要自己多试一试,多去看看官方的写法。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐