目录

一、CIFAR10数据集简介:

二、CIRAR10数据集格式

三、CIFAR10数据集下载与读取

1、下载:

 2、读取

 四、class:torchvision.datasets.CIFAR10 解读


一、CIFAR10数据集简介:

  CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。

             

 

注:杰弗里·埃弗里斯特·辛顿FRS(英语:Geoffrey Everest Hinton,1947年12月6日-),英国出生的加拿大计算机学家和心理学家多伦多大学教授。以其在类神经网络方面的贡献闻名。辛顿是反向传播算法对比散度算法的发明人之一,也是深度学习的积极推动者[1],被誉为“深度学习之父”[2]。辛顿因在深度学习方面的贡献与约书亚·本希奥杨立昆一同被授予了2018年的图灵奖[3]                                                                                                        -wikipedia

 

二、CIRAR10数据集格式

数据集包括:50000张训练集,10000张测试集,其中每张图片是RGB格式,像素32*32。

图片各种格式简介:

图片格式介绍 - 简书

图片格式介绍 - 知乎

 

三、CIFAR10数据集下载与读取

1、下载

 

from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_dataset = CIFAR10(root='.',
                        train=True,
                        transform=transform,
                        download=True)

test_dataset  = CIFAR10(root='.',
                        train=False,
                        transform=transform,)

 2、读取与展示

 

显示图片:

Pytorch中Tensor与各种图像格式的相互转化_明泽.的博客-CSDN博客_tensor转图像

 2 python 读取并显示图片的两种方法 - 邊城浪子 - 博客园

pytorch 张量tensor 转为 jpg 图片_Tchunren的博客-CSDN博客_pytorch将tensor保存为图片 

 四、class:torchvision.datasets.CIFAR10 解读

import os.path
import pickle
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image

from .utils import check_integrity, download_and_extract_archive
from .vision import VisionDataset



[docs]class CIFAR10(VisionDataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__(root, transform=transform, target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

 参考:

1

 torchvision.datasets.cifar — Torchvision 0.12 documentation

Dataset之CIFAR-10:CIFAR-10数据集简介、下载、使用方法之详细攻略_一个处女座的程序猿的博客-CSDN博客_cifar-10cifar10数据集下载及图片格式解析_Briwisdom的博客-CSDN博客_cifar10下载

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐