目录
目录README.md

第四届计图人工智能挑战赛-热身赛

简介

本项目是在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型。
通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

Image 1 Image 2

本项目中的随机ID为:11874931186746 生成的图像为: 1713271763005

安装

运行环境

  • Windows 10
  • python >= 3.7
  • jittor >= 1.3.0

安装依赖

执行以下命令安装 jittor

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jittor

MNIST数据集

如果已有MNIST数据集,或者觉得下载太慢,可以尝试先从其他地方下载,然后将下载的文件复制到C:\Users\Administrator\.cache\jittor\dataset\mnist_data路径当中。

文件夹结构为:

├─mnist_data
│ ├─t10k-images-idx3-ubyte
│ ├─t10k-images-idx3-ubyte.gz
│ └─t10k-labels-idx1-ubyte
│ └─t10k-labels-idx1-ubyte.gz
│ └─train-images-idx3-ubyte
│ └─train-images-idx3-ubyte.gz
│ └─train-labels-idx1-ubyte
│ └─train-labels-idx1-ubyte.gz

最后将download设置为False即可。

dataloader = MNIST(train=True, transform=transform,download=False).set_attrs(batch_size=opt.batch_size, shuffle=True)

训练

运行训练脚本

python CGAN.py

测试

注释掉CGAN.py中模型训练部分的内容,加载已经训练好的权重,指定待生成的数字序列:

number = "11874931186746" # 自定义数字序列

然后重新运行CGAN.py即可生成指定数字序列对应的图像。

致谢

此项目基于论文 A Style-Based Generator Architecture for Generative Adversarial Networks 实现,部分代码参考了 jittor-gan

关于

Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器,为您的模型生成定制化的高性能代码。Jittor还包含了丰富的高性能模型库,涵盖范围包括:图像识别、检测、分割、生成、可微渲染、几何学习、强化学习等。

10.2 MB
邀请码