目录
目录README.md

CGAN_jittor——A Jittor implementation of Conditional GAN (CGAN)

项目简述

        本项目是一个通过 Jittor 框架实现的简单 CGAN 模型,具体是在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。内部的模型结构较为简单,主要是通过全连接层进行构造,配合有 Dropout、LeakyReLU 等方法。

CGAN

        GAN 模型通过同时训练生成器 generator 和判别器 discriminator 进行对抗,目的是训练得到一个好的 generator。具体来说,是 generator 使用随机向量 Z 生成图像 G(Z);discriminator 对生成图像 G(Z)和真实图像 X 同时进行判别分别输出 D(G(Z))和 D(X),对认为是真实图像的输出 1,认为是生成图像的输出 0,然后计算此次判别的$loss={1\over 2} {(0-D(G(Z)))^2 + (1-D(X))^2}$,进行反向传播训练 discriminator;而对于判别结果 D(G(Z)),generator 没能骗过 discriminator 的即为$loss={(1-D(G(Z)))^2}$。CGAN 模型在 GAN 模型生成中加入了条件 label Y,增加了一个额外的控制输入。

Jittor

        一个完全基于动态编译(Just-in-time),内部使用创新的元算子和统一计算图的深度学习框架, 元算子和 Numpy 一样易于使用,并且超越 Numpy 能够实现更复杂更高效的操作。而统一计算图则是融合了静态计算图和动态计算图的诸多优点,在易于使用的同时,提供高性能的优化。基于元算子开发的深度学习模型,可以被计图实时的自动优化并且运行在指定的硬件上,如 CPU,GPU,TPU。具体关于 Jittor 的介绍参见https://cg.cs.tsinghua.edu.cn/jittor/


项目功能

        该项目可以将输入的数字字符串生成为图片。生成图片由模型生成,因其在 MNSIST 上训练,因此生成的图片会与其中的图片风格接近

        具体图片效果如下(以默认输入数字字符串’11024271101857’为例):

        output


使用方法

        本项目使用前需要首先完成’项目安装’,然后可以选择’直接使用’已经训练好的模型,或者选择自己’训练模型’,请分别参考下文的三个内容


项目安装

        主要需要安装 jittor 库,在 https://cg.cs.tsinghua.edu.cn/jittor/download/ 中选择适合自己系统的 jittor 类型,依照网站所给的指令进行安装,推荐在 linux 系统上运行。

        下面给出在 linux 系统上安装 jittor 库的具体步骤,其他操作系统可以在前文的网址中自行查找。

        以下是 Linux 系统 CPU 版本通过 conda 安装的指令,也是本人开发时所使用的的环境:

sudo apt install python3.7-dev libomp-dev
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example

        而关于 Jittor 在 linux 系统的依赖则如下所示:

Python:版本 >= 3.7
C++编译器 (需要下列至少一个)
    g++ (>=5.4.0 for linux)
    clang (>=8.0 for mac)
GPU 编译器(可选):nvcc >=10.0
GPU 加速库(可选):cudnn-dev

直接使用

        如果想使用已经训练好的模型,可以在完成”项目安装”后,从 https://cloud.tsinghua.edu.cn/d/75dda58ad047452bb9a3/ 中下载 generator_last.pkl,将其与 gen.py 置于同一文件夹下,运行:

python gen.py

        直接使用时有 2 个可选参数:

--input,参数为为想生成图片的数字,默认11024271101857

--output,参数为生成的图片的地址,默认"output.png"

        完整版的运行方法如下,以在地址 temp.png 生成 2021012806 的图片为例:

python gen.py --input 2021012806 --output temp.png

训练模型

        如果想自己训练模型,可以在完成”项目安装”后,运行:

python CGAN.py

        训练模型时有如下 11 个可选参数:

参数名 参数意义 默认值
–n_epochs number of epochs of training 100
–batch_size size of the batches 64
–lr adam: learning rate 0.0002
–b1 adam: decay of first order momentum of gradient 0.5
–b2 adam: decay of first order momentum of gradient 0.999
–n_cpu number of cpu threads to use during batch generation 8
–latent_dim dimensionality of the latent space 100
-n_classes number of classes for dataset 10
–img_size size of each image dimension 32
–channels number of image channels 1
–sample_interval interval between image sampling 1000

        训练中每当训练量增加 sample_interval 的时候,会输出一张当前训练效果下的数字的图像(训练量.png),每个数字生成 10 个。

        例如以下是训练数据量为 90000 时的输出,90000.png:

        90000

        此外,每训练 10 轮会将模型中的参数进行一次保存,生成器的参数保存在”generator_last.pkl”,判别器的参数保存在”discriminator_last.pkl”,方便训练中断后加载已训练的数据,避免重复训练。

        训练结束后,会自动以’11024271101857’为样例进行生成测试,生成’result.png’图片。

        后续如果想使用自己训练的模型,参照’直接使用’部分 gen.py 的使用方法即可。注意如果对 latent_dim、n_class、img_size、channels 这四个参数进行了修改,在执行 gen.py 的时候也要显示地带上修改后的参数的值,如’python gen.py–latent_dim xx’。

关于

A Jittor implementation of Conditional GAN (CGAN)

106.0 KB
邀请码