GAN生成动漫头像
这篇文章主要是听的李宏毅老师的GAN课程,结合了一些《深度学习框架pytorch入门与实践》中的代码实现的。
GAN原理简介
GAN(Generative adversarial Networks)生成对抗网络,GAN解决了一个著名问题:给定一批样本,训练一个系统可以生成类似的新样本。生成对抗网络,顾名思义,有两个部分一个是生成器(Generator),一个判别器(Discriminator),两者相互对抗,左右互博。
- 生成器(Generator):输入一个随机噪声,生成一张图片
- 判别器(Discriminator):判别图像是真图片还是假图片
如图展示了基本训练过程,第一代的生成器最开始由于是随机初始化,传入一个随机噪点,产生的图片是模糊一片,第一代判别器做的事情就是,能够辨别图片是第一代的生成器产生的图片,还是真实图片。可能第一代判别器认为有颜色的是真图,而生成器要骗过判别器,所以要进化,进化成了第二代,产生有颜色的图片。判别器也随之进化,找出有嘴巴的是真图,辨别出了生成图和真图。那么生成器也进化到第三代,产生了嘴巴的图片,骗过第二代的判别器,判别器又进化成第三代。一步一步对抗,相互进化,最终生成二次元头像。
算法过程
训练Discriminator
首先初始化生成器和判别器,在每次训练迭代中,先固定住生成器,传入随机噪点给生成器,生成对应的图片。前面说了判别器的任务是判别真假图片,判别器主要是给图片打分,分数越高认为真图片的概率越高,分数越低认为是假图片概率越高。所以训练判别器时,收到Database中的图和生成器给的图,调整参数,如果是Database的图就给高分,如果是生成器给的图就给低分。换句话说,我们训练处的判别器就是要数据集给的图越接近1越好,生成器给的图越接近0越好。
训练Generator
前面训练好了判别器,现在训练生成器,生成器需要做的事情是,收到随机噪点,产生一张图片,要“骗”过判别器。如何“骗”过生成器呢。实际上的做法是,将生成的图片,送到判别器中进行评分,目标是让评分越高越好(越接近1越好)。整个部分可以看做整体,就是一个巨大的hidden layer,其中有Generator和Discriminator,输入噪点,生成评分。调整参数让评分不断靠近1,但是中间部分Discriminator不能调整。
整个过程
- 从数据集中选出m个样本{x1,x2….xm};接着创建m个噪声,这个z的维度由自己决定,这个就是后面生成器输入的噪声;获取生成数据,就是G(z)生成器生成的;更新判别器参数,使得下面的公式最大化,下边公式的意思是:D(x)判别器判别真图片的分数去log平均值 加上 1 - 判别器判别假图片的值 的log平均值。简单说就是要判别真图片的分数越大越好,判别假图片时,假图片的评分距离1的值越远越好。(相当于训练二分类器,用bceloss)
- m个噪点z,更新生成器,生成器根据z产生的图片,喂给判别器,产生的分数越高越好。
BCE Loss
bce loss分类,用于二分类问题。数学公式如下
pytorch中bceloss
1 | class torch.nn.BCELoss(weight: Optional[torch.Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') |
weight: 初始化权重矩阵
size_average: 默认是True,对loss求平均数
reduction: 默认求和, 对于batch_size的loss平均数
代码实现
前面的理论了解之后,可以着手实现了我这里用的数据集是Extra Data,也可以用Anime Dataset尝试。请自行找梯子。
初始化
1 | import matplotlib.pyplot as plt |
生成器
1 | class Generator(nn.Module): |
判别器
1 | class Discriminator(nn.Module): |
优化器
1 | generator = Generator().to(device) |
损失函数
1 | def loss_g_func(testLabel, trueLabel): |
开始训练
1 | # 训练Discriminator |
结果
这个是1个Epoch的效果
5个Epoch
10个Epoch
25个Epoch