Implementação do StyleGAN1 do Zero

Introdução

Este artigo falará sobre um dos melhores GANs atuais, StyleGAN, a partir do artigo <diy5 Uma Arquitetura de Gerador Baseada em Estilo para Redes Adversárias Gerativas, faremos uma implementação limpa, simples e legível usando PyTorch e tentaremos replicar o artigo original o mais próximo possível, então, se você leu o artigo, a implementação deve ser praticamente idêntica.

O conjunto de dados que usaremos neste blog é este conjunto de dados do Kaggle, que contém 16.240 peças de roupas de cima femininas com resolução 256*192.

Pré-requisitos

Antes de mergulhar no trabalho com StyleGAN usando PyTorch, certifique-se de que você tem os seguintes pré-requisitos:

  • Conhecimento Básico de Aprendizado profundo
    Entendimento de redes neurais convolucionais (CNNs).
    Familiaridade com Redes Adversárias Gerativas (GANs), incluindo conceitos como gerador, discriminador e perda adversária.

  • Requisitos de hardware
    Uma GPU poderosa (recomendada pela NVIDIA) para treinamento e inferência mais rápidos.
    Kit de ferramentas CUDA instalado para aceleração de GPU (cuda e cudnn).

  • Familiaridade com o StyleGAN
    É útil ter lido os documentos originais StyleGAN ou StyleGAN2 para compreender as melhorias na arquitetura e os conceitos chave.

Carregar todas as dependências de que precisamos

Vamos primeiramente importar torch, pois utilizaremos PyTorch, e a partir dele importamos nn. Isso nos ajudará a criar e treinar as redes, e também nos permitirá importar optim, um pacote que implementa vários algoritmos de otimização (por exemplo, sgd, adam,…). Do torchvision importamos datasets e transforms para preparar os dados e aplicar algumas transformações.

Vamos importar functional como F de torch.nn para reamostrar as imagens usando interpolate, DataLoader de torch.utils.data para criar tamanhos de mini-lotes, save_image de torchvision.utils para salvar alguns exemplos falsos, e log2 de math, pois precisamos da representação inversa da potência de 2 para implementar o tamanho de mini-lote adaptativo dependendo da resolução de saída, NumPy para álgebra linear, os para interação com o sistema operacional, tqdm para mostrar barras de progresso, e finalmente matplotlib.pyplot para mostrar os resultados e compará-los com os reais.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Hiperparâmetros

  • Inicializar o DATASET pelo caminho das imagens reais.
  • Especificar o início do treinamento no tamanho de imagem 8×8.
  • Inicializar o dispositivo por Cuda, se disponível, ou CPU, caso contrário, e a taxa de aprendizado por 0.001.
  • O tamanho do lote será diferente dependendo da resolução das imagens que queremos gerar, então inicializamos BATCH_SIZES por uma lista de números, você pode mudá-los dependendo da sua VRAM.
  • Inicializar image_size por 128 e CHANNELS_IMG por 3, pois vamos gerar imagens RGB de 128 por 128.
  • No artigo original, eles inicializam Z_DIM, W_DIM e IN_CHANNELS com 512, mas eu os inicializo com 256 para menos uso de VRAM e aceleração do treinamento. Talvez até possamos obter melhores resultados se os dobrássemos.
  • Para o StyleGAN podemos usar qualquer função de perda de GANs que quisermos, então uso a WGAN-GP do artigo Improved Training of Wasserstein GANs. Esta perda contém um parâmetro chamado λ e é comum definir λ = 10.
  • Inicialize PROGRESSIVE_EPOCHS com 30 para cada tamanho de imagem.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #Os autores começam com imagens de 8x8 em vez de 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Obter carregador de dados

Agora vamos criar uma função get_loader para:

  • Aplicar algumas transformações nas imagens (redimensionar as imagens para a resolução que queremos, convertê-las em tensores, então aplicar algumas augmentações e, finalmente, normalizá-las para que todos os pixels variem de -1 a 1).
  • Identificar o tamanho atual do lote usando a lista BATCH_SIZES, e pegar como índice o número inteiro da representação inversa da potência de 2 do image_size/4. E é assim que implementamos o tamanho de minibatch adaptativo dependendo da resolução de saída.
  • Preparar o conjunto de dados usando ImageFolder, pois já está estruturado de uma maneira agradável.
  • Criar tamanhos de mini-lote usando DataLoader que tomam o conjunto de dados e o tamanho do lote com mesclagem dos dados.
  • Por fim, retornar o carregador e o conjunto de dados.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implementação de Modelos

Agora vamos implementar o gerador e discriminador StyleGAN1 (ProGAN e StyleGAN1 têm a mesma arquitetura de discriminador) com as atribuições-chave do artigo. Tentaremos fazer a implementação compacta, mas também manter ela legível e compreensível. Especificamente, os pontos-chave:

  • Rede de Mapeamento de Ruido
  • Normalização Adaptativa de Instância (AdaIN)
  • Crescimento Progressivo

Neste tutorial,我们将 apenas gerar imagens com StyleGAN1, e não implementaremos mesclagem de estilos e variação estocástica, mas não deve ser difícil fazer isso.

Vamos definir uma variável com o nome factors que contém os números que serão multiplicados por IN_CHANNELS para ter o número de canais que queremos em cada resolução de imagem.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Rede de Mapeamento de Ruido

A rede de mapeamento de ruído leva Z e o coloca por oito camadas completamente conectadas separadas por alguma ativação. E não esqueça de equalizar a taxa de aprendizado como os autores fazem no ProGAN (ProGAN e StyleGan escrito pelos mesmos pesquisadores).

Vamos primeiramente construir uma classe com o nome WSLinear (Linear com Ponderação e Escala) que será herdada de nn.Module.

  • No parte init nós mandamos in_features e out_channels. Criamos uma camada linear, então definimos uma escala que será igual à raiz quadrada de 2 dividida por in_features, copiamos o bias da camada atual na coluna para uma variável porque não queremos que o bias da camada linear seja escalado, então o removemos, finalmente inicializamos a camada linear.
  • No parte forward, mandamos x e tudo o que vamos fazer é multiplicar x pela escala e adicionar o bias após变形á-lo.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # inicializar camada linear
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Agora vamos criar a classe MappingNetwork.

  • No parte init mandamos z_dim e w_din, e definimos a rede de mapeamento que primeiro normaliza z_dim, seguido por oito WSLInear e ReLU como funções de ativação.
  • No parte forward retornamos a rede de mapeamento.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Normalização Adaptativa de Instância (AdaIN)

Agora vamos criar a classe AdaIN

  • Na parte init enviamos channels, w_dim e inicializamos instance_norm, que será a parte de normalização de instância, e inicializamos style_scale e style_bias, que serão as partes adaptativas com WSLinear que mapeia a Rede de Mapeamento de Ruido W nos canais.
  • Na passagem forward, enviamos x, aplicamos a normalização de instância para ele, e retornamos style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Injetar Ruido

Agora vamos criar a classe InjectNoise para injetar o ruido no gerador

  • Na parte init enviamos canais e inicializamos o peso a partir de uma distribuição normal aleatória e usamos nn.Parameter para que esses pesos possam ser otimizados
  • Na parte forward, enviamos uma imagem x e a retornamos com ruído aleatório adicionado
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

classes úteis

Os autores construíram StyleGAN com base na implementação oficial do ProGAN de Karras et al, eles usam a mesma arquitetura de discriminador, tamanho de minibatch adaptativo, hiperparâmetros, etc. Portanto, há muitos casos que permanecem os mesmos da implementação do ProGAN.

Nesta seção, criaremos as classes que não mudam da arquitetura do ProGAN.

No trecho de código abaixo, você pode encontrar a classe WSConv2d (camada de convolução ponderada e escalonada) para Equalized Learning Rate para as camadas de conv.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # inicializar camada de conv
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

No trecho de código abaixo, você pode encontrar a classe PixelNorm para normalizar Z antes da Rede de Mapeamento de Ruído.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

No trecho de código abaixo, você pode encontrar a classe ConvBock que nos ajudará a criar o discriminador.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

No trecho de código abaixo, você pode encontrar a classe Discriminatowhich é a mesma que na ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        aqui trabalhamos de trás para frente com os fatores porque o discriminador
        # deve ser espelhado a partir do gerador. Portanto, o primeiro prog_block e
        # a camada rgb que anexaremos funcionará para tamanho de entrada 1024x1024, depois 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # talvez o nome confuso "initial_rgb" seja apenas a camada RGB para tamanho de entrada 4x4
        # fiz isso para "espelhar" o initial_rgb do gerador
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # amostragem descendente usando pool average

        # este é o bloco para tamanho de entrada 4x4
        self.final_block = nn.Sequential(
            # +1 nos in_channels porque concatenamos a partir do MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # usamos isso em vez da camada linear
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha deve ser escalar dentro de [0, 1], e upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # pegamos o std para cada exemplo (across todas as channels e pixels) e repetimos
        # para um único canal e concatenamos com a imagem. Dessa forma, o discriminador
        # obterá informações sobre a variação na batch/imagem
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # onde我们应该在prog_blocks lista começar, talvez um pouco confuso, mas
        # o último é para o 4x4. Então, por exemplo, digamos que steps=1, então我们应该 começar
        # na penúltima, porque input_size será 8x8. Se steps==0 usamos apenas
        # o bloco final
        cur_step = len(self.prog_blocks) - steps

        # converter do rgb como passo inicial, isso dependerá
        # do tamanho da imagem (cada uma terá sua própria camada rgb)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # ou seja, a imagem é 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # porque prog_blocks pode mudar os canais, para a redução usamos rgb_layer
        # do tamanho anterior/menor, o que em nosso caso corresponde a +1 no índice
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # a fade_in é feita primeiro entre o redimensionado e o entrada
        # isso é o oposto do gerador
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

G gerador

Na arquitetura do gerador, temos alguns padrões que se repetem, então vamos primeiro criar uma classe para isso para deixar nosso código o mais limpo possível, vamos nomear a classe GenBlock que será herdada de nn.Module.

  • No parte do init, mandamos in_channels, out_channels e w_dim, então inicializamos conv1 com WSConv2d que mapeia in_channels para out_channels, conv2 com WSConv2d que mapeia out_channels para out_channels, leaky com Leaky ReLU com uma inclinação de 0.2 como usam no artigo, inject_noise1, inject_noise2 com InjectNoise, adain1 e adain2 com AdaIN
  • No parte do forward, mandamos x, e o passamos para conv1 depois para inject_noise1 com leaky, então normalizamos com adain1, e novamente passamos isso para conv2 depois para inject_noise2 com leaky e normalizamos com adain2. E finalmente, retornamos x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Agora temos tudo o que precisamos para criar o gerador.

  • na parte init vamos inicializar ‘starting_constant’ por um tensor 4 x 4 (x 512 canais para o artigo original, e 256 no nosso caso) que passa por uma iteração do gerador, mapeado pelo ‘MappingNetwork’, initial_adain1, initial_adain2 pelo AdaIN, initial_noise1, initial_noise2 pelo InjectNoise, initial_conv por uma camada convolucional que mapeia in_channels para si mesmo, leaky pelo Leaky ReLU com uma inclinação de 0.2, initial_rgb pelo WSConv2d que mapeia in_channels para img_channels que é 3 para RGB, prog_blocks por ModuleList() que vai conter todos os blocos progressivos (indicamos os canais de entrada/saída da convolução multiplicando in_channels que é 512 no artigo e 256 no nosso caso pelos fatores), e rgb_blocks por ModuleList() que vai conter todos os blocos RGB.
  • Para introduzir novas camadas (um componente original do ProGAN), adicionamos a parte fade_in, na qual enviamos alpha, escalado e gerado, e retornamos [tanh(alpha∗generated+(1−alpha)∗upscale)], A razão pela qual usamos tanh é que será a saída (a imagem gerada) e queremos que os pixels estejam no intervalo entre 1 e -1.
  • Na parte avançada, enviamos o ruído (Z_dim), o valor alpha que vai desaparecer lentamente durante o treinamento (alpha está entre 0 e 1) e os passos que é o número da resolução atual com a qual estamos trabalhando, passamos x para o mapa para obter o vetor de ruído intermediário W, passamos o starting_constant para initial_noise1, aplicamos para ele e para W o initial_adain1, então passamos para o initial_conv, e novamente adicionamos initial_noise2 para ele com leaky como função de ativação, e aplicamos para ele e W o initial_adain2. Então verificamos se steps = 0, se for, tudo o que queremos fazer é passar pelo initial RGB e我们已经完成, caso contrário, loopamos sobre o número de passos e em cada loop faremos o upscaling (upscaled) e passamos pelo bloco progressivo correspondente àquela resolução (out). No final, retornamos fade_in que recebe alpha, final_out e final_upscaled após mapeá-lo para RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 para evitar erro de índice devido a factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha deve ser escalar dentro de [0, 1], e upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # O número de canais em upscale permanecerá o mesmo, enquanto
        # out que passou pelos prog_blocks pode mudar. Para garantir
        # que podemos converter ambos para rgb usamos diferentes rgb_layers
        # (steps-1) e steps para upscaled, out respectivamente
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

No trecho de código abaixo, você pode encontrar a função generate_examples que recebe o gerador gen, o número de passos para identificar a resolução atual e um número n=100. O objetivo dessa função é gerar n imagens falsas e salvá-las como resultado.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

No trecho de código abaixo, você pode encontrar a função gradient_penalty para a perda WGAN-GP.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calcular as notas do crítico
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Tomar o gradiente das notas em relação às imagens
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Função de treinamento

Para a função de treinamento, enviamos o crítico (que é o discriminador), gen (gerador), carregador, conjunto de dados, passo, alpha e otimizador para o gerador e para o crítico.

Nos iniciamos fazendo um loop sobre todos os tamanhos de mini-lote que criamos com o DataLoader, e pegamos apenas as imagens, porque não precisamos de um rótulo.

Em seguida, configuramos o treinamento para o discriminador\Crítico quando queremos maximizar E(critico(real)) – E(critico(fake)). Esta equação significa quanto o crítico pode distinguir entre imagens reais e falsas.

Depois disso, configuramos o treinamento do gerador quando queremos maximizar E(critic(fake)).

Por fim, atualizamos o loop e o valor de alpha para fade_in e garantimos que ele esteja entre 0 e 1, e o retornamos.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Atualizar alpha e garantir que seja menor que 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Treinamento

Agora,since temos tudo, vamos juntar tudo para treinar nosso StyleGAN.

Começamos inicializando o gerador, o discriminador/critic e os otimizadores, depois convertemos o gerador e o critic para o modo de treinamento, então loopamos sobre PROGRESSIVE_EPOCHS e, em cada loop, chamamos a função de treinamento o número de vezes de epoch, então geramos algumas imagens falsas e salvamos elas, como resultado, usando a função generate_examples, e finalmente, avançamos para a próxima resolução de imagem.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# inicializar otimizadores
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# começar no passo que corresponde ao tamanho da img que configuramos no config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # começar com alpha muito baixo
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # avançar para o próximo tamanho de img

Resultado

Espero que você consiga seguir todos os passos e obter uma boa compreensão de como implementar o StyleGAN da maneira correta. Agora vamos verificar os resultados que obtemos após treinar este modelo neste conjunto de dados com resolução 128×128.

Conclusão

Neste artigo, fizemos uma implementação limpa, simples e legível do StyleGAN1 do zero usando PyTorch. replicamos o artigo original o mais próximo possível, então, se você leu o artigo, a implementação deve ser praticamente idêntica.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch