Graph Random Neural Network for Semi-Supervised Learning
What started out as a concerted effort to explore Semi-Supervised Learning approaches, eventually resulted in me stumbling upon
this NeurIPS ‘20 publication which provided a beautiful solution; seemingly simple to understand and indeed implement and most importantly, grounded in traditional ML concepts, instead of increasing the complexity of existing architectures.
I’m fascinated by Graph Neural Networks and have been actively exploring various usecases and research problems involving these.
The architecture proposed in the linked publication helps solve problems that have plagued GNNs in tackling semi-supervised learning
problems for a while now:
- Over-smoothing of node representations
- Non-robustness to graph attacks
- Weak generalization in semi-supervised settings when node labels are scarce in the graph
How does GRAND tackle these problems? Random Propagation(RP) is the facilitator:
- It uses graph perturbation methods such as DropNode, Dropout and DropEdge, and combines them with a simple RP step which increases the robustness of the graph by making individual nodes much less dependent on each other, hence less susceptible to such perturbation attacks.
- The design of RP also facilitates separating the feature propagation and feature transformation steps, which are usually coupled in most GNNs, hence reducing the risk of over-smoothing.
- Most importantly, as the name suggests, RP allows each node to randomly pass messages to its neighbors, hence under the homophily assumption, different augmented representations of each node are generated. Consistency Regularization(CR) is then used to force a prediction model such as MLP or GCN to output similar predictions on different augmentations of the same unlabelled data. This in turn imporoves the generalizability of GRAND
Experiments
Apart from an Ablation Study, I conducted various experiments to get a good understanding of the GRAND architecture. Some of them are:
- MLP v/s GCN as the prediction network(GRAND v/s GRAND_GCN)
- Changes in classification accuracy wrt changes in propagation order and number of data augmentations in RP
- Different augmentation strategies: Dropout(partial node feature dropout), DropNode(drop an entire node), DropEdge
- Sensitivity wrt CR loss coefficient
Technical Details
- Language: Python
- Library: NetworkX
- Framework: PyTorch