从头实现 StyleGAN1

简介

本文介绍的是目前最好的生成对抗网络之一,StyleGAN,源自论文《基于风格的生成对抗网络生成器架构》。我们将使用PyTorch对其进行清晰、简单且易于理解的实现,并尽可能接近原文的实现,所以如果你阅读了论文,那么这个实现应该与论文非常相似。

本文将使用的数据集是来自Kaggle的数据集,其中包含16240件分辨率为256*192的女性上装。

预备知识

在开始使用PyTorch进行StyleGAN的工作之前,请确保你具备以下基础知识:

  • 深度学习基础知识
    了解卷积神经网络(CNN)。
    熟悉生成对抗网络(GANs),包括生成器、判别器和对抗损失等概念。

  • 硬件要求
    推荐使用强大的GPU(NVIDIA显卡)以加快训练和推理速度。
    已安装CUDA工具包以实现GPU加速(cudacudnn)。

  • 熟悉StyleGAN
    阅读原始StyleGANStyleGAN2论文有助于理解架构改进和关键概念。

加载我们需要的所有依赖项

我们首先会导入torch,因为我们将会使用PyTorch,并从中导入nn。这将帮助我们创建和训练网络,同时也能让我们导入optim,这是一个实现了各种优化算法(例如sgd、adam等)的包。从torchvision我们导入datasets和transforms来准备数据和应用一些转换。

我们将从torch.nn导入functional作为F来使用interpolate上采样图像,从torch.utils.data导入DataLoader来创建小批量大小,从torchvision.utils导入save_image来保存一些假样本,从math导入log2,因为我们需要2的幂的逆表示来实现根据输出分辨率调整自适应小批量大小,导入NumPy进行线性代数计算,导入os与操作系统交互,导入tqdm显示进度条,最后导入matplotlib.pyplot来显示结果并与真实结果进行比较。

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

超参数

  • 通过真实图片的路径初始化数据集DATASET。
  • 指定起始训练的图像大小为8×8。
  • 根据Cuda是否可用初始化设备,否则使用CPU,并将学习率初始化为0.001。
  • 批量大小将根据我们想要生成的图像分辨率而有所不同,因此我们通过一个数字列表初始化BATCH_SIZES,你可以根据你的VRAM更改它们。
  • 将image_size初始化为128,CHANNELS_IMG初始化为3,因为我们将会生成128×128的RGB图像。
  • 在原始论文中,他们使用512初始化Z_DIM、W_DIM和IN_CHANNELS,但我为了减少VRAM的使用并加速训练,将它们初始化为256。如果我们加倍这些值,也许还能得到更好的结果。
  • 对于StyleGAN,我们可以使用任何我们想要的GAN损失函数,所以我使用了论文中的WGAN-GP Improved Training of Wasserstein GANs。这个损失函数包含一个名为λ的参数,通常设置λ=10。
  • 将PROGRESSIVE_EPOCHS初始化为每个图像大小的30。
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 作者从8x8的图像开始,而不是4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

获取数据加载器

现在我们来创建一个函数get_loader来:

  • 对图像应用一些转换(将图像调整为我们想要的分辨率,将它们转换为张量,然后应用一些增强,最后将它们归一化,使所有像素值在-1到1之间)。
  • 使用列表BATCH_SIZES识别当前批处理大小,并取图像大小除以4的2的幂的反向表示的整数值作为索引。这实际上就是我们如何实现根据输出分辨率自适应最小批处理大小的。
  • 通过使用ImageFolder准备数据集,因为它已经以很好的方式结构化了。
  • 使用DataLoader创建mini-batch大小,它接收数据集和批量大小,并对数据进行打乱。
  • 最后,返回加载器和数据集。
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

模型实现

现在让我们实现StyleGAN1生成器和判别器(ProGAN和StyleGAN1具有相同的判别器架构),并从论文中提取关键属性。我们将尝试使实现简洁,同时保持可读性和可理解性。具体的关键点:

  • 噪声映射网络
  • 自适应实例归一化(AdaIN)
  • 逐步增长

在本教程中,我们将只使用StyleGAN1生成图像,而不实现风格混合和随机变化,但这样做应该并不困难。

让我们定义一个名为factors的变量,它包含将乘以IN_CHANNELS的数字,以便在每个图像分辨率中拥有我们想要的通道数。

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

噪声映射网络

噪声映射网络将Z输入并通过八个全连接层,每层之间有一些激活函数。别忘了像ProGAN(由同一研究人员创作的ProGAN和StyleGan)的作者那样平衡学习率。

让我们首先构建一个名为WSLinear(加权缩放线性)的类,它将从nn.Module继承。

  • 初始化部分,我们传入in_features和out_channels。创建一个线性层,然后我们定义一个等于2的平方根除以in_features的缩放比例,我们将当前列层的偏置复制到一个变量中,因为我们不希望线性层的偏置被缩放,然后我们移除它,最后初始化线性层。
  • 正向传播部分,我们传入x,我们要做的就是将x乘以缩放比例,并在重塑后加上偏置。
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # 初始化线性层
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

现在让我们创建MappingNetwork类。

  • 初始化部分,我们传入z_dim和w_din,我们定义了网络映射,首先正规化z_dim,然后是八个WSLinear和ReLU作为激活函数。
  • 正向传播部分,我们返回网络映射。

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

自适应实例归一化(AdaIN)

现在让我们创建AdaIN类

  • 初始化部分,我们发送通道,w_dim,并且初始化instance_norm,它将是实例归一化部分,我们还初始化style_scale和style_bias,它们将是与WSLinear映射噪声映射网络W到通道的自适应部分。
  • 前向传播过程中,我们发送x,对它应用实例归一化,并返回style_sclate * x + style_bias。

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

注入噪声

现在让我们创建一个InjectNoise类来向生成器中注入噪声

  • 初始化部分,我们发送通道,并且从一个随机正态分布中初始化权重,我们使用nn.Parameter以便这些权重可以被优化
  • 前向传播部分,我们发送一个图像x,并且返回添加了随机噪声的它
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

有用的类

作者在ProGAN的官方实现基础上构建了StyleGAN,他们使用了相同的判别器架构、自适应小批量大小、超参数等。因此,从ProGAN实现中有很多类保持不变。

在本节中,我们将创建与ProGAN架构不变的类。

在下面的代码片段中,你可以找到类 WSConv2d(加权缩放卷积层)用于卷积层的等化学习率。

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        初始化卷积层
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

在下面的代码片段中,你可以找到类 PixelNorm 用于在噪声映射网络之前标准化 Z。

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

在下面的代码片段中,你可以找到类 ConvBlock,它将帮助我们创建判别器。

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

在下面的代码片段中,你可以找到类 Discriminatowich,它与 ProGAN 中的相同。

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        这里我们从判别器的因素开始反向工作,因为判别器
        应该是从生成器中镜像出来的。所以第一个prog_block和
        我们附加的rgb层将为1024x1024的输入大小工作,然后是512->256->等等
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        可能令人困惑的名称"initial_rgb",这只是4x4输入大小的RGB层
        这样做是为了"镜像"生成器的initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  使用平均池进行下采样

        这是4x4输入大小的块
        self.final_block = nn.Sequential(
            +1到in_channels,因为我们从MiniBatch std拼接
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  我们使用这个而不是线性层
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha应该是[0, 1]范围内的标量,并且upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        我们对每个示例(跨越所有通道和像素)取std,然后我们重复它
        对于一个通道并将其与图像拼接。这样判别器
        将获得有关批次/图像变化的信息
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        我们应该在prog_blocks列表中的哪里开始,可能有点令人困惑,但是
        最后一个是用于4x4的。所以例如,假设steps=1,那么我们应该从
        倒数第二个开始,因为input_size将是8x8。如果steps==0,我们只是
        使用最后一个块
        cur_step = len(self.prog_blocks) - steps

        从rgb作为初始步骤转换,这将依赖于
        图像大小(每个图像都会有自己的rgb层)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  即,图像是4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        因为prog_blocks可能会改变通道,对于下缩放我们使用rgb_layer
        从前一个/更小的大小,这在我们的情况下与索引中的+1相关联
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        fade_in首先在缩放后的图像和输入之间进行
        这与生成器相反
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

生成器

在生成器架构中,我们有一些模式会重复,所以让我们首先为它创建一个类,以使我们的代码尽可能整洁,我们将这个类命名为GenBlock,它将从nn.Module继承。

  • 初始化部分,我们传入in_channels、out_channels和w_dim,然后我们通过WSConv2d初始化conv1,它将in_channels映射到out_channels,通过WSConv2d初始化conv2,它将out_channels映射到out_channels,leaky通过Leaky ReLU,其斜率为0.2,正如论文中使用的那样,inject_noise1和inject_noise2通过InjectNoise,adain1和adain2通过AdaIN。
  • 前向传播部分,我们传入x,然后将其传递给conv1,接着传递给inject_noise1和leaky,然后我们用adain1进行归一化,再次将结果传递给conv2,然后传递给inject_noise2和leaky,并用adain2进行归一化。最后,我们返回x。
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

现在我们已经有了创建生成器所需的所有东西。

  • 初始化部分,让我们通过一个生成器的迭代来初始化‘starting_constant’,这是一个4×4(原始论文中为512通道,我们这里为256通道)的常量张量,通过‘MappingNetwork’映射,由AdaIN初始化initial_adain1和initial_adain2,通过InjectNoise初始化initial_noise1和initial_noise2,通过一个卷积层初始化initial_conv,该卷积层将输入通道映射到自身,leaky为斜率为0.2的Leaky ReLU,通过WSConv2d初始化initial_rgb,该层将输入通道映射到img_channels,RGB中为3,prog_blocks通过ModuleList()初始化,将包含所有渐进块(我们通过乘以因素来指示卷积输入/输出通道,原始论文中为512,我们这里为256),rgb_blocks通过ModuleList()初始化,将包含所有RGB块。
  • 为了渐显新的层(ProGAN的原始组件),我们添加了渐显部分,我们发送alpha、scaled和generated,然后返回[tanh(alpha∗generated+(1−alpha)∗upscale)],我们使用tanh的原因是输出(生成的图像)的像素值应该在-1到1之间。
  • 前向部分,我们发送噪声(Z_dim),在训练过程中逐渐淡入的alpha值(alpha介于0和1之间),以及steps,它是我们正在处理的当前分辨率的编号,我们将x传入map以获得中间噪声向量W,将starting_constant传递给initial_noise1,应用它以及W的initial_adain1,然后将其传入initial_conv,再次为它添加initial_noise2,并使用泄漏作为激活函数,应用它以及W的initial_adain2。然后我们检查steps是否等于0,如果是,那么我们只需将其通过初始RGB即可完成,否则,我们遍历步数,在每个循环中我们进行升级(upscaled)并运行对应于该分辨率的渐进块(out)。最后,我们返回淡入,它接收alpha、final_out和final_upscaled,在映射到RGB之后。
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):   #-1 以防止因factors[i+1]导致的索引错误
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
         # alpha应该是[0, 1]范围内的标量,并且 upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

         # upgrade中的通道数将保持不变,而
         # 经过prog_blocks的out可能会改变。为了确保
         # 我们可以将它们都转换为rgb,我们使用不同的rgb_layers
         # (steps-1) 和 steps 分别用于upscaled,out
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

工具

在下面的代码片段中,你可以找到generate_examples函数,它接收生成器 gen,用于确定当前分辨率的步数,以及n=100的数字。这个函数的目的是生成n个假图像并将它们保存为结果。

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

在下面的代码片段中,你可以找到用于WGAN-GP损失的gradient_penalty函数。

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # 计算评判者分数
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # 计算图像相对于分数的梯度
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

训练函数

对于训练函数,我们发送评判者(即判别器)、gen(生成器)、loader、数据集、步数、alpha以及生成器和评判者的优化器。

我们首先遍历所有由DataLoader创建的迷你批次大小,并且只取图像,因为我们不需要标签。

然后我们为判别器\Critic设置训练,当我们想要最大化E(评判者(真实)) – E(评判者(伪造))。这个方程意味着评判者能够在多大程度上区分真实和伪造的图像。

之后,当我们想要最大化E(评判者(伪造)).

最后,我们更新循环和fade_in的alpha值,确保它在0和1之间,然后返回它。

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # 更新alpha值并确保小于1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

训练

现在我们已经有了所有东西,让我们将它们结合起来训练我们的StyleGAN。

我们首先初始化生成器、判别器/评判者和优化器,然后将生成器和评判者转换为训练模式,然后在PROGRESSIVE_EPOCHS上循环,每次循环中,我们调用train函数进行epoch次数的训练,然后我们生成一些伪造图像并保存它们,作为结果,使用generate_examples函数,最后我们进展到下一个图像分辨率。

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# 初始化优化器
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# 从对应于我们在配置中设置的图像大小的步骤开始
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # 从非常低的alpha值开始
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # 进展到下一个图像大小

结果

希望您能够遵循所有步骤并很好地理解如何以正确的方式实现StyleGAN。现在让我们看看在数据集中以128*128分辨率训练这个模型后得到的结果。

结论

在本文中,我们使用PyTorch从头开始实现了一个干净、简单且易于理解的StyleGAN1。我们尽可能地复制原始论文,所以如果您阅读了论文,实现应该与论文非常相似。

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch