华为开源自研AI框架昇思MindSpore应用实践:DCGAN生成漫画头像

2023-05-16

目录

  • 一、原理说明
    • 1.GAN基础原理
    • 2.DCGAN原理
  • 二、环境准备
    • 1.进入ModelArts官网
    • 2.使用CodeLab体验Notebook实例
  • 三、数据准备与处理
    • 1.数据处理
  • 四、创建网络
    • 1.生成器
    • 2.判别器
    • 3.损失和优化器
    • 4.优化器
  • 五、训练模型
  • 六、结果展示

本教程是通过示例代码说明DCGAN网络如何设置网络、优化器、如何计算损失函数以及如何初始化模型权重。在本教程中,使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

一、原理说明

1.GAN基础原理

生成式对抗网络(Generative Adversarial Networks,GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,GAN由两个不同的模型组成——生成器和判别器:

生成器的任务是生成看起来像训练图像的“假”图像;

判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

2.DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量z,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

二、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述
等待环境搭建完成
在这里插入图片描述

2.使用CodeLab体验Notebook实例

下载NoteBook样例代码.ipynb为样例代码,faces文件夹中有动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96

在这里插入图片描述

在这里插入图片描述

选择ModelArts Upload Files上传.ipynb文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

选择Kernel环境

在这里插入图片描述
进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

回到Notebook中,在第一块代码前加入三块命令
在这里插入图片描述

pip install --upgrade pip
conda install mindspore-gpu=1.9.0 cudatoolkit=10.1 -c mindspore -c conda-forge
pip install mindvision

依次运行即可

在这里插入图片描述

在这里插入图片描述

三、数据准备与处理

首先我们将数据集下载到指定目录下并解压。示例代码如下:


from mindvision import dataset

dl_path = "./datasets"
dl_url = "https://download.mindspore.cn/dataset/Faces/faces.zip"

dl = dataset.DownLoad()  # 下载数据集
dl.download_and_extract_archive(url=dl_url, download_path=dl_path)

在这里插入图片描述

注意:如果这里显示

ImportError: libcudart.so.10.1: cannot open shared object file: No such file or directory
说明你选择的MindSpore安装版本有问题,请从头再来,并切换至GPU版本的MindSpore,同时在选择执行模式为图模式,指定训练使用的平台为"GPU"

得到动漫头像数据集
在这里插入图片描述

1.数据处理

首先为执行过程定义一些输入:


import mindspore as ms

# 选择执行模式为图模式;指定训练使用的平台为"GPU",如需使用昇腾硬件可将其替换为"Ascend"
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")

data_root = "./datasets"  # 数据集根目录
batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 10           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

在这里插入图片描述

定义create_dataset_imagenet函数对数据进行处理和增强操作。


import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

from mindspore import nn, ops

def create_dataset_imagenet(dataset_path):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=4,
                                     shuffle=True,
                                     decode=True)

    # 数据增强操作
    transform_img = [
        vision.Resize(image_size),
        vision.CenterCrop(image_size),
        vision.HWC2CHW(),
        lambda x: ((x / 255).astype("float32"), np.random.normal(size=(nz, 1, 1)).astype("float32"))]

    # 数据映射操作
    data_set = data_set.map(input_columns="image",
                            num_parallel_workers=4,
                            operations=transform_img,
                            output_columns=["image", "latent_code"],
                            column_order=["image", "latent_code"])

    # 批量操作
    data_set = data_set.batch(batch_size)
    return data_set

# 获取处理后的数据集
data = create_dataset_imagenet(data_root)

# 获取数据集大小
size = data.get_dataset_size()

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。


import matplotlib.pyplot as plt
%matplotlib inline

data_iter = next(data.create_dict_iterator(output_numpy=True))

# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['image'][:30], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow(image.transpose(1, 2, 0))
plt.show()

在这里插入图片描述

在这里插入图片描述

四、创建网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

1.生成器

我们通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。


from mindspore.common import initializer as init

def conv_t(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
    """定义转置卷积层"""
    weight_init = init.Normal(mean=0, sigma=0.02)
    return nn.Conv2dTranspose(in_channels, out_channels,
                              kernel_size=kernel_size, stride=stride, padding=padding,
                              weight_init=weight_init, has_bias=False, pad_mode=pad_mode)

def bn(num_features):
    """定义BatchNorm2d层"""
    gamma_init = init.Normal(mean=1, sigma=0.02)
    return nn.BatchNorm2d(num_features=num_features, gamma_init=gamma_init)

class Generator(nn.Cell):
    """DCGAN网络生成器"""

    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.SequentialCell()
        self.generator.append(conv_t(nz, ngf * 8, 4, 1, 0))
        self.generator.append(bn(ngf * 8))
        self.generator.append(nn.ReLU())
        self.generator.append(conv_t(ngf * 8, ngf * 4, 4, 2, 1))
        self.generator.append(bn(ngf * 4))
        self.generator.append(nn.ReLU())
        self.generator.append(conv_t(ngf * 4, ngf * 2, 4, 2, 1))
        self.generator.append(bn(ngf * 2))
        self.generator.append(nn.ReLU())
        self.generator.append(conv_t(ngf * 2, ngf, 4, 2, 1))
        self.generator.append(bn(ngf))
        self.generator.append(nn.ReLU())
        self.generator.append(conv_t(ngf, nc, 4, 2, 1))
        self.generator.append(nn.Tanh())

    def construct(self, x):
        return self.generator(x)

# 实例化生成器
netG = Generator()

在这里插入图片描述

2.判别器

判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

DCGAN论文提到,使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

判别器的代码实现如下:


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
    """定义卷积层"""
    weight_init = init.Normal(mean=0, sigma=0.02)
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=kernel_size, stride=stride, padding=padding,
                     weight_init=weight_init, has_bias=False, pad_mode=pad_mode)

class Discriminator(nn.Cell):
    """DCGAN网络判别器"""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.SequentialCell()
        self.discriminator.append(conv(nc, ndf, 4, 2, 1))
        self.discriminator.append(nn.LeakyReLU(0.2))
        self.discriminator.append(conv(ndf, ndf * 2, 4, 2, 1))
        self.discriminator.append(bn(ndf * 2))
        self.discriminator.append(nn.LeakyReLU(0.2))
        self.discriminator.append(conv(ndf * 2, ndf * 4, 4, 2, 1))
        self.discriminator.append(bn(ndf * 4))
        self.discriminator.append(nn.LeakyReLU(0.2))
        self.discriminator.append(conv(ndf * 4, ndf * 8, 4, 2, 1))
        self.discriminator.append(bn(ndf * 8))
        self.discriminator.append(nn.LeakyReLU(0.2))
        self.discriminator.append(conv(ndf * 8, 1, 4, 1))
        self.discriminator.append(nn.Sigmoid())

    def construct(self, x):
        return self.discriminator(x)

# 实例化判别器
netD = Discriminator()

在这里插入图片描述

3.损失和优化器

MindSpore将损失函数、优化器等操作都封装到了Cell中,因为GAN结构上的特殊性,其损失是判别器和生成器的多输出形式,这就导致它和一般的分类网络不同。所以我们需要自定义WithLossCell类,将网络和Loss连接起来。

损失函数
当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss ,为DG加上损失函数和优化器。

连接生成器和损失函数,代码如下:


# 定义损失函数
loss = nn.BCELoss(reduction='mean')

class WithLossCellG(nn.Cell):
    """连接生成器和损失"""

    def __init__(self, netD, netG, loss_fn):
        super(WithLossCellG, self).__init__(auto_prefix=True)
        self.netD = netD
        self.netG = netG
        self.loss_fn = loss_fn

    def construct(self, latent_code):
        """构建生成器损失计算结构"""
        fake_data = self.netG(latent_code)
        out = self.netD(fake_data)
        label_real = ops.OnesLike()(out)
        loss = self.loss_fn(out, label_real)
        return loss

在这里插入图片描述

连接判别器和损失函数,代码如下:


class WithLossCellD(nn.Cell):
    """连接判别器和损失"""

    def __init__(self, netD, netG, loss_fn):
        super(WithLossCellD, self).__init__(auto_prefix=True)
        self.netD = netD
        self.netG = netG
        self.loss_fn = loss_fn

    def construct(self, real_data, latent_code):
        """构建判别器损失计算结构"""
        out_real = self.netD(real_data)
        label_real = ops.OnesLike()(out_real)
        loss_real = self.loss_fn(out_real, label_real)

        fake_data = self.netG(latent_code)
        fake_data = ops.stop_gradient(fake_data)
        out_fake = self.netD(fake_data)
        label_fake = ops.ZerosLike()(out_fake)
        loss_fake = self.loss_fn(out_fake, label_fake)
        return loss_real + loss_fake

在这里插入图片描述

4.优化器

这里设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr = 0.0002beta1 = 0.5的Adam优化器。

为了跟踪生成器的学习进度,在训练的过程中,我们定期将一批固定的遵循高斯分布的隐向量fixed_noise输入到G中,可以看到隐向量生成的图像。


# 创建一批隐向量用来观察G
np.random.seed(1)
fixed_noise = ms.Tensor(np.random.randn(64, nz, 1, 1), dtype=ms.float32)

# 为生成器和判别器设置优化器
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=lr, beta1=beta1)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=lr, beta1=beta1)

五、训练模型

训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Goodfellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化logD(x)+log(1−D(G(z))的值。

训练生成器如DCGAN论文所述,我们希望通过最小化log(1−D(G(z)))来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

下面进行训练:


class DCGAN(nn.Cell):
    """定义DCGAN网络"""

    def __init__(self, myTrainOneStepCellForD, myTrainOneStepCellForG):
        super(DCGAN, self).__init__(auto_prefix=True)
        self.myTrainOneStepCellForD = myTrainOneStepCellForD
        self.myTrainOneStepCellForG = myTrainOneStepCellForG

    def construct(self, real_data, latent_code):
        output_D = self.myTrainOneStepCellForD(real_data, latent_code).view(-1)
        netD_loss = output_D.mean()
        output_G = self.myTrainOneStepCellForG(latent_code).view(-1)
        netG_loss = output_G.mean()
        return netD_loss, netG_loss

实例化生成器和判别器的WithLossCellTrainOneStepCell


# 实例化WithLossCell
netD_with_criterion = WithLossCellD(netD, netG, loss)
netG_with_criterion = WithLossCellG(netD, netG, loss)

# 实例化TrainOneStepCell
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)

在这里插入图片描述

循环训练网络,每经过50次迭代,就收集生成器和判别器的损失,以便于后面绘制训练过程中损失函数的图像。


# 实例化DCGAN网络
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
dcgan.set_train()

# 创建迭代器
data_loader = data.create_dict_iterator(output_numpy=True, num_epochs=num_epochs)
G_losses = []
D_losses = []
image_list = []

# 开始循环训练
print("Starting Training Loop...")

for epoch in range(num_epochs):
    # 为每轮训练读入数据
    for i, d in enumerate(data_loader):
        real_data = ms.Tensor(d['image'])
        latent_code = ms.Tensor(d["latent_code"])
        netD_loss, netG_loss = dcgan(real_data, latent_code)
        if i % 50 == 0 or i == size - 1:
            # 输出训练记录
            print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (
                epoch + 1, num_epochs, i + 1, size, netD_loss.asnumpy(), netG_loss.asnumpy()))
        D_losses.append(netD_loss.asnumpy())
        G_losses.append(netG_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    img = netG(fixed_noise)
    image_list.append(img.transpose(0, 2, 3, 1).asnumpy())

    # 保存网络模型参数为ckpt文件
    ms.save_checkpoint(netG, "Generator.ckpt")
    ms.save_checkpoint(netD, "Discriminator.ckpt")

这里训练时间比较长,请耐心等待
在这里插入图片描述
在这里插入图片描述

Starting Training Loop...
[ 1/10][  1/523]   Loss_D: 1.3341  Loss_G: 4.4303
[ 1/10][ 51/523]   Loss_D: 0.0001  Loss_G:27.6309
[ 1/10][101/523]   Loss_D: 0.0000  Loss_G:27.6309
[ 1/10][151/523]   Loss_D: 0.0000  Loss_G:27.6309
[ 1/10][201/523]   Loss_D: 0.0000  Loss_G:27.6309
[ 1/10][251/523]   Loss_D: 0.0000  Loss_G:27.6308
[ 1/10][301/523]   Loss_D: 0.0000  Loss_G:27.6309
[ 1/10][351/523]   Loss_D: 0.0000  Loss_G:27.6306
[ 1/10][401/523]   Loss_D: 7.1362  Loss_G:10.8959
[ 1/10][451/523]   Loss_D: 2.7982  Loss_G: 1.6938
[ 1/10][501/523]   Loss_D: 0.5665  Loss_G: 3.3509
[ 1/10][523/523]   Loss_D: 0.8589  Loss_G: 5.8118
[ 2/10][  1/523]   Loss_D: 0.7220  Loss_G: 3.6486
[ 2/10][ 51/523]   Loss_D: 0.9084  Loss_G: 3.4355
[ 2/10][101/523]   Loss_D: 0.7106  Loss_G: 3.3597
[ 2/10][151/523]   Loss_D: 1.2464  Loss_G: 3.8619
[ 2/10][201/523]   Loss_D: 1.4379  Loss_G: 1.4148
[ 2/10][251/523]   Loss_D: 0.5010  Loss_G: 2.6713
[ 2/10][301/523]   Loss_D: 0.8369  Loss_G: 3.2203
[ 2/10][351/523]   Loss_D: 0.8340  Loss_G: 2.7246
[ 2/10][401/523]   Loss_D: 0.7258  Loss_G: 3.1784
[ 2/10][451/523]   Loss_D: 0.6898  Loss_G: 3.4755
[ 2/10][501/523]   Loss_D: 0.9853  Loss_G: 3.4425
[ 2/10][523/523]   Loss_D: 0.8548  Loss_G: 2.3108
[ 3/10][  1/523]   Loss_D: 1.1206  Loss_G: 6.0529
[ 3/10][ 51/523]   Loss_D: 0.6412  Loss_G: 3.2571
[ 3/10][101/523]   Loss_D: 0.7830  Loss_G: 3.2050
[ 3/10][151/523]   Loss_D: 1.0531  Loss_G: 4.0849
[ 3/10][201/523]   Loss_D: 0.4773  Loss_G: 3.4415
[ 3/10][251/523]   Loss_D: 1.0287  Loss_G: 5.1689
[ 3/10][301/523]   Loss_D: 0.7435  Loss_G: 4.2903
[ 3/10][351/523]   Loss_D: 0.7258  Loss_G: 3.4914
[ 3/10][401/523]   Loss_D: 0.9525  Loss_G: 1.8072
[ 3/10][451/523]   Loss_D: 0.7222  Loss_G: 2.1848
[ 3/10][501/523]   Loss_D: 0.4841  Loss_G: 3.8900
[ 3/10][523/523]   Loss_D: 1.3593  Loss_G: 1.6790
[ 4/10][  1/523]   Loss_D: 1.3692  Loss_G: 6.2913
[ 4/10][ 51/523]   Loss_D: 0.8611  Loss_G: 3.9655
[ 4/10][101/523]   Loss_D: 1.3133  Loss_G: 2.4826
[ 4/10][151/523]   Loss_D: 0.6847  Loss_G: 5.1198
[ 4/10][201/523]   Loss_D: 0.6726  Loss_G: 3.9191
[ 4/10][251/523]   Loss_D: 1.3120  Loss_G: 2.4799
[ 4/10][301/523]   Loss_D: 0.5391  Loss_G: 2.5938
[ 4/10][351/523]   Loss_D: 0.5148  Loss_G: 3.3189
[ 4/10][401/523]   Loss_D: 0.5152  Loss_G: 2.1859
[ 4/10][451/523]   Loss_D: 0.4354  Loss_G: 3.7258
[ 4/10][501/523]   Loss_D: 0.8461  Loss_G: 1.6059
[ 4/10][523/523]   Loss_D: 0.8209  Loss_G: 1.4153
[ 5/10][  1/523]   Loss_D: 1.3621  Loss_G: 8.4941
[ 5/10][ 51/523]   Loss_D: 0.6527  Loss_G: 3.3710
[ 5/10][101/523]   Loss_D: 0.4800  Loss_G: 3.0760
[ 5/10][151/523]   Loss_D: 0.5460  Loss_G: 2.8898
[ 5/10][201/523]   Loss_D: 0.7443  Loss_G: 2.4008
[ 5/10][251/523]   Loss_D: 0.9210  Loss_G: 5.4013
[ 5/10][301/523]   Loss_D: 0.5267  Loss_G: 3.1586
[ 5/10][351/523]   Loss_D: 0.5461  Loss_G: 4.4159
[ 5/10][401/523]   Loss_D: 0.5737  Loss_G: 3.2949
[ 5/10][451/523]   Loss_D: 0.9223  Loss_G: 1.4930
[ 5/10][501/523]   Loss_D: 0.9890  Loss_G: 5.1565
[ 5/10][523/523]   Loss_D: 0.8597  Loss_G: 5.6968
[ 6/10][  1/523]   Loss_D: 0.8149  Loss_G: 1.9866
[ 6/10][ 51/523]   Loss_D: 1.3344  Loss_G: 8.2650
[ 6/10][101/523]   Loss_D: 0.5464  Loss_G: 2.9574
[ 6/10][151/523]   Loss_D: 0.5783  Loss_G: 3.9141
[ 6/10][201/523]   Loss_D: 0.5426  Loss_G: 4.5565
[ 6/10][251/523]   Loss_D: 0.5757  Loss_G: 2.4842
[ 6/10][301/523]   Loss_D: 0.7165  Loss_G: 4.2469
[ 6/10][351/523]   Loss_D: 0.5514  Loss_G: 1.9710
[ 6/10][401/523]   Loss_D: 0.5034  Loss_G: 3.3386
[ 6/10][451/523]   Loss_D: 0.5529  Loss_G: 2.5434
[ 6/10][501/523]   Loss_D: 0.5793  Loss_G: 4.5730
[ 6/10][523/523]   Loss_D: 0.4959  Loss_G: 2.3813
[ 7/10][  1/523]   Loss_D: 0.5583  Loss_G: 4.7816
[ 7/10][ 51/523]   Loss_D: 0.4124  Loss_G: 3.1867
[ 7/10][101/523]   Loss_D: 0.5679  Loss_G: 2.6333
[ 7/10][151/523]   Loss_D: 0.4654  Loss_G: 3.8254
[ 7/10][201/523]   Loss_D: 0.6624  Loss_G: 1.2572
[ 7/10][251/523]   Loss_D: 0.6794  Loss_G: 4.7149
[ 7/10][301/523]   Loss_D: 0.5441  Loss_G: 4.5748
[ 7/10][351/523]   Loss_D: 0.5405  Loss_G: 4.4008
[ 7/10][401/523]   Loss_D: 0.8556  Loss_G: 5.3858
[ 7/10][451/523]   Loss_D: 0.8062  Loss_G: 1.3542
[ 7/10][501/523]   Loss_D: 0.7903  Loss_G: 1.2369
[ 7/10][523/523]   Loss_D: 1.0799  Loss_G: 1.1563
[ 8/10][  1/523]   Loss_D: 1.1528  Loss_G: 6.3701
[ 8/10][ 51/523]   Loss_D: 0.5500  Loss_G: 2.5632
[ 8/10][101/523]   Loss_D: 0.8834  Loss_G: 5.6649
[ 8/10][151/523]   Loss_D: 0.4682  Loss_G: 1.9880
[ 8/10][201/523]   Loss_D: 0.8519  Loss_G: 2.0310
[ 8/10][251/523]   Loss_D: 1.5056  Loss_G: 7.7112
[ 8/10][301/523]   Loss_D: 0.4374  Loss_G: 3.1714
[ 8/10][351/523]   Loss_D: 0.3988  Loss_G: 3.2287
[ 8/10][401/523]   Loss_D: 0.6580  Loss_G: 3.8090
[ 8/10][451/523]   Loss_D: 0.5487  Loss_G: 3.6912
[ 8/10][501/523]   Loss_D: 0.5297  Loss_G: 3.9933
[ 8/10][523/523]   Loss_D: 0.7350  Loss_G: 4.5166
[ 9/10][  1/523]   Loss_D: 0.8367  Loss_G: 1.3991
[ 9/10][ 51/523]   Loss_D: 1.0498  Loss_G: 5.8035
[ 9/10][101/523]   Loss_D: 0.5274  Loss_G: 2.9916
[ 9/10][151/523]   Loss_D: 0.9688  Loss_G: 1.4680
[ 9/10][201/523]   Loss_D: 0.4435  Loss_G: 3.0589
[ 9/10][251/523]   Loss_D: 0.4547  Loss_G: 3.3577
[ 9/10][301/523]   Loss_D: 0.5956  Loss_G: 3.5646
[ 9/10][351/523]   Loss_D: 0.4052  Loss_G: 2.3165
[ 9/10][401/523]   Loss_D: 0.4558  Loss_G: 2.6287
[ 9/10][451/523]   Loss_D: 0.8953  Loss_G: 5.1640
[ 9/10][501/523]   Loss_D: 0.5268  Loss_G: 2.0344
[ 9/10][523/523]   Loss_D: 0.4568  Loss_G: 2.3330
[10/10][  1/523]   Loss_D: 0.6627  Loss_G: 4.1249
[10/10][ 51/523]   Loss_D: 0.6725  Loss_G: 3.5604
[10/10][101/523]   Loss_D: 0.7393  Loss_G: 2.1902
[10/10][151/523]   Loss_D: 2.1423  Loss_G: 6.3001
[10/10][201/523]   Loss_D: 0.6502  Loss_G: 1.6308
[10/10][251/523]   Loss_D: 0.6091  Loss_G: 3.5198
[10/10][301/523]   Loss_D: 0.3418  Loss_G: 3.1872
[10/10][351/523]   Loss_D: 0.9850  Loss_G: 1.7839
[10/10][401/523]   Loss_D: 0.6159  Loss_G: 1.9957
[10/10][451/523]   Loss_D: 0.4779  Loss_G: 2.7053
[10/10][501/523]   Loss_D: 0.6780  Loss_G: 2.0838
[10/10][523/523]   Loss_D: 0.5710  Loss_G: 3.4589

六、结果展示

运行下面代码,描绘DG损失与训练迭代的关系图:


plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在这里插入图片描述

可视化训练过程中通过隐向量fixed_noise生成的图像。


import matplotlib.pyplot as plt
import matplotlib.animation as animation

def showGif(image_list):
    show_list = []
    fig = plt.figure(figsize=(8, 3), dpi=120)
    for epoch in range(len(image_list)):
        images = []
        for i in range(3):
            row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)
            images.append(row)
        img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
        plt.axis("off")
        show_list.append([plt.imshow(img)])

    ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
    ani.save('./dcgan.gif', writer='pillow', fps=1)

showGif(image_list)

在这里插入图片描述
在这里插入图片描述
注意:训练到此已经结束,最终图像如上

这是原始图像
在这里插入图片描述

随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当num_epochs达到50以上时,生成的动漫头像图片与数据集中的较为相似,下面我们通过加载训练周期为50的生成器网络模型参数文件Generator.ckpt来生成图像,代码如下:


from mindvision import dataset

dl_path = "./netG"
dl_url = "https://download.mindspore.cn/vision/classification/Generator.ckpt"

dl = dataset.DownLoad()  # 下载Generator.ckpt文件
dl.download_url(url=dl_url, path=dl_path)

# 从文件中获取模型参数并加载到网络中
param_dict = ms.load_checkpoint("./netG/Generator.ckpt", netG)

img64 = netG(fixed_noise).transpose(0, 2, 3, 1).asnumpy()

fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):
    images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

在这里插入图片描述

注意:最后这块代码生成的图像是固定的

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

华为开源自研AI框架昇思MindSpore应用实践:DCGAN生成漫画头像 的相关文章

  • 2:Maven-Java Spring

    目录 2 1 Maven介绍2 2 标准目录结构2 3 POM2 4 Maven生命周期2 5 Maven插件 2 1 Maven介绍 Maven是Apache 下的一个纯 Java 开发的开源项目 基于项目对象模型 xff08 缩写 xf
  • 3:SpringBoot-Java Spring

    目录 3 1 SpringBoot介绍3 2 Spring和SpringBoot的区别3 3 系统要求3 4 SpringBootApplication 3 1 SpringBoot介绍 SpringBoot的本质是SpringFramew
  • 4:SpringBoot-Starter-Java Spring

    目录 4 1 SpringBoot Starter介绍4 1 Starter原理4 3 Starter依赖引入4 4 Starter配置 4 1 SpringBoot Starter介绍 Starter是SpringBoot的一种服务 xf
  • 5:SpringBoot-Actuator-Java Spring

    目录 5 1 SpringBoot Actuator介绍5 2 Endpoints 介绍5 3 Actuator原理5 4 Actuator依赖引入 5 1 SpringBoot Actuator介绍 Actuator是Spring Boo
  • Ubuntu 20.04 VNC 安装与设置

    原链接 VNC是一个远程桌面协议 按照本文的说明进行操作可以实现用VNC对Ubuntu 20 04进行远程控制 一般的VNC安装方式在主机没有插显示器的时候是无法使用的 下面的操作可以在主机有显示器和没有显示器时都能够正常工作 首先安装x1
  • 6:RestFul API-Java Spring

    目录 6 1 RestFul API介绍6 2 URL构成6 3 RestFul API原理6 4 RestFul API映射注解6 5 RestFul API操作 6 1 RestFul API介绍 Rest表示性状态转移 xff08 R
  • 7:JSON-Java Spring

    目录 7 1 JSON介绍7 2 JSON和XML的区别7 3 JSON的构成7 4 JSON的语法7 5 JSON parse 7 6 JSON stringify 7 1 JSON介绍 JSON即JavaScript 对象标记法 xff
  • 8:Spring MVC-Java Spring

    目录 8 1 WEB开发模式一8 2 WEB开发模式二8 3 Spring MVC介绍8 4 Spring MVC主要组件8 5 Spring MVC处理流程8 6 Spring MVC的HTTP请求方法 在Web开发中有两种主要的结构 x
  • 9:参数校验-Java Spring

    目录 9 1 参数校验介绍9 2 JSR3039 3 Hibernate Validator9 4 参数校验依赖引入 9 1 参数校验介绍 参数校验即保证数据的合法性 xff0c JCP组织定义了一个标准来规范化这个任务操作 xff0c 即
  • 江服校园导游咨询系统-数据结构课程设计

    目录 1 需求分析1 1 问题描述1 2 系统简介1 3 系统模块功能要求介绍1 4 系统开发环境及开发人员1 5 校园平面图 2 概要设计2 1 算法设计及存储结构说明2 2 系统功能设计 3 详细设计3 1 定义符号变量3 2 主程序模
  • 基于STM32的光敏传感器数据采集系统-嵌入式系统与设计课程设计

    目录 1 项目概述1 1 项目介绍1 2 项目开发环境1 3 小组人员及分工 2 需求分析2 1 系统需求分析2 2 可行性分析2 3 项目实施安排 3 系统硬件设计3 1 系统整体硬件电路设计3 2 STM32 最小系统电路设计3 3 传
  • QX-A51智能小车实现-物联网应用系统设计项目开发

    目录 介绍说明展示 介绍 STC89C52系列单片机是STC推出的新一代高速 低功耗 超强抗干扰 超低价的单片机 xff0c 指令代码完全兼容传统8051单片机 xff0c 12时钟每机器周期和6时钟每机器周期可以任意选择 QX A51智能
  • 11:跨域访问-Java Spring

    目录 11 1 跨域访问11 2 同源策略11 3 跨域解决方案 11 1 跨域访问 跨域指的是浏览器不能执行其他网站的脚本 xff0c 当一个请求url的协议 域名 端口三者有任意一个不同即为跨域 无法跨域是由浏览器的同源策略造成的 xf
  • 10:@Validated和@Valid-Java Spring

    目录 10 1 64 Valid10 2 64 Validated10 2 区别10 2 Controller参数校验 10 1 64 Valid 64 Valid 是 Hibernate validation 提供的注解 xff0c 表示
  • 12:CORS跨域设置-Java Spring

    目录 12 1 CORS介绍12 2 CORS原理12 3 CORS实现 12 1 CORS介绍 CORS跨域资源共享 xff08 Cross origin resource sharing xff09 是指在服务器端定义跨域请求规则 xf
  • Ubuntu虚拟机可以上网,可以ping网络,但是无法update和install,显示不能连接或者无网络

    此方法为我找遍了网上全部解决方案之后还没有解决掉 xff0c 自己琢磨出来的其中一种方法 错误情况 xff1a 可以上浏览器看视频 xff0c 但是不能apt install vim或者gcc 解决方案 1 打开文件夹 2 输入 或者进入
  • 13:SpringBoot跨域解决方案-Java Spring

    目录 13 1 CorsFilter13 2 64 CrossOrigin13 3 WebMvcConfigurer 13 1 CorsFilter SpringBoot设置CORS的的本质都是通过设置响应头信息来告诉前端该请求是否支持跨域
  • 14:Servlet并发机制-Java Spring

    目录 14 1 并发14 2 Servlet并发机制14 3 Tomcat并发特点14 4 Tomcat线程模型 14 1 并发 并发 xff08 Concurrent xff09 是指多个任务交替执行的现象 xff0c 把CPU运行时间划
  • 手写字体识别实验-Python课程设计

    安装python 打开手写识别文件夹中的安装包文件夹 xff0c 双击python3 7 1可执行文件 xff0c 进行安装 弹出窗口 第一步 xff0c 勾选第二个复选框 Add Python 3 7 to PATH xff0c 然后点击
  • 生产企业原材料订购与运输的研究-数据处理课程设计

    目录 摘要1 引言2 规划问题说明3 问题重述3 1 问题分析3 2 数据说明3 3 模型假设3 4 符号说明 4 实验及分析4 1 问题一模型的建立与求解4 2 问题二模型的建立与求解 5 总结5 1 模型的优点5 2 模型的缺点 参考文

随机推荐

  • 信号发生器-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 信号发生器设计与计算3 1 信号发生器设计3 2 方波振荡电路图3 3 三角波振荡电路图3 4 参数计算 4 总原理图及元器件清单
  • 增益可控放大电路-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 放大电路设计与计算3 1 放大电路设计3 2 电子开关切换电路设计3 3 六档控制电路3 4 参数计算 4 总原理图及元器件清单4
  • 超声波测距实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习超声波测距传感器的使用方法 xff0c 了解超声波测距传感器的原理和电路及实际应用 xff0c 了解超声波测距传感器的基本操作
  • 光敏传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习光敏传感器的使用方法 xff0c 了解光敏传感器的基本实验原理和实际应用 xff0c 熟练掌握光敏传感器实验的操作步骤 xff
  • 红外反射传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习红外反射传感器的使用方式 xff0c 了解红外反射传感器的实验原理和实际应用 xff0c 学习并理解Modbus数据格式所代表
  • 酒精传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习酒精传感器MQ 3的使用方法 xff0c 了解酒精传感器的实验原理和实际应用 xff0c 了解酒精传感器的基本操作模式 xff
  • hdoj 1575 Tr A (矩阵快速幂)

    Tr A Time Limit 1000 1000 MS Java Others Memory Limit 32768 32768 K Java Others Total Submission s 4549 Accepted Submiss
  • MapReduce排序过程

    排序是MapReduce框架中最重要的操作之一 MapTask和ReduceTask均会对数据按照key 进行排序 该操作属于Hadoop 的默认行为 xff0c 任何应用程序中的数据均会被排序 xff0c 而不管逻辑上是否需要 默认排序是
  • 温湿度传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习温湿度传感器的使用方法 xff0c 了解温湿度传感器的基本实验原理和实际应用 xff0c 熟练掌握温湿度传感器的基本步骤 xf
  • 烟雾检测传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习烟雾检测传感器的原理及检测方式 xff0c 了解烟雾检测传感器的实验原理和技术指标 xff0c 熟练掌握烟雾检测传感器的工作步
  • 4:Servlet-Java Web

    目录 4 1 Servlet简介4 2 HTTP协议4 3 Servlet与JSP4 4 Servlet处理的基本流程4 5 Servlet 容器4 6 Servlet程序实现 4 1 Servlet简介 Servlet是用Java语言编写
  • 5:Servlet程序-Java Web

    目录 5 1 Servlet要求5 2 创建Servlet5 3 第一个Servlet5 4 Servlet编译5 5 Servlet配置 5 1 Servlet要求 如果要开发一个可以处理HTTP请求的Servlet程序 xff0c 首先
  • 6:部署Servlet-Java Web

    目录 6 1 部署Servlet6 2 请求Servlet6 3 找不到servlet包6 4 Servlet映射的细节 6 1 部署Servlet 部署就是把Servlet的字节码文件放在适当的地方 为了在浏览器上访问Servlet xf
  • 7:Servlet表单-Java Web

    目录 7 1 Servlet响应7 2 Servlet获取客户端参数7 3 Servlet接受表单数据 7 1 Servlet响应 通过response对象对用户进行响应 创建输出流对象 PrintWriter out 61 respons
  • 8:Servlet生命周期-Java Web

    目录 8 1 Servlet生命周期8 2 Servlet生命周期对应的方法8 3 Servlet的多线程机制 8 1 Servlet生命周期 Servlet程序是运行在服务器端的一段Java程序 xff0c 其生命周期将受到Web容器的控
  • 9:中文乱码处理-Java Web

    目录 9 1 常见字符集9 2 乱码原因9 3 解决乱码 9 1 常见字符集 ASCII 最原始的一套编码 xff0c 所有编码都是由一个字节的二进制数对应 xff0c 尽管包含8位 xff0c 但是第一位始终是0 xff0c 也就是128
  • 华为云平台零代码搭建物联网可视化大屏体验:疫情防控数据大屏

    目录 一 介绍二 准备三 搭建1 创建疫情防控大屏应用2 组件放置3 组件配置4 应用打包 一 介绍 零代码搭建物联网可视化大屏 xff1a 自定义物联网场景 xff0c 根据个人理解实现基于华为云IoT以及可视化大屏DLV搭建物联网大屏
  • 华为开源自研AI框架昇思MindSpore入门体验:手写数字识别

    目录 一 环境安装1 进入MindSpore官网2 选择安装版本3 确保为Windows系统4 安装MindSpore5 验证安装6 安装依赖 二 模型训练1 下载并处理数据集2 创建模型 本教程是在CPU Ubuntu上安装MindSpo
  • 转型“系统集成商+大数据运营和服务商”,航天信息看好你哟

    毫无疑问 xff0c 人工智能今天已经是一个 风口 抓住这一契机 xff0c 迎风起舞 xff0c 可能是所有厂商的想法 但是每一个新的趋势出现时 xff0c 一定是机遇与挑战并存 对于厂商来说 xff0c 是处变不惊 xff0c 还是急速
  • 华为开源自研AI框架昇思MindSpore应用实践:DCGAN生成漫画头像

    目录 一 原理说明1 GAN基础原理2 DCGAN原理 二 环境准备1 进入ModelArts官网2 使用CodeLab体验Notebook实例 三 数据准备与处理1 数据处理 四 创建网络1 生成器2 判别器3 损失和优化器4 优化器 五