Convoluted Stuff - Volume 2

JAX strikes back, more puppies and a deeper look into Winograd

Disclaimer: I have a day job. If I’ve made a mistake, please either let me know on Twitter in a civil way and I’ll fix it or ask your employer to hire me to do this for money.

Part 1

In which we benchmark our ideas against jax.lax.conv

In part 1 of this series, we found out that our numpy im2col implementation can nearly match the TensorFlow implementation in speed. Soon after, however, Matthew Johnson of JAX fame commented the following:

Since we are trying a different implementation, we need a different puppy as well, but we’ll keep the kernel and everything else identical to be fair.

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from functools import partial
show = partial(plt.imshow, cmap="gray")
import jax
import scipy
im = (np.array("wolf.jpg").resize((150,150)))/255.).astype(np.float32)
plt.suptitle("W O W S O Q T")
Text(0.5, 0.98, 'W O W S O Q T')


kernel = np.array([[1,0,-1],[0,0,0],[-1,0,1]], dtype=np.float32)
<matplotlib.image.AxesImage at 0x7fd12039ee50>


print(im.shape, kernel.shape)
(150, 150, 3) (3, 3)
input_matrix, input_kernel = np.tile(im, (8, 1, 1, 1)), np.tile(kernel[...,None, None], (1,1,3,16))
print(input_matrix.shape, input_kernel.shape)
(8, 150, 150, 3) (3, 3, 3, 16)

We’ll reuse our im2col implementation from last time and benchmark it against lax.conv_general_dilated.

def im2col_conv2d(image, kernel):
    batch_size, x_out, y_out, in_channels = image.shape
    kx, ky, _, out_channels = kernel.shape

    im_col = np.lib.stride_tricks.as_strided(
             image, (batch_size, x_out-kx+1, y_out-ky+1, kx, ky, in_channels), 
             image.strides[:3] + image.strides[1:], writeable=False)

    return np.tensordot(im_col, kernel, axes=3)
im2col_conv = im2col_conv2d(input_matrix, input_kernel)
np.testing.assert_allclose(im2col_conv, jax.lax.conv_general_dilated(input_matrix, input_kernel, padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"), window_strides=(1,1)))
%timeit im2col_conv2d(input_matrix, input_kernel)
5.85 ms ± 35.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.lax.conv_general_dilated(input_matrix, input_kernel, padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"), window_strides=(1,1))
2.71 ms ± 55.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Carefully selected and autotuned cuDNN kernels everyone… Can’t beat that.

Also, I received this awesome comment about im2col which we could use to make the code more efficient. Thanks to Vladimir Ilievski!

Part 2

Wino$\nabla$ returns

I received a lot of comments declaring a certain surprise about the Winograd implementation and asking for a demo of how the matrix multiplications work in the 2D case. So here’s a little more detail about the implementation:

Let’s remind ourselves of the premise of the Winograd convolution implementation: It’s possible to increase the computational efficiency of a general matrix multiplication-based convolution approach by reducing the number of float multiplications (typically expensive) by replacing them with additions (less expensive) resulting in up to 4x efficiency gains in the optimal case. To do this, pre-computed matrices are used which depend only on the input and kernel size. These matrices ($A$, $B$ and $G$) are (again) identical for a given size of input and kernel. The output of the convolution of an input $d$ with a kernel $g$ then becomes:

$$ Y = A^T \bigg[ [G g G^T] \odot [B^T d B] \bigg] A $$

Lets try it out. The mathematical derivation of matrices $A$, $B$ and $G$ for various input sizes can be found in the appendix of the Winograd convolution paper and a tool for computing them for arbitrary sizes can be found in this repository by the authors.

Since our input is of size $150 \times 150$, which is neatly divisible by $6$, and it turns out that the usage of $6 \times 6$ patches is ideal, since it allows utilisation of the $F(4,3)$ implementation, which results in a 4x computational efficiency gain.

The matrices describing the operation then are as follows:

A = np.array([[1,0,0,0],
             ], dtype=np.float32)

B = np.array([[4,0,0,0,0,0],
             ], dtype=np.float32)

G = np.array([[1/4, 0, 0],
              [-1/6, -1/6, -1/6],
              [-1/6, 1/6, -1/6],
              [1/24, 1/12, 1/6],
              [1/24, -1/12, 1/6],
             ], dtype=np.float32)

print(f"A={A} \n\n B={B} \n\n G={G}")
A=[[ 1.  0.  0.  0.]
 [ 1.  1.  1.  1.]
 [ 1. -1.  1. -1.]
 [ 1.  2.  4.  8.]
 [ 1. -2.  4. -8.]
 [ 0.  0.  0.  1.]] 

 B=[[ 4.  0.  0.  0.  0.  0.]
 [ 0. -4.  4. -2.  2.  4.]
 [-5. -4. -4. -1. -1.  0.]
 [ 0.  1. -1.  2. -2. -5.]
 [ 1.  1.  1.  1.  1.  0.]
 [ 0.  0.  0.  0.  0.  1.]] 

 G=[[ 0.25        0.          0.        ]
 [-0.16666667 -0.16666667 -0.16666667]
 [-0.16666667  0.16666667 -0.16666667]
 [ 0.04166667  0.08333334  0.16666667]
 [ 0.04166667 -0.08333334  0.16666667]
 [ 0.          0.          1.        ]]

If we apply the convolution to a patch of the image, the primitive operation is:

X = input_matrix[0, :6, :6, 0]
k = input_kernel[0,...,0]
winograd_conv = A.T @ (np.multiply(G@k@G.T, B.T@X@B)) @ A

Lets now check our implementation against the lax.conv implementation above:

np.allclose(winograd_conv, jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID").squeeze())

Ooops! What happened here…?

What becomes apparent is that matrix $G$ contains fractions, which will result in some loss of precision! The authors actually discuss this and find it to not be a problem (in fact, I tend to argue that lower precision in CNNs is a form of regularisation). If we therefore slightly adjust the precision of the np.allclose call, things will be fine:

np.allclose(winograd_conv, jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID").squeeze(), atol=1e-7)


%timeit A.T @ (np.multiply(G@k@G.T, B.T@X@B)) @ A
6.69 µs ± 55.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID")
131 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

And once again, we emerge victorious!

Just kidding, this one is really not fair. JAX has to do a LOT more work deciding stuff this time around and for such a small input, it can’t possibly be as efficient as a pre-selected operation (JAX is optimised for performing this over a batch of n-dimensional tensors). For example, scipy.signal.correlate is much closer to what we are doing and is also much closer to the speed we are getting:

scipy.signal.correlate(X, k, mode="valid")
18.9 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


So here we go! A little more insight into what makes our tools tick.

I’d like to thank everyone who commented on the initial post on Twitter for the great ideas and hope to have the chance to do many more of these posts!

Georgios Kaissis
Georgios Kaissis
Specialist Diagnostic Radiologist - Senior Research Scientist

My research interests include AI-based medical image analysis, probabilistic methods and privacy-preserving AI