目录
目录README.md

下面是一个 README 文档示例,适用于你的项目:


GAN 图像生成项目

该项目使用 Jittor 框架实现了一个条件生成对抗网络(Conditional GAN, cGAN)。模型能够基于随机噪声和标签生成图像,并通过判别器来进行训练,使生成的图像尽可能接近真实的图像。该实现基于 MNIST 数据集,支持标签控制生成的数字图像。

项目结构

.
├── generator_last.pkl         # 训练后的生成器模型
├── discriminator_last.pkl     # 训练后的判别器模型
├── result.png                 # 生成的图像
├── main.py                    # 主训练脚本
└── README.md                  # 项目的说明文件

依赖

该项目依赖以下 Python 包:

  • Jittor: 用于高效的深度学习框架。
  • Numpy: 用于数组操作。
  • Pillow: 用于图像保存和处理。
  • argparse: 用于命令行参数解析。

可以通过以下命令安装依赖:

pip install jittor numpy pillow

数据集

本项目使用的是 MNIST 数据集,该数据集包含手写数字的 28x28 像素图像。Jittor 提供了 MNIST 数据集的加载器,支持数据的自动下载和预处理。

参数说明

训练参数

以下是训练脚本支持的命令行参数:

  • --n_epochs (default: 100): 训练的总轮数。
  • --batch_size (default: 64): 每个批次的图像数量。
  • --lr (default: 0.0002): Adam 优化器的学习率。
  • --b1 (default: 0.5): Adam 优化器的一阶矩动量衰减。
  • --b2 (default: 0.999): Adam 优化器的二阶矩动量衰减。
  • --n_cpu (default: 8): 用于批量生成的 CPU 线程数。
  • --latent_dim (default: 100): 隐变量的维度。
  • --n_classes (default: 10): 数据集的类别数量。
  • --img_size (default: 32): 图像的尺寸(宽度/高度)。
  • --channels (default: 1): 图像的通道数(灰度图像为 1)。
  • --sample_interval (default: 1000): 每多少步生成并保存一次图像。

生成的图像

每隔一定的训练步数,生成器会生成并保存当前训练阶段的图像,这些图像会保存为 .png 格式。

训练过程

在训练过程中,生成器和判别器交替训练。生成器尝试生成尽可能真实的图像,欺骗判别器将其判断为真实图像;判别器则尝试区分生成图像与真实图像的差异。

如何运行

  1. 下载并安装依赖项:

    pip install jittor numpy pillow
  2. 下载 MNIST 数据集:

    训练脚本会自动下载 MNIST 数据集并进行预处理。无需手动下载。

  3. 运行训练脚本:

    python main.py --n_epochs 100 --batch_size 64 --lr 0.0002 --b1 0.5 --b2 0.999 --n_cpu 8 --latent_dim 100 --n_classes 10 --img_size 32 --channels 1 --sample_interval 1000

    训练过程中,模型会自动保存生成器和判别器的状态。每 1000 步,会生成一个新图像并保存。

  4. 使用生成的模型生成图像:

    在训练完成后,你可以使用保存的生成器模型生成图像。修改 main.py 文件中的以下代码:

    generator.eval()
    discriminator.eval()
    generator.load('generator_last.pkl')
    discriminator.load('discriminator_last.pkl')
    
    number = "19197163644"  # 替换为你自己的电话号码(或其他标签)
    n_row = len(number)
    z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
    labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
    gen_imgs = generator(z, labels)
    
    img_array = gen_imgs.data.transpose((1, 2, 0, 3))[0].reshape((gen_imgs.shape[2], -1))
    min_ = img_array.min()
    max_ = img_array.max()
    img_array = (img_array - min_) / (max_ - min_) * 255
    Image.fromarray(np.uint8(img_array)).save("result.png")

    替换 number 为你自己的手机号或其他标签,运行代码即可生成与标签对应的图像。

生成图像示例

  • result.png: 生成的图像文件,基于输入的标签和噪声生成。

可能遇到的问题

  1. CUDA 错误:如果使用 GPU 训练,确保 Jittor 正确安装并且 CUDA 配置正常。
  2. 内存问题:训练过程中可能会占用较多内存,建议使用合适的 batch size。

仓库信息

链接如下:https://gitlink.org.cn/zsy123/a_new_try.git


这个 README 文件提供了项目的基本概述、运行步骤以及参数设置等信息,可以根据需要进行修改和扩展。如果有更具体的使用场景或功能,可以根据项目实际需求进一步补充。

关于

A Jittor implementation of Conditional GAN (CGAN)

36.0 KB
邀请码
    Gitlink(确实开源)
  • 加入我们
  • 官网邮箱:gitlink@ccf.org.cn
  • QQ群
  • QQ群
  • 公众号
  • 公众号

©Copyright 2023 CCF 开源发展委员会
Powered by Trustie& IntelliDE 京ICP备13000930号