add gitignore
本项目使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
Jittor 框架目前支持 Linux 或 Windows,需要使用 Python 及 C++编译器(g++或 clang)。 Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考: https://cg.cs.tsinghua.edu.cn/jittor/download/。
本次代码仅包含一个文件 CGAN.py。
生成器Generator和判别器Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。
模型中主要使用 的模块有
因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。
代码会自动下载 MNIST 数据集。每轮迭代 中,我们枚举数据集中的图片(imgs)和类别标签(labels)对,并随机生成一组输入向量,计算生成器和判别器损失函数,回传梯度并更新网络参数。每迭代若干轮会随机采样生成一批数字图片,如下:
在根目录下运行python CGAN.py,可以添加的参数如下:
python CGAN.py
如python CGAN.py --n_epochs 100。
python CGAN.py --n_epochs 100
更改CGAN.py中的number变量值,可以改变输出的数字序列。下面为一个样例输出:
A Jittor implementation of Conditional GAN (CGAN)
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
cgan_jittor
本项目使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。
环境说明
Jittor 框架目前支持 Linux 或 Windows,需要使用 Python 及 C++编译器(g++或 clang)。 Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考: https://cg.cs.tsinghua.edu.cn/jittor/download/。
代码框架
本次代码仅包含一个文件 CGAN.py。
生成器Generator和判别器Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。
模型中主要使用 的模块有
因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。
代码会自动下载 MNIST 数据集。每轮迭代 中,我们枚举数据集中的图片(imgs)和类别标签(labels)对,并随机生成一组输入向量,计算生成器和判别器损失函数,回传梯度并更新网络参数。每迭代若干轮会随机采样生成一批数字图片,如下:
运行说明
在根目录下运行
python CGAN.py
,可以添加的参数如下:如
python CGAN.py --n_epochs 100
。示例结果
更改CGAN.py中的number变量值,可以改变输出的数字序列。下面为一个样例输出: