目录
目录readme.md

CGAN生成MNIST手写数字图像

项目说明

第三届计图人工智能挑战赛热身赛赛题,在MNIST手写数字图片集上训练Conditional GAN模型,通过输入一个随机向量作为随机噪声和类别变迁作为辅助信息,来生成特定数字的图像。

在本次项目中,通过同时训练一个生成器Generate和一个判别器Discriminator并使二者进行最大最小博弈,来不断提高Generator生成图片的能力与Discriminator判别图片来源的能力。

项目获取

  1. 通过git clone获取
git clone https://gitlink.org.cn/KIDSSCC/Jittor-CGAN_MNIST.git
  1. 通过压缩包获取

CGAN.zip:https://pan.baidu.com/s/1CYQqhJncslDEV5diToo0ww?pwd=2333

项目使用

该项目使用python语言进行编写,采用Jittor框架,参照已有的比赛代码框架进行实现。

运行环境为:

  • 操作系统:Ubuntu22.04
  • Python:Python3.10.6
  • Jittor:Jittor: 1.3.7.15

训练与测试过程可以通过如下命令开始

python3 CGAN.py [OPTION]

关于OPTION的说明:

  • –train:是否进行训练,默认取值为1,代表进行训练,当设置为0时,需要在工作目录下已保存训练模型。
  • –n_epochs:训练过程的迭代次数,默认取值为100
  • –batch_size:每一个批次的大小,默认取值为64
  • –lr:Adam优化器学习率,默认取值为0.0002
  • –b1:Adam优化器一阶矩估计衰减因子,默认取值为0.5
  • –b2:Adam优化器二阶矩估计衰减因子,默认取值为0.999
  • –n_cpu:在训练过程中使用的进程数,默认取值为8
  • –latent_dim:输入噪声的隐藏维度,默认取值为100
  • –n_classes:类别标签的类型数,默认取值为10,即对应MNIST数据集的0-9标签
  • –img_size:数据空间大小,默认取值为32,对应手写数字图像的图片大小
  • –channels:图像通道数,默认取值为1,对应手写数字图像的通道数
  • –sample_interval:图片采样的间隔数,默认取值为1000

一种可能的调用方法为:

python3 CGAN.py --n_epochs=10

训练模型并进行测试,训练迭代次数为10

项目测试

在头歌平台进行测试,指定ID为:13780231376878,

在训练参数保持默认的情况下,生成的熟悉图像如下所示:

result

在头歌平台取得0.9975的训练成绩

预训练模型

该项目提供预训练模型,包含两个文件discriminator_last.pkl与generator_last.pkl

使用方法:

将预训练模型与CGAN.py文件放置于同一目录下,随后调用:

python3 CGAN.py --train=0

即可使用预训练模型

预训练模型获取方式:https://pan.baidu.com/s/1CYQqhJncslDEV5diToo0ww?pwd=2333

关于

A Jittor implementation of Conditional GAN (CGAN).

10.0 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号