NOAH SYRKIS
about
|
works
Mechanistic Interpretability on
(multi-task) Irreducible Integer
Identifiers
Noah Syrkis
January 16, 2025
1 |
Mechanistic Interpretability (MI)
2 |
Modular Arithmetic
3 |
Grokking on
𝒯
miiii
4 |
Embeddings
5 |
Neurons
6 |
The
𝜔
-Spike
Figure 1:
ℕ
<
𝑝
2
multiples of 13 or 27 (left) 11 (mid.) or primes (right)
“This disgusting pile of matrices is actually just an astoundingly poorly written, elegant and
consice algorithm” — Neel Nanda
¹
¹
Not verbatim, but the gist of it
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
▶
No obvious mapping from the weights of a trained model to math notation
▶
MI is about reverse engineering these models, and looking closely at them
▶
Many low hanging fruits / practical botany phase of the science
▶
How does a given model work? How can we train it faster? Is it safe?
1 of
18
1 |
Grokking
▶
Grokking
[1]
is “sudden generalization”
▶
MI (often) needs a mechanism
▶
Grokking is thus convenient for MI
▶
Lee et al. (2024)
speeds up grokking by
boosting slow gradients as per
Eq. 1
▶
For more see
Appendix A
ℎ
(
𝑡
)
=
ℎ
(
𝑡
−
1
)
𝛼
+
𝑔
(
𝑡
)
(
1
−
𝛼
)
(1.1)
̂
𝑔
(
𝑡
)
=
𝑔
(
𝑡
)
+
𝜆
ℎ
(
𝑡
)
(1.2)
2 of
18
1 |
Visualizing
▶
MI needs creativity … but there are tricks:
▶
For two-token samples, plot them varying
one on each axis (
Figure 2
)
▶
When a matrix is periodic use Fourier
▶
Singular value decomp. (
Appendix C
).
▶
Take away: get commfy with
esch
-plots
Figure 2: Top singular vectors of
𝐔
𝑊
𝐸
𝒯
nanda
(top), varying
𝑥
0
and
𝑥
1
in sample (left) and
freq. (right) space in
𝑊
out
𝒯
miiii
3 of
18
Figure 3: Shamleless plug: visit
github.com/syrkis/esch
for more esch plots
2 |
Modular Arithmetic
▶
“Seminal” MI paper by
Nanda et al. (2023)
focuses on modular additon (
Eq. 2
)
▶
Their final setup trains on
𝑝
=
1
1
3
▶
They train a one-layer transformer
▶
We call their task
𝒯
nanda
▶
And ours, seen in
Eq. 3
, we call
𝒯
miiii
(
𝑥
0
+
𝑥
1
)
mod
𝑝
,
∀
𝑥
0
,
𝑥
1
(2)
(
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
)
mod
𝑞
,
∀
𝑞
<
𝑝
(3)
4 of
18
2 |
Modular Arithmetic
▶
𝒯
miiii
is non-commutative …
▶
… and multi-task:
𝑞
ranges from 2 to 109
¹
▶
𝒯
nanda
use a single layer transformer
▶
Note that these tasks are synthetic and trivial to solve with conventional programming
▶
They are used in the MI literature to turn black boxes opaque
¹
Largest prime less than
𝑝
=
1
1
3
5 of
18
3 |
Grokking on
𝒯
miiii
▶
The model groks on
𝒯
miiii
(
Figure 4
)
▶
Needed GrokFast
[2]
on compute budget
▶
Final hyperparams are seen in
Table 1
rate
𝜆
wd
𝑑
lr
heads
1
1
0
1
2
1
3
256
3
1
0
4
4
Table 1: Hyperparams for
𝒯
miiii
Figure 4: Training (top) and validation
(bottom) accuracy during training on
𝒯
miiii
6 of
18
4 |
Embeddings
▶
The position embs. of
Figure 5
reflects that
𝒯
nanda
is commutative and
𝒯
miiii
is not
▶
Maybe: this corrects non-comm. of
𝒯
miiii
?
▶
Corr. is
0
.
9
5
for
𝒯
nanda
and
−
0
.
6
4
for
𝒯
miiii
Figure 5: Positional embeddings for
𝒯
nanda
(top) and
𝒯
miiii
(bottom).
7 of
18
4 |
Embeddings
▶
For
𝒯
nanda
token embs. are essentially linear
combinations of 5 frequencies (
𝜔
)
▶
For
𝒯
miiii
more frequencies are in play
▶
Each
𝒯
miiii
subtask targets unique prime
▶
Possibility: One basis per prime task
Figure 6:
𝒯
nanda
(top) and
𝒯
miiii
(bottom)
token embeddings in Fourier basis
8 of
18
4 |
Embeddings
▶
Masking
𝑞
∈
{
2
,
3
,
5
,
7
}
yields we see a
slight decrease in token emb. freqs.
▶
Sanity check:
𝒯
baseline
has no periodicity
▶
The tok. embs. encode a basis per subtask?
Figure 7:
𝒯
baseline
(top),
𝒯
miiii
(middle) and
𝒯
masked
(bottom) token embeddings in Fourier
basis
9 of
18
5 |
Neurons
▶
Figure 8
shows transformer MLP neuron
activations as
𝑥
0
,
𝑥
1
vary on each axis
▶
Inspite of the dense Fourier basis of
𝑊
𝐸
𝒯
miiii
the periodicity is clear
Figure 8: Activations of first three neurons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)
10 of
18
5 |
Neurons
▶
(Probably redundant) sanity check:
Figure 9
confirms neurons are periodic
▶
See some freqs.
𝜔
rise into significance
▶
Lets log
|
𝜔
>
𝜇
𝜔
+
2
𝜎
𝜔
|
while training
Figure 9: FFT of Activations of first three
neurons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)
11 of
18
Figure 10: Neurons as archive and algorithm.
𝒯
basline
on top, FFT on right.
Figure 11: Number of neurons with frequency
𝜔
above the theshold
𝜇
𝜔
+
2
𝜎
𝜔
6 |
The
𝜔
-Spike
▶
Neurs. periodic on solving
𝑞
∈
{
2
,
3
,
5
,
7
}
▶
When we generalize to the reamining tasks,
many frequencies activate (64-sample)
▶
Those
𝜔
’s are not useful for memory and
not useful after generalization
time
256
1024
4096
16384
65536
|
𝝎
|
0
0
1
0
1
8
10
Table 2: active
𝜔
’s through training
Figure 12:
Figure 11
(top) and validation
accuracy from
Figure 4
(bottom)
12 of
18
6 |
The
𝜔
-Spike
▶
GrokFast
[2]
shows time gradient sequences is (arguably) a stocastical signal with:
▶
A fast varying overfitting component
▶
A slow varying generealizing component
▶
My work confirms this to be true for
𝒯
miiii
…
▶
… and observes a strucutre that seems to fit
neither
of the two
13 of
18
6 |
The
𝜔
-Spike
▶
Future work:
▶
Modify GrokFast to assume a third stochastic component
▶
Relate to signal processing literature
▶
Can more depth make tok-embedding sparse?
14 of
18
TAK
References
[1]
A. Power, Y. Burda, H. Edwards, I. Babuschkin, and V. Misra, “Grokking: Generalization
Beyond Overfitting on Small Algorithmic Datasets,” no. arXiv:2201.02177. arXiv, Jan.
2022. doi:
10.48550/arXiv.2201.02177
.
[2]
J. Lee, B. G. Kang, K. Kim, and K. M. Lee, “Grokfast: Accelerated Grokking by
Amplifying Slow Gradients,” no. arXiv:2405.20233. Jun. 2024.
[3]
N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress Measures for
Grokking via Mechanistic Interpretability,” no. arXiv:2301.05217. arXiv, Oct. 2023.
15 of
18
A |
Stochastic Signal Processing
We denote the weights of a model as
𝜃
. The gradient of
𝜃
with respect to our loss function at
time
𝑡
we denote
𝑔
(
𝑡
)
. As we train the model,
𝑔
(
𝑡
)
varies, going up and down. This can be
thought of as a stocastic signal. We can represent this signal with a Fourier basis
(
Appendix B
). GrokFast posits that the slow varying frequencies contribute to grokking. Higer
frequencies are then muted, and grokking is indeed accelerated.
16 of
18
B |
Discrete Fourier Transform
Function can be expressed as a linear combination of cosine and sine waves. A similar thing
can be done for data / vectors.
17 of
18
C |
Singular Value Decomposition
An
𝑛
×
𝑚
matrix
𝑀
can be represented as a
𝑈
Σ
𝑉
∗
, where
𝑈
is an
𝑚
×
𝑚
complex unitary
matrix,
Σ
a rectangular
𝑚
×
𝑛
diagonal matrix (padded with zeros), and
𝑉
an
𝑛
×
𝑛
complex
unitary matrix. Multiplying by
𝑀
can thus be viewed as first rotating in the
𝑚
-space with
𝑈
,
then scaling by
Σ
and then rotating by
𝑉
in the
𝑛
-space.
18 of
18