目录
目录readme.md

一、项目介绍

​ 该项目实现了一个条件生成对抗网络(CGAN),使用Jittor深度学习框架对MNIST数据集进行训练,以生成与特定数字标签相对应的手写数字图像。代码首先导入所需的库并设置CUDA支持,然后通过命令行参数解析配置训练的各种超参数,如训练轮数、批大小、学习率等。

​ 定义了两个主要的神经网络模型:生成器(Generator)和判别器(Discriminator)。生成器接受随机噪声和类别标签作为输入,生成对应的图像;而判别器则判断输入的图像是真实的还是生成的。通过均方误差损失函数来评估生成器和判别器的表现,使用Adam优化器分别优化这两个网络。数据加载部分通过对MNIST数据集进行预处理来生成训练数据,然后在训练循环中,交替训练生成器和判别器,计算损失并更新参数;同时在特定的间隔打印训练过程中的损失信息,并保存生成的样本图像。训练完成后,模型的权重被保存以便后续使用,

​ 最后,代码还展示了如何使用训练好的生成器根据给定的数字标签生成并保存新的手写数字图像。

Understanding GAN Machine Learning: Basics & Applications

二、部署方法

(一)基础环境要求:

  • Python 3.7或以上版本
  • CUDA支持(如果要使用GPU)
  • 足够的内存和存储空间

(二)安装步骤:

# 检查python版本大于等于3.8
python --version
conda install pywin32
python -m pip install jittor
python -m jittor.test.test_core
python -m jittor.test.test_example

(三)运行项目:

# 克隆项目(如果是从git仓库)
git clone [项目地址]
cd [项目目录]

# 运行训练
python CGAN.py

三、原理讲解

​ 生成对抗网络(GAN)是一种强大的机器学习框架,主要用于生成与真实数据相似的图像、视频或其他类型的数据。以下是对GAN的简单介绍及其应用的概述。

(一)生成对抗网络简介

生成对抗网络(Generative Adversarial Network,简称GAN)由伊恩·古德费洛(Ian Goodfellow)等人于2014年提出。GAN的基本结构包括两个神经网络:

  1. 生成网络(Generator):这个网络的任务是从一个随机的潜在空间中生成数据。生成网络通过学习生成与真实样本相似的输出,尽量“欺骗”判别网络。
  2. 判别网络(Discriminator):这个网络的任务是判断输入的数据是真实的还是生成的。它通过对比真实样本与生成样本,学习分辨二者的区别。

这两个网络是在一个对抗的过程中进行训练的。生成网络试图生成越来越真实的数据,而判别网络则试图提高其判别能力。随着训练的进行,生成网络生成的样本越来越接近真实数据,判别网络也变得越来越难以分辨。

(二)应用

生成对抗网络的应用广泛,以下是一些主要领域的应用实例:

  1. 时尚和广告
    • GAN可以生成虚构的模特图像,节省聘请模特和制作广告的费用。这种技术可以用于多样化广告中的模特,吸引更多目标消费者。
  2. 科学
    • 在科学研究中,GAN被用于改善天文图像的质量,模拟重力透镜现象以研究暗物质。例如,GAN能够预测暗物质在特定方向上的分布。
  3. 电子游戏
    • 在电子游戏领域,GAN可以用来对旧游戏的图像进行高分辨率重建。通过训练,GAN能够生成更清晰的2D纹理,同时保留原始图像的细节和颜色,使得游戏画面更加精美。

四、代码分析

(一)导入库

python
import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn

导入所需的库,其中jittor是一个深度学习框架,argparse用于处理命令行参数,numpy用于数值计算。

(二)CUDA支持

python
if jt.has_cuda:
    jt.flags.use_cuda = 1

检查CUDA是否可用,如可用则启用GPU加速。

(三)解析命令行参数

python
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
# 其他参数...
opt = parser.parse_args()
print(opt)

使用argparse定义训练过程中需要的参数,如训练的轮数、批大小、学习率、潜在空间维度等。

(四)定义图像形状

python
img_shape = (opt.channels, opt.img_size, opt.img_size)

定义生成图像的形状,例如通道数和图像尺寸。

(五)生成器

python
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.model = nn.Sequential(
            *block((opt.latent_dim + opt.n_classes), 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def execute(self, noise, labels):
        gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
        img = self.model(gen_input)
        img = img.view((img.shape[0], *img_shape))
        return img
  • Generator类继承自nn.Module,用于生成图像。
  • label_emb是一个嵌入层,用于将类别标签嵌入为向量。
  • block方法定义了一个由全连接层、批归一化和Leaky ReLU激活函数组成的构建模块。
  • execute方法生成图像,其中输入为噪声和标签。

(六)判别器

python
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(
            nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1)
        )

    def execute(self, img, labels):
        d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        validity = self.model(d_in)
        return validity
  • Discriminator类也继承自nn.Module,用于判别生成的图像是否真实。
  • label_embedding将类别标签嵌入。
  • execute方法将输入图像和标签结合,计算有效性得分。

(七)损失函数

python
adversarial_loss = nn.MSELoss()

使用均方误差损失函数来衡量生成器和判别器的性能。

(八)数据加载

python
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

加载MNIST数据集,并对图像进行预处理,如调整大小、灰度化和标准化。

(九)优化器

python
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

为生成器和判别器定义Adam优化器。

(十)保存图像的函数

python
def save_image(img, path, nrow=10, padding=5):
    # 处理图像并保存

定义一个函数,用于保存生成的图像。

(十一)生成图像的函数

python
def sample_image(n_row, batches_done):
    # 随机采样生成的图像并保存

从生成器中随机采样图像并保存。

(十二)训练模型

python
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]
        # 定义真实和假图像的标签
        # 训练生成器
        # 训练判别器
        # 打印损失

这是模型训练的主要部分:

  • 对每个epoch和每个批次,分别训练生成器和判别器。
  • 计算生成器的损失,更新生成器的参数。
  • 计算判别器的损失,更新判别器的参数。
  • 每隔一定的批次打印损失信息并生成样本图像。

(十三)模型保存与加载

python
generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')

在训练结束后,将生成的模型保存到文件中,并在需要时加载。

(十四)自定义输入生成图像

python
number = "2212784"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)

根据自定义的数字生成图像。

(十五)保存最终生成的图像

python
img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")

将生成的图像处理并保存到文件中。

五、运行结果

可以看到我们正确的生成了手写数字。

img

关于

Jittor Gan

44.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号