Architectures

GPT / Decoder-only

Autoregressive transformers with causal masking β€” the architecture behind every modern LLM. Each token attends only to past tokens; generation proceeds one token at a time via next-token prediction.

Causal
Masking
Autoregressive
Generation
Basis of
All modern LLMs

Table of Contents

SECTION 01

The decoder-only design

A decoder-only transformer is a stack of N identical blocks, each containing: Causal Self-Attention β†’ Add & Norm β†’ FFN β†’ Add & Norm. There's no encoder, no cross-attention β€” just causal self-attention and feed-forward layers stacked repeatedly.

The "decoder-only" name is historical (from the original encoder-decoder transformer for translation). Modern LLMs (GPT-2 through GPT-4, Llama, Mistral, Claude, Gemini) are all decoder-only. The architecture works for generation because: (1) causal masking allows training on all token positions simultaneously via teacher forcing; (2) at inference, the model generates one token at a time, appending each new token to the growing context.

Compared to encoder-only: each token only sees its past, giving it a narrower view. But the model compensates with scale β€” LLMs have billions of parameters and trillions of training tokens, encoding world knowledge directly into weights rather than relying on bidirectional context within a single forward pass.

SECTION 02

Causal masking in depth

import torch
import torch.nn.functional as F

def causal_self_attention(Q, K, V):
    # Q, K, V: (batch, heads, seq_len, d_k)
    B, H, N, dk = Q.shape

    # Compute raw scores
    scores = Q @ K.transpose(-2, -1) / dk ** 0.5   # (B, H, N, N)

    # Build causal mask: -inf above the diagonal
    causal_mask = torch.triu(
        torch.full((N, N), float('-inf'), device=Q.device), diagonal=1
    )  # (N, N)

    scores = scores + causal_mask   # broadcast over batch and head dims

    weights = F.softmax(scores, dim=-1)
    return weights @ V

# Visualise what causal masking does:
# For a 4-token sequence, the attention weight matrix looks like:
# token 0:  [1.0, 0.0, 0.0, 0.0]  β€” can only see itself
# token 1:  [0.6, 0.4, 0.0, 0.0]  β€” can see tokens 0 and 1
# token 2:  [0.3, 0.4, 0.3, 0.0]  β€” can see tokens 0, 1, 2
# token 3:  [0.2, 0.3, 0.3, 0.2]  β€” can see all tokens

# PyTorch's efficient implementation (no explicit mask needed):
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
SECTION 03

Training: next-token prediction

import torch
import torch.nn as nn

# Training objective: predict token[i+1] given tokens[0..i]
# This is maximum likelihood estimation over the training corpus

def compute_lm_loss(model, input_ids: torch.Tensor) -> torch.Tensor:
    # input_ids: (batch, seq_len)
    # Labels are input_ids shifted by 1: predict next token at each position
    logits = model(input_ids)          # (batch, seq_len, vocab_size)

    # Shift: inputs are [0..n-2], targets are [1..n-1]
    shift_logits = logits[:, :-1, :]   # (batch, seq_len-1, vocab_size)
    shift_labels = input_ids[:, 1:]    # (batch, seq_len-1)

    loss = nn.CrossEntropyLoss()(
        shift_logits.reshape(-1, shift_logits.size(-1)),
        shift_labels.reshape(-1)
    )
    return loss

# The cross-entropy loss over tokens is equivalent to:
# -log P(token_1 | token_0) - log P(token_2 | token_0, token_1) - ...
# Minimising this = maximising the probability of the training sequences

# Perplexity = exp(average loss) β€” main LLM training metric
# PPX of 10 means model is as uncertain as choosing from 10 equal options
def perplexity(loss: float) -> float:
    return torch.exp(torch.tensor(loss)).item()
SECTION 04

Sampling strategies

import torch
import torch.nn.functional as F

def sample_next_token(logits: torch.Tensor, temperature: float = 1.0,
                       top_k: int = 0, top_p: float = 1.0) -> int:
    # logits: (vocab_size,) unnormalised scores for next token
    logits = logits / temperature  # higher temp = more random

    # Top-k: keep only top k tokens
    if top_k > 0:
        kth_val = logits.topk(top_k).values[-1]
        logits[logits < kth_val] = float('-inf')

    # Top-p (nucleus): keep smallest set of tokens summing to >= p probability
    if top_p < 1.0:
        sorted_logits, sorted_idx = logits.sort(descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumsum = probs.cumsum(dim=-1)
        # Remove tokens once cumulative probability exceeds top_p
        remove = cumsum - probs > top_p
        sorted_logits[remove] = float('-inf')
        logits[sorted_idx] = sorted_logits

    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()

# Greedy decoding (temperature=0 equivalent):
def greedy_next(logits): return logits.argmax().item()

# Typical production settings:
# temperature=0.7, top_p=0.9 β€” creative but coherent
# temperature=0.1, top_k=1  β€” near-deterministic, for code/structured output
# temperature=1.0, top_p=1.0 β€” sample directly from model distribution
SECTION 05

KV cache

During autoregressive generation, at each step the model must compute K and V for the entire prefix (all previously generated tokens). Without caching, generating token t requires O(t) attention computation β€” total generation is O(nΒ²).

The KV cache stores K and V for all past tokens. At each new step, only compute Q, K, V for the new token; then append K and V to the cache and attend over cached + new K/V. This reduces generation to O(n) total.

class CachedAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.kv_cache = None  # (batch, past_len, 2, d_model)

    def forward(self, x, use_cache=True):
        # x: (batch, 1, d_model) β€” single new token during generation
        if use_cache and self.kv_cache is not None:
            past_k, past_v = self.kv_cache
            # Only compute Q for new token, but K/V for full sequence
            full_seq = torch.cat([past_k, x], dim=1)
            out, _ = self.mha(x, full_seq, full_seq)
        else:
            out, _ = self.mha(x, x, x)

        if use_cache:
            prefix = self.kv_cache[0] if self.kv_cache else x
            self.kv_cache = (torch.cat([prefix, x], dim=1), None)  # simplified
        return out

# KV cache memory: 2 * num_layers * seq_len * d_model * bytes_per_element
# For Llama-3 8B: 2 * 32 * 4096 * 4096 * 2 bytes = 2GB at fp16 for 4K tokens
SECTION 06

Minimal GPT implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class GPTBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        N = x.shape[1]
        causal_mask = nn.Transformer.generate_square_subsequent_mask(N, device=x.device)
        # Pre-norm + residual connection (modern GPT uses pre-norm, not post-norm)
        h = self.ln1(x)
        h, _ = self.attn(h, h, h, attn_mask=causal_mask, is_causal=True)
        x = x + h
        x = x + self.ffn(self.ln2(x))
        return x

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=4, max_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([GPTBlock(d_model, num_heads, d_model*4)
                                      for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.shape
        x = self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
        for block in self.blocks:
            x = block(x)
        return self.lm_head(self.ln_f(x))   # (B, T, vocab_size)

model = TinyGPT(vocab_size=50257)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")  # ~6.5M
SECTION 07

Gotchas

Pre-norm vs post-norm matters. The original transformer uses post-norm (LayerNorm after the residual add). Modern LLMs (GPT-2+, Llama) use pre-norm (LayerNorm before the attention/FFN). Pre-norm is more stable to train at large scale β€” gradients flow more cleanly through the residual stream. If you see training instability, check whether you're using post-norm.

The FFN is the majority of parameters. For a 7B Llama model, attention is ~25% of parameters, FFN is ~65%, embeddings ~10%. The FFN (two linear layers with a non-linearity, often SwiGLU) is where most of the "factual knowledge" is stored. Don't neglect FFN optimisation when profiling inference speed.

Temperature 0 is not greedy decoding. Setting temperature to 0 makes all logits divide by 0 (or near-0), producing NaN or all-uniform distributions. Most inference libraries implement "temperature=0" as argmax (true greedy decoding) as a special case. When writing your own sampler, handle temperature≀0 explicitly with argmax rather than the division formula.

SECTION 08

Architecture Comparison

PropertyDecoder-Only (GPT style)Encoder-Decoder (T5 style)Encoder-Only (BERT style)
Primary useGeneration, completionTranslation, summarisationClassification, embeddings
Attention typeCausal (left-to-right)Full (encoder) + Causal (decoder)Bidirectional full attention
Prompt conditioningIn-context (prefix)Separate input sequenceN/A (no generation)
KV cache benefitHigh β€” caches all prior tokensPartial β€” encoder KV fixedN/A
Parameter efficiencyOne model for all tasksBetter for seq2seq tasksCompact for classification

The dominance of decoder-only architectures in frontier LLMs reflects a practical advantage: they unify pretraining (next-token prediction) and fine-tuning (instruction following) into a single paradigm. Encoder-decoder models require a separate pretraining objective and a harder alignment between input and output sequences. For most production GenAI use cases today, decoder-only is the default choice; encoder-only models remain valuable for high-throughput classification and embedding workloads where generation is not needed.

Long system prompts that repeat across requests are prime candidates for prompt caching. The KV pairs for static prefixes can be computed once and reused, cutting both latency and cost for the prefill phase. This is possible precisely because decoder-only models process the prompt autoregressively and the KV cache naturally partitions by sequence position.

From a practical inference standpoint, the distinction between prefill and decode phases matters for latency optimisation. Prefill processes your entire prompt in one forward pass (parallelisable on GPU) and is fast even for long prompts. Decode generates tokens one at a time (sequential) and is the bottleneck for long outputs. This asymmetry means that increasing prompt length has a much smaller latency impact than increasing max output tokens, which is useful for system prompt design trade-offs.