CGAN学习及实现

  1. 1. CGAN学习及实现
    1. 1.1. 效果图
    2. 1.2. CGAN思想
    3. 1.3. 训练的细节
      1. 1.3.1. 训练目标
      2. 1.3.2. 使用的数据集
      3. 1.3.3. 开始训练
    4. 1.4. 代码实现
    5. 1.5. 相关链接

CGAN学习及实现

在GAN的学习中,生成了随机的二次元人物头像。人物随机生成,随机生成的人物不可控,如何让Generator按照我们想的样子的二次元头像呢。这里使用了CGAN(conditional GAN)条件GAN完成这个任务。

论文地址

效果图

这个是跑了150个epoch的效果,训练集和写GAN的那篇用的是一样的。

CGAN思想

这个图网上到处都有,一搜全都是。我也简单说一下吧。CGAN是在原始GAN的基础上,在训练生成器G和判别器D的时候都加入了新的条件变量y,y可以是各种信息,标签等信息。y在训练过程中起到指导作用。

图中的做法是,对于生成器G,将y条件信息,和noise拼接,放入生成器中用来生成图片。对于判别器D,将y条件信息和x图像信息拼接。拼接方式可以是用全连接将类别转为相应大小的图片,拼接到原图片的通道当中,然后继续和GAN的D一样进行评分。这个是网上很多人的做法,我的代码也使用的这个。不过我觉得也可以是,图片经过一个网络生成embedding,label通过一个网络生成一个Embedding然后两个特征相组合在进行评分也可以。

训练的细节

训练目标

根据输入的头发颜色,眼睛颜色,让Generator生成出对应的二次元头像。

使用的数据集

这里用的是和GAN那篇一样的,需要用到tags.csv

csv

最上面的一行是我自己加的。tags格式为:图片号,头发颜色 hair 眼睛颜色 eyes。读取csv使用的是pandas的包,没有的可以自行安装

1
pip install pandas

读csv格式标签的方法。

1
2
3
import pandas as pd
df = pd.read_csv(csv_path) # 读csv文件
label = df.iloc[item, 1] # 获取第item行,第1列,列数是以逗号分隔

开始训练

这里是训练判别器D的方式,不同于GAN那样,只判断图像是否为二次元头像。这里需要想到几个问题:这个图片是否清晰,这个label条件是否正确。需要考虑三种情况,也就是三个loss:

第一种,我们先给D一个训练集当中的图片,和对应的label。此时是最好的情况(图片清晰,条件正确),我们应当给高分,让它靠近1.

第二种,我们给G一个随机noise加上当前的label,让它产生out图片,此时的图片变得不清晰,但是label是正确的。将out图片和label送给D评分,此时分数应当降低,让它靠近0

第三种,我们重新从数据集当中选出和当前label不一样的图片,这时属于是label错误(图片足够清晰,条件错误),D同样需要给他低分。

这个是两种评分方式,第一种上面的是常用方法,图片和条件分别产生Embedding然后合拼到新的网络中产生分数。第二种是李宏毅老师说的方法,觉得更符合逻辑一些。这个做法是,一个图片经过一个网络,产生一个图像分数(用来判定图像是否清晰)和一个Embedding特征,将特征和条件结合,进入新的网络,产生图像和条件是否满足的分数。这样让网络知道是哪个部分导致的分数降低,从而达到更好的效果。这个我打算去实现以下,做个尝试。

这个是整个训练过程:(c代表条件label,x代表图片,z代表随机噪声)

D训练

  1. 从数据集选择m个正向样本。
  2. 产生m个随机noise样本
  3. 获取生成数据,将noise和条件c加入G中产生新的图片
  4. 从数据集中获取m个新的图片。
  5. 更新判别器参数,下面的公式可以那GAN那篇的方式来看。

G训练

  1. 产生m个随机噪声样本noise
  2. 从数据集中拿到m个条件c
  3. 用noise和c产生图片,放入D中评分,让评分靠近1

代码实现

一些参数配置,这里用的label是one-hot形式,2 x 13的大小,总共13种颜色,第一行是头发颜色,第二行是眼睛颜色。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def generateLabel(hair_colors, eye_colors, one_hot=False):
hlen = len(hair_colors)
elen = len(eye_colors)
if hlen != elen:
print("头发和眼睛数量不匹配")
return None
# 转为one-hot label
colorClass = {'white': 0, 'blonde': 1, 'aqua': 2, 'gray': 3, 'yellow': 4, 'black': 5,
'blue': 6, 'brown': 7, 'green': 8, 'pink': 9, 'purple': 10, 'red': 11, 'orange': 12}
if one_hot:
# one_hot label
result = torch.empty(hlen, 2, 13)
for i in range(hlen):
hair = torch.zeros(len(colorClass))
hair[colorClass[hair_colors[i]]] = 1
eye = torch.zeros(len(colorClass))
eye[colorClass[eye_colors[i]]] = 1
t = torch.cat([hair, eye], dim=0)
t = t.view(-1, 13)
result.data[i].copy_(t)
else:
result = torch.empty(hlen, 2, 1)
for i in range(hlen):
t = torch.tensor([
colorClass[hair_colors[i]],
colorClass[eye_colors[i]]
], dtype=torch.float).view(2, 1)
result.data[i].copy_(t)
return result

batch_size = 20 # 批数量
simple_transform = transform.Compose([
transform.Resize((64, 64)),
transform.ToTensor(),
transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) # 图片改大小,转Tensor,转[-1, 1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 使用GPU or CPU
noise_z = 100 # noise维度
generator_feature_map = 64 # 特征数
path = "cartoon" # 数据集路径
train_set = CGANDataSet(pic_dir=os.path.join(path,"images"), label_csv_file=os.path.join(path, "tags.csv"), transform=simple_transform) # 数据集
train_loader = Data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
train_loader2 = Data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
# 正确标签
true_label = torch.ones(batch_size).to(device)
true_label = true_label.view(-1, 1, 1, 1)
# 错误标签
false_label = torch.zeros(batch_size).to(device)
false_label = false_label.view(-1, 1, 1, 1)
# 固定noise,condition 用于每个epoch产生相同图片,用来看效果
fix_noises = torch.randn(4, noise_z, 1, 1).to(device)
hair_colors = ['blue', 'black', 'blue', 'white']
eye_colors = ['blue', 'orange', 'red', 'blue']
fix_condition = generateLabel(hair_colors, eye_colors, one_hot=True).to(device)
# 随机noise
noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
g_train_cycle = 1 # 训练生成器周期
save_img_cycle = 1 # 每几次epoch输出一次结果
print_step = 200 # 打印loss 信息周期
load_last_param = 0 # 是否加载最新的参数 0为不加载,从头训练
train_num = 100 # 训练轮数,几个epoch
bceloss = nn.BCELoss() # 同样使用BCEloss
learning_rate = 0.0002 # 学习率
beta = 0.5

ConditionalGANSet.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
from PIL import Image

class CGANDataSet(Dataset):
def __init__(self, pic_dir, label_csv_file, transform=None):
self.pic_dir = pic_dir
self.label_file = label_csv_file
self.pic_files = os.listdir(pic_dir)
self.df = pd.read_csv(label_csv_file)
self.transform = transform
self.colorClass = {'white': 0, 'blonde': 1, 'aqua': 2, 'gray': 3, 'yellow': 4, 'black': 5,
'blue': 6, 'brown': 7, 'green': 8, 'pink': 9, 'purple': 10, 'red': 11, 'orange': 12}

def __len__(self):
return len(self.pic_files)

def __getitem__(self, item):
image = Image.open(os.path.join(self.pic_dir, self.pic_files[item]))
imagename = self.pic_files[item].split('.')
imagename = int(imagename[0])
label = self.df.iloc[imagename, 1]
label = label.split(' ')
# one_hot 形式
hair = torch.zeros(len(self.colorClass))
hair[self.colorClass[label[0]]] = 1
eye = torch.zeros(len(self.colorClass))
eye[self.colorClass[label[2]]] = 1
label = torch.cat([hair, eye], dim=0)
label = label.view(-1, len(self.colorClass))
if self.transform:
image = self.transform(image)
return image, label

Generator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.layer1 = nn.Sequential(
# 100*1*1 --> (64 * 8) * 4 *4
nn.ConvTranspose2d(noise_z + 2 * 13, generator_feature_map * 8, kernel_size=4, bias=False),
nn.BatchNorm2d(generator_feature_map * 8),
nn.ReLU(True))
self.layer2 = nn.Sequential(
# (64 * 8) * 4 * 4 --> (64 * 4)*8*8
nn.ConvTranspose2d(generator_feature_map * 8, generator_feature_map * 4, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(generator_feature_map * 4),
nn.ReLU(True))
self.layer3 = nn.Sequential(

# (64*4)*8*8 --> (64*2)*16*16
nn.ConvTranspose2d(generator_feature_map * 4, generator_feature_map * 2, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(generator_feature_map * 2),
nn.ReLU(True))
self.layer4 = nn.Sequential(

# (64*2)*16*16 --> 64*32*32
nn.ConvTranspose2d(generator_feature_map * 2, generator_feature_map, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(generator_feature_map),
nn.ReLU(True))
self.layer5 = nn.Sequential(
# 64*32*32 --> 3*64*64
nn.ConvTranspose2d(generator_feature_map, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)

def forward(self, x, label):
label = label.view(-1, 2 * 13, 1, 1)
x = torch.cat([x, label], dim=1)
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
return out

Discriminator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# 定义鉴别器网络D
class Discriminator(nn.Module):
def __init__(self, ndf=64):
super(Discriminator, self).__init__()
self.label_embedding = nn.Linear(2 * 13, 64 * 64)
# 图像先进入一个Network,产生一个Embedding
self.image_network = nn.Sequential(
# (3, 64, 64) --> (64, 32, 32)
nn.Conv2d(3 + 1, ndf, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True),
# (64, 32, 32) --> (128, 16, 16)
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# (128, 16, 16) --> (256, 8, 8)
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# (256, 8, 8) --> (512, 4, 4)
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# (512, 4, 4) --> (1, 1, 1)
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

# 定义NetD的前向传播
def forward(self, x, label):
label = label.view(-1, 2*13)
label = self.label_embedding(label)
label = label.view(-1, 1, 64, 64)
x = torch.cat([x, label], dim=1)
out = self.image_network(x)
return out

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 加载网络参数部分
if load_last_param > 0:
dState = torch.load("CGANOutPut/net/netd_%d.pth" % load_last_param)
gState = torch.load("CGANOutPut/net/netg_%d.pth" % load_last_param)
generator.load_state_dict(gState)
discriminator.load_state_dict(dState)

g_optim = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta, 0.999))
d_optim = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta, 0.999))
# 训练Discriminator
ite = iter(train_loader2)
step = 0
for trainIdx in range(train_num):
for step, data in enumerate(train_loader):
image_x, label = data
image_x = image_x.to(device)
label = label.to(device)
# 训练判别器
noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
# positive examples --> 正样本
positive_score = discriminator(image_x, label)
# 正样本的标签和随机noise samples,产生fake图片
fake_pic = generator(noises, label)
# 由于图像模糊导致分数低的情况。
low_img_score = discriminator(fake_pic.detach(), label)
# 再次从数据集中拿出样本, 要求标签和当前标签不一样。 image3为通过label筛选出与label不一样的图片
image3 = torch.empty_like(image_x)
index = 0
while True:
try:
image2, label2 = ite.next()
except StopIteration:
ite = iter(train_loader2)
image2, label2 = ite.next()
label2 = label2.to(device)
for i in range(label.shape[0]):
if torch.argmax(label[i][0]) == torch.argmax(label2[i][0]) and torch.argmax(label[i][1]) == torch.argmax(label2[i][1]):
continue
image3.data[index].copy_(image2.data[i])
index += 1
if index == batch_size:
break
if index == batch_size:
break
image3 = image3.to(device)
# 图片足够清晰,但是label不匹配导致的分数低的情况
wrong_label_score = discriminator(image3, label)
# 求梯度
d_optim.zero_grad()
real_loss = bceloss(positive_score, true_label)
low_img_loss = bceloss(low_img_score, false_label)
wrong_label_loss = bceloss(wrong_label_score, false_label)
d_loss = real_loss + low_img_loss + wrong_label_loss

d_loss.backward()
d_optim.step()

if step % g_train_cycle == 0:
# 训练生成器
try:
_, lbs = ite.next()
except StopIteration:
ite = iter(train_loader2)
_, lbs = ite.next()
lbs = lbs.to(device)
g_optim.zero_grad()
noises.data.copy_(torch.randn(batch_size, noise_z, 1, 1))
fake_img = generator(noises, lbs)
fake_out = discriminator(fake_img, lbs)
loss_fake = bceloss(fake_out, true_label)
loss_fake.backward()
g_optim.step()

if step % print_step == print_step - 1:
print("train: ", trainIdx, "step: ", step + 1, " d_loss: ", real_loss.item(), "mean score: ",
torch.mean(positive_score).item())
print("train: ", trainIdx, "step: ", step + 1, " g_loss: ", loss_fake.item(), "mean score: ",
torch.mean(fake_out).item())
# 每间隔save_img_cycle个epoch,进行保存网络参数以及样例图像。
if trainIdx % save_img_cycle == 0:
with torch.no_grad():
fix_fake_image = generator(fix_noises, fix_condition)
fix_fake_image = fix_fake_image.data.cpu()
comb_img = torchvision.utils.make_grid(fix_fake_image, nrow=4)
savepath = os.path.join("CGANOutPut", "pics", "g_%s.jpg" % (trainIdx + load_last_param + 1))
saveImg(comb_img, savepath)
torch.save(discriminator.state_dict(), './CGANOutPut/net/netd_%s.pth' % (trainIdx + load_last_param + 1))
torch.save(generator.state_dict(), './CGANOutPut/net/netg_%s.pth' % (trainIdx + load_last_param + 1))

相关链接

李宏毅GAN视频