ADD file via upload
本项目实现了基于 PyTorch 的条件生成对抗网络 (Conditional GAN),用于生成特定类别的手写数字图像。通过给定的随机噪声和类别标签,模型可以学习生成与该类别相对应的手写数字图片。我的目标是探索深度学习在图像生成任务上的应用,并提高对 GAN 架构的理解。
要运行本项目,请确保你的环境中已安装jittor框架,我提交了两个版本,一个是myCGAN.py(torch框架下的实现)一个是CGAN.py(jittor框架下的实现)
我主要是把示例代码中的两个todo给完成了:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes) self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), # TODO: 添加最后一个线性层,最终输出为一个实数 ) def execute(self, img, labels): d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1) # TODO: 将d_in输入到模型中并返回计算结果
变成了:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes) self.model = nn.Sequential( nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2), nn.Linear(512, 1), ) def forward(self, img, labels): d_in = torch.cat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1) validity = self.model(d_in) return validity
经过多轮迭代,我的模型成功地学会了生成指定类别的手写数字图像。下面展示了生成结果,我的数字是”2212123“:
A Jittor implementation of Conditional GAN (CGAN).
Conditional GAN - MNIST 生成对抗网络
简介
本项目实现了基于 PyTorch 的条件生成对抗网络 (Conditional GAN),用于生成特定类别的手写数字图像。通过给定的随机噪声和类别标签,模型可以学习生成与该类别相对应的手写数字图片。我的目标是探索深度学习在图像生成任务上的应用,并提高对 GAN 架构的理解。
安装
要运行本项目,请确保你的环境中已安装jittor框架,我提交了两个版本,一个是myCGAN.py(torch框架下的实现)一个是CGAN.py(jittor框架下的实现)
我的修改
我主要是把示例代码中的两个todo给完成了:
变成了:
结果
经过多轮迭代,我的模型成功地学会了生成指定类别的手写数字图像。下面展示了生成结果,我的数字是”2212123“: