Skip to main content

LSTM Variants

The vanilla LSTM cell with input / forget / output gates is the most commonly taught form, but a long series of papers explored variants that adjust gating, memory pathways, or topology. This page surveys the variants most likely to appear in production code, in papers, or in interview questions, and explains when each is actually worth using.

For the baseline LSTM definition (cell state CtC_t, gates iti_t, ftf_t, oto_t, additive memory path), see Gated RNNs and Sequence-to-Sequence.

Quick reference

VariantYearModificationWhen to use
PeepholeGers & Schmidhuber 2000Gates see Ct1C_{t-1}Precise timing tasks; mostly historical
Coupled forget-input (CIFG)Greff et al. 2017it=1fti_t = 1 - f_tFewer parameters with little quality loss
Bidirectional (BiLSTM)Schuster & Paliwal 1997Two LSTMs read fwd + bwdOffline tagging / NER / encoder side
Stacked / Deep LSTMGraves 2013LL LSTMs in sequenceStandard for capacity scaling
Projection LSTM (LSTMP)Sak et al. 2014Hidden state linearly projected before outputSpeech recognition; reduces large hidden dim
ConvLSTMShi et al. 2015Replace matmuls with convolutionsSpatiotemporal data (video, weather)
Tree-LSTMTai et al. 2015Children aggregate into parentSyntactic structure / dependency trees
Highway / Residual LSTMZhang et al. 2016Skip connection over LSTM blockVery deep stacks
mLSTM / sLSTM (xLSTM)Beck et al. 2024Matrix memory / exponential gatesModern revival, competes with linear-attention

Peephole connections

Vanilla LSTM gates use only the previous hidden state ht1h_{t-1} and the current input xtx_t. Peephole connections let each gate also inspect the previous cell state Ct1C_{t-1} (and sometimes CtC_t for the output gate):

ft=σ(Wfxt+Ufht1+VfCt1+bf)it=σ(Wixt+Uiht1+ViCt1+bi)ot=σ(Woxt+Uoht1+VoCt+bo)\begin{aligned} f_t &= \sigma(W_f x_t + U_f h_{t-1} + V_f \odot C_{t-1} + b_f) \\ i_t &= \sigma(W_i x_t + U_i h_{t-1} + V_i \odot C_{t-1} + b_i) \\ o_t &= \sigma(W_o x_t + U_o h_{t-1} + V_o \odot C_t + b_o) \end{aligned}

The diagonal weight vector VgV_g (one per gate gg) is elementwise on the cell state. Peephole connections improved LSTM accuracy on timing tasks where the model needs to count intervals — e.g., the original test was learning to produce spikes at precise times after a trigger.

Practical relevance today: small. The empirical study by Greff et al. (2017) tested eight variants on speech, handwriting, and music tasks and found peephole connections offered no statistically significant improvement on those problems. Most modern frameworks omit them by default.

Coupled input-forget gate (CIFG)

Observation: the forget gate ftf_t controls how much old memory is kept, while the input gate iti_t controls how much new candidate C~t\tilde C_t enters. In many trained LSTMs the two are anti-correlated. CIFG removes this redundancy by setting

it=1ft.i_t = 1 - f_t.

This is the same idea as a GRU's update gate. A CIFG LSTM has 25% fewer parameters in the gating block. The Greff study found CIFG performed comparably to vanilla LSTM. If memory is tight, CIFG is a clean win.

Bidirectional LSTM (BiLSTM)

For tasks where the entire input sequence is available before any output is needed — POS tagging, named-entity recognition, encoder side of sequence-to-sequence — a unidirectional LSTM wastes information about future context. BiLSTM runs two LSTMs and concatenates their hidden states:

ht=LSTM(xt,ht1)ht=LSTM(xt,ht+1)ht=[ht;ht]\overrightarrow{h_t} = \mathrm{LSTM}_\rightarrow(x_t, \overrightarrow{h_{t-1}}) \qquad \overleftarrow{h_t} = \mathrm{LSTM}_\leftarrow(x_t, \overleftarrow{h_{t+1}}) \qquad h_t = [\overrightarrow{h_t}; \overleftarrow{h_t}]

Cannot be used in causal language modeling (the backward pass would leak future tokens) — it's an encoder-only tool. PyTorch: nn.LSTM(..., bidirectional=True).

Stacked / Deep LSTM

The standard recipe for adding capacity to an RNN is to stack layers: the hidden state output of layer \ell becomes the input of layer +1\ell+1.

Empirically, 2-4 layers helps; beyond that gradients become hard to train without skip connections (see Residual LSTM below). All major encoder-decoder MT systems before transformers used 4-8 layer BiLSTM encoders.

Projection LSTM (LSTMP)

When the hidden size nn is large, the LSTM weight matrix is O(n2)O(n^2) per gate. The Google ASR team (Sak et al. 2014) introduced a linear projection layer inside the recurrence:

ht=Wrht,htRr,rnh_t' = W_r h_t, \qquad h_t' \in \mathbb{R}^{r}, \quad r \ll n

The projected hth_t' is what the next time step (and the next layer) sees, while the larger hth_t stays internal. This decouples representational capacity (nn) from inter-time-step bandwidth (rr), letting models be wider without quadratic blowup. Standard in production speech models for years before transformers took over.

ConvLSTM

For spatiotemporal data — video, radar, weather grids — a vanilla LSTM flattens spatial structure into a vector. ConvLSTM (Shi et al. 2015) keeps the data 4D (batch×channel×H×W\text{batch} \times \text{channel} \times H \times W) and replaces every matrix multiplication with a 2D convolution:

it=σ(WxiXt+WhiHt1+bi)ft=σ(WxfXt+WhfHt1+bf)ot=σ(WxoXt+WhoHt1+bo)Ct=ftCt1+ittanh(WxcXt+WhcHt1+bc)Ht=ottanh(Ct)\begin{aligned} i_t &= \sigma(W_{xi} * X_t + W_{hi} * H_{t-1} + b_i) \\ f_t &= \sigma(W_{xf} * X_t + W_{hf} * H_{t-1} + b_f) \\ o_t &= \sigma(W_{xo} * X_t + W_{ho} * H_{t-1} + b_o) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c) \\ H_t &= o_t \odot \tanh(C_t) \end{aligned}

The * operator is 2D convolution; gates and cell state are now 4D tensors. Used in precipitation nowcasting and other grid-time-series problems. Has been mostly succeeded by 3D CNNs, ConvNets-with-attention, and Video Transformers, but ConvLSTM remains a strong baseline.

Tree-LSTM

For syntactic trees (dependency or constituency parses), the natural recurrence order is bottom-up over the tree rather than left-to-right over a sequence. Tree-LSTM (Tai et al. 2015) defines two variants:

  • Child-Sum Tree-LSTM: each parent sums hidden states from its children before applying gates. Useful when the number of children is variable.
  • N-ary Tree-LSTM: each parent has at most NN children with separate gate parameters per child position. Used when children are ordered (e.g., binary constituency parses).

Tree-LSTMs were strong on sentiment classification over Stanford Sentiment Treebank, but the architecture is fiddly to batch and largely fell out of favor when Transformers handled sentence-level structure with attention.

Highway / Residual LSTM

Very deep RNN stacks (10+ layers) suffer from vanishing gradients across both time and depth. The fix mirrors what ResNet did for CNNs: skip connections.

A simple Residual LSTM layer computes

ht()=LSTM()(ht(1))+ht(1)h_t^{(\ell)} = \mathrm{LSTM}^{(\ell)}(h_t^{(\ell-1)}) + h_t^{(\ell-1)}

(possibly with a linear projection if dimensions mismatch). Highway LSTM generalizes this with a learned gate TT controlling how much of the shortcut to use:

ht()=TLSTM()(ht(1))+(1T)ht(1).h_t^{(\ell)} = T \odot \mathrm{LSTM}^{(\ell)}(h_t^{(\ell-1)}) + (1 - T) \odot h_t^{(\ell-1)}.

Used in deep ASR stacks of ~10-20 layers in 2016-2018.

xLSTM (mLSTM and sLSTM)

In 2024 Beck et al. revived LSTM with two new cells designed to compete with Transformers:

  • sLSTM keeps a scalar memory cell but uses exponential gating it=exp()i_t = \exp(\cdot) and ft=exp()f_t = \exp(\cdot) with a normalizer state, allowing the cell to forget arbitrarily fast and remember arbitrarily long.
  • mLSTM replaces the scalar cell with a matrix memory CtRd×dC_t \in \mathbb{R}^{d \times d}, updated by an outer product of key and value vectors — directly analogous to a linear-attention KV cache.

Empirically xLSTM matches or beats Mamba on some language-modeling benchmarks at similar scale. See Mamba and the broader linear-attention family for context.

Choosing a variant in practice

Or — and this is the honest 2025 recommendation — consider whether you should be using an LSTM at all. For most sequence problems the practical choice is a Transformer (see Attention Is All You Need) or a modern linear-time alternative (Mamba, RWKV, Griffin). LSTMs remain useful for:

  • Online / streaming inference where the model truly cannot look ahead and a tiny constant-memory state matters
  • Edge devices where Transformer key-value cache memory is unaffordable
  • Reinforcement-learning policies that need a small hidden state passed across episodes
  • Time-series forecasting at small scale where data is insufficient to train attention

Common pitfalls

  • Treating GRU and LSTM as drop-in replacements. GRU exposes one state; LSTM exposes two (hth_t, CtC_t). When passing initial states or saving checkpoints, the API differs.
  • Forgetting that BiLSTM is non-causal. Cannot be used in next-token language modeling. A common bug: training a BiLSTM language model and seeing miraculous perplexity that does not survive inference.
  • Initializing the forget gate bias too low. The standard trick is to initialize bfb_f to 1.01.0 so the cell starts in "remember" mode. PyTorch's nn.LSTM does NOT do this by default — you must set it manually after construction.
  • Stacking too deep without skip connections. 4 layers is usually the sweet spot for vanilla stacked LSTM. Beyond 6 you need residual or highway connections.
  • Confusing ConvLSTM with running a regular LSTM after a CNN. ConvLSTM keeps spatial dims inside the recurrence; a CNN-then-LSTM pipeline collapses spatial information before recurrence.

Code: forget-gate bias init and BiLSTM in PyTorch

import torch
import torch.nn as nn

class BiLSTMTagger(nn.Module):
def __init__(self, vocab, emb_dim, hidden, n_tags, layers=2):
super().__init__()
self.embed = nn.Embedding(vocab, emb_dim)
self.lstm = nn.LSTM(
emb_dim, hidden,
num_layers=layers,
bidirectional=True,
batch_first=True,
dropout=0.3,
)
self.head = nn.Linear(2 * hidden, n_tags)
self._init_forget_bias()

def _init_forget_bias(self):
# PyTorch packs gates as [i, f, g, o]; second quarter is forget.
for name, p in self.lstm.named_parameters():
if "bias" in name:
n = p.size(0)
start, end = n // 4, n // 2
p.data[start:end].fill_(1.0)

def forward(self, tokens):
h = self.embed(tokens)
out, _ = self.lstm(h)
return self.head(out)

This module is a standard sequence tagger: BiLSTM with forget-bias init at 1.0, dropout between layers, linear head over the bidirectional hidden states.

Connections