目录
目录README.md

CGAN_mnist

  • 基于Jittor深度学习框架实现CGAN模型的手写数字生成模型,使用的数据集为MNIST

Jittor介绍

Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器,为您的模型生成定制化的高性能代码。Jittor还包含了丰富的高性能模型库,涵盖范围包括:图像识别、检测、分割、生成、可微渲染、几何学习、强化学习等。

Jittor前端语言为Python,使用了主流的包含模块化和动态图执行的接口设计,后端则使用高性能语言进行了深度优化。

CGAN介绍

Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型:输入为一个随机向量 $z$, 生成器$G$输出一幅图像 $G(z)$, 而判别器 $D$ 需要将真实图像 $x$ 与合成图像 $G(z)$ 区分开来。

然而,GAN 对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。在生成器 generator 和判别器 discriminator 中添加相同的额外信息$y$,GAN 就可以扩展为一个 conditional 模型。$y$ 可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将 $y$ 作为额外输入层,添加到生成器和判别器来完成条件控制。

运行

使用默认参数在命令行执行以下代码即可:

python gan.py 

自定义参数时将参数加在后边即可:

python gan.py [-h] [--n_epochs N_EPOCHS] [--batch_size BATCH_SIZE] [--lr LR] [--b1 B1] [--b2 B2] [--n_cpu N_CPU] [--latent_dim LATENT_DIM] [--n_classes N_CLASSES] [--img_size IMG_SIZE] [--channels CHANNELS] [--sample_interval SAMPLE_INTERVAL]
关于

基于jittor深度学习框架完成的CGAN模型,通过MNIST数据集训练实现手写数字生成。

31.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号