目录
目录README.md

CGAN_jittor

这个仓库包含一个使用 Jittor 深度学习框架实现的条件生成对抗网络(CGAN)。CGAN 是一种生成对抗网络(GAN),其中生成器和判别器都依赖于一些附加信息,在本项目中是 MNIST 数据集中的类别标签。

目录

介绍

本项目实现了一个 CGAN,用于生成基于类别标签的图像。该实现使用 Jittor 深度学习框架和 MNIST 数据集。CGAN 由一个生成器和一个判别器组成,两者共同训练以生成与给定标签对应的逼真图像。

功能

  • 使用类别标签的条件图像生成
  • 通过命令行参数自定义训练参数
  • 在训练过程中定期采样图像
  • 模型检查点保存和加载训练好的模型

安装

要安装本项目所需的依赖项,请按照以下步骤操作:

  1. 克隆仓库:

    git clone https://gitlink.org.cn/yzx2004/CGAN_jittor.git
    cd CGAN-Jittor
  2. 安装 Jittor:

    按照 Jittor 官方文档 的说明安装 Jittor。

  3. 安装其他依赖项:

    pip install numpy pillow

使用方法

训练 CGAN

要训练 CGAN,可以运行 CGAN.py 脚本,并通过命令行参数自定义训练参数(都是可选的)。下面是一个示例:

python CGAN.py [--n_epochs 100] [--batch_size 64] [--lr 0.0002] [--latent_dim 100] [--n_classes 10] [--img_size 32] [--channels 1] [--sample_interval 1000]

命令行参数

  • --n_epochs:训练的轮数(默认:100)
  • --batch_size:批次大小(默认:64)
  • --lr:Adam 优化器的学习率(默认:0.0002)
  • --b1:Adam 优化器的一阶动量的衰减(默认:0.5)
  • --b2:Adam 优化器的二阶动量的衰减(默认:0.999)
  • --n_cpu:生成批次时使用的 CPU 线程数(默认:8)
  • --latent_dim:潜在空间的维度(默认:100)
  • --n_classes:数据集的类别数量(默认:10)
  • --img_size:每个图像的尺寸(默认:32)
  • --channels:图像通道数(默认:1)
  • --sample_interval:图像采样的间隔(默认:1000)

模型检查点

在训练过程中,生成器和判别器的模型参数会定期保存到检查点文件 generator_last.pkldiscriminator_last.pkl 中。这些文件的用途如下:

  • generator_last.pkl:保存生成器的最新模型参数。可以使用这个文件在训练结束后恢复生成器模型,以继续训练或生成图像。
  • discriminator_last.pkl:保存判别器的最新模型参数。可以使用这个文件在训练结束后恢复判别器模型,以继续训练或评估生成器的性能。

要加载这些检查点文件,可以在脚本中调用 generator.load('generator_last.pkl')discriminator.load('discriminator_last.pkl')

生成图像

训练后,可以使用训练好的生成器生成图像。脚本将在训练期间按指定间隔保存采样的图像。最终生成的图像在训练循环结束后保存为 result.png

结果展示

在训练期间,图像将按指定间隔保存(由 --sample_interval 定义)。下面是训练中采样的图像10000.png以及训练结束后生成的图像result.png的示例。

10000.png

1716856384268

result.png

1716856402959

关于

A Jittor implementation of Conditional GAN (CGAN).

81.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号