目录
目录README.md

Conditional GAN Using Jittor

This project implements a Conditional GAN (Generative Adversarial Network) using the Jittor deep learning framework. The GAN is trained on the MNIST dataset to generate images conditioned on class labels.

Requirements

  • Python 3.x
  • Jittor
  • Jittor Dataset for MNIST
  • NumPy
  • PIL (Pillow)
  • argparse

Files and Directories

  • CGAN.py: The main script which includes the GAN implementation and training loop.
  • generator_last.pkl: Trained generator model.
  • discriminator_last.pkl: Trained discriminator model.
  • result.png: Generated image from a specified digit sequence.

Usage

  1. Install Dependencies: Make sure you have the necessary libraries installed. You can use the following command to install the requirements.

    pip install jittor numpy pillow
  2. Train the Model

    Run the main.py script to train the model. You can adjust the training parameters using command-line arguments.

    python main.py --n_epochs 100 --batch_size 64 --lr 0.0002 --latent_dim 100 --n_classes 10 --img_size 32 --channels 1 --sample_interval 1000
  3. Generate Images

    After training, you can generate images using the trained generator. The script is designed to generate images based on a specified digit sequence.

    python main.py
  4. Output

    The generated image will be saved as result.png.

Model Architecture

Generator

The generator network takes latent space vectors and class labels as input and generates an image conditioned on the class label.

Discriminator

The discriminator network takes an image and class label as input and outputs a single scalar value representing the probability that the image is real.

关于

本项目使用Jittor深度学习框架创建Conditional GAN(生成式对抗网络)模型进行训练,训练集为MNIST的手写灰度数字图片,其中标签即为数字的值。训练完成后,给定数字序列可以使用生成器生成相应的灰度图片。

31.0 KB
邀请码