Estilo de Implementação GAN1 do zero

Introdução

Este artigo é sobre um dos melhores GANs hoje em dia, o StyleGAN, a partir do artigo <diy5 Uma Arquitetura de Gerador Baseada em Estilo para Redes Adversárias Gerativas, faremos uma implementação limpa, simples e legível usando PyTorch, e tentaremos replicar o artigo original o mais próximo possível, então, se você leu o artigo, a implementação deve ser praticamente idêntica.

O conjunto de dados que usaremos neste blog é este conjunto de dados do Kaggle, que contém 16.240 peças de roupas de cima para mulheres com resolução 256*192.

Pré-requisitos

Antes de mergulhar no trabalho com StyleGAN usando PyTorch, certifique-se de que você tem os seguintes pré-requisitos:

  • Conhecimento Básico de Aprendizado Profundo
    Compreensão de redes neurais convolucionais (CNNs).
    Familiaridade com Redes Adversárias Gerativas (GANs), incluindo conceitos como gerador, discriminador e perda adversária.

  • Requisitos de Hardware
    Uma GPU poderosa (NVIDIA recomendada) para treinamento e inferência mais rápidos.
    Kit CUDA instalado para aceleração de GPU (cuda e cudnn).

  • Familiaridade com StyleGAN
    É útil ter lido os papers originais do StyleGAN ou StyleGAN2 para entender melhorias na arquitetura e conceitos-chave.

Carregar todas as dependências que precisamos

Primeiro importaremos torch, pois utilizaremos PyTorch, e a partir disso importamos nn. Isso nos ajudará a criar e treinar as redes, e também nos permitirá importar optim, um pacote que implementa vários algoritmos de otimização (por exemplo, sgd, adam,…). Do torchvision importamos datasets e transforms para preparar os dados e aplicar algumas transformações.

Importaremos functional como F de torch.nn para upsampler as imagens usando interpolate, DataLoader de torch.utils.data para criar tamanhos de mini-lotes, save_image de torchvision.utils para salvar alguns exemplos falsos, e log2 de math, pois precisamos da representação inversa da potência de 2 para implementar o tamanho de mini-lote adaptativo dependendo da resolução de saída, NumPy para álgebra linear, os para interação com o sistema operacional, tqdm para exibir barras de progresso, e finalmente matplotlib.pyplot para mostrar os resultados e compará-los com os reais.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Hiperparâmetros

  • Inicializaremos o DATASET pelo caminho das imagens reais.
  • Especificaremos o início do treinamento no tamanho de imagem 8×8.
  • Inicializaremos o dispositivo por Cuda, se disponível, ou CPU, caso contrário, e a taxa de aprendizado por 0.001.
  • O tamanho do lote será diferente dependendo da resolução das imagens que queremos gerar, então inicializamos BATCH_SIZES com uma lista de números, você pode mudá-los dependendo da sua VRAM.
  • Inicializamos image_size por 128 e CHANNELS_IMG por 3, pois geraremos imagens RGB de 128 por 128.
  • No artigo original, eles inicializam Z_DIM, W_DIM e IN_CHANNELS com 512, mas eu os inicializo com 256 para menos uso de VRAM e aceleração do treinamento. Talvez até possamos obter melhores resultados se os dobrássemos.
  • Para o StyleGAN podemos usar qualquer função de perda de GANs que quisermos, então uso WGAN-GP do artigo Improved Training of Wasserstein GANs. Esta perda contém um parâmetro chamado λ e é comum definir λ = 10.
  • Inicialize PROGRESSIVE_EPOCHS com 30 para cada tamanho de imagem.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #Os autores começam com imagens de 8x8 em vez de 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Obter carregador de dados

Agora vamos criar uma função get_loader para:

  • Aplicar algumas transformações às imagens (redimensionar as imagens para a resolução que queremos, convertê-las em tensores, então aplicar algumas augmentações e, finalmente, normalizá-las para que todos os pixels variem de -1 a 1).
  • Identificar o tamanho atual do lote usando a lista BATCH_SIZES, e tomar como índice o número inteiro da representação inversa da potência de 2 do image_size/4. E isso é realmente como implementamos o tamanho de minibatch adaptativo dependendo da resolução de saída.
  • Preparar o conjunto de dados usando ImageFolder, pois já está estruturado de uma maneira agradável.
  • Criar tamanhos de mini-lote usando DataLoader que tomam o conjunto de dados e o tamanho do lote com mesclagem dos dados.
  • Finalmente, retornar o carregador e o conjunto de dados.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Implementação de Modelos

Agora vamos Implementar o gerador e discriminador StyleGAN1 (ProGAN e StyleGAN1 têm a mesma arquitetura de discriminador) com as atribuições-chave do artigo. Tentaremos tornar a implementação compacta, mas também legível e compreensível. Especificamente, os pontos-chave:

  • Rede de Mapeamento de Ruído
  • Normalização Adaptativa de Instância (AdaIN)
  • Crescimento Progressivo

Neste tutorial,我们将 apenas gerar imagens com StyleGAN1, e não implementaremos mesclagem de estilos e variação estocástica, mas não deve ser difícil fazer isso.

Vamos definir uma variável com o nome factors que contém os números que serão multiplicados por IN_CHANNELS para ter o número de canais que queremos em cada resolução de imagem.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Rede de Mapeamento de Ruído

A rede de mapeamento de ruído pega Z e o passa por oito camadas completamente conectadas separadas por alguma ativação. E não esqueça de equalizar a taxa de aprendizado como os autores fazem no ProGAN (ProGAN e StyleGan são escritos pelos mesmos pesquisadores).

Vamos primeiramente construir uma classe com o nome WSLinear (Linear com Ponderação e Escala) que será herdada de nn.Module.

  • No parte init nós mandamos in_features e out_channels. Criamos uma camada linear, então definimos uma escala que será igual à raiz quadrada de 2 dividida por in_features, copiamos o bias da camada atual na coluna para uma variável porque não queremos que o bias da camada linear seja escalado, então o removemos, finalmente inicializamos a camada linear.
  • No parte forward, mandamos x e tudo o que vamos fazer é multiplicar x pela escala e adicionar o bias após变形.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # inicializar camada linear
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Agora vamos criar a classe MappingNetwork.

  • No parte init mandamos z_dim e w_din, e definimos a rede de mapeamento que primeiramente normaliza z_dim, seguido por oito WSLInear e ReLU como funções de ativação.
  • No parte forward, retornamos o mapeamento da rede.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Normalização Adaptativa de Instância (AdaIN)

Agora vamos criar a classe AdaIN

  • Na parte init enviamos canais, w_dim e inicializamos instance_norm, que será a parte de normalização de instância, e inicializamos style_scale e style_bias, que serão as partes adaptativas com WSLinear que mapeia a Rede de Mapeamento de Ruído W nos canais.
  • Na passagem forward, enviamos x, aplicamos a normalização de instância para ele e retornamos style_scale * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Injetar Ruído

Agora vamos criar a classe InjectNoise para injetar o ruído no gerador

  • Na parte init enviamos canais e inicializamos o peso a partir de uma distribuição normal aleatória e usamos nn.Parameter para que esses pesos possam ser otimizados
  • Na parte forward, enviamos uma imagem x e a retornamos com ruído aleatório adicionado
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

classes úteis

Os autores construíram StyleGAN com base na implementação oficial do ProGAN de Karras et al, eles usam a mesma arquitetura de discriminador, tamanho de minibatch adaptativo, hiperparâmetros, etc. Então há muitos classes que permanecem as mesmas da implementação do ProGAN.

Nesta seção, criaremos as classes que não mudam da arquitetura do ProGAN.

No trecho de código abaixo, você pode encontrar a classe WSConv2d (camada de convolução ponderada e escalonada) para Equalized Learning Rate para as camadas de convolução.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # inicializar camada de convolução
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

No trecho de código abaixo, você pode encontrar a classe PixelNorm para normalizar Z antes da Rede de Mapeamento de Ruído.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

No trecho de código abaixo, você pode encontrar a classe ConvBlock que nos ajudará a criar o discriminador.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

No trecho de código abaixo, você pode encontrar a classe Discriminatowich é a mesma que na ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        aqui trabalhamos com caminhos invertidos a partir dos fatores porque o discriminador
        # deve ser espelhado a partir do gerador. Portanto, o primeiro bloco prog_block e
        # a camada rgb que anexaremos funcionará para o tamanho de entrada 1024x1024, depois 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # talvez o nome confuso "initial_rgb" seja apenas a camada RGB para o tamanho de entrada 4x4
        # fizemos isso para "espelhar" o initial_rgb do gerador
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # amostragem descendente usando pool de média

        # este é o bloco para o tamanho de entrada 4x4
        self.final_block = nn.Sequential(
            # +1 em in_channels porque concatenamos a partir do std do MiniBatch
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # usamos isso em vez de uma camada linear
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha deve ser escalar dentro de [0, 1], e upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # pegamos o std para cada exemplo (across todas as kanais e pixels) e repetimos
        # para um único canal e concatenamos com a imagem. Dessa forma, o discriminador
        # obterá informações sobre a variação na batelada/imagem
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # onde我们应该开始 na lista de prog_blocks, talvez um pouco confuso, mas
        # o último é para o 4x4. Portanto, exemplo, digamos que steps=1, então
        # devemos começar no penúltimo porque o input_size será 8x8. Se steps==0,
        # apenas usamos o bloco final
        cur_step = len(self.prog_blocks) - steps

        # converter do rgb como passo inicial, isso dependerá
        # do tamanho da imagem (cada uma terá sua própria camada rgb)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # ou seja, a imagem é 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # porque prog_blocks pode mudar os canais, para a redução de escala usamos rgb_layer
        # da tamanho anterior/menor, o que em nosso caso corresponde a +1 no índice
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # a fade_in é feita primeiro entre o escalado para baixo e a entrada
        # isso é o oposto do gerador
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

Gerador

Na arquitetura do gerador, temos alguns padrões que se repetem, então vamos primeiro criar uma classe para isso para tornar nosso código o mais limpo possível, vamos nomear a classe GenBlock que será herdada de nn.Module.

  • No parte do init, enviamos in_channels, out_channels e w_dim, então inicializamos conv1 com WSConv2d que mapeia in_channels para out_channels, conv2 com WSConv2d que mapeia out_channels para out_channels, leaky com Leaky ReLU com uma inclinação de 0.2 como usam no artigo, inject_noise1 e inject_noise2 com InjectNoise, adain1 e adain2 com AdaIN
  • No parte do forward, enviamos x, e o passamos por conv1, depois para inject_noise1 com leaky, então normalizamos com adain1, e novamente passamos isso para conv2, depois para inject_noise2 com leaky e normalizamos com adain2. E finalmente, retornamos x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Agora temos tudo o que precisamos para criar o gerador.

  • na parte init vamos inicializar ‘starting_constant’ por um tensor 4 x 4 (x 512 canais para o artigo original, e 256 no nosso caso) que passa por uma iteração do gerador, mapeado pelo ‘MappingNetwork’, initial_adain1, initial_adain2 pelo AdaIN, initial_noise1, initial_noise2 pelo InjectNoise, initial_conv por uma camada convolucional que mapeia os in_channels para si mesmo, leaky pelo Leaky ReLU com uma inclinação de 0.2, initial_rgb pelo WSConv2d que mapeia os in_channels para img_channels que é 3 para RGB, prog_blocks por ModuleList() que contém todos os blocos progressivos (indicamos os canais de entrada/saída da convolução multiplicando in_channels, que é 512 no artigo e 256 no nosso caso, por fatores), e rgb_blocks por ModuleList() que contém todos os blocos RGB.
  • Para fundir novas camadas (um componente original do ProGAN), adicionamos a parte fade_in, na qual enviamos alpha, scaled e generated, e retornamos [tanh(alpha∗generated+(1−alpha)∗upscale)], A razão pela qual usamos tanh é que será a saída (a imagem gerada) e queremos que os pixels estejam no intervalo entre 1 e -1.
  • Na parte avançada, enviamos o ruído (Z_dim), o valor alpha que vai desvanecer gradualmente durante o treinamento (alpha está entre 0 e 1) e os passos que é o número da resolução atual com a qual estamos trabalhando, passamos x para o mapa para obter o vetor de ruído intermediário W, passamos starting_constant para initial_noise1, aplicamos para ele e para W initial_adain1, então passamos para initial_conv, e novamente adicionamos initial_noise2 para ele com leaky como função de ativação, e aplicamos para ele e W initial_adain2. Em seguida, verificamos se steps = 0, se for, então tudo o que queremos fazer é executá-lo através do initial RGB e estamos feitos, caso contrário, loopamos sobre o número de passos, e em cada loop faremos o upscaling (upscaled) e executamos através do bloco progressivo que corresponde àquela resolução (out). No final, retornamos fade_in que recebe alpha, final_out e final_upscaled após mapeá-lo para RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1 para evitar erro de índice devido a factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha deve ser escalar dentro de [0, 1], e upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # O número de canais em upscale permanecerá o mesmo, enquanto
        # out que passou pelos prog_blocks pode mudar. Para garantir
        # que podemos converter ambos para rgb usamos diferentes rgb_layers
        # (steps-1) e steps para upscaled, out respectivamente
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Utils

No trecho de código abaixo, você pode encontrar a função generate_examples que recebe o gerador gen, o número de passos para identificar a resolução atual e um número n=100. O objetivo desta função é gerar n imagens falsas e salvá-las como resultado.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

No trecho de código abaixo, você pode encontrar a função gradient_penalty para a perda WGAN-GP.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calcular notas do crítico
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Tomar o gradiente das notas em relação às imagens
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Função de treinamento

Para a função de treinamento, enviamos o crítico (que é o discriminador), gen (gerador), loader, conjunto de dados, passo, alpha e otimizador para o gerador e para o crítico.

Iniciamos fazendo um loop sobre todos os tamanhos de mini-lote que criamos com o DataLoader, e pegamos apenas as imagens, pois não precisamos de um rótulo.

Em seguida, configuramos o treinamento para o discriminador\Crítico quando queremos maximizar E(critic(real)) – E(critic(fake)). Esta equação significa quanto o crítico pode distinguir entre imagens reais e falsas.

Após isso, configuramos o treinamento do gerador quando queremos maximizar E(critic(fake)).

Finalmente, atualizamos o loop e o valor alpha para fade_in e garantimos que ele está entre 0 e 1, e o retornamos.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Atualizar alpha e garantir que seja menor que 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Treinamento

Agora que temos tudo, vamos juntar as peças para treinar nosso StyleGAN.

Iniciamos pelo descarregamento do gerador, do discriminador/crítico e dos otimizadores, depois convertemos o gerador e o crítico para o modo de treinamento, depois-loop sobre PROGRESSIVE_EPOCHS e em cada loop, chamamos a função de treinamento o número de vezes de epoch, então geramos algumas imagens falsas e as salvamos, como resultado, usando a função generate_examples, e finalmente, progredimos para a próxima resolução de imagem.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# inicializar otimizadores
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# começar no passo que corresponde ao tamanho da img que definimos na config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # começar com alpha muito baixo
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # progredir para o próximo tamanho de img

Resultado

Espero que você consiga seguir todos os passos e obtenha uma boa compreensão de como implementar o StyleGAN da maneira correta. Agora vamos verificar os resultados que obtemos após treinar este modelo neste conjunto de dados com resolução de 128*x 128.

Conclusão

Neste artigo, fizemos uma implementação limpa, simples e legível do StyleGAN1 do zero usando PyTorch. replicamos o artigo original o mais próximo possível, então, se você ler o artigo, a implementação deve ser praticamente idêntica.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch