RWKV (Peng et al., 2023)
Peng, Alcaide, Anthony, Albalak, Arcadinho, Biderman, Cao, Cheng, and many collaborators' "RWKV: Reinventing RNNs for the Transformer Era" proposes Receptance Weighted Key Value, a language-model architecture that trains with Transformer-like parallelism but runs autoregressive inference like an RNN. Its central design is a channelwise linear-attention recurrence with a small fixed state rather than a growing key-value cache.
RWKV follows Hyena in attacking the Transformer's long-sequence cost, but it chooses a recurrent formulation rather than FFT long convolution. It also anticipates Mamba: both show that modern recurrent models can scale, but Mamba replaces RWKV's time-invariant WKV recurrence with selective state-space dynamics.
Definitions
Problem and motivation. A decoder-only Transformer stores key and value vectors for every previous token during generation. That cache grows linearly with context length and is read repeatedly. Ordinary RNNs avoid the cache by keeping a fixed hidden state, but historically they were harder to train at scale and underperformed Transformers. RWKV tries to combine the two: parallelizable training over sequences, constant-state inference, and enough expressive capacity to behave competitively with Transformers.
RWKV stands for Receptance Weighted Key Value. The main time-mixing block produces vectors analogous to attention quantities:
The parameters implement token shift: each channel can interpolate between the current token and the previous token before projection.
The WKV operation is a per-channel exponentially decayed weighted average. A simplified single-channel form is
Here is a learned nonnegative time-decay parameter and is a learned bonus for the current token. The output gate, called receptance, is usually
RWKV blocks also contain a channel-mixing subblock, analogous to the feed-forward or MLP part of a Transformer, but with token shift and gating.
Key results
Method. RWKV replaces quadratic query-key attention with channelwise decayed accumulation. The WKV numerator and denominator can be updated recurrently, so inference only needs a fixed amount of state per layer and channel. During training, the same recurrence can be computed in a time-parallel mode using scan-like kernels and matrix multiplications. This is why the paper describes RWKV as having a Transformer form for training and an RNN form for inference.
The model is a stack of residual blocks. Each block has a time-mixing subblock and a channel-mixing subblock, with layer normalization and carefully designed initialization. The time-mixing recurrence contains exponential terms, so the paper uses numerically stable updates. Rather than store raw sums that can overflow, an implementation tracks scaled numerator, denominator, and a running maximum-like term.
Architecture details and hyperparameters. The paper scales models from 169M to 14B parameters, trained for one epoch on The Pile, about 330B tokens. The training context length is 1024 tokens for the main pretrained models. The paper uses Adam without weight decay, bfloat16 precision, dynamic batch sizes of 128 or 256 sequences, and an exponential learning-rate decay schedule. It reports a parameter formula of the form
where is vocabulary size, is model dimension, and is the number of layers. It also emphasizes custom initialization: many weights are initialized near zero or identity-like behavior to stabilize deep recurrent training.
The recurrent state is small. The paper describes each layer's state as a handful of vectors of dimension ; in practical terms, this is independent of generated sequence length. That independence is the main contrast with a Transformer KV cache.
Benchmarks. RWKV reports pretrained checkpoints from 169M to 14B and compares with similarly sized Transformers such as Pythia, OPT, and BLOOM on a FLOP-matched basis. The paper reports competitive average zero-shot performance across tasks including ARC, BoolQ, COPA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, and WinoGrande. It also reports that RWKV follows a Transformer-like scaling-law relationship between compute and loss, with a strong log-log fit in their experiments.
For long context, the paper fine-tunes by progressively increasing context length from 1024 to 2048, 4096, and 8192 tokens and observes decreasing Pile test loss as context grows. On Long Range Arena, the paper reports RWKV as second only to S4 across five datasets, with stronger behavior on text and code-like tasks than on image/pathfinder tasks. The paper is careful about limitations: fixed-state recurrence can struggle with tasks requiring exact recall of many small details over long contexts.
Visual
| Architecture | Training over tokens | Inference state | Long-context bottleneck | Main tradeoff |
|---|---|---|---|---|
| Transformer decoder | Parallel | KV cache grows with length | Attention over cache | Strong exact access, high memory |
| Basic RNN | Sequential unless scanned | Fixed hidden state | State compression | Efficient but historically weaker |
| RWKV | Parallelizable recurrence | Fixed WKV state | Compressed channel memory | Fast inference, harder exact recall |
| Mamba | Parallel selective scan | Fixed SSM state | Selective state compression | Stronger content-dependent recurrence |
Worked example 1: one-channel WKV update
Problem: compute the simplified WKV value at for one channel. Let
with decay and current-token bonus .
- The past terms for use positions and :
For :
For :
- The current term is
- The numerator is
- The denominator is
- Therefore
Check: the second token dominates because its key is high and it has not decayed; the first token still contributes but is halved by time decay.
Worked example 2: fixed recurrent state versus KV cache
Problem: compare state growth for a 24-layer decoder with model dimension and generated length . Use a simplified Transformer with one key and one value vector of dimension per token per layer, and a simplified RWKV state of scalars per layer.
- Transformer cache scalars:
-
First compute .
-
Then multiply by :
- RWKV state scalars:
- Ratio:
Check: this simplified calculation ignores heads, precision, and implementation details, but it captures the key asymptotic point: the Transformer cache grows with , while RWKV's recurrent state does not.
Code
import torch
def rwkv_wkv_step(k, v, state, time_decay, time_first):
"""Numerically simple RWKV-style single step.
k, v: [batch, channels]
state stores numerator and denominator for the decayed past.
"""
num, den = state
current_weight = torch.exp(time_first + k)
out = (num + current_weight * v) / (den + current_weight).clamp_min(1e-6)
decay = torch.exp(-time_decay).clamp(max=1.0)
next_num = decay * num + torch.exp(k) * v
next_den = decay * den + torch.exp(k)
return out, (next_num, next_den)
batch, channels = 2, 4
state = (torch.zeros(batch, channels), torch.zeros(batch, channels))
time_decay = torch.full((channels,), 0.7)
time_first = torch.zeros(channels)
for _ in range(3):
k = torch.randn(batch, channels)
v = torch.randn(batch, channels)
y, state = rwkv_wkv_step(k, v, state, time_decay, time_first)
print(y.shape)
Common pitfalls
- Treating RWKV as a normal RNN. Its time-mixing recurrence is designed to be parallelizable for training and constant-state for inference.
- Ignoring numerical stability. The exponential WKV formula should be implemented with stable rescaling in real systems.
- Assuming fixed state means unlimited exact memory. A fixed state can carry useful summaries, but exact retrieval of many details can be harder than with full attention.
- Comparing RWKV with Transformers trained on different data or token budgets without caveats. The paper makes FLOP-matched comparisons, but external baselines still differ.
- Forgetting token shift. The interpolation between and is part of how RWKV injects local temporal information.
- Using Transformer-style prompts blindly. The paper reports sensitivity to ordering because an RNN-like model cannot revisit earlier instruction text with explicit attention.
Connections
- Responds to the KV-cache and quadratic-attention costs introduced by Attention Is All You Need.
- Extends recurrent ideas from Sequence Modeling and RNNs with modern scaling, initialization, and parallel scan-style training.
- Differs from Hyena, which uses implicit long convolutions rather than WKV recurrence.
- Sets up Mamba, whose selective state-space recurrence can be seen as a stronger content-dependent recurrent token mixer.
- Related D2L pages: Gated RNNs and Sequence-to-Sequence, Attention and Transformers, and Computational Performance.
- Further reading: Attention Free Transformer, linear attention, QRNN, S4, RetNet, Mamba, and the Pythia suite for controlled language-model scaling comparisons.