Table of Contents:

Backprop through a convolutional layer is one of the most fundamental operations in deep learning. Although the derivation is surprisingly simple, but there are very few good resources out on the web explaining it. In this post, we’ll derive it, implement it, show that the two agree perfectly, and provide some intuition as to what is going on.

Recap of a Convolutional Layer

Before we go into the backprop derivation, we’ll review the basic operation of a convolutional layer, which actually implements cross-correlation in modern libraries like Pytorch. To make things easy to understand, we’ll work with a small numerical example. Imagine a simple 3x3 kernel \(k\) (Sobel filter…):

import numpy as np; import torch
k = np.array(
    [
        [1,0,-1],
        [2,0,-2],
        [1,0,-1]
    ]).reshape(1,1,3,3).astype(np.float32)

and a simple 6x5 input image:

x = np.array(
    [
        [1,1,1,2,3],
        [1,1,1,2,3],
        [1,1,1,2,3],
        [2,2,2,2,3],
        [3,3,3,3,3],
        [4,4,4,4,4]
    ]).reshape(1,1,6,5).astype(np.float32)

We can perform cross-correlation of \(x\) with \(k\) with Pytorch:

conv = torch.nn.Conv2d(
    in_channels=1,
    out_channels=1,
    kernel_size=3,
    bias=False,
    stride = 1,
    padding_mode='zeros',
    padding=0
)

x_tensor = torch.from_numpy(x)
x_tensor.requires_grad = True
conv.weight = torch.nn.Parameter(torch.from_numpy(w))
out = conv(x_tensor)

If we want to examine the gradients of the upstream variable \(o\) (named out above) w.r.t. the kernel \(k\) and w.r.t. the input \(x\), we need to create a scalar loss, and perform backprop:

loss = out.sum()
loss.backward()

print(conv.weight.grad)
print(x_tensor.grad)

We’ll now show how these quantities are derived.

Backprop through Convolution to Weights

Cross-correlation above is taking 9 windows from \(x\), and dot-product-ing each with the flattened kernel \(k\): \(x = \begin{bmatrix} 1 & 1 & 1 & 2 & 3 \\ 1 & 1 & 1 & 2 & 3 \\ 1 & 1 & 1 & 2 & 3 \\ 2 & 2 & 2 & 2 & 3 \\ 3 & 3 & 3 & 3 & 3 \\ 4 & 4 & 4 & 4 & 4 \\ \end{bmatrix}\)

Suppose we number each window from \(x\) as \(x_{w_i}\), from \(x_{w_1}, ..., x_{w_9}\):

[1,1,1]   [1,1,2]   [1,2,3]
[1,1,1]   [1,1,2]   [1,2,3]
[1,1,1]   [1,1,2]   [1,2,3]

[1,1,1]   [1,1,2]   [1,2,3]
[1,1,1]   [1,1,2]   [1,2,3]
[2,2,2]   [2,2,2]   [2,2,3]

[1,1,1]   [1,1,2]   [1,2,3]
[2,2,2]   [2,2,2]   [2,2,3]
[3,3,3]   [3,3,3]   [3,3,3]

[2,2,2]   [2,2,2]   [2,2,2]
[3,3,3]   [3,3,3]   [3,3,3]
[4,4,4]   [4,4,4]   [4,4,4]

Now, since \(L = \sum\limits_{i=1}^4 \sum\limits_{j=1}^3 o_{ij}\) as defined above, \(\frac{\partial L}{\partial o} = \begin{bmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix}\). This will be our upstream gradient.

By the chain rule, we need to multiply the upstream gradient with the conv layer’s gradient, to get gradients w.r.t. the inputs to the conv layer: \(\frac{ \partial L }{\partial k} = \frac{\partial L}{\partial o} \cdot \frac{ \partial o }{\partial k}\)

Likewise, to backprop to the input image, we see: \(\frac{ \partial L }{\partial x} = \frac{\partial L}{\partial o} \cdot \frac{ \partial o }{\partial x}\)

We’ll focus just on \(\frac{ \partial o }{\partial k}\) for now – the derivative with respect to our kernel weights.

Suppose our input image is \(N_1 \times N_2\), and our kernel \(k\) has dimensions \(k_1 \times k_2\). Let \(x\) be indexed by row, column \((r,c)\) and the kernel is indexed by \((a,b)\).

Suppose our input has been padded. Then:

Then each “pixel” of our output \(o\), i.e. \(o[r,c]\), can be computed as the dot product between the kernel and a window of \(x\) whose top-left corner lies at \((r,c)\):

\[\begin{aligned} o[r,c] &= \sum\limits_{a=0}^{k_1-1} \sum\limits_{b=0}^{k_2-1} x[r+a,c+b] w[a,b] \\ o[r,c] &= \vec{x_{vec}}^T \vec{w_{vec}} \end{aligned}\]

Only one term in the double-sum will be relevant: \(\frac{\partial o[r,c] }{\partial w[a^\prime,b^\prime] } = x[r + a^\prime, c + b^\prime]\) By the chain rule,

\[\begin{aligned} \frac{\partial L}{\partial w[a^\prime,b^\prime]} &= \sum\limits_{r=0}^{N_1-1} \sum\limits_{c=0}^{N_2-1} \frac{\partial L}{\partial o[r,c]} \frac{\partial o[r,c]}{\partial w[a^\prime,b^\prime]} \\ &= \sum\limits_{r=0}^{N_1-1} \sum\limits_{c=0}^{N_2-1} \frac{\partial L}{\partial o[r,c]} \cdot x[r+a^\prime,c+b^\prime] \end{aligned}\]

This double-sum over multiplied scalars is should remind us of cross-correlation, because this is its definition. Thus, we can re-interpret the equation above as cross-correlation between the upstream gradient and the input image, where the “filter” sizes can suddenly be quite large (\(3 \times 4\) in this example): \(\frac{\partial L}{\partial w} = x * \frac{\partial L}{\partial o}\), where the \(*\) operator represents cross-correlation, not convolution.

Let’s check if this matches the gradients w.r.t. the kernel that Pytorch computed for us:

print(conv.weight.grad)
tensor([[[[15., 18., 25.],
          [21., 23., 28.],
          [30., 31., 34.]]]])

There are 12 possible places where w[0,0] can be placed, highlighted in red below:

\[x = \begin{bmatrix} \color{red}{1} & \color{red}{1} & \color{red}{1} & 2 & 3 \\ \color{red}{1} & \color{red}{1} & \color{red}{1} & 2 & 3 \\ \color{red}{1} & \color{red}{1} & \color{red}{1} & 2 & 3 \\ \color{red}{2} & \color{red}{2} & \color{red}{2} & 2 & 3 \\ 3 & 3 & 3 & 3 & 3 \\ 4 & 4 & 4 & 4 & 4 \\ \end{bmatrix}\]

When we sum them up, we find that indeed, \(\frac{\partial L}{\partial w[0,0]} = 15\)

print(x.squeeze()[:4,:3].sum())
15.0

We can now check all the places where w[0,1] can be placed (where valid), highlighted in red below:

\[x = \begin{bmatrix} 1 & \color{blue}{1} & \color{blue}{1} & \color{blue}{2} & 3 \\ 1 & \color{blue}{1} & \color{blue}{1} & \color{blue}{2} & 3 \\ 1 & \color{blue}{1} & \color{blue}{1} & \color{blue}{2} & 3 \\ 2 & \color{blue}{2} & \color{blue}{2} & \color{blue}{2} & 3 \\ 3 & 3 & 3 & 3 & 3 \\ 4 & 4 & 4 & 4 & 4 \\ \end{bmatrix}\]

When we sum them up, we find that indeed, \(\frac{\partial L}{\partial w[0,1]} = 18\)

print(x.squeeze()[0:0+4,1:1+3].sum())
18.0

Hopefully, the story is becoming clear by now for the stride=1 case. However, the stride=2 case is not so trivial, since the valid places where w[0,0] can be placed on the input image are no longer adjacent to one another; rather, we need to use dilated cross-correlation.

If we have the same kernel, then:

print(out)
array([[[[ 0., -8.],
         [ 0., -4.]]]])

print(dout)
array([[[[1., 1.],
         [1., 1.]]]])

\(\frac{\partial L}{\partial w}\) is equal to:

print(conv.weight.grad.numpy())
array([[[[ 4.,  6.,  8.],
         [ 6.,  7.,  9.],
         [ 8.,  9., 10.]]]], dtype=float32)

The valid places where w[0,0] can be placed on the input image, when dot-product-ed with dout, yields \(\color{red}{4}\), as expected.

\[x = \begin{bmatrix} \color{red}{1} & 1 & \color{red}{1} & 2 & 3 \\ 1 & 1 & 1 & 2 & 3 \\ \color{red}{1} & 1 & \color{red}{1} & 2 & 3 \\ 2 & 2 & 2 & 2 & 3 \\ 3 & 3 & 3 & 3 & 3 \\ 4 & 4 & 4 & 4 & 4 \\ \end{bmatrix}\]

w[0,1] can be placed at the blue locations, and when the dout filter values are placed at these locations, the dot product is \(\color{blue}{6}\), as expected:

\[x = \begin{bmatrix} 1 & \color{blue}{1} & 1 & \color{blue}{2} & 3 \\ 1 & 1 & 1 & 2 & 3 \\ 1 & \color{blue}{1} & 1 & \color{blue}{2} & 3 \\ 2 & 2 & 2 & 2 & 3 \\ 3 & 3 & 3 & 3 & 3 \\ 4 & 4 & 4 & 4 & 4 \\ \end{bmatrix}\]

Backprop through Convolution to an Input Image

Now, we’ll work through the gradient w.r.t. the input image.

Our upstream gradient was \(\frac{\partial L}{\partial o} = \begin{bmatrix} 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{bmatrix}\) since it came from a sum operation over a \(4 \times 3\) grid. We’ll need to pad this back up to the original size, and convolve it with the input kernel \(w\). Suppose instead we flip the input kernel \(w\) about the x and y axes, and perform cross correlation:

\[\begin{bmatrix} 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 \end{bmatrix} * \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}\]

We find that for our toy example, \(\frac{\partial L}{\partial x}\) is equal to the following \(6 \times 5\) matrix:

print(x_tensor.grad)
tensor([[[[ 1.,  1.,  0., -1., -1.],
          [ 3.,  3.,  0., -3., -3.],
          [ 4.,  4.,  0., -4., -4.],
          [ 4.,  4.,  0., -4., -4.],
          [ 3.,  3.,  0., -3., -3.],
          [ 1.,  1.,  0., -1., -1.]]]])

CMU has a nice discussion here.

Reflection

We’ve observed an interesting duality – cross-correlation in the forward pass becomes cross-correlation in the backwards pass. Just as the forward pass and backward pass for linear layers is just matrix multiplication, we see the same duality for convolution.