11 February, 2024

Easier to Understand: What is a Transformer? How does GPT work?

This article delves into the Transformer model, explaining its significance in natural language processing and how it forms the backbone of large language models like GPT and LLaMA.

Easier to Understand: What is a Transformer? How does GPT work?
Available in:
 English
 Vietnamese
Reading time: 8 min.
Table of content

    In the previous installment of our "Easier to Understand" series, we explored the techniques used to process natural language before the world invented Transformers. In this article, we'll continue our "journey through time" to understand the development of this natural language processing field, as well as how large language models like GPT or LLaMA work.

    "Easier to Understand" series:

    What are Transformers?

    Transformers are a language model structure introduced in the research paper titled "Attention is all you need". This structure addresses many of the problems that RNNs faced (as discussed in the previous article).

    There are many models based on the transformer structure, with the most famous being ELMo, BERT, and more recently, the GPT and LLaMA families. In this article, we'll use GPT and LLaMA as examples, so the transformer structure introduced here will be of the decoder-only type.

    GPT stands for Generative Pre-trained Transformer:

    • Generative: meaning the model can create new content, for example, if you input a question, it can generate an answer
    • Pre-trained: the model has been trained on unfiltered data, mostly to understand language first, without concern for context.
    • Transformer: (as explained above)

    Note: As this article is aimed at newcomers, I'll omit some overly complex details in my explanations, such as encoder-decoder, softmax, masked attention, etc.

    Next token prediction

    Transformers have many different applications, but in this article, we'll focus on next token prediction. To understand what a token is, you can read the previous installment of this series.

    ChatGPT actually operates on the principle of next token prediction. This means if you provide an incomplete piece of text, for example: "Today it's raining, I don't want to...", GPT will complete it, perhaps with something like: "...go to school". Of course, there are many other ways to complete it, such as "go to work", "go outside", etc., but the important thing is that the completion must be consistent with the preceding part.

    To complete accurately, the model needs to understand the connection between words in the sentence, and more importantly, the overall context of the text.

    For ChatGPT, you can roughly understand that the model is trained to understand that "after a question comes an answer". Therefore, when you input a question like "How are you?", it will provide an answer like "I'm fine, thank you". If you've read OpenAI's blog when they first released InstructGPT (the predecessor of ChatGPT), you'll see that without this training, the GPT-3 model would respond with completely irrelevant things. This additional "teaching" process is called fine-tuning, meaning teaching only specific things (e.g., teaching how to answer questions, not teaching additional knowledge).

    How we read

    "A very interesting thing about natural language is that not every word in a sentence is equally important."

    The above sentence can be shortened to:

    "Interesting about natural language, not every word is equally important."

    So, our problem is how to design a model that can "measure" the necessity of words in a sentence. The reason for measuring the importance of words is not to shorten sentences as above, but to know what information the model needs to pay attention to in order to generate the next word in the sentence (in other words, if you ask a question, the model needs to know which words in the question to focus on to generate the answer).

    To achieve this, transformers use a technique called "self-attention", which means finding the importance of a word in relation to other words within the same text (hence the "self" in "self-attention").

    In the example above, the word "that" refers to "a thing" that I mentioned at the beginning of the sentence, so when reading the word "that", the coefficient of the word "thing" must be higher than other words.

    Measuring this "level of importance" is the core, and also the most complex part of the Transformer.

    Query-Key-Value

    To calculate the "level of importance" of any word in a sentence compared to the other words, transformers use matrices for calculation. I won't go into mathematical details (because I'm not that good at math either, hehe), but in summary, we need to calculate 3 matrices serving 3 different purposes:

    • The Query matrix can be understood as similar to you entering a question on Google
    • The Key matrix can be understood as the webpage title for Google to compare between your question and that title
    • The Value matrix is the actual information content to answer your question

    In the example below, let's assume I've entered 3 tokens "The", "cat", "sat" and want the model to calculate the 4th token "on". The complete sentence we desire is "The cat sat on the mat"

    When you input a token into the model, what happens is:

    • The Query-Key-Value (Q-K-V) matrices are calculated for that token
    • We take the Query of the current token and multiply it with the Key of all previous tokens. This matrix multiplication helps us calculate how the current token is "related" to the previous tokens.
    • The result we get is the "level of importance" coefficient

    Below, let's assume we're at the third token "sat"

    1. We calculate Q-K-V for the third token
    2. Then multiply Q of the third token with K of tokens 1, 2, and 3

    In the example above, after multiplying Q with K, we calculate that: the word "sat" is the most important (with a coefficient of 0.9), because we often use the phrase "sat on".

    The next thing we need to do is multiply the Value matrix of each token with its "level of importance" coefficient, then sum them all up. This process is called "weighted sum". The final result we get is a single matrix, containing information about the content of everything that came before:

    Finally, we project the resulting matrix into a vector, and use a deep neural network (specifically a feed-forward network) to transform the resulting matrix into a token:

    We can repeat all the above processes to "fill in" the next word after "on", for example "The cat sat on the mat". Of course, what to fill in completely depends on the dataset that was used to train the model.

    The remaining question is: how to calculate the Query-Key-Value matrices?

    In reality, during the training process, the model will learn how to calculate these matrices. You can roughly understand that calculating these matrices directly affects the "intelligence" of the model. The more it learns, the higher the IQ - just like humans.

    Positional encoding

    At this point, we have another problem: due to the commutative nature of the multiplication and addition operations above (a + b = b + a), from the machine's perspective, saying "the dog eats" and "eats the dog" are no different (that's a bad joke 💀).

    So before inputting the token to calculate Q-K-V, we need to "merge" the position information together with the token. You can roughly understand that after the token has been transformed into a vector (as described in the previous article), this vector will be added to a vector specifically describing the position of that token:

    In the example above, at the bottom layer (output after addition), I've written out the original words for easier understanding. In reality, this output is a vector, ready to be fed into the Q-K-V calculation.

    There are many methods to create the position vector. At the time of writing this article, open-source models like LLaMA, Mistral, Phi-2, etc., all use Rotary Positional Embedding (RoPE). As GPT-3 is closed source, I don't have official information about the method OpenAI uses.

    Application in practice

    The things I just explained above, when applied to real models, are multiplied in number, for example:

    The model doesn't just calculate one set of Q-K-V, but many sets of Q-K-V will be calculated at the same time. The reason is to allow different "sets" to "filter" different information from a piece of text. This technique is called multi-head attention. For example, with LLaMA 2, 32 sets are used.

    Not only that, in reality, people connect the Q-K-V sets together (the output of one set is the input of the next set). Each time we "connect" like this, it can be considered as the model having an additional layer. For example, below I asked "The capital of Vietnam is..." The first layers typically take on the task of "separating" grammatical information, while the later layers "separate" more abstract information (like city, Vietnam, etc.). The following image was obtained by modifying the source code of llama.cpp:

    In addition, you can view a 3D representation here to better understand the structure of a typical transformer model: https://bbycroft.net/llm

    Comparison with Recurrent neural network

    In the previous article, I mentioned the Recurrent neural network (RNN) model structure. Although Transformers are more difficult to understand, they address major issues such as:

    • Multiple tokens can be input at the same time, increasing the efficiency of parallel processing (e.g., if you have multiple GPUs running in parallel). This is possible because the calculation of Q-K-V matrices can be performed in parallel, and the positional embedding of each token can be calculated independently.
    • Although matrix multiplication is very time-consuming, the results of multiplying K-V matrices can be reused, reducing calculation time. Thanks to the K-V cache technique, the time to calculate new tokens remains almost unchanged, whether it's the 100th, 1000th, or 10000th token.
    • While RNN "condenses" the meaning of a sentence into a single vector, Transformer keeps all existing content in memory. You use more RAM, but you're sure not to lose any information.

    References

    Want to receive latest articles from my blog?