目录
Zhaoyuanming3个月前3次提交
目录README.md

条件生成对抗网络(CGAN)项目报告

概述

本次作业使用 Jittor 框架实现的条件生成对抗网络(CGAN)。该模型使用 MNIST 数据集进行训练,生成基于类别标签的图像。与GAN不同的是CGAN的判别器会同时收到生成的图片和真实图片的训练。

开发环境

安装依赖

在开始项目之前,安装了以下依赖项:

  • Jittor:用于实现神经网络训练和推理。
  • NumPy:用于数值计算。
  • PIL(Python Imaging Library):用于处理和保存图像。
  • MNIST 数据集:用于训练模型。

项目结构

  • Generator:生成器网络,用于根据输入的随机噪声和类别标签生成图像。
  • Discriminator:判别器网络,用于判断图像是真实的还是生成的。
  • adversarial_loss:对抗损失函数,使用均方误差(MSE)来衡量生成器和判别器的性能。
  • save_image:用于保存生成的图像。
  • sample_image:用于生成和保存一组样本图像。

模型实现

生成器(Generator)

生成器的目的是根据输入的噪声和标签生成图像。它通过一个多层全连接网络来处理输入。模型结构包括:

  1. 标签嵌入层:将类别标签转化为向量表示。
  2. 一系列全连接层:使用 LeakyReLU 激活函数和批归一化层。
  3. 输出层:将最终的向量转化为图像的像素值,使用 Tanh 激活函数。
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.model = nn.Sequential(
            *block((opt.latent_dim + opt.n_classes), 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def execute(self, noise, labels):
        gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
        img = self.model(gen_input)
        img = img.view((img.shape[0], *img_shape))
        return img

判别器(Discriminator)

判别器的任务是判断输入图像是否为真实图像。它将图像和对应的标签信息拼接,并通过一系列全连接层进行处理,最终输出一个标量表示图像的真实性。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(
            nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
        )

    def execute(self, img, labels):
        d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        x = self.model(d_in)
        return x

损失函数与优化器

在训练过程中,生成器和判别器的损失函数是对抗损失。生成器试图欺骗判别器使其认为生成的图像是真实的,而判别器则尝试正确区分真实图像和生成图像。

adversarial_loss = nn.MSELoss()

optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

数据集与预处理

使用 MNIST 数据集作为训练数据。每张图像是 28x28 的灰度图像,通过图像缩放和标准化处理以适应输入尺寸。

from jittor.dataset.mnist import MNIST
import jittor.transform as transform

transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])

dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

训练过程

在训练过程中,生成器和判别器交替训练。生成器通过生成逼真的图像来最小化其损失,而判别器则试图正确分类真实图像和生成图像。每隔一定的批次,保存生成的图像以便观察训练效果。

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]
        valid = jt.ones([batch_size, 1]).float32().stop_grad()
        fake = jt.zeros([batch_size, 1]).float32().stop_grad()

        real_imgs = jt.array(imgs)
        labels = jt.array(labels)

        # 训练生成器
        z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
        gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.sync()
        optimizer_G.step(g_loss)

        # 训练判别器
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.sync()
        optimizer_D.step(d_loss)

        if i % 50 == 0:
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data))

        if epoch % 10 == 0:
            generator.save("generator_last.pkl")
            discriminator.save("discriminator_last.pkl")

生成图像

训练完成后,使用训练好的生成器,根据数字序列2211757,生成对应的图像。

number = "2211757"
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z, labels)

img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1))
min_ = img_array.min()
max_ = img_array.max()
img_array = (img_array - min_) / (max_ - min_) * 255
Image.fromarray(np.uint8(img_array)).save("result.png")

实验结果

可以看到生成了清晰可见的数字序列,CGAN训练完成 result

总结

这次作业实现了一个基于 Jittor 框架的条件生成对抗网络(CGAN)。使用 MNIST 数据集,训练了一个能够根据输入的数字标签生成对应数字图像的模型。生成的图像质量逐步提升,最终实现了测试序列的清晰生成。

关于

这是赵元鸣2211757的计算机图形学作业项目

35.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号