The paper "Deep Grokking: Would Deep Neural Networks Generalize Better?" by Simin Fan, Razvan Pascanu, and Martin Jaggi investigates the grokking phenomenon in deep neural networks, a previously unexplored area.
Before we dive into the paper, a quick introduction to "Grokking" might help.
Grokking, a term originating from Robert A. Heinlein's 1961 science fiction novel Stranger in a Strange Land, refers to a profound understanding of a subject that transcends logical reasoning. The term "grok" means to "drink" or "understand fully."
In the context of machine learning, grokking describes a phenomenon where a neural network suddenly and significantly improves its performance on a test dataset after an extended period of overfitting on the training data. In other words, grokking is used to describe a sudden phase transition where a model, after seemingly failing to generalize, abruptly starts performing well on unseen data.

Summary of the Paper . . .
The paper explores how deep neural networks can suddenly achieve high test accuracy after a prolonged period of overfitting. Here's an overview:
1. Grokking Phenomenon: Grokking is observed when a deep neural network, after overfitting to the training data, experiences a sudden improvement in test accuracy. This study focuses on 12-layer Multi-Layer Perceptrons (MLPs).
2. Comparison with Shallow Networks: The paper shows that deep networks are more prone to grokking than shallow networks. This means that deeper networks have a higher tendency to improve their generalization capabilities after a period of poor performance on test data.
3. Feature Rank Correlation: A significant finding is the correlation between feature ranks and generalization. As the network starts to generalize better, the ranks of the features (essentially a measure of the complexity and significance of features learned by the network) tend to decrease. This indicates that the network learns more refined and efficient representations of the data over time.
4. Multi-Stage Generalization: Deep networks often exhibit a multi-stage improvement in test accuracy. This is characterized by distinct phases of accuracy gains, which are linked to changes in the feature ranks. The paper observes a "double-descent" pattern, where the feature ranks first increase and then decrease, corresponding to phases of generalization.
5. Implications for Training: The study suggests that monitoring feature ranks can provide better insights into the generalization behavior of deep networks compared to traditional metrics like weight-norm. This can help in understanding the training dynamics and improving the generalization performance of deep neural networks.
In essence, the paper highlights the complex but fascinating dynamics of how deep neural networks learn and generalize, offering new perspectives on improving AI model performance.
In More Detail . . .
Key Insights from the Paper
Susceptibility of Deep Networks to Grokking:
Deep neural networks, specifically those with more layers, are more prone to experiencing grokking compared to shallow networks.
Multi-Stage Generalization:
In deep networks, test accuracy can exhibit multiple sharp increases rather than a single phase transition. This is a novel observation, differing from the previously documented single-stage grokking in shallow networks.
Feature Rank as an Indicator:
The study finds that internal feature ranks (the complexity and quality of features learned by each layer) correlate strongly with the grokking phase transition. This suggests that feature rank could be a more reliable indicator of a model's generalization capability than weight-norms.
Experimental Setup and Findings
Model and Dataset
Architecture and Configuration:
The researchers utilized Multi-Layer Perceptrons (MLPs) with varying depths, specifically focusing on models with 3, 6, 9, and 12 layers.
They conducted experiments on the MNIST dataset, a standard benchmark for image classification tasks.
Initialization and Training:
The networks were initialized with large weights, which is known to induce grokking behavior.
A small weight decay was employed during training, ensuring that the model didn't overly penalize large weights.
Stochastic Gradient Descent (SGD) with momentum was used as the optimization algorithm, which helps accelerate gradients vectors in the right directions, leading to faster converging.
Key Measurements
Linear Probing Accuracy:
Linear probing involves training a linear classifier on the features extracted from each layer of the network.
This method helps in understanding the quality of features learned at different stages of the training process.
During the grokking phase, a sharp increase in linear probing accuracy was observed, indicating a sudden improvement in the feature quality.
Numerical Rank:
The numerical rank of the feature matrix at each layer was measured to understand the complexity of features.
A significant drop in numerical rank was observed during the grokking phase, suggesting that the network was learning more compact and efficient representations.
Experimental Results
Deep MLPs vs. Shallow MLPs:
Deeper MLPs showed delayed improvements in training and test accuracy, indicating a stronger tendency towards grokking.
Shallow networks exhibited more immediate improvements but did not experience the same dramatic phase transitions as deeper networks.
Multi-Stage Generalization:
For deeper networks, the test accuracy exhibited multiple sharp increases, a phenomenon mirrored by double-descent patterns in feature ranks.
This multi-stage grokking is a novel observation, differing from the previously documented single-stage grokking in shallow networks.
Feature Rank Correlation:
Feature ranks decreased significantly during the grokking phase transition, suggesting that feature rank is a promising indicator of generalization performance.
This finding highlights the importance of internal feature quality over traditional weight-norm metrics for understanding model generalization capabilities.
Analysis and Implications
Multi-Layer Perceptron (MLP) Architecture
Structure:
The Multi-Layer Perceptrons (MLPs) used in the experiments had varying depths: 3, 6, 9, and 12 layers.
Each layer comprised a large number of neurons with ReLU (Rectified Linear Unit) as the activation function.
Batch normalization was applied to improve the stability of the training process.
Initialization and Training:
Networks were initialized with large weights to induce grokking behavior.
Small weight decay was employed to avoid overly penalizing large weights.
Stochastic Gradient Descent (SGD) with momentum was the optimization algorithm, accelerating gradient vectors in the correct directions for faster convergence.
Measuring Internal Representations
Linear Probing Accuracy:
Linear probing involves training a linear classifier on features extracted from each layer of the network.
This method helps evaluate the quality of features learned at different stages of training.
During the grokking phase, a sharp increase in linear probing accuracy indicated a sudden improvement in feature quality.
Numerical Rank:
The numerical rank of the feature matrix at each layer was measured to understand feature complexity.
A significant drop in numerical rank during the grokking phase suggested the network was learning more compact and efficient representations.
Experimental Results
Deep MLPs vs. Shallow MLPs:
Deeper MLPs showed delayed improvements in training and test accuracy, indicating a stronger tendency towards grokking.
Shallow networks exhibited more immediate improvements but lacked the dramatic phase transitions seen in deeper networks.
Multi-Stage Generalization:
Test accuracy in deeper networks exhibited multiple sharp increases, mirrored by double-descent patterns in feature ranks.
This multi-stage grokking is a novel observation, differing from the single-stage grokking documented in shallow networks.
Feature Rank Correlation:
Feature ranks decreased significantly during the grokking phase transition, suggesting that feature rank is a promising indicator of generalization performance.
This highlights the importance of internal feature quality over traditional weight-norm metrics for understanding model generalization capabilities.
Conclusion
The study explores grokking in deep neural networks, revealing that deeper models are more susceptible to this phenomenon. The multi-stage generalization observed in these networks and the correlation between feature rank and grokking provide fresh insights into understanding neural network training dynamics. These findings suggest that feature rank could serve as a more effective indicator of a model's potential for generalization compared to traditional metrics like weight-norm.
This exploration of grokking in deep networks opens avenues for further research, particularly in developing more robust indicators for model generalization and understanding the underlying mechanisms driving these phase transitions. Future work could investigate the applicability of these findings to other neural network architectures and datasets, potentially leading to improved training techniques and more reliable neural network models.
For a more detailed understanding, you can access the full paper here.
Cluedo Tech can help you with your AI strategy, use cases, development, and execution. Request a meeting.