------------------------------------------------2021年6月18日重大更新--------------------------------------------------------------

目前已经退出bug修复之后的tensorflow2.3物体分类代码,大家可以训练自己的数据集,快来试试吧

csdn教程链接:手把手教你用tensorflow2.3训练自己的分类数据集_CSDN博客

b站视频链接:手把手教你用tensorflow2训练自己的数据集

数据集链接:计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_CSDN博客

代码链接:vegetables_tf2.3: 基于tensorflow2.3开发的水果蔬菜识别系统 (gitee.com)

------------------------------------------------------------------dejahu---------------------------------------------------------------------
image-20210305103139860

垃圾分类是目前社会的一个热点,分类的任务是计算机视觉任务中的基础任务,相对来说比较简单,只要找到合适的数据集,垃圾分类的模型构建并不难,这里我找到一份关于垃圾分类的数据集,一共有四个大类和245个小类,大类分别是厨余垃圾、可回收物、其他垃圾和有害垃圾,小类主要是垃圾的具体类别,果皮、纸箱等。

为了方便大家使用,我已经提前将数据集进行了处理,按照8比1比1的比例将原始数据集划分成了训练集、验证集和测试集,链接如下:

垃圾分类数据集和tf代码-8w张图片245个类.zip-深度学习文档类资源-CSDN下载

代码结构

trash1.0
├─ .idea idea配置文件
├─ imgs 图片文件
├─ main_window.py 图形界面代码
├─ models
│    └─ mobilenet_trashv1_2.pt
├─ old 一些废弃的代码
├─ readme.md 你现在看到的
├─ test.py 测试文件
├─ test4dataset.py  测试所有的数据集
├─ test4singleimg.py 测试单一的图片
├─ train_245_class.py 训练代码
└─ utils.py 工具类,用于划分数据集

训练

训练前请执行命令按照好项目所需的依赖库,关于如何在python中使用conda和pip对项目包管理可以看这篇文章或者是看我b站的这个视频,里面有详细的讲解。

csdn文章:Windows下GPU深度学习环境的配置(pytorch和tensorflow)_ECHOSON的博客-CSDN博客

b站视频:【大作业怎么搞01】基于tensorflow2.3的花卉识别程序_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili

conda create -n torch1.6 python==3.6.10
conda activate torch1.6
conda install pytorch torchvision cudatoolkit=10.2 # GPU(可选)
conda install pytorch torchvision cpuonly
pip install opencv-python
pip install matplotlib

首先需要把数据集下载之后进行解压,记住解压的路径,并在train.py的18行将数据集路径修改为你本地的数据集路径,修改之后执行运行train.py文件即可开始模型训练,训练之后的模型将会保存在models目录下。

模型训练部分则选用了大名鼎鼎的mobilenet,mobilenet是比较轻量的网络,在cpu上也可以运行的很快,训练的代码如下,首先通过pytorch的dataloader加载数据集,并加载预训练的mobilenet进行微调。

# coding:utf-8
# TODO 添加一个图形化界面
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
from old.train_based_torchvision import Net

names = ['其他垃圾_PE塑料袋', '其他垃圾_U型回形针', '其他垃圾_一次性杯子', '其他垃圾_一次性棉签', '其他垃圾_串串竹签', '其他垃圾_便利贴', '其他垃圾_创可贴', '其他垃圾_卫生纸',
         '其他垃圾_厨房手套', '其他垃圾_厨房抹布', '其他垃圾_口罩', '其他垃圾_唱片', '其他垃圾_图钉', '其他垃圾_大龙虾头', '其他垃圾_奶茶杯', '其他垃圾_干燥剂', '其他垃圾_彩票',
         '其他垃圾_打泡网', '其他垃圾_打火机', '其他垃圾_搓澡巾', '其他垃圾_果壳', '其他垃圾_毛巾', '其他垃圾_涂改带', '其他垃圾_湿纸巾', '其他垃圾_烟蒂', '其他垃圾_牙刷',
         '其他垃圾_电影票', '其他垃圾_电蚊香', '其他垃圾_百洁布', '其他垃圾_眼镜', '其他垃圾_眼镜布', '其他垃圾_空调滤芯', '其他垃圾_笔', '其他垃圾_胶带', '其他垃圾_胶水废包装',
         '其他垃圾_苍蝇拍', '其他垃圾_茶壶碎片', '其他垃圾_草帽', '其他垃圾_菜板', '其他垃圾_车票', '其他垃圾_酒精棉', '其他垃圾_防霉防蛀片', '其他垃圾_除湿袋', '其他垃圾_餐巾纸',
         '其他垃圾_餐盒', '其他垃圾_验孕棒', '其他垃圾_鸡毛掸', '厨余垃圾_八宝粥', '厨余垃圾_冰激凌', '厨余垃圾_冰糖葫芦', '厨余垃圾_咖啡', '厨余垃圾_圣女果', '厨余垃圾_地瓜',
         '厨余垃圾_坚果', '厨余垃圾_壳', '厨余垃圾_巧克力', '厨余垃圾_果冻', '厨余垃圾_果皮', '厨余垃圾_核桃', '厨余垃圾_梨', '厨余垃圾_橙子', '厨余垃圾_残渣剩饭', '厨余垃圾_水果',
         '厨余垃圾_泡菜', '厨余垃圾_火腿', '厨余垃圾_火龙果', '厨余垃圾_烤鸡', '厨余垃圾_瓜子', '厨余垃圾_甘蔗', '厨余垃圾_番茄', '厨余垃圾_秸秆杯', '厨余垃圾_秸秆碗',
         '厨余垃圾_粉条', '厨余垃圾_肉类', '厨余垃圾_肠', '厨余垃圾_苹果', '厨余垃圾_茶叶', '厨余垃圾_草莓', '厨余垃圾_菠萝', '厨余垃圾_菠萝蜜', '厨余垃圾_萝卜', '厨余垃圾_蒜',
         '厨余垃圾_蔬菜', '厨余垃圾_薯条', '厨余垃圾_薯片', '厨余垃圾_蘑菇', '厨余垃圾_蛋', '厨余垃圾_蛋挞', '厨余垃圾_蛋糕', '厨余垃圾_豆', '厨余垃圾_豆腐', '厨余垃圾_辣椒',
         '厨余垃圾_面包', '厨余垃圾_饼干', '厨余垃圾_鸡翅', '可回收物_不锈钢制品', '可回收物_乒乓球拍', '可回收物_书', '可回收物_体重秤', '可回收物_保温杯', '可回收物_保鲜膜内芯',
         '可回收物_信封', '可回收物_充电头', '可回收物_充电宝', '可回收物_充电牙刷', '可回收物_充电线', '可回收物_凳子', '可回收物_刀', '可回收物_包', '可回收物_单车', '可回收物_卡',
         '可回收物_台灯', '可回收物_吊牌', '可回收物_吹风机', '可回收物_呼啦圈', '可回收物_地球仪', '可回收物_地铁票', '可回收物_垫子', '可回收物_塑料制品', '可回收物_太阳能热水器',
         '可回收物_奶粉桶', '可回收物_尺子', '可回收物_尼龙绳', '可回收物_布制品', '可回收物_帽子', '可回收物_手机', '可回收物_手电筒', '可回收物_手表', '可回收物_手链',
         '可回收物_打包绳', '可回收物_打印机', '可回收物_打气筒', '可回收物_扫地机器人', '可回收物_护肤品空瓶', '可回收物_拉杆箱', '可回收物_拖鞋', '可回收物_插线板', '可回收物_搓衣板',
         '可回收物_收音机', '可回收物_放大镜', '可回收物_日历', '可回收物_暖宝宝', '可回收物_望远镜', '可回收物_木制切菜板', '可回收物_木桶', '可回收物_木棍', '可回收物_木质梳子',
         '可回收物_木质锅铲', '可回收物_木雕', '可回收物_枕头', '可回收物_果冻杯', '可回收物_桌子', '可回收物_棋子', '可回收物_模具', '可回收物_毯子', '可回收物_水壶',
         '可回收物_水杯', '可回收物_沙发', '可回收物_泡沫板', '可回收物_灭火器', '可回收物_灯罩', '可回收物_烟灰缸', '可回收物_热水瓶', '可回收物_燃气灶', '可回收物_燃气瓶',
         '可回收物_玩具', '可回收物_玻璃制品', '可回收物_玻璃器皿', '可回收物_玻璃壶', '可回收物_玻璃球', '可回收物_瑜伽球', '可回收物_电动剃须刀', '可回收物_电动卷发棒',
         '可回收物_电子秤', '可回收物_电熨斗', '可回收物_电磁炉', '可回收物_电脑屏幕', '可回收物_电视机', '可回收物_电话', '可回收物_电路板', '可回收物_电风扇', '可回收物_电饭煲',
         '可回收物_登机牌', '可回收物_盒子', '可回收物_盖子', '可回收物_盘子', '可回收物_碗', '可回收物_磁铁', '可回收物_空气净化器', '可回收物_空气加湿器', '可回收物_笼子',
         '可回收物_箱子', '可回收物_纸制品', '可回收物_纸牌', '可回收物_罐子', '可回收物_网卡', '可回收物_耳套', '可回收物_耳机', '可回收物_衣架', '可回收物_袋子', '可回收物_袜子',
         '可回收物_裙子', '可回收物_裤子', '可回收物_计算器', '可回收物_订书机', '可回收物_话筒', '可回收物_豆浆机', '可回收物_路由器', '可回收物_轮胎', '可回收物_过滤网',
         '可回收物_遥控器', '可回收物_量杯', '可回收物_金属制品', '可回收物_钉子', '可回收物_钥匙', '可回收物_铁丝球', '可回收物_铅球', '可回收物_铝制用品', '可回收物_锅',
         '可回收物_锅盖', '可回收物_键盘', '可回收物_镊子', '可回收物_闹铃', '可回收物_雨伞', '可回收物_鞋', '可回收物_音响', '可回收物_餐具', '可回收物_餐垫', '可回收物_饰品',
         '可回收物_鱼缸', '可回收物_鼠标', '有害垃圾_指甲油', '有害垃圾_杀虫剂', '有害垃圾_温度计', '有害垃圾_灯', '有害垃圾_电池', '有害垃圾_电池板', '有害垃圾_纽扣电池',
         '有害垃圾_胶水', '有害垃圾_药品包装', '有害垃圾_药片', '有害垃圾_药瓶', '有害垃圾_药膏', '有害垃圾_蓄电池', '有害垃圾_血压计']


class MainWindow(QTabWidget):
    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon('imgs/面性铅笔.png'))
        self.setWindowTitle('垃圾识别')
        # 加载网络
        self.net = torch.load("models/mobilenet_trashv1_2.pt", map_location=lambda storage, loc: storage)
        self.transform = transforms.Compose(
            # 这里只对其中的一个通道进行归一化的操作
            [transforms.Resize([224, 224]),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        self.resize(800, 600)
        self.initUI()

    def initUI(self):
        main_widget = QWidget()
        main_layout = QHBoxLayout()
        font = QFont('楷体', 15)
        left_widget = QWidget()
        left_layout = QVBoxLayout()
        img_title = QLabel("测试样本")
        img_title.setFont(font)
        img_title.setAlignment(Qt.AlignCenter)
        self.img_label = QLabel()
        self.predict_img_path = "imgs/img111.jpeg"
        img_init = cv2.imread(self.predict_img_path)
        img_init = cv2.resize(img_init, (400, 400))
        cv2.imwrite('imgs/target.png', img_init)
        self.img_label.setPixmap(QPixmap('imgs/target.png'))
        left_layout.addWidget(img_title)
        left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
        left_widget.setLayout(left_layout)

        right_widget = QWidget()
        right_layout = QVBoxLayout()
        btn_change = QPushButton(" 上传垃圾图像 ")
        btn_change.clicked.connect(self.change_img)
        btn_change.setFont(font)
        btn_predict = QPushButton(" 识别垃圾种类 ")
        btn_predict.setFont(font)
        btn_predict.clicked.connect(self.predict_img)

        label_result = QLabel(' 识 别 结 果 ')
        self.result = QLabel("待识别")
        label_result.setFont(QFont('楷体', 16))
        self.result.setFont(QFont('楷体', 24))
        right_layout.addStretch()
        right_layout.addWidget(label_result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(self.result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(btn_change)
        right_layout.addWidget(btn_predict)
        right_layout.addStretch()
        right_widget.setLayout(right_layout)

        # 关于页面
        about_widget = QWidget()
        about_layout = QVBoxLayout()
        about_title = QLabel('欢迎使用智能垃圾识别系统')
        about_title.setFont(QFont('楷体', 18))
        about_title.setAlignment(Qt.AlignCenter)
        about_img = QLabel()
        about_img.setPixmap(QPixmap('imgs/logoxx.png'))
        about_img.setAlignment(Qt.AlignCenter)
        label_super = QLabel()
        label_super.setText("<a href='https://blog.csdn.net/ECHOSON'>我的个人主页</a>")
        label_super.setFont(QFont('楷体', 12))
        label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)
        # git_img = QMovie('images/')
        about_layout.addWidget(about_title)
        about_layout.addStretch()
        about_layout.addWidget(about_img)
        about_layout.addStretch()
        about_layout.addWidget(label_super)
        about_widget.setLayout(about_layout)

        main_layout.addWidget(left_widget)
        main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)
        self.addTab(main_widget, '主页面')
        self.addTab(about_widget, '关于')
        self.setTabIcon(0, QIcon('imgs/面性计算器.png'))
        self.setTabIcon(1, QIcon('imgs/面性本子vg.png'))

    def change_img(self):
        openfile_name = QFileDialog.getOpenFileName(self, '选择文件', '', 'Image files(*.jpg , *.png, *.jpeg)')
        print(openfile_name)
        img_name = openfile_name[0]
        if img_name == '':
            pass
        else:
            self.predict_img_path = img_name
            img_init = cv2.imread(self.predict_img_path)
            img_init = cv2.resize(img_init, (400, 400))
            cv2.imwrite('imgs/target.png', img_init)
            self.img_label.setPixmap(QPixmap('imgs/target.png'))

    def predict_img(self):
        # 预测图片
        # 开始预测
        # img = Image.open()
        transform = transforms.Compose(
            [transforms.Resize([224, 224]),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        img = Image.open(self.predict_img_path)
        RGB_img = img.convert('RGB')
        img_torch = transform(RGB_img)
        img_torch = img_torch.view(-1, 3, 224, 224)
        outputs = self.net(img_torch)
        _, predicted = torch.max(outputs, 1)
        result = str(names[predicted[0].numpy()])

        self.result.setText(result)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    x = MainWindow()
    x.show()
    sys.exit(app.exec_())

测试

模型训练好之后就可以进行模型的测试了,其中test4dataset.py文件主要是对数据集进行测试,也就是解压之后的test目录下的所有文件进行测试,那么test4singleimg.py文件主要是对单一的图片进行测试。

考虑到大家可能想省去训练的过程,所以我在models目录下放了我训练好的模型,你可以直接使用我训练好的模型进行测试,目前在测试集上的准确率大概在80%左右,不是很高,但是也足够使用。

另外,处理基本的测试之外,还有分类别的测试以及heatmap形式的演示,这部分的代码写的比较乱,暂时放在了abandon目录下,如果项目的star超过100的话,我会再更新这部分的内容。以下就是部分测试的代码;

# from train import load_data
from PIL import ImageFile
import torch
import os
from torchvision import transforms, datasets
import numpy as np
from torch.utils.data import Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# 有些图片信息不全,不能读取,跳过这些图片
ImageFile.LOAD_TRUNCATED_IMAGES = True
np.set_printoptions(suppress=True)

# todo
def load_test_data(data_dir="E:/遥感目标检测数据集/垃圾分类数据集/trash_real_split"):
    data_transforms = {
        'val': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['val', 'test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                                  shuffle=True, num_workers=0)
                   for x in ['val', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['val', 'test']}
    class_names = image_datasets['test'].classes
    return dataloaders, dataset_sizes, class_names


def test_test_dataset(model_path="models/mobilenet_trashv1_2.pt"):
    # 加载模型
    net = torch.load(model_path, map_location=lambda storage, loc: storage)
    dataloaders, dataset_sizes, class_names = load_test_data()
    testloader = dataloaders['test']
    test_size = dataset_sizes['test']
    net.to(device)
    net.eval()
    # 测试全部的准确率
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += torch.sum(predicted == labels.data)
    correct = correct.cpu().numpy()
    print('Accuracy of the network on the %d test images: %d %%' % (test_size,
                                                                    100 * correct / total))


def test_test_dataset_by_classes(model_path="models/mobilenet_trashv1_2.pt"):
    # 加载模型
    net = torch.load(model_path, map_location=lambda storage, loc: storage)
    dataloaders, dataset_sizes, class_names = load_test_data()
    testloader = dataloaders['test']
    test_size = dataset_sizes['test']
    net.to(device)
    net.eval()
    classes = class_names
    # 测试每一类的准确率
    class_correct = list(0. for i in range(len(class_names)))
    class_total = list(0. for i in range(len(class_names)))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    for i in range(len(class_names)):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))


if __name__ == '__main__':
    print('模型在整个数据集上的表现:')
    test_test_dataset()
    print('模型在每一类上的表现:')
    test_test_dataset_by_classes()

测试结果如下图:

image-20210305140104656

图形化界面

图形化界面主要通过Pyqt5来进行开发,主要是完成一些上传图片,对图片进行识别并把识别结果进行输出的功能,俺的审美不是很好,所以设计的界面可能不是很好看,大家后面可以根据自己的需要修改界面。

image-20210305142518950

image-20210305142537660

image-20210305142602138

代码链接

代码链接:trash_torch1.5: 基于pyotrch开发的垃圾分类程序! (gitee.com)

如果你觉得这个项目帮助了你,可以请我喝杯咖啡😊

wxs

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐