add CGAN.py .gitignore
这个仓库包含一个使用 Jittor 深度学习框架实现的条件生成对抗网络(CGAN)。CGAN 是一种生成对抗网络(GAN),其中生成器和判别器都依赖于一些附加信息,在本项目中是 MNIST 数据集中的类别标签。
本项目实现了一个 CGAN,用于生成基于类别标签的图像。该实现使用 Jittor 深度学习框架和 MNIST 数据集。CGAN 由一个生成器和一个判别器组成,两者共同训练以生成与给定标签对应的逼真图像。
要安装本项目所需的依赖项,请按照以下步骤操作:
克隆仓库:
git clone https://gitlink.org.cn/yzx2004/CGAN_jittor.git cd CGAN-Jittor
安装 Jittor:
按照 Jittor 官方文档 的说明安装 Jittor。
安装其他依赖项:
pip install numpy pillow
要训练 CGAN,可以运行 CGAN.py 脚本,并通过命令行参数自定义训练参数(都是可选的)。下面是一个示例:
CGAN.py
python CGAN.py [--n_epochs 100] [--batch_size 64] [--lr 0.0002] [--latent_dim 100] [--n_classes 10] [--img_size 32] [--channels 1] [--sample_interval 1000]
--n_epochs
--batch_size
--lr
--b1
--b2
--n_cpu
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
在训练过程中,生成器和判别器的模型参数会定期保存到检查点文件 generator_last.pkl 和 discriminator_last.pkl 中。这些文件的用途如下:
generator_last.pkl
discriminator_last.pkl
要加载这些检查点文件,可以在脚本中调用 generator.load('generator_last.pkl') 和 discriminator.load('discriminator_last.pkl')。
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')
训练后,可以使用训练好的生成器生成图像。脚本将在训练期间按指定间隔保存采样的图像。最终生成的图像在训练循环结束后保存为 result.png。
result.png
在训练期间,图像将按指定间隔保存(由 --sample_interval 定义)。下面是训练中采样的图像10000.png以及训练结束后生成的图像result.png的示例。
10000.png
A Jittor implementation of Conditional GAN (CGAN).
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
CGAN_jittor
这个仓库包含一个使用 Jittor 深度学习框架实现的条件生成对抗网络(CGAN)。CGAN 是一种生成对抗网络(GAN),其中生成器和判别器都依赖于一些附加信息,在本项目中是 MNIST 数据集中的类别标签。
目录
介绍
本项目实现了一个 CGAN,用于生成基于类别标签的图像。该实现使用 Jittor 深度学习框架和 MNIST 数据集。CGAN 由一个生成器和一个判别器组成,两者共同训练以生成与给定标签对应的逼真图像。
功能
安装
要安装本项目所需的依赖项,请按照以下步骤操作:
克隆仓库:
安装 Jittor:
按照 Jittor 官方文档 的说明安装 Jittor。
安装其他依赖项:
使用方法
训练 CGAN
要训练 CGAN,可以运行
CGAN.py
脚本,并通过命令行参数自定义训练参数(都是可选的)。下面是一个示例:命令行参数
--n_epochs
:训练的轮数(默认:100)--batch_size
:批次大小(默认:64)--lr
:Adam 优化器的学习率(默认:0.0002)--b1
:Adam 优化器的一阶动量的衰减(默认:0.5)--b2
:Adam 优化器的二阶动量的衰减(默认:0.999)--n_cpu
:生成批次时使用的 CPU 线程数(默认:8)--latent_dim
:潜在空间的维度(默认:100)--n_classes
:数据集的类别数量(默认:10)--img_size
:每个图像的尺寸(默认:32)--channels
:图像通道数(默认:1)--sample_interval
:图像采样的间隔(默认:1000)模型检查点
在训练过程中,生成器和判别器的模型参数会定期保存到检查点文件
generator_last.pkl
和discriminator_last.pkl
中。这些文件的用途如下:generator_last.pkl
:保存生成器的最新模型参数。可以使用这个文件在训练结束后恢复生成器模型,以继续训练或生成图像。discriminator_last.pkl
:保存判别器的最新模型参数。可以使用这个文件在训练结束后恢复判别器模型,以继续训练或评估生成器的性能。要加载这些检查点文件,可以在脚本中调用
generator.load('generator_last.pkl')
和discriminator.load('discriminator_last.pkl')
。生成图像
训练后,可以使用训练好的生成器生成图像。脚本将在训练期间按指定间隔保存采样的图像。最终生成的图像在训练循环结束后保存为
result.png
。结果展示
在训练期间,图像将按指定间隔保存(由
--sample_interval
定义)。下面是训练中采样的图像10000.png
以及训练结束后生成的图像result.png
的示例。10000.png
result.png