# Convoluted Stuff - Volume 2

JAX strikes back, more puppies and a deeper look into Winograd

## Part 1

#### In which we benchmark our ideas against `jax.lax.conv`

In part 1 of this series, we found out that our `numpy`

im2col implementation can nearly match the TensorFlow implementation in speed. Soon after, however, Matthew Johnson of `JAX`

fame commented the following:

Awesome reading, and thanks for the very kind words :) Might be interesting to compare to jax.lax.conv, though on GPUs that’ll typically call into a carefully-selected autotuned cuDNN kernel.

— Matthew Johnson (@SingularMattrix) May 17, 2020

Since we are trying a different implementation, we need a different puppy as well, but we’ll keep the kernel and everything else identical to be fair.

```
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from functools import partial
show = partial(plt.imshow, cmap="gray")
import jax
import scipy
```

```
im = (np.array(Image.open("wolf.jpg").resize((150,150)))/255.).astype(np.float32)
```

```
show(im)
im.dtype
plt.suptitle("W O W S O Q T")
```

```
Text(0.5, 0.98, 'W O W S O Q T')
```

```
kernel = np.array([[1,0,-1],[0,0,0],[-1,0,1]], dtype=np.float32)
show(kernel.squeeze())
```

```
<matplotlib.image.AxesImage at 0x7fd12039ee50>
```

```
print(im.shape, kernel.shape)
```

```
(150, 150, 3) (3, 3)
```

```
input_matrix, input_kernel = np.tile(im, (8, 1, 1, 1)), np.tile(kernel[...,None, None], (1,1,3,16))
print(input_matrix.shape, input_kernel.shape)
```

```
(8, 150, 150, 3) (3, 3, 3, 16)
```

We’ll reuse our im2col implementation from last time and benchmark it against `lax.conv_general_dilated`

.

```
def im2col_conv2d(image, kernel):
batch_size, x_out, y_out, in_channels = image.shape
kx, ky, _, out_channels = kernel.shape
im_col = np.lib.stride_tricks.as_strided(
image, (batch_size, x_out-kx+1, y_out-ky+1, kx, ky, in_channels),
image.strides[:3] + image.strides[1:], writeable=False)
return np.tensordot(im_col, kernel, axes=3)
```

```
im2col_conv = im2col_conv2d(input_matrix, input_kernel)
```

```
np.testing.assert_allclose(im2col_conv, jax.lax.conv_general_dilated(input_matrix, input_kernel, padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"), window_strides=(1,1)))
```

```
%timeit im2col_conv2d(input_matrix, input_kernel)
```

```
5.85 ms ± 35.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

```
%timeit jax.lax.conv_general_dilated(input_matrix, input_kernel, padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"), window_strides=(1,1))
```

```
2.71 ms ± 55.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

😕

Carefully selected and autotuned cuDNN kernels everyone… Can’t beat that.

Also, I received this awesome comment about im2col which we could use to make the code more efficient. Thanks to Vladimir Ilievski!

Great article! For the Caffe trick convolution, if your input is with constant shape (which is the case usually), you can cash the im2col indices in a dictionary and use index selection in Python to instantly transform the matrices in im2col format.

— Vladimir Ilievski (@VladOsaurus) May 17, 2020

## Part 2

#### Wino$\nabla$ returns

I received a lot of comments declaring a certain surprise about the Winograd implementation and asking for a demo of how the matrix multiplications work in the 2D case. So here’s a little more detail about the implementation:

Let’s remind ourselves of the premise of the Winograd convolution implementation: It’s possible to increase the computational efficiency of a general matrix multiplication-based convolution approach by reducing the number of float multiplications (typically expensive) by replacing them with additions (less expensive) resulting in up to 4x efficiency gains in the optimal case.
To do this, pre-computed matrices are used which depend **only** on the input and kernel size. These matrices ($A$, $B$ and $G$) are (again) **identical** for a given size of input and kernel. The output of the convolution of an input $d$ with a kernel $g$ then becomes:

$$ Y = A^T \bigg[ [G g G^T] \odot [B^T d B] \bigg] A $$

Lets try it out. The mathematical derivation of matrices $A$, $B$ and $G$ for various input sizes can be found in the appendix of the Winograd convolution paper and a tool for computing them for arbitrary sizes can be found in this repository by the authors.

Since our input is of size $150 \times 150$, which is neatly divisible by $6$, and it turns out that the usage of $6 \times 6$ patches is ideal, since it allows utilisation of the $F(4,3)$ implementation, which results in a 4x computational efficiency gain.

The matrices describing the operation then are as follows:

```
A = np.array([[1,0,0,0],
[1,1,1,1],
[1,-1,1,-1],
[1,2,4,8],
[1,-2,4,-8],
[0,0,0,1]
], dtype=np.float32)
B = np.array([[4,0,0,0,0,0],
[0,-4,4,-2,2,4],
[-5,-4,-4,-1,-1,0],
[0,1,-1,2,-2,-5],
[1,1,1,1,1,0],
[0,0,0,0,0,1]
], dtype=np.float32)
G = np.array([[1/4, 0, 0],
[-1/6, -1/6, -1/6],
[-1/6, 1/6, -1/6],
[1/24, 1/12, 1/6],
[1/24, -1/12, 1/6],
[0,0,1]
], dtype=np.float32)
print(f"A={A} \n\n B={B} \n\n G={G}")
```

```
A=[[ 1. 0. 0. 0.]
[ 1. 1. 1. 1.]
[ 1. -1. 1. -1.]
[ 1. 2. 4. 8.]
[ 1. -2. 4. -8.]
[ 0. 0. 0. 1.]]
B=[[ 4. 0. 0. 0. 0. 0.]
[ 0. -4. 4. -2. 2. 4.]
[-5. -4. -4. -1. -1. 0.]
[ 0. 1. -1. 2. -2. -5.]
[ 1. 1. 1. 1. 1. 0.]
[ 0. 0. 0. 0. 0. 1.]]
G=[[ 0.25 0. 0. ]
[-0.16666667 -0.16666667 -0.16666667]
[-0.16666667 0.16666667 -0.16666667]
[ 0.04166667 0.08333334 0.16666667]
[ 0.04166667 -0.08333334 0.16666667]
[ 0. 0. 1. ]]
```

If we apply the convolution to a patch of the image, the primitive operation is:

```
X = input_matrix[0, :6, :6, 0]
k = input_kernel[0,...,0]
```

```
winograd_conv = A.T @ (np.multiply(G@k@G.T, B.T@X@B)) @ A
```

Lets now check our implementation against the `lax.conv`

implementation above:

```
np.allclose(winograd_conv, jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID").squeeze())
```

```
False
```

Ooops! What happened here…?

What becomes apparent is that matrix $G$ contains fractions, which will result in some loss of precision! The authors actually discuss this and find it to not be a problem (in fact, I tend to argue that lower precision in CNNs is a form of regularisation).
If we therefore slightly adjust the precision of the `np.allclose`

call, things will be fine:

```
np.allclose(winograd_conv, jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID").squeeze(), atol=1e-7)
```

```
True
```

👍

```
%timeit A.T @ (np.multiply(G@k@G.T, B.T@X@B)) @ A
```

```
6.69 µs ± 55.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

```
%%timeit
jax.lax.conv_general_dilated(X[None,...,None], k[...,None,None], window_strides=(1,1), dimension_numbers=("NHWC", "HWIO", "NHWC"), padding="VALID")
```

```
131 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```

And once again, we emerge victorious!

Just kidding, this one is really not fair. `JAX`

has to do a *LOT* more work deciding stuff this time around and for such a small input, it can’t possibly be as efficient as a pre-selected operation (JAX is optimised for performing this over a batch of n-dimensional tensors). For example, `scipy.signal.correlate`

is much closer to what we are doing and is also much closer to the speed we are getting:

```
%%timeit
scipy.signal.correlate(X, k, mode="valid")
```

```
18.9 µs ± 130 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

## Conclusion

So here we go! A little more insight into what makes our tools tick.

I’d like to thank everyone who commented on the initial post on Twitter for the great ideas and hope to have the chance to do many more of these posts!