目录
目录README.md

项目名称:基于条件生成对抗网络在MNIST数据集上的数字图像生成

实验介绍

本项目旨在通过训练条件生成对抗网络(Conditional GAN, CGAN)模型,在MNIST数字图像数据集上生成特定类别的数字图像。CGAN通过输入随机噪声向量和类别标签,生成对应类别的图像。实验包括模型的定义与训练、生成图像的评估以及结果的分析。最终,模型能够根据给定的随机ID生成对应的数字图像,并在评测中达到较高的分类准确率。

关键技术

Jittor(计图)

  • 官网Jittor
  • 特点:即时编译机制(JIT),动态地将计算图转换为高效的代码,减少计算资源的浪费,提高运行效率。

CGAN(Conditional Generative Adversarial Networks)

  • 特点:生成对抗网络(GAN)的扩展版本,允许在生成过程中过滤或控制生成的样本,以便让生成器在特定条件下生成特定类别或特征的样本。

实验过程

实验过程包括CGAN的网络结构、损失函数设计、使用CGAN生成一串数字、从头训练CGAN、以及在mnist手写数字数据集上的训练结果。详细信息可参考Jittor CGAN教程

实验方法

数据集

  • MNIST数据集:包含60000张训练图像和10000张测试图像,每张图像为28x28像素的灰度图,代表0-9共10个类别的手写数字。

模型架构

  • 生成器Generator:输入包括随机噪声和标签信息,通过多个全连接层逐渐扩展特征空间,最终生成目标图像。
  • 判别器Discriminator:输入是一个图像和标签信息,输出是真图片的概率。

损失函数

  • 生成器的损失 g_loss:使用MSE损失函数计算判别器输出的对生成图像的有效性评分与真实标签之间的均方误差。
  • 判别器的损失 d_loss:由真实图像的损失和生成图像的损失的平均值组成。

实验环境

  • CPU:Intel(R) Xeon(R) Silver 4214R CPU @ 2.40GHz,48核心
  • GPU:Tesla V100-SXM2-16GB

结果与分析

模型在MNIST训练集上训练了120个epoch,使用超参数如下:

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=120, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=1000, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
opt = parser.parse_args()

生成结果在文件中的result目录下可以找到。

结论

本实验成功利用CGAN在MNIST数据集上进行了数字图像的生成任务。通过引入类别标签作为条件信息,生成器能够根据指定标签生成对应的手写数字图像,模型在训练过程中不断优化生成图像的质量与多样性。使用Jittor框架实现该任务,结合了即时编译技术提高了训练效率。实验表明,CGAN在图像生成任务中具有较好的性能,并且通过合理的模型架构与超参数调优,能够在合适的训练时间内实现良好的图像生成效果。

参考文献

  1. Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … & Bengio, Y. (2014). Generative adversarial nets. In Advances in neural information processing systems (pp. 2672-2680).
  2. Mirza, M., & Osindero, S. (2014). Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784.
  3. Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434.
关于

A Jittor implementation of Conditional GAN (CGAN).

15.7 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号