grid_sample用于在pytprch的tensor中做不规则采样,下例从一个5*5的空间a中进行4点采样,采样坐标系为(-1,1),grid取了4个角的点,为了便于观察,填充方式为取最接近的点的取值。可以看到,a空间左上角坐标为[-1,-1],右下角坐标为[1,1]。取值方式有‘nearest’、‘bilinear’。当grid为非平均间隔的坐标点时,即可实现不规则采样

import cv2
# import torch_geometric
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


a = np.arange(25).reshape(5,5)
a=torch.FloatTensor(list(a)).unsqueeze(0).unsqueeze(0)
grid = torch.tensor([[[-1.0,-1.0],[1.0,-1.0]],[[-1.0,1.0],[1.0,1.0]]]).unsqueeze(0)
# # 目的是得到一个 长宽为2的tensor
# out_h = 2
# out_w = 2
#  # grid的生成方式等价于用mesh_grid
#  # 都是(-1,1)之间插值
# new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
# new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
# grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)

outp = F.grid_sample(a, grid=grid, mode='nearest')   # mode用bilinear时,会和周围pad的0平均掉,不便于直接观察

print(a.size())   # (N, C, H, W)
print(grid.size())   # (N, H, W, 2)
print(outp.size())   # (N, C, H, W) [1, 1, 2, 2]

print(a)
print(grid)
print(outp)

"""
torch.Size([1, 1, 5, 5])
torch.Size([1, 2, 2, 2])
torch.Size([1, 1, 2, 2])
tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]]]])
tensor([[[[-1., -1.],
          [ 1., -1.]],

         [[-1.,  1.],
          [ 1.,  1.]]]])
tensor([[[[ 0.,  4.],
          [20., 24.]]]])
          
"""

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐