F.grid_sample
grid_sample用于在pytprch的tensor中做不规则采样,下例从一个5*5的空间a中进行4点采样,采样坐标系为(-1,1),grid取了4个角的点,为了便于观察,填充方式为取最接近的点的取值。可以看到,a空间左上角坐标为[-1,-1],右下角坐标为[1,1]。取值方式有‘nearest’、‘bilinear’。当grid为非平均间隔的坐标点时,即可实现不规则采样import cv2#
·
grid_sample用于在pytprch的tensor中做不规则采样,下例从一个5*5的空间a中进行4点采样,采样坐标系为(-1,1),grid取了4个角的点,为了便于观察,填充方式为取最接近的点的取值。可以看到,a空间左上角坐标为[-1,-1],右下角坐标为[1,1]。取值方式有‘nearest’、‘bilinear’。当grid为非平均间隔的坐标点时,即可实现不规则采样
import cv2
# import torch_geometric
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
a = np.arange(25).reshape(5,5)
a=torch.FloatTensor(list(a)).unsqueeze(0).unsqueeze(0)
grid = torch.tensor([[[-1.0,-1.0],[1.0,-1.0]],[[-1.0,1.0],[1.0,1.0]]]).unsqueeze(0)
# # 目的是得到一个 长宽为2的tensor
# out_h = 2
# out_w = 2
# # grid的生成方式等价于用mesh_grid
# # 都是(-1,1)之间插值
# new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
# new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
# grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
outp = F.grid_sample(a, grid=grid, mode='nearest') # mode用bilinear时,会和周围pad的0平均掉,不便于直接观察
print(a.size()) # (N, C, H, W)
print(grid.size()) # (N, H, W, 2)
print(outp.size()) # (N, C, H, W) [1, 1, 2, 2]
print(a)
print(grid)
print(outp)
"""
torch.Size([1, 1, 5, 5])
torch.Size([1, 2, 2, 2])
torch.Size([1, 1, 2, 2])
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]]]])
tensor([[[[-1., -1.],
[ 1., -1.]],
[[-1., 1.],
[ 1., 1.]]]])
tensor([[[[ 0., 4.],
[20., 24.]]]])
"""
更多推荐
已为社区贡献1条内容
所有评论(0)