今天在训练网络的时候,考虑做一个实验需要将pytorch里面的某个Tensor沿着特征维度进行shuffle,之前考虑的是直接使用shuffle函数(random.shuffle),但是发现random.shuffle函数会导致Tensor里面数据被重复提取,进而导致输出跟输入分布发生变化。

a = torch.randint(0, 2, (3, 3))
print(a)
random.shuffle(a)
print(a)

因此,经过查找资料后,总结pytorch的tensor可以利用如下方式进行shuffle。

 

1、所有elements随机shuffle

先上代码,代码示例来源于https://discuss.pytorch.org/t/shuffling-a-tensor/25422/4

t=torch.tensor([[1, 2, 3],[3, 4, 5]])
print(t)
idx = torch.randperm(t.nelement())
t = t.view(-1)[idx].view(t.size())
print(t)

 

核心思路就是生成一个index,然后将index进行shuffle,最终利用shuffle后的index把Tensor进行重新排列。

 

2、按照某一特定维度进行shuffle

此外,如果想像我一样按照某一特定维度进行shuffle,那么可以利用如下代码:

t=torch.tensor([[1, 2, 3],[3, 4, 5]])
print(t)
idx = torch.randperm(t.shape[1])
t = t[:, idx].view(t.size())
print(t)

 

如果是想针对多维Tensor可以用同样的方法进行shuffle。

希望这篇文章对你有帮助。

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐