In His Shoes Part II - GPT2 and Transformers
What is a Transformer?
The Transformer is a type of neural network designed to work with sequential data. As of December 2020, it is considered state-of-the-art in solving problems such as text translation and summarisation. Introduced in the paper Attention Is All You Need, it improves on the previous top models (LSTMs, GRUs) with the ability to parallelise computation (hence reducing training times) and fixing long term sequential dependancies.
Below is a summary table that compares neural network based NLP frameworks along with some fantastic resources that can give more detailed insight into these concepts.
Framework | Principle | Advantages | Disadvantages | Resources |
---|---|---|---|---|
RNN | RNNs are neural networks that are designed to take in sequential data (i.e. time series or language data) with no requirement of a fixed size. Although normal neural networks use previously seen data at the training stage to influence predictions, RNNs are able to also use previously seen data at the inference stage. | Able to use information from preceding data to make a decision on current data. | Vanishing (or exploding) gradient problems mean information 'too far into the past' is lost. Slow to train due to difficulty to parallelise. | [1] [2] [3] [4] |
LSTM / GRU | A type of RNN that is able to better remember long-term dependencies by allowing historical data to 'pass' through neurons in a network. This is typically performed through the use of 'input' and 'forget' gates. | Improves the 'memory' of RNNs. | Still struggles with really long sequences (1K tokens) and slow to train. | [1] [2] |
Transformers | A different architecture style to LSTMs / GRUs whereby sequences are passed through the model in 'batch' (typically called a context vector) rather than sequentially, and the model learns which part of the batch is most important for further learning. This concept of deciding what is important is called Attention. | Faster to train as it can be parallelised and fixes long term dependencies. | [1] [2] [3] |
Problem and Modelling Design
The problem in this article is the same as in the first part of this article, namely it is to be able to autocomplete a sentence based on a certain persona; in this case, the personas are journalists from different BBC News sections.
My previous modelling approach was to explore generating sentence forecasts using LSTMs. Although when LSTMs rose back into popularity in the 2015 and powered consumer products like Alexa and Google Assistant, the models I trained did not achieve as good results as I suspected. I'll concede in saying this is in part to a lack of significant hyperparameter tuning, but there seemed to be an upper limit to the performance due to fundamental features of the model (see model comparison table above).
After further exploration in the NLP space, I then discovered the Transformer, a model that in the past few years has demonstrated significantly better results than LSTMs and GRUs. And better still, there are many open-source variants of the Transformer that have been trained on large datasets, thus lending itself to use in transfer learning tasks. At the time of writing, the most readily available state-of-the-art models are those created by OpenAI, in the GPT-X family; the latest model GPT-3 has not yet been released, however there are a few available implementations of its predecessor GPT-2. Some other popular flavours of Transformer include Bert and T5 however I haven't had time to explore these so far.
The Python implementation I chose to go with is the one in the HuggingFace library. They have a whole heap of NLP models encoded as PyTorch or TensorFlow objects with excellent documentation, strong community support and a fairly simple API.
Load in model
The GPT-2 model (and all other NLP models in the HuggingFace library) have a matching tokenizer object. The tokenizer object contains the equivalent text manipulation steps as what was using when initially training the GPT-2 model. You can download the pretrained model and load into memory with the following code.
from transformers import GPT2LMHeadModel, TFGPT2LMHeadModel, GPT2TokenizerFast, \ OpenAIGPTLMHeadModel, TFOpenAIGPTLMHeadModel, OpenAIGPTTokenizer, \ T5Model, TFT5Model, T5Tokenizer, GPT2Config, TFGPT2Model class TransformerLoader: """ Load pretrained models, deals with checkpointing etc. At the minute, can be a function. """ @classmethod def from_huggingface(cls, model_name, framework='tf'): """ pass """ TOKENIZERS = { 'gpt': OpenAIGPTTokenizer, 'gpt2': GPT2TokenizerFast, 't5': T5Tokenizer } MODELS = { 'tf': { 'gpt': TFOpenAIGPTLMHeadModel, 'gpt2': TFGPT2LMHeadModel, 't5': TFT5Model, }, 'pt': { 'gpt': OpenAIGPTLMHeadModel, 'gpt2': GPT2LMHeadModel, 'tf': T5Model, } } if model_name not in MODELS[framework].keys(): raise NotImplementedError('Model not imported') # load tokenizer tokenizer_class = TOKENIZERS[model_name] tokenizer = tokenizer_class.from_pretrained(model_name, add_prefix_space=True) tokenizer.pad_token = tokenizer.eos_token # load model model_class = MODELS[framework][model_name] model = model_class.from_pretrained(model_name) return tokenizer, model model = 'gpt2' tokenizer, model = TransformerLoader.from_huggingface(model, framework='tf')
I am loading the TFGPT2LMHeadModel
rather than GPT2LMHeadModel
. The difference here is that the TFGPT2LMHeadModel
is a TensorFlow model and the other is a
PyTorch model. You can use these different objects with their respective frameworks and associated library rather than using of any HuggingFace's APIs.
Also notice that I am using the TFGPT2LMHeadModel
object rather than the TFGPT2LMDoubleHeadsModel
object; this is because
the latter model is typically used in problems where you want to train network with two distinct outputs. For example, the image below demonstrates the dynamics of a double headed model where the 'backbone' network is the feature map
between the input and the penultimate layer of the network, and the two different heads perform a 1. classification task of labelling the image, and 2. predicting the bounding box of a section of the image. Loss functions in multiheaded models is typically the weighted sum of the individual losses for each head. In general, multiheaded models is a network design with multiple output layers, allowing difference
loss functions to be computed, and hence different tasks that can be learnt with mostly the same network weights.
Load and Tokenise Text
I load the text and tokenise with the following snippet of code.
text = DataReader.read_bbc_tech() text = ' '.join(sentence for sentence in text) # join into large string tokenized_text = tokenizer.encode(text, return_tensors='tf') tokenized_text = tokenized_text[0]
Shape Input
I create feature and label tokens with the following code, using a sequence length of 10 and setting the labels to be a 1-shifted version of the input features. Unlike in the previous post, I do not pass a sliding window over the training set. This is because in all of the other posts where I've seen GPT-2 trained, noone else has used the sliding window concept on the training data - I also have a feeling that passing a sliding window is significantly contributing to the overfitting problem.
# split into chunks seq_length = 11 features = [] labels = [] # here labels = features + 1 (shifted by 1, next token prediction) examples = [] for i in range(0, len(tokenized_text) - seq_length + 1, seq_length): examples.append(tokenized_text[i:i + seq_length]) for ex in examples: features.append(ex[:-1]) labels.append(ex[1:])
I then create train and validation sets using the TensorFlow dataset object.
BATCH_SIZE = 12 BUFFER_SIZE = len(features) dataset = tf.data.Dataset.from_tensor_slices((features, labels)).shuffle(BUFFER_SIZE) train_size = math.ceil(len(features) * 0.8) # 80, 20 split train_dataset = dataset.take(train_size).batch(BATCH_SIZE, drop_remainder=True) val_dataset = dataset.skip(train_size).batch(BATCH_SIZE, drop_remainder=True)
I also noticed that when BATCH_SIZE
was anything but 12, I kept receiving a shape error at the training stage. I'm not entirely sure what causes this but I think it also has something to do with the loss function I compile with. Note that this setup has BATCH_SIZE = model.config.n_layer = 12
.
Training and Evaluation
I then compile with the categorical cross entropy function (one for every layer of the network). As noted earlier, I found this loss function quite tricky to get working and relied on the help of the author of this article for advice.
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # defining our optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) # definining our loss function loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # defining our metric which we want to observe metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') # compiling the model model.compile(optimizer=optimizer, loss=[loss, *[None] * model.config.n_layer], metrics=[metric])
Then I can finally begin training. Like in the previous article, training is done using Google Colab where I can utilise Googles GPUs and TPUs for much quicker training and inference.
num_epoch = 50 model_filepath = './bbc-politics' callbacks = [ EarlyStopping(monitor='val_accuracy', patience=25), ModelCheckpoint(f'{model_filepath}', save_best_only=True, save_weights_only=False, monitor='val_accuracy') ] history = model.fit( train_dataset, epochs=epochs, validation_data=(val_dataset), callbacks=callbacks )
In the end, the model results are phenomenal when compared to the LSTM method; in fact I suspect they are too good and I am cautious and wonder if there is some sort of data leakage occurring. The training after 25 epochs is shown below.
The table below gives a comparison between what metrics the fine-tuned GPT-2 model achieves on a few different input texts and compares this with the results from the previous post.
Dataset | Random Guess | Conditional Guess (n=3) | LSTM | GPT-2 |
---|---|---|---|---|
BBC Technology | 0.01 | 0.29 | 0.38 | 0.87 |
BBC Politics | 0.001 | 0.17 | 0.19 | 0.74 |
BBC Business | 0.001 | 0.09 | 0.15 | 0.59 |
It is clear that even GPT-2 does miles better and this can also be seen when you inspect the actual predictions of the model.
Generated Output vs. Previous Method
The following snippet of code allows you to generate sequences based on some seed tokens. I apply this function to the same seed sequences as in the previous post and compared the results of the new model.
def generate_sequence(text, tokenizer, model): # encoding the input text input_ids = tokenizer.encode(text, return_tensors='tf') # output output = model.generate( input_ids, max_length = 50, # maximum number of tokens/sub-words to output temperature = 0.7, # how innovative the result will be (higher temperature is equivalent to more innovative) do_sample=True, no_repeat_ngram_size=2, # do not allow repetations in the generated text above the number of n grams num_return_sequences=5 # number of outputs to generate ) # decode the output text = tokenizer.decode(output[0]) # print out the output generated from the sample input text return text
BBC Tech - Mobile music challenges 'iPod age'
Actual Sequenceand the current generation of mobiles using flash technology can store up to one gigabyte of music enough for 250 songs we are working in the hard disk area andLSTM Generated Sequence
and the current generation of mobiles using flash technology can store up to one gigabyte of music enough for 250 songs we are working in the hard disk area andOut-of-the-Box GPT-2 Generated Sequence
and the current generation of mobiles using flash technology can be more effective. The question is, will flash work for other mobile types? And do other brands of devices that have already been designed for this type of use have a similar marketFine Tuned GPT-2 Generated Sequence
and the current generation of mobiles using flash technology can overcome these barriers. "Compression technology has the potential to do much more but one thing," says Gavin Patterson, senior manager at Symantec's mobile network business unit. This is because it
BBC Politics - Clarke defends terror detentions
Actual Sequencethe new year in the meantime we will be studying the judgment carefully to see whether it is possible to modify our legislation to address the concerns raised by theLSTM Generated Sequence
the new year in the meantime we will be studying the announcement on his spending plans on the same after a meeting on labour's media media lord woolf for labour'sOut-of-the-Box GPT-2 Generated Sequence
the new year in the meantime we will be studying he progress of the economy in various ways. We have taken decisions on the issues of capital controls and the budget. We have made decisions to improve working conditions. The government has now taken actionFine Tuned GPT-2 Generated Sequence
the new year in the meantime we will be studying this carefully to see if it is sufficient to justify the move." He added: "However, in any case any decision we make is subject to further action by the Scottish Executive." Scottish ministers who want
BBC Politics - Age work law could 'fail' over 65s
Actual Sequenceretirement under 65 employers will no longer be able to force workers to retire before 65 unless they can justify it the government has announced that firms will be barredLSTM Generated Sequence
retirement under 65 employers will no longer be able to advise for the companies said mr rubinsohn is due to delivering improvements and the us's biggest manufacturers could be aboveOut-of-the-Box GPT-2 Generated Sequence
retirement under 65 employers will no longer be able to offer the benefit. "The government is looking at ways to expand the range of benefits, and many of them will be available to all employees," said Ms. Lacey, an economist atFine Tuned GPT-2 Generated Sequence
retirement under 65 employers will no longer be able to their advantage. The changes will apply to employers who have more than 50 employees and will help introduce a fairness to retirement policies. Employers will now be required
As testament to how great the fine-tuned GPT-2 model is, I have swapped the real sentence found in the original article with the sentence predicted by the GPT-2 model! Use the button above to correct the swap.
Another point to make is that although you can see that the out-of-the-box GPT-2 model works well at generating a sentence that makes grammatical sense, the fine-tuned the model generates samples that are much more in the style of a news article.
As always, I encourage you to check out the code here and follow up if you have any questions.