目录
目录README.md

第四届计图Jittor 人工智能挑战赛热身赛

简介

在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

安装

本项目可在 1 张 RTX 3050Ti 上运行,训练时间约为25分钟。

运行环境

windows 11 python >= 3.8 jittor >= 1.3.0 ###安装依赖 pip install jittor ##数据预处理 加载MNIST数据集:使用Jittor内置的MNIST数据集加载器,读取MNIST数据集的训练集和测试集。 将像素值归一化:将像素值从0-255范围内的整数归一化到-1到1之间的浮点数,以便更好地训练GAN。 将图像转换为张量:将每个图像转换为一个张量,并将标签也转换为张量。 对标签进行一次性编码:将标签转换为一个one-hot编码的张量,以便在训练GAN时能够提供给生成器和判别器网络。 打乱数据集:将训练集中的图像和标签打乱,以便更好地训练GAN。 ##训练 在训练前可以修改 number 变量以生成不同的手写数字序列。在训练过程中,程序将会输出每个 epoch 中的判别器(D)和生成器(G)的损失函数。每个 epoch 结束时,程序将会生成一组图片并保存在当前目录下。此外,在代码结束后,程序将会生成一张包含指定数字序列的手写数字图片并保存在当前目录下。

推理

在这个代码中,模型的推理、测试和评估都是通过生成器(Generator)实现的。具体来说,通过生成器的 execute 方法,输入随机噪声和数字类别,生成一组图片。在推理和测试时,我们可以使用该方法生成图片,并进行可视化或者其他的结果展示。在评估时,我们可以根据具体的评估指标,计算生成的图片与真实图片之间的差异,来评估生成器的性能。在本代码中,我们使用了训练过程中的损失函数作为评估指标,以衡量生成器欺骗判别器的能力。

致谢

感谢计图官方提供的示例代码。

注意事项

注意环境要搭配好,不然会出现 jittor 调用不了的情况。

关于
71.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号