Neural Network Framework - Exercise: Convolution and Pooling Layer

Introduction

In this exercise, you will continue to implement the neural network framework that you started in exercise e06_nn_framework. At the end of this exercise, the framework should be extended by a convolutional layer and a pooling layer so that you can create simple ConvNets. You want your operations, especially the convolution, to be efficient, so it will not slow down the training process to an unacceptable rate. Therefore your goal is to implement vectorized versions of the layers in the exercise.

Requirements

Knowledge

At this point, you should have a good understanding of how a convolutional layer works. At least, you should have solved the exercises exercise-nn-framework and exercise-conv-net-pen-and-paper before doing this exercise. Exercise exercise-nn-framework provided you with the necessary information about the framework architecture that you need. Exercise exercise-conv-net-pen-and-paper gave you the theoretical background about the vectorization of the convolutional layer.

You can find both of these exercises in the Convolutional Neural Networks course.

Python-Modules

import numpy as np

Data

While your implementation is intended to work with any given image data$ X $, you will also use some tweaked values, e.g. the toy example$ I_{toy} $ and the kernel$ K_{toy} $ during this exercise so you are able to verify your results. As$ I_{toy} $ and$ K_{toy} $ you will use the data from exercise-conv-net-pen-and-paper, so you know already the results of the covolutional operation and the backward pass (if$ dout $ is a all-ones matrix) for the hyperparameter setting stride$ s = 2 $ and padding$ p = 0 $.

$ I_{toy} = \left[ \begin{array} { c c c c c } { 1 } & { 1 } & { - 2 } & { 0 } & { 1 } \\ { 1 } & { 0 } & { 0 } & { 2 } & { 1 } \\ { 0 } & { 1 } & { 0 } & { 5 } & { - 1 } \\ { - 2 } & { 1 } & { 0 } & { - 1 } & { 1 } \\ { 0 } & { 1 } & { 0 } & { 5 } & { - 1 } \end{array} \right] , K_{toy} = \left[ \begin{array} { c c c } { 0 } & { 1 } & { 1 } \\ { 1 } & { 0 } & { 0 } \\ { 0 } & { 1 } & { 0 } \end{array} \right] , dout_{toy} = \left[ \begin{array} { c c c } { 1 } & { 1 } \\ { 1 } & { 1 } \end{array} \right] $ $ K_{toy} $ will be defined later on in the section Initializer.

# Images 
X = np.random.randn(600, 1, 28, 28)
I_toy = np.array([[[[1,1,-2,0,1],[1,0,0,2,1],[0,1,0,5,-1],[-2,1,0,-1,1],[0,1,0,5,-1]]]])

# Upflowing toy gradients 
dout_toy = np.array([[[[1.,1.],[1.,1.]]]])

#print('X shape:', X.shape)
#print('X:', X)

Utility classes and functions

The following classes and functions are utilities to vectorize the convolutional operation and to initialize learnable parameters with different methods. You can put these into the utils.py file of your framework after the exercise. Be sure to read them carefully and figure out how they work.

im2col/col2im Functions

The following methods are taken from cs231 Convolutional Neural Network assignment 2 and have been modified for square images. They represent a compact implementation of the im2col approach to vectorize the convolutional operation. You should be familiar with the theory behind that approach. Now let us have a closer look at the implementation, which splits into three methods get_indices, im2col and col2im.

get_indices()

Our goal is to transform an image$ I $ to the im2col vector$ \vec{i} $. For image dimensions (5x5), kernel dimension$ (3,3) $ and hyperparameters stride$ S = 2 $, padding$ P = 0 $ following applies:

$ I_{5x5} = \left[ \begin{matrix} i_{1,1} & i_{1,2} & i_{1,3} & i_{1,4} & i_{1,5} \\ i_{2,1} & i_{2,2} & i_{2,3} & i_{2,4} & i_{2,5}\\ i_{3,1} & i_{3,2} & i_{3,3} & i_{3,4} & i_{3,5}\\ i_{4,1} & i_{4,2} & i_{4,3} & i_{4,4} & i_{4,5}\\ i_{5,1} & i_{5,2} & i_{5,3} & i_{5,4} & i_{5,5}\\ \end{matrix} \right] $

The corresponding im2col matrix should look like:

$ \vec{i} = \left[ \begin{matrix} i_{1,1} & i_{1,3} & i_{3,1} & i_{3,3} \\ i_{1,2} & i_{1,4} & i_{3,2} & i_{3,4} \\ i_{1,3} & i_{1,5} & i_{3,3} & i_{3,5} \\ i_{2,1} & i_{2,3} & i_{4,1} & i_{4,3} \\ i_{2,2} & i_{2,4} & i_{4,2} & i_{4,4} \\ i_{2,3} & i_{2,5} & i_{4,3} & i_{4,5} \\ i_{3,1} & i_{3,3} & i_{5,1} & i_{5,3} \\ i_{3,2} & i_{3,4} & i_{5,2} & i_{5,4} \\ i_{3,3} & i_{3,5} & i_{5,3} & i_{5,5} \\ \end{matrix}\right] $

When we look at the$ i $-th and$ j $-th index individually, we can store each index in its own matrix. That is what the get_indices() does for you. The corresponding matrix for the$ i $-th index would be

$ \left[ \begin{matrix} 1 & 1 & 3 & 3 \\ 1 & 1 & 3 & 3 \\ 1 & 1 & 3 & 3 \\ 2 & 2 & 4 & 4 \\ 2 & 2 & 4 & 4 \\ 2 & 2 & 4 & 4 \\ 3 & 3 & 5 & 5 \\ 3 & 3 & 5 & 5 \\ 3 & 3 & 5 & 5 \\ \end{matrix}\right] $

and for the$ j $-th index:

$ \left[ \begin{matrix} 1 & 3 & 1 & 3 \\ 2 & 4 & 2 & 4 \\ 3 & 5 & 3 & 5 \\ 1 & 3 & 1 & 3 \\ 2 & 4 & 2 & 4 \\ 3 & 5 & 3 & 5 \\ 1 & 3 & 1 & 3 \\ 2 & 4 & 2 & 4 \\ 3 & 5 & 3 & 5 \\ \end{matrix}\right] $

These two matrices are the variables i and j are given by the method get_indices as the 2nd and 3rd element of the returned tuple. For simplification, we have a trivial depth dimension$ c $ in the example.

def get_indices(x_shape, filter_dim=(3, 3), padding=0, stride=1):
    '''Gets the incides in which the elements of your matrix has to be inserted 
       in order to transform then into a im2col matrix.

    Args:
        x_shape: Shape of the input data
        filter_dim: i-th and j-th filter dimensions, e.g., (3,3) 
        padding: Layer padding size 
        stride: Layer stride size 
   
    Returns:
        im2col indices as 3-tuple (c,i, j) - (channel indices, col indices, 
                                              row indices)
    '''
    # get shape of the input data and calculate output dimensions
    # beause we deal with squared images we don't have to calculate it twice
    N, C, H, W = x_shape
    out_size = (H + 2 * padding - filter_dim[0]) // stride + 1  # // to get an int() instead of float()
    #out_size = int(out_size)

    # calculate the indices of the channel dimension
    c = np.repeat(np.arange(C), filter_dim[0] * filter_dim[1]).reshape(-1, 1)

    # calculating the indices of the width & height dimension
    # repeat() and tile() are used to multiply the wanted sequences
    i0 = np.repeat(np.arange(filter_dim[0]), filter_dim[1])
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_size), out_size)
    j0 = np.tile(np.arange(filter_dim[0]), filter_dim[1] * C)
    j1 = stride * np.tile(np.arange(out_size), out_size)
    
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)
    return (c, i, j)

im2col()

The function im2col() transforms an image$ I $ into the desired representation using the indices from get_indices(). For the toy example$ I_{toy} $, a$ (5x5) $-matrix with concrete values, we expect the following:

$ I_{toy} = \left[ \begin{array} { c c c c c } { 1 } & { 1 } & { - 2 } & { 0 } & { 1 } \\ { 1 } & { 0 } & { 0 } & { 2 } & { 1 } \\ { 0 } & { 1 } & { 0 } & { 5 } & { - 1 } \\ { - 2 } & { 1 } & { 0 } & { - 1 } & { 1 } \\ { 0 } & { 1 } & { 0 } & { 5 } & { - 1 } \end{array} \right] $

You know from the exercise exercise-conv-net-pen-and-paper that the im2col matrix of$ I_{toy} $ should look like:

$ \vec{i} = \left[ \begin{matrix} 1 & -2 & 0 & 0 \\ 1 & 0 & 1 & 5 \\ -2 & 1 & 0 & -1 \\ 1 & 0 & -2 & 0 \\ 0 & 2 & 1 & -1 \\ 0 & 1 & 0 & 1 \\ 0 & 0 & 0 & 0 \\ 1 & 5 & 1 & 5 \\ 0 & -1 & 0 & -1 \end{matrix} \right] $

def im2col(x, filter_dim=(3, 3), padding=0, stride=1):
    ''' Transforms a image matrix to im2col matrix. 
    
    Assume you have a image of shape (600, 1, 28, 28), padding=0, 
    stride=2 and a filter with dimensions (3,3). You already know 
    that the output dimension of a convolution operator has to be 
    (13,13) with (28-3)/2 + 1 = 13.
    im2col creates then a new matrix with the shape of (9 * 1, 600 * 13 * 13)
    which you then can matrix multiply with your flattend kernel of shape 
    (n,9 * 1). The multiplication will result into a new matrix of shape 
    (n,600*13*13) which you can then reshape into your convolution 
    output (600, n, 13, 13) which is the wanted result. Note that n is
    the numbers of filters inside your convolution layer.
    
    Args:
        x: Input data
        filter_dim: i-th and j-th filter dimensions, e.g., (3,3)
        padding: Layer padding size
        stride: Layer stride size
        
    Returns:
        im2col matrix e.g. in our example with shape (9 * 1, 600 * 13 * 13)
    '''
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
    
    # get the indices of the column matrix
    c, i, j = get_indices(x.shape, filter_dim, padding, stride)
    
    # create the col matrix by using the indices from above
    # Hint: cols should have a shape of (9 * 1, 600 * 13 * 13) in the example given by the documentation.
    # cols = TODO
    cols = x_padded[:, c, i, j]
    cols = cols.transpose(1, 2, 0)

    # transforming the matrix to the desired shape, e.g., (9 * 1, 600 * 13 * 13)
    cols = cols.reshape(filter_dim[0] * filter_dim[1] * x.shape[1], -1)
    return cols
# Usage example and reproduction of the result from e08
print('I_toy transformed with im2col:')
print(im2col(I_toy, stride=2))

col2im()

The function col2im() is the inverse function to im2col(). In general, you use the col2im-function to convert the gradient$ \frac{\partial L}{\partial I} $ from its im2col-matrix representation back into the initial shape of the input. This assumes that you use the same hyperparameter settings as in the im2col transformation.

The im2col-matrix that we get as a result of the backward path in the convolutional layer for the toy example is:

$ \frac{\partial L}{\partial I_{toy\_col}} = \left[ \begin{matrix} 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 \\ 1 & 1 & 1 & 1 \\ 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 \end{matrix} \right] $

From the optional exercise in e08 you should know that the correct gradient$ \frac{\partial L}{\partial I_{toy}} $ is:

$ \frac{\partial L}{\partial I_{toy}} = \left[ \begin{matrix} 0 & 1 & 1 & 1 & 1 \\ 1 & 0 & 1 & 0 & 0 \\ 0 & 2 & 1 & 2 & 1 \\ 1 & 0 & 1 & 0 & 0 \\ 0 & 1 & 0 & 1 & 0 \end{matrix} \right] $

Exactly these values you get if you use col2im() with the$ \frac{\partial L}{\partial I_{toy\_col}} $ matrix and a hyperparamter setting of stride$ S = 2 $ and padding$ P = 0 $.

def col2im(cols, x_shape, filter_dim=(3, 3), padding=0, stride=1):
    ''' Transforms a im2col matrix to its initial shape.
    
    You can think of this method as the reverse function of im2col. It is
    adding up the corresponing indices and transforms the given matrix
    back into the initial shape.

    Assuming you have a im2col transformed matrix with shape (9,600*13*13). 
    The original image had a shape of (600,1,28,28) with padding P = 0 
    and stride S = 2. col2im creates out of the im2col matrix and the same
    hyperparameter a new matrix with a shape of (600, 1, 28, 28).
    
    Args:
        cols: im2col matrix
        x_shape: Shape of the data before applying im2col transformation 
        filter_dim: Filter dimensions
        padding: Layer padding size
        stride: Layer stride size
    
    Returns:
        Matrix with inital shape 
    '''
    # Get shapes and padding
    N, C, H, W = x_shape
    padded_shape = H + 2 * padding

    # Create a placeholder with initial input shape and padding used during conv 
    x_padded = np.zeros((N, C, padded_shape, padded_shape), dtype=cols.dtype)

    # transform and reorder matrix to restore a im2col matrix for each sample
    cols_reshaped = cols.reshape(C * filter_dim[0] * filter_dim[1], -1, N)
    cols_reshaped = cols_reshaped.transpose(2, 0, 1)    

    # Get same indices used in a corresponging im2col transformation. 
    # With the indices and fancy indexing add all matching indices
    # and store them in a matrix with initial shape.
    c, i, j = get_indices(x_shape, filter_dim, padding, stride)
    np.add.at(x_padded, (slice(None), c, i, j), cols_reshaped)
    
    # if necessary remove conv padding
    if padding == 0:
        return x_padded
    return x_padded[:, :, padding:-padding, padding:-padding]
# Usage example and reproduction of the result from e08
dI_toy_col =np.array([[[[0.,0.,0.,0.],
                      [1.,1.,1.,1.],
                      [1.,1.,1.,1.],
                      [1.,1.,1.,1.],
                      [0.,0.,0.,0.],
                      [0.,0.,0.,0.],
                      [0.,0.,0.,0.],
                      [1.,1.,1.,1.],
                      [0.,0.,0.,0.]]]])
dI_toy = col2im(dI_toy_col, I_toy.shape, stride=2)
print(dI_toy)

Initializer

Class WeightInitializer provides some functions to initialize the weights in your convolutional layer with different methods, e.g., Xavier Glorot initialization technic if you use sigmoid as activation function. You should move the class into utils.py after solving the exercise and develop it so that you can initialize your fully connected layer with it as well. The method toy_initialization is for this exercise only, and you should delete it when moving the class into the script file. You can use this method with$ I_{toy} $ and$ K_{toy} $ to reproduce the values of exercise-conv-net-pen-and-paper. The usage is pretty simple. When initializing a convolutional layer, you set the weight initialization as a parameter:

conv1 = Convolution(filter_num=16,filter_dim=(1, 3, 3), initializer=WeightInitializer.he, stride=1, padding=0)

However, be careful to choose a method that is appropriate for your activation function.

class WeightInitializer(object):
    ''' Different weight initialization methods
    
    '''
    def random(filter_num, input_channels, filter_dim):
        ''' Initialize a kernel with random values 
        
        Returns:
            4d tensor (filter_number, filter_depth, filter_height, filter_weight)        
        '''
        random_weight = np.random.randn(filter_num, input_channels, filter_dim[0], filter_dim[1]) * 0.1
        return random_weight

    def glorot(filter_num, input_channels, filter_dim):
        ''' Xavier Glorot initialization - used for sigmoid, tanh
        
        Returns:
            4d tensor (filter_number, filter_depth, filter_height, filter_weight)
        '''
        glorot_weights = np.random.randn(filter_num, input_channels, filter_dim[0], filter_dim[1]) * np.sqrt(1. / filter_dim[0]) 
        return glorot_weights
    
    def he(filter_num, input_channels, filter_dim):
        ''' Kaiming He initialization - used for ReLU family
        
        Returns:
            4d tensor (filter_number, filter_depth, filter_height, filter_weight)
        '''
        he_weights = np.random.randn(filter_num, input_channels, filter_dim[0], filter_dim[1]) * np.sqrt(2. / filter_dim[0]) 
        return he_weights
    
    def toy_initialization(filter_num, input_channels, filter_dim):
        ''' Only for exercise e09 test case. Initialize 
            the filter used in e08: 
            [[[[ 0.,1.,1.], 
               [ 1.,0.,0.],
               [ 0.,1.,0.]]]])
        '''
        return np.array([[[[ 0.,1.,1.], [ 1.,0.,0.],[ 0.,1.,0.]]]])

Convolutional Layer

Vectorized Implementation Convolutional Layer

Task:

Implement a vectorized convolutional layer. The class and initialization are given. You have to implement the forward and backward path. Revisit the exercise exercise-conv-net-pen-and-paper.pdf if necessary to recall how a vectorized convolution is done. Then add the Convolution class to your already existing layer.py in your framework.

class Convolution():
    ''' Creates a convolutional layer
    
    To use:
        conv1 = Convolution(filter_num=16,filter_dim=(1, 3, 3), 
                           initializer=WeightInitializer.he, stride=1, padding=0)
        conv.forward(X)
    '''
    def __init__(self, filter_num=32, filter_dim=(1, 3, 3), initializer=WeightInitializer.random, stride=1, padding=0):
        ''' Initilize convolution layer with given parameter
        
        Args:
            filter_num: number of filters used for the convolution
            filter_dim: (filter_depth, filter_height, filter_width) 
                        filter_depth have to be equal to input_depth
            stride: step size to move the filter
            padding: size of zero padding
        '''
        self.input_channels = filter_dim[0]
        self.filter_num = filter_num
        self.filter_dim = (filter_dim[1], filter_dim[2])
        self.stride = stride
        self.padding = padding
        # Initialize weights and bias
        self.W = initializer(filter_num, self.input_channels, self.filter_dim)
        self.b = np.zeros((1, filter_num)).T
        
        self.params = [self.W, self.b]
        
    def forward(self, X, verbose=False):
        ''' Convolution over input X to create feature maps
        
        Actualy implements a vectorized cross-correlation over
        input X with the filters specified in the layer creation.

        Args:
            X: 4d tensor (num_images, num_channels, height, width)
            verbose: If set, prints information about shapes of x_col, w_col and out
            
        Returns:
            out: feature maps 
        '''
        # Get input and weight parameter
        n_filters, d_filter, h_filter, w_filter = self.W.shape
        n_x, d_x, h_x, w_x = X.shape
        
        # Calculate feature map dimensions
        size_out = (h_x - h_filter + 2 * self.padding) // self.stride + 1
        
        ##############################
        ####### BEGIN SOLUTION ####### 
        ##############################

        ############################
        ####### END SOLUTION #######
        ############################
        # store input and im2col-matrix for backprop
        self.X = X
        self.x_col = x_col 
        return out

    def backward(self, dout, verbose=False):
        ''' Calculates the gradient with respect to the image X, the filter/kernel W and bias b.

        Args:
            dout: Gradient of the output
            verbose: If set, prints information about shapes of db, dW and dX

        Returns:
            dX : Derivation with respect to X
            dW : Derivation with respect to W
            db : Derivation with respect to b
        '''
        n_filter, d_filter, h_filter, w_filter = self.W.shape

        ##############################
        ####### BEGIN SOLUTION ####### 
        ##############################

        ############################
        ####### END SOLUTION #######
        ############################

        return dX, [dW, db]

Testing Your Conv Implementation

There are two test cases:

  • The toy example to reproduce the values from exercise-conv-net-pen-and-paper. You can verify these results by hand.
  • A more realistic X tensor, but with a simple gradient. Your resulting shapes should match.
# --- Toy example ---
# Create a ConvLayer with parameters from e08
conv_toy = Convolution(filter_num=1,filter_dim=(1, 3, 3), initializer=WeightInitializer.toy_initialization, stride=2, padding=0)

# Forward path
out_toy = conv_toy.forward(I_toy)
np.testing.assert_array_equal(out_toy, np.array([[[[1., 6.], [0., 9.]]]]), verbose=True)

# Backward path 
dout_toy = np.ones((1,1,2,2)) # Gradient from e08
dI_toy, [dW_toy, db_toy] = conv_toy.backward(dout_toy)
np.testing.assert_array_equal(dI_toy, np.array([[[[0., 1., 1., 1., 1.], 
                                                  [1., 0., 1., 0., 0.],
                                                  [0., 2., 1., 2., 1.],
                                                  [1., 0., 1., 0., 0.],
                                                  [0., 1., 0., 1., 0.]]]]), verbose=True)
np.testing.assert_array_equal(dW_toy, np.array([[[[-1., 7., -2.], 
                                                  [-1., 2., 2.],
                                                  [0., 12., -2.]]]]), verbose=True)
np.testing.assert_array_equal(db_toy, np.array([[4]]), verbose=True)
# More complex test case
# Create a ConvLayer with appropriate filter settings
conv1 = Convolution(filter_num=16,filter_dim=(1, 3, 3), initializer=WeightInitializer.he, stride=1, padding=0)

# Forward path
out1 = conv1.forward(X, verbose=False)

#Backward path
dout1 = np.ones((600, 16, 26, 26))
dX1, [dW1, db1] = conv1.backward(dout1, verbose=False)

# Calculate some shapes individualy and validate it against the implementation
num, channel, height, width = X.shape
print('Is the shape of the output correct?', out1.shape == (num, conv1.filter_num, 
                                                    (height - conv1.filter_dim[0] + 2 * conv1.padding) // conv1.stride + 1, 
                                                    (width - conv1.filter_dim[0] + 2 * conv1.padding) // conv1.stride + 1), 
                                                    out1.shape)
print('Is the shape of input_gradients correct?', dX1.shape == X.shape, dX1.shape)
print('Is the shape of weight_gradients correct?', dW1.shape == conv1.W.shape, dW1.shape)
print('Is the shape of bias_gradients correct?', db1.shape == conv1.b.shape, db1.shape)

Pooling Layer

Vectorized Implementation MaxPooling Layer

Task:

Implement a max pooling operation as a network layer. Move the class into the layer.py file of the framework after you finish your implementation. If you want you can also implement a parameter for choosing a different pooling function, e.g. mean or sum.

class Pooling():
    ''' Creates a MaxPooling layer
    '''

    def __init__(self, filter_dim=(2, 2), stride=2):
        ''' Initialize a pooling layer with `max` as pooling function

        Layer is usualy initialzied with a (2,2) filter and
        stride 2 as parameter. Your input volume, e.g. (1,1,28,28)
        should shrink by factor 2 -> (1,1,14,14) with this setting.

        Args:
            filter_dim: pooling size
            stride: stride size
        '''
        self.filter_dim = filter_dim
        self.stride = stride
        self.params = []

    def forward(self, X, verbose=False):
        ''' Applies the max function to each kernel postion
        
        Args:
            X: input volume - 4d tensor
            verbose: If set, prints shapes of some volumes
            
        Returns:
            4d tensor with reduced dimensions according to the chosen
            hyperparameters and max function applied to each filter
            position
        '''
        # Reshapes the images so that the depth dim is 1
        n_x, d_x, h_x, w_x = X.shape
        x_reshaped = X.reshape(n_x * d_x, 1, h_x, w_x)
        
        ##############################
        ####### BEGIN SOLUTION ####### 
        ##############################
        # HINT: Pooling operation is just a convolution with a different filter.
        #       Make use im2col to get a column matrix and find the max in each
        #       column. Store max indices in the object for use in the backprop.
        
        ############################
        ####### END SOLUTION #######
        ############################
            
        # Save input and input col for future use
        self.x_col = x_col
        self.X = X
        return out

    def backward(self, dout, verbose=False):
        ''' Backward path of the Maxpooling-Layer
        
        Remember there is no gradient calculation in a
        max pooling layer.
        
        Args:
            dout: Upflowing gradient
            verbose: If set, prints shapes of some volumes
            
        Returns:
            dX with following modification:
             - dX values sitting on indices of max values from 
               the forward path will be passed through the layer
             - all other values are set to zero
        '''
        # Save the shape of the input image
        n_x, d_x, h_x, w_x = self.X.shape
        
        ##############################
        ####### BEGIN SOLUTION ####### 
        ##############################
        # Hint: Use the indices of the forward path and
        #       fancy indexing to chose the right indices dX
        
        ############################
        ####### END SOLUTION #######
        ############################
        
        # reshaping it pack to the input image shape
        dX = dX.reshape(self.X.shape)
        return dX, []

Test Your Pooling Implementation

Lets create another toy example again an use it for the max pooling layer:

$ I_{pooling} = \left[ \begin{matrix} 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 9 & 0 & 8 & 0 & 9 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 7 & 0 & 7 & 0 & 8 & 0 \\ 5 & 0 & 6 & 0 & 9 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ \end{matrix} \right] $

Use the output of the forward path as input for the backward path to recreate the original input. The second test inserts the more complex output tensor from the convolutional operation into a pooling layer and checks its dimensions after forward and backward path.

# Another toy example as sanity check
I_toy_pooling = np.array([[[[0., 0., 0., 0., 0., 0.], 
                            [0., 9., 0., 8., 0., 9.],
                            [0., 0., 0., 0., 0., 0.],
                            [7., 0., 7., 0., 8., 0.],
                            [5., 0., 6., 0., 9., 0.],
                            [0., 0., 0., 0., 0., 0.]]]])

pooling_toy = Pooling(filter_dim=(2, 2), stride=2)
out_toy_pooling = pooling_toy.forward(I_toy_pooling, verbose=True)

# backward path - pooling
dout_toy_pooling = pooling_toy.backward(out_toy_pooling, verbose=True)
dX_toy_pooling, empty = dout_toy_pooling

# Tests
np.testing.assert_array_equal(out_toy_pooling, np.array([[[[9., 8., 9.], 
                                                           [7., 7., 8.],
                                                           [5., 6., 9.]]]]), verbose=True)
np.testing.assert_array_equal(dX_toy_pooling, np.array([[[[0., 0., 0., 0., 0., 0.], 
                                                          [0., 9., 0., 8., 0., 9.],
                                                          [0., 0., 0., 0., 0., 0.],
                                                          [7., 0., 7., 0., 8., 0.],
                                                          [5., 0., 6., 0., 9., 0.],
                                                          [0., 0., 0., 0., 0., 0.]]]]), verbose=True)
# A more complex example
# forward path - pooling
pooling2 = Pooling(filter_dim=(2, 2), stride=2)
out2 = pooling2.forward(out1)

# backward path - pooling
gradient2 = pooling2.backward(out2)
dX2, empty = gradient2

num, channel, height, width = out1.shape
print('------------Pooling test------------')
print('Output correct shape?', out2.shape == (num, 
                                              channel, 
                                              (height-pooling2.filter_dim[0])//pooling2.stride + 1, 
                                              (width-pooling2.filter_dim[0])//pooling2.stride + 1), out2.shape)
print('Gradient correct shape?', dX2.shape == out1.shape, dX2.shape)

Outlook

After finishing the exercise, you should have a working and efficient neural network framework. Do not forget to move all classes into the corresponding script files of the framework. A good follow up is to repeat your experiments on a dataset of your interest from exercise-nn-framework but with a ConvNet instead of a standard neural network.

Licenses

Notebook License (CC-BY-SA 4.0)

The following license applies to the complete notebook, including code cells. It does however not apply to any referenced external media (e.g., images).

1163150 - Neural Networks - Exercise: Convolutional and Pooling Layer
by Steven Mi, Benjamin Voigt
is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
Based on a work at https://gitlab.com/deep.TEACHING.

Code License (MIT)

The following license only applies to code cells of the notebook.

Copyright 2018 Steven Mi, Benjamin Voigt

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.