目录
目录README.md

cgan_jittor

本项目使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。

环境说明

Jittor 框架目前支持 Linux 或 Windows,需要使用 Python 及 C++编译器(g++或 clang)。 Jittor 提供了三种安装方法:docker,pip 和手动安装,具体安装教程请参考: https://cg.cs.tsinghua.edu.cn/jittor/download/。

代码框架

本次代码仅包含一个文件 CGAN.py。

生成器Generator和判别器Discriminator 中的 init 函数用于定义模型架构,execute 函数给定网络输入返回网络输出。

模型中主要使用 的模块有

  • nn.Embedding(num, dim):用于将 num 类整数标签转换为 dim 维向量。
  • nn.Linear(in_features, out_features):全连接层,输入向量维度 in_features,输出向量 维度 out_features。
  • nn.Drouout(p):将比例为 p 的特征置为 0。
  • nn.LeakyReLU(scale):ReLU 函数的变种,输入为负值时输出乘以 scale。

因为图像的尺度较小,我们直接使用了全连接层而不是通常的卷积层。

代码会自动下载 MNIST 数据集。每轮迭代 中,我们枚举数据集中的图片(imgs)和类别标签(labels)对,并随机生成一组输入向量,计算生成器和判别器损失函数,回传梯度并更新网络参数。每迭代若干轮会随机采样生成一批数字图片,如下:

随机采样生成的数字图片

运行说明

在根目录下运行python CGAN.py,可以添加的参数如下:

参数 默认值 含义
–n_epochs 50 number of epochs of training
–batch_size 64 size of the batches
–lr 0.0002 adam: learning rate
–b1 0.5 adam: decay of first order momentum of gradient
–b2 0.999 adam: decay of first order momentum of gradient
–n_cpu 8 number of cpu threads to use during batch generation
–latent_dim 100 dimensionality of the latent space
–n_classes 10 number of classes for dataset
–img_size 32 size of each image dimension
–channels 1 number of image channels
–sample_interval 1000 interval between image sampling

python CGAN.py --n_epochs 100

示例结果

更改CGAN.py中的number变量值,可以改变输出的数字序列。下面为一个样例输出:

样例输出

关于

A Jittor implementation of Conditional GAN (CGAN)

130.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号