生成对抗网络系列
【生成对抗网络】GAN入门与代码实现(一)
【生成对抗网络】GAN入门与代码实现(二)
【生成对抗网络】基于DCGAN的二次元人物头像生成(TensorFlow2)
【生成对抗网络】ACGAN的代码实现

1. 生成对抗网络介绍

生成对抗网络(Generative Adversarial Network)于2014年被Goodfellow等人提出,然后迅速流行。GAN能通过学习特定领域知识创造出新的图像、文本等。2016年,GAN热潮席卷人工智能领域顶级会议,从ICLR到NIPS,大量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。

在GAN中主要由生成器G(generator)与判别器D(discriminator)构成。其中生成器用于生成逼真的假数据、判别器则需要在判别出真实数据与假数据,生成器与判别器相互博弈,在能力上有所提升,生成器生成的数据越来越像是真实的数据,判别器则能更好地将两者分辨出来,直到两者达到一种平衡。

假如以小狗图片作为生成的目标:

  • 生成器:接收一个随机噪声(随机变量)作为输入,输出一个小狗的图片(假图片)。
  • 判别器:将原真实的小狗图片和生成器生成的小狗图片两者区分出来,判断谁真谁假。

在模型训练的过程中:

​ 生成器:学习如何更好的将生成的小狗图片更加像真实,从而让判别器误认为是真实的。

​ 判别器:不断地将生成器生成的图片与真实的图片用于判别器模型的训练,提高自己的判别准确率。

GAN的整个训练过程如下:

  1. ​ 生成器接收随机噪声,并生成假图像;
  2. ​ 判别器接收假图像和真实图像组合的数据,学习如何判别真假图像;
  3. ​ 生成器生成新的图像,并使用判别器来判别真假,同时通过判别器来判别此次造假的水平;
  4. ​ 重复步骤 1-3。

2. 基于TensorFlow2的GAN的简单实现

我们以手写数据集MNIST为例进行演示。让GAN学习生成一些新的手写数字图片,每张图片的尺寸为28*28。
在这里插入图片描述

代码实现步骤如下:

  1. 定义生成器,接收随机噪声,输出图像张量
  2. 定义判别器,接收图像张量,输出真假张量
  3. 定义生成对抗网络,接收随机噪声,输出真假张量。生成对抗网络由前面定义的生成器的模型层和判别器的模型创建( 它们共享权重),同时需要冻结判别器的权重。
  4. 将随机噪声输入生成器,生成一批图像
  5. 使用生成的图像与真实图像训练判别器(假图像的目标为0,真图像的目标为1)
  6. 使用新随机噪声输入生成对抗网络,输出真假(使生成的假图像判别为1),提高“造假”水平
  7. 重复4-6步骤

2.1 导包与参数设置

import numpy as np # 用于数据处理
import tensorflow as tf # 版本2.0及以上
from tensorflow import keras # 主要使用keras实现
import tqdm # 进度条,使用pip install tqdm安装
import matplotlib.pyplot as plt # 绘图函数库
%matplotlib inline
LATENT_DTM = 100 # 随机噪声的长度
IMAGE_SHAPE = (28,28,1) # 手写数字图片的尺寸与通道数

2.2 生成器

生成器接收随机向量,然后通过模型生成一张手写数字图片。

关键点:

  • 使用随机噪声作为输入,保证模型具有一定的随机性
  • 使用tanh作为最后一层的激活函数,可以获得更好的效果
  • 使用LeakyReLU激活函数来代替ReLU激活函数
generator_net = [
    keras.layers.Input(shape=(LATENT_DTM,)), # 输入为长度100点随机向量
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(1024),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.BatchNormalization(momentum = 0.8),
    keras.layers.Dense(np.prod(IMAGE_SHAPE),activation='tanh'),
    keras.layers.Reshape(IMAGE_SHAPE) #  将向量重塑shape为(28,28,1),输出图片
]
generator = keras.models.Sequential(generator_net)

2.3 判别器

判别器是一个二分类问题,接收一个图片,输出真假。

discriminator_net =[
    keras.layers.Input(shape=IMAGE_SHAPE),
    keras.layers.Flatten(),
    keras.layers.Dense(512),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(256),
    keras.layers.LeakyReLU(alpha = 0.2),
    keras.layers.Dense(1,activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5)

模型编译:

discriminator.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc'])

2.4 搭建生成对抗网络

将生成器与判别器组合在一起,同时冻结判别器的权重。

该过程将生成器生成的图片直接送入判别器模型,从而直接输出结果。在该网络中,需要冻结判别器的权重,因为我们需要在此过程中训练生成器,让判别器的结果输出为“真”,从而不断完善生成器生成图像的水平,所以只需要训练生成器的层。

# 生成对抗网络使用生成器模型层和判别器模型层,它们共享权重。
adversarial_net = generator_net + discriminator_net

# 冻结判别器的层的权重
# trainable 属性只有编译后才生效,所以之前的判别器中同样的层还是可以训练的
for layer in discriminator_net:
    layer.trainable = False
adversarial = keras.models.Sequential(adversarial_net)

优化器:

optimizer = keras.optimizers.Adam(0.0002,0.5) # 优化器

模型编译:

adversarial.compile(loss=keras.losses.binary_crossentropy,optimizer=optimizer,metrics=['acc']) # 模型编译

2.5 数据准备与预处理

加载keras中内置的手写数据集

(image_set,_),_ = keras.datasets.mnist.load_data() # 加载数据集
image_set = image_set/127.5 - 1 # shape为(60000,28,28)
image_set = image_set.reshape((image_set.shape[0],28,28,1)) # shape为(60000,28,28,1)

准备训练过程中可视化的随机向量seed

num_example_to_generate = 6 # 用于绘图过程中生成图片的数量

seed = np.random.normal(0,1,(num_example_to_generate,LATENT_DTM)) # 生成6个长度为100的随机向量

用于记录训练过程中的准确率与损失

# 损失
g_loss_list = [] # 生成器
d_loss_list = [] # 判别器
# 准确率
g_acc_list = [] # 生成器
d_acc_list = [] # 判别器

2.6 主训练方法

def train(batch = 30000,batch_size = 300):
    # 准备batch_size大小的真假数据标签
    valid = np.ones((batch_size)) # 全是1
    fake = np.zeros((batch_size)) # 全是0
    
    # 使用进度条tqdm库
    batch_tqdm = tqdm.trange(batch)
    for index in batch_tqdm:
        
        # 随机选择batch_size数量的数据作为训练数据
        idx = np.random.randint(0,image_set.shape[0],batch_size)
        imgs = image_set[idx] 
        
        # 生成噪声数据并作为生成器的输入
        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))
        
        # 使用生成器生成图像
        gen_imgs = generator.predict(noise)
        
        # 训练判别器
        # 使用真实图像和生成图像训练判别器,真实图像的标签全部为1,生成图像的标签全部为0
        d_state_real = discriminator.train_on_batch(imgs,valid) # 返回的是loss和acc
        d_state_fake = discriminator.train_on_batch(gen_imgs,fake)
        # 判别器在生成图像与真实图像两者的结果取平局值
        d_state = 0.5*(np.add(d_state_real,d_state_fake))
        
        # 训练判别器
        noise = np.random.normal(0,1,(batch_size,LATENT_DTM))
        # 训练生成对抗网络,目标是生成判别器人物真实的图像,因此标签为1
        # 因为生成对抗网络中的判别器的层都冻结了,所以实际上在训练生成器,不断生成更加逼真的图像
        adv_state = adversarial.train_on_batch(noise,valid)
        
        # 更新进度条后缀文本,用于输出训练进度
        state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
                f"[G loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}]"
        batch_tqdm.set_postfix(state=state)
        # 存储损失值和准确率
        g_loss_list.append(adv_state[0])
        g_acc_list.append(adv_state[1])
        d_loss_list.append(d_state[0])
        d_acc_list.append(d_state[1])
        
        if index%500 == 0: # 每500次绘图一次
            generate_plot_image(seed) # 绘图函数,每次都用同一个随机噪声seed生成图片,可以看到数字的变化

注意model的train_on_batch方法的使用。

2.7 绘图函数

用固定的noise绘制6张图片,以便观察训练效果。

# 画图函数
def generate_plot_image(test_noise):

    pre_image = generator(test_noise,training = False) # 用生成器,生成手写图片
    # print(pre_image.shape) # (6,28,28,1)
    fig = plt.figure(figsize=(16,3)) # figsize:指定figure的宽和高,单位为英寸
    for i in range(pre_image.shape[0]):   # pre_image的shape的第一个维度就是个数,这里是6
        plt.subplot(1,6,i+1) # 几行几列的 第i+1个图片(从1开始)
        plt.imshow((pre_image[i,:,:,:] + 1)/2) # 加1除2: 将生成的-1~1的图片弄到0-1之间,
        plt.axis('off') # 不要坐标
    plt.show()

2.8 开始训练

训练30000个batch,每个batch随机拿出300个图片用于训练。

batch = 30000
batch_size = 300
train(batch,batch_size)

2.9 loss与acc绘图

损失Loss:

plt.plot(range(1, batch+1), g_loss_list, label='g_loss')
plt.plot(range(1, batch+1), d_loss_list, label='d_loss')
plt.legend()

准确率Acc:

plt.plot(range(1, batch+1), g_acc_list, label='g_acc')
plt.plot(range(1, batch+1), d_acc_list, label='d_acc')
plt.legend()

2.10 结果

可以看到生成器生成图片的效果越来越好

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

loss:

在这里插入图片描述

acc:

更新GAN的另一种实现方法:使用TensorFlow2中求导机制进行自定义训练的GAN代码实现,可对比进行学习。
博客链接:【生成对抗网络】GAN入门与代码实现(二)

参考文献:《TensorFlow2实战》艾力

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐