图像风格迁移模型-CycleGAN

作者: FutureSI
日期: 2021.03
摘要: 本案例实现了CycleGAN模型用于风格迁移。

一、CycleGAN介绍

CycleGAN,即循环生成对抗网络,是一种用于图片风格迁移的模型。原来的图片风格迁移模型通过在两组一一匹配的图片进行上训练,来学习输入图片组与输出图片组的特征映射关系,从而实现将输入图片的特征迁移到输出图片上,比如将A组图片的斑马的条纹外观特征迁移到B组普通马匹图片上。但是,训练所要求的两组一一对应训练集图片往往难以获得。CycleGAN通过给GAN网络添加循环一致性损失(consistency loss)的方法打破了训练集图片数据的一一对应限制。

二、框架导入设置

  1. # 解压 ai studio 数据集(首次执行后注释)
  2. !unzip -qa -d ~/data/data10040/ ~/data/data10040/horse2zebra.zip
  3. # 如果用wget自行下载数据集需要自行添加训练集列表文件
  4. # !wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
  5. # !unzip -qa -d /home/aistudio/data/data10040/ horse2zebra.zip
  1. import paddle
  2. from paddle.io import Dataset, DataLoader, IterableDataset
  3. import numpy as np
  4. import cv2
  5. import random
  6. import time
  7. import warnings
  8. import matplotlib.pyplot as plt
  9. %matplotlib inline
  10. warnings.filterwarnings("ignore", category=Warning) # 过滤报警信息
  11. BATCH_SIZE = 1
  12. DATA_DIR = '/home/aistudio/data/data10040/horse2zebra/' # 设置训练集数据地址

三、准备数据集

  1. from PIL import Image
  2. from paddle.vision.transforms import RandomCrop
  3. # 处理图片数据:随机裁切、调整图片数据形状、归一化数据
  4. def data_transform(img, output_size):
  5. h, w, _ = img.shape
  6. assert h == w and h >= output_size # check picture size
  7. # random crop
  8. rc = RandomCrop(224)
  9. img = rc(img)
  10. # normalize
  11. img = img / 255. * 2. - 1.
  12. # from [H,W,C] to [C,H,W]
  13. img = np.transpose(img, (2, 0, 1))
  14. # data type
  15. img = img.astype('float32')
  16. return img
  17. # 定义horse2zebra数据集对象
  18. class H2ZDateset(Dataset):
  19. def __init__(self, data_dir):
  20. super(H2ZDateset, self).__init__()
  21. self.data_dir = data_dir
  22. self.pic_list_a = np.loadtxt(data_dir+'trainA.txt', dtype=np.str)
  23. np.random.shuffle(self.pic_list_a)
  24. self.pic_list_b = np.loadtxt(data_dir+'trainB.txt', dtype=np.str)
  25. np.random.shuffle(self.pic_list_b)
  26. self.pic_list_lenth = min(int(self.pic_list_a.shape[0]), int(self.pic_list_b.shape[0]))
  27. def __getitem__(self, idx):
  28. img_dir_a = self.data_dir+self.pic_list_a[idx]
  29. img_a = cv2.imread(img_dir_a)
  30. img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
  31. img_a = data_transform(img_a, 224)
  32. img_dir_b = self.data_dir+self.pic_list_b[idx]
  33. img_b = cv2.imread(img_dir_b)
  34. img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB)
  35. img_b = data_transform(img_b, 224)
  36. return np.array([img_a, img_b])
  37. def __len__(self):
  38. return self.pic_list_lenth
  39. # 定义图片loader
  40. h2zdateset = H2ZDateset(DATA_DIR)
  41. loader = DataLoader(h2zdateset, shuffle=True, batch_size=BATCH_SIZE, drop_last=False, num_workers=0, use_shared_memory=False)
  42. data = next(loader())[0]
  43. data = np.transpose(data, (1, 0, 2, 3, 4))
  44. print("读取的数据形状:", data.shape)
  1. 读取的数据形状: [2, 1, 3, 224, 224]

四、模型组网

4.1 定义辅助功能函数

判别器负责区分图片的“真假”。输入的是训练集图片,判别器的输出越趋近于数值1(即判别此图片为真);如果输入的是生成器生成的图片,判别器的输出越趋近于数值0(即判别此图片为假)。这样,生成器就可以根据判别器输出的变化而计算梯度以优化生成网络。

  1. from PIL import Image
  2. import os
  3. # 打开图片
  4. def open_pic(file_name='./data/data10040/horse2zebra/testA/n02381460_1300.jpg'):
  5. img = Image.open(file_name).resize((256, 256), Image.BILINEAR)
  6. img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5
  7. img = img.transpose((2, 0, 1))
  8. img = img.reshape((-1, img.shape[0], img.shape[1], img.shape[2]))
  9. return img
  10. # 存储图片
  11. def save_pics(pics, file_name='tmp', save_path='./output/pics/', save_root_path='./output/'):
  12. if not os.path.exists(save_root_path):
  13. os.makedirs(save_root_path)
  14. if not os.path.exists(save_path):
  15. os.makedirs(save_path)
  16. for i in range(len(pics)):
  17. pics[i] = pics[i][0]
  18. pic = np.concatenate(tuple(pics), axis=2)
  19. pic = pic.transpose((1,2,0))
  20. pic = (pic + 1) / 2
  21. # plt.imshow(pic)
  22. pic = np.clip(pic * 256, 0, 255)
  23. img = Image.fromarray(pic.astype('uint8')).convert('RGB')
  24. img.save(save_path+file_name+'.jpg')
  25. # 显示图片
  26. def show_pics(pics):
  27. print(pics[0].shape)
  28. plt.figure(figsize=(3 * len(pics), 3), dpi=80)
  29. for i in range(len(pics)):
  30. pics[i] = (pics[i][0].transpose((1,2,0)) + 1) / 2
  31. plt.subplot(1, len(pics), i + 1)
  32. plt.imshow(pics[i])
  33. plt.xticks([])
  34. plt.yticks([])
  35. # 图片缓存队列
  36. class ImagePool(object):
  37. def __init__(self, pool_size=50):
  38. self.pool = []
  39. self.count = 0
  40. self.pool_size = pool_size
  41. def pool_image(self, image):
  42. return image
  43. image = image.numpy()
  44. rtn = ''
  45. if self.count < self.pool_size:
  46. self.pool.append(image)
  47. self.count += 1
  48. rtn = image
  49. else:
  50. p = np.random.rand()
  51. if p > 0.5:
  52. random_id = np.random.randint(0, self.pool_size - 1)
  53. temp = self.pool[random_id]
  54. self.pool[random_id] = image
  55. rtn = temp
  56. else:
  57. rtn = image
  58. return paddle.to_tensor(rtn)

4.2 查看读取的数据集图片

  1. show_pics([data[0].numpy(), data[1].numpy()])
  1. (1, 3, 224, 224)

png

4.3 定义判别器

  1. import paddle
  2. import paddle.nn as nn
  3. import numpy as np
  4. # 定义基础的“卷积层+实例归一化”块
  5. class ConvIN(nn.Layer):
  6. def __init__(self, num_channels, num_filters, filter_size, stride=1, padding=1, bias_attr=None,
  7. weight_attr=None):
  8. super(ConvIN, self).__init__()
  9. model = [
  10. nn.Conv2D(num_channels, num_filters, filter_size, stride=stride, padding=padding,
  11. bias_attr=bias_attr, weight_attr=weight_attr),
  12. nn.InstanceNorm2D(num_filters),
  13. nn.LeakyReLU(negative_slope=0.2)
  14. ]
  15. self.model = nn.Sequential(*model)
  16. def forward(self, x):
  17. return self.model(x)
  18. # 定义CycleGAN的判别器
  19. class Disc(nn.Layer):
  20. def __init__(self, weight_attr=nn.initializer.Normal(0., 0.02)):
  21. super(Disc, self).__init__()
  22. model = [
  23. ConvIN(3, 64, 4, stride=2, padding=1, bias_attr=True, weight_attr=weight_attr),
  24. ConvIN(64, 128, 4, stride=2, padding=1, bias_attr=False, weight_attr=weight_attr),
  25. ConvIN(128, 256, 4, stride=2, padding=1, bias_attr=False, weight_attr=weight_attr),
  26. ConvIN(256, 512, 4, stride=1, padding=1, bias_attr=False, weight_attr=weight_attr),
  27. nn.Conv2D(512, 1, 4, stride=1, padding=1, bias_attr=True, weight_attr=weight_attr)
  28. ]
  29. self.model = nn.Sequential(*model)
  30. def forward(self, x):
  31. return self.model(x)

4.4 测试判别器模块

  1. ci = ConvIN(3, 3, 3, weight_attr=nn.initializer.Normal(0., 0.02))
  2. logit = ci(paddle.to_tensor(data[0]))
  3. print('ConvIN块输出的特征图形状:', logit.shape)
  4. d = Disc()
  5. logit = d(paddle.to_tensor(data[0]))
  6. print('判别器输出的特征图形状:', logit.shape)
  1. ConvIN块输出的特征图形状: [1, 3, 224, 224]
  2. 判别器输出的特征图形状: [1, 1, 26, 26]

4.5 定义生成器

  1. # 定义基础的“转置卷积层+实例归一化”块
  2. class ConvTransIN(nn.Layer):
  3. def __init__(self, num_channels, num_filters, filter_size, stride=1, padding='same', padding_mode='constant',
  4. bias_attr=None, weight_attr=None):
  5. super(ConvTransIN, self).__init__()
  6. model = [
  7. nn.Conv2DTranspose(num_channels, num_filters, filter_size, stride=stride, padding=padding,
  8. bias_attr=bias_attr, weight_attr=weight_attr),
  9. nn.InstanceNorm2D(num_filters),
  10. nn.LeakyReLU(negative_slope=0.2)
  11. ]
  12. self.model = nn.Sequential(*model)
  13. def forward(self, x):
  14. return self.model(x)
  15. # 定义残差块
  16. class Residual(nn.Layer):
  17. def __init__(self, dim, bias_attr=None, weight_attr=None):
  18. super(Residual, self).__init__()
  19. model = [
  20. nn.Conv2D(dim, dim, 3, stride=1, padding=1, padding_mode='reflect', bias_attr=bias_attr,
  21. weight_attr=weight_attr),
  22. nn.InstanceNorm2D(dim),
  23. nn.LeakyReLU(negative_slope=0.2),
  24. ]
  25. self.model = nn.Sequential(*model)
  26. def forward(self, x):
  27. return x + self.model(x)
  28. # 定义CycleGAN的生成器
  29. class Gen(nn.Layer):
  30. def __init__(self, base_dim=64, residual_num=7, downup_layer=2, weight_attr=nn.initializer.Normal(0., 0.02)):
  31. super(Gen, self).__init__()
  32. model=[
  33. nn.Conv2D(3, base_dim, 7, stride=1, padding=3, padding_mode='reflect', bias_attr=False,
  34. weight_attr=weight_attr),
  35. nn.InstanceNorm2D(base_dim),
  36. nn.LeakyReLU(negative_slope=0.2)
  37. ]
  38. # 下采样块(down sampling)
  39. for i in range(downup_layer):
  40. model += [
  41. ConvIN(base_dim * 2 ** i, base_dim * 2 ** (i + 1), 3, stride=2, padding=1, bias_attr=False,
  42. weight_attr=weight_attr),
  43. ]
  44. # 残差块(residual blocks)
  45. for i in range(residual_num):
  46. model += [
  47. Residual(base_dim * 2 ** downup_layer, True, weight_attr=nn.initializer.Normal(0., 0.02))
  48. ]
  49. # 上采样块(up sampling)
  50. for i in range(downup_layer):
  51. model += [
  52. ConvTransIN(base_dim * 2 ** (downup_layer - i), base_dim * 2 ** (downup_layer - i - 1), 3,
  53. stride=2, padding='same', padding_mode='constant', bias_attr=False, weight_attr=weight_attr),
  54. ]
  55. model += [
  56. nn.Conv2D(base_dim, 3, 7, stride=1, padding=3, padding_mode='reflect', bias_attr=True,
  57. weight_attr=nn.initializer.Normal(0., 0.02)),
  58. nn.Tanh()
  59. ]
  60. self.model = nn.Sequential(*model)
  61. def forward(self, x):
  62. return self.model(x)

4.6 测试生成器模块

  1. cti = ConvTransIN(3, 3, 3, stride=2, padding='same', padding_mode='constant', bias_attr=False,
  2. weight_attr=nn.initializer.Normal(0., 0.02))
  3. logit = cti(paddle.to_tensor(data[0]))
  4. print('ConvTransIN块输出的特征图形状:', logit.shape)
  5. r = Residual(3, True, weight_attr=nn.initializer.Normal(0., 0.02))
  6. logit = r(paddle.to_tensor(data[0]))
  7. print('Residual块输出的特征图形状:', logit.shape)
  8. g = Gen()
  9. logit = g(paddle.to_tensor(data[0]))
  10. print('生成器输出的特征图形状:', logit.shape)
  1. ConvTransIN块输出的特征图形状: [1, 3, 448, 448]
  2. Residual块输出的特征图形状: [1, 3, 224, 224]
  3. 生成器输出的特征图形状: [1, 3, 224, 224]

五、训练CycleGAN网络

  1. # 模型训练函数
  2. def train(epoch_num=99999, adv_weight=1, cycle_weight=10, identity_weight=10, \
  3. load_model=False, model_path='./model/', model_path_bkp='./model_bkp/', \
  4. print_interval=1, max_step=5, model_bkp_interval=2000):
  5. # 定义两对生成器、判别器对象
  6. g_a = Gen()
  7. g_b = Gen()
  8. d_a = Disc()
  9. d_b = Disc()
  10. # 定义数据读取器
  11. dataset = H2ZDateset(DATA_DIR)
  12. reader_ab = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=False, num_workers=2)
  13. # 定义优化器
  14. g_a_optimizer = paddle.optimizer.Adam(learning_rate=0.0002, beta1=0.5, beta2=0.999, parameters=g_a.parameters())
  15. g_b_optimizer = paddle.optimizer.Adam(learning_rate=0.0002, beta1=0.5, beta2=0.999, parameters=g_b.parameters())
  16. d_a_optimizer = paddle.optimizer.Adam(learning_rate=0.0002, beta1=0.5, beta2=0.999, parameters=d_a.parameters())
  17. d_b_optimizer = paddle.optimizer.Adam(learning_rate=0.0002, beta1=0.5, beta2=0.999, parameters=d_b.parameters())
  18. # 定义图片缓存队列
  19. fa_pool, fb_pool = ImagePool(), ImagePool()
  20. # 定义总迭代次数为0
  21. total_step_num = np.array([0])
  22. # 加载存储的模型
  23. if load_model == True:
  24. ga_para_dict = paddle.load(model_path+'gen_b2a.pdparams')
  25. g_a.set_state_dict(ga_para_dict)
  26. gb_para_dict = paddle.load(model_path+'gen_a2b.pdparams')
  27. g_b.set_state_dict(gb_para_dict)
  28. da_para_dict = paddle.load(model_path+'dis_ga.pdparams')
  29. d_a.set_state_dict(da_para_dict)
  30. db_para_dict = paddle.load(model_path+'dis_gb.pdparams')
  31. d_b.set_state_dict(db_para_dict)
  32. total_step_num = np.load('./model/total_step_num.npy')
  33. # 定义本次训练开始时的迭代次数
  34. step = total_step_num[0]
  35. # 开始模型训练循环
  36. print('Start time :', time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 'start step:', step + 1)
  37. for epoch in range(epoch_num):
  38. for data_ab in reader_ab:
  39. step += 1
  40. # 设置模型为训练模式,针对bn、dropout等进行不同处理
  41. g_a.train()
  42. g_b.train()
  43. d_a.train()
  44. d_b.train()
  45. # 得到A、B组图片数据
  46. data_ab = np.transpose(data_ab[0], (1, 0, 2, 3, 4))
  47. img_ra = paddle.to_tensor(data_ab[0])
  48. img_rb = paddle.to_tensor(data_ab[1])
  49. # 训练判别器DA
  50. d_loss_ra = paddle.mean((d_a(img_ra.detach()) - 1) ** 2)
  51. d_loss_fa = paddle.mean(d_a(fa_pool.pool_image(g_a(img_rb.detach()))) ** 2)
  52. da_loss = (d_loss_ra + d_loss_fa) * 0.5
  53. da_loss.backward() # 反向更新梯度
  54. d_a_optimizer.step() # 更新模型权重
  55. d_a_optimizer.clear_grad() # 清除梯度
  56. # 训练判别器DB
  57. d_loss_rb = paddle.mean((d_b(img_rb.detach()) - 1) ** 2)
  58. d_loss_fb = paddle.mean(d_b(fb_pool.pool_image(g_b(img_ra.detach()))) ** 2)
  59. db_loss = (d_loss_rb + d_loss_fb) * 0.5
  60. db_loss.backward()
  61. d_b_optimizer.step()
  62. d_b_optimizer.clear_grad()
  63. # 训练生成器GA
  64. ga_gan_loss = paddle.mean((d_a(g_a(img_rb.detach())) - 1) ** 2)
  65. ga_cyc_loss = paddle.mean(paddle.abs(img_rb.detach() - g_b(g_a(img_rb.detach()))))
  66. ga_ide_loss = paddle.mean(paddle.abs(img_ra.detach() - g_a(img_ra.detach())))
  67. ga_loss = ga_gan_loss * adv_weight + ga_cyc_loss * cycle_weight + ga_ide_loss * identity_weight
  68. ga_loss.backward()
  69. g_a_optimizer.step()
  70. g_a_optimizer.clear_grad()
  71. # 训练生成器GB
  72. gb_gan_loss = paddle.mean((d_b(g_b(img_ra.detach())) - 1) ** 2)
  73. gb_cyc_loss = paddle.mean(paddle.abs(img_ra.detach() - g_a(g_b(img_ra.detach()))))
  74. gb_ide_loss = paddle.mean(paddle.abs(img_rb.detach() - g_b(img_rb.detach())))
  75. gb_loss = gb_gan_loss * adv_weight + gb_cyc_loss * cycle_weight + gb_ide_loss * identity_weight
  76. gb_loss.backward()
  77. g_b_optimizer.step()
  78. g_b_optimizer.clear_grad()
  79. # 存储训练过程中生成的图片
  80. if step in range(1, 101):
  81. pic_save_interval = 1
  82. elif step in range(101, 1001):
  83. pic_save_interval = 10
  84. elif step in range(1001, 10001):
  85. pic_save_interval = 100
  86. else:
  87. pic_save_interval = 500
  88. if step % pic_save_interval == 0:
  89. save_pics([img_ra.numpy(), g_b(img_ra).numpy(), g_a(g_b(img_ra)).numpy(), g_b(img_rb).numpy(), \
  90. img_rb.numpy(), g_a(img_rb).numpy(), g_b(g_a(img_rb)).numpy(), g_a(img_ra).numpy()], \
  91. str(step))
  92. test_pic = open_pic()
  93. test_pic_pp = paddle.to_tensor(test_pic)
  94. save_pics([test_pic, g_b(test_pic_pp).numpy()], str(step), save_path='./output/pics_test/')
  95. # 打印训练过程中的loss值和生成的图片
  96. if step % print_interval == 0:
  97. print([step], \
  98. 'DA:', da_loss.numpy(), \
  99. 'DB:', db_loss.numpy(), \
  100. 'GA:', ga_loss.numpy(), \
  101. 'GB:', gb_loss.numpy(), \
  102. time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
  103. show_pics([img_ra.numpy(), g_b(img_ra).numpy(), g_a(g_b(img_ra)).numpy(), g_b(img_rb).numpy()])
  104. show_pics([img_rb.numpy(), g_a(img_rb).numpy(), g_b(g_a(img_rb)).numpy(), g_a(img_ra).numpy()])
  105. # 定期备份模型
  106. if step % model_bkp_interval == 0:
  107. paddle.save(g_a.state_dict(), model_path_bkp+'gen_b2a.pdparams')
  108. paddle.save(g_b.state_dict(), model_path_bkp+'gen_a2b.pdparams')
  109. paddle.save(d_a.state_dict(), model_path_bkp+'dis_ga.pdparams')
  110. paddle.save(d_b.state_dict(), model_path_bkp+'dis_gb.pdparams')
  111. np.save(model_path_bkp+'total_step_num', np.array([step]))
  112. # 完成训练时存储模型
  113. if step >= max_step + total_step_num[0]:
  114. paddle.save(g_a.state_dict(), model_path+'gen_b2a.pdparams')
  115. paddle.save(g_b.state_dict(), model_path+'gen_a2b.pdparams')
  116. paddle.save(d_a.state_dict(), model_path+'dis_ga.pdparams')
  117. paddle.save(d_b.state_dict(), model_path+'dis_gb.pdparams')
  118. np.save(model_path+'total_step_num', np.array([step]))
  119. print('End time :', time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 'End Step:', step)
  120. return
  121. # 从头训练
  122. train()
  123. # 继续训练
  124. # train(print_interval=1, max_step=5, load_model=True)
  125. # train(print_interval=500, max_step=20000, load_model=True)
  1. Start time : 2021-03-10 11:36:45 start step: 1
  2. [1] DA: [1.5323195] DB: [2.9221125] GA: [13.066509] GB: [20.061096] 2021-03-10 11:36:46
  3. (1, 3, 224, 224)
  4. (1, 3, 224, 224)
  5. [2] DA: [3.431984] DB: [4.0848613] GA: [13.800614] GB: [12.840221] 2021-03-10 11:36:46
  6. (1, 3, 224, 224)
  7. (1, 3, 224, 224)
  8. [3] DA: [3.3024106] DB: [2.2502034] GA: [12.881987] GB: [12.331587] 2021-03-10 11:36:47
  9. (1, 3, 224, 224)
  10. (1, 3, 224, 224)
  11. [4] DA: [3.911097] DB: [1.5154138] GA: [12.64529] GB: [14.333654] 2021-03-10 11:36:47
  12. (1, 3, 224, 224)
  13. (1, 3, 224, 224)
  14. [5] DA: [1.9493798] DB: [1.8769395] GA: [14.874502] GB: [11.431137] 2021-03-10 11:36:48
  15. (1, 3, 224, 224)
  16. (1, 3, 224, 224)
  17. End time : 2021-03-10 11:36:48 End Step: 5

png

png

png

png

png

png

png

png

png

png

六、用训练好的模型进行预测

  1. def infer(img_path, model_path='./model/'):
  2. # 定义生成器对象
  3. g_b = Gen()
  4. # 设置模型为训练模式,针对bn、dropout等进行不同处理
  5. g_b.eval()
  6. # 读取存储的模型
  7. gb_para_dict = paddle.load(model_path+'gen_a2b.pdparams')
  8. g_b.set_state_dict(gb_para_dict)
  9. # 读取图片数据
  10. img_a = cv2.imread(img_path)
  11. img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
  12. img_a = data_transform(img_a, 224)
  13. img_a = paddle.to_tensor(np.array([img_a]))
  14. # 正向计算进行推理
  15. img_b = g_b(img_a)
  16. # 打印输出输入、输出图片
  17. print(img_a.numpy().shape, img_a.numpy().dtype)
  18. show_pics([img_a.numpy(), img_b.numpy()])
  19. infer('./data/data10040/horse2zebra/testA/n02381460_1300.jpg')
  1. (1, 3, 224, 224) float32

png