Explicación del artículo de las Redes de Atención en Grafos con ilustración e implementación en PyTorch

Explicación de las Redes de Atención en Grafos con ilustración e implementación en PyTorch

Una descripción detallada e ilustrada del artículo “Graph Attention Networks” de Veličković et al. con la implementación en PyTorch del modelo propuesto.

Ilustración de la capa de paso de mensajes en una red de atención de gráficos — imagen del autor

Introducción

Las redes neuronales gráficas (GNN) son una poderosa clase de redes neuronales que operan en datos estructurados en forma de grafo. Aprenden representaciones (incrustaciones) de los nodos al agregar información del vecindario local de un nodo. Este concepto se conoce como ‘paso de mensajes’ en la literatura de aprendizaje de representaciones de gráficos.

Los mensajes (incrustaciones) se pasan entre nodos en el grafo a través de múltiples capas de la GNN. Cada nodo agrega los mensajes de sus vecinos para actualizar su representación. Este proceso se repite en todas las capas, permitiendo que los nodos obtengan representaciones que codifican información más rica sobre el grafo. Algunas de las variantes importantes de las GNN son GraphSAGE [2], Graph Convolution Network [3], etc. Puedes explorar más variantes de GNN aquí.

Ilustración simple de un paso único de paso de mensajes — imagen del autor

Redes de Atención de Gráficos (GAT) [1] son una clase especial de GNN que se propusieron para mejorar este esquema de paso de mensajes. Introdujeron un mecanismo de atención aprendible que permite a un nodo decidir qué nodos vecinos son más importantes al agregar mensajes de su vecindario local asignando un peso entre cada nodo de origen y destino en lugar de agregar información de todos los vecinos con pesos iguales.

Empíricamente, se ha demostrado que las Redes de Atención de Gráficos superan a muchos otros modelos de GNN en tareas como clasificación de nodos, predicción de enlaces y clasificación de gráficos. Han demostrado un rendimiento de vanguardia en varios conjuntos de datos de gráficos de referencia.

En esta publicación, recorreremos la parte crucial del artículo original “Graph Attention Networks” de Veličković et al. [1], explicaremos estas partes e implementaremos simultáneamente las nociones propuestas en el artículo utilizando el marco de trabajo PyTorch para comprender mejor la intuición del método GAT.

También puedes acceder al código completo utilizado en esta publicación, que contiene el código de entrenamiento y validación en este repositorio de GitHub

Leyendo el Artículo

Sección 1 — Introducción

Después de revisar ampliamente los métodos existentes en la literatura de aprendizaje de representaciones de gráficos en la Sección 1, “Introducción”, se presenta la Red de Atención de Gráficos (GAT). Los autores mencionan:

  1. Una visión general del mecanismo de atención incorporado.
  2. Tres propiedades de GAT, a saber, cálculo eficiente, aplicabilidad general a todos los nodos y utilidad en el aprendizaje inductivo.
  3. Referencias y conjuntos de datos en los que evaluaron el rendimiento de GAT.
sección seleccionada del artículo original de GAT

Luego, después de comparar su enfoque con algunos métodos existentes y mencionar las similitudes y diferencias generales entre ellos, pasan a la siguiente sección del artículo.

Sección 2 — Arquitectura GAT

En esta sección, que constituye la parte principal del artículo, se presenta en detalle la arquitectura de la Red de Atención de Gráficos. Para continuar con la explicación, supongamos que la arquitectura propuesta funciona en un grafo con N nodos (V = {vᵢ}; i=1,…,N) y cada nodo se representa con un vector hᵢ de F elementos, con cualquier configuración arbitraria de aristas existentes entre los nodos.

Ejemplo de gráfico de entrada - imagen por el autor

Los autores comienzan caracterizando una sola Capa de Atención de Gráficos y cómo opera, lo cual se convierte en los bloques de construcción de una Red de Atención de Gráficos. En general, se supone que una sola capa GAT toma un gráfico con sus representaciones de nodos (incrustaciones) dadas como entrada, propaga información a los nodos vecinos locales y produce una representación actualizada de los nodos.

sección seleccionada del artículo GAT original

Como se resalta anteriormente, para hacerlo, primero afirman que todos los vectores de características de los nodos de entrada (hᵢ) a la capa GA son transformados linealmente (es decir, multiplicados por una matriz de pesos W), en PyTorch, generalmente se hace de la siguiente manera:

Transformación lineal de las características de los nodos - imagen por el autor
import torchfrom torch import nn# in_features -> F y out_feature -> F'in_features = ...out_feature = ...# instanciar la matriz de pesos aprendibles W (FxF')W = nn.Parameter(torch.empty(size=(in_features, out_feature)))# Inicializar la matriz de pesos Wnn.init.xavier_normal_(W)# multiplicar W y h (h es la entrada de características de todos los nodos -> matriz NxF)h_transformed = torch.mm(h, W)

Ahora teniendo en cuenta que hemos obtenido una versión transformada de nuestras características de nodos de entrada (incrustaciones), avanzamos algunos pasos para observar y comprender cuál es nuestro objetivo final en una capa GAT.

Como se describe en el artículo, al final de una capa de atención de gráficos, para cada nodo i, necesitamos obtener un nuevo vector de características que sea más consciente de la estructura y el contexto de su vecindario.

Esto se logra calculando una suma ponderada de las características de los nodos vecinos seguida de una función de activación no lineal σ. Esta suma ponderada también se conoce como el ‘Paso de Agregación’ en las operaciones generales de capas GNN, según la literatura de Graph ML.

Estos pesos αᵢⱼ ∈ [0, 1] se aprenden y se calculan mediante un mecanismo de atención que denota la importancia de las características del vecino j para el nodo i durante el paso de mensajes y la agregación.

sección seleccionada del artículo GAT original

Ahora veamos cómo se calculan estos pesos de atención αᵢⱼ para cada par de nodos i y su vecino j:

En resumen, los pesos de atención αᵢⱼ se calculan de la siguiente manera:

sección seleccionada del artículo GAT original

Donde los eᵢⱼ son puntuaciones de atención y se aplica la función Softmax para que todos los pesos estén en el intervalo [0, 1] y sumen 1.

Los puntajes de atención eᵢⱼ se calculan ahora entre cada nodo i y sus vecinos j ∈ N a través de la función de atención a(…) de la siguiente manera:

Sección seleccionada del artículo original de GAT

Donde || denota la concatenación de dos incrustaciones de nodos transformados, y a es un vector de parámetros aprendibles (es decir, parámetros de atención) de tamaño 2 * F’ (el doble del tamaño de las incrustaciones transformadas).

Y el (aᵀ) es la transpuesta del vector a, lo que resulta en que toda la expresión aᵀ [Whᵢ|| Whⱼ] sea el producto punto (producto interno) entre “a” y la concatenación de las incrustaciones transformadas.

La operación completa se ilustra a continuación:

Cálculo de los puntajes de atención en GAT - imagen del autor

En PyTorch, para obtener estos puntajes, tomamos un enfoque ligeramente diferente. Debido a que es más eficiente calcular eᵢⱼ entre todas las parejas de nodos y luego seleccionar solo aquellas que representan aristas existentes entre nodos. Para calcular todos los eᵢⱼ:

# instanciar el vector de parámetros de atención aprendibles 'a'a = nn.Parameter(torch.empty(size=(2 * out_feature, 1)))# Inicializar el vector de parámetros 'a'nn.init.xavier_normal_(a)# hemos obtenido 'h_transformed' en el fragmento de código anterior# calcular el producto punto de todas las incrustaciones de nodos# y la primera mitad de los parámetros del vector de atención (correspondientes a los mensajes vecinos)source_scores = torch.matmul(h_transformed, self.a[:out_feature, :])# calcular el producto punto de todas las incrustaciones de nodos# y la segunda mitad de los parámetros del vector de atención (correspondientes al nodo objetivo)target_scores = torch.matmul(h_transformed, self.a[out_feature:, :])# realizar una suma por difusión e = source_scores + target_scores.Te = self.leakyrelu(e)

La última parte del fragmento de código (# suma por difusión) suma todos los puntajes fuente y objetivo uno a uno, lo que resulta en una matriz NxN que contiene todos los puntajes eᵢⱼ (ilustrado a continuación)

Cálculo paralelo vectorizado de los puntajes de atención entre todos los nodos en GAT - imagen del autor

Hasta ahora, hemos asumido que el grafo está completamente conectado y hemos calculado los puntajes de atención entre todas las posibles parejas de nodos. Para abordar esto, después de aplicar la activación LeakyReLU a los puntajes de atención, los puntajes de atención se enmascaran en función de las aristas existentes en el grafo, lo que significa que solo mantenemos los puntajes que corresponden a las aristas existentes.

Esto se puede lograr asignando un puntaje negativo grande (para aproximar -∞) a los elementos en la matriz de puntajes entre nodos sin aristas existentes para que sus pesos de atención correspondientes sean cero después de la función softmax.

Podemos lograr esto utilizando la matriz de adyacencia del grafo. La matriz de adyacencia es una matriz NxN con 1 en la fila i y columna j si hay una arista entre el nodo i y j, y 0 en cualquier otro lugar. Por lo tanto, creamos la máscara asignando -∞ a los elementos cero de la matriz de adyacencia y asignando 0 en cualquier otro lugar. Luego, agregamos la máscara a nuestra matriz de puntajes y aplicamos la función softmax a través de sus filas.

connectivity_mask = -9e16 * torch.ones_like(e)# adj_mat es la matriz de adyacencia N por N
e = torch.where(adj_mat > 0, e, connectivity_mask) # puntuaciones de atención enmascaradas        # los coeficientes de atención se calculan aplicando una softmax por filas# para cada columna j en la matriz de puntuaciones de atención e
attention = F.softmax(e, dim=-1)

Finalmente, según el artículo, después de obtener las puntuaciones de atención y enmascararlas con las aristas existentes, obtenemos los pesos de atención αᵢⱼ realizando una softmax por filas de la matriz de puntuaciones.

sección seleccionada del artículo original de GAT
Ilustración de la aplicación de máscara de conectividad y softmax a las puntuaciones de atención para obtener los coeficientes de atención - imagen del autor.

Y como se discutió anteriormente, calculamos la suma ponderada de las incrustaciones de nodos:

# las incrustaciones finales de los nodos se calculan como un promedio ponderado de las características de sus vecinosh_prime = torch.matmul(attention, h_transformed)

Finalmente, el artículo introduce el concepto de atención multi-head, donde todas las operaciones discutidas se realizan a través de múltiples flujos paralelos de operaciones, donde las cabezas de resultado finales se promedian o se concatenan.

sección seleccionada del artículo original de GAT

El proceso de atención y agregación multi-head se ilustra a continuación:

Ilustración de atención multi-head (con K = 3 cabezas) por el nodo 1 en su vecindario. Diferentes estilos y colores de flechas denotan cálculos de atención independientes. Las características agregadas de cada cabeza se concatenan o promedian para obtener h'. - Imagen del artículo original

Para concluir la implementación de forma modular y limpia (como un módulo PyTorch) y para incorporar la funcionalidad de atención multi-head, la implementación completa de Graph Attention Layer se realiza de la siguiente manera:

import torch
from torch import nn
import torch.nn.functional as F

################################### DEFINICIÓN DE LA CAPA GAT ###################################
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int,
                 n_heads: int, concat: bool = False, dropout: float = 0.4,
                 leaky_relu_slope: float = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.n_heads = n_heads # Número de cabezas de atención
        self.concat = concat # si concatenar las cabezas de atención finales
        self.dropout = dropout # Tasa de dropout
        if concat: # concatenar las cabezas de atención
            self.out_features = out_features # Número de características de salida por nodo
            assert out_features % n_heads == 0 # Asegurarse de que out_features sea múltiplo de n_heads
            self.n_hidden = out_features // n_heads
        else: # promediar la salida sobre las cabezas de atención (usado en el artículo principal)
            self.n_hidden = out_features
        # Se aplica una transformación lineal compartida, parametrizada por una matriz de pesos W, a cada nodo
        # Inicializar la matriz de pesos W
        self.W = nn.Parameter(torch.empty(size=(in_features, self.n_hidden * n_heads)))
        # Inicializar los pesos de atención a
        self.a = nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1)))
        self.leakyrelu = nn.LeakyReLU(leaky_relu_slope) # Función de activación LeakyReLU
        self.softmax = nn.Softmax(dim=1) # Función de activación softmax para los coeficientes de atención
        self.reset_parameters() # Reiniciar los parámetros

    def reset_parameters(self):
        nn.init.xavier_normal_(self.W)
        nn.init.xavier_normal_(self.a)

    def _get_attention_scores(self, h_transformed: torch.Tensor):
        source_scores = torch.matmul(h_transformed, self.a[:, :self.n_hidden, :])
        target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden:, :])
        # suma de transmisión
        # (n_heads, n_nodes, 1) + (n_heads, 1, n_nodes) = (n_heads, n_nodes, n_nodes)
        e = source_scores + target_scores.mT
        return self.leakyrelu(e)

    def forward(self,  h: torch.Tensor, adj_mat: torch.Tensor):
        n_nodes = h.shape[0]
        # Aplicar transformación lineal a la característica del nodo -> W h
        # forma de salida (n_nodes, n_hidden * n_heads)
        h_transformed = torch.mm(h, self.W)
        h_transformed = F.dropout(h_transformed, self.dropout, training=self.training)
        # dividir las cabezas mediante el cambio de forma del tensor y colocando la dimensión de las cabezas primero
        # forma de salida (n_heads, n_nodes, n_hidden)
        h_transformed = h_transformed.view(n_nodes, self.n_heads, self.n_hidden).permute(1, 0, 2)

        # obtener las puntuaciones de atención
        # forma de salida (n_heads, n_nodes, n_nodes)
        e = self._get_attention_scores(h_transformed)
        # Establecer la puntuación de atención para las aristas que no existen en -9e15 (ENMASCARAMIENTO DE ARISTAS QUE NO EXISTEN)
        connectivity_mask = -9e16 * torch.ones_like(e)
        e = torch.where(adj_mat > 0, e, connectivity_mask) # puntuaciones de atención enmascaradas

        # se calculan los coeficientes de atención como un softmax por filas
        # para cada columna j en la matriz de puntuaciones de atención e
        attention = F.softmax(e, dim=-1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        # las incrustaciones finales de los nodos se calculan como un promedio ponderado de las características de sus vecinos
        h_prime = torch.matmul(attention, h_transformed)
        # concatenar/promediar las cabezas de atención
        # forma de salida (n_nodes, out_features)
        if self.concat:
            h_prime = h_prime.permute(1, 0, 2).contiguous().view(n_nodes, self.out_features)
        else:
            h_prime = h_prime.mean(dim=0)
        return h_prime

A continuación, los autores hacen una comparación entre GATs y algunas de las otras metodologías/arquitecturas existentes de GNN. Argumentan que:

  1. Los GATs son computacionalmente más eficientes que algunos métodos existentes debido a que pueden calcular los pesos de atención y realizar la agregación local en paralelo.
  2. Los GATs pueden asignar diferente importancia a los vecinos de un nodo mientras agregan mensajes, lo que puede permitir un salto en la capacidad del modelo y aumentar la interpretabilidad.
  3. GAT considera el vecindario completo de los nodos (no requiere muestreo de vecinos) y no asume ningún orden dentro de los nodos.
  4. GAT puede reformularse como una instancia particular de MoNet (Monti et al., 2016) al establecer la función de pseudo-coordenadas como u(x, y) = f(x)||f(y), donde f(x) representa características (potencialmente transformadas por MLP) del nodo x y || es la concatenación; y la función de peso como wj(u) = softmax(MLP(u))

Sección 3 — Evaluación

En la tercera sección del artículo, primero, los autores describen los benchmarks, conjuntos de datos y tareas en las que se evalúa el GAT. Luego presentan los resultados de su evaluación del modelo.

Aprendizaje transductivo vs. Aprendizaje inductivo Los conjuntos de datos utilizados como benchmarks en este artículo se diferencian en dos tipos de tareas, Transductivo e Inductivo.

  • Aprendizaje inductivo: Es un tipo de tarea de aprendizaje supervisado en el que un modelo se entrena solo en un conjunto de ejemplos de entrenamiento etiquetados y el modelo entrenado se evalúa y prueba en ejemplos que no se observaron durante el entrenamiento. Es el tipo de aprendizaje que se conoce como aprendizaje supervisado común.
  • Aprendizaje transductivo: En este tipo de tarea, todos los datos, incluidos los ejemplos de entrenamiento, validación y prueba, se utilizan durante el entrenamiento. Pero en cada fase, el modelo solo accede al conjunto correspondiente de etiquetas. Esto significa que durante el entrenamiento, el modelo solo se entrena utilizando la pérdida resultante de las instancias y etiquetas de entrenamiento, pero las características de prueba y validación se utilizan para el paso de mensajes. Esto se debe principalmente a la información estructural y contextual existente en los ejemplos.

Conjuntos de datos En el artículo, se utilizan cuatro conjuntos de datos de referencia para evaluar los GATs, tres de los cuales corresponden a aprendizaje transductivo y otro se utiliza como una tarea de aprendizaje inductivo.

Los conjuntos de datos de aprendizaje transductivo, a saber, los conjuntos de datos Cora, Citeseer y Pubmed (Sen et al., 2008) son todos gráficos de citas en los que los nodos son documentos publicados y las aristas (conexiones) son citas entre ellos, y las características del nodo son elementos de una representación de bolsa de palabras de un documento. El conjunto de datos de aprendizaje inductivo es un conjunto de datos de interacción de proteínas-proteínas (PPI) que contiene gráficos de diferentes tejidos humanos (Zitnik & Leskovec, 2017). Los conjuntos de datos se describen más abajo:

Resumen de los conjuntos de datos utilizados en nuestros experimentos — del artículo original.

Configuración y resultados

  • Para las tres tareas transductivas, la configuración utilizada para el entrenamiento es:Utilizan 2 capas GAT — la capa 1 usa- K = 8 cabezas de atención- F’ = 8 dimensión de características de salida por cabeza- Activación ELUy para la segunda capa [Cora & Citeseer | Pubmed]- [1 | 8] cabeza de atención con C número de clases dimensión de salida- Activación Softmax para la salida de probabilidad de clasificacióny para la red general- Dropout con p = 0.6– Regularización L2 con λ = [0.0005 | 0.001]
  • Para las tres tareas transductivas, la configuración utilizada para el entrenamiento es:Tres capas — – Capa 1 y 2: K = 4 | F’ = 256 | ELU – Capa 3: K = 6 | F’ = C clases | Sigmoid (multi-etiqueta)sin regularización y dropout

La implementación de la primera configuración en PyTorch se realiza a continuación utilizando la capa que definimos anteriormente:

class GAT(nn.Module):    def __init__(self,        in_features,        n_hidden,        n_heads,        num_classes,        concat=False,        dropout=0.4,        leaky_relu_slope=0.2):        super(GAT, self).__init__()        # Definir las capas de atención de gráficos        self.gat1 = GraphAttentionLayer(            in_features=in_features, out_features=n_hidden, n_heads=n_heads,            concat=concat, dropout=dropout, leaky_relu_slope=leaky_relu_slope            )                self.gat2 = GraphAttentionLayer(            in_features=n_hidden, out_features=num_classes, n_heads=1,            concat=False, dropout=dropout, leaky_relu_slope=leaky_relu_slope            )    def forward(self, input_tensor: torch.Tensor , adj_mat: torch.Tensor):        # Aplicar la primera capa de atención de gráficos        x = self.gat1(input_tensor, adj_mat)        x = F.elu(x) # Aplicar la función de activación ELU a la salida de la primera capa        # Aplicar la segunda capa de atención de gráficos        x = self.gat2(x, adj_mat)        return F.softmax(x, dim=1) # Aplicar la función de activación softmax

Después de las pruebas, los autores informan el siguiente rendimiento para los cuatro puntos de referencia, mostrando los resultados comparables de GAT en comparación con los métodos GNN existentes.

Resumen de los resultados en términos de precisión de clasificación para Cora, Citeseer y Pubmed - del artículo original.
Resumen de los resultados en términos de puntuaciones F1 micro promedio, para el conjunto de datos PPI - del artículo original.

Conclusión

En conclusión, en esta publicación de blog, intenté adoptar un enfoque detallado y fácil de seguir para explicar el artículo “Graph Attention Networks” de Veličković et al. mediante el uso de ilustraciones para ayudar a los lectores a comprender las ideas principales detrás de estas redes y por qué son importantes para trabajar con datos estructurados en gráficos complejos (por ejemplo, redes sociales o moléculas). Además, la publicación incluye una implementación práctica del modelo utilizando PyTorch, un marco de programación popular. Al leer la publicación de blog y probar el código, espero que los lectores puedan obtener una comprensión sólida de cómo funcionan los GAT y cómo se pueden aplicar en escenarios del mundo real. Espero que esta publicación haya sido útil y alentadora para explorar aún más esta emocionante área de investigación.

Además, puede acceder al código completo utilizado en esta publicación, que contiene el código de entrenamiento y validación, en este repositorio de GitHub.

Estaré encantado de escuchar cualquier pensamiento o sugerencia/cambio en la publicación.

Referencias

[1] – Graph Attention Networks (2017), Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, Yoshua Bengio. arXiv:1710.10903v3

[2] – Inductive Representation Learning on Large Graphs (2017), William L. Hamilton, Rex Ying, Jure Leskovec. arXiv:1706.02216v4

[3] – Semi-Supervised Classification with Graph Convolutional Networks (2016), Thomas N. Kipf, Max Welling. arXiv:1609.02907v4

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Varias filtraciones de datos en 23andMe

Datos genéticos robados llevan a una demanda colectiva contra la empresa de pruebas.

Inteligencia Artificial

15+ Herramientas de IA para Desarrolladores (Diciembre 2023)

GitHub Copilot GitHub Copilot se destaca como un asistente de codificación impulsado por IA líder en el mercado. Dise...

Inteligencia Artificial

¿Cómo funciona realmente la Difusión Estable? Una explicación intuitiva

Este breve artículo explica cómo funciona la Difusión Estable de manera intuitiva para principiantes. Es un vistazo b...

Inteligencia Artificial

Esta investigación de IA presenta MeshGPT Un enfoque novedoso para la generación de formas que produce mallas directamente en forma de triángulos

MeshGPT es propuesto por investigadores de la Universidad Técnica de Munich, Politécnico de Turín, AUDI AG como un mé...

Inteligencia Artificial

Conoce Cursive Un Marco de Inteligencia Artificial Universal e Intuitivo para Interactuar con LLMs

En el ámbito de la interfaz con los Modelos de Lenguaje Grande (LLMs, por sus siglas en inglés), los desarrolladores ...

Investigación

Imágenes detalladas desde el espacio ofrecen una imagen más clara de los efectos de la sequía en las plantas.

Los investigadores de J-WAFS están utilizando observaciones de teledetección para construir sistemas de alta resoluci...