可随意转载! Update2022.10.11

前言

在之前的文章中,我们学习了DatasetGAN半监督生成语义分割数据集(image+semantic label pair),本文将要介绍的EditGAN利用了DatasetGAN生成的成对数据,通过训练edit vector,在推理时可以利用预训练的edit vector实时编辑(或者在线finetune)编辑目标图片。它的优点是对edit vector解偶能力,可高精度编辑图片。

论文标题:《EditGAN: High-Precision Semantic Image Editing》

简介

EditGAN基本方法:把输入图片x用FPNEncoder训练模型嵌入latent空间,作为”edit vector”; 训练一个semantic网络分支,利用联合概率的特性,修改语义图,通过训练好的网络影响RGB图片(达到了编辑图片的效果)。”edit vector”既可以直接用来生成图片,也可以继续finetune(默认30步)以提高编辑精度。

相关工作

学术上GAN图片编辑一般分为以下几种:

  • 利用类标签或者像素级语义标签
  • 辅助属性分类器
  • 在源和第三方图片间混合(mixing)或者插值(interpolating),用第三方图片来控制目标
  • 通过找latent code变量或者网络参数

EditGAN采用的是不同于它们的方法,它依赖于源图和对应的语义分割图的联合概率P(x, y); x是源图,y是语义分割图,且它们的latent code相同。

EditGAN的原理

人为修改语义分割yedit后,因为语义分割图和源图共享latent code,那么我们只要训练网络,优化以下参数:

new w+edit = w++δw+edit ,以匹配yedit的变化。

然而网络可学习的参数只有 δ w+≈ δ w+edit,因此在推理的时候完全可以直接用“edit vector”。当然作者也说了自监督优化一下更好(更慢)。所以EditGAN有三种模式:

  1. 用训练好的”edit vector” 可以直接做训练
  2. 用训练好的”edit vector” 参数初始化网络,做refine
  3. 从零开始训练&推理

EditGAN的训练步骤

  1. 复用StyleGAN2模型
  2. 训练FPNEncoder模型,把图片z映射到w+
  3. 利用StyleGAN2模型和FPNEnocer模型把DatasetGAN数据编码到w+空间
  4. 训练DatasetGAN模型
  5. 运行web界面,把StyleGAN2模型,FPNEncoder模型,DatasetGAN模型和FPNEncoder编码的w+图片都用起来

预训练数据

  • checkpoint/stylegan_pretrain
  • checkpoint/encoder_pretrain/training_embedding 包含masks(png文件),一些npy文件
  • checkpoint/datasetgan_pretrain 预训练的的语义分类器(decoder,classifier)

总结

这篇论文公式很少,也没有画任何网络结构,主要是在做实验。论文说除了各种实验,最终实验结果花了3500小时的GPU。

代码精读

页面流程run_app.py

api/index
api/edit_from_mask
api/apply_edit_vector

代码问题

run_optimization_post_process()方法对应页面上finetune下拉框(选择:0次),这是上面介绍的EditGAN的第一种模式。

run_optimization_editGAN() 方法实时做steps=30次优化动作,这是上面论文介绍的EditGAN第二种模式。

这两个方法都会爆显存(3090 24g),论文作者建议去改datasetGAN的输入尺寸(默认1024×1024)和MODEL_NUMBER。

经过定位爆显存的问题是在run_optimization_editGAN()run_seg()方法,如下:

# 16个affine,拼接的image_features很大,里面存的是Tensor(显存)
# print(image_features.element_size()*image_features.nelement() //1024//1024)
# run_optimization_editGAN 6G;run_seg() 4G  

image_features = []       
for i in range(len(affine_layers)):
   image_features.append(self.bi_upsamplers[i](
       affine_layers[i]))
image_features = torch.cat(image_features, 1)
image_features = image_features[:, :, 64:448]
image_features = image_features[0]
image_features = image_features.reshape(self.args['dim'], -1).transpose(1, 0)

有三种修改方法:

  • 上策:写一个torch.cat 的c++扩展算子,避免内存拷贝
  • 中策:按附录的做法,先开一块内存然后直接把Tensor拼接上去
  • 下策:每循环一次就做一次gc

a) train_encoder.py代码精读

main方法训练的是FPNEncoder网络(取名叫stylegan_encoder)。

test方法用训练好的FPNEncoder模型,通过embed_one_example()把图片转换到w+空间

FPNEncoder是从SemanticGAN改过来的(DatasetGAN中没有)。

FPN(特征金字塔)是来自目标检测方向论文《Feature pyramid networks for object detection》

b) train_interpreter.py代码精读

训练的是分类器pixel_classifier,就是semantic分支

prepare_data是从DatasetGAN改过来的。它的prepare_model相当与DatasetGAN的prepare_stylegan, 只是因为它支持StyleGAN2和StyleGAN1切换,而DatasetGAN只支持StyleGAN1。

c) EditGAN_tool.py代码精读

Tool类初始化过程如下:

加载配置文件”experiments/tool_car.json”,读取配置项’encoder_checkpoint’, ‘classfier_checkpoint’;创建四个目录:editing_vector_path,sampling_path,result_path,upload_latent_path。

# 准备模型
self.g_all, self.upsamplers, self.bi_upsamplers, self.classifier_list, self.avg_latent =prepare_model(self.args,classfier_checkpoint,self.args['classifier_iter'],num_class,self.num_classifier)
# Interpolate网络(上采样,等同与upsample,pytorch建议用这个api)
self.inter = Interpolate(self.args['im_size'][1], 'bilinear')
# FPN网络,这个在DatasetGAN里面还没有
self.stylegan_encoder = FPNEncoder(3, n_latent=self.args['n_latent'], only_last_layer=self.args['use_w'])
# 定义lpips损失(人类可视的一种损失函数定义)        
self.percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True,
                                       normalize=self.args['normalize']).to(device)

附录

torch.cat减少显存拷贝的方法

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

def high_mem():
    mylist = []
    for i in range(10):
        a = torch.zeros(20000,10000).to(device)
        mylist.append(a)

    print("memory: {}M".format(mylist[0].element_size()*mylist[0].nelement() //1024 //1024 *10))
    mylist = torch.cat(mylist, 1)
    # del(mylist)
    # torch.cuda.empty_cache()
    print("memory: {}M".format(mylist.element_size()*mylist.nelement() //1024 //1024))

def low_mem():
    x = 20000
    y = 10000
    mylist = torch.zeros(x*10, y).to(device)
    for i in range(10):
        start = i * x
        end = (i+1) * x
        mylist[ start:end, :] = torch.ones(20000,10000).to(device)

    print("memory: {}M".format(mylist.element_size()*mylist.nelement() //1024 //1024))


if __name__ == '__main__':
    # high_mem()
    low_mem()