目录
目录README.md

Readme

这段代码实现了一个生成对抗网络(GAN),用于生成手写数字。该模型在MNIST数据集上进行训练,并可以生成类似训练数据的新数字图像。

要求

Jittor Numpy Pillow

使用方法

运行脚本以训练GAN模型: python gan.py 这将为指定的训练周期数训练模型,并在每10个周期保存生成器和判别器模型。 使用特定数字生成新图像: Python

generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')

number = 20765732071649
number_str = str(number)
n_row = len(number_str)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number_str[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)

img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_ = img_array.min()
max_ = img_array.max()
img_array = (img_array - min_) / (max_ - min_) * 255
Image.fromarray(np.uint8(img_array)).save("result.png")

模型架构

GAN模型包括一个生成器和一个判别器网络:

生成器 输入:随机噪声(100维)和一个独热编码的标签(10维) 输出:32x32的灰度图像 判别器 输入:32x32的灰度图像和一个独热编码的标签(10维) 输出:一个实数,表示输入是真实图像(1)还是生成图像(0) 生成器和判别器以对抗的方式进行训练,生成器试图欺骗判别器,而判别器试图正确分类真实和生成的图像。

结果

训练模型后,您可以通过提供特定数字来生成新的手写数字图像。生成的图像将保存为result.png。

关于

使用Jittor框架进行噪声去除,并且进行在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

30.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号