Table of Contents:

LayerNorm

Layer normalization (Ba, 2016) standardizes values to have a mean of 0 and a variance of 1 (operating over hidden dimensions / number of neurons / hidden units):

(Ba, 2016) define the operation in the context of RNNs as:

\[\mathbf{h}^{t} = f\left[ \frac{\mathbf{g}}{\sigma^{t}} \odot \left(\mathbf{a}^{t} - \mu^{t}\right) + \mathbf{b} \right] \qquad \mu^{t} = \frac{1}{H} \sum_{i=1}^{H} a_i^{t} \qquad \sigma^{t} = \sqrt{ \frac{1}{H} \sum_{i=1}^{H} \left(a_i^{t} - \mu^{t}\right)^2 }\]

where the superscript \(t\) means time step, and \(H\) denotes the number of hidden units in a layer. In a simplified form,

\[\text{LayerNorm}(x) = \mathbf{g} \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2}} + \mathbf{b}\]

where \(\mu, \sigma^2\) are the mean and variance of \(x\), \(\odot\) is the element-wise multiplication between two vectors, and \(\mathbf{g}\) and \(\mathbf{b}\) are learnable parameters. \(b\) and \(g\) are vectors, not scalars (“defined as the bias and gain parameters of the same dimension as \(h^t\) “, see Section 3.1 of (Ba, 2016)).

LayerNorm was used in the original Transformer paper GPT-2: “Layer normalization (Ba et al., 2016) was moved to the input of each sub-block”

def layer_norm(x: np.ndarray, g: np.ndarray, b: np.ndarray, eps: float = 1e-5) -> np.ndarray:
  """Input x, learnable vector params g, b."""
  mean = np.mean(x, axis=-1, keepdims=True)
  variance = np.var(x, axis=-1, keepdims=True)
  x = (x - mean) / np.sqrt(variance + eps)  # normalize x to have mean=0 and var=1 over last axis
  return g * x + b  # scale and offset with learnable gamma/beta params

Layer normalization ensures that the inputs for each layer are always within a consistent range, which is supposed to speed up and stabilize the training process. Like Batch Normalization, the normalized output is then scaled and offset with two learnable vectors gamma and beta. The small epsilon term in the denominator is used to avoid a division by zero error.

Layer norm is used instead of batch norm in the transformer for various reasons. The differences between various normalization techniques is outlined in this excellent blog post.

We apply layer normalization over the last axis of the input.

>>> x = np.array([[2, 2, 3], [-5, 0, 1]])
>>> x = layer_norm(x, g=np.ones(x.shape[-1]), b=np.zeros(x.shape[-1]))
>>> x
array([[-0.70709, -0.70709,  1.41418],
       [-1.397  ,  0.508  ,  0.889  ]])
>>> x.var(axis=-1)
array([0.99996, 1.     ]) # floating point shenanigans
>>> x.mean(axis=-1)
array([-0., -0.])

Why do we need learnable b,g? If this strictly normalized data is passed directly into a non-linear activation function (like a Sigmoid, Tanh, or ReLU), it might restrict what the network can learn. For instance, a Sigmoid function is almost entirely linear around 0. If you force all data to center around 0, you lose the complex, non-linear representational power of that activation function.The learnable parameters $g$ (scale/gain) and $b$ (shift/bias) allow the network to stretch, squish, or move that normalized distribution into the “sweet spot” for the subsequent activation function.

Also, if the network determines during training that normalizing the data actually hurts its performance for a specific neuron, it can learn to set the gain equal to the original variance

Pytorch and GPT-2 Weight Dissection We can see the weight and bias (\(\mathbf{g}, \mathbf{b}\)) in Pytorch’s nn.LayerNorm have a shape equal to the number of hidden dims/channels:

import torch
import torch.nn as nn

batch, seq_len, d_model = 20, 5, 10
x = torch.randn(batch, seq_len, d_model)

ln = nn.LayerNorm(d_model)

print(ln.weight.shape)  # torch.Size([10])
print(ln.bias.shape)    # torch.Size([10])
print(ln(x).shape)      # torch.Size([20, 5, 10])

If we look at the GPT-2 weights, we see that on average \(\mathbf{g}\) has values such as 0.2, 0.9, and on average \(\mathbf{b}\) is considerably smaller, e.g. -0.01, 0.01.

LayerNorm requires high precision (e.g. FP32) because calculating variance involves squaring values and subtracting means. In low precision, the “epsilon” ($1e-5$) can be smaller than the smallest representable difference between numbers.

In GPT-2, ln_1 happens before the self-attention sublayer, and ln_2 happens before the MLP/feed-forward sublayer. Specifically, transformer.h.0.ln_1.weight is the gamma / scale parameter of the first LayerNorm in transformer block 0, and transformer.h.0.ln_2.weight is the gamma / scale parameter of the second LayerNorm in transformer block 0:

import torch
from transformers import AutoModelForCausalLM

model_name = "openai-community/gpt2"  # formerly "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)

sd = model.state_dict()

for name in [
    "transformer.h.0.ln_1.weight",
    "transformer.h.0.ln_1.bias",
    "transformer.h.0.ln_2.weight",
    "transformer.h.0.ln_2.bias",
    "transformer.ln_f.weight",
    "transformer.ln_f.bias",
]:
    t = sd[name]
    print(name)
    print("shape:", tuple(t.shape))
    print("first 10:", t[:10])
    print("mean:", t.float().mean().item())
    print("std:", t.float().std().item())
    print("min:", t.float().min().item())
    print("max:", t.float().max().item())
    print()

prints the following:

transformer.h.0.ln_1.weight
shape: (768,)
first 10: tensor([0.2232, 0.1820, 0.1534, 0.1917, 0.2036, 0.1948, 0.1467, 0.1865, 0.2143,
        0.1956])
mean: 0.18035894632339478
std: 0.04131494462490082
min: 0.04186137020587921
max: 0.25266674160957336

transformer.h.0.ln_1.bias
shape: (768,)
first 10: tensor([-0.0037,  0.0272, -0.0640, -0.0050, -0.0157, -0.0115,  0.2019,  0.0358,
        -0.0020, -0.0082])
mean: -0.006593452300876379
std: 0.03580174222588539
min: -0.258882999420166
max: 0.20192869007587433

transformer.h.0.ln_2.weight
shape: (768,)
first 10: tensor([0.1310, 0.2093, 0.2066, 1.2542, 1.2638, 1.2695, 0.0935, 0.0793, 0.2260,
        1.3008])
mean: 0.8678297400474548
std: 0.48494789004325867
min: 0.045285746455192566
max: 1.5110347270965576

We now test the correctness of the forward pass:

import torch
import torch.nn as nn

def torch_layernorm_reference(
	x_np: np.ndarray, gamma_np: np.ndarray, beta_np: np.ndarray, eps=1e-5
) -> np.ndarray:
    """Reference output using PyTorch's nn.LayerNorm."""
    x = torch.tensor(x_np, dtype=torch.float64)
    gamma = torch.tensor(gamma_np, dtype=torch.float64)
    beta = torch.tensor(beta_np, dtype=torch.float64)

    D = x_np.shape[-1]
    layernorm = nn.LayerNorm(D, eps=eps, elementwise_affine=True).double()

    with torch.no_grad():
        layernorm.weight.copy_(gamma)
        layernorm.bias.copy_(beta)

    out = layernorm(x)
    return out.detach().numpy()


def test_layernorm_forward_once(N: int, D: int, eps: float = 1e-5, seed: int = 0):
    rng = np.random.default_rng(seed)

    x = rng.normal(size=(N, D))
    gamma = rng.normal(size=(D,))
    beta = rng.normal(size=(D,))

    out, cache = layernorm_forward(x, gamma, beta, {"eps": eps})
    expected = torch_layernorm_reference(x, gamma, beta, eps=eps)

    assert out is not None, "layernorm_forward returned out=None"
    assert out.shape == x.shape, f"Expected output shape {x.shape}, got {out.shape}"
    print(out, expected)
    np.testing.assert_allclose(
        out,
        expected,
        rtol=1e-4,
        atol=1e-4,
        err_msg="NumPy layernorm output does not match torch.nn.LayerNorm",
    )


test_cases = [(2, 3), (4, 5), (10, 20), (1, 8), (7, 1)]

for seed, (N, D) in enumerate(test_cases):
  print(f"Testing N={N}, D={D}")
  test_layernorm_forward_once(N, D, seed=seed)

print("All tests passed!")

LayerNorm Backprop

The backpropagation derivation for LayerNorm is a multi-step application of the chain rule. Because LayerNorm calculates a mean and variance for each individual training example (across the specified dimensions), the gradient must account for how each input \(x_i\) (per-dimension) affects those statistics.

1. The Forward Pass

Given an input vector \(x\) (of length \(d\)), LayerNorm computes:

  1. Mean: \(\mu = \frac{1}{d} \sum_{i=1}^d x_i\)
  2. Variance: \(\sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2\)
  3. Normalization: \(\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}\)
  4. Scaling & Shifting: \(y_i = \gamma_i \hat{x}_i + \beta_i\)

2. The Backward Pass

Let \(L\) be the loss function. We assume we are given the upstream gradient \(\frac{\partial L}{\partial y_i}\).

Step A: Gradients for Learnable Parameters

The gradients for the gain (\(\gamma\)) and bias (\(\beta\)) are straightforward:

\[\frac{\partial L}{\partial \gamma} = \frac{\partial L}{\partial y} \odot \hat{x}\] \[\frac{\partial L}{\partial \beta} = \frac{\partial L}{\partial y}\]

Step B: Gradient for the Normalized Input (\(\hat{x}\))

\[\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gamma\]

Step C: Gradient for the Input (\(x\))

This is the “heavy lifting” part. To find \(\frac{\partial L}{\partial x_i}\), you must sum the paths through \(\hat{x}_i\), \(\mu\), and \(\sigma^2\). Using the chain rule and simplifying the terms, the standard form of the gradient is:

\[\frac{\partial L}{\partial x_i} = \frac{1}{d \sqrt{\sigma^2 + \epsilon}} \left[ d \frac{\partial L}{\partial \hat{x}_i} - \sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} - \hat{x}_i \sum_{j=1}^d \left( \frac{\partial L}{\partial \hat{x}_j} \cdot \hat{x}_j \right) \right]\]

To get to this final, elegant form, we have to track how \(x_i\) influences the output through three distinct paths: the direct normalized value, the mean, and the variance.

Let \(v = \sigma^2 + \epsilon\) and \(std = \sqrt{v}\). The forward equation is \(\hat{x}_i = (x_i - \mu) \cdot v^{-1/2}\).

1. The Total Derivative Setup

By the chain rule, the gradient for one input \(x_i\) is the sum of its impact on every \(\hat{x}_j\) in the vector:

\[\frac{\partial L}{\partial x_i} = \sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial x_i}\]

To solve this, we first find \(\frac{\partial \hat{x}_j}{\partial x_i}\) using the quotient rule or product rule on \(\hat{x}_j = (x_j - \mu)v^{-1/2}\):

\[\frac{\partial \hat{x}_j}{\partial x_i} = \frac{\partial (x_j - \mu)}{\partial x_i} v^{-1/2} + (x_j - \mu) \frac{\partial (v^{-1/2})}{\partial x_i}\]

2. Intermediate Partial Derivatives

We need three specific pieces to plug into the equation above:

  1. The Mean: \(\frac{\partial \mu}{\partial x_i} = \frac{1}{d}\)
  2. The Numerator: \(\frac{\partial (x_j - \mu)}{\partial x_i} = \delta_{ij} - \frac{1}{d}\) (where \(\delta_{ij}=1\) if \(i=j\), else \(0\)).
  3. The Variance:
\[\frac{\partial v}{\partial x_i} = \frac{\partial}{\partial x_i} \left[ \frac{1}{d} \sum (x_k - \mu)^2 \right] = \frac{2}{d}(x_i - \mu)\]

(Note: The sum of \((x_k - \mu)\) is zero, which simplifies the expansion).

  1. The Inverse Std Dev:
\[\frac{\partial (v^{-1/2})}{\partial x_i} = -\frac{1}{2} v^{-3/2} \frac{\partial v}{\partial x_i} = -\frac{1}{2} v^{-3/2} \left( \frac{2(x_i - \mu)}{d} \right) = -\frac{x_i - \mu}{d \cdot v^{3/2}}\]

3. Combining the Terms

Now substitute these back into the expression for \(\frac{\partial \hat{x}_j}{\partial x_i}\):

\[\frac{\partial \hat{x}_j}{\partial x_i} = \left( \delta_{ij} - \frac{1}{d} \right) v^{-1/2} + (x_j - \mu) \left( -\frac{x_i - \mu}{d \cdot v^{3/2}} \right)\]

Distribute \(v^{-1/2}\) and recognize that \(\frac{x_j - \mu}{\sqrt{v}} = \hat{x}_j\):

\[\frac{\partial \hat{x}_j}{\partial x_i} = \frac{1}{\sqrt{v}} \left[ \delta_{ij} - \frac{1}{d} - \frac{(x_j - \mu)(x_i - \mu)}{d \cdot v} \right]\] \[\frac{\partial \hat{x}_j}{\partial x_i} = \frac{1}{\sqrt{v}} \left[ \delta_{ij} - \frac{1}{d} - \frac{\hat{x}_j \hat{x}_i}{d} \right]\]

4. The Final Summation

Now we plug this into our Total Derivative \(\sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial x_i}\):

\[\frac{\partial L}{\partial x_i} = \sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} \cdot \frac{1}{\sqrt{v}} \left[ \delta_{ij} - \frac{1}{d} - \frac{\hat{x}_j \hat{x}_i}{d} \right]\]

Pull the constant \(\frac{1}{d\sqrt{v}}\) out (which requires multiplying the \(\delta_{ij}\) term by \(d\)):

\[\frac{\partial L}{\partial x_i} = \frac{1}{d\sqrt{v}} \sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} \left[ d\delta_{ij} - 1 - \hat{x}_j \hat{x}_i \right]\]

Distribute the sum across the three internal terms:

  1. \(\sum d\delta_{ij} \frac{\partial L}{\partial \hat{x}_j} = d \frac{\partial L}{\partial \hat{x}_i}\) (because \(\delta_{ij}\) is only \(1\) when \(j=i\)).
  2. \(\sum 1 \cdot \frac{\partial L}{\partial \hat{x}_j} = \sum \frac{\partial L}{\partial \hat{x}_j}\).
  3. \(\sum \hat{x}_j \hat{x}_i \frac{\partial L}{\partial \hat{x}_j} = \hat{x}_i \sum (\frac{\partial L}{\partial \hat{x}_j} \hat{x}_j)\).

Result

\[\frac{\partial L}{\partial x_i} = \frac{1}{d \sqrt{\sigma^2 + \epsilon}} \left[ d \frac{\partial L}{\partial \hat{x}_i} - \sum_{j=1}^d \frac{\partial L}{\partial \hat{x}_j} - \hat{x}_i \sum_{j=1}^d \left( \frac{\partial L}{\partial \hat{x}_j} \cdot \hat{x}_j \right) \right]\]

3. Intuition of the Terms

  • Term 1 (\(d \frac{\partial L}{\partial \hat{x}_i}\)): The direct impact of the input on the output.
  • Term 2 (\(\sum \frac{\partial L}{\partial \hat{x}_j}\)): This term ensures that the sum of the gradients across the normalized dimensions is zero. This corresponds to the fact that shifting the input by a constant doesn’t change the output (mean-invariance).
  • Term 3 (\(\hat{x}_i \sum \dots\)): This term accounts for the variance. It ensures that scaling the input by a constant doesn’t change the output (scale-invariance).

In PyTorch, the backward pass is implemented in highly optimized C++/CUDA kernels: see ATen library source (e.g., aten/src/ATen/native/layer_norm.cpp), you will see the implementation of the LayerNormBackward function, which mathematically executes the logic above.

Numerical Verification of Backward Pass

We’ll define functions that accept a single input. lambda x: layernorm_forward(x, gamma, beta, ln_param)[0] and lambda gamma: layernorm_forward(x, gamma, beta, ln_param)[0] etc.

Now, define random inputs and random upstream gradient dout:

# Test the affine_backward function
np.random.seed(31)

N = 10
D = 3

x = np.random.randn(N, D)
gamma = np.random.randn(D)
beta = np.random.randn(D)
ln_param = {}
ln_param['eps'] = 1e-10
dout = np.random.randn(N, D)

dx_num = eval_numerical_gradient_array(lambda x: layernorm_forward(x, gamma, beta, ln_param)[0], x, dout)
dg_num = eval_numerical_gradient_array(lambda gamma: layernorm_forward(x, gamma, beta, ln_param)[0], gamma, dout)
db_num = eval_numerical_gradient_array(lambda beta: layernorm_forward(x, gamma, beta, ln_param)[0], beta, dout)

_, cache = layernorm_forward(x, gamma, beta, ln_param)
dx, dg, db = layernorm_backward(dout, cache)

print("dg_num shape:",  dg_num.shape)
print("dg shape: ", dg.shape)

# The error should be around e-10 or less
print('Testing affine_backward function:')
print('dx error: ', rel_error(dx_num, dx))
print('dg error: ', rel_error(dg_num, dg))
print('db error: ', rel_error(db_num, db))

We’ll check numerical gradients using the following function, that applies the chain rule to multiply by the upstream gradient df:

def eval_numerical_gradient_array(f: Callable, x: np.ndarray, df: np.ndarray, h: float = 1e-5):
    """
    Evaluate a numeric gradient for a function `f` that accepts a numpy
    array `x` and returns a numpy array. `df` is upstream gradient.
    """
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=["multi_index"], op_flags=["readwrite"])
    while not it.finished:
        ix = it.multi_index

        oldval = x[ix]
        x[ix] = oldval + h
        pos = f(x).copy()
        x[ix] = oldval - h
        neg = f(x).copy()
        x[ix] = oldval

        grad[ix] = np.sum((pos - neg) * df) / (2 * h)
        it.iternext()
    return grad

RMSNorm

Root Mean Square Normalization (RMSNorm) (Zhang, 2019) simpler variant of the LayerNorm used in the original GPT-2. RMSNorm only focuses on re-scaling invariance, totally removing the mean statistic.

\[\bar{a}_i = \frac{a_i}{\operatorname{RMS}(\mathbf{a})} g_i, \qquad \text{where } \operatorname{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} a_i^2}.\]

RMSNorm rescales a vector so its values have unit root-mean-square. This keeps activations from growing or shrinking as they flow through the network, which stabilizes training. Karpathy’s microgpt removes the learned scaling paramater \(g\):

def rmsnorm(x):
  ms = sum(xi * xi for xi in x) / len(x)
  scale = (ms + 1e-5) ** -0.5
  return [xi * scale for xi in x]

Background on Initialization

First, to understand normalization layers, we need to understand network input initialization, since each layer (not just the first layer) needs such a stable distribution.

Xavier Glorot

Kaiming He

too large of learning rate if init to large values

why do we want whitened data (linearly transformed to have mean 0, unit variance)

  • why mean 0?

  • why do we want unit variance?

Long known (LeCun et al., 1998b; Wiesler & Ney, 2011) that the network training converges faster if its inputs are whitened.

(1) In short, due to limited precision in network. By normalizing, you can make the bias fitting task much easier (the bias is close to 0, and at a minimum the “usable precision” of the float parameters can be meaningfully spent fitting the data

What if we zero-initialize all layer weights? -> multiply by zero activations in backprop, get 0 grad

What if we initialize all weights to same value -> all weights wil be the same, no use in more paramters, all redundant

(2) Internal “Covariate shift”: the distribution of each layer’s inputs (network activations) changes during training, as the parameters of the previous layers change See Szegedy and Ioffe, 2015.

\[y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} + \beta^{(k)}\]
  • GroupNorm

[ADD GRAPHIC]

BatchNorm

BatchNorm normalizes using batch-level statistics, while LayerNorm normalizes each example independently.

  • BatchNorm -> fixes the means and variances of layer inputs.

Why no batch norm not used in transformers?

  • BatchNorm depends on other examples in the batch: In a transformer, you usually want the representation of one sequence/token to not depend on which other examples happened to be in the same mini-batch.
  • Autoregressive generation often has batch size: At inference time, LLMs often generate one sequence at a time. BatchNorm needs reliable batch statistics. With tiny batch sizes, the statistics are noisy or meaningless.
  • BatchNorm behaves differently during training and inference: this is annoying for sequence models.
  • Masking is required due to padding on sequence dim.

References

  1. Sergey Ioffe, Christian Szegedy. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. [PDF].

  2. Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton. Layer Normalization. 2016. arXiv

  3. Biao Zhang, Rico Sennrich. Root Mean Square Layer Normalization. 2019. [PDF]

  4. Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever. Language Models are Unsupervised Multitask Learners. 2019. [PDF]

  5. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. 2017. Attention Is All You Need. [PDF].