Rishiraj's blog

Scaling Beyond 4k

If you’re training a Large Language Model today, a 4k context window just doesn't cut it anymore. Users want to drop entire codebases into the prompt, analyze massive PDFs, and run complex, multi-turn agentic workflows. We all want our models to handle 128k tokens or more.

But as anyone who has actually tried to train a model on long sequences knows, pushing beyond a few thousand tokens breaks things. Your attention scores start decaying, your GPUs run out of memory, and your training throughput falls off a cliff.

In this post, we are going to walk through exactly how to solve these problems. We will look at how to tweak your model architecture to understand long-range dependencies, how to fix your data packing, and finally, how to distribute massive sequences across your GPU cluster using Context Parallelism and Zig-Zag Ring Attention.

Let's dig in.

The Positional Problem: Why Models Forget Distant Tokens

Transformers naturally have no sense of word order. They process everything at once. To fix this, we use positional embeddings. Almost every modern model uses RoPE (Rotary Positional Embedding), which encodes position information as rotation angles in a high-dimensional space.

RoPE is great, but it has a flaw when it comes to long contexts. As sequence lengths grow, the rotation angles grow. This causes the attention scores for distant tokens to decay too rapidly. If a token is 60,000 positions away, the model struggles to "see" it.

To fix this, we have to adjust the math. There are two common ways to slow down this attention decay:

  1. ABF (Adjusted Base Frequency): This is a brute-force but effective method. You simply increase the base frequency in RoPE's formulation. This slows down the rotation angles, preventing the attention scores from decaying too fast.
  2. YaRN (Yet another RoPE extensioN): This is a bit smarter. Instead of scaling everything uniformly, YaRN interpolates frequencies unevenly across the RoPE dimensions. It gives you smoother scaling and prevents catastrophic attention loss at extreme lengths.

But what if we just... turn off positional embeddings entirely?

There is a fascinating approach called NoPE (No Positional Embedding). If you train a transformer without explicit positional encodings, it implicitly learns positional information through causal masking. The beauty of NoPE is that it naturally extrapolates to longer contexts far better than RoPE.

The catch? NoPE models usually perform worse on short-context reasoning and knowledge tasks.

So, how do we get the best of both worlds? For models like Llama 4, Command A, and SmolLM3, the answer is a hybrid approach. You alternate layers. You apply RoPE to some layers (to keep that sharp, local reasoning and recency bias) and apply NoPE to the rest (to allow information retrieval across massive distances). When researchers ran ablations on a 1B parameter model, they found this hybrid approach maintained short-context performance while vastly improving long-context handling.

The Data Problem: Stop Attending to Granola Recipes

Before we talk about GPU clusters, we have to talk about how we feed data into the model.

During pretraining, we use fixed sequence lengths. But our actual documents (web pages, code files) are variable lengths. To avoid wasting compute on padding tokens, we pack multiple documents together into a single sequence, separated by End-Of-Sequence (EOS) tokens.

If you are training on a 4k context, this isn't a huge deal. But what happens when you train on 32k or 64k sequences?

If you look at datasets like FineWeb-Edu or Python-Edu, over 80% of the documents are shorter than 2,000 tokens. That means a single 32k training sequence is actually a mixed khichdi of 15 to 20 completely unrelated documents.

If you use standard causal masking, a token in document 15 can attend to everything in documents 1 through 14. You might have a Python script spending compute cycles attending to a granola bar recipe and a Wikipedia article on climate change. This doesn't just waste computation; it actually introduces noise that degrades model performance.

The fix is Intra-document masking. You modify the attention mask so that tokens can only attend to previous tokens within their specific document. It essentially resets the attention window at every document boundary. When researchers tested this, it had zero negative impact on short-context tasks, but it became absolutely crucial for speeding up and stabilizing their long-context training runs.

The Memory Wall: Why TP and SP Aren't Enough

Okay, your architecture is set and your data is masked correctly. Now you have to actually fit a 64k or 128k sequence into GPU memory.

You might think, “I’ll just use Tensor Parallelism (TP) and Sequence Parallelism (SP).”

And you should! TP splits the matrix multiplications across the hidden dimension, and SP shards the activation memory for operations like Dropout and LayerNorm across the sequence dimension. By combining TP and SP, you significantly reduce your memory footprint per GPU.

But there is a hard limit. As your sequence length grows, the activation memory for the Attention block still blows up. Even if you use selective activation recomputation (throwing away activations during the forward pass and recalculating them during the backward pass to save memory), you still have to hold some activations at layer boundaries.

If your sequence is long enough, the Attention block simply will not fit on a single node, even with TP=8.

This is where we introduce Context Parallelism (CP).

Enter Context Parallelism and Ring Attention

Context Parallelism takes the same idea as Sequence Parallelism—splitting the input along the sequence dimension—but applies it to the Attention block itself.

But Attention is tricky. In an MLP layer, every token is processed independently. In an Attention layer, every single token needs to access the Key and Value (K/V) pairs from every previous token in the sequence. If you shard the sequence across 4 GPUs, GPU 4 needs the K/V pairs that are sitting over on GPU 1, GPU 2, and GPU 3.

How do we do this without grinding our cluster to a halt with communication overhead?

You could just do a massive "All-Gather" operation, where every GPU sends its K/V pairs to every other GPU simultaneously. But that requires a ton of temporary memory because every GPU suddenly has to hold the entire sequence's K/V pairs at once. It defeats the purpose of splitting the memory in the first place.

The more elegant solution is Ring Attention.

Instead of an All-Gather, we do an All-to-All in a ring pattern. Imagine we have 4 GPUs.

  1. Each GPU starts computing the attention scores for the chunk of data it holds locally.
  2. At the same time, it asynchronously sends its K/V pairs to the next GPU in the ring, while receiving K/V pairs from the previous GPU.
  3. Once the local computation is done, the GPU immediately starts computing on the newly arrived K/V pairs.
  4. It repeats this process until the K/V pairs have made a full circle around the GPUs.

It’s a bucket brigade. It is incredibly memory efficient because each GPU only needs enough memory to hold one extra chunk of data at a time. And because communication is overlapped with computation, it stays fast.

The Triangle Problem: Why Zig-Zag is Necessary

Ring Attention sounds perfect, but a naive implementation falls apart in autoregressive language models. The culprit is the causal attention mask.

Think about what a causal mask looks like: it's a triangle. Early tokens only attend to a few things. Late tokens attend to everything.

If we deal out our sequence sequentially (GPU 1 gets tokens 1-4, GPU 2 gets tokens 5-8, etc.), we create a massive workload imbalance.

GPU 1 can compute its attention immediately. It doesn't need to wait for anything because tokens 1-4 only look at themselves. It finishes its work instantly and then sits idle. Meanwhile, GPU 4 (which has the last tokens in the sequence) has to wait for multiple rounds of the ring to get all the previous K/V pairs, and it has to do exponentially more math.

This creates a massive "bubble"—idle time where most of your GPUs are just staring at the wall waiting for one GPU to finish the heavy lifting.

The fix for this is Zig-Zag Ring Attention.

Instead of dealing the sequence out in a straight line, we mix up the ordering. We distribute the sequence in a zig-zag pattern so that every single GPU gets a mix of "early" tokens and "late" tokens.

When you rearrange the data this way, the causal mask is no longer a lopsided triangle on the cluster level. The computation is perfectly balanced. Every GPU does the exact same amount of work, at the exact same time, eliminating the idle bubble.

The Practical Recipe: How to Actually Train It

So you have your hybrid RoPE/NoPE architecture, your intra-document masking, and your Zig-Zag Context Parallelism ready to go. Should you just set the sequence length to 128k and hit train?

Absolutely not. That will destroy your compute budget.

Attention mechanisms scale quadratically. Training on massive sequences from step one is a waste of time and money, because early in training, the model is just trying to learn basic, short-range correlations between words. Long sequences don't help it learn English any faster.

The standard industry practice is sequential scaling.

Stage 1: The Foundation (4k) You do the vast majority of your training here. For SmolLM3, the folks at HF trained for 8 trillion tokens at a 4k context length. This builds the core intelligence of the model.

Stage 2: The First Jump (32k) Toward the very end of training, you extend the context. HF folks ran a stage over 50 billion tokens where they pushed the context to 32k. Crucial detail: When you make this jump, you must adjust your RoPE base frequency. HF folks bumped their theta value up to 2 Million.

Stage 3: The Final Stretch (64k) HF folks did one final, short training stage to push the context to 64k, bumping the RoPE base frequency again to 5 Million. They found that going any higher on the frequency (like 10M) slightly hurt short-context tasks like math (GSM8k), so 5M was the sweet spot.

Notice they stopped training at 64k.

Stage 4: Inference Extrapolation (128k) HF folks wanted the model to support 128k, but training on it was too expensive. Instead, they rely on YaRN (Yet Another RoPE extensioN) to extrapolate at inference time. YaRN allows for roughly a 4x increase in sequence length without retraining.

They actually tested using YaRN to push the 64k model to 256k, but performance on long-context benchmarks like RULER degraded. 128k was the safe, reliable limit.

Summary

Scaling beyond 4k isn't just about throwing more GPUs at the problem. It requires a full-stack approach.

You have to rethink your positional embeddings (mixing RoPE and NoPE). You have to clean up your dataloader (intra-document masking). You have to escape the memory wall by sharding the sequence across GPUs (Context Parallelism). And you have to keep your hardware utilization high by passing data in a circle and balancing the causal mask workload (Zig-Zag Ring Attention).

And most importantly, you have to be pragmatic about your compute budget, scaling up your sequence lengths in careful, deliberate stages. If you get all of these pieces right, you can build models that don't just generate text, but can actually reason over entire books and code repositories.