如何在Pytorch中对Tensor进行shuffle
今天在训练网络的时候,考虑做一个实验需要将pytorch里面的某个Tensor沿着特征维度进行shuffle,之前考虑的是直接使用shuffle函数(random.shuffle),但是发现random.shuffle函数会导致Tensor里面数据被重复提取,进而导致输出跟输入分布发生变化。a = torch.randint(0, 2, (3, 3))print(a)random.shuffle(
·
今天在训练网络的时候,考虑做一个实验需要将pytorch里面的某个Tensor沿着特征维度进行shuffle,之前考虑的是直接使用shuffle函数(random.shuffle),但是发现random.shuffle函数会导致Tensor里面数据被重复提取,进而导致输出跟输入分布发生变化。
a = torch.randint(0, 2, (3, 3))
print(a)
random.shuffle(a)
print(a)
因此,经过查找资料后,总结pytorch的tensor可以利用如下方式进行shuffle。
1、所有elements随机shuffle
先上代码,代码示例来源于https://discuss.pytorch.org/t/shuffling-a-tensor/25422/4
t=torch.tensor([[1, 2, 3],[3, 4, 5]])
print(t)
idx = torch.randperm(t.nelement())
t = t.view(-1)[idx].view(t.size())
print(t)
核心思路就是生成一个index,然后将index进行shuffle,最终利用shuffle后的index把Tensor进行重新排列。
2、按照某一特定维度进行shuffle
此外,如果想像我一样按照某一特定维度进行shuffle,那么可以利用如下代码:
t=torch.tensor([[1, 2, 3],[3, 4, 5]])
print(t)
idx = torch.randperm(t.shape[1])
t = t[:, idx].view(t.size())
print(t)
如果是想针对多维Tensor可以用同样的方法进行shuffle。
希望这篇文章对你有帮助。
更多推荐
已为社区贡献1条内容
所有评论(0)