目录
目录README.md

Jittor挑战热身赛 手写数字生成赛题 ConditionalGAN

主要结果

简介

本项目包含了第三届计图挑战赛热身赛计图 - 手写数字生成比赛的代码实现。本项目的特点是:将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

安装

| 介绍基本的硬件需求、运行环境、依赖安装方法

本项目可在 cpu上运行,训练时间约为 6小时。

运行环境

  • python >= 3.9
  • jittor >= 1.3.8.5

安装依赖

执行以下命令安装 python 依赖

conda install pywin32=302

预训练模型

预训练模型模型下载地址为 https:abc.def.gh,下载后放入目录 <root>/ 下。 由于文件大小问题,generator_last.pkl无法上传

训练

| 介绍模型训练的方法

# ----------
#  模型训练
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # 数据标签,valid=1表示真实的图片,fake=0表示生成的图片
        valid = jt.ones([batch_size, 1]).float32().stop_grad()
        fake = jt.zeros([batch_size, 1]).float32().stop_grad()

        # 真实图片及其类别
        real_imgs = jt.array(imgs)
        labels = jt.array(labels)

        # -----------------
        #  训练生成器
        # -----------------

        # 采样随机噪声和数字类别作为生成器输入
        z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
        gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()

        # 生成一组图片
        gen_imgs = generator(z, gen_labels)
        # 损失函数衡量生成器欺骗判别器的能力,即希望判别器将生成图片分类为valid
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.sync()
        optimizer_G.step(g_loss)

        # ---------------------
        #  训练判别器
        # ---------------------

        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real,valid)
        """TODO: 计算真实类别的损失函数"""

        validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake,fake)
        """TODO: 计算虚假类别的损失函数"""

        # 总的判别器损失
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.sync()
        optimizer_D.step(d_loss)
        if i  % 50 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

    if epoch % 10 == 0:
        generator.save("generator_last.pkl")
        discriminator.save("discriminator_last.pkl")

推理

| 介绍模型推理、测试、或者评估的方法

生成测试集上的结果可以运行以下命令:

generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')

number = str(14226321421396)#TODO: 写入比赛页面中指定的数字序列(字符串类型)
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)

img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")

致谢

部分代码参考了 https://blog.csdn.net/qq_52852138/article/details/121682072。

关于

本赛道将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

3.8 MB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号