Exercise - Learning Reber Grammar - Recurrent Neural Network with Pytorch

Introduction

This notebook aims to provide an introduction to Recurrent Neural Networks (RNNs).

First, we'll at the so called Reber Gramar. A simple artifical grammar, which is easy to understand and relatively low concerning computational effort.

Then we will go through the theory of RNNs step by step, while at the same time constructing the network.

At the end of the notebook you should have a basic understanding of vanilla RNNs. We won't be dealing with Long-Short-Term-Memories (LSTMs), the advanced RNN-cells.

Requirements

Knowledge

  • Pytorch
  • Basic Feed-Forward-Networks

Python Modules

import numpy as np
import torch
import matplotlib.pyplot as plt

The Reber Grammar

The reber grammar can be used for generating artifical data for the evaluation of recurrent neural networks.

A valid string can be generated by the following automaton.

internet connection needed

Example: BTSSXXTVVE is a valid string in the grammar.

Sidenote: The reber grammar was first introdcued in 1967 in [REB67]. Its purpose was to study implicit learning of grammatical rules.

The following code allows us to generate random valid grammar strings:

chars='BTSXPVE'

graph = [[(1,5),('T','P')] , [(1,2),('S','X')], \
           [(3,5),('S','X')], [(6,),('E')], \
           [(3,2),('V','P')], [(4,5),('V','T')] ]


def in_grammar(word):
    if word[0] != 'B':
        return False
    node = 0    
    for c in word[1:]:
        transitions = graph[node]
        try:
            node = transitions[0][transitions[1].index(c)]
        except ValueError: # using exceptions for flow control in python is common
            return False
    return True        
      
def sequenceToWord(sequence):
    """
    converts a sequence (one-hot) in a reber string
    """
    reberString = ''
    for s in sequence:
        index = np.where(s==1.)[0][0]
        reberString += chars[index]
    return reberString
    
def generateSequences(minLength):
    while True:
        inchars = ['B']
        node = 0
        while node != 6:
            transitions = graph[node]
            i = np.random.randint(0, len(transitions[0]))
            inchars.append(transitions[1][i])
            node = transitions[0][i]
        if len(inchars) > minLength:  
            return inchars

def get_one_example(minLength):
    inchars = generateSequences(minLength)
    inseq = []
    for i in zip(inchars): 
        inpt = np.zeros(7)
        inpt[chars.find(i[0])] = 1.     
        inseq.append(inpt)
    outseq = inseq[1:]
    inseq = inseq[0:-1]
    return inseq, outseq

The cell bellow shows how to use the code. get_one_example(minLength=10) produces 2 output sequences in one-hot-encoding:

  • for example B will be encoded as vector [1., 0., 0., 0., 0., 0., 0.]
  • E will be encoded as vector [0., 0., 0., 0., 0., 0., 1.]

The functionsequenceToWord can be used to transform one-hot-encoding back to characters.

Finally, the function in_grammar is used to validate a string if it is a correct reber grammar string.

in_seq, out_seq = get_one_example(minLength=10)

print('Input sequence starting at char index 0 until the second to the last:\n')
print(in_seq)
print('-----------------------------------')
print('Input sequence as character:')
print(sequenceToWord(in_seq))
print('-----------------------------------')
print('Target sequence starting at char index 1 until the last')
print(sequenceToWord(out_seq))
print('-----------------------------------')
print('Validate a string:')
### Append 'E' to input sequence as char at last index is ommited
print(in_grammar(sequenceToWord(in_seq)+'E'))

RNN

The main concept of a RNN is, that not only the input of the current time step$ t $ is fed into the network, but also the hidden state of the previous time step$ t-1 $. In the example picture below, our input data at any time step has 3 features. These features are fully connected to 4 hidden neurons. These 4 hidden neurons are again fully connected to 3 output values.

That the number of input and output neurons (in the picture 3) match, makes sence. Knowing the current character we want to predict, which is the next character in the sequence.

internet connection needed

Since our input characters are encoded as vectors of length 7, we have 7 input features n_in.

Our hidden layer will have 20 neurons n_hid.

We want to predict the next letter, based on the previous letters, so our output will also be a vector of length 7 n_out.

n_in  = 7 
n_hid = 20
n_out = 7

dtypeTorch = torch.float32
dtypeNumpy = np.float32

RNNs are very volatile to the vanishing / exploding gradient problem, so here it is very crucuial to initilize weights the right way. The following function allows us to sample weights in a way, so their sum does not get too high.

def sample_weights(sizeX, sizeY, verbose=False):
    values = np.ndarray([sizeX, sizeY], dtype=dtypeNumpy)
    for dx in range(sizeX):
        vals = np.random.uniform(low=-1., high=1.,  size=(sizeY,))
        values[dx,:] = vals
    _,svs,_ = np.linalg.svd(values)
    #svs[0] is the largest singular value
    if verbose: print('Sum of wheights before normalization: ', values.sum())
    values = values / svs[0]
    if verbose: print('Sum of wheights after normalization: ', values.sum())
    return values

Function to initilize weights and biases. Here we pass the argument verbose=True, which sololy serves the purpose to demonstrate the normalization.

def get_parameter(n_in, n_out, n_hid, verbose=False):
    b_h = torch.tensor(np.zeros(n_hid), dtype=dtypeTorch, requires_grad=True)
    W_ih = torch.tensor(sample_weights(n_in, n_hid, verbose), dtype=dtypeTorch, requires_grad=True)
    W_hh = torch.tensor(sample_weights(n_hid, n_hid, verbose), dtype=dtypeTorch, requires_grad=True)
    W_ho = torch.tensor(sample_weights(n_hid, n_out, verbose), dtype=dtypeTorch, requires_grad=True)
    b_o = torch.tensor(np.zeros(n_out), dtype=dtypeTorch, requires_grad=True)
    return W_ih, W_hh, b_h, W_ho, b_o

W_ih, W_hh, b_h, W_ho, b_o = get_parameter(n_in, n_out, n_hid, verbose=True)   

Task:

Based on the graphic for the RNN above and the formulas below, implement the one_step() function:

  • The hidden state$ h_t $ at time step$ t $ consist of :

    • the input features$ x_t $ at time step$ t $, which are fully connected, through weight matrix$ W_{ih} $
    • the hidden state values of$ h_{t-1} $ of the last time step, which are fully connected through weight matrix$ W_{hh} $
    • and a bias term$ b_h $
    • and all fed into an activation function. e.g.$ tanh $

$ h_t = tanh( W_{ih} x_{t} + W_{hh} h_{t-1} + b_{h} ) $

  • The output then consists of the hidden state values$ h_{t} $, fully connected through weight matrix$ W_{ho} $, bias$ b_o $ added, and finally fed into the logistic function$ \sigma $, to produce values from$ 0 $ to$ 1 $ for every output value:

$ y_t = \sigma( W_{ho} h_{t} + b_{o} ) $

return the prediction$ y_t $ for the current step and the hidde state$ h_t $, so we can use$ h_t $ in the next stept at time$ t+1 $

Sidenote:

we could use the fact that only one of our output values is on (1), while the others are off (0) and use the$ softmax $ function. But for simplicity, we'll just use the logistic function.

# input at time step t: x_t
# prior hidden state at time step t-1: h_tm1
# weights and biases: W_ih, W_hh, W_ho, b_h
def one_step(x_t, h_tm1, W_ih, W_hh, b_h, W_ho, b_o):

    raise NotImplementedError()
    return [h_t, y_t]

Ordinary binary cross entropy loss:

def bce (y_pred, y):
    return - torch.mean( y * torch.log(y_pred) + (1. - y) * torch.log(1.- y_pred))

Start the training:

W_ih, W_hh, b_h, W_ho, b_o = get_parameter(n_in, n_out, n_hid)   

lr = 0.2
costs = []

for e in range(1000):
    
    ### generate new sequence each epoch
    in_seq, out_seq = get_one_example(minLength=10)
    ### initial hiddenstate of time step t0-1 (does not exist)
    h_tm1 = torch.tensor(np.zeros(n_hid), dtype=dtypeTorch)
    
    ### we accumulate costs over the whole sequence starting with 0
    cost = 0

    for c in range(len(in_seq)):
        x_t = torch.tensor(in_seq[c], dtype=dtypeTorch)
        y_t_true = torch.tensor(out_seq[c], dtype=dtypeTorch)
        h_tm1, y_t_pred = one_step(x_t, h_tm1, W_ih, W_hh, b_h, W_ho, b_o)
        
        ### we accumulate costs over the whole sequence
        cost += bce(y_t_pred, y_t_true)
        
    ### backward when sequence completed
    cost.backward()      
    costs.append(cost.detach().numpy())

    with torch.no_grad():
        W_ih -= W_ih.grad * lr
        W_hh -= W_hh.grad * lr
        b_h -= b_h.grad * lr
        W_ho -= W_ho.grad * lr
        b_o -= b_o.grad * lr

        W_ih.grad.zero_()
        W_hh.grad.zero_()
        b_h.grad.zero_()
        W_ho.grad.zero_()
        b_o.grad.zero_()    

Now plot the costs executing the cell below. The costs should look like the following:

internet connection needed

Costs are still high at the end since almost everywhere in the grammar graph we have 2 posibilities for the next character. So even if the network decides for one of the 2 valid character, there is a 50 % chance it does not align with the ground truth data and therefore will be treated as wrong. Nevertheless we can see downwards trend.

### costs are still high since almost everywhere in the grammar graph we have 2 posibilities
### but we can see downwards trend
plt.plot(np.linspace(0,len(costs),len(costs)), costs)

Besides looking at the graph for the costs, there is another way to check if our model has learned the grammer. Namely by sampling, which means, we let the network generate outputs, given a starting sequence:

We start by feeding the starting character B at time step 0 to the network, which yields scores for the next character.

Instead of taking the character with the highest score, we treat the scores as a probability distribution and sample a character from it.

If we always took the character with the highest score, our RNN would always generate the same sequence given the starting sequence B.

Task:

Write some code to sample a sequence:

  1. Start by feeding the one-hot-encoded vector for character B into the network and an initial hidden state.
  2. This will yield$ y_t $, a vector of length 7, each element being in the range of 0 to 1 and the hidden state$ h_t $
  3. In order to treat these values as probabilites, normalize them, so they sum up to 1.0
  4. Sample from it and turn it into a character
  5. Then start over, this time feeding$ y_t $ and the new hidden state into the network
  6. Loop until the character E was generated.

Remember to keep track of the generated characters. Use the function sequenceToWord and in_grammar to validate the string.

With the used network and the hyperparameters proposed, your network should produce around 8-9 valid reber grammer strings out of 10.

Hint:

The code is almost the same as in the cell used to train the network with some differences, e.g.:

  • do not use ground truth data to feed into the network
  • no optimization of the weights
### Your code here

Summary and Outlook

TODO

Literature

Licenses

Notebook License (CC-BY-SA 4.0)

The following license applies to the complete notebook, including code cells. It does however not apply to any referenced external media (e.g., images).

Exercise - Learning Reber Grammar - Recurrent Neural Network with Pytorch
by Christian Herta, Klaus Strohmenger
is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
Based on a work at https://gitlab.com/deep.TEACHING.

Code License (MIT)

The following license only applies to code cells of the notebook.

Copyright 2019 Exercise - Learning Reber Grammar - Recurrent Neural Network with Pytorch

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.