目录

ImageFolder 加载数据集

使用pytorch提供的Dataset类创建自己的数据集。

Dataset加载数据集

接下来我们就可以构建我们的网络架构:

 训练我们的网络:

 保存网络模型(这里不止是保存参数,还保存了网络结构)


pytorch加载图片数据集有两种方法。

1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别

 

在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。

ImageFolder 加载数据集

1. 导入包和设置transform

from torchvision.datasets import ImageFolder

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader

transforms = transforms.Compose([
    transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:
    transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片
    transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
]) 

2.加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。

path = r'D:\dataset_deep_learning\Flower_Orig_dataset'

data_train = datasets.ImageFolder(path, transform=transforms)

data_loader = DataLoader(data_train, batch_size=64, shuffle=True)

for i, data in enumerate(data_loader):
    images, labels = data

    # 打印数据集中的图片
    img = torchvision.utils.make_grid(images).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

    break

使用pytorch提供的Dataset类创建自己的数据集。

具体步骤:

1.  首先要有一个txt文件, 这个文件格式是: 图片路径   标签  图片文件夹.  这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。

2. 有了这个txt文件, 我们就可以在类里面构造我们的数据集.

2.1    把图片路径和图片标签分割开, 有三个列表, 一个列表是图片路径名, 一个列表是标签号,一个列表是这类图片的文件夹 。 有一点就是第 i 个图片列表和 第 i 个标签是对应的

3. 重写__len__方法  和  __getitem__方法

3.1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可

4.就可以构建数据集实例和加载数据集.

文件结构如图:

 定义一个用来生成[ 图片路径 标签  该类图片文件夹名] 这样的txt文件函数(因为用了a追加的方式,所以,flower_orig和sunflower两个文件夹下的都被写进data.txt文件了)

#打开存放图片的文件夹,然后遍历文件名,把文件名字, label 还有 文件夹名写入data.txt文件中。

import os

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)  

    data = os.listdir(path)

    f = open(root + '\\' + 'data.txt', 'a')

    for line in data:
        f.write(line + ' ' + str(label) + ' ' + file_name + '\n')
    f.close()

path = r'D:\dataset_deep_learning\Flower_Orig_dataset'

# 调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

 现在看看查看data.txt文件的格式如图:(由图中三部分组成)

现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息

Dataset加载数据集

#我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径
import os

import numpy as np
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms



transforms = transforms.Compose([
    transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:
    transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片
    transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
])


class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path

        self.txt_root = self.root + '\\' + 'data.txt'

        f = open(self.txt_root, 'r')
        data = f.readlines()

        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            #print(word[0], word[1], word[2])   
            #word[0]是图片名字.jpg  word[1]是label  word[2]是文件夹名,如sunflower
            imgs.append(os.path.join(self.root,word[2], word[0]))

            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]

        img = Image.open(img).convert('RGB')

        # 此时img是PIL.Image类型   label是str类型

        if self.transform is not None:
            img = self.transform(img)

        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)

        return img, label

 加载我们的数据集并查看我们加载到图片:

path = r'D:\数据集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)

data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)


for i, data in enumerate(data_loader):
    images, labels = data

    # 打印数据集中的图片
    img = torchvision.utils.make_grid(images).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

    break

接下来我们就可以构建我们的网络架构:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)

    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s

        return num_features

 训练我们的网络:

model = Net()

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data

        out = model(images)

        loss = criterion(out, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0

print('finished train')

 保存网络模型(这里不止是保存参数,还保存了网络结构)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b权重值

# 读取模型
model = torch.load('model_name.pth')

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐