Introdução
Modelos de aprendizagem profunda maiores requerem mais potência de computação e recursos de memória. O treinamento rápido de redes neurais profundas foi alcançado através do desenvolvimento de novas técnicas. Em vez de FP32 (formato de números de ponto flutuante com precisão total), você pode usar FP16 (formato de ponto flutuante com precisão parcial), e os pesquisadores descobriram que usá-los em conjunto é uma opção melhor.
A precisão mista permite o treinamento com precisão parcial enquanto ainda preserva muito da precisão da rede de precisão simples. O termo “técnica de precisão mista” refere-se ao fato de que este método faz uso de representações tanto de precisão simples quanto de precisão parcial.
Neste overview da treinagem de Precisão Mista Automática (AMP) com PyTorch, nós mostramos como a técnica funciona, percorrendo passo a passo o processo de usar AMP, e discutimos aplicações avançadas de técnicas de AMP com estruturas de código para que usuários possam integrá-las com o seu próprio código.
Pré-requisitos
Conhecimento Básico de PyTorch: Familiaridade com PyTorch, incluindo seus conceitos básicos como tensores, módulos e o laço de treinamento.
Compreensão dos Fundamentos de Aprendizagem Profunda: Conceitos como redes neurais, backpropagation e otimização.
Conhecimento de Treinamento com Precisão Mista: Conhecimento dos benefícios e desvantagens do treinamento com precisão mixta, incluindo o uso reduzido de memória e o cálculo mais rápido.
Acesso a Hardware Compatível: Uma GPU que suporte precisão mixta, como as GPUs NVIDIA com Cores Tensor (por exemplo, arquiteturas Volta, Turing, Ampere).
Configuração de Python e CUDA: Um ambiente de Python funcional com o PyTorch instalado e o CUDA configurado para aceleração por GPU.
Visão Geral da Precisão Mista
Como a maioria dos frameworks de aprendizado profundo, o PyTorch normalmente é treinado com dados de ponto flutuante de 32 bits (FP32). No entanto, o FP32 nem sempre é necessário para o sucesso. É possível usar um ponto flutuante de 16 bits para algumas operações, onde o FP32 consome mais tempo e memória.
Como resultado, os engenheiros da NVIDIA desenvolveram uma técnica que permite que o treinamento com precisão mixta seja realizado em FP32 para uma pequena quantidade de operações enquanto a maior parte da rede funciona em FP16.
- Converta o modelo para utilizar o tipo de dado float16 onde possível.
- Manter os pesos mestres float32 para acumular atualizações de peso em cada iteração.
- O uso de escalonamento de perda para preservar valores de gradiente pequenos.
Treinamento com Precisão mista em PyTorch
Para treinamento com precisão mista, o PyTorch já oferece um conjunto rico de recursos integrados.
Os parâmetros de um módulo são convertidos para FP16 quando você chama o método .half()
, e os dados de um tensor são convertidos para FP16 quando você chama .half()
. Serão usadas aritméticas rápidas FP16 para executar quaisquer operações nestes módulos ou tensores. As bibliotecas de matemática NVIDIA (cuBLAS e cuDNN) são bem suportadas pelo PyTorch. Os dados da pipeline FP16 são processados usando Cores de Tensor para realizar GEMMs e convoluções. Para empregar Cores de Tensor em cuBLAS, as dimensões de um GEMM ([M, K] x [K, N] -> [M, N]) devem ser múltiplos de 8.
Introduzindo Apex
As ferramentas de precisão mista do Apex são projetadas para aumentar a velocidade de treinamento enquanto mantêm a precisão e a estabilidade do treinamento de single-precision. O Apex pode executar operações em FP16 ou FP32, manipular automaticamente a conversão de parâmetros mestres, e escalar automaticamente as perdas.
Apex foi criado para tornar mais fácil para investigadores incluir treinamento de precisão mista em seus modelos. Amp, abreviatura de Automatic Mixed-Precision, é uma das funcionalidades do Apex, uma extensão leve do PyTorch. Algumas linhas adicionais em suas redes são o que os usuários precisam para se beneficiar do treinamento de precisão mista com Amp. O Apex foi lançado em CVPR 2018, e vale a pena notar que a comunidade PyTorch mostrou forte apoio para o Apex desde o seu lançamento.
Alterando o modelo de execução apenas de forma menor, Amp torna o processo de criação ou execução de seu script livre de preocupações com tipos mistos. As suposições de Amp podem não se ajustar tão bem em modelos que utilizam o PyTorch de maneiras menos comuns, mas há ganchos para ajustar essas suposições conforme necessário.
O Amp oferece todos os benefícios do treinamento de precisão mista sem a necessidade de escalonamento de perda ou conversões de tipo que devem ser gerenciadas explicitamente. O site do GitHub do Apex contém instruções para o processo de instalação, e a documentação oficial da API pode ser encontrada aqui.
Como as Amps Funcionam
O Amp utiliza um paradigma de whitelist/blacklist ao nível lógico. As operações de tensor no PyTorch incluem funções de rede neural, como torch.nn.functional.conv2d, funções matemáticas simples, como torch.log, e métodos de tensor, como torch.Tensor.add__. Existem três categorias principais de funções neste universo:
- Whitelist: Funções que poderiam beneficiar do speedup de matemática de FP16. Aplicações típicas incluem multiplicação de matrizes e convolução.
- Blacklist: Os inputs devem estar em FP32 para funções onde 16 bits de precisão pode não ser o suficiente.
- Todo o resto (quaisquer funções que sobraram): Funções que podem executar em FP16, mas o custo de um cast de FP32 -> FP16 para executá-las em FP16 não é válido, já que o speedup não é significativo.
A tarefa do Amp é simples, pelo menos em teoria. O Amp determina se uma função do PyTorch está whitelistada, blacklistada ou nem uma nem outra antes de a chamar. Todos os argumentos devem ser convertidos para FP16 se whitelistados ou FP32 se blacklistados. Se nenhum, apenas certifique-se que todos os argumentos forem do mesmo tipo. Esta política não é tão simples de aplicar na prática quanto parece.
Usando Amp em conjunto com um modelo PyTorch
Para incluir o Amp the um script PyTorch atual, siga estes passos:
- Use a biblioteca Apex para importar o Amp.
- Inicialize o Amp para que ele faça as mudanças necessárias no modelo, no otimizador e nas funções internas de PyTorch.
- Note onde a backpropagation (.backward()) ocorre para que o Amp possa escalar a perda e limpar o estado de iteração por vez.
Passo 1
Há apenas uma linha de código para o primeiro passo:
Passo 2
O modelo de rede neural e o otimizador usados para treinamento devem já estar especificados para concluir este passo, que é apenas uma linha de comprimento.
Configurações adicionais permitem que você ajuste as tensões e tipos de operação do Amp. A função amp.initialize() aceita muitos parâmetros, entre os quais vamos especificar três deles:
- models (torch.nn.Module ou lista de torch.nn.Modules) – Modelos para modificar/cast.
- optimizadores (opcional, torch.optim.Optimizer ou lista de torch.optim.Optimizers) – Optimizadores para modificar/cast. OBRIGATÓRIO para treinamento, opcional para inferência.
- opt_level (str, opcional, padrão=“O1”) – Nível de otimização de precisão pura ou mista. Valores aceitos são “O0”, “O1”, “O2” e “O3”, explicados em detalhes acima. Há quatro níveis de otimização:
O0 para treinamento de FP32: Isto é um no-op. Não precisa se preocupar com isso pois o seu modelo de entrada deve já ser em FP32, e O0 pode ajudar a estabelecer um baseline para a precisão.
O1 para Precisão Mista (recomendado para uso típico): Modifique todos os métodos de Tensor e Torch para usar um esquema de casting de input whitelist-blacklist. Em FP16, operações whitelist, como por exemplo as operações amigas do Tensor Core como GEMMs e convoluções são executadas. A softmax, por exemplo, é uma operação blacklist que exige precisão FP32. A menos que explicitamente afirmado o contrário, O1 também usa escalonamento dinâmico de perda.
O2 para “Almost FP16” Precisão Mista: O2 casting os pesos do modelo para FP16, patchando o método forward do modelo para casting dos dados de entrada para FP16, mantendo batchorns em FP32, mantendo os master weights em FP32, atualizando o param_groups do otimizador para que o otimizador.step() agisse diretamente nos pesos em FP32 e implementando escalonamento dinâmico de perda (a menos que seja sobrescrito). Diferentemente de O1, O2 não patcha funções de Torch ou métodos de Tensor.
O3 para treinamento com FP16: O3 pode não ser tão estável quanto O1 e O2 em relação à verdadeira precisão mista. Consequentemente, pode ser vantajoso definir uma velocidade de base para seu modelo, contra a qual a eficiência de O1 e O2 possa ser avaliada.
A propriedade extra sobridefinida keep_batchnorm_fp32=True em O3 pode ajudá-lo a determinar a “velocidade da luz” se seu modelo usar normalização em lotes, permitindo a normalização em lotes cudnn.
O0 e O3 não são verdadeiras precisão mista, mas eles ajudam a definir bases de precisão e velocidade, respectivamente. Uma implementação de precisão mista é definida como O1 e O2.
Você pode tentar ambas e ver qual melhora mais a performance e a precisão de seu modelo particular.
Step 3
Certifique-se de identificar onde ocorre a passagem de trás em seu código.
Algumas linhas de código que parecem com isso aparecerão:
loss = criterion(…)
loss.backward()
optimizer.step()
Step4
Usando o gerenciador de contexto Amp, você pode habilitar a escala de perda simplesmente envolvendo a passagem de trás:
loss = criterion(…)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
É tudo. Agora você pode executar seu script com o treinamento de precisão mista habilitado.
Capturando Chamadas de Função
O PyTorch não possui um objeto de modelo estático ou gráfico para se agarrar e inserir as conversões mencionadas acima, pois é flexível e dinâmico. Através de “monkey patching” das funções necessárias, o Amp pode interceptar e converter parâmetros dinamicamente.
Por exemplo, você pode usar o código abaixo para garantir que os argumentos para o método torch.nn.functional.linear sempre forem convertidos para fp16:
orig_linear = torch.nn.functional.linear
def wrapped_linear(*args):
casted_args = []
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
casted_args.append(torch.cast(arg, torch.float16))
else:
casted_args.append(arg)
return orig_linear(*casted_args)
torch.nn.functional.linear = wrapped_linear
Embora o Amp possa adicionar refinações para tornar o código mais resistente, chamar Amp.init() efetivamente causa inserções de monkey patches em todas as funções relevantes do PyTorch para que os argumentos sejam corretamente convertidos em tempo de execução.
Minimizando Conversões
Cada peso é convertido FP32 -> FP16 apenas uma vez por iteração, pois o Amp mantém um cache interno de todas as conversões de parâmetro e reutiliza-os conforme necessário. Em cada iteração, o gerenciador de contexto para o passo de retrocesso indica ao Amp quando limpar o cache.
Autocasting e Escala de Gradientes Usando o PyTorch
“Treinamento de precisão mista automatizado” se refere à combinação de torch.cuda.amp.autocast e torch.cuda.amp.GradScaler. Usando torch.cuda.amp.autocast, você pode configurar o autocasting apenas para certas áreas. O autocasting seleciona automaticamente a precisão para as operações GPU para otimizar a eficiência enquanto mantém a precisão.
As instâncias de torch.cuda.amp.GradScaler tornam as etapas de escalonamento de gradientes mais fáceis de executar. O escalonamento de gradientes reduz o underflow de gradientes, ajudando as redes com gradientes de float16 a alcançar melhor convergência.
Aqui está um exemplo de código que demonstra como usar autocast() para obter precisão mista automatizada no PyTorch:
# Cria modelo e otimizador em precisão padrão
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Cria um GradScaler uma vez no início do treinamento.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Executa a passagem de frente com autocasting.
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# As operações de retrocesso são executadas com a mesma dtype que o autocast escolheu para as operações de frente correspondentes.
scaler.scale(loss).backward()
# scaler.step() primeiro desescala os gradientes dos parâmetros atribuídos ao otimizador.
scaler.step(optimizer)
# Atualiza o escala para a próxima iteração.
scaler.update()
Se a passagem para frente de uma operação específica tiver entradas de float16, então a passagem para trás dessa operação produz gradientes de float16, e o float16 pode não conseguir expressar gradientes com pequenas magnitude.
As atualizações para os parâmetros relacionados serão perdidas se esses valores forem esvaziados para zero (“”underflow””).
A escalonamento de gradiente é uma técnica que usa um fator de escala para multiplicar as perdas da rede e então realizar a passagem para trás no escalonado de perda para evitar o underflow. Também é necessário escalonar os gradientes que fluem para trás pela rede com esse mesmo fator. Consequentemente, os valores dos gradientes têm uma magnitude maior, o que os impede de esvaziar para zero.
Antes de atualizar os parâmetros, cada gradiente de parâmetro (atributo .grad) deve ser desescalado para que o fator de escala não interfira com a taxa de aprendizagem. Tanto autocast quanto GradScaler podem ser usados independentemente, já que são módulos.
Trabalhando com Gradientes Não escalonados
Clipping de Gradientes
Toda a escalação dos gradientes pode ser feita usando o método Scaler.scale(Loss).backward()
. As propriedades .grad
dos parâmetros entre backward()
e scaler.step(optimizer)
devem ser desescaladas antes de serem modificadas ou inspecionadas. Se você quiser limitar a norma global (veja torch.nn.utils.clip_grad_norm_()) ou a magnitude máxima (veja torch.nn.utils.clip_grad_value_()) do conjunto de gradientes para ser menor ou igual a um certo valor (um limiar imposto pelo usuário), você pode usar uma técnica chamada “clipping de gradientes”
. O clipping sem desescalar resultaria na norma/magnitude máxima dos gradientes sendo escalada, invalidando seu limiar solicitado (que era esperado para ser o limiar para gradientes desescalados). Os gradientes contidos pelos parâmetros fornecidos ao optimizador são desescalados por scaler.unscale (optimizer).
Você pode desescalar os gradientes de outros parâmetros que foram anteriormente fornecidos the outro optimizer (como optimizer1) usando scaler.unscale (optimizer1). Podemos ilustrar esse conceito adicionando duas linhas de código:
# Desescala os gradientes dos parâmetros atribuídos ao optimizador em tempo de execução
scaler.unscale_(optimizer)
# Como os gradientes dos parâmetros atribuídos ao optimizador já foram desescalados, faz o clip como de costume:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
Trabalhando com Gradientes Escalados
A acumulação de gradientes
A acumulação de gradientes é baseada em um conceito absurdamente básico. Em vez de atualizar os parâmetros do modelo, ela aguarda e acumula os gradientes através de lotes sucessivos para calcular a perda e o gradiente.
Após um determinado número de lotes, os parâmetros são atualizados dependendo do gradiente acumulado. Aqui está um trecho de código sobre como usar acumulação de gradientes usando pytorch:
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
# normalizar a perda
loss = loss / iters_to_accumulate
# Acumula gradientes escalados.
scaler.scale(loss).backward()
# atualização de pesos
if (i + 1) % iters_to_accumulate == 0:
# pode desescalar aqui se desejado
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
- A acumulação de gradientes adiciona gradientes através de um tamanho de lote adequado de batch_per_iter * iters_to_accumulate.
A escala deve ser calibrada para o lote efetivo; isso envolve verificar para graduações de inf/NaN, pular o passo se qualquer inf/NaN for detectado e atualizar a escala à granularidade do lote efetivo.
Também é crucial manter os grads em uma escala escalada e consistente quando são adicionados grads para um lote efetivo particular.
Se as graduações forem não escalonadas (ou o fator de escala muda) antes da acumulação estiver completa, a próxima passagem de retrocesso adicionará graduações escalonadas a graduações não escalonadas (ou graduações escalonadas por um fator diferente), depois da qual é impossível recuperar as graduações não escalonadas acumuladas que devem ser aplicadas.
- Você pode desescalonar graduações usando unscale pouco antes do passo, depois que todas as graduações escalonadas para o próximo passo tenham sido acumuladas.
Para garantir um lote efetivo completo, chame simplesmente update no final de cada iteração onde você chamou anteriormente step - enumerate(data) function permite que nós mantenhamos o rastreamento do índice do lote enquanto iteramos pelos dados.
- Divida a perda em execução por iters_to_accumulate(loss / iters_to_accumulate). Isso reduz a contribuição de cada mini-batch que estamos processando, normalizando a perda. Se você calcular a média da perda dentro de cada lote, a divisão já está correta e nenhuma normalização adicional é necessária. Este passo pode não ser necessário dependendo de como você calcula a perda.
- Quando usamos
scaler.scale(loss).backward()
no PyTorch, as gradientes escaladas são acumuladas e armazenadas até chamarmosoptimizer.zero_grad()
.
Multiplicar perda por taxa de aprendizado
Ao implementar uma multa de gradiente, o torch.autograd.grad() é usado para construir gradientes, que são combinados para formar o valor da multa, e depois adicionados à perda. Uma multa L2 sem escalar ou autocasting é mostrada no exemplo abaixo.
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Cria gradientes
grad_prams = torch.autograd.grad(outputs=loss,
inputs=model.parameters(),
create_graph=True)
# Computa o termo de multa e o adiciona à perda
grad_norm = 0
for grad in grad_prams:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
loss.backward()
# Você pode cortar gradientes aqui
optimizer.step()
Os tensores fornecidos a torch.autograd.grad() devem ser escalados para implementar uma multa de gradiente. É necessário desescalar os gradientes antes de combiná-los para obter o valor da multa. since a computação do termo de multa faz parte da passagem de frente, deve acontecer dentro de um contexto de autocast.
Para a mesma multa L2, aqui é como ela se parece:
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Realizar escalonamento de perda para a passagem de retrocesso de autograd.grad, resultando #scaled_grad_prams
scaled_grad_prams = torch.autograd.grad(outputs=scaler.scale(loss),
inputs=model.parameters(),
create_graph=True)
# Cria grad_prams antes de computar a penalidade (grad_prams deve ser #desescalado).
# Como nenhum otimizador possui scaled_grad_prams, é usada divisão convencional em vez de scaler.unscale_:
inv_scaled = 1./scaler.get_scale()
grad_prams = [p * inv_scaled for p in scaled_grad_prams]
# A termo de penalidade é calculado e adicionado à perda.
with autocast():
grad_norm = 0
for grad in grad_prams:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Aplica escalonamento à chamada de retrocesso.
# Acumula gradientes folhas escalonados corretamente.
scaler.scale(loss).backward()
# Você pode desescalar aqui
# step() e update() prosseguem como de costume.
scaler.step(optimizer)
scaler.update()
Trabalhando com Múltiplos Modelos, Perdas e Otimizadores
Scaler.scale deve ser chamado em cada perda na sua rede se tiver muitas delas.
Se você tiver muitos otimizadores em sua rede, você pode executar scaler.unscale em qualquer um deles, e você deve chamar scaler.step em cada um deles. No entanto, scaler.update deve ser usado apenas uma vez, depois do passo de todos os otimizadores usados nesta iteração:
scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer1.zero_grad()
optimizer2.zero_grad()
with autocast():
output1 = model1(input)
output2 = model2(input)
loss1 = loss_fn(2 * output1 + 3 * output2, target)
loss2 = loss_fn(3 * output1 - 5 * output2, target)
# Embora o retenção do gráfico seja desassociada do amp, está presente neste exemplo já que ambas as chamadas para backward() compartilham certas regiões do gráfico.
scaler.scale(loss1).backward(retain_graph=True)
scaler.scale(loss2).backward()
# Se você quiser visualizar ou ajustar os gradientes dos parâmetros que eles possuem, você pode especificar quais otimizadores obtêm desescalação explícita. .
scaler.unscale_(optimizer1)
scaler.step(optimizer1)
scaler.step(optimizer2)
scaler.update()
Cada otimizador verifica seus gradientes para infs/NaNs e faz um julgamento individual se deve ou não pular o passo. Alguns otimizadores podem pular o passo, enquanto outros podem não fazer isso. A omissão do passo acontece apenas uma vez a cada algumas centenas de iterações; portanto, não deveria afetar a convergência. Para modelos com vários otimizadores, você pode relatar o problema se você ver má convergência após adicionar a escala de gradientes.
Trabalhando com Múltiplos GPUs
Um dos problemas mais significativos com modelos de Aprendizagem Profunda é que eles estão se tornando demasiado grandes para serem treinados em um único GPU. Pode demorar muito tempo para treinar um modelo em um único GPU, e o treinamento multi-GPU é necessário para pronto os modelos o quanto antes possível. Um pesquisador bem-sucedido conseguiu encurtar o período de treinamento do ImageNet de duas semanas para 18 minutos ou treinar o mais amplo e avançado Transformer-XL em duas semanas, em vez de quatro anos.
DataParallel e DistributedDataParallel
Sem comprometer a qualidade, o PyTorch oferece a melhor combinação de facilidade de uso e controle. nn.DataParallel e nn.parallel.DistributedDataParallel são duas funcionalidades do PyTorch para distribuir o treinamento em vários GPUs. Você pode usar esses wrappers fáceis de usar e mudanças para treinar a rede em vários GPUs.
DataParallel em um único processo
Em um único computador, DataParallel ajuda a espalhar o treinamento sobre muitos GPUs.
Vamos olhar mais de perto como o DataParallel realmente funciona na prática.
Ao utilizar DataParallel para treinar uma rede neural, os seguintes estágios ocorrem:
- O mini-batch é dividido no GPU:0.
- Dividir e distribuir o mini-batch para todos os GPUs disponíveis.
- Copiar o modelo para os GPUs.
- O passo de encontro é executado em todos os GPUs.
- Calcular a perda em relação às saídas da rede no GPU:0, além de retornar as perdas aos vários GPUs. As gradientes devem ser calculadas em cada GPU.
- Soma de gradientes no GPU:0 e aplica o otimizador para atualizar o modelo.
Deverá ser observado que as preocupações discutidas aqui se aplicam unicamente a autocast. O comportamento do GradScaler permanece inalterado. Não importa se o torch.nn.DataParallel criar threads para cada dispositivo para executar a passagem de frente. O estado de autocast é comunicado em cada um deles, e o seguinte funcionará:
model = Model_m()
p_model = nn.DataParallel(model)
# Define o autocast na thread principal
with autocast():
# Haverá autocasting em p_model.
output = p_model(input)
# loss_fn também é autocast
loss = loss_fn(output)
DistributedDataParallel, um GPU por processo
A documentação para torch.nn.parallel.DistributedDataParallel recomenda o uso de um GPU por processo para melhor desempenho. Nesta situação, DistributedDataParallel não lança threads internamente; portanto, o uso de autocast e GradScaler não é afetado.
DistributedDataParallel, vários GPUs por processo
Aqui, torch.nn.parallel.DistributedDataParallel pode gerar uma thread secundária para executar a passagem de frente em cada dispositivo, como torch.nn.DataParallel. A solução é a mesma: aplicar autocast como parte do método forward do seu modelo para garantir que estiver habilitado em threads secundárias.
Conclusão
Neste artigo, nós :
- Iniciamos Apex.
- Vimos como o Amps funciona.
- Vimos como realizar escalonamento de gradiente, corte de gradiente, acumulação de gradiente e penáltia de gradiente.
- Vimos como podemos trabalhar com vários modelos, perdas e otimizadores.
- Vimos como podemos executar DataParallel em um único processo quando trabalhando com vários GPU.
Referências
https://developer.nvidia.com/blog/apex-pytorch-easy-mixed-precision-training/
https://nvidia.github.io/apex/amp.html
https://discuss.pytorch.org/t/accumulating-gradients/30020
https://towdatascience.com/how-to-scale-training-on-multiple-gpus-dae1041f49d2
Source:
https://www.digitalocean.com/community/tutorials/automatic-mixed-precision-using-pytorch