Fashion MNIST

MNIST数据集可能是计算机视觉所接触的第一个图片数据集。而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可以有更好的区分度。
Fashion MNIST 数据集包含 60000 张图片的训练集和 10000 张图片的测试集。图片的大小为 28×28,共784个像素。像素的灰度值介于0~255之间的整数。
数据集分为10个类别,分别是:

  • ‘T-shirt/top’,
  • ‘Trouser’,
  • ‘Pullover’,
  • ‘Dress’,
  • ‘Coat’,
  • ‘Sandal’,
  • ‘Shirt’,
  • ‘Sneaker’,
  • ‘Bag’,
  • ‘Ankle boot’

1. 使用 torchvision.datasets 加载数据集

import torch
import torchvision
import Image

# 使用 torchvision.datasets.FashionMNIST 下载数据
# data_path: 数据集的保存路径
# train: True下载训练集,False下载测试集
# transform: 图片预处理
# download: 是否从网络上下载数据
>>> data_path = "保存路径"
>>> train_data = torchvision.datasets.FashionMNIST(data_path, train=True, transform=None, download=True)

# 查看数据内容
>>> train_data.data.shape
torch.Size([60000, 28, 28])
>>> train_data.targets.shape
torch.Size([60000])
>>> train_data.classes
['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']
# 查看数据集的数据类型
>>> train_data
Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ..\pytorch\data
    Split: Train

# Dataset FashionMNIST对象是 torch.utils.data.dataset.Dataset 的子集

>>> train_data.__class__.__mro__
(torchvision.datasets.mnist.FashionMNIST,
 torchvision.datasets.mnist.MNIST,
 torchvision.datasets.vision.VisionDataset,
 torch.utils.data.dataset.Dataset,
 typing.Generic,
 object)

2. 解析“…-idx3-ubyte”文件

如果已经将数据集下载到了本地,可以直接解析文件来导入数据集。
数据集包括以下四个文件:

train-images-idx3-ubyte
train-labels-idx1-ubyte
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte

这是四个二进制文件,其中的idx3表示有三个维度,idx1表示有一个维度。 这里使用常用的二进制解析库 struct 来解析文件。

import numpy as np
import struct

# 文件路径
data_path = r'路径'
file_names = ['t10k-images-idx3-ubyte',
              't10k-labels-idx1-ubyte',
              'train-images-idx3-ubyte',
              'train-labels-idx1-ubyte']

def decode_idx3_ubyte(file):
    """
    解析数据文件
    """
    # 读取二进制数据
    with open(file, 'rb') as fp:
        bin_data = fp.read()
    
    # 解析文件中的头信息
    # 从文件头部依次读取四个32位,分别为:
    # magic,numImgs, numRows, numCols
    # 偏置
    offset = 0
    # 读取格式: 大端
    fmt_header = '>iiii'
    magic, numImgs, numRows, numCols = struct.unpack_from(fmt_header, bin_data, offset)
    print(magic,numImgs,numRows,numCols)
    
    # 解析图片数据
    # 偏置掉头文件信息
    offset = struct.calcsize(fmt_header)
    # 读取格式
    fmt_image = '>'+str(numImgs*numRows*numCols)+'B'
    data = torch.tensor(struct.unpack_from(fmt_image, bin_data, offset)).reshape(numImgs, numRows, numCols)
    return data


def decode_idx1_ubyte(file):
    """
    解析标签文件
    """
    # 读取二进制数据
    with open(file, 'rb') as fp:
        bin_data = fp.read()
    
    # 解析文件中的头信息
    # 从文件头部依次读取两个个32位,分别为:
    # magic,numImgs
    # 偏置
    offset = 0
    # 读取格式: 大端
    fmt_header = '>ii'
    magic, numImgs = struct.unpack_from(fmt_header, bin_data, offset)
    print(magic,numImgs)
    
    # 解析图片数据
    # 偏置掉头文件信息
    offset = struct.calcsize(fmt_header)
    # 读取格式
    fmt_image = '>'+str(numImgs)+'B'
    data = torch.tensor(struct.unpack_from(fmt_image, bin_data, offset))
    return data

train_set = (decode_idx3_ubyte(os.path.join(data_path, file_names[0])),
             decode_idx1_ubyte(os.path.join(data_path, file_names[1])))
test_set = (decode_idx3_ubyte(os.path.join(data_path, file_names[2])),
            decode_idx1_ubyte(os.path.join(data_path, file_names[3])))

运行结果:

2051 10000 28 28
2049 10000
2051 60000 28 28
2049 60000

3. 数据分批

使用 pytorchDataLoader 对象分批读取数据。

# 将data和label张量封装为数据类,可通过第一个维度来索引每一个样本。
train_data = torch.utils.data.TensorDataset(*train_set)
test_data = torch.utils.data.TensorDataset(*test_set)

# 创建数据加载器,小批次读取数据
batch_size = 5050
train_Loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# 分批读取
for X,y in train_Loader:
    print(X.shape, y.shape)

运行结果:

torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([4450, 28, 28]) torch.Size([4450])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐