目录
目录README.md

Jittor 热身赛 Conditional GAN

主要结果

简介

| 简单介绍项目背景、项目特点

本项目包含了第三届计图挑战赛计图 - 生成特定数字图像的代码实现。

本项目的特点是:采用了 Conditional GAN 模型在MNIST数据集上训练,通过输入一个随机向量z和额外的辅助信息y,生成特定数字的图像。

安装

| 介绍基本的硬件需求、运行环境、依赖安装方法

本项目可在 1 张 3060Ti上运行,训练时间约为 20 分钟。

运行环境

  • Windows 10
  • python >= 3.8
  • jittor >= 1.3.8

安装依赖

创建虚拟环境

conda create -n jittor python=3.8

激活虚拟环境

conda activate jittor

安装 jittor所需环境

python -m pip install jittor
python -m jittor.test.test_core
python -m jittor.test.test_example
python -m jittor.test.test_cudnn_op

预训练模型

generator预训练模型为generator_last.pkl

discriminator预训练模型为discriminator_last.pkl

预训练模型模型在仓库的根目录中,下载后放到项目的根目录下。

数据预处理

在这个代码中,数据预处理是在数据集加载时完成的。在数据集加载时,使用transforms.Compose函数将多个数据预处理步骤组合在一起。这些步骤包括将图像转换为张量、将图像像素值归一化到[-1,1]之间、将图像大小调整为指定大小等。这些预处理步骤可以帮助模型更好地学习数据集中的模式和特征。在这个代码中,数据集是MNIST手写数字数据集,因此预处理步骤包括将图像转换为灰度图像、将图像大小调整为32x32、将图像像素值归一化到[-1,1]之间。这些预处理步骤可以帮助模型更好地学习手写数字的特征和模式。

训练

在这个代码中,模型的训练是通过训练循环来实现的。训练循环迭代数据集中的每个批次,并使用生成器和鉴别器网络对批次中的真实图像和生成的图像进行训练。在每个迭代中,生成器网络使用随机噪声向量和标签生成一批假图像,并将其传递给鉴别器网络进行训练。鉴别器网络使用真实图像和生成的图像进行训练,并将其与真实标签进行比较。在每个迭代中,生成器和鉴别器网络的损失值都会被打印到控制台上。

训练循环还使用batches_done变量来跟踪已处理的批次数。如果已处理的批次数是sample_interval的倍数,则使用sample_image函数生成一张样本图像。这个函数使用生成器网络生成一批假图像,并将它们保存到文件中。

每10个epoch,生成器和鉴别器网络都会使用save方法保存到磁盘上。然后使用load方法从磁盘上加载生成器和鉴别器网络。这样做是为了确保始终使用最新版本的网络进行图像生成。

在训练过程中,可以通过调整超参数和训练数据集来改进生成的图像的质量。可以使用其他评估指标来评估模型的性能,例如生成图像的多样性、清晰度和真实性等。

推理

在这个代码中,模型的评估可以通过生成的图像进行视觉检查。可以通过调整生成器的超参数和训练数据集来改进生成的图像的质量。此外,可以使用其他评估指标来评估模型的性能,例如生成图像的多样性、清晰度和真实性等。可以使用人类评估者对生成的图像进行主观评估,或者使用自动评估指标,例如Inception Score和Fréchet Inception Distance等。这些指标可以帮助评估生成的图像与真实图像之间的差异。在测试和推理方面,可以使用训练好的生成器网络来生成新的手写数字图像。可以通过调整输入噪声向量和标签来生成不同的图像。可以使用不同的数字标签来生成不同的数字图像。可以使用生成的图像来测试模型的性能和生成能力。

致谢

此项目代码参考了 jittor热身赛示例代码

关于

2023Jittor挑战赛热身题

38.0 KB
邀请码