Implementatie van StyleGAN1 van scratch

Inleiding

Dit artikel gaat over één van de beste GANs van tegenwoordig, StyleGAN van het paper Een Style-Based Generator Architecture for Generative Adversarial Networks, we zullen een schone, eenvoudige en leesbare implementatie ervan maken met PyTorch, en proberen de originele paper zoveel mogelijk na te bootsen, dus als je het paper leest, zou de implementatie vrijwel identiek moeten zijn.

De dataset die we in deze blog gebruiken is deze dataset van Kaggle die 16240 bovenkleding voor vrouwen bevat met een resolutie van 256*192.

Voorkennis

Voordat je aan de slag gaat met StyleGAN met PyTorch, zorg ervoor dat je de volgende voorkennis hebt:

  • Basiskennis van Diepe Lering
    Begrip van convolutionele neurale netwerken (CNNs).
    Vertrouwdheid met Generative Adversarial Networks (GANs), inclusief concepten zoals de generator, discriminator en adversarische verlies.

  • Hardware vereisten
    Een krachtige GPU (NVIDIA aanbevolen) voor snellere training en inferentie.
    CUDA toolkit geïnstalleerd voor GPU-versnelling (cuda en cudnn).

  • Vertrouwdheid met StyleGAN
    Het is nuttig om de originele StyleGAN of StyleGAN2 papers te hebben gelezen om de architectuurverbeteringen en belangrijkste concepten te begrijpen.

Laad alle afhankelijkheden die we nodig hebben

We gaan eerst torch importeren omdat we PyTorch gaan gebruiken, en daarna importeren we nn. Dat helpt ons om netwerken te maken en te trainen, en ook om optim te importeren, een pakket dat verschillende optimalisatiealgoritmen implementeert (bijv. sgd, adam, …). Van torchvision importeren we datasets en transforms om de data voor te bereiden en enkele transformaties toe te passen.

We importeren functional als F van torch.nn om de afbeeldingen op te schalen met interpolate, DataLoader van torch.utils.data om mini-batchgroottes te maken, save_image van torchvision.utils om enkele valse voorbeelden op te slaan, en log2 van math omdat we de inverse representatie van de macht van 2 nodig hebben om de aanpasbare minibatchgrootte afhankelijk van de uitvoerresolutie te implementeren, NumPy voor lineaire algebra, os voor interactie met het besturingssysteem, tqdm om voortgangsbalken te tonen, en uiteindelijk matplotlib.pyplot om de resultaten te tonen en te vergelijken met de echte.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Hyperparameters

  • Initialiseer de DATASET met het pad van de echte afbeeldingen.
  • Specificeer de start van de training bij een afbeeldingsgrootte van 8×8.
  • Initialiseer het apparaat met Cuda als het beschikbaar is en CPU anders, en de leergraad op 0.001.
  • De batchgrootte zal verschillen afhankelijk van de resolutie van de afbeeldingen die we willen genereren, dus we initialiseren BATCH_SIZES met een lijst van getallen, je kunt ze aanpassen afhankelijk van je VRAM.
  • Initialiseer image_size op 128 en CHANNELS_IMG op 3 omdat we 128×128 RGB-afbeeldingen gaan genereren.
  • In het oorspronkelijke paper initialiseren ze Z_DIM, W_DIM en IN_CHANNELS met 512, maar ik initialiseer ze met 256 in plaats daarvan voor minder VRAM-gebruik en versnelde training. We zouden zelfs betere resultaten kunnen krijgen als we ze verdubbelden.
  • Voor StyleGAN kunnen we elke GANs-verliesfunctie gebruiken die we willen, dus ik gebruik WGAN-GP uit het paper Improved Training of Wasserstein GANs. Deze verlies bevat een parameter genaamd λ en het is gebruikelijk om λ = 10 in te stellen.
  • Initialiseer PROGRESSIVE_EPOCHS met 30 voor elke afbeeldingsgrootte.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #De auteurs beginnen met 8x8 afbeeldingen in plaats van 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Verkrijg data loader

Laten we nu een functie get_loader maken om:

  • Enkele transformaties toe te passen op de afbeeldingen (afbeeldingen resizing naar de gewenste resolutie, ze om te zetten naar tensors, dan enkele augmentaties toe te passen en uiteindelijk te normaliseren zodat alle pixels variëren van -1 tot 1).
  • De huidige batchgrootte te identificeren met behulp van de lijst BATCH_SIZES, en als index te nemen het integere getal van de inverse vertegenwoordiging van de macht van 2 van image_size/4. En dit is eigenlijk hoe we de adaptieve minibatchgrootte implementeren afhankelijk van de uitvoerresolutie.
  • De dataset voor te bereiden door ImageFolder te gebruiken omdat het al op een nette manier is gestructureerd.
  • Maak mini-batchgroottes met DataLoader die de dataset en batchgrootte gebruiken met het shuffelen van de gegevens.
  • Daarnaast retourneren we de loader en dataset.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implementatie van modellen

Laten we nu de StyleGAN1 generator en discriminator implementeren (ProGAN en StyleGAN1 hebben dezelfde discriminatorarchitectuur) met de belangrijkste attributen uit het paper. We zullen proberen de implementatie compact te maken, maar ook leesbaar en begrijpelijk te houden. Specifiek de belangrijkste punten:

  • Noise Mapping Network
  • Adaptive Instance Normalization (AdaIN)
  • Progressieve groei

In deze tutorial zullen we alleen afbeeldingen genereren met StyleGAN1, en geen style mixing en stochastische variatie implementeren, maar dat zou niet moeilijk moeten zijn.

Laten we een variabele definiëren met de naam factors die de getallen bevatten die vermenigvuldigd moeten worden met IN_CHANNELS om het aantal kanalen te verkrijgen dat we in elke afbeeldingsresolutie willen.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Noise Mapping Network

Het geluidsmappingnetwerk neemt Z en voert het door acht volledig verbonden lagen, gescheiden door een activatie, en vergeet niet het leerrendement te egaliseren zoals de auteurs dat doen in ProGAN (ProGAN en StyleGan, geschreven door dezelfde onderzoekers).

Laten we eerst een klasse bouwen met de naam WSLinear (gewogen geschaalde Lineair) die wordt afgeleid van nn.Module.

  • In het init-gedeelte sturen we in_features en out_channels door. We maken een lineaire laag, definiëren vervolgens een schaal die gelijk is aan de vierkantswortel van 2 gedeeld door in_features, kopiëren de bias van de huidige kolomlaag naar een variabele omdat we de bias van de lineaire laag niet willen schalen, verwijderen we deze vervolgens en initialiseren we de lineaire laag.
  • In het forward-gedeelte sturen we x door en alles wat we gaan doen is x vermenigvuldigen met scale en de bias toevoegen nadat deze is herschikt.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialiseren van de lineaire laag
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Laten we nu de MappingNetwork-klasse maken.

  • In het init-gedeelte sturen we z_dim en w_din door, en we definiëren het netwerk mapping dat eerst z_dim normaliseert, gevolgd door acht WSLInear en ReLU als activatiefuncties.
  • In het forward-gedeelte retourneren we het netwerk mapping.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Adaptieve Instance Normalisatie (AdaIN)

Laten we nu de AdaIN-klasse maken.

  • In de init-deel sturen we kanalen, w_dim, en we initialiseren instance_norm die het deel voor instance normalisatie zal zijn, en we initialiseren style_scale en style_bias die de adaptieve delen zullen zijn met WSLinear die de Noise Mapping Network W naar kanalen mapt.
  • In de forward-doorvoer sturen we x, passen we instance normalization toe ervoor, en retourneren we style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Inject Noise

Laten we nu de klasse InjectNoise maken om ruis in de generator in te voegen

  • In het init-deel stuurden we kanalen en we initialiseren het gewicht uit een willekeurige normale verdeling en we gebruiken nn.Parameter zodat deze gewichten geoptimaliseerd kunnen worden
  • In de forward-deel sturen we een afbeelding x en we retourneren het met toegevoegde willekeurige ruis
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

helpful classes

De auteurs bouwen StyleGAN op de officiële implementatie van ProGAN door Karras et al, ze gebruiken dezelfde discriminatorarchitectuur, adaptieve minibatch-grootte, hyperparameters, etc. Dus er zijn veel klassen die hetzelfde blijven van de ProGAN-implementatie.

In dit gedeelte zullen we de klassen maken die niet veranderen van de ProGAN-architectuur.

In de onderstaande codefragment kunt u de klasse WSConv2d (gegewogen geschaalde convolutielaag) vinden voor Equalized Learning Rate voor de conv lagen.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialiseer conv laag
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In de onderstaande codefragment kunt u de klasse PixelNorm vinden om Z te normaliseren voor de Noise Mapping Network.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

In de onderstaande codefragment kunt u de klasse ConvBock vinden die ons zal helpen de discriminator te maken.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

In de onderstaande codefragment kunt u de klasse Discriminatowich vinden die hetzelfde is als in ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        hier werken we achteruit vanaf factoren omdat de diskriminant
        moet worden gemirrored van de generator. Dus de eerste prog_block en
        rgb laag die we toevoegen zal werken voor een invoergrootte van 1024x1024, dan 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        misschien verwarrende naam "initial_rgb" dit is gewoon de RGB laag voor een invoergrootte van 4x4
        deze deed om de generator initial_rgb te "mirror"
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  down sampling met gebruik van avg pool

        dit is het blok voor een invoergrootte van 4x4
        self.final_block = nn.Sequential(
            +1 aan in_channels omdat we samenvoegen van MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  we gebruiken dit in plaats van een lineaire laag
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha moet een schaal zijn binnen [0, 1], en upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        we nemen de std voor elk voorbeeld (over alle kanalen en pixels) dan herhalen we het
        voor een enkele kanal en voegen het samen met de afbeelding. Op deze manier zal de diskriminant
        informatie krijgen over de variatie in de batch/afbeelding
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        waar we moeten beginnen in de lijst van prog_blocks, misschien een beetje verwarrend maar
        de laatste is voor de 4x4. Dus bijvoorbeeld als steps=1, dan moeten we beginnen
        bij de tweede van het laatst omdat input_size zal zijn 8x8. Als steps==0 gebruiken we gewoon
        het laatste blok
        cur_step = len(self.prog_blocks) - steps

        converteren van rgb als initiële stap, dit zal afhankelijk zijn van
        de afbeeldingsgrootte (elke zal zijn eigen rgb laag hebben)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  bijvoorbeeld, de afbeelding is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        omdat prog_blocks de kanalen kan wijzigen, gebruiken we voor down scale de rgb_layer
        van de vorige/kleinere grootte die in ons geval correleert met +1 in de indexering
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        de fade_in wordt eerst gedaan tussen de downscaled en de input
        dit is het tegenovergestelde van de generator
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

Generator

In de generatorarchitectuur hebben we enkele patronen die zich herhalen, dus laten we eerst een klasse maken voor het zo proper mogelijke code, laten we de klasse GenBlock noemen die wordt afgeleid van nn.Module.

  • In het init-gedeelte sturen we in_channels, out_channels en w_dim door, dan initialiseren we conv1 met WSConv2d die in_channels naar out_channels mapt, conv2 met WSConv2d die out_channels naar out_channels mapt, leaky met Leaky ReLU met een helling van 0.2 zoals ze dat in het paper gebruiken, inject_noise1 en inject_noise2 met InjectNoise, adain1 en adain2 met AdaIN.
  • In het forward-gedeelte sturen we x door en we passen het toe op conv1 dan op inject_noise1 met leaky, dan normaliseren we het met adain1, en weer passen we dat toe op conv2 dan op inject_noise2 met leaky en we normaliseren het met adain2. En tenslotte geven we x terug.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Nu hebben we alles wat we nodig hebben om de generator te maken.

  • in het init gedeelte initialiseren we ‘starting_constant’ door een constante 4 x 4 (x 512 kanalen voor het originele paper, en 256 in ons geval) tensor die door een iteratie van de generator wordt gestuurd, gemapt door ‘MappingNetwork’, initial_adain1, initial_adain2 door AdaIN, initial_noise1, initial_noise2 door InjectNoise, initial_conv door een conv laag die in_channels naar zichzelf mapt, leaky door Leaky ReLU met een helling van 0.2, initial_rgb door WSConv2d die in_channels naar img_channels mapt wat 3 is voor RGB, prog_blocks door ModuleList() die alle progressieve blokken zal bevatten (we geven convolutie invoer/uitvoer kanalen aan door in_channels te vermenigvuldigen wat 512 is in het paper en 256 in ons geval met factoren), en rgb_blocks door ModuleList() die alle RGB blokken zal bevatten.
  • Om nieuwe lagen in te faden (een oorspronkelijke component van ProGAN), voegen we het fade_in gedeelte toe, waarin we alpha, scaled, en gegenereerd sturen, en we retourneren [tanh(alpha∗gegenereerd+(1−alpha)∗opschaal)], De reden dat we tanh gebruiken is dat dit de uitvoer (het gegenereerde beeld) zal zijn en we willen dat de pixels in het bereik tussen 1 en -1 liggen.
  • In de voorwaartse deel, sturen we het geluid (Z_dim), de alpha-waarde die langzaam fade-in zal gaan tijdens de training (alpha ligt tussen 0 en 1), en stappen die het nummer van de huidige resolutie is waar we mee werken, we voeren x door de kaart om de tussenliggende ruisvector W te krijgen, we voeren starting_constant door initial_noise1, passen het toe en voor W initial_adain1, vervolgens voeren we het door initial_conv, en weer voegen we initial_noise2 toe voor het met leaky als activatiefunctie, en passen het toe en W initial_adain2. Dan controleren we of stappen = 0, als dat zo is, dan willen we alleen maar door de initial RGB laten lopen en dat is het, anders, we lopen over het aantal stappen, en in elke lus schalen we op (upscaled) en we lopen door het progressieve blok dat overeenkomt met die resolutie (out). Aan het einde, retourneren we fade_in dat alpha, final_out, en final_upscaled neemt na het mappen naar RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 om indexfout te voorkomen vanwege factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha moet een schaal zijn binnen [0, 1], en upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # Het aantal kanalen in upscale blijft hetzelfde, terwijl
        # out die door prog_blocks is gegaan mogelijk verandert. Om te
        # ervoor te zorgen dat we beide naar rgb kunnen converteren, gebruiken we
        # verschillende rgb_lagen voor (stappen-1) en stappen voor upscaled, out respectievelijk.
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

In de onderstaande code snippet kun je de functie generate_examples vinden die de generator gen, het aantal stappen om de huidige resolutie te identificeren, en een getal n=100 accepteert. Het doel van deze functie is om n nepafbeeldingen te genereren en deze op te slaan als resultaat.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

In de onderstaande code snippet kun je de gradient_penalty functie voor WGAN-GP verlies vinden.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Bereken critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Neem de afgeleide van de scores met betrekking tot de afbeeldingen
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Train functie

Voor de train functie sturen we de critic (wat de discriminator is), gen(generator), loader, dataset, stap, alpha, en optimizer voor de generator en voor de critic.

We beginnen met het doorlopen van alle mini-batch groottes die we maken met de DataLoader, en we nemen alleen de afbeeldingen omdat we geen label nodig hebben.

Daarna stellen we de training voor de discriminator\Critic in when we want to maximize E(critic(real)) – E(critic(fake)). Deze vergelijking betekent hoeveel de critic kan onderscheiden tussen echte en nepafbeeldingen.

Daarna stellen we de training in voor de generator wanneer we E(critic(fake)).

uiteindelijk de lus bijwerken en de alpha-waarde voor fade_in bijstellen en ervoor zorgen dat deze tussen 0 en 1 ligt, en deze retourneren.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Alpha bijwerken en ervoor zorgen dat het minder dan 1 is
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Training

Nu we alles hebben, laten we het samenstellen om onze StyleGAN te trainen.

We beginnen met het initialiseren van de generator, de discriminator/critic en de optimizers, zetten de generator en de critic in train-modus, en lopen vervolgens over PROGRESSIVE_EPOCHS, en in elke lus roepen we het trainingsfunctie het aantal keren dat het aantal epochs is, vervolgens genereren we wat nepafbeeldingen en slaan deze op als resultaat met behulp van de generate_examples-functie, en tenslotte gaan we naar de volgende afbeeldingsresolutie.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# optimizers initialiseren
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# beginnen bij de stap die overeenkomt met de img-grootte die we hebben ingesteld in de config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # beginnen met een zeer lage alpha
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # doorgaan naar de volgende img-grootte

Resultaat

Hopelijk kun je alle stappen volgen en een goed begrip krijgen van hoe je StyleGAN op de juiste manier kunt implementeren. Laten we nu de resultaten bekijken die we verkrijgen na het trainen van dit model in deze dataset met een resolutie van 128*x 128.

Conclusie

In dit artikel maken we een schone, eenvoudige en leesbare implementatie van StyleGAN1 vanaf nul met behulp van PyTorch. We repliceren het originele paper zo dicht mogelijk, dus als je het paper leest, zou de implementatie vrijwel identiek moeten zijn.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch