Constructing Neural Networks with Automatic Differentiation (Differentiable Programming)

Introduction

In two previous notebooks, you've implemented a Node class that supports automatic differentiation. (Automatic Differentiation for Scalars / Automatic Differentiation for Matrices)

You've modeled computational graphs using Node objects and their operators. This allowed you to numerically compute the gradients of the output with respect to all named nodes that contributed to it in the graph.

Now, we will use this autodiff class to create layers of a neural network. The aim is to obtain the gradients using autodiff. Usually, training a network involves performing the forward pass to calculate the loss, then the backward pass to calculate the gradients of all parameters. With autodiff, you only have to implement the forward pass and get the gradients of the loss 'for free'.

At the end of this notebook, you'll have created an architecture to construct custom Neural Network models. You'll then implement such a neural net for recognising handwritten digits in the MNIST dataset.

Requirements

Prerequisites

Knowledge

  • Knowledge of automatic differentiation and neural networks is expected.

Python Modules

# for everything
import numpy as np
import matplotlib.pyplot as plt

# for the node (autodiff) class
import operator
import uuid
import numbers
import os

# for the neural network
from deep_teaching_commons.data.fundamentals.mnist import Mnist
from collections import namedtuple

Autodiff class Node

Deep learning frameworks support automatic differentiation via backpropagation through the computational graph. You've created such a class in a previous exercise, so you should be able to use your own implementation. But for reference, a full implementation is provided below.

# we must add the grad values 
def combine_dicts(a, b, op=operator.add):
    x = (list(a.items()) + list(b.items()) +
        [(k, op(a[k], b[k])) for k in set(b) & set(a)])
    return {x[i][0]: x[i][1] for i in range(0, len(x))}

class Node(object):
    
    NodeStore = dict()
    
    def _set_variables(self, name, value):
        self.value = value
        self.shape = value.shape
        self.dtype = value.dtype
        self.name = name
        self.uuid = uuid.uuid4()
        if name:
            self.grad = lambda g : {name : g}
            self._register()
        else: 
            self.grad = lambda g : {}

    def __init__(self, value, name=None):      
        # wrap numbers in a numpy 2D-Array
        if isinstance(value, numbers.Number):
            value = np.array([[value]])
        assert isinstance(value, np.ndarray)
        if len(value.shape)==1:
            value = value.reshape(-1,1)
        
        self._set_variables(name, value)
        # only 2D-Arrays are supported at moment
        assert len(self.shape)==2
        
    def _register(self):
        Node.NodeStore[self.name] = self
        
    def get_param():
        param = dict()
        for n in Node.NodeStore:
            param[n] = Node.NodeStore[n].value
        return param
        
    def set_param(param):
        for n in Node.NodeStore:
            Node.NodeStore[n].value = param[n]
        
    def _broadcast_g_helper(o, o_):# broadcasting make things slightly more complicated
        if o_.shape[0] > o.shape[0]:
            o_ = o_.sum(axis=0).reshape(1,-1)
        if o_.shape[1] > o.shape[1]:
            o_ = o_.sum(axis=1).reshape(-1,1)
        return o_    
                  
    def _set_grad(self_, g_total_self, other, g_total_other):
        g_total_self = Node._broadcast_g_helper(self_, g_total_self)
        x = self_.grad(g_total_self)
        g_total_other = Node._broadcast_g_helper(other, g_total_other)
        x = combine_dicts(x, other.grad(g_total_other))
        return x
                
    def __add__(self, other):
        if isinstance(other, numbers.Number):
            other = Node(np.array([[other]]))
        ret = Node(self.value + other.value)
        def grad(g):
            g_total_self = g
            g_total_other = g
            x = Node._set_grad(self, g_total_self, other, g_total_other)
            return x
        ret.grad = grad
        return ret
    
    def __radd__(self, other):
        return Node(other) + self
    
    def __sub__(self, other):
        if isinstance(other, numbers.Number):
            other = Node(other)
        ret = self + (other * -1.)   
        return ret
    
    def __rsub__(self, other):
        if isinstance(other, numbers.Number):
            other = Node(other) 
            return other - self
        raise NotImplementedError()
        
    def __mul__(self, other):
        if isinstance(other, numbers.Number) or isinstance(other, np.ndarray):
            other = Node(other)           
        ret = Node(self.value * other.value) 
        def grad(g):
            g_total_self = g * other.value
            g_total_other = g * self.value
            x = Node._set_grad(self, g_total_self, other, g_total_other)
            return x
        ret.grad = grad
        return ret      
    
    def __rmul__(self, other):
        if isinstance(other, numbers.Number):
            return Node(other) * self
        raise NotImplementedError()
        
    def concatenate(self, other, axis=0):
        assert axis in (0,1) # TODO
        ret = Node(np.concatenate((self.value, other.value), axis=axis))
        def grad(g):
            if axis == 0: 
                g_total_self = g[:self.shape[0]] 
                g_total_other = g[self.shape[0]:]
            elif axis == 1:
                g_total_self = g[:, :self.shape[1]] 
                g_total_other = g[:, self.shape[1]:]
            x = Node._set_grad(self, g_total_self, other, g_total_other)
            return x
        ret.grad = grad
        return ret
    
    # slicing
    def __getitem__(self, val):
         raise NotImplementedError()
            
    def __truediv__(self, other):
        if isinstance(other, numbers.Number):
            other = Node(np.array([[other]]))
        ret = Node(self.value / other.value) 
        def grad(g):
            g_total_self = g / other.value
            g_total_other = -1 * self.value * g / (other.value**2)
            x = Node._set_grad(self, g_total_self, other, g_total_other)
            return x
        ret.grad = grad
        return ret
    
    def __rtruediv__(self, other):
        if isinstance(other, numbers.Number):
            other = Node(other)
            return other/self
        raise NotImplementedError()
        
    def __neg__(self):
        return self * -1.
    
    def dot(self, other):
        ret = Node(np.dot(self.value, other.value))    
        def grad(g):
            g_total_self = np.dot(g, other.value.T)
            g_total_other = np.dot(self.value.T, g)
            x = Node._set_grad(self, g_total_self, other, g_total_other)
            return x
        ret.grad = grad
        return ret
    
    def transpose(self):
        ret = Node(self.value.T)
        def grad(g):
            x = self.grad(g.T)
            return x
        ret.grad = grad
        return ret       
    
    def exp(self):
        ret = Node(np.exp(self.value))
        def grad(g):
            assert self.shape == g.shape
            x = self.grad(np.exp(self.value) * g)
            return x
        ret.grad = grad
        return ret    
    
    def log(self):
        ret = Node(np.log(self.value))
        def grad(g):
            assert self.shape == g.shape
            x = self.grad(1./self.value * g)
            return x
        ret.grad = grad
        return ret     
       
    def square(self):
        ret = Node(np.square(self.value))
        def grad(g):
            assert self.shape == g.shape
            x = self.grad(2 * self.value * g)
            return x
        ret.grad = grad
        return ret       
         
    def sqrt(self):
        ret = Node(np.sqrt(self.value))
        def grad(g):
            assert self.shape == g.shape
            x = self.grad(0.5 * (1/np.sqrt(self.value)) * g)
            return x
        ret.grad = grad
        return ret     
        
    def sum(self, axis=None):
        if axis is None:
            return self._sum_all()
        assert axis in (0,1)
        return self._sum(axis)
               
        
    def _sum_all(self):
        ret = Node(np.sum(self.value).reshape(1,1))
        def grad(g):
            x = self.grad(np.ones_like(self.value) * g)
            return x
        ret.grad = grad
        return ret
    
    def _sum(self, axis):
        ret = self.value.sum(axis=axis)
        if axis==0: 
            ret = ret.reshape(1, -1)
        else:
            ret = ret.reshape(-1, 1)
        ret = Node(ret)
        def grad(g):
            x = self.grad(np.ones_like(self.value) * g)
            return x
        ret.grad = grad
        return ret  
    
    def relu(self):
        self.mask = self.value > 0.
        ret = Node(self.mask * self.value)
        def grad(g):
            assert self.shape == g.shape
            x = self.grad(self.mask * g)
            return x
        ret.grad = grad
        return ret    
    
    def softmax(self):
        ret = self.exp() / self.exp().sum(axis=1)
        np.testing.assert_almost_equal(ret.value.sum(axis=1), 1.)
        return ret
    
    def sigmoid(self):
        return 1./(1. + self.exp())

Neural Node

This class is used to create individual layers of a neural network. This class is never instantiated and you won't interact with it directly. Rather, it provides a set of static functions used by the upcoming Model class.

Most significantly, the functions _Linear_Layer and _ReLu_Layer are used to create layers for a neural network. The workflow is as follows 1. The function_Linear_Layer or _ReLu_Layer takes a few arguments:

  • the number of input features
  • the number of output features
  • a human-readable name
  • (optional) A dictionary of parameters
  1. In the function body
  • initialise the weights and biases of the layer if they're not already provided
  • define the forward function of the layer
  1. There are two return values
  • the forward function

  • the parameters as a dictionary that maps human-readable names to values

class NeuralNode(object):
     
    _sep = "_"
    
    def _get_fullname(name, suffix):
        return NeuralNode._sep.join([name, suffix])
        
    def _initialize_W(fan_in, fan_out):
        gain = np.sqrt(2)
        std = gain / np.sqrt(fan_in)
        bound = np.sqrt(3) * std 
        return np.random.uniform(-bound, bound, size=(fan_in, fan_out))
        
    def _initialize_b(fan_in, fan_out):
        bound = 1 / np.sqrt(fan_in)
        return np.random.uniform(-bound, bound, size=(1, fan_out))    
        
    def _Linear_Layer(fan_in, fan_out, name=None, param=None):
        if param is None:
            param = dict()
        assert isinstance(name, str) 
        weight_name = NeuralNode._get_fullname(name, "weight")
        bias_name = NeuralNode._get_fullname(name, "bias")
        
        #W_value = NeuralNode.param.get(weight_name)
        #b_value = NeuralNode.param.get(bias_name)
        W_value = param.get(weight_name)
        b_value = param.get(bias_name)
        
        assert (W_value is None and b_value is None) or (W_value is not None and b_value is not None)
        if W_value is None:
            W_value = NeuralNode._initialize_W(fan_in, fan_out)
            b_value = NeuralNode._initialize_b(fan_in, fan_out)
            param[weight_name] = W_value
            param[bias_name] = b_value
            
        W = Node(W_value, weight_name)
        b = Node(b_value, bias_name) 
        return lambda X: (X.dot(W) + b), param
    
    def _ReLu_Layer(fan_in, fan_out, name=None, param=None):
        if param is None:
            param = dict()
        ll, param = NeuralNode._Linear_Layer(fan_in, fan_out, name, param)
        f = lambda X: ll(X).relu()
        return f, param

You won't interact with this class directly. But for a demonstration, let's create a named ReLu layer with 3 input features and 1 output feature.

forward_fn,param = NeuralNode._ReLu_Layer(3,1,"relu")

Verify for yourself whether the shapes of the weights and bias make sense.

param
toy_data = np.fromfunction(lambda i,j : i-5, (10,3))
toy_data
out = forward_fn(Node(toy_data))
out.value

Model

With NeuralNode we have a mechanism to create individual layers. Now we need a way to connect layers in a coherent network, which is exactly what the Model class does.

A model instance defines one or more layers, provides a function for the forward pass, a loss function and can accesss/modify all learnable parameters across all layers.

To define the network: 1. Create a subclass of Model 1. Add layers to the model using the ReLu_Layer or Linear_Layer functions 1. Implement the forward pass and the loss function

To train the model: 1. get_param(self) returns a dictionary of all parameters in the network 1. get_grad(self,x,y) feeds the train samples x and train labels y into the network. It returns the gradients of all parameters with respect to the loss as well as the loss. 1. set_param(self, param) sets parameters of the model.

So, in a training loop you'll get the current parameters, compute the loss and gradients, calculate the new parameters through an update rule i.e. param_new = param_old - gradient * learning_rate and set the new parameters.

class Model(object):

        # self.nodes values are tuples (layer_type, param)                                                                                    
        _NNode = namedtuple("NNode", ['layer_type', 'param'])

        def __init__(self):
            self.nodes = dict()

        def get_param(self):
            param = dict()
            for node_name, node_value in self.nodes.items():
                param = {**param, **node_value.param}
            return param

        def set_param(self, param):
            for node_name, node_value in self.nodes.items():
                for param_name in node_value.param:
                    node_value.param[param_name] = param[param_name]

        def get_grad(self, x, y):
            loss_ = self.loss(x, y)
            g = np.ones_like(loss_.value)
            return loss_.grad(g), loss_


        def _set_layer(self, fan_in, fan_out, name, layer_type):
            assert isinstance(name, str)
            assert name not in self.nodes.keys()
                                                                                                 
            _, param = layer_type(fan_in, fan_out, name=name, param=dict())
            self.nodes[name] = Model._NNode(layer_type=layer_type, param=param)                                    
            return lambda x: layer_type(fan_in, fan_out, name, param)[0](x)

        def ReLu_Layer(self, fan_in, fan_out, name=None):
            return self._set_layer(fan_in, fan_out, name, NeuralNode._ReLu_Layer)

        def Linear_Layer(self, fan_in, fan_out, name=None):
            return self._set_layer(fan_in, fan_out, name, NeuralNode._Linear_Layer)
        
        # The following methods                                                                                               
        # must be implemented by subclasses                                                                               
        def forward(self, x):
            raise NotImplementedError()

        def loss(self, x, y):
            raise NotImplementedError()

Demonstration

For a demonstration, we'll set up a model for linear regression. The network comprises a single linear layer with a single neuron. The input and output are scalar values originating from the target $ y = 2 * x + 1 $ with some added noise.

x = np.linspace(-5,5)
noise = 2 * np.random.normal(size=x.size)
y = 2 * x + 1 + noise
plt.scatter(x,y)

Here, we define the network architecture.

class LinearRegressionNN(Model):
    def __init__(self):
        # Call the parent constructor
        super(LinearRegressionNN, self).__init__()
        # Define the layers of the network. self.layer stores
        # the forward function of that layer
        self.layer = self.Linear_Layer(1,1, "my_amazing_layer")
        
    def forward(self, x):
        if not type(x) == Node:
            x = Node(x)
        # Pipe x through the forward function of all layers
        # (in our case there's only one layer)
        out = self.layer(x)
        return out

    def loss(self,x,y):
        # MSE loss
        if not type(x) == Node:
            y = Node(y)
        out = self.forward(x)
        loss = (out - y).square().sum()
        return loss

Now we create a simple training loop.

lnn = LinearRegressionNN()
for epoch in range(100):
    # compute the loss and gradients
    grad,loss = lnn.get_grad(x,y)
    # get the current parameters
    param_current = lnn.get_param()
    # calculate new parameters
    param_new = { name : param_current[name] - 0.001 * grad[name]
                for name in param_current.keys()}
    # set new parameters
    lnn.set_param(param_new)
    if epoch%10 == 0:
        print('epoch: {}, loss: {}'.format(epoch, loss.value))

Verify for yourself if the learned weight/bias match the target function.

lnn.get_param()
# Output of the model:
plt.scatter(x,y,label='data')
x_ = np.linspace(-5,5)
p = lnn.forward(x_).value
plt.plot(x_,p,'r',label='model output')
plt.legend()

Exercises

Exercise: MNIST

In the following exercise you'll create a network to recognise handwritten digits with the MNIST dataset.

Data preparation

This cells downloads the MNIST dataset, by default into

~/deep.Teaching/data/ ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz

Modify data_dir to choose a different download location.

The images (originally 28x28 pixels) are flattened into a 784 dimensional vector. The labels are one-hot encoded.

data_dir = os.path.expanduser('~') + '/deep.Teaching/data/'
mnist_loader = Mnist(data_dir=data_dir)
train_images, train_labels, test_images, test_labels = mnist_loader.get_all_data(
    flatten=True,
    one_hot_enc=True,
    normalized=True
)

image = 234 #image you want to see
print()
print("Picture shape",train_images[image].shape)
print("label",train_labels[image],"->",np.argmax(train_labels[image]))

pixels = train_images[image].reshape((28,28))
plt.imshow(pixels, cmap='gray')
plt.show()

You can use the function below to generate mini-batches, or implement your own.

def gen_mini_batches(x,y,size):
    """Yields mini-batches from a dataset.
    
    Parameters:
      x: samples
      y: labels
      size: desired mini-batch size
      
    Yields:
      (x_mini, y_mini): subset of samples and labels from the training set"""
    assert len(x) == len(y)
    assert len(x) % size == 0
    assert size <= len(y)
    while True:
        indices = np.random.choice(np.arange(len(y)), size=size, replace=False)
        x_mini = x[indices,:]
        y_mini = y[indices,:]
        yield x_mini, y_mini

Exercise: Define the model

Define the network architecture of the model.

  • Define three layers:

    1. Input Layer: ReLu (input: 784 features, output: 500 features)
    2. Hidden Layer: ReLu (input : 500, output: 200 features)
    3. Hidden Layer: ReLu (input: 200, output: 10 features)
  • In the forward function, compute the forward pass through all layers. At the end, apply the softmax function.
  • In the loss function, compute the cross-entropy loss. If you haven't modified the data preparation cell, the target y is one-hot encoded.
nn = NeuralNode
nn.Model = Model

class Net(nn.Model):
    def __init__(self):
        # Create layers
        self.hidden_1 = None
        
    def forward(self, x):
        if not type(x) == Node:
            x = Node(x)
        # Implement the forward pass through all
        # layers, then apply softmax
        out = np.random.random(size=(len(x),10))
        raise NotImplementedError()
        return out
        
    def loss(self, x, y):
        y = Node(y)
        # Implement cross-entropy loss
        loss = None
        raise NotImplementedError()
        return loss
nn = Net()
loss = nn.loss(np.random.randn(1,784), np.eye(10)[:,2])
loss.value

Exercise: Optimizer

Now we want to train our neural network. For this purpose we'll implement a function to perform mini batch gradient descent. The input is the network we want to train, the training data and hyperparameters.

Remember: Use get_param and set_param to (update) the parameters of the model.

def sgd(network,x,y,steps,lrate,mini_batch_size,print_every=50):
    for i in range(steps):
        # draw a random mini batch
        # compute the derivatives of the loss functions
        # get parameters of the modul
        # compute new parameters
        # set new parameters
        # print the loss
        pass
net = Net()
sgd(net,
         x=train_images,
         y=train_labels,
         steps=1000,
         lrate=0.001,
         mini_batch_size=100,
         print_every=100)

Test the accuracy of your model on the test data.

true = test_labels.argmax(axis=1) # decode one-hot vector
pred = net.forward(test_images).value.argmax(axis=1)
accuracy = np.mean(true == pred)
accuracy

Exercise: Refactoring Optimizer

Now we want to refactor our gradient descent function into a more general Optimizer class. An optimizer is initialised with a model, training data and hyperparameters. It implements functions to train the model.

To use it, subclass Optimizer with a concrete class that represents an optimization technique. You can scroll a few cells down to see the subclass SGD for stochastic gradient descent in action.

Task: You'll find a skeleton for the Optimizer class below.

  • Implement the random_batch method
  • Implement the _train method
class Optimizer(object):
    
    def __init__(self, model, x_train=None, y_train=None, hyperparam=dict(), batch_size=128):
        self.model = model
        self.x_train = x_train
        self.y_train = y_train
        self.batch_size=batch_size
        self.hyperparam = hyperparam
        self._set_param()
        self.grad_stores = [] # list of dicts for momentum, etc. 

    def _train(self, steps=1000, num_grad_stores=0, print_each=100):
        # Initialise grad_store - stores dicts of learnable
        # parameters of the model
        assert num_grad_stores in (0,1,2)
        model = self.model
        if num_grad_stores>0:
            x, y = self.random_batch()
            grad, loss = model.get_grad(x, y)
            self.grad_stores = num_grad_stores * [dict()]
        for grad_store in self.grad_stores:
            for g in grad:
                grad_store[g] = np.zeros_like(grad[g])
        
        # TODO: get old parameters 
        # for each step...
        for i in range(1, steps+1):
            pass
            # TODO: generate a random batch of train data     
            # TODO: calculate the loss and gradients
            # TODO: print the loss
            # TODO: calculate updated parameters. The update rule is defined ins self._update 
            # TODO: set new parameters
        
        return loss.value
    
    def random_batch(self):
        # Returns a random batch of training data. The size of the batch
        # is given by self.batch_size
        raise NotImplementedError()
    
    #######################################################
    # The following methods are implemented by subclasses #
    #######################################################
    def train(self, steps=1000, print_each=100):
        # Call the inherited _train function with default parameters
        raise NotImplementedError()
        
    def _update(self, param, grad, g, i):
        # Performs an update for a single parameter.
        # param is a dictionary of all named parameters. g is the name of the
        # parameter to update.
        #
        # param :dict(str,np.ndarray) a dictionary of named parameters
        # grad : np.ndarray gradient of the parameter
        # g : str; name of the parameter, used as a key in the `param` dictionary
        # i : int; the step in the training loop in which the update takes place
        raise NotImplementedError() 
        
    def _set_param(self):
        # set hyperparameters as class members
        # e.g. self.alpha = self.hyperparam.get('alpha',0.9)
        pass

Now we'll subclass Optimizer and implement stochastic gradient descent. Then we'll train a MNIST model using the refactored code.

class SGD(Optimizer):
    
    def __init__(self, model, x_train=None, y_train=None, hyperparam=dict(), batch_size=128):
        super(SGD, self).__init__(model, x_train, y_train, hyperparam, batch_size)
        
    def _set_param(self):
        self.alpha = self.hyperparam.get("alpha", 0.001)

    def _update(self, param, grad, g, i):
        param[g] -= self.alpha * grad[g]    
             
    def train(self, steps=1000, print_each=100):
        return self._train(steps, num_grad_stores=1, print_each=print_each)
net = Net()
optimizer = SGD(net,
         x_train=train_images,
         y_train=train_labels,
         hyperparam=({ 'alpha' : 0.001}),
         batch_size=100
)

# too high learning rate will cause numerical instability in the cross
# entropy , so stay below 0.005 if you test the code 

optimizer.train(steps=1000)

Test the accuracy on the test set.

true = test_labels.argmax(axis=1) # decode one-hot vector
pred = net.forward(test_images).value.argmax(axis=1)
accuracy = np.mean(true == pred)
accuracy

Summary and Outlook

You've now created an architecture to create custom neural networks. Where do you go from here?

  • Right now your model only offers linear and relu layers, you could extend it with more types of layers.
  • You could refactor parameter initialisation similarly to how you refactored the training process. Right now, every parameter is initialised with a fixed method. We could implement different Initializers and control how each layer initialises its parameters.

Literature

Licenses

Notebook License (CC-BY-SA 4.0)

The following license applies to the complete notebook, including code cells. It does however not apply to any referenced external media (e.g., images).

Constructing Neural Networks with Automatic Differentiation
by Christian Herta, Diyar Oktay
is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
Based on a work at https://gitlab.com/deep.TEACHING.

Code License (MIT)

The following license only applies to code cells of the notebook.

Copyright 2019 Christian Herta, Diyar Oktay

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.