目录
目录README.md

Conditional GAN - MNIST 生成对抗网络

简介

本项目实现了基于 PyTorch 的条件生成对抗网络 (Conditional GAN),用于生成特定类别的手写数字图像。通过给定的随机噪声和类别标签,模型可以学习生成与该类别相对应的手写数字图片。我的目标是探索深度学习在图像生成任务上的应用,并提高对 GAN 架构的理解。

安装

要运行本项目,请确保你的环境中已安装jittor框架,我提交了两个版本,一个是myCGAN.py(torch框架下的实现)一个是CGAN.py(jittor框架下的实现)

我的修改

我主要是把示例代码中的两个todo给完成了:



class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
                                   nn.LeakyReLU(0.2),
                                   nn.Linear(512, 512),
                                   nn.Dropout(0.4),
                                   nn.LeakyReLU(0.2),
                                   nn.Linear(512, 512),
                                   nn.Dropout(0.4),
                                   nn.LeakyReLU(0.2),
                                   # TODO: 添加最后一个线性层,最终输出为一个实数
                                   )

    def execute(self, img, labels):
        d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        # TODO: 将d_in输入到模型中并返回计算结果


变成了:


  
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(
            nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        validity = self.model(d_in)
        return validity
        

结果

经过多轮迭代,我的模型成功地学会了生成指定类别的手写数字图像。下面展示了生成结果,我的数字是”2212123“:

生成结果
关于

A Jittor implementation of Conditional GAN (CGAN).

52.0 KB
邀请码