【Pytorch代码学习】——数据集划分
简介将数据集划分为训练集和测试集代码介绍目录文件目录存放格式运行前运行后代码import osfrom shutil import copy, rmtreeimport randomdef mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹在重新创建rmtree(file_path)os.makedirs(
·
简介
将数据集划分为训练集和测试集
代码介绍
目录
文件目录存放格式
-
运行前
-
运行后
代码
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹在重新创建
rmtree(file_path)
os.makedirs(file_path)
def main():
# 保证随机可复现
random.seed(0)
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1# **在此处修改想要验证集数量**
# 指向你解压后的flower_photos文件夹
cwd = os.getcwd()
data_root = os.path.join(cwd, "flower_data")# **在此处修改存放数据集文件夹名称**
origin_flower_path = os.path.join(data_root, "flower_photos")# **在此处修改存放未划分数据集文件夹名称**
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]# **在此处修改存放分类数据集名称名称flower_class**
# 建立保存训练集的文件夹
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(train_root, cla))
# 建立保存验证集的文件夹
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
# 随机采样验证集的索引
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
# 将分配至训练集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
代码学习
- assert
代码:
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
代码用来检查条件os.path.exists(origin_flower_path),不符合就终止程序,并输出后面的语句"path ‘{}’ does not exist.".format(origin_flower_path)
检查是否存在origin_flower_path目录,不符合就终止程序,输出path ’地址’ does not exist.
文件目录不存在的输出示例:
2.建立列表
代码
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
以origin_flower_path文件目录下的文件名称按顺序建立列表
输出示例
print(flower_class)
3.文件操作
- os.getcwd()
在Python中可以使用os.getcwd()函数获得当前的路径。得到当前脚本的工作目录(并非脚本存放的绝对路径)
- os.path.join
用于路径拼接文件的路径
- os.path.abspath()
获得当前脚本存放的绝对路径。
- data_root = os.path.abspath(os.path.join(os.getcwd(), “…/…”)) # get data root path
"…/…"返回上上级目录
应用
按文件目录格式存储文件(分类名称文件夹名称按自己的数据集改,也就是daisy等文件夹的名称)后直接调用该函数直接可以划分出训练集和测试集
更多推荐
已为社区贡献1条内容
所有评论(0)