El Modelo de Difusión Anotado

'Annotated Diffusion Model'

En esta publicación de blog, echaremos un vistazo más profundo a los Modelos Probabilísticos de Difusión para Desruido (también conocidos como DDPM, modelos de difusión, modelos generativos basados en puntajes o simplemente autoencoders) ya que los investigadores han logrado resultados notables con ellos para generación de imágenes/sonido/video (in)condicional. Ejemplos populares (al momento de escribir esto) incluyen GLIDE y DALL-E 2 de OpenAI, Latent Diffusion de la Universidad de Heidelberg e ImageGen de Google Brain.

Repasaremos el artículo original de DDPM de (Ho et al., 2020), implementándolo paso a paso en PyTorch, basado en la implementación de Phil Wang, que a su vez se basa en la implementación original de TensorFlow. Hay que tener en cuenta que la idea de la difusión para la modelización generativa ya se introdujo en (Sohl-Dickstein et al., 2015). Sin embargo, pasó hasta (Song et al., 2019) (en la Universidad de Stanford) y luego (Ho et al., 2020) (en Google Brain) quienes mejoraron el enfoque de manera independiente.

Ten en cuenta que hay varias perspectivas sobre los modelos de difusión. Aquí, utilizamos la perspectiva de un modelo de variables latentes en tiempo discreto, pero asegúrate de consultar las otras perspectivas también.

¡Bien, vamos a sumergirnos!

from IPython.display import Image
Image(filename='assets/78_annotated-diffusion/ddpm_paper.png')

Primero, instalaremos e importaremos las bibliotecas necesarias (asumiendo que tienes PyTorch instalado).

!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

¿Qué es un modelo de difusión?

Un modelo de difusión (de desruido) no es tan complejo si lo comparas con otros modelos generativos como Flujos Normalizados, GANs o VAEs: todos convierten ruido de alguna distribución simple en una muestra de datos. Este también es el caso aquí, donde una red neuronal aprende a desruidar gradualmente los datos comenzando desde ruido puro.

Un poco más en detalle para las imágenes, la configuración consta de 2 procesos:

  • un proceso de difusión hacia adelante fijo (o predefinido) qq q de nuestra elección, que agrega gradualmente ruido gaussiano a una imagen hasta que terminas con ruido puro
  • un proceso de desruido de difusión hacia atrás aprendido p θ p_\theta p θ ​ , donde una red neuronal se entrena para desruidar gradualmente una imagen comenzando desde ruido puro hasta que terminas con una imagen real.

Tanto el proceso hacia adelante como el proceso hacia atrás indexado por tt t ocurren durante un número finito de pasos de tiempo TT T (los autores de DDPM usan T = 1000). Comienzas con t = 0t=0t = 0 donde muestreas una imagen real x 0 \mathbf{x}_0 x 0 ​ de la distribución de tus datos (digamos una imagen de un gato de ImageNet), y el proceso hacia adelante muestrea algo de ruido de una distribución gaussiana en cada paso de tiempo tt t , que se agrega a la imagen del paso de tiempo anterior. Dado un valor suficientemente grande de TT T y un horario bien comportado para agregar ruido en cada paso de tiempo, terminas con lo que se llama una distribución gaussiana isotrópica en t = Tt=Tt = T a través de un proceso gradual.

En forma más matemática

Escribamos esto de forma más formal, ya que finalmente necesitamos una función de pérdida tratable que nuestra red neuronal debe optimizar.

Sea qq q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ​ ) la distribución de datos reales, digamos de “imágenes reales”. Podemos muestrear de esta distribución para obtener una imagen, x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x 0 ​ ∼ q ( x 0 ​ ) . Definimos el proceso de difusión hacia adelante qq q ( x t ∣ x t − 1 ) q(\mathbf{x}_t | \mathbf{x}_{t-1}) q ( x t ​ ∣ x t − 1 ​ ) que agrega ruido gaussiano en cada paso de tiempo tt t , de acuerdo con un horario de varianza conocido 0 < β 1 < β 2 < . . . < β T < 1 0 < \beta_1 < \beta_2 < … < \beta_T < 1 0 < β 1 ​ < β 2 ​ < . . . < β T ​ < 1 como q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) . q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 – \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}). q ( x t ​ ∣ x t − 1 ​ ) = N ( x t ​ ; 1 − β t ​ ​ x t − 1 ​ , β t ​ I ) .

Recordemos que una distribución normal (también llamada distribución gaussiana) está definida por 2 parámetros: una media μ \mu μ y una varianza σ 2 ≥ 0 \sigma^2 \geq 0 σ 2 ≥ 0 . Básicamente, cada nueva imagen (ligeramente más ruidosa) en el paso de tiempo t t t se extrae de una distribución gaussiana condicional con μ t = 1 − β t x t − 1 \mathbf{\mu}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} μ t ​ = 1 − β t ​ ​ x t − 1 ​ y σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ , lo cual podemos hacer muestreando ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) y luego estableciendo x t = 1 − β t x t − 1 + β t ϵ \mathbf{x}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} x t ​ = 1 − β t ​ ​ x t − 1 ​ + β t ​ ​ ϵ .

Observa que los β t \beta_t β t ​ no son constantes en cada paso de tiempo t t t (de ahí el subíndice) — de hecho se define una llamada “planificación de la varianza”, que puede ser lineal, cuadrática, coseno, etc. como veremos más adelante (un poco como una planificación de la tasa de aprendizaje).

Entonces, comenzando desde x 0 \mathbf{x}_0 x 0 ​ , terminamos con x 1 , . . . , x t , . . . , x T \mathbf{x}_1, …, \mathbf{x}_t, …, \mathbf{x}_T x 1 ​ , . . . , x t ​ , . . . , x T ​ , donde x T \mathbf{x}_T x T ​ es ruido gaussiano puro si establecemos la planificación de manera apropiada.

Ahora, si conociéramos la distribución condicional p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) , entonces podríamos ejecutar el proceso en reversa: muestreando algo de ruido gaussiano aleatorio x T \mathbf{x}_T x T ​ , y luego gradualmente “desruido” hasta obtener una muestra de la distribución real x 0 \mathbf{x}_0 x 0 ​ .

Sin embargo, no conocemos p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) . Es inmanejable ya que requiere conocer la distribución de todas las imágenes posibles para poder calcular esta probabilidad condicional. Por lo tanto, vamos a aprovechar una red neuronal para aproximar (aprender) esta distribución de probabilidad condicional , llamémosla p θ ( x t − 1 ∣ x t ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) p θ ​ ( x t − 1 ​ ∣ x t ​ ) , con θ \theta θ siendo los parámetros de la red neuronal, actualizados mediante descenso de gradiente.

De acuerdo, entonces necesitamos una red neuronal para representar una distribución de probabilidad (condicional) del proceso inverso. Si asumimos que este proceso inverso también es gaussiano, entonces recordemos que cualquier distribución gaussiana está definida por 2 parámetros:

  • una media parametrizada por μ θ \mu_\theta μ θ ​ ;
  • una varianza parametrizada por Σ θ \Sigma_\theta Σ θ ​ ;

por lo tanto, podemos parametrizar el proceso como p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t)) p θ ​ ( x t − 1 ​ ∣ x t ​ ) = N ( x t − 1 ​ ; μ θ ​ ( x t ​ , t ) , Σ θ ​ ( x t ​ , t ) ) donde la media y la varianza también se condicionan en el nivel de ruido t t t .

Por lo tanto, nuestra red neuronal necesita aprender/representar la media y la varianza. Sin embargo, los autores de DDPM decidieron mantener la varianza fija y permitir que la red neuronal solo aprenda (represente) la media μ θ \mu_\theta μ θ ​ de esta distribución de probabilidad condicional. Según el artículo:

Primero, establecemos Σ θ ( x t , t ) = σ t 2 I \Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I} Σ θ ​ ( x t ​ , t ) = σ t 2 ​ I como constantes dependientes del tiempo no entrenadas. Experimentalmente, tanto σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ como σ t 2 = β ~ t \sigma^2_t = \tilde{\beta}_t σ t 2 ​ = β ~ ​ t ​ (ver artículo) tuvieron resultados similares.

Posteriormente, esto se mejoró en el artículo de Improved diffusion models, donde una red neuronal también aprende la varianza de este proceso inverso, además de la media.

Por lo tanto, continuamos asumiendo que nuestra red neuronal solo necesita aprender/representar la media de esta distribución de probabilidad condicional.

Definir una función objetivo (reparametrizando la media)

Para derivar una función objetivo para aprender la media del proceso inverso, los autores observan que la combinación de q q q y p θ p_\theta p θ ​ se puede ver como un autoencoder variacional (VAE) (Kingma et al., 2013). Por lo tanto, se puede utilizar el <strong{límite inferior variacional} (también llamado ELBO) para minimizar el logaritmo negativo de la verosimilitud con respecto a la muestra de datos de la verdad básica x 0 \mathbf{x}_0 x 0 ​ (nos referimos al artículo de VAE para más detalles sobre ELBO). Resulta que el ELBO para este proceso es una suma de pérdidas en cada paso de tiempo t t t , L = L 0 + L 1 + . . . + L T L = L_0 + L_1 + … + L_T L = L 0 ​ + L 1 ​ + . . . + L T ​ . Debido a la construcción del proceso directo q q q y el proceso inverso, cada término (excepto L 0 L_0 L 0 ​ ) de la pérdida es en realidad la <strong{divergencia KL entre 2 distribuciones gaussianas} que se puede escribir explícitamente como una pérdida L2 con respecto a las medias.

Una consecuencia directa del proceso directo construido q q q, como se muestra por Sohl-Dickstein et al., es que podemos muestrear x t \mathbf{x}_t x t ​ en cualquier nivel de ruido arbitrario condicionado a x 0 \mathbf{x}_0 x 0 ​ (ya que las sumas de gaussianas también son gaussianas). Esto es muy conveniente: no necesitamos aplicar q q q repetidamente para muestrear x t \mathbf{x}_t x t ​. Tenemos que q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t | \mathbf{x}_0) = \cal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1- \bar{\alpha}_t) \mathbf{I}) q ( x t ​ ∣ x 0 ​ ) = N ( x t ​ ; α ˉ t ​ ​ x 0 ​ , ( 1 − α ˉ t ​ ) I )

con α t : = 1 − β t \alpha_t := 1 – \beta_t α t ​ : = 1 − β t ​ y α ˉ t : = Π s = 1 t α s \bar{\alpha}_t := \Pi_{s=1}^{t} \alpha_s α ˉ t ​ : = Π s = 1 t ​ α s ​ . Llamemos a esta ecuación la “propiedad agradable”. Esto significa que podemos muestrear ruido gaussiano y escalarlo adecuadamente y agregarlo a x 0 \mathbf{x}_0 x 0 ​ para obtener x t \mathbf{x}_t x t ​ directamente. Tenga en cuenta que α ˉ t \bar{\alpha}_t α ˉ t ​ son funciones del horario de varianza β t \beta_t β t ​ conocido y, por lo tanto, también son conocidos y se pueden precalcular. Esto nos permite, durante el entrenamiento, <strong{optimizar términos aleatorios de la función de pérdida L L L} (o en otras palabras, muestrear aleatoriamente t t t durante el entrenamiento y optimizar L t L_t L t ​).

Otra belleza de esta propiedad, como se muestra en Ho et al., es que uno puede (después de algunos cálculos, a los cuales referimos al lector a esta excelente publicación de blog) en su lugar reparametrizar la media para hacer que la red neuronal aprenda (prediga) el ruido añadido (a través de una red ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) ) para el nivel de ruido t t t en los términos de KL que constituyen las pérdidas. Esto significa que nuestra red neuronal se convierte en un predictor de ruido, en lugar de un predictor de media (directo). La media se puede calcular de la siguiente manera:

μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t – \frac{\beta_t}{\sqrt{1- \bar{\alpha}_t}} \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \right) μ θ ​ ( x t ​ , t ) = α t ​ ​ 1 ​ ( x t ​ − 1 − α ˉ t ​ ​ β t ​ ​ ϵ θ ​ ( x t ​ , t ) )

La función objetivo final L t L_t L t ​ se ve entonces de la siguiente manera (para un paso de tiempo aleatorio t t t dado ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) ):

∥ ϵ − ϵ θ ( x t , t ) ∥ 2 = ∥ ϵ − ϵ θ ( α ˉ t x 0 + ( 1 − α ˉ t ) ϵ , t ) ∥ 2 . \| \mathbf{\epsilon} – \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \|^2 = \| \mathbf{\epsilon} – \mathbf{\epsilon}_\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{(1- \bar{\alpha}_t) } \mathbf{\epsilon}, t) \|^2. ∥ ϵ − ϵ θ ​ ( x t ​ , t ) ∥ 2 = ∥ ϵ − ϵ θ ​ ( α ˉ t ​ ​ x 0 ​ + ( 1 − α ˉ t ​ ) ​ ϵ , t ) ∥ 2 .

Aquí, x 0 \mathbf{x}_0 x 0 ​ es la imagen inicial (real, no corrompida) y vemos el nivel de ruido directo t t t dado por el proceso directo fijo. ϵ \mathbf{\epsilon} ϵ es el ruido puro muestreado en el paso de tiempo t t t, y ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta (\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) es nuestra red neuronal. La red neuronal se optimiza utilizando un error cuadrático medio (MSE) entre el ruido gaussiano verdadero y el predicho.

El algoritmo de entrenamiento se ve de la siguiente manera:

En otras palabras:

  • tomamos una muestra aleatoria x 0 \mathbf{x}_0 x 0 ​ de la distribución de datos reales desconocida y posiblemente compleja q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ​ )
  • tomamos una nivel de ruido t t t uniformemente entre 1 1 1 y T T T (es decir, un paso de tiempo aleatorio)
  • muestreamos algo de ruido de una distribución gaussiana y corrompemos la entrada con este ruido en el nivel t t t (usando la propiedad definida anteriormente)
  • se entrena la red neuronal para predecir este ruido basado en la imagen corrompida x t \mathbf{x}_t x t ​ (es decir, ruido aplicado a x 0 \mathbf{x}_0 x 0 ​ basado en el programa conocido β t \beta_t β t ​ )

En realidad, todo esto se hace en lotes de datos, ya que se utiliza el descenso de gradiente estocástico para optimizar las redes neuronales.

La red neuronal

La red neuronal debe tomar una imagen con ruido en un paso de tiempo específico y devolver el ruido predicho. Tenga en cuenta que el ruido predicho es un tensor que tiene el mismo tamaño/resolución que la imagen de entrada. Por lo tanto, técnicamente, la red recibe y devuelve tensores de la misma forma. ¿Qué tipo de red neuronal podemos usar para esto?

Lo que se utiliza típicamente aquí es muy similar a un Autoencoder, que es posible que recuerde de tutoriales típicos de “introducción al aprendizaje profundo”. Los Autoencoders tienen una capa llamada “cuello de botella” entre el codificador y el decodificador. El codificador primero codifica una imagen en una representación oculta más pequeña llamada “cuello de botella”, y luego el decodificador decodifica esa representación oculta en una imagen real. Esto obliga a la red a mantener solo la información más importante en la capa del cuello de botella.

En términos de arquitectura, los autores de DDPM optaron por una U-Net, presentada por (Ronneberger et al., 2015) (que en ese momento logró resultados de vanguardia en la segmentación de imágenes médicas). Esta red, al igual que cualquier autoencoder, consta de un cuello de botella en el medio que se asegura de que la red aprenda solo la información más importante. Es importante destacar que se introdujeron conexiones residuales entre el codificador y el decodificador, mejorando en gran medida el flujo de gradientes (inspirado en ResNet de He et al., 2015).

Como se puede ver, un modelo U-Net primero reduce el tamaño de la entrada (es decir, hace que la entrada sea más pequeña en términos de resolución espacial), y luego se realiza un aumento de tamaño.

A continuación, implementamos esta red paso a paso.

Funciones auxiliares de la red

Primero, definimos algunas funciones y clases auxiliares que se utilizarán al implementar la red neuronal. Es importante destacar que se define un módulo Residual, que simplemente suma la entrada a la salida de una función en particular (en otras palabras, agrega una conexión residual a una función específica).

También definimos alias para las operaciones de aumento y reducción de tamaño.

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

Incrustaciones de posición

Dado que los parámetros de la red neuronal se comparten a lo largo del tiempo (nivel de ruido), los autores utilizan incrustaciones sinusoidales de posición para codificar t t t, inspiradas en el Transformer (Vaswani et al., 2017). Esto hace que la red neuronal “sepa” en qué paso de tiempo particular (nivel de ruido) está funcionando, para cada imagen en un lote.

El módulo SinusoidalPositionEmbeddings toma un tensor de forma (batch_size, 1) como entrada (es decir, los niveles de ruido de varias imágenes ruidosas en un lote) y lo convierte en un tensor de forma (batch_size, dim), donde dim es la dimensionalidad de las incrustaciones de posición. Esto se agrega a cada bloque residual, como veremos más adelante.

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

Bloque ResNet

A continuación, definimos el bloque de construcción principal del modelo U-Net. Los autores de DDPM emplearon un bloque de ResNet amplio (Zagoruyko et al., 2016), pero Phil Wang ha reemplazado la capa convolucional estándar por una versión “estandarizada de peso”, que funciona mejor en combinación con la normalización de grupo (ver (Kolesnikov et al., 2019) para más detalles).

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    La estandarización de pesos supuestamente funciona de manera sinérgica con la normalización de grupo
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

Módulo de atención

A continuación, definimos el módulo de atención, que los autores de DDPM agregaron entre los bloques convolucionales. La atención es el bloque de construcción de la famosa arquitectura Transformer (Vaswani et al., 2017), que ha mostrado gran éxito en diversos campos de la IA, desde el procesamiento del lenguaje natural y la visión hasta el plegamiento de proteínas. Phil Wang utiliza 2 variantes de atención: una es la autoatención multi-cabeza regular (como se usa en el Transformer), la otra es una variante de atención lineal (Shen et al., 2018), cuyos requisitos de tiempo y memoria escalan linealmente en la longitud de la secuencia, a diferencia de la atención regular que escala cuadráticamente.

Para una explicación detallada del mecanismo de atención, remitimos al lector a la maravillosa publicación en el blog de Jay Allamar.

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

Normalización de grupo

Los autores de DDPM intercalan las capas de convolución/atención de la U-Net con normalización de grupo (Wu et al., 2018). A continuación, definimos una clase PreNorm, que se utilizará para aplicar la groupnorm antes de la capa de atención, como veremos más adelante. Es importante destacar que ha habido un debate sobre si aplicar la normalización antes o después de la atención en los Transformers.

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

U-Net condicional

Ahora que hemos definido todos los bloques de construcción (incrustaciones de posición, bloques ResNet, atención y normalización de grupo), es hora de definir la red neuronal completa. Recordemos que el objetivo de la red ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) es tomar un lote de imágenes ruidosas y sus respectivos niveles de ruido como entrada, y devolver el ruido añadido a la imagen de entrada. De manera más formal:

  • la red recibe como entrada un lote de imágenes ruidosas de forma (tamaño_lote, num_canales, altura, ancho) y un lote de niveles de ruido de forma (tamaño_lote, 1), y devuelve un tensor de forma (tamaño_lote, num_canales, altura, ancho)

La red se construye de la siguiente manera:

  • primero, se aplica una capa convolucional en el lote de imágenes ruidosas, y se calculan las incrustaciones de posición para los niveles de ruido
  • a continuación, se aplican una serie de etapas de submuestreo. Cada etapa de submuestreo consta de 2 bloques ResNet + groupnorm + atención + conexión residual + una operación de submuestreo
  • en el centro de la red, nuevamente se aplican bloques ResNet, intercalados con atención
  • a continuación, se aplican una serie de etapas de sobremuestreo. Cada etapa de sobremuestreo consta de 2 bloques ResNet + groupnorm + atención + conexión residual + una operación de sobremuestreo
  • finalmente, se aplica un bloque ResNet seguido de una capa convolucional.

En última instancia, las redes neuronales apilan capas como si fueran bloques de lego (pero es importante entender cómo funcionan).

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # determinar dimensiones
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # cambiado a 1 y 0 desde 7,3

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # incrustaciones de tiempo
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # capas
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

Definiendo el proceso de difusión directa

El proceso de difusión directa añade gradualmente ruido a una imagen desde la distribución real, en un número de pasos de tiempo T T T . Esto sucede de acuerdo a un programa de varianza . Los autores originales de DDPM emplearon un programa lineal:

Establecemos las varianzas del proceso directo en constantes que aumentan linealmente desde β 1 = 1 0 − 4 \beta_1 = 10^{−4} β 1 ​ = 1 0 − 4 hasta β T = 0.02 \beta_T = 0.02 β T ​ = 0 . 0 2 .

Sin embargo, se demostró en (Nichol et al., 2021) que se pueden obtener mejores resultados al emplear un programa coseno.

A continuación, definimos varios programas para los pasos de tiempo T T T (elegiremos uno más adelante).

def programa_beta_coseno(pasos_de_tiempo, s=0.008):
    """
    Programa coseno propuesto en https://arxiv.org/abs/2102.09672
    """
    pasos = pasos_de_tiempo + 1
    x = torch.linspace(0, pasos_de_tiempo, pasos)
    alphas_cumprod = torch.cos(((x / pasos_de_tiempo) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def programa_beta_lineal(pasos_de_tiempo):
    beta_inicio = 0.0001
    beta_fin = 0.02
    return torch.linspace(beta_inicio, beta_fin, pasos_de_tiempo)

def programa_beta_cuadratico(pasos_de_tiempo):
    beta_inicio = 0.0001
    beta_fin = 0.02
    return torch.linspace(beta_inicio**0.5, beta_fin**0.5, pasos_de_tiempo) ** 2

def programa_beta_sigmoidal(pasos_de_tiempo):
    beta_inicio = 0.0001
    beta_fin = 0.02
    betas = torch.linspace(-6, 6, pasos_de_tiempo)
    return torch.sigmoid(betas) * (beta_fin - beta_inicio) + beta_inicio

Para empezar, usemos el programa lineal para T = 300 pasos de tiempo y definamos las diversas variables a partir de los β t \beta_t β t ​ que necesitaremos, como el producto acumulativo de las varianzas α ˉ t \bar{\alpha}_t α ˉ t ​ . Cada una de las variables a continuación son tensores unidimensionales, que almacenan valores desde t t t hasta T T T . Es importante destacar que también definimos una función extract, que nos permitirá extraer el índice t t t apropiado para un lote de índices.

pasos_de_tiempo = 300

# definir programa beta
betas = programa_beta_lineal(pasos_de_tiempo=pasos_de_tiempo)

# definir alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# cálculos para la difusión q(x_t | x_{t-1}) y otros
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# cálculos para la posterior q(x_{t-1} | x_t, x_0)
varianza_posterior = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    tamaño_lote = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(tamaño_lote, *((1,) * (len(x_shape) - 1))).to(t.device)

Ilustraremos con una imagen de gatos cómo se añade ruido en cada paso de tiempo del proceso de difusión.

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
imagen = Image.open(requests.get(url, stream=True).raw) # imagen PIL de forma HWC
imagen

El ruido se agrega a los tensores de PyTorch, en lugar de las imágenes de Pillow. Primero definiremos transformaciones de imagen que nos permitan convertir una imagen PIL a un tensor de PyTorch (al que podemos agregar el ruido) y viceversa.

Estas transformaciones son bastante simples: primero normalizamos las imágenes dividiendo por 255 (para que estén en el rango [0, 1]) y luego nos aseguramos de que estén en el rango [-1, 1]. Según el artículo de DPPM:

Suponemos que los datos de imagen consisten en enteros en {0, 1, …, 255} escalados linealmente a [-1, 1]. Esto asegura que el proceso de reversión de la red neuronal opere sobre entradas escaladas de manera consistente, comenzando desde la distribución normal estándar prior p (xT).

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # convertir a tensor de PyTorch de forma CHW, dividir por 255
    Lambda(lambda t: (t * 2) - 1),
    
])

x_start = transform(image).unsqueeze(0)
x_start.shape

También definimos la transformación inversa, que toma un tensor de PyTorch que contiene valores en el rango [-1, 1] y los convierte de nuevo en una imagen PIL:

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW a HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

Verifiquemos esto:

reverse_transform(x_start.squeeze())

Ahora podemos definir el proceso de difusión directa como en el artículo:

# difusión directa (usando la propiedad agradable)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

Probémoslo en un paso de tiempo particular:

def get_noisy_image(x_start, t):
  # agregar ruido
  x_noisy = q_sample(x_start, t=t)

  # convertir de nuevo a imagen PIL
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image

# tomar paso de tiempo
t = torch.tensor([40])

get_noisy_image(x_start, t)

Visualicémoslo para varios pasos de tiempo:

import matplotlib.pyplot as plt

# usar semilla para reproducibilidad
torch.manual_seed(0)

# fuente: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Crear una cuadrícula 2D incluso si solo hay 1 fila
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Imagen original')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

Esto significa que ahora podemos definir la función de pérdida dada el modelo de la siguiente manera:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

El denoise_model será nuestro U-Net definido anteriormente. Utilizaremos la pérdida de Huber entre el ruido verdadero y el ruido predicho.

Definir un Dataset + DataLoader de PyTorch

Aquí definimos un Dataset regular de PyTorch. El Dataset simplemente consiste en imágenes de un conjunto de datos real, como Fashion-MNIST, CIFAR-10 o ImageNet, escaladas linealmente a [ − 1 , 1 ] [−1, 1] [ − 1 , 1 ] .

Cada imagen se redimensiona al mismo tamaño. Es interesante destacar que las imágenes también se voltean horizontalmente al azar. Según el artículo:

Utilizamos volteos horizontales aleatorios durante el entrenamiento para CIFAR10; probamos entrenar tanto con como sin volteos, y encontramos que los volteos mejoran ligeramente la calidad de las muestras.

Aquí utilizamos la biblioteca 🤗 Datasets para cargar fácilmente el conjunto de datos Fashion MNIST desde el centro. Este conjunto de datos consta de imágenes que ya tienen la misma resolución, es decir, 28×28.

from datasets import load_dataset

# cargar el conjunto de datos desde el centro
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128

A continuación, definimos una función que aplicaremos sobre la marcha en todo el conjunto de datos. Utilizamos la funcionalidad with_transform para eso. La función simplemente aplica algunos procesamientos básicos de imágenes: volteos horizontales aleatorios, reescalado y finalmente los hace tener valores en el rango [ − 1 , 1 ] [-1,1] [ − 1 , 1 ] .

from torchvision import transforms
from torch.utils.data import DataLoader

# definir transformaciones de imagen (por ejemplo, utilizando torchvision)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# definir función
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# crear dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

batch = next(iter(dataloader))
print(batch.keys())

Muestreo

Dado que vamos a muestrear del modelo durante el entrenamiento (con el fin de hacer un seguimiento del progreso), definimos el código para eso a continuación. El muestreo se resume en el artículo como el Algoritmo 2:

Generar nuevas imágenes a partir de un modelo de difusión implica invertir el proceso de difusión: empezamos desde T T T , donde muestreamos ruido puro de una distribución gaussiana, y luego utilizamos nuestra red neuronal para desruidizarlo gradualmente (usando la probabilidad condicional que ha aprendido), hasta llegar al paso de tiempo t = 0 t = 0 t = 0 . Como se muestra arriba, podemos derivar una imagen ligeramente menos desruidizada x t − 1 \mathbf{x}_{t-1 } x t − 1 ​ enchufando la reparametrización de la media, usando nuestro predictor de ruido. Recuerda que la varianza se conoce de antemano.

Idealmente, terminamos con una imagen que parece haber venido de la distribución de datos reales.

El código a continuación implementa esto.

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Ecuación 11 en el artículo
    # Utilizar nuestro modelo (predictor de ruido) para predecir la media
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algoritmo 2 línea 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algoritmo 2 (incluyendo devolver todas las imágenes)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # comenzar desde ruido puro (para cada ejemplo en el lote)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

Tenga en cuenta que el código anterior es una versión simplificada de la implementación original. Encontramos que nuestra simplificación (que está en línea con el Algoritmo 2 en el artículo) funciona tan bien como la implementación original más compleja, que emplea el recorte.

Entrenar el modelo

A continuación, entrenamos el modelo de manera habitual en PyTorch. También definimos cierta lógica para guardar periódicamente imágenes generadas, utilizando el método sample definido anteriormente.

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

A continuación, definimos el modelo y lo movemos a la GPU. También definimos un optimizador estándar (Adam).

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

¡Comencemos el entrenamiento!

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algoritmo 1 línea 3: muestrear t uniformemente para cada ejemplo en el lote
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Pérdida:", loss.item())

      loss.backward()
      optimizer.step()

      # guardar imágenes generadas
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

Muestreo (inferencia)

Para muestrear del modelo, simplemente podemos usar nuestra función de muestra definida anteriormente:

# muestrear 64 imágenes
muestras = sample(model, image_size=image_size, batch_size=64, channels=channels)

# mostrar una al azar
índice_aleatorio = 5
plt.imshow(muestras[-1][índice_aleatorio].reshape(image_size, image_size, channels), cmap="gray")

¡Parece que el modelo es capaz de generar una bonita camiseta! Tenga en cuenta que el conjunto de datos en el que entrenamos es de baja resolución (28×28).

También podemos crear un gif del proceso de eliminación de ruido:

import matplotlib.animation as animation

índice_aleatorio = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(muestras[i][índice_aleatorio].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

Tenga en cuenta que el artículo de DDPM mostró que los modelos de difusión son una dirección prometedora para la generación de imágenes (in)condicionales. Desde entonces, esto se ha mejorado enormemente, especialmente para la generación de imágenes condicionadas por texto. A continuación, enumeramos algunos trabajos importantes (pero lejos de ser exhaustivos) de seguimiento:

  • Modelos probabilísticos de difusión mejorados para la eliminación de ruido (Nichol et al., 2021): encuentra que aprender la varianza de la distribución condicional (además de la media) ayuda a mejorar el rendimiento
  • Modelos de difusión en cascada para generación de imágenes de alta fidelidad (Ho et al., 2021): introduce la difusión en cascada, que comprende una secuencia de varios modelos de difusión que generan imágenes de resolución creciente para la síntesis de imágenes de alta fidelidad
  • Los modelos de difusión superan a las GAN en la síntesis de imágenes (Dhariwal et al., 2021): muestran que los modelos de difusión pueden lograr una calidad de muestra de imagen superior a los modelos generativos de vanguardia actual mejorando la arquitectura U-Net, así como introduciendo orientación de clasificadores
  • Orientación de difusión sin clasificador (Ho et al., 2021): muestra que no se necesita un clasificador para guiar un modelo de difusión mediante el entrenamiento conjunto de un modelo de difusión condicional y no condicional con una única red neuronal
  • Generación de imágenes textuales condicionadas jerárquicas con CLIP Latents (DALL-E 2) (Ramesh et al., 2022): utiliza una prioridad para convertir un título de texto en una incrustación de imagen CLIP, después de lo cual un modelo de difusión lo decodifica en una imagen
  • Modelos de difusión de texto a imagen fotorrealistas con comprensión profunda del lenguaje (ImageGen) (Saharia et al., 2022): muestra que combinar un modelo de lenguaje preentrenado grande (por ejemplo, T5) con difusión en cascada funciona bien para la síntesis de texto a imagen

Tenga en cuenta que esta lista solo incluye trabajos importantes hasta la fecha de escritura, que es el 7 de junio de 2022.

Por ahora, parece que la principal (tal vez única) desventaja de los modelos de difusión es que requieren múltiples pasadas hacia adelante para generar una imagen (lo cual no es el caso de los modelos generativos como GANs). Sin embargo, hay investigaciones en curso que permiten generar con alta fidelidad en tan solo 10 pasos de eliminación de ruido.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Ciencia de Datos

El Desafío de Ver la Imagen Completa de la Inteligencia Artificial

Cada vez es más difícil tener conversaciones reflexivas sobre el cambiante (y rápidamente creciente) impacto de la IA...

Inteligencia Artificial

Conoce a MetaGPT El asistente de IA impulsado por ChatGPT que convierte texto en aplicaciones web.

¡Esta revolucionaria herramienta de IA te permite crear aplicaciones web sin código en solo segundos!

Inteligencia Artificial

Destilando lo que sabemos

Los investigadores buscan reducir el tamaño de los modelos GPT grandes.

Inteligencia Artificial

Miles en fila para obtener el implante de chip cerebral de Neuralink, de Elon Musk

El implante, diseñado para reemplazar la pieza del cráneo removida, leerá y analizará la actividad cerebral de la per...