python
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# 定义真实和假图像的标签
# 训练生成器
# 训练判别器
# 打印损失
python
number = "2212784"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)
一、项目介绍
该项目实现了一个条件生成对抗网络(CGAN),使用Jittor深度学习框架对MNIST数据集进行训练,以生成与特定数字标签相对应的手写数字图像。代码首先导入所需的库并设置CUDA支持,然后通过命令行参数解析配置训练的各种超参数,如训练轮数、批大小、学习率等。
定义了两个主要的神经网络模型:生成器(Generator)和判别器(Discriminator)。生成器接受随机噪声和类别标签作为输入,生成对应的图像;而判别器则判断输入的图像是真实的还是生成的。通过均方误差损失函数来评估生成器和判别器的表现,使用Adam优化器分别优化这两个网络。数据加载部分通过对MNIST数据集进行预处理来生成训练数据,然后在训练循环中,交替训练生成器和判别器,计算损失并更新参数;同时在特定的间隔打印训练过程中的损失信息,并保存生成的样本图像。训练完成后,模型的权重被保存以便后续使用,
最后,代码还展示了如何使用训练好的生成器根据给定的数字标签生成并保存新的手写数字图像。
二、部署方法
(一)基础环境要求:
(二)安装步骤:
(三)运行项目:
三、原理讲解
生成对抗网络(GAN)是一种强大的机器学习框架,主要用于生成与真实数据相似的图像、视频或其他类型的数据。以下是对GAN的简单介绍及其应用的概述。
(一)生成对抗网络简介
生成对抗网络(Generative Adversarial Network,简称GAN)由伊恩·古德费洛(Ian Goodfellow)等人于2014年提出。GAN的基本结构包括两个神经网络:
这两个网络是在一个对抗的过程中进行训练的。生成网络试图生成越来越真实的数据,而判别网络则试图提高其判别能力。随着训练的进行,生成网络生成的样本越来越接近真实数据,判别网络也变得越来越难以分辨。
(二)应用
生成对抗网络的应用广泛,以下是一些主要领域的应用实例:
四、代码分析
(一)导入库
导入所需的库,其中
jittor
是一个深度学习框架,argparse
用于处理命令行参数,numpy
用于数值计算。(二)CUDA支持
检查CUDA是否可用,如可用则启用GPU加速。
(三)解析命令行参数
使用
argparse
定义训练过程中需要的参数,如训练的轮数、批大小、学习率、潜在空间维度等。(四)定义图像形状
定义生成图像的形状,例如通道数和图像尺寸。
(五)生成器
Generator
类继承自nn.Module
,用于生成图像。label_emb
是一个嵌入层,用于将类别标签嵌入为向量。block
方法定义了一个由全连接层、批归一化和Leaky ReLU激活函数组成的构建模块。execute
方法生成图像,其中输入为噪声和标签。(六)判别器
Discriminator
类也继承自nn.Module
,用于判别生成的图像是否真实。label_embedding
将类别标签嵌入。execute
方法将输入图像和标签结合,计算有效性得分。(七)损失函数
使用均方误差损失函数来衡量生成器和判别器的性能。
(八)数据加载
加载MNIST数据集,并对图像进行预处理,如调整大小、灰度化和标准化。
(九)优化器
为生成器和判别器定义Adam优化器。
(十)保存图像的函数
定义一个函数,用于保存生成的图像。
(十一)生成图像的函数
从生成器中随机采样图像并保存。
(十二)训练模型
这是模型训练的主要部分:
(十三)模型保存与加载
在训练结束后,将生成的模型保存到文件中,并在需要时加载。
(十四)自定义输入生成图像
根据自定义的数字生成图像。
(十五)保存最终生成的图像
将生成的图像处理并保存到文件中。
五、运行结果
可以看到我们正确的生成了手写数字。