torch.utils.data.dataset.random_split随机划分后对划分后数据处理

在使用torch.utils.data.dataset.random_split后,生成同属于Dataset类型的Subset类,若想对划分后的训练集(train)和验证集(validation)再进行处理,只需对train_set对象进行浅拷贝即可改变类内属性。

         data_set = MySegmentation(cfg, split='train')
	     # data_set.change_split()
	     n_val = int(len(data_set) * cfg["train"]["val_percent"])
	     n_train = len(data_set) - n_val
	     train_set, val_set = random_split(data_set, [n_train, n_val])
	     # 对划分后的数据集浅拷贝,修改默认类内属性
  		 train_set.dataset = copy(data_set)
         val_set.dataset.split = "val"
	     print("train dataset:{}".format(len(train_set)))
	     print("validation dataset:{}".format(len(val_set)))
	     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,**kwargs)
	     val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)
	     return train_loader, val_loader

在随机划分数据集后,数据集为subset类,包含两个对象dataset和indices,其中indice为对应随机抽出的索引位置。其中dataset在划分的 训练集 和 测试集中数据仍指向相同地址,改变其中一个对象属性则都会全部修改。
在这里插入图片描述
通过使用浅拷贝的方法,改变其中一个数据集的指向地址,再对属性进行修改,就会修改指定的对象属性。
在这里插入图片描述
上图则是浅拷贝之后改变数据集地址。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐