Implementazione StileGAN1 da zero

Introduzione

Questo articolo riguarda uno dei migliori GAN oggi disponibili, StyleGAN, tratta dal paper Una Architettura Generatrice Basata sullo Stile per Reti Adversariali Generative, realizzeremo un’implementazione pulita, semplice e leggibile utilizzando PyTorch, cercando di replicare il più possibile il paper originale, quindi se avete letto il paper, l’implementazione dovrebbe essere praticamente identica.

Il dataset che utilizzeremo in questo blog è questo dataset da Kaggle che contiene 16240 capi d’abbigliamento superiori per donne con risoluzione 256*192.

Prerequisiti

Prima di immergervi nel lavoro con StyleGAN utilizzando PyTorch, assicuratevi di avere i seguenti prerequisiti:

  • Conoscenze di Base sull’Apprendimento Profondo
    Comprensione delle reti neurali convoluzionali (CNN).
    Familiarità con le Reti Adversariali Generative (GAN), inclusi concetti come il generatore, il discriminatore e la perdita avversaria.

  • Requisiti Hardware
    Una GPU potente (raccomandata NVIDIA) per una formazione e un’inferenza più rapide.
    CUDA toolkit installato per l’accelerazione GPU (cuda e cudnn).

  • Familiarità con StyleGAN
    È utile aver letto i documenti originali di StyleGAN o StyleGAN2 per comprendere i miglioramenti dell’architettura e i concetti chiave.

Caricare tutte le dipendenze di cui abbiamo bisogno

Prima importeremo torch dato che utilizzeremo PyTorch, e da lì importeremo nn. Questo ci aiuterà a creare e addestrare le reti, e ci permetterà anche di importare optim, un pacchetto che implements vari algoritmi di ottimizzazione (ad esempio sgd, adam,…). Da torchvision importeremo datasets e transforms per preparare i dati e applicare alcune trasformazioni.

Importeremo functional come F da torch.nn per upsample le immagini utilizzando interpolate, DataLoader da torch.utils.data per creare dimensioni di mini-lotti, save_image da torchvision.utils per salvare alcuni campioni falsi, e log2 da math perché abbiamo bisogno della rappresentazione inversa della potenza di 2 per implementare la dimensione di mini-lotto adattiva a seconda della risoluzione di output, NumPy per algebra lineare, os per l’interazione con il sistema operativo, tqdm per mostrare le barre di avanzamento, e infine matplotlib.pyplot per mostrare i risultati e confrontarli con quelli reali.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Iperparametri

  • Inizializziamo il DATASET con il percorso delle immagini reali.
  • Specificiamo l’inizio del train alle dimensioni di immagine 8×8.
  • Inizializziamo il device con Cuda se disponibile e CPU altrimenti, e il tasso di apprendimento a 0.001.
  • La dimensione del batch sarà diversa a seconda della risoluzione delle immagini che vogliamo generare, quindi inizializziamo BATCH_SIZES con una lista di numeri, puoi cambiarli a seconda della tua VRAM.
  • Inizializziamo image_size a 128 e CHANNELS_IMG a 3 perché genereremo immagini RGB di 128 per 128.
  • Nel paper originale, inizializzano Z_DIM, W_DIM e IN_CHANNELS a 512, ma io li inizializzo a 256 invece per un minor utilizzo della VRAM e per accelerare l’addestramento. Potremmo forse ottenere risultati migliori se li raddoppiassimo.
  • Per StyleGAN possiamo utilizzare qualsiasi funzione di perdita GAN che vogliamo, quindi utilizzo WGAN-GP dal paper Improved Training of Wasserstein GANs. Questa perdita contiene un parametro chiamato λ e è comune impostare λ = 10.
  • Inizializzare PROGRESSIVE_EPOCHS a 30 per ogni dimensione di immagine.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 # Gli autori iniziano da immagini 8x8 invece di 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Ottieni il data loader

Ora creiamo una funzione get_loader per:

  • Applicare alcune trasformazioni alle immagini (ridimensionare le immagini alla risoluzione che vogliamo, convertire in tensori, poi applicare alcune augmentazioni e infine normalizzare i pixel per essere tutti compresi tra -1 e 1).
  • Identificare la dimensione corrente del batch utilizzando la lista BATCH_SIZES, e prendere come indice il numero intero della rappresentazione inversa della potenza di 2 della dimensione dell’immagine/4. Ed è esattamente così che implementiamo la dimensione del minibatch adattiva in base alla risoluzione di output.
  • Preparare il dataset utilizzando ImageFolder perché è già strutturato in modo gradevole.
  • Creare dimensioni di mini-lotti utilizzando DataLoader che prendono il dataset e la dimensione del lotto con la mescolatura dei dati.
  • Infine, restituire il carico e il dataset.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implementazione dei modelli

Adesso implementiamo il generatore e il discriminatore StyleGAN1 (ProGAN e StyleGAN1 hanno la stessa architettura del discriminatore) con le attribuzioni chiave del paper. Cercheremo di rendere l’implementazione compatta ma anche leggibile e comprensibile. In particolare, i punti chiave:

  • Rete di Mappatura del Rumore
  • Normalizzazione Adattiva dell’istanza (AdaIN)
  • Crescita progressiva

In questo tutorial, genereremo solo immagini con StyleGAN1, senza implementare il mixing dello stile e la variazione stocastica, ma non dovrebbe essere difficile farlo.

Definiamo una variabile con il nome factors che contiene i numeri che moltiplicheranno IN_CHANNELS per ottenere il numero di canali che vogliamo in ogni risoluzione di immagine.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Rete di Mappatura del Rumore

La rete di mappatura del rumore prende Z e lo passa attraverso otto strati completamente connessi separati dasome attivazione. E non dimenticare di equalizzare il tasso di apprendimento come fanno gli autori in ProGAN (ProGAN e StyleGan scritti dagli stessi ricercatori).

Lasciamo innanzitutto costruire una classe con il nome WSLinear (weighted scaled Linear) che verrà ereditata da nn.Module.

  • Nella parte init inviamo in_features e out_channels. Creiamo un livello lineare, poi definiamo una scala che sarà uguale alla radice quadrata di 2 diviso in_features, copiamo il bias della colonna corrente in una variabile perché non vogliamo che il bias del livello lineare sia scalato, poi lo rimuoviamo, infine inizializziamo il livello lineare.
  • Nella parte forward, inviamo x e tutto ciò che faremo è moltiplicare x per scale e aggiungere il bias dopo averlo ridefinito.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # inizializza il livello lineare
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Ora creiamo la classe MappingNetwork.

  • Nella parte init inviamo z_dim e w_din, e definiamo la rete di mappatura che prima normalizza z_dim, seguita da otto WSLInear e ReLU come funzioni di attivazione.
  • Nella parte forward, restituiamo la mappatura della rete.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Adaptive Instance Normalization (AdaIN)

Ora creiamo la classe AdaIN

  • Nella parte init inviamo canali, w_dim e inizializziamo instance_norm che sarà la parte di normalizzazione dell’istanza, e inizializziamo style_scale e style_bias che saranno le parti adattive con WSLinear che mappa il Noise Mapping Network W nei canali.
  • Nella parte forward inviamo x, applichiamo la normalizzazione dell’istanza e restituiamo style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Inject Noise

Ora creiamo la classe InjectNoise per iniettare il rumore nel generatore

  • Nella parte init abbiamo inviato canali e inizializziamo il peso da una distribuzione normale casuale e utilizziamo nn.Parameter in modo che questi pesi possano essere ottimizzati
  • Nella parte forward inviamo un’immagine x e la restituiamo con rumore casuale aggiunto
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

classi utili

Gli autori costruiscono StyleGAN sull’implementazione ufficiale di ProGAN di Karras et al, utilizzano la stessa architettura del discriminatore, dimensione adattiva del minibatch, iperparametri, ecc. Quindi ci sono molte classi che rimangono le stesse dall’implementazione di ProGAN.

In questa sezione, creeremo le classi che non cambiano dall’architettura ProGAN.

Nel seguente frammento di codice puoi trovare la classe WSConv2d (strato convoluzionale con pesi scalati) per Equalized Learning Rate per gli strati convoluzionali.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # inizializza strato convoluzionale
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

Nel seguente frammento di codice puoi trovare la classe PixelNorm per normalizzare Z prima della Rete di Mappatura del Rumore.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

Nel seguente frammento di codice puoi trovare la classe ConvBlock che ci aiuterà a creare il discriminante.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

Nel seguente frammento di codice puoi trovare la classe Discriminatowhich è la stessa di quella in ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        qui lavoriamo al contrario partendo dai fattori perché il discriminante
        dovrebbe essere specchiato dal generatore. Quindi il primo prog_block e
        il primo strato rgb che aggiungiamo funzioneranno per la dimensione di input 1024x1024, poi 512->256->ecc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        forse il nome "initial_rgb" è confusionario, questa è solo la layer RGB per la dimensione di input 4x4
        ho fatto così per "specchiare" l'initial_rgb del generatore
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  riduzione del campione utilizzando la media pool

        questo è il blocco per la dimensione di input 4x4
        self.final_block = nn.Sequential(
            +1 ai canali in ingresso perché concateniamo dalla std del MiniBatch
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  utilizziamo questo invece di uno strato lineare
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha dovrebbe essere uno scalare tra [0, 1], e upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        prendiamo la std per ogni esempio (attraverso tutti i canali e i pixel) poi la ripetiamo
        per un singolo canale e la concateniamo con l'immagine. In questo modo il discriminante
        ottiene informazioni sulla variazione nel batch/immagine
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        da dove dovremmo iniziare nell'elenco dei prog_blocks, forse un po' confusionario ma
        l'ultimo è per il 4x4. Quindi, ad esempio, se steps=1, dovremmo iniziare
        dal penultimo perché la dimensione di input sarà 8x8. Se steps==0 utilizziamo semplicemente
        il blocco finale
        cur_step = len(self.prog_blocks) - steps

        convertiamo da rgb come primo passo, questo dipenderà
        dalla dimensione dell'immagine (ognuna avrà la sua layer rgb)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  ad esempio, l'immagine è 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        poiché i prog_blocks potrebbero cambiare i canali, per la riduzione utilizziamo rgb_layer
        dal尺寸 precedente/più piccolo che nel nostro caso corrisponde a +1 nell'indicizzazione
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        il fade_in viene eseguito prima tra il ridimensionato e l'input
        questo è l'opposto del generatore
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

Generatore

Nell’architettura del generatore, abbiamo alcuni schemi che si ripetono, quindi creiamo innanzitutto una classe per rendere il nostro codice il più pulito possibile; chiamiamo la classe GenBlock che sarà ereditata da nn.Module.

  • Nella parte init inviamo in_channels, out_channels e w_dim, poi inizializziamo conv1 da WSConv2d che mappa in_channels in out_channels, conv2 da WSConv2d che mappa out_channels in out_channels, leaky da Leaky ReLU con una pendenza di 0.2 come usano nel paper, e poi la classe GenBlock.2 come nel documento, inject_noise1, inject_noise2 da InjectNoise, adain1 e adain2 da AdaIN
  • Nella parte forward, inviamo x, lo passiamo a conv1 e poi a inject_noise1 con leaky, quindi lo normalizziamo con adain1, e di nuovo lo passiamo a conv2 e poi a inject_noise2 con leaky e lo normalizziamo con adain2. Infine, restituiamo x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Ora abbiamo tutto ciò che ci serve per creare il generatore.

  • nella parte init iniziamo ‘starting_constant’ con un tensore costante 4 x 4 (x 512 canali per il paper originale, e 256 nel nostro caso) che viene mandato attraverso un’iterazione del generatore, mappato da ‘MappingNetwork’, initial_adain1, initial_adain2 da AdaIN, initial_noise1, initial_noise2 da InjectNoise, initial_conv da un livello convoluzionale che mappa in_channels a se stesso, leaky da Leaky ReLU con una pendenza di 0.2, initial_rgb da WSConv2d che mappa in_channels a img_channels che è 3 per RGB, prog_blocks da ModuleList() che conterrà tutti i blocchi progressivi (indicando i canali di input/output della convoluzione moltiplicando in_channels che è 512 nel paper e 256 nel nostro caso per i fattori), e rgb_blocks da ModuleList() che conterrà tutti i blocchi RGB.
  • Per sfumare nuovi livelli (un componente originale di ProGAN), aggiungiamo la parte fade_in, alla quale mandiamo alpha, scaled e generated, e restituiamo [tanh(alpha∗generated+(1−alpha)∗upscale)], La ragione per cui usiamo tanh è che sarà l’output (l’immagine generata) e vogliamo che i pixel siano nel range tra 1 e -1.
  • Nella parte in avanti, inviamo il rumore (Z_dim), il valore alpha che si dissolverà gradualmente durante l’addestramento (alpha è compreso tra 0 e 1), e steps che è il numero della risoluzione corrente con cui stiamo lavorando, passiamo x nella mappa per ottenere il vettore di rumore intermedio W, passiamo starting_constant a initial_noise1, applichiamo both e per W initial_adain1, poi lo passiamo in initial_conv, e di nuovo aggiungiamo initial_noise2 per esso con leaky come funzione di attivazione, e applichiamo both e W initial_adain2. Poi controlliamo se steps = 0, se lo è, tutto quello che vogliamo fare è farlo scorrere attraverso l’initial RGB e abbiamo finito, altrimenti, iteriamo sul numero di passaggi, e in ogni ciclo eseguiamo l’upscaling(upscaled) e lo facciamo scorrere attraverso il blocco progressivo che corrisponde a quella risoluzione(out). Alla fine, restituiamo fade_in che prende alpha, final_out, e final_upscaled dopo averlo mappato a RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 per prevenire l'errore di indice a causa di factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha dovrebbe essere scalare entro [0, 1], e upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # Il numero di canali in upscale rimarrà lo stesso, mentre
        # out che è passato attraverso prog_blocks potrebbe cambiare. Per assicurare
        # possiamo convertire entrambi in rgb utilizziamo layer rgb diversi
        # (steps-1) e steps per upscaled, out rispettivamente
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

Nella seguente porzione di codice puoi trovare la funzione generate_examples che accetta il generatore gen, il numero di passaggi per identificare la risoluzione corrente e un numero n=100. L’obiettivo di questa funzione è generare n immagini fake e salvarle come risultato.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

Nella seguente porzione di codice puoi trovare la funzione gradient_penalty per la perdita WGAN-GP.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calcola i punteggi del critico
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Prendi il gradiente dei punteggi rispetto alle immagini
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Funzione di addestramento

Per la funzione di addestramento, inviamo il critico (che è il discriminatore), gen (il generatore), il loader, il dataset, il passo, alpha e l’ottimizzatore per il generatore e per il critico.

Iniziamo ciclando su tutte le dimensioni dei mini-lotti che creiamo con il DataLoader, e prendiamo solo le immagini perché non ci serve un’etichetta.

Poi impostiamo l’addestramento per il discriminatore\Critico quando vogliamo massimizzare E(critico(reale)) – E(critico(falso)). Questa equazione significa quanto il critico riesce a distinguere tra immagini reali e fake.

Dopo đó, impostiamo l’addestramento per il generatore quando vogliamo massimizzare E(critic(fake)).

Infine, aggiorniamo il ciclo e il valore alpha per fade_in e ci assicuriamo che sia compreso tra 0 e 1, e lo restituiamo.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Aggiorna alpha e assicurati che sia inferiore a 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Addestramento

Ora che abbiamo tutto, mettiamo tutto insieme per addestrare il nostro StyleGAN.

Iniziamo initializing il generatore, il discriminatore/critico e gli ottimizzatori, poi convertiamo il generatore e il critico in modalità addestramento, quindi iteriamo su PROGRESSIVE_EPOCHS e in ogni ciclo chiamiamo la funzione di addestramento un numero di volte uguale alle epoche, poi generiamo alcune immagini false e le salviamo, come risultato, utilizzando la funzione generate_examples, e infine passiamo alla risoluzione dell’immagine successiva.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# inizializza ottimizzatori
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# inizia al passo che corrisponde alla dimensione dell'immagine che abbiamo impostato in config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # inizia con un alpha molto basso
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # progressione alla dimensione dell'immagine successiva

Risultato

Spero che tu possa seguire tutti i passaggi e ottenere una buona comprensione di come implementare StyleGAN nel modo corretto. Ora vediamo i risultati che otteniamo dopo aver addestrato questo modello su questo dataset con risoluzione 128*x 128.

Conclusione

In questo articolo, abbiamo realizzato un’implementazione pulita, semplice e leggibile da zero di StyleGAN1 utilizzando PyTorch. Abbiamo replicato il paper originale il più possibile, quindi se leggi il paper l’implementazione dovrebbe essere praticamente identica.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch