cvpr2019-STGAN 部分代码

人脸属性转移 训练代码train.py

github地址:https://github.com/csmliu/STGAN

STGAN论文链接:https://arxiv.org/abs/1904.09709

upload successful

数据准备

1
2
3
4
5
xa = tr_data.batch_op[0]   #训练图像
a = tr_data.batch_op[1] #训练图像的标签
b = tf.random_shuffle(a) # 目标对象的标签,用的是random_shuffle,得到的结果不一定是相反的
_a = (tf.to_float(a) * 2 - 1) * thres_int #将标签归一化
_b = (tf.to_float(b) * 2 - 1) * thres_int

生成部分


1
2
3
4
5
6
z = Genc(xa) #数据送入编码器
zb = Gstu(z, _b-_a if label=='diff' else _b) if use_stu else z #标签为'diff'形式和encode结果送入STU模块
xb_ = Gdec(zb, _b-_a if label=='diff' else _b) #STU输出的结果和'diff'形式的标签一起送入解码器
with tf.control_dependencies([xb_]): #形成依赖关系,一定要有xb_才能进行下面操作
za = Gstu(z, _a-_a if label=='diff' else _a) if use_stu else z
xa_ = Gdec(za, _a-_a if label=='diff' else _a) #重建

判别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
xa_logit_gan, xa_logit_att = D(xa)  #判别真假和属性标签  真图像
xb__logit_gan, xb__logit_att = D(xb_) #假图像


if mode == 'wgan': # wgan-gp loss
wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
d_loss_gan = -wd
gp = models.gradient_penalty(D, xa, xb_)

xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) #真图分类标签损失
d_loss = d_loss_gan + gp * 10.0 + xa_loss_att #判别器损失

if mode == 'wgan':
xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) #生成图分类标签损失
xa__loss_rec = tf.losses.absolute_difference(xa, xa_) #重建损失
g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * rec_loss_weight #生成器损失

384 HD-celebA数据集准备

github: https://github.com/LynnHo/HD-CelebA-Cropper

1
2
3
4
import cv2
imread = cv2.imread
imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY]) #参数为写JPG图片的质量默认为95
align_crop = cropper.align_crop_5pts_opencv #声明剪切函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#创建保存文件
save_dir = os.path.join(args.data_dir, 'data_crop_%s_%s' % (args.crop_size, args.save_format))
if not os.path.isdir(save_dir):
os.mkdir(save_dir)

#读取图像和定位
img_dir = os.path.join(args.data_dir, 'img_align_celeba')
landmark_file = os.path.join(args.data_dir, 'list_landmarks_align_celeba.csv')

# np.loadtxt是numpy读文件的,skiprows是跳过的行数,delimiter是每行元素的分隔符, usecols是每行第几个
img_names = np.loadtxt(landmark_file, skiprows=2, usecols=0, dtype=str)
landmarks = np.loadtxt(landmark_file, skiprows=2, delimiter=',', usecols=range(1, 11))
landmarks.shape = -1, 5, 2 #将五个点的x,y坐标排好
mean_lm = cropper._DEFAULT_MEAN_LANDMARKS #平均定位点坐标是固定的数值
1
2
3
4
5
6
from multiprocessing import Pool
pool = Pool(args.n_worker) #多线程
for _ in tqdm(pool.imap(work, range(len(img_names))), total=len(img_names)):
pass
pool.close()
pool.join()
1
2
3
4
5
6
7
8
9
10
11
12
13
def work(i):  # a single work
for _ in range(3): # try three times
img = imread(os.path.join(img_dir, img_names[i]))
img_crop = align_crop(img,
landmarks[i],
mean_lm,
crop_size=args.crop_size,
face_factor=args.face_factor,
landmark_factor=args.landmark_factor,
align_type=args.align_type,
order=args.order,
mode=args.mode) #face_factor=0.65,landmark_factor=0.35
imwrite(os.path.join(save_dir, img_names[i].replace('jpg', args.save_format)), img_crop)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def align_crop_5pts_opencv(img, src_landmarks, mean_landmarks=_DEFAULT_MEAN_LANDMARKS,
crop_size=384, face_factor=0.65,landmark_factor=0.35, align_type='similarity',
order=3, mode='edge'):
move = np.array([img.shape[1] // 2, img.shape[0] // 2]) #图片中心点
v_border = img.shape[0] - crop_size
w_border = img.shape[1] - crop_size

mean_landmarks -= np.array([mean_landmarks[0, :] + mean_landmarks[1, :]]) / 2.0 #两只眼睛的中间作为图像中心
trg_landmarks = mean_landmarks * (crop_size * face_factor * landmark_factor) + move #目标坐标
tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0] #形成仿射矩阵
#修正目标的眼睛中间作为目标图像的中心
trg_mid = (trg_landmarks[0, :] + trg_landmarks[1, :]) / 2.0
src_mid = (src_landmarks[0, :] + src_landmarks[1, :]) / 2.0
new_trg_mid = cv2.transform(np.array([[trg_mid]]), tform)[0, 0]
tform[:, 2] += src_mid - new_trg_mid

output_shape = (crop_size // 2 + move[1] + 1, crop_size // 2 + move[0] + 1)
# cv2.WARP_INVERSE_MAP + cv2.INTER_CUBIC使用插值方式
img_align = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + cv2.INTER_CUBIC,
borderMode=border[mode])

结果
upload successful