Attention Residuals

Attention Residuals (AttnRes) is a drop-in replacement for standard residual connections that enables selective aggregation of earlier representations via learned attention, significantly boosting performance in reasoning and coding tasks.
Paper | arXiv | Overview | Results | Citation
(a) Standard residuals with uniform additive accumulation. (b) Full AttnRes: each layer attends over all previous outputs. (c) Block AttnRes: layers are grouped into blocks, reducing memory from O(Ld) to O(Nd). *
This is the official repository for Attention Residuals (AttnRes), a drop-in replacement for standard residual connections in Transformers that enables each layer to selectively aggregate earlier representations via learned, input-dependent attention over depth.
Standard residual connections accumulate all layer outputs with fixed unit weights. As depth grows, this uniform aggregation dilutes each layer's contribution and causes hidden-state magnitudes to grow unboundedly — a well-known problem with PreNorm.
AttnRes replaces this fixed accumulation with softmax attention over preceding layer outputs.
Full AttnRes is straightforward but requires O(Ld) memory at scale. Block AttnRes partitions layers into N blocks, accumulates within each block via standard residuals, and applies attention only over block-level representations. With ~8 blocks, it recovers most of Full AttnRes's gains while serving as a practical drop-in replacement with marginal overhead.
PyTorch-style pseudocode
def block_attn_res(blocks: list[Tensor], partial_block: Tensor, proj: Linear, norm: RMSNorm) -> Tensor:
"""
Inter-block attention: attend over block reps + partial sum.
"""
V = torch.stack(blocks + [partial_block]) # [N+1, B, T, D]
K = norm(V)
logits = torch.einsum('d, n b t d -> n b t', proj.weight.squeeze(), K)
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
return h
def forward(self, blocks: list[Tensor], hidden_states: Tensor) -> tuple[list[Tensor], Tensor]:
partial_block = hidden_states
h = block_attn_res(blocks, partial_block, self.attn_res_proj, self.attn_res_norm)
if self.layer_number % (self.block_size // 2) == 0:
blocks.append(partial_block)
partial_block = None
attn_out = self.attn(self.attn_norm(h))
partial_block = partial_block + attn_out if partial_block is not None else attn_out
h = block_attn_res(blocks, partial_block, self.mlp_res_proj, self.mlp_res_norm)
mlp_out = self.mlp(self.mlp_norm(h))
partial_block = partial_block + mlp_out
return blocks, partial_block
AttnRes consistently outperforms the baseline across all compute budgets. Block AttnRes matches the loss of a baseline trained with 1.25x more compute.
| Category | Benchmark | Baseline | AttnRes | |---|---|---|---| | General | MMLU | 73.5 | 74.6 | | Reasoning | GPQA-Diamond | 36.9 | 44.4 | | Logic | BBH | 76.3 | 78.0 | | Knowledge | TriviaQA | 69.9 | 71.8 | | Math & Code | Math | 53.5 | 57.1 | | Code | HumanEval | 59.1 | 62.2 | | Code | MBPP | 72.0 | 73.9 | | Chinese | CMMLU | 82.0 | 82.9 | | Chinese | C-Eval | 79.6 | 82.5 |
AttnRes improves across the board, with the largest gains on multi-step reasoning (+7.5 on GPQA-Diamond) and code generation (+3.1 on HumanEval). It mitigates PreNorm dilution: output magnitudes remain bounded across depth and gradient norms distribute more uniformly across layers.
Source: Hacker News










