1.nn.Unfold()函数

描述:pytorch中的nn.Unfold()函数,在图像处理领域,经常需要用到卷积操作,但是有时我们只需要在图片上进行滑动的窗口操作,将图片切割成patch,而不需要进行卷积核和图片值的卷积乘法操作。这是就需要用到nn.Unfold()函数,该函数是从一个batch图片中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。

torch.nn.Unfold(kernel_size,dilation=1,paddding=0,stride=1)

该函数的输入是(bs,c,h,w),其中bs为batch-size,C是channel的个数。
而该函数的输出是(bs,Cxkernel_size[0]xkernel_size[1],L)其中L是特征图或者图片的尺寸根据kernel_size的长宽滑动裁剪后得到的多个patch的数量。

import torch.nn as nn
import torch
batches_img=torch.rand(1,2,4,4)#模拟图片数据(bs,2,4,4),通道数C为2
print("batches_img:\n",batches_img)

nn_Unfold=nn.Unfold(kernel_size=(2,2),dilation=1,padding=0,stride=2)
patche_img=nn_Unfold(batches_img)
print("patche_img.shape:",patche_img.shape)
print("patch_img:\n",patche_img)

在这里插入图片描述
该方法的主要应用场景是将图片切割成不同的patch,配合一下代码实现

#上面的代码能够获取到patch_img,(bs,C*K*K,L),L代表的是将每张图片分割成多少块
reshape_patche_img=patche_img.view(batches_img.shape[0],batches_img.shape[1],2,2,-1)
print(reshape_patche_img.shape)#[bs, C, k, k, L]
reshape_patche_img=reshape_patche_img.permute(0,4,1,2,3)#[N, L, C, k, k]
print(reshape_patche_img.shape)

结果:
在这里插入图片描述

2.nn.Fold()函数

该函数是nn.Unfold()函数的逆操作。

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())

在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐