Cycle GAN学习及实现
之前看的AnoSeg的那篇论文,其中使用了两个Generator,但并不太清楚另一个Generator的作用。最近找到一个蛮有意思的GAN,就是CycleGAN做的是风格迁移。这个GAN中也使用了两个Generator,想学习一下,从中看看有没有帮助理解的地方。也是挺好玩的一个GAN,下面是我训练100个Epoch出来的效果。
马—->斑马
斑马——>马
Cycle GAN原理
CycleGAN主要是做的在没有成对的图像情况下,弄清楚如何将一个图像的特征转化为另一个图像的特征。假设X是A类型的图,Y是B类型的图。要训练一种Generator,将X的原图生成新的y,让判别器D无法分辨生成的y和Y,认为他们两个是同一类。然而实际操作中发现无法单独优化判别器。然后,从翻译中找到的灵感,假如从英语翻译到法语,再从法语翻译回英语,那么原来的英语应该和翻译回的英语相同。
就是说CycleGAN当中有两个Generator,一个将X类转为Y类,一个是将Y类转为X类。对应的也有两个Discriminator,一个是认X类的,一个是认Y类的。目标是把X转成Y。
如果只训练一个Generator直接转换的话存在一个问题
这里生成器生成一张图片,判别器的任务是判别是否是Y图像,生成器的任务是生成图片尽可能的变成Y类型的图。但问题是,从X->Y有无数种映射可以完成X->Y。生成器可以完全无视输入条件,直接生成一个与原始图像无关的东西,但也被判别器认为是Y类型的图。因此要求生成器生成的图片不光要欺骗过判别器,同时还要与原图像有一定的关系。
CycleGan的做法:
为了保证生成图像和原始图像有关联,它使用了一个G(x->y)先将x原图转为fake_y的图,然后用G(y->x)将生成的fake_y图转回x类型的图,这两张x图应该是尽可能相同,因为应当是同一张x图。对于G(y->x)也是同理的训练方法。这个模型可以视为两个自动编码器(auto-encoder)。
训练生成器的时候,就和普通的GAN是一样的,用的对抗损失。x -> G(x) -> F(G(x)) ≈ x,两个x应当尽可能相等,使用的是循环一致性损失(cycle-consistency loss)
训练过程
训练过程中使用了3个loss,主要是对抗损失(adversarial loss)、循环一致性损失(Cycle consistency loss)、身份损失(identity loss)。其中身份损失可以不加。不使用Identityloss会使得两个生成器自由改变输入图像的色调,但是会一定程度上提高训练速度。开头的我跑的效果图是使用了identity loss的效果。
对抗损失
对于Generator,将D(G(x))的得分尽量趋于1,对于Discriminator将D(x)得分趋于1,D(G(x))尽量趋于0。论文中第4部分Implementation中的训练细节中说道,使用最小二乘损失取代原来的负对数似然损失,使用最小二乘损失效果更好。
循环一致性损失
就是前面说过的例子,英语->法语->英语的过程。前后英语应当是一样的。应用到图片上,公式如下:
损失函数如下
将原始的x和经过G生成y,y在经过F生成的新的x之间使用L1Loss。然后对y也是这个操作,两个loss求和。
身份损失
这个损失主要是解决色调问题。不加影响不大。
就是将生成器生成的图像和原图做一次L1loss,然后相加
完整模型对象
两个生成器的对抗损失和一个循环一致性损失乘上λ加和即为损失
过程
参数:学习率为0.0002,batch_size 为1,λ为10
以 斑马 <=> 马 为例。
gen_h为生成马的生成器。gen_z为生成斑马的生成器。
disc_h为判别马的判别器。disc_z为判别斑马的判别器。
先训练两个Discriminator
gen_h输入一个zebra图像,生成fake_horse
gen_z输入一个horse图像,生成fake_zebra
将两个生成器生成的fake_horse,fake_zebra分别放到disc_h,disc_z产生分数,让分数靠近0.
disc_z和disc_h在分别接收之前给生成器的zebra和horse图像,这个是真图,所以训练的时候让分数靠近1。
损失使用mseloss
训练两个Generator
将之前生成的fake_zebra和fake_horse分别给对应的判别器,得到分数,靠近1。
然后计算Cycle consistency loss。用fake_zebra放到gen_h,fake_horse放到gen_z。产生cycle_zebra和cycle_horse,这两个应当和最初始的zebra和horse相同。使用l1loss计算。
计算identity loss
用gen_z喂入zebra产生identity_zebra,gen_h为入horse产生identity_horse,zebra对identity_zebra,horse对identity_horse使用L1loss计算。
最终的G_loss = 两个分数的loss + CycleLoss + identityloss(可以不加这个)
网络结构
生成器结构
对于128 x 128的训练图像使用了6个残差块。对于256 x 256或更高分辨率的图像使用了9个残差块。
说明:
c7s1-k:表示 7 x 7的卷积-InstanceNorm-ReLU层,k个滤波器,步长为1.
dk:表示3 x 3卷积层-InstanceNorm-ReLU层,k个滤波器,步长为2
Rk:表示一个残差块,这个残差块有两个3 x 3卷积层,两层之间有相同数量的滤波器。
uk:表示一个3 × 3分数步卷积-InstanceNorm-ReLU层,具有k个滤波器和步长 1/2。分数步长卷积也就是反卷积。
判别器结构
对于判别器,论文中使用的70 x 70的PatchGAN。
说明:
Ck:表示一个4 x 4 卷积-InstanceNorm-LeakyReLU层,有k个滤波器,步长为2。在最后一层,应用了一个卷积去产生一个以为的输出。对于第一层C64层不使用InstanceNorm,LeakyReLU使用的斜率为0.2.
代码实现
config.py
1 | import torch |
Discriminator.py
1 | import torch |
Generator.py
1 | import torch |
train.py
1 | import os.path |
dataset.py
1 | import torch |