Intuitions for Transformer Circuits

An exploration of mechanistic interpretability to understand the inner workings of Transformer models, focusing on the residual stream as shared memory and attention as an addressing mechanism.
Intuitions for Transformer Circuits
A mental model for addressing the residual stream
In a previous post on language modeling, I implemented a GPT-style transformer. Lately I’ve been learning mechanistic interpretability to go deeper and understand why the transformer works on a mathematical level.
This post is a brain dump of what I’ve learned so far after reading A Mathematical Framework for Transformer Circuits (herein: “Framework”) and working through the Intro to Mech Interp section on ARENA. My goal is to describe my current intuition for the paper, especially parts I was confused about so that perhaps my take can help others gain clarity on these areas as well.
First, a brief aside on my overall motivation for working on this stuff. Mechanistic Interpretability (MI/mech interp) is the study of ML model internals whose aim is to understand from first principles why models behave and work as they do. You can kind of think of it as the machine learning analogue of reverse engineering software. It is similar in spirit to the science of biological neural networks, but applied to artificial neural networks instead.
MI is part of a broader field of interpretability, which is used in yet another field called AI alignment. Alignment strives to make our large AI models aligned with human values. Basically, the overall goal is to understand and control the models before they control us. To ensure that they don’t engage in harmful, deceptive, dangerous, or subversive behavior. Unfortunately, we live in a world where large language models have encouraged “successful” suicide, engaged in blackmail for self-preservation, and asserted humans should be enslaved by AI. This current version of reality is unacceptable to me.
And as if that weren’t enough, we don’t even understand why these models do what they do. They are the only man-made technology in history that we don’t fully understand from first principles. Given this state of reality, I think that alignment is one of the most important problems we face today and one we have to get right. As a personal bonus, the alignment problem is as fascinating as it is important. It provides an outlet for me to leverage my specific technical skills and interests towards a meaningful cause. It is also extremely difficult, and I like a good challenge.
Ok, now back to the originally scheduled programming.
Attention-Only Transformers
Framework does a deep dive into the key components of a simplified transformer-based language model. It analyzes transformer blocks that only have multi-head attention. This means no MLPs and no layernorms. This leaves the token embedding and positional encoding at the beginning, followed by n layers of multi-head attention, followed by the unembedding at the end.
My goal in this post is not to re-derive all the math, because the *Framework *paper does a better job. Instead I want to share how I conceptualize the most important takeaways, especially for areas that I thought were confusing at first so that if you have the same confusion perhaps my take will bring some clarity. In my view, the most important concepts to understand from this paper are the residual stream, attention, circuits, and induction heads.
The Residual Stream
Mathematically, the residual stream is a high dimensional vector space. You will usually see the dimension of the residual stream specified as d_model in GPT-related papers and code. For example, GPT2-small uses a d_model of 768.
Conceptually, the residual stream is like shared memory. It is used much like the DRAM on your computer. Different components of the model (attention, MLPs, etc) perform loads and stores from that memory. The loads and stores occur sequentially through the forward pass, one layer at a time. However each component in a given layer loads in parallel and stores in parallel with the others. The model learns to carve out subspaces in this vector space. This helps prevent components from clobbering over what previous components have written. The residual stream itself doesn’t do any computation, but serves as a shared medium through which layers communicate with each other.
When I was presented with this view of the residual stream, my mind immediately started asking how far can we push this analogy to memory? Having worked in computer security for a decade, it made me wonder if there is an analogue to page tables and memory permissions? Could we bring the concepts of userspace and kernelspace to prevent “privileged” subspaces from being accessed by “unprivileged” subspaces?
But I’m getting ahead of myself. Let’s start with a simpler question: how does addressing work for the residual stream? In order to access a memory location, you have to have an address. Residual stream addresses can be decomposed into two logical parts, token:subspace, much like the classic segment:offset logical address from the x86 architecture. One major difference is that a traditional memory address is deterministic in the sense that only one value from one location is loaded. Addresses into the residual stream are “soft”, in general specifying a set of locations to load according to some learned probability distribution.
Attention
Conceptually, attention computes the first part of the token:subspace address. The fundamental purpose of attention is to specify which source token locations to load information from. Each row in the attention matrix is the “soft” distribution over the source (i.e. key) token indices from which information will be moved into the destination token (i.e. query).
It is important to understand that attention is all about figuring out the token indices to read from. If we look at the residual stream as a two dimensional memory array, then attention probabilistically selects rows of this memory for each query.
So the token part of the address selects the rows in the residual stream via attention. What about the subspace part? How is it computed? Once we have this part then we can determine the actual *value *that is stored into the destination token’s location. To answer t
Source: Hacker News










