如何使用预训练的权重从自定义数据集中生成图像?


来源:磐创AI   时间:2021-04-08 10:55:36


概要

分享我的知识,使用带有示例代码片段的迁移学习逐步在Google colab中的自定义数据集上训练StyleGAN

如何使用预训练的权重从自定义数据集中生成图像

使用不同的种子值生成新图像

介绍

生成对抗网络(GAN) 是机器学习中的一项最新创新,由 Ian J. Goodfellow 及其同事于2014年首次提出。

它是一组神经网络,以两人零和博弈的形式相互对抗。博弈论(一个人的胜利就是另一个人的损失)。它是用于无监督学习的生成模型的一种形式。这里有一个生成器(用于从潜在空间中的某个点在数据上生成新实例)和鉴别器(用于将生成器生成的数据与实际或真实数据值区分开)。

最初,生成器生成虚假或伪造的数据,鉴别器可以将其分类为伪造,但是随着训练的继续,生成器开始学习真实数据的分布并开始生成真实的数据。这种情况一直持续到鉴别器无法将其分类为不真实的并且生成器输出的所有数据看起来都像真实数据。因此,此处生成器的输出连接到鉴别器的输入,并根据鉴别器的输出(是实数还是非实数)计算损失,并通过反向传播,为后续训练(epoch)更新生成器的权重。

StyleGAN目前在市场上有多种GAN变体,但在本文中,我将重点介绍Nvidia在2018年12月推出的StyleGAN。StyleGAN的体系结构使用基线渐进式GAN。即,生成图像的大小从非常低的角度逐渐增加分辨率(4×4)到非常高的分辨率(1024×1024),并使用双线性采样代替基线渐进式GAN中使用的最近邻居上/下采样。

迁移学习在另一个相似的数据集上使用已训练的模型权重并训练自定义数据集。

自定义数据集包含2500个来自时尚的纹理图像。下面几张示例纹理图像可供参考。此处你可以替换成自己的自定义数据集。

重点和前提条件:必须使用GPU,StyleGAN无法在CPU环境中进行训练。为了演示,我已经使用google colab环境进行实验和学习。

确保选择Tensorflow版本1.15.2。StyleGAN仅适用于tf 1.x

StyleGAN训练将花费大量时间(几天之内取决于服务器容量,例如1个GPU,2个GPU等)

如果你正在从事与GAN相关的任何实时项目,那么由于colab中的使用限制和超时,你可能想在 tesla P-80或 P-100专用服务器上训练GAN 。

如果你有google-pro(不是强制性的),则可以节省多达40-50%的本文训练时间 ,我对GAN进行了3500次迭代训练,因为训练整个GAN需要很长时间(要获取高分辨率图像),则需要至少运行25000次迭代(推荐)。另外,我的图像分辨率是64×64,但是styleGAN是在1024×1024分辨率图像上训练的。

我已使用以下预先训练的权重来训练我的自定义数据集(有关更多详细信息,请参见Tensorflow Github官方链接)

https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ使用迁移学习在Google Colab中的自定义数据集上训练style GAN打开colab并打开一个新的botebook。确保在Runtime->Change Runtime type->Hardware accelerator下设置为GPU验证你的帐户并装载G驱动器from google.colab import drivedrive.mount('/content/drive', force_remount=True)确保选择了Tensorflow版本1.15.2。StyleGAN仅适用于tf1.x。%tensorflow_version 1.ximport tensorflowprint(tensorflow.__version__)从 https://github.com/NVlabs/stylegan 克隆stylegan.git!git clone https://github.com/NVlabs/stylegan.git!ls /content/stylegan/You should see something like thisconfig.py LICENSE.txt run_metrics.pydataset_tool.py metrics stylegan-teaser.pngdnnlib pretrained_example.py traininggenerate_figures.py README.md train.py5. 将 stylegan文件夹添加到python,以导入dnnlib模块import syssys.path.insert(0, "/content/stylegan")import dnnlib6. 将自定义数据集从G驱动器提取到你选择的colab服务器文件夹中!unrar x "/content/drive/My Drive/CustomDataset.rar" "/content/CData/"7. Stylegan要求图像必须是正方形,并且为获得很好的分辨率,图像必须为1024×1024。但是在本演示中,我将使用64×64的分辨率,下一步是将所有图像调整为该分辨率。# resize all the images to same sizeimport osfrom tqdm import tqdmimport cv2from PIL import Imagefrom resizeimage import resizeimagepath = '/content/CData/'for filename in tqdm(os.listdir(path),desc ='reading images ...'):image = Image.open(path+filename)image = image.resize((64,64))image.save(path+filename, image.format)8.将自定义数据集复制到colab并调整大小后,使用以下命令将自定义图像转换为tfrecords。这是StyleGAN的要求,因此此步骤对于训练StyleGAN是必不可少的。! python /content/stylegan/dataset_tool.py create_from_images /content/stylegan/datasets/custom-dataset /content/texturereplace your custom dataset path (instead of /content/texture)9.一旦成功创建了tfrecords,你应该查看它们/content/stylegan/datasets/custom-dataset/custom-dataset-r02.tfrecords - 22/content/stylegan/datasets/custom-dataset/custom-dataset-r03.tfrecords - 23/content/stylegan/datasets/custom-dataset/custom-dataset-r04.tfrecords -24/content/stylegan/datasets/custom-dataset/custom-dataset-r05.tfrecords -25/content/stylegan/datasets/custom-dataset/custom-dataset-r06.tfrecords -26These tfrecords correspond to 4x4 , 8x8 ,16x16, 32x32 and 64x64 resolution images (baseline progressive) respectiviely

免责声明:本网站所有信息仅供参考,不做交易和服务的根据,如自行使用本网资料发生偏差,本站概不负责,亦不负任何法律责任。

延伸阅读

最新文章

国家能源局原副局长张玉清:综合能源服务发展有哪些趋势? 国家能源局原副局长张玉清:综合能源服务发展有哪些趋势?

精彩推荐

产业新闻

中国航空规划设计研究总院出品:郑州西部环保能源工程设计图 中国航空规划设计研究总院出品:郑州西部环保能源工程设计图

热门推荐