目录
目录README.md

Jittor 热身赛 CGAN手写数字生成

主要结果

简介

本项目包含了第二届计图挑战赛计图 - 热身赛CGAN手写数字生成的代码实现。本项目的特点是:使用Conditional Generative Adversarial Nets (CGAN)实现基于手写数字数据集MNIST的数字生成。

运行环境

  • Ubuntu 18.04.5
  • python >= 3.7
  • jittor >= 1.3.0
  • cuda >= 11.0

使用

运行代码后,模型将首先训练CGAN生成器与判别器,之后保存模型;模型训练结束后,自动加载模型并开始推理过程。 单卡训练可运行以下命令:

python CGAN.py --[PARAM_1] [VALUE_1] --[PARAM_2] [VALUE_2]

多卡训练(NVIDIA)可以运行以下命令:

CUDA_VISIBLE_DEVICES=[INDEX_OF_DEVICE] python CGAN.py --[PARAM_1] [VALUE_1] --[PARAM_2] [VALUE_2]

其中[INDEX_OF_DEVICE]处填写显卡编号,[PARAM_X]处填写参数名,[VALUE_X]处填写对应参数值。

主要参数

  • n_epochs:训练轮次,默认为1
  • batch_size:批次数据量大小,默认为64
  • lr:学习率,默认2e-4
  • latent_dim:隐向量维度,默认为100
  • n_classes:类别数,默认为10
  • img_size:图像尺寸,默认为32(32x32)
  • channels:图像通道数,默认为1
  • sample_interval:采样间隔,默认为1000

参考

@article{mirza2014conditional,
  title={Conditional generative adversarial nets},
  author={Mirza, Mehdi and Osindero, Simon},
  journal={arXiv preprint arXiv:1411.1784},
  year={2014}
}
关于

本项目关于Jittor挑战赛热身赛相关代码

30.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号