torch.roll 函数的理解
torch.roll 函数官方解释翻译torch.roll(input, shifts, dims=None) → Tensorinput (Tensor) —— 输入张量。shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果移位是一个元组,dims必须是一个相同大小的元组,并且每个维度将移位相应的值。dims (int 或 tuple
·
如果是看swin-transformer进来的,推荐看看GitHub上的这个问题,会很有帮助!
https://github.com/microsoft/Swin-Transformer/issues/38
翻译
torch.roll(input, shifts, dims=None) → Tensor
- input (Tensor) —— 输入张量。
- shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
- dims (int 或 tuple of python:int) 确定的维度。
沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。
官方例子
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
'''第0维度向下移1位,多出的[7,8]补充到顶部'''
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
[1, 2],
[3, 4],
[5, 6]])
'''第0维度向上移1位,多出的[1,2]补充到底部'''
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
[5, 6],
[7, 8],
[1, 2]])
'''tuple元祖,维度一一对应:
第0维度向下移2位,多出的[5,6][7,8]补充到顶部,
第1维向右移1位,多出的[6,8,2,4]补充到最左边'''
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面
以下一个多维张量的例子(参考swin transformer论文源码):
torch.roll(x, shifts=(-20, -20), dims=(1, 2))
完整代码
import torch
import numpy as np
import matplotlib.pyplot as plt
shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(-20, -20), dims=(1, 2))
#shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
print("---------经过循环位移了---------")
else:
shifted_x = x
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
plt.title("non_shifted")
else:
plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()
点击阅读全文
更多推荐
活动日历
查看更多
直播时间 2025-02-26 16:00:00


直播时间 2025-01-08 16:30:00


直播时间 2024-12-11 16:30:00


直播时间 2024-11-27 16:30:00


直播时间 2024-11-21 16:30:00


目录
所有评论(0)