目录
目录README.md

GAN MNIST图像生成

这是一个使用GAN生成手写数字MNIST图像的代码,它由一个生成器和一个鉴别器组成。

算法介绍

GAN(Generative Adversarial Networks)是一种生成对抗网络,由一个生成器和一个鉴别器组成,用于生成与真实数据相似的数据。生成器用于生成数据,而鉴别器则用于区分生成的数据与真实数据。生成器和鉴别器通过训练不断提高各自的能力,最终达到生成与真实数据无法区分的数据的目的。

本代码使用GAN生成手写数字MNIST图像,具体实现方法为:

  1. 生成器:使用全连接层和LeakyReLU激活函数构建生成器,将噪声向量和标签进行拼接,生成图像。

  2. 鉴别器:使用全连接层和LeakyReLU激活函数构建鉴别器,将图像和标签进行拼接,输出一个实数,用于判断真实或生成的数据。

  3. 损失函数:使用平方误差损失函数进行训练,让鉴别器输出的实数与真实标签的平方误差最小,同时让生成器输出的图像经过鉴别器后的实数与真实标签的平方误差最小。

代码说明

  • generator:生成器类,用于生成图像。

  • discriminator:鉴别器类,用于判断真实或生成的数据。

  • adversarial_loss:平方误差损失函数。

  • dataloader:使用Jittor框架自带的MNIST数据集进行训练。

  • save_image:用于保存生成的图像。

使用方法

  1. 克隆代码库。

  2. 安装Jittor框架。

  3. 运行python main.py,开始训练。

  4. 训练结束后,在output文件夹中可以看到生成的图像。

关于

A Jittor implementation of Conditional GAN (CGAN).

35.0 KB
邀请码