pytorch保存图片主要用到 make_grid 和save_image这两个方法。
制作网格:
torchvision.utils.make_grid(tensor, nrow, padding) 
# 参数说明
# tensor(tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列表。
# nrow(int):网格每行图片的个数,默认是8;千万不要理解为图片的行数。
# padding(int):四周填充的宽度,默认是2,可以理解为网格中图片之间的间距。默认填充值是0,也就是黑色。
注:这是三个比较常用的参数,其它参数请参考官方文档(https://pytorch.org/tutorials/)
保存本地:

用pytorch提供的save_image方法tensor数据类型保存时不用再转为PIL.Imagenumpy.ndarray

torchvision.utils.save_image(tensor, fp)
# 参数
# tensor(Tensor or list):待保存的tensor数据(可以是上述处理好的grid)。如果给以一个四维的batch的tensor,将调用网格方法,然后再保存到本地。最后保存的图片是不带batch的。
# fp:图片保存路径
代码示例:
import torch
from torch.utils.data import DataLoader, dataloader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

# 下载minist数据
dataset = datasets.MNIST(
    root='./data/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)

# 加载minist数据
dataloader = DataLoader(
    dataset=dataset,
    batch_size=8,
    shuffle=True
)

# 保存图片
images, labels = next(iter(dataloader))
print(images.size())  # torch.Size([8, 1, 28, 28])
images = make_grid(images, 4, 0)
print(images.size())  # torch.Size([3, 84, 84])
save_image(images, 'D:\maozan1\Desktop\JDWork\\vscode\pytorch-demo\\test.jpg')
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐