NOAH SYRKIS

Mechanistic Interpretability on(multi-task) Irreducible IntegerIdentifiersNoah SyrkisJanuary 16, 20251 |Mechanistic Interpretability (MI)2 |Modular Arithmetic3 |Grokking on 𝒯miiii4 |Embeddings5 |Neurons6 |The 𝜔-SpikeFigure 1: <𝑝2 multiples of 13 or 27 (left) 11 (mid.) or primes (right)“This disgusting pile of matrices is actually just an astoundingly poorly written, elegant andconsice algorithm” — Neel Nanda¹¹Not verbatim, but the gist of it1 |Mechanistic Interpretability (MI)Sub-symbolic nature of deep learning obscures model mechanismsNo obvious mapping from the weights of a trained model to math notationMI is about reverse engineering these models, and looking closely at themMany low hanging fruits / practical botany phase of the scienceHow does a given model work? How can we train it faster? Is it safe?1 of 181 |GrokkingGrokking [1] is “sudden generalization”MI (often) needs a mechanismGrokking is thus convenient for MILee et al. (2024) speeds up grokking byboosting slow gradients as per Eq. 1For more see Appendix A(𝑡)=(𝑡1)𝛼+𝑔(𝑡)(1𝛼)(1.1)̂𝑔(𝑡)=𝑔(𝑡)+𝜆(𝑡)(1.2)2 of 181 |VisualizingMI needs creativity … but there are tricks:For two-token samples, plot them varyingone on each axis (Figure 2)When a matrix is periodic use FourierSingular value decomp. (Appendix C).Take away: get commfy with esch-plotsFigure 2: Top singular vectors of 𝐔𝑊𝐸𝒯nanda(top), varying 𝑥0 and 𝑥1 in sample (left) andfreq. (right) space in 𝑊out𝒯miiii3 of 18Figure 3: Shamleless plug: visit github.com/syrkis/esch for more esch plots2 |Modular Arithmetic“Seminal” MI paper by Nanda et al. (2023)focuses on modular additon (Eq. 2)Their final setup trains on 𝑝=113They train a one-layer transformerWe call their task 𝒯nandaAnd ours, seen in Eq. 3, we call 𝒯miiii(𝑥0+𝑥1)mod𝑝,𝑥0,𝑥1(2)(𝑥0𝑝0+𝑥1𝑝1)mod𝑞,𝑞<𝑝(3)4 of 182 |Modular Arithmetic𝒯miiii is non-commutative …… and multi-task: 𝑞 ranges from 2 to 109¹𝒯nanda use a single layer transformerNote that these tasks are synthetic and trivial to solve with conventional programmingThey are used in the MI literature to turn black boxes opaque¹Largest prime less than 𝑝=1135 of 183 |Grokking on 𝒯miiiiThe model groks on 𝒯miiii (Figure 4)Needed GrokFast [2] on compute budgetFinal hyperparams are seen in Table 1rate𝜆wd𝑑lrheads110121325631044Table 1: Hyperparams for 𝒯miiiiFigure 4: Training (top) and validation(bottom) accuracy during training on 𝒯miiii6 of 184 |EmbeddingsThe position embs. of Figure 5 reflects that𝒯nanda is commutative and 𝒯miiii is notMaybe: this corrects non-comm. of 𝒯miiii?Corr. is 0.95 for 𝒯nanda and 0.64 for 𝒯miiiiFigure 5: Positional embeddings for 𝒯nanda(top) and 𝒯miiii (bottom).7 of 184 |EmbeddingsFor 𝒯nanda token embs. are essentially linearcombinations of 5 frequencies (𝜔)For 𝒯miiii more frequencies are in playEach 𝒯miiii subtask targets unique primePossibility: One basis per prime taskFigure 6: 𝒯nanda (top) and 𝒯miiii (bottom)token embeddings in Fourier basis8 of 184 |EmbeddingsMasking 𝑞{2,3,5,7} yields we see aslight decrease in token emb. freqs.Sanity check: 𝒯baseline has no periodicityThe tok. embs. encode a basis per subtask?Figure 7: 𝒯baseline (top), 𝒯miiii (middle) and𝒯masked (bottom) token embeddings in Fourierbasis9 of 185 |NeuronsFigure 8 shows transformer MLP neuronactivations as 𝑥0, 𝑥1 vary on each axisInspite of the dense Fourier basis of 𝑊𝐸𝒯miiiithe periodicity is clearFigure 8: Activations of first three neurons for𝒯nanda (top) and 𝒯miiii (bottom)10 of 185 |Neurons(Probably redundant) sanity check:Figure 9 confirms neurons are periodicSee some freqs. 𝜔 rise into significanceLets log |𝜔>𝜇𝜔+2𝜎𝜔| while trainingFigure 9: FFT of Activations of first threeneurons for 𝒯nanda (top) and 𝒯miiii (bottom)11 of 18Figure 10: Neurons as archive and algorithm. 𝒯basline on top, FFT on right.Figure 11: Number of neurons with frequency 𝜔 above the theshold 𝜇𝜔+2𝜎𝜔6 |The 𝜔-SpikeNeurs. periodic on solving 𝑞{2,3,5,7}When we generalize to the reamining tasks,many frequencies activate (64-sample)Those 𝜔’s are not useful for memory andnot useful after generalizationtime256102440961638465536|𝝎|00101810Table 2: active 𝜔’s through trainingFigure 12: Figure 11 (top) and validationaccuracy from Figure 4 (bottom)12 of 186 |The 𝜔-SpikeGrokFast [2] shows time gradient sequences is (arguably) a stocastical signal with:A fast varying overfitting componentA slow varying generealizing componentMy work confirms this to be true for 𝒯miiii… and observes a strucutre that seems to fit neither of the two13 of 186 |The 𝜔-SpikeFuture work:Modify GrokFast to assume a third stochastic componentRelate to signal processing literatureCan more depth make tok-embedding sparse?14 of 18TAKReferences[1]A. Power, Y. Burda, H. Edwards, I. Babuschkin, and V. Misra, “Grokking: GeneralizationBeyond Overfitting on Small Algorithmic Datasets,” no. arXiv:2201.02177. arXiv, Jan.2022. doi: 10.48550/arXiv.2201.02177.[2]J. Lee, B. G. Kang, K. Kim, and K. M. Lee, “Grokfast: Accelerated Grokking byAmplifying Slow Gradients,” no. arXiv:2405.20233. Jun. 2024.[3]N. Nanda, L. Chan, T. Lieberum, J. Smith, and J. Steinhardt, “Progress Measures forGrokking via Mechanistic Interpretability,” no. arXiv:2301.05217. arXiv, Oct. 2023.15 of 18A |Stochastic Signal ProcessingWe denote the weights of a model as 𝜃. The gradient of 𝜃 with respect to our loss function attime 𝑡 we denote 𝑔(𝑡). As we train the model, 𝑔(𝑡) varies, going up and down. This can bethought of as a stocastic signal. We can represent this signal with a Fourier basis(Appendix B). GrokFast posits that the slow varying frequencies contribute to grokking. Higerfrequencies are then muted, and grokking is indeed accelerated.16 of 18B |Discrete Fourier TransformFunction can be expressed as a linear combination of cosine and sine waves. A similar thingcan be done for data / vectors.17 of 18C |Singular Value DecompositionAn 𝑛×𝑚 matrix 𝑀 can be represented as a 𝑈Σ𝑉, where 𝑈 is an 𝑚×𝑚 complex unitarymatrix, Σ a rectangular 𝑚×𝑛 diagonal matrix (padded with zeros), and 𝑉 an 𝑛×𝑛 complexunitary matrix. Multiplying by 𝑀 can thus be viewed as first rotating in the 𝑚-space with 𝑈,then scaling by Σ and then rotating by 𝑉 in the 𝑛-space.18 of 18