Implementación de StyleGAN1 desde cero

Introducción

Este artículo trata sobre una de las mejores GAN hoy en día, StyleGAN del artículo Una Arquitectura Generadora Basada en Estilo para Redes Adversariales Generativas, haremos una implementación limpia, sencilla y legible de la misma utilizando PyTorch, y trataremos de replicar el artículo original lo más posible, así que si lees el artículo, la implementación debería ser prácticamente idéntica.

El conjunto de datos que utilizaremos en este blog es este conjunto de datos de Kaggle que contiene 16240 prendas de ropa superior para mujeres con resolución 256*192.

Prerrequisitos

Antes de sumergirte en trabajar con StyleGAN utilizando PyTorch, asegúrate de cumplir con los siguientes prerrequisitos:

  • Conocimientos Básicos de Aprendizaje Profundo
    Entendimiento de redes neuronales convolucionales (CNNs).
    Familiaridad con Redes Adversariales Generativas (GANs), incluyendo conceptos como el generador, el discriminador y la pérdida adversaria.

  • Requisitos de Hardware
    Una GPU potente (se recomienda NVIDIA) para un entrenamiento e inferencia más rápidos.
    Kit de herramientas CUDA instalado para la aceleración por GPU (cuda y cudnn).

  • Familiaridad con StyleGAN
    Es útil haber leído los papers originales de StyleGAN o StyleGAN2 para entender las mejoras en la arquitectura y conceptos clave.

Cargar todas las dependencias que necesitamos

Primero importaremos torch ya que utilizaremos PyTorch, y desde allí importamos nn. Eso nos ayudará a crear y entrenar las redes, y también nos permitirá importar optim, un paquete que implementa varios algoritmos de optimización (por ejemplo, sgd, adam,…). Desde torchvision importamos datasets y transforms para preparar los datos y aplicar algunas transformaciones.

Importaremos functional como F desde torch.nn para upsampler las imágenes utilizando interpolate, DataLoader desde torch.utils.data para crear tamaños de mini-lotes, save_image desde torchvision.utils para guardar algunos ejemplos falsos, y log2 desde math porque necesitamos la representación inversa de la potencia de 2 para implementar el tamaño de mini-lote adaptable dependiendo de la resolución de salida, NumPy para algebra lineal, os para interactuar con el sistema operativo, tqdm para mostrar barras de progreso, y finalmente matplotlib.pyplot para mostrar los resultados y compararlos con los reales.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Hipерпarámetros

  • Inicializamos el DATASET con la ruta de las imágenes reales.
  • Especificamos el inicio del entrenamiento en tamaño de imagen 8×8.
  • Inicializamos el dispositivo con Cuda si está disponible y CPU de lo contrario, y la tasa de aprendizaje en 0.001.
  • El tamaño del lote será diferente dependiendo de la resolución de las imágenes que queramos generar, por lo que inicializamos BATCH_SIZES con una lista de números, puedes cambiarlas dependiendo de tu VRAM.
  • Inicializamos image_size en 128 y CHANNELS_IMG en 3 porque generaremos imágenes RGB de 128 por 128.
  • En el documento original, inicializan Z_DIM, W_DIM, y IN_CHANNELS en 512, pero yo los inicializo en 256 para reducir el uso de VRAM y acelerar el entrenamiento. Podríamos quizás obtener mejores resultados si los duplicáramos.
  • Para StyleGAN podemos usar cualquiera de las funciones de pérdida de GAN que queramos, así que uso WGAN-GP del paper Improved Training of Wasserstein GANs. Esta pérdida contiene un parámetro llamado λ y es común establecer λ = 10.
  • Inicializar PROGRESSIVE_EPOCHS en 30 para cada tamaño de imagen.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #Los autores parten de imágenes de 8x8 en lugar de 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Obtener el cargador de datos

Ahora creemos una función get_loader para:

  • Aplicar algunas transformaciones a las imágenes (redimensionar las imágenes a la resolución que queremos, convertirlas a tensores, luego aplicar algunas aumentaciones, y finalmente normalizarlas para que todos los píxeles varíen de -1 a 1).
  • Identificar el tamaño actual del lote utilizando la lista BATCH_SIZES, y tomar como índice el número entero de la representación inversa de la potencia de 2 del tamaño de imagen/4. Y esto es en realidad cómo implementamos el tamaño de minibatch adaptativo dependiendo de la resolución de salida.
  • Preparar el conjunto de datos usando ImageFolder porque ya está estructurado de una manera adecuada.
  • Crear tamaños de mini-lotes utilizando DataLoader que tomen el conjunto de datos y el tamaño de lote con mezcla de datos.
  • Finalmente, devuelve el cargador y el conjunto de datos.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implementación de modelos

Now let’s Implement the StyleGAN1 generator and discriminator(ProGAN and StyleGAN1 have the same discriminator architecture) with the key attributions from the paper. We will try to make the implementation compact but also keep it readable and understandable. Specifically, the key points:

  • Red de Mapeo de Ruido
  • Normalización Adaptativa de Instancia (AdaIN)
  • Crecimiento Progresivo

En este tutorial, solo generaremos imágenes con StyleGAN1, y no implementaremos la mezcla de estilos y la variación estocástica, pero no debería ser difícil hacerlo.

Definamos una variable con el nombre factors que contenga los números que se multiplicarán con IN_CHANNELS para tener el número de canales que queremos en cada resolución de imagen.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Red de Mapeo de Ruido

La red de mapeo de ruido toma Z y lo hace pasar por ocho capas completamente conectadas separadas por alguna activación. Y no olvides igualar la tasa de aprendizaje como hacen los autores en ProGAN (ProGAN y StyleGan escrito por los mismos investigadores).

Vamos a construir primero una clase con el nombre WSLinear (Lineal Escalado Ponderado) que se heredará de nn.Module.

  • En la parte init enviamos in_features y out_channels. Creamos una capa lineal, luego definimos una escala que será igual a la raíz cuadrada de 2 dividida por in_features, copiamos el sesgo de la capa actual en una variable porque no queremos que el sesgo de la capa lineal sea escalado, luego lo quitamos, finalmente, inicializamos la capa lineal.
  • En la parte forward, enviamos x y todo lo que vamos a hacer es multiplicar x con la escala y agregar el sesgo después de deformarlo.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # inicializar capa lineal
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Ahora creemos la clase MappingNetwork.

  • En la parte init enviamos z_dim y w_din, y definimos la red de mapeo que primero normaliza z_dim, seguido de ocho WSLInear y ReLU como funciones de activación.
  • En la parte forward, devolvemos la red de mapeo.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Normalización Adaptativa de Instancia (AdaIN)

Ahora creemos la clase AdaIN

  • En la parte init enviamos canales, w_dim, y inicializamos instance_norm que será la parte de normalización de instancia, y también inicializamos style_scale y style_bias que serán las partes adaptables con WSLinear que mapea la Red de Mapeo de Ruido W en canales.
  • En la pasada forward, enviamos x, aplicamos la normalización de instancia para él, y devolvemos style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Inyectar Ruido

Ahora creemos la clase InjectNoise para inyectar el ruido en el generador

  • En la parte init enviamos canales y inicializamos el peso desde una distribución normal aleatoria y usamos nn.Parameter para que estos pesos puedan ser optimizados
  • En la parte forward, enviamos una imagen x y la devolvemos con ruido aleatorio agregado
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

clases útiles

Los autores construyen StyleGAN sobre la implementación oficial de ProGAN de Karras et al, usan la misma arquitectura de discriminador, tamaño de minibatch adaptativo, hiperparámetros, etc. Así que hay muchas clases que se mantienen igual en la implementación de ProGAN.

En esta sección, crearemos las clases que no cambian de la arquitectura de ProGAN.

En el fragmento de código a continuación puedes encontrar la clase WSConv2d (capa de convolución ponderada y escalada) para Equalized Learning Rate para las capas de convolución.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # inicializar capa de convolución
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

En el fragmento de código a continuación puedes encontrar la clase PixelNorm para normalizar Z antes de la Red de Mapeo de Ruido.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

En el fragmento de código a continuación puedes encontrar la clase ConvBock que nos ayudará a crear el discriminador.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

En el fragmento de código a continuación puedes encontrar la clase Discriminatowich que es la misma que en ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        aquí trabajamos hacia atrás desde los factores porque el discriminador
        debería estar espejado desde el generador. Así que el primer prog_block y
        la primera capa rgb que adjuntamos funcionará para el tamaño de entrada 1024x1024, luego 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        quizás el nombre confundente "initial_rgb" esta es solo la capa RGB para el tamaño de entrada 4x4
        hicimos esto para "espejar" el initial_rgb del generador
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  reducción de muestra utilizando avg pool

        este es el bloque para el tamaño de entrada 4x4
        self.final_block = nn.Sequential(
            +1 a in_channels porque concatenamos desde MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  usamos esto en lugar de una capa lineal
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha debería ser un escalar dentro de [0, 1], y upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        tomamos la std para cada ejemplo (a través de todos los canales y píxeles) y lo repetimos
        para un solo canal y lo concatenamos con la imagen. De esta manera el discriminador
        obtendrá información sobre la variación en el lote/imagen
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        donde deberíamos comenzar en la lista de prog_blocks, tal vez un poco confundido pero
        el último es para el 4x4. Así que, por ejemplo, digamos que steps=1, entonces deberíamos comenzar
        en el segundo último porque el input_size será 8x8. Si steps==0 simplemente
        usamos el bloque final
        cur_step = len(self.prog_blocks) - steps

        convertir desde rgb como paso inicial, esto dependerá de
        el tamaño de la imagen (cada uno tendrá su propia capa rgb)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  es decir, la imagen es 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        porque prog_blocks podría cambiar los canales, para la reducción de escala usamos rgb_layer
        del tamaño anterior/más pequeño, lo cual en nuestro caso se correlaciona con +1 en el índice
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        el fade_in se realiza primero entre el escalado descendente y la entrada
        esto es lo contrario del generador
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

Generador

En la arquitectura del generador, tenemos algunos patrones que se repiten, así que primero creemos una clase para ellos y hagamos que nuestro código sea lo más limpio posible, llamemos a la clase GenBlock que se heredará de nn.Module.

  • En la parte de init enviamos in_channels, out_channels y w_dim, luego inicializamos conv1 con WSConv2d que mapea in_channels a out_channels, conv2 con WSConv2d que mapea out_channels a out_channels, leaky con Leaky ReLU con una pendiente de 0.2 como lo usan en el paper, inject_noise1, inject_noise2 con InjectNoise, adain1 y adain2 con AdaIN
  • En la parte de forward, enviamos x, y lo pasamos a conv1 luego a inject_noise1 con leaky, luego lo normalizamos con adain1, y de nuevo pasamos eso a conv2 luego a inject_noise2 con leaky y lo normalizamos con adain2. Y finalmente, devolvemos x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Ahora tenemos todo lo que necesitamos para crear el generador.

  • en la parte init inicialicemos ‘starting_constant’ con un tensor de 4 x 4 (x 512 canales para el paper original, y 256 en nuestro caso) que se somete a una iteración del generador, mapeado por ‘MappingNetwork’, initial_adain1, initial_adain2 por AdaIN, initial_noise1, initial_noise2 por InjectNoise, initial_conv por una capa convolucional que mapea in_channels a sí misma, leaky por Leaky ReLU con una pendiente de 0.2, initial_rgb por WSConv2d que mapea in_channels a img_channels que es 3 para RGB, prog_blocks por ModuleList() que contendrá todos los bloques progresivos (indicamos los canales de entrada/salida de la convolución multiplicando in_channels que es 512 en el paper y 256 en nuestro caso con factores), y rgb_blocks por ModuleList() que contendrá todos los bloques RGB.
  • Para fundir nuevas capas (un componente original de ProGAN), añadimos la parte fade_in, a la que enviamos alpha, scaled y generated, y retornamos [tanh(alpha∗generated+(1−alpha)∗upscale)], La razón por la que usamos tanh es que será la salida (la imagen generada) y queremos que los píxeles estén en un rango entre 1 y -1.
  • En la parte adelante, enviamos el ruido (Z_dim), el valor de alpha que se va a desvanecer lentamente durante el entrenamiento (alpha está entre 0 y 1), y los pasos que es el número de la resolución actual con la que estamos trabajando, pasamos x al mapa para obtener el vector de ruido intermedio W, pasamos starting_constant a initial_noise1, aplicamos para él y para W initial_adain1, luego lo pasamos a initial_conv, y de nuevo añadimos initial_noise2 para él con leaky como función de activación, y aplicamos para él y W initial_adain2. Luego verificamos si steps = 0, si es así, entonces todo lo que queremos hacer es ejecutarlo a través del initial RGB y ya está, de lo contrario, recorremos el número de pasos, y en cada bucle escalamos (upscaled) y pasamos a través del bloque progresivo que corresponde a esa resolución(out). Al final, devolvemos fade_in que toma alpha, final_out, y final_upscaled después de mapearlo a RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 para evitar error de índice debido a factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha debería ser un escalar dentro de [0, 1], y upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # El número de canales en upscale se mantendrá igual, mientras que
        # out que ha pasado a través de prog_blocks podría cambiar. Para asegurar
        # que podemos convertir ambos a rgb usamos diferentes rgb_layers
        # (steps-1) y steps para upscaled, out respectivamente
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

En el siguiente fragmento de código puedes encontrar la función generate_examples que toma el generador gen, el número de pasos para identificar la resolución actual, y un número n=100. El objetivo de esta función es generar n imágenes falsas y guardarlas como resultado.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

En el siguiente fragmento de código puedes encontrar la función gradient_penalty para la pérdida WGAN-GP.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calcular las puntuaciones del crítico
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Tomar el gradiente de las puntuaciones con respecto a las imágenes
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Función de entrenamiento

Para la función de entrenamiento, enviamos el crítico (que es el discriminador), gen (el generador), el cargador, el conjunto de datos, el paso, alfa, y el optimizador para el generador y para el crítico.

Comenzamos ciclando sobre todos los tamaños de mini-lotes que creamos con el DataLoader, y solo tomamos las imágenes porque no necesitamos una etiqueta.

Luego configuramos el entrenamiento para el discriminador\Critico cuando queremos maximizar E(critic(real)) – E(critic(fake)). Esta ecuación significa cuánto puede distinguir el crítico entre imágenes reales y falsas.

Después de eso, configuramos el entrenamiento para el generador cuando queremos maximizar E(critic(fake)).

Finalmente, actualizamos el bucle y el valor de alpha para fade_in y nos aseguramos de que esté entre 0 y 1, y lo devolvemos.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Actualizar alpha y asegurarse de que sea menor que 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Entrenamiento

Ahora que tenemos todo, vamos a unirlo todo para entrenar nuestro StyleGAN.

Comenzamos inicializando el generador, el discriminador/critic y los optimizadores, luego convertimos el generador y el critic al modo de entrenamiento, luego iteramos sobre PROGRESSIVE_EPOCHS, y en cada bucle, llamamos a la función de entrenamiento el número de épocas veces, luego generamos algunas imágenes falsas y las guardamos como resultado, utilizando la función generate_examples, y finalmente, avanzamos a la siguiente resolución de imagen.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# inicializar optimizadores
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# comenzar en el paso que corresponde al tamaño de img que configuramos en config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # comenzar con un alpha muy bajo
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # avanzar a la siguiente img tamaño

Resultado

Espero que puedas seguir todos los pasos y obtener una buena comprensión de cómo implementar StyleGAN de la manera correcta. Ahora vamos a revisar los resultados que obtenemos después de entrenar este modelo en este conjunto de datos con resolución 128*x 128.

Conclusión

En este artículo, realizamos una implementación limpia, simple y legible desde cero de StyleGAN1 utilizando PyTorch. replicamos el paper original lo más cerca posible, por lo que si lees el paper, la implementación debería ser prácticamente idéntica.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch