Update README.md
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN, cGAN),旨在基于随机噪声和标签生成手写数字图像,并通过判别器进行训练,以使生成的图像尽可能接近真实的手写数字图像。该实现基于MNIST数据集,可通过标签控制生成特定数字的图像。
. ├── generator_last.pkl # 训练后的生成器模型 ├── discriminator_last.pkl # 训练后的判别器模型 ├── result.png # 生成的图像 ├── CGAN.py # 主训练脚本 └── README.md # 项目的说明文件
该项目依赖以下Python包:
可以通过以下命令安装依赖:
pip install jittor numpy pillow
本项目使用MNIST数据集,该数据集包含手写数字的28x28像素图像。Jittor提供了MNIST数据集的加载器,支持数据的自动下载和预处理。
以下是训练脚本支持的命令行参数:
--n_epochs
--batch_size
--lr
--b1
--b2
--n_cpu
--latent_dim
--n_classes
--img_size
--channels
--sample_interval
在生成图像时,可在CGAN.py文件中修改number变量来指定生成图像的标签序列(例如,可替换为电话号码、自定义数字序列等)。
CGAN.py
number
每隔sample_interval训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为.png格式(如result.png)。
sample_interval
.png
result.png
在训练过程中,生成器和判别器交替训练。生成器尝试生成逼真的手写数字图像,以欺骗判别器将其判断为真实图像;判别器则努力区分生成图像与真实图像的差异,通过不断调整两者的参数,使生成器生成的图像质量逐渐提高。
python3 CGAN.py
generator.eval() discriminator.eval() generator.load('generator_last.pkl') discriminator.load('discriminator_last.pkl') number = "1234567890" # 替换为你自己的数字序列(例如电话号码等) n_row = len(number) z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad() labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad() gen_imgs = generator(z, labels) img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1)) min_ = img_array.min() max_ = img_array.max() img_array = (img_array - min_) / (max_ - min_) * 255 Image.fromarray(np.uint8(img_array)).save("result.png")
result.png:生成的图像文件,基于输入的标签和噪声生成。例如,当number为”12345”时,生成的图像将尝试呈现出与数字1、2、3、4、5相关的手写数字特征。
jt.has_cuda
batch_size
项目仓库链接如下:https://gitlink.org.cn/yifan_personal/yifan_jitu_hw.git
本项目是使用 Jittor 深度学习框架实现的条件生成对抗网络(CGAN),旨在根据特定标签生成手写数字图像。
©Copyright 2023 CCF 开源发展委员会 Powered by Trustie& IntelliDE 京ICP备13000930号
基于Jittor的条件生成对抗网络手写数字生成项目
本项目使用Jittor框架实现了一个条件生成对抗网络(Conditional GAN, cGAN),旨在基于随机噪声和标签生成手写数字图像,并通过判别器进行训练,以使生成的图像尽可能接近真实的手写数字图像。该实现基于MNIST数据集,可通过标签控制生成特定数字的图像。
项目结构
. ├── generator_last.pkl # 训练后的生成器模型 ├── discriminator_last.pkl # 训练后的判别器模型 ├── result.png # 生成的图像 ├── CGAN.py # 主训练脚本 └── README.md # 项目的说明文件
依赖
该项目依赖以下Python包:
可以通过以下命令安装依赖:
数据集
本项目使用MNIST数据集,该数据集包含手写数字的28x28像素图像。Jittor提供了MNIST数据集的加载器,支持数据的自动下载和预处理。
参数说明
训练参数
以下是训练脚本支持的命令行参数:
--n_epochs
(默认值:100):训练的总轮数。--batch_size
(默认值:64):每个批次的图像数量。--lr
(默认值:0.0002):Adam优化器的学习率。--b1
(默认值:0.5):Adam优化器的一阶矩动量衰减。--b2
(默认值:0.999):Adam优化器的二阶矩动量衰减。--n_cpu
(默认值:8):用于批量生成的CPU线程数。--latent_dim
(默认值:100):隐变量的维度。--n_classes
(默认值:10):数据集的类别数量(MNIST数据集中为10个数字类别)。--img_size
(默认值:32):图像的尺寸(宽度/高度,本项目将MNIST图像调整为此尺寸)。--channels
(默认值:1):图像的通道数(MNIST为灰度图像,通道数为1)。--sample_interval
(默认值:1000):每多少步生成并保存一次图像。生成图像相关参数
在生成图像时,可在
CGAN.py
文件中修改number
变量来指定生成图像的标签序列(例如,可替换为电话号码、自定义数字序列等)。生成的图像
每隔
sample_interval
训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为.png
格式(如result.png
)。训练过程
在训练过程中,生成器和判别器交替训练。生成器尝试生成逼真的手写数字图像,以欺骗判别器将其判断为真实图像;判别器则努力区分生成图像与真实图像的差异,通过不断调整两者的参数,使生成器生成的图像质量逐渐提高。
如何运行
CGAN.py
文件中的以下代码: 将number
替换为你自己的数字序列(如手机号或其他自定义数字序列),然后运行代码即可生成与标签对应的图像。生成图像示例
result.png
:生成的图像文件,基于输入的标签和噪声生成。例如,当number
为”12345”时,生成的图像将尝试呈现出与数字1、2、3、4、5相关的手写数字特征。可能遇到的问题
jt.has_cuda
)以及CUDA相关驱动和环境变量是否设置正确来排查。batch_size
。如果内存不足,可以尝试减小batch_size
或者关闭其他占用内存的程序。仓库信息
项目仓库链接如下:https://gitlink.org.cn/yifan_personal/yifan_jitu_hw.git