Cycle GAN学习及实现

  1. 1. Cycle GAN学习及实现
    1. 1.1. Cycle GAN原理
    2. 1.2. 训练过程
      1. 1.2.1. 对抗损失
      2. 1.2.2. 循环一致性损失
      3. 1.2.3. 身份损失
      4. 1.2.4. 完整模型对象
      5. 1.2.5. 过程
    3. 1.3. 网络结构
      1. 1.3.1. 生成器结构
      2. 1.3.2. 判别器结构
    4. 1.4. 代码实现
    5. 1.5. 相关链接

Cycle GAN学习及实现

之前看的AnoSeg的那篇论文,其中使用了两个Generator,但并不太清楚另一个Generator的作用。最近找到一个蛮有意思的GAN,就是CycleGAN做的是风格迁移。这个GAN中也使用了两个Generator,想学习一下,从中看看有没有帮助理解的地方。也是挺好玩的一个GAN,下面是我训练100个Epoch出来的效果。

马—->斑马

斑马——>马

Cycle GAN原理

CycleGAN主要是做的在没有成对的图像情况下,弄清楚如何将一个图像的特征转化为另一个图像的特征。假设X是A类型的图,Y是B类型的图。要训练一种Generator,将X的原图生成新的y,让判别器D无法分辨生成的y和Y,认为他们两个是同一类。然而实际操作中发现无法单独优化判别器。然后,从翻译中找到的灵感,假如从英语翻译到法语,再从法语翻译回英语,那么原来的英语应该和翻译回的英语相同。

​ 就是说CycleGAN当中有两个Generator,一个将X类转为Y类,一个是将Y类转为X类。对应的也有两个Discriminator,一个是认X类的,一个是认Y类的。目标是把X转成Y。

​ 如果只训练一个Generator直接转换的话存在一个问题

这里生成器生成一张图片,判别器的任务是判别是否是Y图像,生成器的任务是生成图片尽可能的变成Y类型的图。但问题是,从X->Y有无数种映射可以完成X->Y。生成器可以完全无视输入条件,直接生成一个与原始图像无关的东西,但也被判别器认为是Y类型的图。因此要求生成器生成的图片不光要欺骗过判别器,同时还要与原图像有一定的关系。

​ CycleGan的做法:

为了保证生成图像和原始图像有关联,它使用了一个G(x->y)先将x原图转为fake_y的图,然后用G(y->x)将生成的fake_y图转回x类型的图,这两张x图应该是尽可能相同,因为应当是同一张x图。对于G(y->x)也是同理的训练方法。这个模型可以视为两个自动编码器(auto-encoder)。

训练生成器的时候,就和普通的GAN是一样的,用的对抗损失。x -> G(x) -> F(G(x)) ≈ x,两个x应当尽可能相等,使用的是循环一致性损失(cycle-consistency loss)

训练过程

训练过程中使用了3个loss,主要是对抗损失(adversarial loss)、循环一致性损失(Cycle consistency loss)、身份损失(identity loss)。其中身份损失可以不加。不使用Identityloss会使得两个生成器自由改变输入图像的色调,但是会一定程度上提高训练速度。开头的我跑的效果图是使用了identity loss的效果。

对抗损失

GANLoss

对于Generator,将D(G(x))的得分尽量趋于1,对于Discriminator将D(x)得分趋于1,D(G(x))尽量趋于0。论文中第4部分Implementation中的训练细节中说道,使用最小二乘损失取代原来的负对数似然损失,使用最小二乘损失效果更好。

循环一致性损失

就是前面说过的例子,英语->法语->英语的过程。前后英语应当是一样的。应用到图片上,公式如下:

损失函数如下

将原始的x和经过G生成y,y在经过F生成的新的x之间使用L1Loss。然后对y也是这个操作,两个loss求和。

身份损失

这个损失主要是解决色调问题。不加影响不大。

就是将生成器生成的图像和原图做一次L1loss,然后相加

完整模型对象

两个生成器的对抗损失和一个循环一致性损失乘上λ加和即为损失

过程

参数:学习率为0.0002,batch_size 为1,λ为10

以 斑马 <=> 马 为例。

gen_h为生成马的生成器。gen_z为生成斑马的生成器。

disc_h为判别马的判别器。disc_z为判别斑马的判别器。

先训练两个Discriminator

gen_h输入一个zebra图像,生成fake_horse

gen_z输入一个horse图像,生成fake_zebra

将两个生成器生成的fake_horse,fake_zebra分别放到disc_h,disc_z产生分数,让分数靠近0.

disc_z和disc_h在分别接收之前给生成器的zebra和horse图像,这个是真图,所以训练的时候让分数靠近1。

损失使用mseloss

训练两个Generator

将之前生成的fake_zebra和fake_horse分别给对应的判别器,得到分数,靠近1。

然后计算Cycle consistency loss。用fake_zebra放到gen_h,fake_horse放到gen_z。产生cycle_zebra和cycle_horse,这两个应当和最初始的zebra和horse相同。使用l1loss计算。

计算identity loss

用gen_z喂入zebra产生identity_zebra,gen_h为入horse产生identity_horse,zebra对identity_zebra,horse对identity_horse使用L1loss计算。

最终的G_loss = 两个分数的loss + CycleLoss + identityloss(可以不加这个)

网络结构

论文中提供了PytorchTorch版本的实现源码

访问慢的可以使用这个链接:PytorchTorch

生成器结构

对于128 x 128的训练图像使用了6个残差块。对于256 x 256或更高分辨率的图像使用了9个残差块。

GeneratorArchitecttures

说明:

c7s1-k:表示 7 x 7的卷积-InstanceNorm-ReLU层,k个滤波器,步长为1.

dk:表示3 x 3卷积层-InstanceNorm-ReLU层,k个滤波器,步长为2

Rk:表示一个残差块,这个残差块有两个3 x 3卷积层,两层之间有相同数量的滤波器。

uk:表示一个3 × 3分数步卷积-InstanceNorm-ReLU层,具有k个滤波器和步长 1/2。分数步长卷积也就是反卷积。

判别器结构

对于判别器,论文中使用的70 x 70的PatchGAN。

说明:

Ck:表示一个4 x 4 卷积-InstanceNorm-LeakyReLU层,有k个滤波器,步长为2。在最后一层,应用了一个卷积去产生一个以为的输出。对于第一层C64层不使用InstanceNorm,LeakyReLU使用的斜率为0.2.

代码实现

config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import os
import platform
import torchvision.transforms as Transforms

is_in_windows = platform.system() == "Windows"

Device = "cuda" if torch.cuda.is_available() else "cpu"
train_dir = "" # 训练文件路径
test_dir = "" # 测试文件路径
batch_size = 1
learning_rate = 2e-4

simple_transform = Transforms.Compose([
Transforms.Resize((256, 256)),
Transforms.ToTensor(),
Transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_num = 50 # 训练多少次
lambda_cycle = 10 # 循环一致性损失的权重,论文中的设定为10
lambda_identity = 1.0 # 使用identityloss
load_num_param = 100 # 加载第几个参数

Discriminator.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn

class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(Block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(0.2)
)

def forward(self, x):
return self.conv(x)

class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super(Discriminator, self).__init__()
self.initial = nn.Sequential(
nn.Conv2d(
in_channels,
features[0],
kernel_size=4,
stride=2,
padding=1,
padding_mode="reflect"
),
nn.LeakyReLU(0.2)
)

layers = []
in_channels = features[0]
for feature in features[1:]:
layers.append(Block(in_channels, feature, stride=1 if feature == features[-1] else 2))
in_channels = feature
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
self.model = nn.Sequential(*layers)

def forward(self,x):
x = self.initial(x)
return torch.sigmoid(self.model(x))

def test():
x = torch.randn((5,3,256,256))
model = Discriminator(in_channels=3)
preds = model(x)
print(preds.size())
if __name__ == "__main__":
test()

Generator.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
super(ConvBlock, self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
if down
else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True) if use_act else nn.Identity()

)
def forward(self, x):
return self.conv(x)

class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
ConvBlock(channels, channels, kernel_size=3, padding=1),
ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
)
def forward(self, x):
return x + self.block(x)

class Generator(nn.Module):
def __init__(self, img_channels, num_features= 64, num_residuals=9):
super(Generator, self).__init__()
self.initial = nn.Sequential(
nn.Conv2d(img_channels, out_channels=num_features, kernel_size=7, stride=1, padding_mode="reflect", padding=3),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True)
)
self.down_blocks = nn.ModuleList(
[
ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1),
]
)
self.residual_blocks = nn.Sequential(
*[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
)
self.up_blocks = nn.ModuleList([
ConvBlock(num_features * 4, num_features *2, down=False, kernel_size=3, stride=2, padding=1, output_padding= 1),
ConvBlock(num_features * 2, num_features * 1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
])
self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

def forward(self, x):
x = self.initial(x)
for layer in self.down_blocks:
x = layer(x)
x = self.residual_blocks(x)
for layer in self.up_blocks:
x = layer(x)
return torch.tanh(self.last(x))

def test():
img_channels = 3
img_size = 256
x = torch.randn((2, img_channels, img_size, img_size))
gen = Generator(img_channels, 9)
print(gen(x).size())

if __name__ == "__main__":
test()

train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os.path
import torch
import torchvision.utils

from dataSet import HorseZebraDataSet
import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from Discriminator import Discriminator
from Generator import Generator
import config
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import utils as vutils

def save_image_tensor(input_tensor: torch.Tensor, filename):
input_tensor = input_tensor.clone().detach()
input_tensor = input_tensor.to(torch.device("cpu"))
input_tensor = input_tensor * 0.5 + 0.5
vutils.save_image(input_tensor, filename)


def train_fn(disc_h, disc_z, gen_z, gen_h, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
H_reals = 0
H_fakes = 0
loop = tqdm(loader, leave=True)
for idx, (zebra, horse) in enumerate(loop):
zebra = zebra.to(config.Device)
horse = horse.to(config.Device)
with torch.cuda.amp.autocast():
fake_horse = gen_h(zebra)
D_H_real = disc_h(horse)
D_H_fake = disc_h(fake_horse.detach())
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
D_H_loss = D_H_fake_loss + D_H_real_loss

fake_zebra = gen_z(horse)
D_Z_real = disc_z(zebra)
D_Z_fake = disc_z(fake_zebra.detach())
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_H_real))
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_H_fake))
D_Z_loss = D_Z_fake_loss + D_Z_real_loss

D_loss = (D_H_loss + D_Z_loss) / 2

opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()

# 训练Generator
with torch.cuda.amp.autocast():
D_H_fake = disc_h(fake_horse)
D_Z_fake = disc_z(fake_zebra)
loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

# cycle loss
cycle_zebra = gen_z(fake_horse)
cycle_horse = gen_h(fake_zebra)
cycle_zebra_loss = l1(zebra, cycle_zebra)
cycle_horse_loss = l1(horse, cycle_horse)

# identity loss
identity_zebra = gen_z(zebra)
identity_horse = gen_h(horse)
identity_zebra_loss = l1(zebra, identity_zebra)
identity_horse_loss = l1(horse, identity_horse)

G_loss = (
loss_G_Z
+ loss_G_H
+ cycle_zebra_loss * config.lambda_cycle
+ cycle_horse_loss * config.lambda_cycle
+ identity_horse_loss * config.lambda_identity
+ identity_zebra_loss * config.lambda_identity
)

opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 200 == 199:
print("loss: ", G_loss.item())
loop.set_postfix(H_real=H_reals/(idx + 1), H_fake=H_fakes/(idx+1))


def saveModel(model, filename):
torch.save(model.state_dict(), filename)

def main():
disc_H = Discriminator(in_channels=3).to(config.Device)
disc_Z = Discriminator(in_channels=3).to(config.Device)
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.Device)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.Device)
opt_disc = optim.Adam(
list(disc_H.parameters()) + list(disc_Z.parameters()),
lr=config.learning_rate,
betas=(0.5, 0.999)
)
opt_gen = optim.Adam(
list(gen_H.parameters()) + list(gen_Z.parameters()),
lr=config.learning_rate,
betas=(0.5, 0.999)
)
L1 = nn.L1Loss()
mse = nn.MSELoss()

dataset = HorseZebraDataSet(root_zebra=config.train_dir + "/trainB", root_horse=config.train_dir + "/trainA",
transform= config.simple_transform)
testDataSet = HorseZebraDataSet(root_zebra=config.test_dir + "/testB", root_horse=config.test_dir + "/testA", transform= config.simple_transform)
loader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=True,
)
test_loader = DataLoader(
testDataSet,
batch_size=4,
shuffle=True
)
# 固定测试样本,可以直接观察每个epoch的变化
fix_zebras, fix_horses = iter(test_loader).next()
fix_zebras = fix_zebras.to(config.Device)
fix_horses = fix_horses.to(config.Device)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
if config.load_num_param > 0:
gen_Z.load_state_dict(torch.load("net/gen_z_%d.pth" % config.load_num_param))
gen_H.load_state_dict(torch.load("net/gen_h_%d.pth" % config.load_num_param))
disc_Z.load_state_dict(torch.load("net/disc_z_%d.pth" % config.load_num_param))
disc_H.load_state_dict(torch.load("net/disc_h_%d.pth" % config.load_num_param))
for epoch in range(config.train_num):
train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

# save model parameters
saveModel(disc_H, os.path.join("./net", "disc_h_%d.pth" % (epoch + config.load_num_param + 1)))
saveModel(disc_Z, os.path.join("./net", "disc_z_%d.pth" % (epoch + config.load_num_param + 1)))
saveModel(gen_H, os.path.join("./net", "gen_h_%d.pth" % (epoch + config.load_num_param + 1)))
saveModel(gen_Z, os.path.join("./net", "gen_z_%d.pth" % (epoch + config.load_num_param + 1)))

# save Pic
with torch.no_grad():
fix_horses = fix_horses.to(config.Device)
fix_zebras = fix_zebras.to(config.Device)
fake_horse = gen_H(fix_zebras)
fake_zebra = gen_Z(fix_horses)
comb = torch.cat([fake_horse, fix_zebras])
comb = torchvision.utils.make_grid(comb, nrow=4)
comb2 = torch.cat([fake_zebra, fix_horses])
comb2 = torchvision.utils.make_grid(comb2, nrow=4)
save_image_tensor(comb, "./pic/zebra2horse/%d_epoch_zebra2horse.jpg" % (epoch + config.load_num_param + 1))
save_image_tensor(comb2, "./pic/horse2zebra/%d_epoch_horse2zebra.jpg" % (epoch + config.load_num_param + 1))


if __name__ == "__main__":
main()

dataset.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from PIL import Image
import os
from torch.utils.data import Dataset


class HorseZebraDataSet(Dataset):
def __init__(self, root_zebra, root_horse, transform=None):
self.root_zebra = root_zebra
self.root_horse = root_horse
self.transform = transform

self.zebra_image = os.listdir(root_zebra)
self.horse_image = os.listdir(root_horse)
self.length_dataset = max(len(self.zebra_image), len(self.horse_image))
self.zebra_len = len(self.zebra_image)
self.horse_len = len(self.horse_image)

def __len__(self):
return self.length_dataset

def __getitem__(self, item):
zebra_img = self.zebra_image[item % self.zebra_len]
horse_img = self.horse_image[item % self.horse_len]
zebra_img = Image.open(os.path.join(self.root_zebra, zebra_img))
horse_img = Image.open(os.path.join(self.root_horse, horse_img))
if len(zebra_img.split()) < 3:
bimg = zebra_img.split()
zebra_img = Image.merge("RGB", [bimg[0], bimg[0], bimg[0]])
if len(horse_img.split()) < 3:
bimg = zebra_img.split()
horse_img = Image.merge("RGB", [bimg[0], bimg[0], bimg[0]])
if self.transform:
zebra_img = self.transform(zebra_img)
horse_img = self.transform(horse_img)
return zebra_img, horse_img

相关链接

论文原文

李宏毅GAN

CycleGAN实现