Update README.md
本项目包含了第三届计图挑战赛热身赛计图 - 手写数字生成比赛的代码实现。本项目的特点是:将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
| 介绍基本的硬件需求、运行环境、依赖安装方法
本项目可在 cpu上运行,训练时间约为 6小时。
执行以下命令安装 python 依赖
conda install pywin32=302
预训练模型模型下载地址为 https:abc.def.gh,下载后放入目录 <root>/ 下。 由于文件大小问题,generator_last.pkl无法上传
<root>/
| 介绍模型训练的方法
# ---------- # 模型训练 # ---------- for epoch in range(opt.n_epochs): for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] # 数据标签,valid=1表示真实的图片,fake=0表示生成的图片 valid = jt.ones([batch_size, 1]).float32().stop_grad() fake = jt.zeros([batch_size, 1]).float32().stop_grad() # 真实图片及其类别 real_imgs = jt.array(imgs) labels = jt.array(labels) # ----------------- # 训练生成器 # ----------------- # 采样随机噪声和数字类别作为生成器输入 z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32() gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32() # 生成一组图片 gen_imgs = generator(z, gen_labels) # 损失函数衡量生成器欺骗判别器的能力,即希望判别器将生成图片分类为valid validity = discriminator(gen_imgs, gen_labels) g_loss = adversarial_loss(validity, valid) g_loss.sync() optimizer_G.step(g_loss) # --------------------- # 训练判别器 # --------------------- validity_real = discriminator(real_imgs, labels) d_real_loss = adversarial_loss(validity_real,valid) """TODO: 计算真实类别的损失函数""" validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels) d_fake_loss = adversarial_loss(validity_fake,fake) """TODO: 计算虚假类别的损失函数""" # 总的判别器损失 d_loss = (d_real_loss + d_fake_loss) / 2 d_loss.sync() optimizer_D.step(d_loss) if i % 50 == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data) ) batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: sample_image(n_row=10, batches_done=batches_done) if epoch % 10 == 0: generator.save("generator_last.pkl") discriminator.save("discriminator_last.pkl")
| 介绍模型推理、测试、或者评估的方法
生成测试集上的结果可以运行以下命令:
generator.eval() discriminator.eval() generator.load('generator_last.pkl') discriminator.load('discriminator_last.pkl') number = str(14226321421396)#TODO: 写入比赛页面中指定的数字序列(字符串类型) n_row = len(number) z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad() labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad() gen_imgs = generator(z,labels) img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)) min_=img_array.min() max_=img_array.max() img_array=(img_array-min_)/(max_-min_)*255 Image.fromarray(np.uint8(img_array)).save("result.png")
部分代码参考了 https://blog.csdn.net/qq_52852138/article/details/121682072。
本赛道将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
Jittor挑战热身赛 手写数字生成赛题 ConditionalGAN
简介
本项目包含了第三届计图挑战赛热身赛计图 - 手写数字生成比赛的代码实现。本项目的特点是:将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。
安装
| 介绍基本的硬件需求、运行环境、依赖安装方法
本项目可在 cpu上运行,训练时间约为 6小时。
运行环境
安装依赖
执行以下命令安装 python 依赖
预训练模型
预训练模型模型下载地址为 https:abc.def.gh,下载后放入目录
<root>/
下。 由于文件大小问题,generator_last.pkl无法上传训练
| 介绍模型训练的方法
推理
| 介绍模型推理、测试、或者评估的方法
生成测试集上的结果可以运行以下命令:
致谢
部分代码参考了 https://blog.csdn.net/qq_52852138/article/details/121682072。