Implémentation de StyleGAN1 à partir de zéro

Introduction

Cet article porte sur l’un des meilleurs GANs actuels, StyleGAN tiré de l’article A Style-Based Generator Architecture for Generative Adversarial Networks, nous allons en faire une implémentation propre, simple et lisible en utilisant PyTorch, et essayer de reproduire l’article original aussi fidèlement que possible, de sorte que si vous lisez l’article, l’implémentation devrait être à peu près identique.

Le jeu de données que nous allons utiliser dans ce blog est ce dataset de Kaggle qui contient 16240 vêtements supérieurs pour les femmes avec une résolution de 256*192.

Prérequis

Avant de vous plonger dans le travail avec StyleGAN en utilisant PyTorch, assurez-vous que vous avez les prérequis suivants:

  • .

    Connaissances de base en apprentissage profond
    Compréhension des réseaux de neurones convolutifs (CNN).
    Familiarité avec les réseaux adversaires génératifs (GAN), y compris des concepts tels que le générateur, le discriminateur et la perte adversaire.

  • Exigences matérielles
    Une GPU puissante (NVIDIA recommandée) pour un entraînement et une inférence plus rapides.
    Kit CUDA installé pour l’accélération GPU (cuda et cudnn).

  • Familiarité avec StyleGAN
    Il est utile d’avoir lu les articles originaux StyleGAN ou StyleGAN2 pour comprendre les améliorations d’architecture et les concepts clés.

Charger toutes les dépendances dont nous avons besoin

Nous allons d’abord importer torch puisque nous utiliserons PyTorch, et à partir de là, nous importerons nn. Cela nous aidera à créer et à entraîner les réseaux, et nous permettra également d’importer optim, un package qui implémente divers algorithmes d’optimisation (par exemple, sgd, adam,…). De torchvision, nous importons datasets et transforms pour préparer les données et appliquer certaines transformations.

Nous allons importer functional comme F de torch.nn pour upsample les images en utilisant interpolate, DataLoader de torch.utils.data pour créer des tailles de mini-lots, save_image de torchvision.utils pour enregistrer quelques échantillons faux, et log2 de math car nous avons besoin de la représentation inverse de la puissance de 2 pour implémenter la taille de mini-lot adaptative en fonction de la résolution de sortie, NumPy pour l’algèbre linéaire, os pour l’interaction avec le système d’exploitation, tqdm pour afficher les barres de progression, et enfin matplotlib.pyplot pour montrer les résultats et les comparer avec les vrais.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Hyperparamètres

  • Initialiser le DATASET par le chemin des images réelles.
  • Spécifier le démarrage de l’entraînement à la taille d’image 8×8.
  • Initialiser le périphérique par Cuda s’il est disponible et par CPU sinon, et le taux d’apprentissage par 0,001.
  • La taille du lot sera différente en fonction de la résolution des images que nous voulons générer, donc nous initialisons BATCH_SIZES par une liste de nombres, vous pouvez les modifier en fonction de votre VRAM.
  • Initialiser image_size par 128 et CHANNELS_IMG par 3 car nous allons générer des images RGB de 128 par 128.
  • Dans l’article original, ils initialisent Z_DIM, W_DIM et IN_CHANNELS à 512, mais je les initialise à 256 plutôt pour utiliser moins de VRAM et accélérer l’entraînement. Nous pourrions peut-être même obtenir de meilleurs résultats si nous les doubillions.
  • Pour StyleGAN, nous pouvons utiliser n’importe laquelle des fonctions de perte GAN que nous voulons, donc j’utilise WGAN-GP de l’article Improved Training of Wasserstein GANs. Cette perte contient un paramètre nommé λ et il est commun de régler λ = 10.
  • Initialiser PROGRESSIVE_EPOCHS à 30 pour chaque taille d’image.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #Les auteurs commencent avec des images 8x8 au lieu de 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Obtenir le chargeur de données

Jetzt erstellen wir eine Funktion get_loader, um :

  • Appliquer certaines transformations aux images (redimensionner les images à la résolution que nous voulons, les convertir en tenseurs, puis appliquer certaines augmentations, et enfin normaliser les pixels pour qu’ils varient de -1 à 1).
  • Identifier la taille actuelle du lot en utilisant la liste BATCH_SIZES, et prendre comme index le nombre entier de la représentation inverse de la puissance de 2 de image_size/4. Et c’est ainsi que nous mettons en œuvre la taille de lot adaptative en fonction de la résolution de sortie.
  • Préparer le jeu de données en utilisant ImageFolder car il est déjà structuré de manière agréable.
  • Créer des tailles de mini-lots en utilisant DataLoader qui prennent le jeu de données et la taille du lot avec mélange des données.
  • Finalement, retourner le chargeur et le jeu de données.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implémentation des modèles

Jetzt implementieren wir den StyleGAN1 générateur et discriminateur (ProGAN et StyleGAN1 ont la même architecture de discriminateur) avec les attributions clés de l’article. Nous essayerons de rendre l’implémentation compacte mais aussi lisible et compréhensible. Plus précisément, les points clés :

  • Réseau de cartographie du bruit
  • Normalisation adaptative d’instance (AdaIN)
  • Croissance progressive

Dans ce tutoriel, nous ne générerons que des images avec StyleGAN1, et ne mettrons pas en œuvre le mélange de styles et la variation stochastique, mais cela ne devrait pas être difficile à faire.

Reprenons une variable avec le nom factors qui contient les nombres qui multiplieront IN_CHANNELS pour avoir le nombre de canaux que nous voulons dans chaque résolution d’image.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Réseau de cartographie du bruit

Le réseau de cartographie du bruit prend Z et le fait passer par huit couches entièrement connectées, séparées par une activation.

Et n’oubliez pas d’égaliser le taux d’apprentissage comme le font les auteurs dans ProGAN (ProGAN et StyleGan rédigés par les mêmes chercheurs).

  • Commençons par construire une classe nommée WSLinear (weighted scaled Linear) qui héritera de nn.Module.Dans la partie init, nous envoyons in_features et out_channels. Créons une couche linéaire, puis nous définissons un scale qui sera égal à la racine carrée de 2 divisée par in_features, nous copions le biais de la couche colonne actuelle dans une variable car nous ne voulons pas que le biais de la couche linéaire soit scalaire, puis nous le retirons, enfin, nous initialisons la couche linéaire.
  • Dans la partie forward, nous envoyons x et tout ce que nous allons faire, c’est multiplier x par scale et ajouter le biais après l’avoir redimensionné.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialiser la couche linéaire
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Jetzt erstellen wir die MappingNetwork-Klasse.

  • Dans la partie init, nous envoyons z_dim et w_din, et nous définissons le réseau de cartographie qui commence par normaliser z_dim, suivi de huit WSLInear et ReLU en tant que fonctions d’activation.
  • Dans la partie forward, nous renvoyons le réseau de cartographie.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Normalisation Adaptative Instance (AdaIN)

Jetzt erstellen wir die AdaIN-Klasse

  • Dans la partie init, nous envoyons les canaux, w_dim, et nous initialisons instance_norm qui sera la partie de normalisation par instance, et nous initialisons style_scale et style_bias qui seront les parties adaptatives avec WSLinear qui mappe le réseau de cartographie du bruit W en canaux.
  • Dans la passe forward, nous envoyons x, appliquons la normalisation par instance pour celui-ci, et retournons style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Inject Noise

Jetzt erstellen wir die Klasse InjectNoise, um den Generator mit Noise zu injizieren

  • Dans la partie init, nous envoyons les canaux et nous initialisons le poids à partir d’une distribution normale aléatoire et nous utilisons nn.Parameter afin que ces poids puissent être optimisés
  • Dans la partie forward, nous envoyons une image x et nous la retournons avec du bruit aléatoire ajouté
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

classes utiles

Les auteurs ont construit StyleGAN sur l’implémentation officielle de ProGAN par Karras et al., ils utilisent la même architecture de discriminateur, taille de mini-lot adaptative, hyperparamètres, etc. Il y a donc beaucoup de classes qui restent les mêmes que dans l’implémentation de ProGAN.

Dans cette section, nous allons créer les classes qui ne changent pas de l’architecture de ProGAN.

Dans l’extrait de code ci-dessous, vous pouvez trouver la classe WSConv2d (couche de convolution pondérée et échelonnée) pour Equalized Learning Rate pour les couches de convolution.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialiser la couche de convolution
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

Dans l’extrait de code ci-dessous, vous pouvez trouver la classe PixelNorm pour normaliser Z avant le réseau de cartographie du bruit.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

Dans l’extrait de code ci-dessous, vous pouvez trouver la classe ConvBlock qui nous aidera à créer le discriminant.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

Dans l’extrait de code ci-dessous, vous pouvez trouver la classe Discriminatowich qui est la même que dans ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        ici, nous travaillons à rebours à partir des facteurs car le discriminant
        # devrait être mirroiré à partir du générateur. Ainsi, le premier prog_block et
        # la couche rgb que nous ajoutons fonctionnera pour une taille d'entrée de 1024x1024, puis 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # peut-être un nom confus "initial_rgb" c'est simplement la couche RGB pour une taille d'entrée de 4x4
        # fait cela pour "mirroir" le initial_rgb du générateur
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # échantillonnage descendant en utilisant une pool moyenne

        # c'est le bloc pour une taille d'entrée de 4x4
        self.final_block = nn.Sequential(
            # +1 aux in_channels car nous concaténons à partir de MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # nous utilisons cela au lieu d'une couche linéaire
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha devrait être scalaire dans [0, 1], et upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # nous prenons l'écart-type pour chaque exemple (à travers toutes les canaux et les pixels) puis nous le répétons
        # pour un canal unique et le concaténons avec l'image. De cette manière, le discriminant
        # obtenir des informations sur la variation dans le lot/image
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # où nous devrions commencer dans la liste des prog_blocks, peut-être un peu confus mais
        # le dernier est pour le 4x4. Ainsi, par exemple, disons que steps=1, alors nous devrions commencer
        # à l'avant-dernier car la taille d'entrée sera de 8x8. Si steps==0, nous utilisons simplement
        # le bloc final
        cur_step = len(self.prog_blocks) - steps

        # conversion de rgb comme étape initiale, cela dépendra de
        # la taille de l'image (chaque taille aura sa propre couche rgb)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # par exemple, l'image est de 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # car prog_blocks pourrait changer les canaux, pour l'échelle descendant nous utilisons rgb_layer
        # de la taille précédente plus petite, ce qui dans notre cas correspond à +1 dans l'indexation
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # la fade_in est effectuée d'abord entre le downscaled et l'entrée
        # c'est l'inverse du générateur
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

Générateur

Dans l’architecture du générateur, nous avons des motifs qui se répètent, donc créons d’abord une classe pour cela afin de rendre notre code aussi propre que possible. Nommons cette classe GenBlock qui héritera de nn.Module.

  • Dans la partie init, nous envoyons in_channels, out_channels et w_dim, puis nous initialisons conv1 avec WSConv2d qui mappe in_channels vers out_channels, conv2 avec WSConv2d qui mappe out_channels vers out_channels, leaky avec Leaky ReLU avec une pente de 0,2 comme ils l’utilisent dans le papier, inject_noise1 et inject_noise2 avec InjectNoise, adain1 et adain2 avec AdaIN.
  • Dans la partie forward, nous envoyons x, et nous le faisons passer par conv1 puis par inject_noise1 avec leaky, puis nous le normalisons avec adain1, et à nouveau nous passons celui-ci dans conv2 puis par inject_noise2 avec leaky et nous le normalisons avec adain2. Et enfin, nous retournons x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Jetzt haben wir alles, was wir brauchen, um den Generator zu erstellen.

  • dans la partie init, initialisons ‘starting_constant’ par un tenseur constant 4 x 4 (x 512 canaux pour l’article original, et 256 dans notre cas) qui est passé par une itération du générateur, mappé par ‘MappingNetwork’, initial_adain1, initial_adain2 par AdaIN, initial_noise1, initial_noise2 par InjectNoise, initial_conv par une couche de convolution qui mappe in_channels en lui-même, leaky par Leaky ReLU avec une pente de 0.2, initial_rgb par WSConv2d qui mappe in_channels en img_channels qui est 3 pour RGB, prog_blocks par ModuleList() qui contiendra tous les blocs progressifs (nous indiquons les canaux d’entrée/sortie de la convolution par multiplication de in_channels qui est 512 dans l’article et 256 dans notre cas par des facteurs), et rgb_blocks par ModuleList() qui contiendra tous les blocs RGB.
  • Pour faire apparaître de nouvelles couches (un composant d’origine de ProGAN), nous ajoutons la partie fade_in, à laquelle nous envoyons alpha, scaled et generated, et nous retournons [tanh(alpha∗generated+(1−alpha)∗upscale)], La raison pour laquelle nous utilisons tanh est que cela sera la sortie (l’image générée) et nous voulons que les pixels soient dans une plage comprise entre 1 et -1.
  • Dans la partie avancée, nous envoyons le bruit (Z_dim), la valeur alpha qui va s’estomper progressivement pendant l’entraînement (alpha est entre 0 et 1), et les étapes qui est le numéro de la résolution actuelle avec laquelle nous travaillons, nous passons x dans la carte pour obtenir le vecteur de bruit intermédiaire W, nous passons starting_constant à initial_noise1, l’appliquons et pour W initial_adain1, puis nous le passons dans initial_conv, et de nouveau nous ajoutons initial_noise2 pour lui avec leaky comme fonction d’activation, et l’appliquons pour lui et W initial_adain2. Ensuite, nous vérifions si steps = 0, si c’est le cas, alors tout ce que nous voulons faire, c’est le faire passer par le initial RGB et c’est tout, sinon, nous bouclons sur le nombre d’étapes, et à chaque boucle nous faisons de l’upscaling (upscaled) et nous passons par le bloc progressif qui correspond à cette résolution (out). À la fin, nous retournons fade_in qui prend alpha, final_out, et final_upscaled après l’avoir cartographié en RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 pour éviter une erreur d'index en raison de factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha devrait être scalaire dans [0, 1], et upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # Le nombre de canaux dans upscale restera le même, tandis que
        # out qui a traversé prog_blocks pourrait changer. Pour nous assurer
        # que nous pouvons les convertir tous deux en rgb, nous utilisons différents rgb_layers
        # (steps-1) et steps pour upscaled, out respectivement
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

Dans l’extrait de code ci-dessous, vous pouvez trouver la fonction generate_examples qui prend le générateur gen, le nombre de étapes pour identifier la résolution actuelle, et un nombre n=100. Le but de cette fonction est de générer n images factices et de les enregistrer comme résultat.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

Dans l’extrait de code ci-dessous, vous pouvez trouver la fonction gradient_penalty pour la perte WGAN-GP.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculer les scores du critique
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Prendre le gradient des scores par rapport aux images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Fonction d’entraînement

Pour la fonction d’entraînement, nous envoyons le critique (qui est le discriminant), gen (générateur), le chargeur, le jeu de données, l’étape, alpha, et l’optimiseur pour le générateur et pour le critique.

Nous commençons par boucler sur toutes les tailles de mini-lots que nous créons avec le DataLoader, et nous prenons seulement les images car nous n’avons pas besoin d’un étiquetage.

Ensuite, nous mettons en place l’entraînement pour le discriminant/Critique lorsque nous voulons maximiser E(critique(réel)) – E(critique(factice)). Cette équation signifie combien le critique peut distinguer entre des images réelles et factices.

Après cela, nous configurons l’entraînement du générateur lorsque nous voulons maximiser E(critic(fake)).

Finalement, nous mettons à jour la boucle et la valeur alpha pour fade_in et nous nous assurons qu’elle est comprise entre 0 et 1, puis nous la retournons.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Mise à jour de alpha et assurance qu'il est inférieur à 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Entraînement

Jetzt da wir alles avons, mettons-le ensemble pour entraîner notre StyleGAN.

Nous commençons par initialiser le générateur, le discriminateur/critique et les optimiseurs, puis passons le générateur et le critique en mode entraînement, puis bouclons sur PROGRESSIVE_EPOCHS, et dans chaque boucle, nous appelons la fonction train un nombre d’époques, puis nous générons quelques images factices et les sauvegardons, en utilisant la fonction generate_examples, et enfin, nous passons à la résolution d’image suivante.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# initialiser les optimiseurs
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# commencer à l'étape correspondant à la taille d'image que nous avons définie dans la config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # commencer avec une alpha très faible
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # progresser vers la taille d'image suivante

Résultat

En espérant que vous serez en mesure de suivre toutes les étapes et d’obtenir une bonne compréhension de la manière d’implémenter StyleGAN de la bonne façon. Maintenant, voyons les résultats que nous obtenons après avoir entraîné ce modèle sur ce jeu de données avec une résolution de 128*x 128.

Conclusion

Dans cet article, nous réalisons une implémentation propre, simple et lisible de StyleGAN1 à partir de zéro en utilisant PyTorch. nous replicons le papier original le plus fidèlement possible, donc si vous lisez le papier, l’implémentation devrait être presque identique.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch