目录
目录README.md

基于Jittor的条件生成对抗网络手写数字生成项目

本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN, cGAN),旨在基于随机噪声和标签生成手写数字图像,并通过判别器进行训练,以使生成的图像尽可能接近真实的手写数字图像。该实现基于MNIST数据集,可通过标签控制生成特定数字的图像。

项目结构

. ├── generator_last.pkl # 训练后的生成器模型 ├── discriminator_last.pkl # 训练后的判别器模型 ├── result.png # 生成的图像 ├── CGAN.py # 主训练脚本 └── README.md # 项目的说明文件

依赖

该项目依赖以下Python包:

  • Jittor:用于高效的深度学习计算。
  • Numpy:用于数组操作和数值计算。
  • Pillow(PIL):用于图像保存和处理。
  • argparse:用于命令行参数解析。

可以通过以下命令安装依赖:

pip install jittor numpy pillow

数据集

本项目使用MNIST数据集,该数据集包含手写数字的28x28像素图像。Jittor提供了MNIST数据集的加载器,支持数据的自动下载和预处理。

参数说明

训练参数

以下是训练脚本支持的命令行参数:

  • --n_epochs(默认值:100):训练的总轮数。
  • --batch_size(默认值:64):每个批次的图像数量。
  • --lr(默认值:0.0002):Adam优化器的学习率。
  • --b1(默认值:0.5):Adam优化器的一阶矩动量衰减。
  • --b2(默认值:0.999):Adam优化器的二阶矩动量衰减。
  • --n_cpu(默认值:8):用于批量生成的CPU线程数。
  • --latent_dim(默认值:100):隐变量的维度。
  • --n_classes(默认值:10):数据集的类别数量(MNIST数据集中为10个数字类别)。
  • --img_size(默认值:32):图像的尺寸(宽度/高度,本项目将MNIST图像调整为此尺寸)。
  • --channels(默认值:1):图像的通道数(MNIST为灰度图像,通道数为1)。
  • --sample_interval(默认值:1000):每多少步生成并保存一次图像。

生成图像相关参数

在生成图像时,可在CGAN.py文件中修改number变量来指定生成图像的标签序列(例如,可替换为电话号码、自定义数字序列等)。

生成的图像

每隔sample_interval训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为.png格式(如result.png)。

训练过程

在训练过程中,生成器和判别器交替训练。生成器尝试生成逼真的手写数字图像,以欺骗判别器将其判断为真实图像;判别器则努力区分生成图像与真实图像的差异,通过不断调整两者的参数,使生成器生成的图像质量逐渐提高。

如何运行

  1. 下载并安装依赖项:
    pip install jittor numpy pillow
  2. 下载MNIST数据集: 训练脚本会自动下载MNIST数据集并进行预处理,无需手动下载。
  3. 运行训练脚本:
    python3 CGAN.py
    训练过程中,模型会自动保存生成器和判别器的状态。每1000步,会生成一个新图像并保存。
  4. 使用生成的模型生成图像: 训练完成后,若要使用保存的生成器模型生成图像,需修改CGAN.py文件中的以下代码:
    generator.eval()
    discriminator.eval()
    generator.load('generator_last.pkl')
    discriminator.load('discriminator_last.pkl')
    number = "1234567890"  # 替换为你自己的数字序列(例如电话号码等)
    n_row = len(number)
    z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
    labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
    gen_imgs = generator(z, labels)
    img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1))
    min_ = img_array.min()
    max_ = img_array.max()
    img_array = (img_array - min_) / (max_ - min_) * 255
    Image.fromarray(np.uint8(img_array)).save("result.png")
    number替换为你自己的数字序列(如手机号或其他自定义数字序列),然后运行代码即可生成与标签对应的图像。

生成图像示例

result.png:生成的图像文件,基于输入的标签和噪声生成。例如,当number为”12345”时,生成的图像将尝试呈现出与数字1、2、3、4、5相关的手写数字特征。

可能遇到的问题

  1. CUDA错误(如果使用GPU训练):确保Jittor正确安装并且CUDA配置正常。可通过检查Jittor是否能正确检测到CUDA(运行jt.has_cuda)以及CUDA相关驱动和环境变量是否设置正确来排查。
  2. 内存问题:训练过程中可能会占用较多内存,建议根据系统内存情况使用合适的batch_size。如果内存不足,可以尝试减小batch_size或者关闭其他占用内存的程序。

仓库信息

项目仓库链接如下:https://gitlink.org.cn/yifan_personal/yifan_jitu_hw.git

关于

本项目是使用 Jittor 深度学习框架实现的条件生成对抗网络(CGAN),旨在根据特定标签生成手写数字图像。

68.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号