This repository contains a Python script for training a Generative Adversarial Network (GAN) using the Jittor deep learning framework. The script is designed to train a GAN on the MNIST dataset, though it can be adapted for other datasets and uses. The purpose of this GAN is to generate digit images that mimic the style of the MNIST dataset.
Features
Customizable Training Parameters: Adjust various training parameters such as number of epochs, batch size, learning rate, etc.
Image Generation: Generate images at specified intervals to monitor the progress of the model training.
Model Saving and Loading: Save the state of the generator and discriminator during training and load them for later use.
CUDA Support: Leverages GPU acceleration if available to speed up the training process.
Prerequisites
Before you can run this script, ensure that you have the following installed:
Python 3.6 or higher
Jittor
NumPy
PIL (Pillow)
Jittor can utilize NVIDIA GPUs if you have CUDA installed. Make sure your CUDA toolkit is compatible with the installed version of Jittor.
Installation
Clone this repository:
git clone https://github.com/your-username/your-repo-name.git
cd your-repo-name
Install dependencies (assuming you have Python and pip installed):
pip install jittor numpy pillow
Make sure to install Jittor following the specific instructions provided at the Jittor official website if the pip install does not work.
Usage
Run the script using the following command:
python3 CGAN.py
Command Line Arguments
--n_epochs: Number of epochs for training.
--batch_size: Size of the batches.
--lr: Adam optimizer learning rate.
--b1: Decay of first order momentum of gradient (adam optimizer).
--b2: Decay of second order momentum of gradient (adam optimizer).
--n_cpu: Number of CPU threads to use during batch generation.
--latent_dim: Dimensionality of the latent space.
--n_classes: Number of classes for dataset.
--img_size: Size of each image dimension.
--channels: Number of image channels.
--sample_interval: Interval between image sampling.
Image Generation
Generated images will be saved periodically during training in the working directory. These images help visualize the progress of the GAN training.
Customization
To customize the script for different datasets or to change the architecture of the GAN, modify the Generator and Discriminator classes as needed. This may involve changing the layer configurations or the input and output dimensions according to the characteristics of your dataset.
License
This project is licensed under the MIT License - see the LICENSE file for details.
GAN Training with Jittor
This repository contains a Python script for training a Generative Adversarial Network (GAN) using the Jittor deep learning framework. The script is designed to train a GAN on the MNIST dataset, though it can be adapted for other datasets and uses. The purpose of this GAN is to generate digit images that mimic the style of the MNIST dataset.
Features
Prerequisites
Before you can run this script, ensure that you have the following installed:
Jittor can utilize NVIDIA GPUs if you have CUDA installed. Make sure your CUDA toolkit is compatible with the installed version of Jittor.
Installation
Clone this repository:
Install dependencies (assuming you have Python and pip installed):
Make sure to install Jittor following the specific instructions provided at the Jittor official website if the pip install does not work.
Usage
Run the script using the following command:
Command Line Arguments
--n_epochs
: Number of epochs for training.--batch_size
: Size of the batches.--lr
: Adam optimizer learning rate.--b1
: Decay of first order momentum of gradient (adam optimizer).--b2
: Decay of second order momentum of gradient (adam optimizer).--n_cpu
: Number of CPU threads to use during batch generation.--latent_dim
: Dimensionality of the latent space.--n_classes
: Number of classes for dataset.--img_size
: Size of each image dimension.--channels
: Number of image channels.--sample_interval
: Interval between image sampling.Image Generation
Generated images will be saved periodically during training in the working directory. These images help visualize the progress of the GAN training.
Customization
To customize the script for different datasets or to change the architecture of the GAN, modify the
Generator
andDiscriminator
classes as needed. This may involve changing the layer configurations or the input and output dimensions according to the characteristics of your dataset.License
This project is licensed under the MIT License - see the LICENSE file for details.