Convoluted Stuff

Optimising compilers, and how thousand-year-old math shaped deep learning

Disclaimer: I have a day job. If I’ve made a mistake, please either let me know on Twitter in a civil way and I’ll fix it or ask your employer to hire me to do this for money.

Introduction

This post now has a part 2!

I asked a grad student of mine the other day, how is convolution implemented?. He’s an excellent student, and on a piece of paper, he pointed out that a thing called a kernel slides over an image, and there is some element-wise multiplication and some summation going on. But then I asked him again: yeah, but how is it implemented?.

-A Python API calls optimised code in the background! came the answer (I told you he’s good). But I’m a tenacious man.

-OK then, write some pseudocode!. I have a rule that pseudocode is just a placeholder for Python. So he did. My student didn’t license the code to me, so I won’t put it here, but it looked a lot like the following code from Cottonwood, Brandon Rohrer’s deep learning framework (which is MIT licensed):

@njit
def xcorr_1d(signal, kernel, n_steps=None):
    """
    Calculate n_steps of the sliding dot product,
    a.k.a. the cross-correlation,
    between a one dimensional signal and a one dimensional kernel.

    Start with the beginning (zeroth elements) of the kernel and signal
    aligned.
    Shift the kernel up by one position each iteration.
    """
    if n_steps is None:
        n_steps = signal.size - kernel.size + 1

    result = np.zeros(n_steps, dtype=np.double)
    n_ker = kernel.size
    for i in range(n_steps):
        # Using np.dot() instead of np.sum() over the products cuts
        # the computation time down by a factor of 5.
        result[i] = np.dot(signal[i: i + n_ker], kernel)
    return result

This is only a small part of Brandon’s code, but it’s among the most readable and pedagogically useful implementations I know (make sure to also watch his fantastic playlist on the topic).

This is an implementation by a man who has thought about what he does. See those comments? He thought about computation time. He also made sure to get optimised code by making Numba optimise the loop! My student was satisfied having provided a useful answer, and we went on with our days.

This is an essay about convolution, but not in the obvious way. It doesn’t cover the basics of kernels and sliding filters. Instead, I want to talk about something that we rarely discuss in academia, which is how we get from the maths we teach about deep learning to the math we actually use in deep learning and how very intelligent people make it possible for anyone to type in something like conv2d and get a state-of-the-art neural network with incredibly efficient computation.

Cottonwood 2D

To make an important point, let us first generalise the Cottonwood convolution to two dimensions, by which we obviously mean three dimensions, by which we mean four. Jokes asside, we assume we wish to perform convolution of an input matrix with dimensions (N_Batch_Size, Height, Width, Channels) (NHWC) with a kernel of dimensions (filter_height, filter_width, in_channels, out_channels). For simplicity, we will forgo padding and implementing strides or dilated convolutions. Our operation will be performed on an input matrix of (8x150x150x3) and an edge detection kernel of (3x3x3x16) as per the above notation, giving us an output of (8x148x148x16).

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from functools import partial
show = partial(plt.imshow, cmap="gray")
import tensorflow as tf
from scipy import linalg
from numba import njit, prange
im = (np.array(Image.open("/assets/images/puppy.jpg").resize((150,150)))/255.).astype(np.float32)
show(im)
im.dtype
plt.suptitle("O M G Q T")
Text(0.5, 0.98, 'O M G Q T')

png

kernel = np.array([[1,0,-1],[0,0,0],[-1,0,1]], dtype=np.float32)
show(kernel.squeeze())
<matplotlib.image.AxesImage at 0x7f9d8a188090>

png

print(im.shape, kernel.shape)
(150, 150, 3) (3, 3)
input_matrix, input_kernel = np.tile(im, (8, 1, 1, 1)), np.tile(kernel[...,None, None], (1,1,3,16))
print(input_matrix.shape, input_kernel.shape)
(8, 150, 150, 3) (3, 3, 3, 16)
@njit
def conv2d(image, kernel):
    '''Valid convolution of a square input of shape (batch_size, x, y, in_channels) 
    with a square kernel of shape (x, y, in_channels, out_channels )'''
    batch_size, x_out, y_out, in_channels = image.shape
    kx, ky, _, out_channels = kernel.shape  

    output = np.zeros((batch_size, x_out-kx+1, y_out-kx+1, out_channels))

    for n in range(batch_size):
        for h in range(x_out-kx+1):
            for w in range(y_out-kx+1):
                for c in range(out_channels):
                    #unfortunately, numba.njit fails with dot on 4D
                    output[n,h,w,c] = np.multiply(image[n, h:h+kx, w:w+ky, :], kernel[:,:,:,c]).sum()

    return output
result = conv2d(input_matrix, input_kernel)
show(result[0,...,0])
print(result.shape)
(8, 148, 148, 16)

png

Works well enough and it has the desired effect (edge detection). Lets make sure that we’re doing things right by testing our implementation against the Tensorflow one.

np.testing.assert_allclose(conv2d(input_matrix, input_kernel), tf.nn.conv2d(input_matrix, input_kernel, padding="VALID", strides=(1,1)))

Lets now test how fast our implementation is against the Tensorflow one:

%timeit conv2d(input_matrix, input_kernel)
568 ms ± 8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit tf.nn.conv2d(input_matrix, input_kernel, padding="VALID", strides=(1,1))
4.31 ms ± 89.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

More than two orders of magnitude difference! In this essay, we’ll try to understand what goes into these two orders of magnitude, and how modern deep learning libraries manage to provide this performance on a wide assortment of devices and with great flexibility towards the users' needs.

What lies beneath

Before we go further, let me briefly discuss what kinds of thoughts actually need to go into such a library to bridge the gap between those two orders of magnitude (non-exhaustive list):

  • We won’t talk in detail about the effects of matrix, batch, array or kernel size. In practice, those make a huge difference. For example, parallelisation only pays dividends if the time overhead of parallelising an operation is smaller than the time penalty of serial execution. Libraries like TensorFlow have optimised primitives (some of which we’ll talk about below) for certain kernel sizes, but need to be flexible towards any kernel size chosen by the user.

  • We’ll hand-wave away considerations about locality: This is a hugely important question. Is your data small enough to fit in cache/VRAM/RAM? Every time you move data from one location to another, enormous time penalties are incurred. In fact, in unoptimised implementations, the cost of copying arrays from RAM to VRAM is significantly larger than the net computation time of the convolutions in the network. A big part of why high-performance clusters like NVidia DGX systems include NVLink$^{TM}$ (and now Mellanox InfiniBand$^{TM}$) is to reduce this overhead.

  • We won’t consider optimisiations from prefetching or array ordering. Array order matters! Although NHWC is the default ordering in many libraries, alternative ordering (NCHW) is faster, since CPUs fetch RAM in blocks, not as single registers, thus adjacent matrices can be fetched together (TensorFlow internally re-orders to NCHW, see here). Similarly, if your library uses C-ordering, it makes sense to also use C-ordered arrays instead of forcing the library to reorder to Fortran order internally.

  • What precision do you require? If you are doing high-precision scientific computing, you probably need 64-bit floats. In computer vision, you might find that even 32-bit floats are a waste. Recent GPUs can perform calculations in FP16, whereas not every CPU supports FP16 calculations.

  • On the topic of hardware, there are enormous time savings from hardware-specific optimisations. This can mean:

    • CPUs and GPUs are fundamentally very different beasts with different cache and memory architectures and designed for different tasks. Thinking about where your code will be executed (SIMD vs. MIMD) makes a big difference.
    • You have hardware specifically designed to perform these kinds of operations (think TPU, NVidia TensorCores, specific CPU instructions for neural networks in new Intel chips)
    • You have a compiler which can optimise for these operations. Read what TensorFlow spits out whenever you start it up on CPU: Ever seen something like:

    This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA. To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags

Lastly I’d like to motivate you to think about why you are doing this? If you’re trying to learn, an implementation with nested loops is great! If you want to use open source tools only, you have to use OpenCL. If the time you spend developing is more important to you than the time your computer spends on calculations, your route will be different. If you need GPU support on macOS, you need something like PlaidML. The list goes on.

All that said, I think it’s clear that calling optimised code is neither as trivial as writing a bit of C++, nor the end of the discussion. It requires thinking about computer science and engineering and about numerical methods, which is delightfully different from how we usually think about convolution! Lets therefore look at some examples.

Goodbye loops, hello JAX/XLA

If you are like me, every for loop in the example above made your stomach acid production increase just a little bit. Contrary to popluar belief, CPython is not a slow language. The core developers have made a great job of optimising everything that can be optimised. Between type checking overhead, compiler eval loops and memory management however, nested hot loops in CPython are just a plain bad idea. This is where a few possibilities for optimising come in:

  1. Target the points where CPython is slow by implementing your code in a different language, more suited to the occasion. This could be Cython for example, or C++. numpy actually does exactly that, with a powerful backend calling fast subroutines implemented in C or Fortran and more optimisations we’ll talk about below. Considering this, numpy is essentially just a Python binding for fast libraries implemented in statically typed languages and optimised implementations.
  2. Try to paralellise or eliminate as many loops as possible. In our example above, the batch dimension can be processed in parallel trivially, and we already implemented some logic based on the fact that the image and the kernel are square (otherwise we’d have needed yet more loops!).

We are already using numpy in our implementation and we chose to JIT-compile away the overhead from some of the for loops with Numba. This is how we can take the pure CPython implementation (which takes around 13 seconds) down to 600ms. Considering that all we had to do was to add a @njit decorator to our code, that’s a lot of speed for not a lot of effort. Still, we are ~100 times slower than TensorFlow! Numba is also a more generic JIT, focused on providing broad support for a range of numpy operations. I love using it, as it provides speed-ups comparable to other very fast languages (C, Julia), many of which (C) are more difficult to write and more dangerous (memory management!) or don’t (yet) enjoy the broad support in the ML community (Julia). However, numba doesn’t provide support for automatic differentiation and is not focused specifically on ML. This is where projects such as JAX come in.

My background is scientific computing and when I started doing research, the deep learning boom hadn’t started yet (that is, before 2010…). The crowning glory of my doctoral thesis years (which never got used nor published), was writing a custom automatic differentiation engine in Object PASCAL to implement what is now known as Hamiltonian Monte Carlo from scratch. If I had a tool such as JAX back then (and GPUs), the trauma would have not been as severe. Besides my love for JAX (of which I make no secret, and if the JAX team ever needs a radiologist who does ML research, you know where to find me!) for its marriage of the numpy API with functional programming and automatic differentiation, it provides tools perfectly suited for what we are trying to do here, among others JIT compilation and vectorisation. XLA, the compiler JAX sits atop is a great example of a modern, open source, “domain specific” compiler. Its focus on ML means that it smartly fuses operations on the array level, knows about memory/speed tradeoffs on different platforms while making use of all the usual JIT goodness (compilation caches, AOT compilation on demand etc.)). The increasing focus on XLA will mean that TensorFlow uses it more and more, which means that we hopefully get to use XLA tracing magic on tf.functions soon, and that our friends using PyTorch, Julia etc. can also profit!

The backend isn’t everything

Now that we’ve surveyed the low-level, technical considerations and found out that the future is multi-frontend, multi-backend, multi-device and pythonic (or Julian, Swifty, hopefully Rusty or whichever other syntactic LLVM sugar du jour you like), lets think some more about the numerical methods part of the equation (no pun intended). Lets assume we want to get rid of the loops entirely. We have a few beautiful mathematical ideas at our disposal, which I’ll showcase below.

Convolution as matrix multiplication in the frequency domain

Lets start with the most elegant one. It turns out that convolution can be interpreted as elementwise multiplication of the input matrix and the kernel in the frequency domain. This implementation replaces the exponential order nested loops with two Fourier transforms and one inverse Fourier transform, which are quadratic order, but usually implemented as Fast Fourier Transforms (FFTs), which are $n \log(n)$. Since this is real convolution, not cross-correlation, we need to flip the kernel to get the same result as the tf implementation. By default, this is also a same convolution, not a valid one as above, as the image doesn’t shrink.

def fftconvolve(image, kernel):
    kernel = np.flipud(np.fliplr(kernel))
    return np.real(np.fft.ifft2(np.fft.fft2(image)*np.fft.fft2(kernel, s=image.shape)))
fftresult = fftconvolve(im[...,0], kernel)
show(fftresult)
<matplotlib.image.AxesImage at 0x7f9d79ef28d0>

png

%timeit fftconvolve(im[...,0], kernel)
738 µs ± 5.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

738 µs!

This is what I call speed (although this isn’t for the full tensor)! I love how this superfast algorithm is also one of the simplest ones, conceptually and programatically, and it goes to show how beautiful the connection between Deep Learning and Signal Processing can be. This method is actually implemented in scipy.signal.correlate and there is a paper describing the method. GPU libraries include FFT subroutines (CUFFT), and there is an implementation of FFT-based convolutions in cuDNN and by extension Tensorflow. However, it requires a lot of memory, since the kernel has to be tiled to the size of the image to implement the multiplication in one go. I’m not sure if it’s currently used in practice.

Convolution as vanilla matrix multiplication

Since convolution is a linear operation, the idea of interpreting it as a matrix multiplication is not new. It turns out that it is possible to construct a matrix (a doubly blocked circulant matrix in fact, a variant of a Toeplitz matrix) that replicates convolution in a single step.

def convmatrix(kernel, input_shape): #via https://stackoverflow.com/questions/56702873/i
    k_h, k_w = kernel.shape
    i_h, i_w = input_shape
    o_h, o_w = i_h-k_h+1, i_w-k_w+1
    toeplitz = []
    for r in range(k_h):
        toeplitz.append(linalg.toeplitz(c=(kernel[r,0], *np.zeros(i_w-k_w)), r=(*kernel[r], *np.zeros(i_w-k_w))) ) 
    h_blocks, w_blocks = o_h, i_h
    h_block, w_block = toeplitz[0].shape
    W_conv = np.zeros((h_blocks, h_block, w_blocks, w_block))
    for i, B in enumerate(toeplitz):
        for j in range(o_h):
            W_conv[j, :, i+j, :] = B
    W_conv.shape = (h_blocks*h_block, w_blocks*w_block)
    return W_conv
show(np.dot(convmatrix(kernel, input_shape=(150,150)), im[...,0].flatten()).reshape((148,148)))
<matplotlib.image.AxesImage at 0x7f9d5ade22d0>

png

This technique is both educational and problematic. Let’s have a look at the doubly blocked circulant matrix:

cm = convmatrix(kernel, input_shape=(150,150))
cm
array([[ 1.,  0., -1., ...,  0.,  0.,  0.],
       [ 0.,  1.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  0.,  0., ...,  1.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  1.,  0.],
       [ 0.,  0.,  0., ..., -1.,  0.,  1.]])
cm.shape
(21904, 22500)
print(f"Sparsity = {(cm==0).sum()/np.prod(cm.shape):.4f}")
Sparsity = 0.9998

This is a very large and very sparse matrix. Yet, we can fully implement our convolution with it. Due to its circulant properties, it takes over the “sliding window” function! Also, its sparsity reminds us of why CNNs are so parameter efficient, they only need a few! It’s also very problematic however. Even with a sparse matrix subroutine such as cuSPARSE, allocating and moving around such an enormous matrix and multiplying by it just doesn’t make much sense. There is however a spiritual successor to this technique, which was actually used in production (I think it was in Caffe first, correct me if I am wrong).

Im2Col or the Caffe Trick

What if, instead of actually creating these huge matrices out of the kernel, we rearrange the image in such a way that we can perform a dense multiplication of the rearranged image matrix with a kernel vector? The trick here is to think about the “sliding window” the opposite way around! Instead of sliding the kernel, we rearrange the pixels (imagine Picasso meets tetris 1). As no calculations are required, we can do this without any modification on a chunk of memory by just reading the right parts:

def im2col_conv2d(image, kernel):
    batch_size, x_out, y_out, in_channels = image.shape
    kx, ky, _, out_channels = kernel.shape

    im_col = np.lib.stride_tricks.as_strided(
             image, (batch_size, x_out-kx+1, y_out-ky+1, kx, ky, in_channels), 
             image.strides[:3] + image.strides[1:], writeable=False)

    return np.tensordot(im_col, kernel, axes=3)
im2col_conv = im2col_conv2d(input_matrix, input_kernel)
im2col_conv.shape
(8, 148, 148, 16)
np.testing.assert_allclose(im2col_conv, tf.nn.conv2d(input_matrix, input_kernel, padding="VALID", strides=(1,1)))
%timeit im2col_conv2d(input_matrix, input_kernel)
6.7 ms ± 23.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

BAM!

We did it! We are nearly as fast as TensorFlow on the entire operation, and the implementation is almost eye-watering in its elegance. We can be proud of our puny little CPython and numpy implementation! Lets have a look at this mysterious intermediate im_col object:

def inspect_im_col(image, kernel):
    batch_size, x_out, y_out, in_channels = image.shape
    kx, ky, _, out_channels = kernel.shape

    im_col = np.lib.stride_tricks.as_strided(
             image, (batch_size, x_out-kx+1, y_out-ky+1, kx, ky, in_channels), 
             image.strides[:3] + image.strides[1:], writeable=False)

    return im_col
mysterious_object = inspect_im_col(input_matrix, input_kernel)
mysterious_object.shape
(8, 148, 148, 3, 3, 3)

We see a little issue here, which is that this intermediate array contains a lot of replication:

show(mysterious_object[0, ..., 0,0,0])
<matplotlib.image.AxesImage at 0x7f9d68adf250>

png

show(mysterious_object[0, ..., 1,0,0])
<matplotlib.image.AxesImage at 0x7f9d69a7d910>

png

show(mysterious_object[0, ..., 2,0,0])
<matplotlib.image.AxesImage at 0x7f9d5c7bae10>

png

These arrays are nearly identical, yet each takes up memory. In fact, our original image of size 54000 32-bit floats has now grown to473126 32-bit floats, which is nearly 9 times larger. If we use a larger kernel, say 7x7, the size becomes 2438553, which is around 45 times larger. This exponential dependence on kernel size makes the technique a little inefficient for large kernel sizes.

Still, Im2Col is used by Tensorflow in two kernels in the CPU library and in quantized_conv_ops.

Winograd: The Final Frontier

Although we (nearly) matched Tensorflow’s implementation in speed, I’d like to discuss one more algorithm, which I personally find awesome for many reasons:

The broad idea of the algorithm, which was introduced in this paper is that convolutions can be sped up by replacing float multiplications by additions and using shared intermediate outputs. Since float adds are typically faster than float mults, speed gains can be realised (The paper reports a 4x complexity reduction in the ideal case). Concretely, whereas our regular convolution of a filter with size $w$ with an image of size $k$ above requires $(w \times k)^2$ multiplications, a fast algorithm only requires $(w+k-1)^2$ multiplications. It also turns out that for a kernel $g$ and input patch $d$ of a pre-defined size, specific matrices $A$, $B$ and $G$ can be pre-computed, such that the output of 2D convolution becomes:

$$ Y = A^T \bigg[ [G g G^T] \odot [B^T d B] \bigg] A $$

Thus, by pre-computing and storing these matrices (which again are identical for certain input and kernel sizes) and tiling the input matrix appropriately and doing the whole thing in parallel, even higher computational gains can be realised.

Summary

I hope this short post took you on a fun journey and stimulated you to find out more about how the libraries we use do some of the things they do under the hood. My huge respect to Brandon Rohrer for his great educational content, Matthew Johnson and the JAX team, the TF folks, and everyone who is using their intellect to advance our field! Modern Deep Learning is indeed a wonderful amalgam of great ideas, be they old or new!

Supplementary reading

  • If you are a beginner or would like to brush up on your convolution math skills, read this paper!
  • For a fun comparison of different convolution methods and a few notes about prefetching, check this post
  • For a great discussion of FFTs, a post by the legendary JakeVDP of (now) JAX fame can be found here

  1. read nice explanations here and here ↩︎

  2. Shmuel Winograd. Arithmetic complexity of computations, volume 33. Siam, 1980. and On multiplication of polynomials modulo a polynomial. SIAM Journal on Computing, 9(2):225–229, 1980. ↩︎

Georgios Kaissis
Georgios Kaissis
Specialist Diagnostic Radiologist - Senior Research Scientist

My research interests include AI-based medical image analysis, probabilistic methods and privacy-preserving AI

Related