Building neural networks from first principles using geometry and symmetry
The core insight of Geometric Deep Learning is that successful neural network architectures can be derived from first principles using three foundational concepts. This provides a unified framework for understanding CNNs, GNNs, and Transformers as special cases of the same geometric blueprint.
The geometric structure underlying your data (grids, graphs, manifolds)
Transformations that should not change the output (translations, rotations, permutations)
The space of functions/features defined on the domain
Neural network layers should be designed to be equivariant to the symmetry group:
Geometric Deep Learning categorizes domains into five geometric categories (the "5Gs"), each with distinct symmetry groups and corresponding architectures:
| Domain | Symmetry Group | Architecture | Applications |
|---|---|---|---|
| Grids | Translation | CNNs | Images, video |
| Groups | Group elements | Group-equivariant CNNs | Rotational data |
| Graphs | Permutation | GNNs, Message Passing | Molecules, social networks |
| Geodesics | Isometries | Geometric CNNs | 3D shapes, meshes |
| Gauges | Gauge transformations | Gauge-equivariant networks | Particle physics, manifolds |
ψ = message function (computes messages between node pairs)⊕ = aggregation function (sum, mean, max, or attention)φ = update function (MLP that combines node state with aggregated messages)Fixed, pre-computed attention weights based on graph topology. Best for homophilous graphs (similar nodes connect). Most scalable via sparse matrix multiplication.
Learned, feature-dependent attention weights. Can handle heterophilous graphs. Examples: GAT, Transformers.
| Operation | Trade-offs |
|---|---|
| Sum | Default choice, preserves multiset info. Sensitive to outliers. |
| Mean | Normalized view, variable neighborhoods. Loses count info. |
| Max | Highlights salient features. Loses multiset info. |
| Attention | Learns importance dynamically. More parameters, slower. |
Doesn't need to learn the same function for all transformed versions
Built-in invariances prevent overfitting to spurious correlations
Respects known symmetries of the problem domain
| Problem | Solutions |
|---|---|
| Over-smoothing | Skip connections, DropEdge, normalization |
| Over-squashing | Graph rewiring, virtual nodes |
| Limited expressivity | Higher-order WL tests, subgraph methods |
| Long-range dependencies | Graph Transformers, virtual edges |
| Domain | Key Architecture Features |
|---|---|
| Molecular property prediction | Invariant to atom permutation, rotation-equivariant for 3D |
| Protein structure (AlphaFold) | SE(3)-equivariant attention, multi-scale |
| Drug discovery | Message passing on molecular graphs |
| Traffic prediction | Spatio-temporal GNNs |
| Weather forecasting | Icosahedral mesh GNNs (GraphCast), diffusion models (GenCast) |
| Physics simulation | Equivariant to physical symmetries |
Most comprehensive, production-ready
Framework-agnostic
JAX-based, good for research
For E(3)-equivariant networks
"The most successful deep learning architectures (CNNs, GNNs, Transformers) are all special cases of the same geometric blueprint."