Реализация StyleGAN1 с нуля

Введение

Эта статья посвящена одному из лучших GAN на сегодняшний день, StyleGAN из статьи Style-Based Generator Architecture for Generative Adversarial Networks, мы создадим чистую, простую и читаемую реализацию на основе PyTorch и постараемся как можно ближе复制 исходную статью, чтобы если вы читали статью, реализация должна быть практически идентичной.

Массив данных, который мы будем использовать в этом блоге, это массив данных с Kaggle, который содержит 16240female upper clothes с разрешением 256*192.

Предпосылки

Прежде чем приступить к работе с StyleGAN на основе PyTorch, убедитесь, что у вас есть следующие предпосылки:

  • Основные знания в области深海学习
    Понимание convolutions neural networks (CNN).
    Знакомство с Generative Adversarial Networks (GAN), включая такие концепции, как генератор, дискриминатор и враждебная потеря.

  • Требования к оборудованию
    Мощная видеокарта (рекомендуется NVIDIA) для ускорения обучения и вывода.
    Установленный toolkit CUDA для ускорения работы с GPU (cuda и cudnn).

  • Знакомство со StyleGAN
    Полезно будет прочитать оригинальные статьи StyleGAN или StyleGAN2, чтобы понять улучшения архитектуры и ключевые концепции.

Загрузить все необходимые зависимости

Мы сначала импортируем torch, так как мы будем использовать PyTorch, и из него импортируем nn. Это поможет нам создавать и обучать сети, а также позволит импортировать optim, пакет, реализующий различные оптимизационные алгоритмы (например, sgd, adam,…). Из torchvision мы импортируем datasets и transforms для подготовки данных и применения некоторых преобразований.

Мы импортируем functional как F из torch.nn для increase images usando interpolate, DataLoader из torch.utils.data для создания размеров mini-batch, save_image из torchvision.utils для сохранения некоторых подделанных примеров, log2 из math, так как нам нужно обратное представление степени 2 для реализации адаптивного размера mini-batch в зависимости от разрешающей способности вывода, NumPy для линейной алгебры, os для взаимодействия с операционной системой, tqdm для отображения индикаторов прогресса и, наконец, matplotlib.pyplot для отображения результатов и сравнения их с реальными.

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

Гиперпараметры

  • Инициализируем DATASET по пути к реальным изображениям.
  • Указываем размер изображения для начала обучения 8×8.
  • Инициализируем устройство через Cuda, если оно доступно, и через CPU в противном случае, а такжеlearning rate в 0.001.
  • Размер batch size будет различаться в зависимости от разрешения изображений, которые мы хотим генерировать, поэтому мы инициализируем BATCH_SIZES списком чисел, которые вы можете изменить в зависимости от вашей VRAM.
  • Инициализируем image_size в 128 и CHANNELS_IMG в 3, так как мы будем генерировать изображения 128 на 128 в формате RGB.
  • В оригинальной статье они инициализируют Z_DIM, W_DIM и IN_CHANNELS значением 512, но я инициализирую их значением 256 вместо этого для уменьшения использования VRAM и ускорения обучения. Возможно, мы могли бы получить даже лучшие результаты, если бы удвоили их.
  • Для StyleGAN мы можем использовать любую функцию потерь GAN, которую хотим, поэтому я использую WGAN-GP из статьи Improved Training of Wasserstein GANs. Эта функция потерь содержит параметр с именем λ, и обычно устанавливают λ = 10.
  • Инициализируйте PROGRESSIVE_EPOCHS значением 30 для каждого размера изображения.
DATASET                 = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #Авторы начинают с изображений 8x8 вместо 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

Получите загрузчик данных

Теперь создадим функцию get_loader, чтобы:

  • Применить некоторые преобразования к изображениям (изменить размер изображений до desired resolution, convert them to tensors, затем применить некоторые улучшения и, наконец, normalize их, чтобы все пиксели находились в диапазоне от -1 до 1).
  • Определить текущий размер пакета, используя список BATCH_SIZES, и взять в качестве индекса целое число, обратное представлению степени 2 размера изображения/4. И именно так мы реализуем адаптивный размер минипакета в зависимости от разрешающей способности выхода.
  • Подготовить набор данных, используя ImageFolder, так как он уже структурирован удобным образом.
  • >Создайте размеры mini-batch с использованием DataLoader, который принимает набор данных и размер batch size с перемешиванием данных.
  • Наконец, верните загрузчик и набор данных.
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

Реализация моделей

Теперь давайте реализуем генератор и дискриминатор StyleGAN1 (ProGAN и StyleGAN1 имеют одинаковую архитектуру дискриминатора) с основными атрибутами из статьи. Мы постараемся сделать реализацию компактной, но также читаемой и понятной. Specifically, the key points:

  • Network Noise Mapping
  • Adaptive Instance Normalization (AdaIN)
  • Progressive growing

В этом руководстве мы будем просто генерировать изображения с помощью StyleGAN1, не реализуя стилирование и стохастическое разнообразие, но это shouldn’t быть太难.

Давайте определим переменную с названием factors, которая содержит числа, умножаемые на IN_CHANNELS для получения desired количества каналов в каждом разрешении изображения.

factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

Network Noise Mapping

Шумовая картографическая сеть принимает Z и пропускает его через восемь teljes connected слоев, разделенных какой-то активацией. И не забудьте equalize темп обучения, как это делают авторы в ProGAN (ProGAN и StyleGan написаны теми же исследователями).

Давайте сперва создадим класс с названием WSLinear (weighted scaled Linear), который будет унаследован от nn.Module.

  • В части init мы передаем in_features и out_channels. Создаем линейный слой, затем определяем scale, который будет равен квадратному корню из 2, деленному на in_features, мы копируем bias текущего столбцового слоя в переменную, потому что мы не хотим, чтобы bias линейного слоя был масштабирован, затем удаляем его, наконец, инициализируем линейный слой.
  • В части forward мы передаем x и все, что мы будем делать, это умножить x на scale и добавить bias после его преобразования.
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super(WSLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # инициализация линейного слоя
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

Теперь создадим класс MappingNetwork.

  • В части init мы передаем z_dim и w_din, и определяем сеть картографирования, которая сначала normalize z_dim, затем следует восемь WSLInear и ReLU в качестве активационных функций.
  • В части forward мы возвращаем сеть картографирования.

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

Адаптивная инстанс normalization (AdaIN)

Теперь создадим класс AdaIN

  • В части init мы отправляем каналы, w_dim, и мы инициализируем instance_norm, который будет частью инстанс-нормализации, а также инициализируем style_scale и style_bias, которые будут адаптивными частями с WSLinear, который отображает Noise Mapping Network W в каналы.
  • В процессе forward мы отправляем x, применяем для него инстанс-нормализацию и возвращаем style_sclate * x + style_bias.

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

Введение шума

Теперь создадим класс InjectNoise для внедрения шума в генератор

  • В части init мы отправляем каналы и инициализируем вес из случайного нормального распределения, используя nn.Parameter, чтобы эти веса можно было оптимизировать
  • В части forward мы отправляем изображение x и возвращаем его с добавленным случайным шумом
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

полезные классы

Авторы создали StyleGAN на основе официальной реализации ProGAN от Karras и др., они используют ту же архитектуру дискриминатора, адаптивный размер minibatch, гиперпараметры и т.д. Поэтому есть множество классов, которые remained неизменными по сравнению с реализацией ProGAN.

В этом разделе мы создадим классы, которые не изменяются по сравнению с архитектурой ProGAN.

В фрагменте кода ниже вы можете найти класс WSConv2d (ciosл加权 масштабированный конволюционный слой) для Equalized Learning Rate для конволюционных слоев.

class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # инициализация конволюционного слоя
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

В фрагменте кода ниже вы можете найти класс PixelNorm для normalization Z передNoise Mapping Network.

class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)   

В фрагменте кода ниже вы можете найти класс ConvBock, который поможет нам создать дискриминатор.

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

В фрагменте кода ниже вы можете найти класс Discriminatowich, который такой же, как и в ProGAN.

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        здесь мы работаем обратными путями от факторов, потому что дискриминант
        должен быть зеркальным для генератора.	So первый prog_block и
        rgb слой, который мы добавим, будет работать для входного размера 1024x1024, затем 512->256-> и т.д.
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        возможно, путающее имя "initial_rgb" это просто rgb слой для входного размера 4x4
        сделал это, чтобы "зеркалить" начальный rgb генератора
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )   downsampling с использованием среднего пула

        это блок для входного размера 4x4
        self.final_block = nn.Sequential(
            +1 к in_channels, потому что мы конкатенируем из MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  мы используем это вместо линейного слоя
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        alpha должен быть скаляром в диапазоне [0, 1], и upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
         мы берем std для каждого примера (по всем каналам и пикселям), затем повторяем его
         для одного канала и конкатенируем с изображением. Таким образом дискриминант
        получит информацию о variации в партии/изображении
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        где мы должны начать в списке prog_blocks, может быть немного запутанно, но
        последний для 4x4.	So пример, давайте предположим, steps=1, тогда мы должны начать
         со второго с конца, потому что входной размер будет 8x8. Если steps==0, мы просто
        используем конечный блок
        cur_step = len(self.prog_blocks) - steps

        преобразуем из rgb как начальный шаг, это будет зависеть от
        размера изображения (каждое будет иметь свой rgb слой)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  то есть, изображение 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

         потому что prog_blocks могут изменить каналы, для уменьш. масштаба мы используем rgb_layer
         из предыдущего/меньшего размера, который в нашем случае коррелирует с +1 в индексации
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

         fade_in выполняется сначала между уменьш. масштабом и входом
         это противоположно генератору
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

for Генератор

В архитектуре генератора у нас есть некоторые паттерны, которые повторяются, поэтому давайте сперва создадим класс для них, чтобы сделать наш код как можно чище, назовем класс GenBlock, который будет наследоваться от nn.Module.

  • В части инициализации мы передаем in_channels, out_channels и w_dim, затем инициализируем conv1 через WSConv2d, который отображает in_channels на out_channels, conv2 через WSConv2d, который отображает out_channels на out_channels, leaky через Leaky ReLU с наклоном 0.2, как это используется в статье, inject_noise1 и inject_noise2 через InjectNoise, adain1 и adain2 через AdaIN.
  • В части forward мы передаем x и пропускаем его через conv1, затем через inject_noise1 с leaky, затем normalize с adain1, и снова пропускаем его через conv2, затем через inject_noise2 с leaky и normalize с adain2. И finally, мы возвращаем x.
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

Теперь у нас есть все, что нам нужно для создания генератора.

  • В части init давайте инициируем ‘starting_constant’ константой 4 x 4 (x 512 каналов для оригинальной статьи, и 256 в нашем случае) тензора, который проходит итерацию генератора, map через ‘MappingNetwork’, initial_adain1, initial_adain2 через AdaIN, initial_noise1, initial_noise2 через InjectNoise, initial_conv через конволюционный слой, который отображает in_channels на самого себя, leaky через Leaky ReLU с наклоном 0.2, initial_rgb через WSConv2d, который отображает in_channels на img_channels, что equals 3 для RGB, prog_blocks через ModuleList(), который будет содержать все прогрессивные блоки (мы указываем входные/выходные каналы конволюции, умножая in_channels, что equals 512 в статье и 256 в нашем случае, на коэффициенты), и rgb_blocks через ModuleList(), который будет содержать все RGB блоки.
  • Чтобы gradually добавлять новые слои (оригинальный компонент ProGAN), мы добавляем часть fade_in, в которую мы передаем alpha, scaled и generated, и мы возвращаем [tanh(alpha∗generated+(1−alpha)∗upscale)], Причина, по которой мы используем tanh, заключается в том, что это будет вывод (сгенерированное изображение), и мы хотим, чтобы пиксели находились в диапазоне между 1 и -1.
  • В части вперёд, мы отправляем шум (Z_dim), значение alpha, которое будет медленно исчезать во время обучения (alpha находится между 0 и 1), и шаги, которые являются номером текущего разрешения, с которым мы работаем, передаём x в карту, чтобы получить промежуточный вектор шума W, передаём starting_constant в initial_noise1, применяем его и для W initial_adain1, затем передаём его в initial_conv, и снова добавляем initial_noise2 для него с leaky в качестве функции активации, и применяем его и W initial_adain2. Затем проверяем, равны ли шаги 0, если да, то всё, что мы хотим сделать, это пропустить его через initial RGB и мы закончили, в противном случае, мы循环 по количеству шагов, и в каждом цикле мы масштабируем (upscaled) и пропускаем через прогрессирующий блок, соответствующий этому разрешению (out). В конце, мы возвращаем fade_in, который принимает alpha, final_out и final_upscaled после его mapping к RGB.
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):  # -1, чтобы предотвратить ошибку индекса из-за factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha должен быть скаляром в диапазоне [0, 1], и upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # Количество каналов в upscale останется тем же, в то время как
        # out, который прошёл через prog_blocks, может измениться. Чтобы обеспечить
        # мы можем преобразовать их в rgb, мы используем разные rgb_layers
        # (steps-1) и шаги для upscaled, out соответственно
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

Утилиты

В фрагменте кода ниже вы можете найти функцию generate_examples, которая принимает генератор gen, количество шагов для определения текущего разрешения и число n=100. Цель этой функции – сгенерировать n поддельных изображений и сохранить их в качестве результата.

def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

В фрагменте кода ниже вы можете найти функцию gradient_penalty для WGAN-GP loss.

def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Расчет оценок критика
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Принять градиент оценок по отношению к изображениям
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Функция обучения

Для функции обучения мы отправляем критик (который является дискриминатором), gen (генератор), лоадер, набор данных, шаг, альфу и оптимайзеры для генератора и критика.

Мы начинаем с цикла по всем размерам迷你- батчей, которые мы создаем с помощью DataLoader, и берем только изображения, так как нам не нужен ярлык.

Затем мы настраиваем обучение для дискриминатора\Критика, когда хотим максимизировать E(critic(real)) – E(critic(fake)). Это уравнение означает, насколько критик может различать реальные и поддельные изображения.

После этого мы настраиваем обучение генератора, когда хотим максимизировать E(critic(fake)).

Наконец, мы обновляем цикл и значение альфа для fade_in и обеспечиваем, чтобы оно было между 0 и 1, и возвращаем его.

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]


        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Обновление альфа и обеспечение значения меньше 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )


    return alpha

Обучение

Теперь, когда у нас есть все, давайте объединим их, чтобы обучить наш StyleGAN.

Мы начинаем с инициализации генератора, дискриминатора/критика и оптимизаторов, затем переводим генератор и критик в режим обучения, затем循环 по PROGRESSIVE_EPOCHS и в каждом цикле вызываем функцию обучения количество итераций, затем генерируем некоторые поддельные изображения и сохраняем их, как результат, используя функцию generate_examples, и, наконец, переходим к следующему разрешению изображения.

gen = Generator(
        Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
    ).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# инициализация оптимизаторов
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

# начать с шага, соответствующего размеру изображения, который мы установили в config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # начать с очень низкого значения альфа
    loader, dataset = get_loader(4 * 2 ** step)  
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # перейти к следующему размеру изображения

Результат

Надеюсь, вам удастся следовать всем шагам и получить хорошее понимание того, как реализовать StyleGAN的正确ным образом. Теперь давайте рассмотрим результаты, которые мы получаем после обучения этой модели на этом наборе данных с разрешением 128×128.

Заключение

В этой статье мы создаем чистую, простую и читаемую реализацию с нуля StyleGAN1 с использованием PyTorch. Мы как можно ближе replicируем оригинальную статью, поэтому, если вы читали статью, реализация должна быть практически идентичной.

Source:
https://www.digitalocean.com/community/tutorials/implementation-stylegan-from-scratch