talks
|
works
MIIII
January 8, 2025
Rio de Janeiro, Brazil
Press D to download as a PDF and F to toggle presentation mode. Navigate up and down with K and J when presenting
MIIII
Noah Syrkis
January 8, 2025
1 |
Mechanistic Interpretability (MI)
2 |
Modular Arithmetic
3 |
Grokking on
𝒯
miiii
4 |
Embeddings
5 |
Neurons
6 |
The
𝜔
-Spike
“This disgusting pile of matrices is actually just an astoundingly poorly written, elegant and consice
algorithm” — Neel Nanda
1
1
Not verbatim, but the gist of it
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
▶
No obvious mapping from the weights of a trained model to math notation
▶
MI is about reverse engineering these models, and looking closely at them
▶
How does a given model work? How can we train it faster? Is it safe?
1 of
24
1 |
Mechanistic Interpretability (MI)
▶
Early MI work focus on modular addition
[1]
▶
𝒯
nanda
focus on a model mapping
(
𝑥
0
,
𝑥
1
)
→
𝑦
▶
True mapping given by
𝑦
=
𝑥
0
+
𝑥
1
mod
𝑝
(0,0)
(1,0)
(2,0)
(0,1)
(1,1)
(2,1)
(0,2)
(1,2)
(2,2)
Table 1: Table of
(
𝑥
0
,
𝑥
1
)
-tuples for
𝑝
=
3
𝑥
1
Figure 1: esch of
(
𝑥
0
,
𝑥
1
)
-tuples for
𝑝
=
5
2 of
24
1 |
Mechanistic Interpretability (MI)
▶
on
𝑦
from
𝒯
nanda
Remainders
𝑥
1
𝑥
0
Figure 2: esch diagram of
𝑦
from
𝒯
nanda
3 of
24
1 |
Mechanistic Interpretability (MI)
▶
Array
7
2
11
4
9
1
8
2
10
6
3
10
5
3
8
1
7
12
5
2
9
11
4
0
6
10
11
4
9
2
6
0
7
3
8
2
1
11
10
5
10
3
8
1
12
4
7
2
9
1
10
6
12
6
0
11
4
8
1
5
10
3
7
2
9
2
9
7
0
11
3
12
6
4
8
10
1
5
8
1
12
5
10
7
0
11
9
2
6
4
3
4
11
6
9
3
2
10
1
7
0
12
8
5
10
5
2
12
7
9
3
0
6
1
8
11
4
6
12
8
3
0
11
5
4
1
10
2
9
7
1
7
4
10
8
6
9
2
12
5
11
3
0
9
3
10
6
2
4
11
8
5
7
0
12
1
0
8
5
1
11
10
6
12
3
9
4
7
2
4 of
24
1 |
Mechanistic Interpretability (MI)
▶
Are hard to see
5 of
24
1 |
Mechanistic Interpretability (MI)
1.
Make a task
2.
Solve the task
3.
Inspect the solution
▶
Think artificial neuroscience
Figure 3: Target
𝑦
for as
𝑥
0
and
𝑥
1
move from
0
to
𝑝
−
1
for the task
𝑥
0
+
𝑥
1
mod
𝑝
=
𝑦
6 of
24
1.1 |
Grokking
[1]
▶
Sudden generalization long after overfitting
▶
MI (by definition) needs a mechanism
▶
Grokking is thus convenient for MI
Figure 4: Example of the grokking
7 of
24
2 |
Modular Arithmetic
▶
“Seminal” MI paper by
Nanda et al. (2023)
focuses on modular addition (
𝒯
nanda
)
▶
Their final setup trains on
𝑝
=
1
1
3
▶
They train a one-layer transformer
▶
We call their task
𝒯
nanda
▶
And ours we call
𝒯
miiii
𝒯
nanda
=
(
𝑥
0
+
𝑥
1
)
mod
𝑝
,
∀
𝑥
0
,
𝑥
1
(1.1)
𝒯
miiii
=
(
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
)
mod
𝑞
,
∀
𝑞
<
𝑝
(1.2)
8 of
24
2 |
Modular Arithmetic
▶
𝒯
miiii
is non-commutative …
▶
… and multi-task:
𝑞
ranges from 2 to 109
1
▶
𝒯
nanda
use a single layer transformer
▶
Note that these tasks are synthetic and trivial to solve with conventional programming
▶
They are used in the MI literature to turn black boxes opaque
1
Largest prime less than
𝑝
=
1
1
3
9 of
24
Figure 5:
ℕ
<
𝑝
2
multiples of 13 or 27 (left) 11 (mid.) or primes (right)
3 |
Grokking on
𝒯
miiii
▶
For two-token samples, plot them varying one
on each axis (
Figure 6
)
▶
When a matrix is periodic use Fourier
▶
Singular value decomposition
Figure 6: Top singular vectors of
𝐔
𝑊
𝐸
𝒯
nanda
(top),
varying
𝑥
0
and
𝑥
1
in sample (left) and freq.
(right) space in
𝑊
out
𝒯
miiii
10 of
24
3 |
Grokking on
𝒯
miiii
▶
The model groks on
𝒯
miiii
(
Figure 7
)
▶
Needed GrokFast
[2]
on compute budget
▶
Final hyperparams are seen in
Table 3
rate
𝜆
wd
𝑑
lr
heads
1
1
0
1
2
1
3
256
3
1
0
4
4
Table 3: Hyperparams for
𝒯
miiii
Figure 7: Training (top) and validation (bottom)
accuracy during training on
𝒯
miiii
11 of
24
4 |
Embeddings
How the embedding layer deals with the difference between
𝒯
nanda
and
𝒯
miiii
12 of
24
4.1 |
Correcting for non-commutativity
▶
The position embs. of
Figure 8
reflects that
𝒯
nanda
is commutative and
𝒯
miiii
is not
▶
Maybe: this corrects non-comm. of
𝒯
miiii
?
▶
Corr. is
0
.
9
5
for
𝒯
nanda
and
−
0
.
6
4
for
𝒯
miiii
Figure 8: Positional embeddings for
𝒯
nanda
(top)
and
𝒯
miiii
(bottom).
13 of
24
4.2 |
Correcting for multi-tasking
▶
For
𝒯
nanda
token embs. are essentially linear
combinations of 5 frequencies (
𝜔
)
▶
For
𝒯
miiii
more frequencies are in play
▶
Each
𝒯
miiii
subtask targets unique prime
▶
Possibility: One basis per prime task
Figure 9:
𝒯
nanda
(top) and
𝒯
miiii
(bottom) token
embeddings in Fourier basis
14 of
24
4.3 |
Sanity-check and task-mask
▶
Masking
𝑞
∈
{
2
,
3
,
5
,
7
}
yields we see a slight
decrease in token emb. freqs.
▶
Sanity check:
𝒯
baseline
has no periodicity
▶
The tok. embs. encode a basis per subtask?
Figure 10:
𝒯
baseline
(top),
𝒯
miiii
(middle) and
𝒯
masked
(bottom) token embeddings in Fourier
basis
15 of
24
5 |
Neurons
▶
Inspite of the dense Fourier basis of
𝑊
𝐸
𝒯
miiii
the
periodicity is clear
/public/figs/miiii/neurs_113_miiii.svg
), caption: [Activations of first three neurons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)],
)
16 of
24
5 |
Neurons
▶
(Probably redundant) sanity check:
Figure 12
confirms neurons are periodic
▶
See some freqs.
𝜔
rise into significance
▶
Lets log
|
𝜔
>
𝜇
𝜔
+
2
𝜎
𝜔
|
while training
Figure 12: FFT of Activations of first three neu
rons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)
17 of
24
Figure 13: Neurons as archive for
𝒯
basline
Figure 14: Neurons as algorithm
𝒯
miiii
Figure 15: Number of neurons with frequency
𝜔
above the theshold
𝜇
𝜔
+
2
𝜎
𝜔
6 |
The
𝜔
-Spike
▶
Neurs. periodic on solving
𝑞
∈
{
2
,
3
,
5
,
7
}
▶
When we generalize to the reamining tasks,
many frequencies activate (64-sample)
▶
Those
𝜔
’s are not useful for memory and not
useful after generalization
time
256
1024
4096
16384
65536
|
𝝎
|
0
0
1
0
1
8
10
Table 4: active
𝜔
’s through training
Figure 16:
Figure 15
(top) and validation accuracy
from
Figure 7
(bottom)
18 of
24
6 |
The
𝜔
-Spike
▶
GrokFast
[2]
shows time gradient sequences is (arguably) a stocastical signal with:
▶
A fast varying overfitting component
▶
A slow varying generealizing component
▶
My work confirms this to be true for
𝒯
miiii
…
▶
… and observes a strucutre that seems to fit
neither
of the two
19 of
24
6 |
The
𝜔
-Spike
▶
Future work:
▶
Modify GrokFast to assume a third stochastic component
▶
Relate to signal processing literature
▶
Can more depth make tok-embedding sparse?
20 of
24
References
[1]
N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress Measures for Grokking
via Mechanistic Interpretability,” no. arXiv:2301.05217. arXiv, Oct. 2023.
[2]
J. Lee, B. G. Kang, K. Kim, and K. M. Lee, “Grokfast: Accelerated Grokking by Amplifying Slow
Gradients,” no. arXiv:2405.20233. Jun. 2024.
21 of
24
A |
Stochastic Signal Processing
We denote the weights of a model as
𝜃
. The gradient of
𝜃
with respect to our loss function at time
𝑡
we denote
𝑔
(
𝑡
)
. As we train the model,
𝑔
(
𝑡
)
varies, going up and down. This can be thought of as
a stocastic signal. We can represent this signal with a Fourier basis. GrokFast posits that the slow
varying frequencies contribute to grokking. Higer frequencies are then muted, and grokking is indeed
accelerated.
22 of
24
B |
Discrete Fourier Transform
Function can be expressed as a linear combination of cosine and sine waves. A similar thing can be
done for data / vectors.
23 of
24
C |
Singular Value Decomposition
An
𝑛
×
𝑚
matrix
𝑀
can be represented as a
𝑈
Σ
𝑉
∗
, where
𝑈
is an
𝑚
×
𝑚
complex unitary matrix,
Σ
a rectangular
𝑚
×
𝑛
diagonal matrix (padded with zeros), and
𝑉
an
𝑛
×
𝑛
complex unitary matrix.
Multiplying by
𝑀
can thus be viewed as first rotating in the
𝑚
-space with
𝑈
, then scaling by
Σ
and
then rotating by
𝑉
in the
𝑛
-space.
24 of
24
MIIII
Noah Syrkis
January 8, 2025
1 |
Mechanistic Interpretability (MI)
2 |
Modular Arithmetic
3 |
Grokking on
𝒯
miiii
4 |
Embeddings
5 |
Neurons
6 |
The
𝜔
-Spike
“This disgusting pile of matrices is actually just an astoundingly poorly written, elegant and consice
algorithm” — Neel Nanda
1
1
Not verbatim, but the gist of it
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
1 of
24
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
▶
No obvious mapping from the weights of a trained model to math notation
1 of
24
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
▶
No obvious mapping from the weights of a trained model to math notation
▶
MI is about reverse engineering these models, and looking closely at them
1 of
24
1 |
Mechanistic Interpretability (MI)
▶
Sub-symbolic nature of deep learning obscures model mechanisms
▶
No obvious mapping from the weights of a trained model to math notation
▶
MI is about reverse engineering these models, and looking closely at them
▶
How does a given model work? How can we train it faster? Is it safe?
1 of
24
1 |
Mechanistic Interpretability (MI)
▶
Early MI work focus on modular addition
[1]
2 of
24
1 |
Mechanistic Interpretability (MI)
▶
Early MI work focus on modular addition
[1]
▶
𝒯
nanda
focus on a model mapping
(
𝑥
0
,
𝑥
1
)
→
𝑦
2 of
24
1 |
Mechanistic Interpretability (MI)
▶
Early MI work focus on modular addition
[1]
▶
𝒯
nanda
focus on a model mapping
(
𝑥
0
,
𝑥
1
)
→
𝑦
▶
True mapping given by
𝑦
=
𝑥
0
+
𝑥
1
mod
𝑝
2 of
24
1 |
Mechanistic Interpretability (MI)
▶
Early MI work focus on modular addition
[1]
▶
𝒯
nanda
focus on a model mapping
(
𝑥
0
,
𝑥
1
)
→
𝑦
▶
True mapping given by
𝑦
=
𝑥
0
+
𝑥
1
mod
𝑝
(0,0)
(1,0)
(2,0)
(0,1)
(1,1)
(2,1)
(0,2)
(1,2)
(2,2)
Table 1: Table of
(
𝑥
0
,
𝑥
1
)
-tuples for
𝑝
=
3
𝑥
1
Figure 1: esch of
(
𝑥
0
,
𝑥
1
)
-tuples for
𝑝
=
5
2 of
24
1 |
Mechanistic Interpretability (MI)
▶
on
𝑦
from
𝒯
nanda
Remainders
𝑥
1
𝑥
0
Figure 2: esch diagram of
𝑦
from
𝒯
nanda
3 of
24
1 |
Mechanistic Interpretability (MI)
▶
Array
7
2
11
4
9
1
8
2
10
6
3
10
5
3
8
1
7
12
5
2
9
11
4
0
6
10
11
4
9
2
6
0
7
3
8
2
1
11
10
5
10
3
8
1
12
4
7
2
9
1
10
6
12
6
0
11
4
8
1
5
10
3
7
2
9
2
9
7
0
11
3
12
6
4
8
10
1
5
8
1
12
5
10
7
0
11
9
2
6
4
3
4
11
6
9
3
2
10
1
7
0
12
8
5
10
5
2
12
7
9
3
0
6
1
8
11
4
6
12
8
3
0
11
5
4
1
10
2
9
7
1
7
4
10
8
6
9
2
12
5
11
3
0
9
3
10
6
2
4
11
8
5
7
0
12
1
0
8
5
1
11
10
6
12
3
9
4
7
2
4 of
24
1 |
Mechanistic Interpretability (MI)
▶
Are hard to see
5 of
24
1 |
Mechanistic Interpretability (MI)
1.
Make a task
2.
Solve the task
3.
Inspect the solution
Figure 3: Target
𝑦
for as
𝑥
0
and
𝑥
1
move from
0
to
𝑝
−
1
for the task
𝑥
0
+
𝑥
1
mod
𝑝
=
𝑦
6 of
24
1 |
Mechanistic Interpretability (MI)
1.
Make a task
2.
Solve the task
3.
Inspect the solution
▶
Think artificial neuroscience
Figure 3: Target
𝑦
for as
𝑥
0
and
𝑥
1
move from
0
to
𝑝
−
1
for the task
𝑥
0
+
𝑥
1
mod
𝑝
=
𝑦
6 of
24
1.1 |
Grokking
[1]
▶
Sudden generalization long after overfitting
Figure 4: Example of the grokking
7 of
24
1.1 |
Grokking
[1]
▶
Sudden generalization long after overfitting
▶
MI (by definition) needs a mechanism
Figure 4: Example of the grokking
7 of
24
1.1 |
Grokking
[1]
▶
Sudden generalization long after overfitting
▶
MI (by definition) needs a mechanism
▶
Grokking is thus convenient for MI
Figure 4: Example of the grokking
7 of
24
2 |
Modular Arithmetic
▶
“Seminal” MI paper by
Nanda et al. (2023)
focuses on modular addition (
𝒯
nanda
)
▶
Their final setup trains on
𝑝
=
1
1
3
▶
They train a one-layer transformer
▶
We call their task
𝒯
nanda
𝒯
nanda
=
(
𝑥
0
+
𝑥
1
)
mod
𝑝
,
∀
𝑥
0
,
𝑥
1
(1.1)
𝒯
miiii
=
(
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
)
mod
𝑞
,
∀
𝑞
<
𝑝
(1.2)
8 of
24
2 |
Modular Arithmetic
▶
“Seminal” MI paper by
Nanda et al. (2023)
focuses on modular addition (
𝒯
nanda
)
▶
Their final setup trains on
𝑝
=
1
1
3
▶
They train a one-layer transformer
▶
We call their task
𝒯
nanda
▶
And ours we call
𝒯
miiii
𝒯
nanda
=
(
𝑥
0
+
𝑥
1
)
mod
𝑝
,
∀
𝑥
0
,
𝑥
1
(1.1)
𝒯
miiii
=
(
𝑥
0
𝑝
0
+
𝑥
1
𝑝
1
)
mod
𝑞
,
∀
𝑞
<
𝑝
(1.2)
8 of
24
2 |
Modular Arithmetic
▶
𝒯
miiii
is non-commutative …
▶
… and multi-task:
𝑞
ranges from 2 to 109
1
▶
𝒯
nanda
use a single layer transformer
▶
Note that these tasks are synthetic and trivial to solve with conventional programming
▶
They are used in the MI literature to turn black boxes opaque
1
Largest prime less than
𝑝
=
1
1
3
9 of
24
Figure 8:
ℕ
<
𝑝
2
multiples of 13 or 27 (left) 11 (mid.) or primes (right)
3 |
Grokking on
𝒯
miiii
▶
For two-token samples, plot them varying one
on each axis (
Figure 9
)
▶
When a matrix is periodic use Fourier
▶
Singular value decomposition
Figure 9: Top singular vectors of
𝐔
𝑊
𝐸
𝒯
nanda
(top),
varying
𝑥
0
and
𝑥
1
in sample (left) and freq.
(right) space in
𝑊
out
𝒯
miiii
10 of
24
3 |
Grokking on
𝒯
miiii
▶
The model groks on
𝒯
miiii
(
Figure 10
)
▶
Needed GrokFast
[2]
on compute budget
▶
Final hyperparams are seen in
Table 6
rate
𝜆
wd
𝑑
lr
heads
1
1
0
1
2
1
3
256
3
1
0
4
4
Table 6: Hyperparams for
𝒯
miiii
Figure 10: Training (top) and validation (bottom)
accuracy during training on
𝒯
miiii
11 of
24
4 |
Embeddings
How the embedding layer deals with the difference between
𝒯
nanda
and
𝒯
miiii
12 of
24
4.1 |
Correcting for non-commutativity
▶
The position embs. of
Figure 12
reflects that
𝒯
nanda
is commutative and
𝒯
miiii
is not
Figure 11: Positional embeddings for
𝒯
nanda
(top)
and
𝒯
miiii
(bottom).
13 of
24
4.1 |
Correcting for non-commutativity
▶
The position embs. of
Figure 12
reflects that
𝒯
nanda
is commutative and
𝒯
miiii
is not
▶
Maybe: this corrects non-comm. of
𝒯
miiii
?
▶
Corr. is
0
.
9
5
for
𝒯
nanda
and
−
0
.
6
4
for
𝒯
miiii
Figure 12: Positional embeddings for
𝒯
nanda
(top)
and
𝒯
miiii
(bottom).
13 of
24
4.2 |
Correcting for multi-tasking
▶
For
𝒯
nanda
token embs. are essentially linear
combinations of 5 frequencies (
𝜔
)
▶
For
𝒯
miiii
more frequencies are in play
▶
Each
𝒯
miiii
subtask targets unique prime
▶
Possibility: One basis per prime task
Figure 13:
𝒯
nanda
(top) and
𝒯
miiii
(bottom) token
embeddings in Fourier basis
14 of
24
4.3 |
Sanity-check and task-mask
▶
Masking
𝑞
∈
{
2
,
3
,
5
,
7
}
yields we see a slight
decrease in token emb. freqs.
▶
Sanity check:
𝒯
baseline
has no periodicity
▶
The tok. embs. encode a basis per subtask?
Figure 14:
𝒯
baseline
(top),
𝒯
miiii
(middle) and
𝒯
masked
(bottom) token embeddings in Fourier
basis
15 of
24
5 |
Neurons
▶
Inspite of the dense Fourier basis of
𝑊
𝐸
𝒯
miiii
the
periodicity is clear
/public/figs/miiii/neurs_113_miiii.svg
), caption: [Activations of first three neurons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)],
)
16 of
24
5 |
Neurons
▶
(Probably redundant) sanity check:
Figure 16
confirms neurons are periodic
▶
See some freqs.
𝜔
rise into significance
▶
Lets log
|
𝜔
>
𝜇
𝜔
+
2
𝜎
𝜔
|
while training
Figure 16: FFT of Activations of first three neu
rons for
𝒯
nanda
(top) and
𝒯
miiii
(bottom)
17 of
24
Figure 17: Neurons as archive for
𝒯
basline
Figure 18: Neurons as algorithm
𝒯
miiii
Figure 19: Number of neurons with frequency
𝜔
above the theshold
𝜇
𝜔
+
2
𝜎
𝜔
6 |
The
𝜔
-Spike
▶
Neurs. periodic on solving
𝑞
∈
{
2
,
3
,
5
,
7
}
▶
When we generalize to the reamining tasks,
many frequencies activate (64-sample)
▶
Those
𝜔
’s are not useful for memory and not
useful after generalization
time
256
1024
4096
16384
65536
|
𝝎
|
0
0
1
0
1
8
10
Table 7: active
𝜔
’s through training
Figure 20:
Figure 19
(top) and validation accuracy
from
Figure 10
(bottom)
18 of
24
6 |
The
𝜔
-Spike
▶
GrokFast
[2]
shows time gradient sequences is (arguably) a stocastical signal with:
▶
A fast varying overfitting component
▶
A slow varying generealizing component
▶
My work confirms this to be true for
𝒯
miiii
…
▶
… and observes a strucutre that seems to fit
neither
of the two
19 of
24
6 |
The
𝜔
-Spike
▶
Future work:
▶
Modify GrokFast to assume a third stochastic component
▶
Relate to signal processing literature
▶
Can more depth make tok-embedding sparse?
20 of
24
References
[1]
N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress Measures for Grokking
via Mechanistic Interpretability,” no. arXiv:2301.05217. arXiv, Oct. 2023.
[2]
J. Lee, B. G. Kang, K. Kim, and K. M. Lee, “Grokfast: Accelerated Grokking by Amplifying Slow
Gradients,” no. arXiv:2405.20233. Jun. 2024.
21 of
24
A |
Stochastic Signal Processing
We denote the weights of a model as
𝜃
. The gradient of
𝜃
with respect to our loss function at time
𝑡
we denote
𝑔
(
𝑡
)
. As we train the model,
𝑔
(
𝑡
)
varies, going up and down. This can be thought of as
a stocastic signal. We can represent this signal with a Fourier basis. GrokFast posits that the slow
varying frequencies contribute to grokking. Higer frequencies are then muted, and grokking is indeed
accelerated.
22 of
24
B |
Discrete Fourier Transform
Function can be expressed as a linear combination of cosine and sine waves. A similar thing can be
done for data / vectors.
23 of
24
C |
Singular Value Decomposition
An
𝑛
×
𝑚
matrix
𝑀
can be represented as a
𝑈
Σ
𝑉
∗
, where
𝑈
is an
𝑚
×
𝑚
complex unitary matrix,
Σ
a rectangular
𝑚
×
𝑛
diagonal matrix (padded with zeros), and
𝑉
an
𝑛
×
𝑛
complex unitary matrix.
Multiplying by
𝑀
can thus be viewed as first rotating in the
𝑚
-space with
𝑈
, then scaling by
Σ
and
then rotating by
𝑉
in the
𝑛
-space.
24 of
24