目录
目录README.md

第二届计图挑战赛之热身赛

基于Jittor框架实现的CGAN网络(数字生成)

image-20220505105649057

简介

本项目包含了第二届计图挑战赛热身赛的代码。该比赛在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像

这里是本项目所使用的[基本代码框架](https://github.com/Jittor/gan-jittor/blob/master/ competition/warm_up_comp/CGAN.py),本项目在此基础上完成了TODO的要求。

安装

本项目在普通安装Nvidia显卡的电脑上即可运行。

安装Jittor的方式请到官网查看详情

运行环境

  • python >= 3.8
  • CUDA >= 10.2
  • jittor >= 1.3.0

安装依赖

执行以下命令安装 python 依赖

pip install jittor 

预训练模型

本次实验代码没有使用预训练模型,但训练完毕之后的checkpoint将会存至到与main.py同目录下的generator_last.pkl与discriminator_last.pkl文件之中,以供后续使用。

数据预处理

本代码的数据是从网上下载的,无需数据预处理操作

训练

单卡训练可直接运行以下命令:

python main.py 

致谢

本份代码基于 jittor-gan所提供的框架实现,感谢代码框架的提供者。

感谢国家自然科学基金委信息科学部北京信息科学与技术国家研究中心清华-腾讯互联网创新技术联合实验室共同指导举办的本次Jittor人工智能挑战赛

感谢腾讯科技(深圳)有限公司对本次赛事的赞助

注意事项

  • 在安装Jittor时若遇到错误,应先检查python的版本,若是版本问题,建议去官网下载python新版本,本人使用的是python 3.10.4,亲测可以顺利安装Jittor。
  • 代码中主要使用到的Jittor框架中的模块有:
    • nn.Embedding(num, dim):用于将 num 类整数标签转换为 dim 维向量
    • nn.Linear(in_features, out_features):全连接层,输入向量维度 in_features,输出向量 维度 out_features
    • nn.Drouout(p):将比例为 p 的特征置为 0
    • nn.LeakyReLU(scale):ReLU 函数的变种,输入为负值时输出乘以 scale
关于

使用 Jittor 机器学习框架,在数字图片数据集 MNIST 上训练一个将随机噪声和类别标签映射为数字图片的 Conditional GAN 模型,生成指定数字序列对应的图片。

32.0 KB
邀请码