Update README.md
Jittor Numpy Pillow
运行脚本以训练GAN模型: python gan.py 这将为指定的训练周期数训练模型,并在每10个周期保存生成器和判别器模型。 使用特定数字生成新图像: Python
generator.eval() discriminator.eval() generator.load('generator_last.pkl') discriminator.load('discriminator_last.pkl') number = 20765732071649 number_str = str(number) n_row = len(number_str) z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad() labels = jt.array(np.array([int(number_str[num]) for num in range(n_row)])).float32().stop_grad() gen_imgs = generator(z, labels) img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)) min_ = img_array.min() max_ = img_array.max() img_array = (img_array - min_) / (max_ - min_) * 255 Image.fromarray(np.uint8(img_array)).save("result.png")
GAN模型包括一个生成器和一个判别器网络:
生成器 输入:随机噪声(100维)和一个独热编码的标签(10维) 输出:32x32的灰度图像 判别器 输入:32x32的灰度图像和一个独热编码的标签(10维) 输出:一个实数,表示输入是真实图像(1)还是生成图像(0) 生成器和判别器以对抗的方式进行训练,生成器试图欺骗判别器,而判别器试图正确分类真实和生成的图像。
训练模型后,您可以通过提供特定数字来生成新的手写数字图像。生成的图像将保存为result.png。
使用Jittor框架进行噪声去除,并且进行在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Readme
这段代码实现了一个生成对抗网络(GAN),用于生成手写数字。该模型在MNIST数据集上进行训练,并可以生成类似训练数据的新数字图像。
要求
Jittor Numpy Pillow
使用方法
运行脚本以训练GAN模型: python gan.py 这将为指定的训练周期数训练模型,并在每10个周期保存生成器和判别器模型。 使用特定数字生成新图像: Python
模型架构
GAN模型包括一个生成器和一个判别器网络:
生成器 输入:随机噪声(100维)和一个独热编码的标签(10维) 输出:32x32的灰度图像 判别器 输入:32x32的灰度图像和一个独热编码的标签(10维) 输出:一个实数,表示输入是真实图像(1)还是生成图像(0) 生成器和判别器以对抗的方式进行训练,生成器试图欺骗判别器,而判别器试图正确分类真实和生成的图像。
结果
训练模型后,您可以通过提供特定数字来生成新的手写数字图像。生成的图像将保存为result.png。