pytorch中的save_image
pytorch中的save_imagepytorch保存图片主要用到 make_grid 和 save_image这两个方法。制作网格:torchvision.utils.make_grid(tensor, nrow, padding)# 参数说明# tensor(tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列
·
pytorch保存图片主要用到 make_grid 和save_image这两个方法。
制作网格:
torchvision.utils.make_grid(tensor, nrow, padding)
# 参数说明
# tensor(tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列表。
# nrow(int):网格每行图片的个数,默认是8;千万不要理解为图片的行数。
# padding(int):四周填充的宽度,默认是2,可以理解为网格中图片之间的间距。默认填充值是0,也就是黑色。
注:这是三个比较常用的参数,其它参数请参考官方文档(https://pytorch.org/tutorials/)。
保存本地:
用pytorch提供的save_image方法tensor
数据类型保存时不用再转为PIL.Image
或numpy.ndarray
torchvision.utils.save_image(tensor, fp)
# 参数
# tensor(Tensor or list):待保存的tensor数据(可以是上述处理好的grid)。如果给以一个四维的batch的tensor,将调用网格方法,然后再保存到本地。最后保存的图片是不带batch的。
# fp:图片保存路径
代码示例:
import torch
from torch.utils.data import DataLoader, dataloader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
# 下载minist数据
dataset = datasets.MNIST(
root='./data/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
# 加载minist数据
dataloader = DataLoader(
dataset=dataset,
batch_size=8,
shuffle=True
)
# 保存图片
images, labels = next(iter(dataloader))
print(images.size()) # torch.Size([8, 1, 28, 28])
images = make_grid(images, 4, 0)
print(images.size()) # torch.Size([3, 84, 84])
save_image(images, 'D:\maozan1\Desktop\JDWork\\vscode\pytorch-demo\\test.jpg')
更多推荐
已为社区贡献1条内容
所有评论(0)