從頭實現 StyleGAN1

介紹

本文將介紹目前最佳的生成對抗網絡之一,來自論文《A Style-Based Generator Architecture for Generative Adversarial Networks》的StyleGAN。我們將使用PyTorch進行乾淨、簡單且易於閱讀的實現,並盡可能地模擬原論文的內容,所以如果您讀過這篇論文,這個實現應該與之幾乎完全相同。

我們在這篇部落格中將使用的數據集是來自Kaggle的數據集,其中包含了16240張解析度為256*192的女性上衣圖片。

先備知識

在您開始使用PyTorch進行StyleGAN的工作之前,請確保您具備以下先備知識:

  • 深度學習的基本知識
    理解卷積神經網絡(CNNs)。
    熟悉生成對抗網絡(GANs),包括生成器、判別器和對抗損失等概念。

  • 硬件要求
    一個強大的GPU(推薦NVIDIA)以加快訓練和推理速度。
    安裝CUDA工具包以支持GPU加速(cudacudnn)。

  • 熟悉StyleGAN
    閱讀原始的StyleGANStyleGAN2論文有助於理解架構改進和關鍵概念。

加載我們所需的所有依賴

我們首先會導入torch,因為我們將使用PyTorch,然後從中導入nn。這將幫助我們創建和訓練網絡,並讓我們導入optim,一個實現各種優化算法(例如sgd、adam…)的包。從torchvision我們導入datasets和transforms來準備數據和應用一些轉換。

我們將從torch.nn導入functional作為F來使用interpolate上掇圖像,從torch.utils.data導入DataLoader來創建小批量大小,從torchvision.utils導入save_image來保存一些假樣本,從math導入log2因為我們需要2的冪的逆表示來實現根據輸出解析度自適應的批量大小,導入NumPy來進行線性代數計算,導入os來與操作系統交互,導入tqdm來顯示進度條,最後導入matplotlib.pyplot來顯示結果和與真實的進行比較。

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

超參數

  • 通過真實圖像的路徑初始化DATASET。
  • 指定從8×8圖像大小開始訓練。
  • 通過Cuda(如果可用)或CPU初始化設備,並將學習率設為0.001。
  • 批量大小將根據我們想要生成的圖像解析度而有所不同,因此我們通過一個數字列表初始化BATCH_SIZES,你可以根據你的VRAM更改它們。
  • 將image_size初始化為128,將CHANNELS_IMG初始化為3,因為我們將生成128×128的RGB圖像。
  • 在原始論文中,他們將Z_DIM、W_DIM和IN_CHANNELS初始化為512,但我則將它們初始化為256,以減少VRAM的使用並加速訓練。我們甚至可以通過將它們翻倍來獲得更好的結果。
  • 對於StyleGAN,我們可以使用任何我們想要的GAN損失函數,所以我選擇了論文中的WGAN-GPImproved Training of Wasserstein GANs。這個損失函數包含一個名為λ的參數,通常設置λ = 10。
  • 將PROGRESSIVE_EPOCHS初始化為每個圖像大小30。
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #作者們從8x8圖像開始,而不是4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

獲取數據加載器

現在讓我們創建一個函數get_loader來:

  • 對圖像應用一些轉換(將圖像調整到我們想要的解析度,將它們轉換為張量,然後應用一些增強,最後將它們正規化以使所有像素範圍在-1到1之間)。
  • 使用列表BATCH_SIZES識別當前批量大小,並將圖像大小除以4的2的冪的逆表示的整數數字作為索引。這就是我們如何實現依賴於輸出解析度的自適應最小批量大小的方法。
  • 使用ImageFolder準備數據集,因為它已經以一種很好的方式結構化。
  • 創建使用DataLoader的迷你批次大小,該DataLoader接受數據集和批次大小,並對數據進行洗牌。
  • 最後,返回加載器和數據集。
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

模型實現

現在讓我們實現StyleGAN1的生成器和判別器(ProGAN和StyleGAN1具有相同的判別器架構),並從論文中提取關鍵屬性。我們將努力使實現簡潔,同時保持可讀性和可理解性。具體來說,關鍵點:

  • 噪聲映射網絡
  • 自適應實例歸一化(AdaIN)
  • 逐步增長

在本教程中,我們將僅使用StyleGAN1生成圖像,而不實現風格混合和隨機變化,但這應該不難做到。

讓我們定義一個名為factors的變量,其中包含將與IN_CHANNELS相乘的數字,以在我們想要的每個圖像解析度中獲得通道數。

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

噪聲映射網絡

噪音映射網絡將 Z 輸入並通過八個全連接層,層與層之間有一些激活函數。並不要忘了像 ProGAN(由同一組研究者創作的 ProGAN 和 StyleGan)的作者那樣平衡學習率。

我們首先建立一個名為 WSLinear(加權縮放線性)的類,它將繼承自 nn.Module。

  • init 部分我們傳入 in_features 和 out_channels。創建一個線性層,然後我們定義一個 scale,它將等於 2 的平方根除以 in_features,我們將當前列層的偏置複製到一個變量中,因為我們不希望線性層的偏置被縮放,然後我們移除它,最後初始化線性層。
  • forward 部分中,我們傳入 x,而我們要做的只是將 x 與 scale 相乘,並在重塑後加上偏置。
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # 初始化線性層
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

現在我們來創建 MappingNetwork 類。

  • init 部分我們傳入 z_dim 和 w_din,並定義網絡映射,首先對 z_dim 進行標準化,然後接八個 WSLInear 和 ReLU 作為激活函數。
  • forward 部分中,我們返回網絡映射。

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

自適應實例歸一化(AdaIN)

現在我們來創建 AdaIN 類

  • 初始化部分,我們傳送通道、w_dim,並初始化instance_norm,這將是實例歸一化的部分;我們還初始化style_scale和style_bias,這將是與WSLinear映射Noise Mapping Network W到通道的自適應部分。
  • 前向傳播過程中,我們傳送x,對其應用實例歸一化,並返回style_sclate * x + style_bias。

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

注入噪音

現在讓我們創建類InjectNoise以將噪音注入生成器

  • 初始化部分,我們傳送通道,並從隨機正態分佈中初始化權重,我們使用nn.Parameter以便這些權重可以被優化
  • 前向傳播部分,我們傳送一幅圖像x,並返回添加了隨機噪音的圖像
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

有助的類別

作者在StyleGAN的基礎上建立了ProGAN的官方實現,由Karras等人提出,他們使用相同的判別器結構、自適應小批量大小、超參數等。因此,從ProGAN實現中保留了很多類別。

在這一部分,我們將創建從ProGAN架構中不變的類別。

在以下代碼片段中,您可以找到類別 WSConv2d(加權縮放卷積層)以實現卷積層的等化學習率。

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        初始化卷積層
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

在以下代碼片段中,您可以找到類別 PixelNorm 用於在噪聲映射網絡之前對 Z 進行歸一化。

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

在以下代碼片段中,您可以找到類別 ConvBock,它將幫助我們創建判別器。

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

在以下代碼片段中,您可以找到類別 Discriminatowich,它與 ProGAN 中的相同。

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        這裡我們從因子逆向操作,因為判別器
        應該從生成器中對稱。所以第一個 prog_block 和
        我們附加的 rgb 層將對於輸入尺寸 1024x1024 有作用,然後 512->256->等等
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        也許名稱 "initial_rgb" 會令人困惑,這只是 4x4 輸入尺寸的 RGB 層
        這是為了 "對稱" 生成器的 initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  使用 avg pool 進行下採樣

        這是 4x4 輸入尺寸的區塊
        self.final_block = nn.Sequential(
            +1 到 in_channels,因為我們從 MiniBatch std 連接
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  我們使用這個而不是線性層
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha 應該是 [0, 1] 范圍內的標量,且 upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        我們為每個示例(跨所有通道和像素)計算 std,然後我們重複它
        對於單一通道並將其與圖像連接。這樣判別器
        將獲得批次/圖像變化的信息
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        我們應該在 prog_blocks 列表中的哪裡開始,也許有點混亂,但
        最後一個是為了 4x4。所以舉例來說,假設 steps=1,那麼我們應該從
        倒數第二個開始,因為 input_size 將是 8x8。如果 steps==0,我們就
        使用最後的區塊
        cur_step = len(self.prog_blocks) - steps

        從 rgb 作為初始步驟轉換,這將取決於
        圖像尺寸(每個都會有自己的 rgb 層)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  例如,圖像是 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        因為 prog_blocks 可能會改變通道,對於縮小我們使用 rgb_layer
        從前一個/更小的尺寸,在我們的例子中相關於索引 +1
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        fade_in 首先在縮小後的圖像和輸入之間進行
        這與生成器相反
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

生成器

在生成器架构中,我们有一些模式会重复,所以让我们首先为它创建一个类,以使我们的代码尽可能清晰,我们将这个类命名为GenBlock,它将从nn.Module继承。

  • 初始化部分,我们传入in_channels、out_channels和w_dim,然后我们通过WSConv2d初始化conv1,它将in_channels映射到out_channels,通过WSConv2d初始化conv2,它将out_channels映射到out_channels,leaky通过Leaky ReLU,其斜率为0.2,正如论文中使用的那样,inject_noise1和inject_noise2通过InjectNoise,adain1和adain2通过AdaIN。
  • 前向传播部分,我们传入x,然后将其传递给conv1,再传递给inject_noise1和leaky,然后我们用adain1对其进行归一化,接着我们再次将结果传递给conv2,再传递给inject_noise2和leaky,并用adain2进行归一化。最后,我们返回x。
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

现在我们已经拥有创建生成器所需的所有内容。

  • 初始化部分,讓我們通過生成器的一次迭代來初始化 ‘starting_constant’,這個常數是一個 4 x 4(原始論文中為 512 個通道,我們的案例中為 256 個通道)的張量。通過 ‘MappingNetwork’ 映射,使用 AdaIN 初始化 initial_adain1 和 initial_adain2,使用 InjectNoise 初始化 initial_noise1 和 initial_noise2,使用卷積層初始化 initial_conv 將輸入通道映射到自身,leaky 是具有 0.2 倾斜度的 Leaky ReLU,使用 WSConv2d 初始化 initial_rgb 將輸入通道映射到 RGB 的 img_channels,即 3。prog_blocks 是 ModuleList(),將包含所有的進展性區塊(我們用乘積表示卷積的輸入/輸出通道,論文中為 512,我們的案例中為 256,並乘以係數)。rgb_blocks 是 ModuleList(),將包含所有的 RGB 區塊。
  • 為了逐漸引入新層(ProGAN 的原始组件),我們添加了淡入部分,我們傳送 alpha、scaled 和 generated,並返回 [tanh(alpha∗generated+(1−alpha)∗upscale)]。我們使用 tanh 的原因是這將是輸出(生成的圖像),我們希望像素值在 -1 到 1 之間。
  • 前向部分,我們傳送噪音(Z_dim),在訓練過程中會逐漸淡入的alpha值(alpha介於0和1之間),以及steps,這是我們正在处理的當前解析度的數字,我們將x傳入map以獲得中間噪音向量W,我們將starting_constant傳給initial_noise1,對其進行應用,並對W進行initial_adain1的應用,然後我們將其傳入initial_conv,再次為其添加initial_noise2並使用leaky作為激活函數,並對其和W進行initial_adain2的應用。然後我們檢查steps是否等於0,如果是,那麼我們所要做的就是讓它通過initial RGB,然後就完成了,否則,我們會循環步數,並在每次循環中進行升采样(upscaled),並通過與該解析度相對應的進步塊(out)。最後,我們返回淡入,它將alpha、final_out和經過RGB映射的final_upscaled作為參數。
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):   #-1以防止由於factors[i+1]造成的索引錯誤
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
         # alpha應該是[0, 1]範圍內的标量,並且upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

         # upscale中的通道數將保持不變,而
         # 經過prog_blocks的out可能會變化。為了確保
         # 我們可以將其都轉換為rgb,我們使用不同的rgb_layers
         # 分別為upscaled的(steps-1)和steps的out
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

工具

在下面的代碼片段中,您可以找到 generate_examples 函數,它接受 生成器 gen、用於識別當前解析度的 步數,以及 n=100 的數字。這個函數的目標是生成 n 個假圖像並將它們作為結果保存。

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

在下面的代碼片段中,您可以找到用於 WGAN-GP 損失的 gradient_penalty 函數。

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # 計算評論家得分
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # 對圖像的得分計算梯度
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

訓練函數

對於訓練函數,我們傳遞評論家(即判別器)、生成器 gen、數據加載器 loader、數據集 dataset、步數 step、alpha,以及生成器和評論家的優化器。

我們首先遍歷我們使用 DataLoader 創建的所有的最小批量大小,並且只取圖像,因為我們不需要標籤。

然後我們設置判別器\評論家的訓練,當我們想要最大化 E(評論家(真)) – E(評論家(假)) 時。這個方程意味著評論家區分真實和假圖像的能力。

之後,當我們想要最大化E(評論者(假圖))

時,我們為生成器設置訓練。

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # 更新alpha並確保小於1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

訓練

現在既然我們已經有了所有東西,讓我們把它們組合起來來訓練我們的StyleGAN。

我們從初始化生成器、判別器/評論者以及優化器開始,然後將生成器和評論者轉換為訓練模式,接著循環遍歷PROGRESSIVE_EPOCHS,在每次循環中,我們調用train函數指定次數的epoch,然後生成一些假圖像並將它們保存起來,因此,使用generate_examples函數,最後我們進展到下一個圖像解析度。

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# 初始化優化器
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# 從對應於配置中設置的圖像大小的步驟開始
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # 從非常低的alpha開始
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # 進展到下一個圖像大小

結果

希望您能遵循所有步骤,並充分理解如何以正確的方式實現StyleGAN。現在我們來查看在這個數據集中以128*x 128解析度訓練此模型後我們得到的结果。

結論

在本文中,我們從頭開始使用PyTorch進行了一個乾淨、簡單且易於閱讀的StyleGAN1的實現。我們盡可能接近原論文的內容,所以如果您閱讀了論文,那麼這個實現應該與之非常相似。

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch