目录
目录README.md

Jittor 第四届人工智能挑战赛热身赛项目

result

简介

本项目包含了第四届计图人工智能挑战赛——热身赛的代码实现。本项目的特点是:在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

安装

本项目可在 1 张 3090 上运行,训练时间约为 20分钟。

运行环境

  • ubuntu 18.04 LTS
  • python >= 3.8
  • jittor >= 1.3.1
  • CUDA 11.3

安装依赖

执行以下命令安装 python 依赖

pip install jittor==1.3.1

训练

  1. 导入必要的库和设置参数:导入了Jittor和其他必要的库,并设置了训练所需的参数。
  2. 定义生成器(Generator):定义了一个生成器模型,它通过接受随机噪声和数字类别作为输入,生成与给定类别相关的图像。生成器使用了全连接层和批归一化层,最终输出一个与给定图像大小相匹配的图像。
  3. 定义判别器(Discriminator):定义了一个判别器模型,它接受真实图像和相应的标签(类别),并尝试将生成的图像与真实图像区分开来。判别器使用了几个全连接层和激活函数,最终输出一个实数,表示输入图像是真实图像的概率。
  4. 定义损失函数:使用了均方误差(MSELoss)作为生成器和判别器的损失函数。生成器的目标是欺骗判别器,使其将生成的图像分类为真实图像,而判别器的目标是正确地将真实图像和生成的图像分类。
  5. 导入数据集:导入了MNIST数据集,并对图像进行了预处理。
  6. 训练模型:使用导入的数据集训练生成器和判别器模型。在每个 epoch 中,通过迭代数据集中的每个批次进行训练,优化生成器和判别器的参数。同时,定期保存生成的图像用于可视化。
  7. 生成图像:使用训练好的生成器模型生成指定数字序列对应的图像,并保存为图片文件。

推理

在训练过程中,生成器(Generator)和判别器(Discriminator)相互对抗,以提高生成器生成逼真图像的能力。

在训练之后,模型进入评估阶段。评估的任务是使用生成器生成与指定数字序列对应的图像。这个任务可以分为以下步骤:

  1. 为给定的数字序列生成随机噪声向量(latent vectors),可以使用 np.random.normal 函数生成高斯分布的随机噪声。
  2. 将数字序列转换为相应的标签,用于生成器的输入。
  3. 使用生成器(Generator)模型生成图像。生成器将噪声向量和标签作为输入,并生成与数字序列对应的图像。
  4. 将生成的图像保存到文件中,以供进一步检查和评估。

在代码的最后一部分,将生成的图像保存到名为 “result.png” 的文件中。这个图像是根据输入的数字序列生成的逼真图像。

致谢

基于计图官方示例代码填充注释为TODO的部分完成

关于

利用jittor在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

48.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号