项目场景:

跑U-net网络的时候,有一步是torch.cat()操作,出现在这里插入图片描述
下面是代码

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.conv1 = DoubleConv(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(256, 512)
        self.up6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv6 = DoubleConv(512, 256)
        self.up7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv7 = DoubleConv(256, 128)
        self.up8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv8 = DoubleConv(128, 64)
        self.up9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv9 = DoubleConv(64, 32)
        self.conv10 = nn.Conv2d(32, out_channels, 1)

    def forward(self, x):
        #print(x.shape)
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        #print(p1.shape)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        #print(p2.shape)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        #print(p3.shape)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        #print(p4.shape)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        #print(up_7.shape,c3.shape)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

问题描述

本来是在跑彩色图像的去噪任务,想尝试一些灰度图像的去噪。刚开始只是修改了img_channel,将3改为1,想当然的将crop_img_size改为了灰度图像数据集的统一尺寸,没有考虑pool(2)操作后和上采样后的对齐问题。
比如5/2=2
但是2*2=4
这时候,4和5就对不上了。


原因分析:

torch.cat()函数的功能是将多个tensor类型矩阵的连接。它有两个参数,第一个是tensor元组或者tensor列表;第二个是dim,如果tensor是二维的,dim=0指在行上连接,dim=1指在列上连接。

注意:torch.cat 进行连接的tensor的shape,除了需要连接的维度上的shape值可不同,必须拥有相同的shape,a是(2,3),b是(2,20)即torch.cat((a,b),-1)可以进行连接;torch.cat((a,b),0)不可以进行连接,因为3和20值不同

那么问题找到了,就是维度没有对上
为什么没有对上呢?先开始跑彩色图像的时候crop_img_size是256,稳稳的在2的幂结果上,怎么除2都不会有余数,再乘回去也没有误差。
但是我改成了180后,180/2/2=45,45/2=22,这时候22上采样完是44,44与45不齐了,问题产生。


解决方案:

尺寸crop_img_size改为2的幂结果,比如128,这应该也是为什么我们看到的大部分输入的size都是256、512等等这些数的原因,为了能整除,没有余数。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐