Neural Networks Intuitions: 19. Transformers
Hello All!
It is great to be back again and this time I am really excited about the concept I have been long waiting to write about and the most transformational one(in literal sense ;)) in AI space for the past 6–7 years— the Transformers.
Yes, we will be learning about the concepts of Attention in detail and what makes a Transformer so effective in learning rich representations from the data, helping in various downstream tasks — ranging from Computer Vision to NLP.
To make things more intuitive(and less complex), I will be discussing about the fundamentals of Attention mechanism in the context of a Computer Vision problem first and then touch upon NLP :)
Most importantly, I will be approaching Transformers from the perspective of a General Purpose Algorithm/Differential Computer, why Transformers are so general and why they can be applied to all types of tasks — be it Computer Vision or NLP and understand why “a transformer can be viewed as a more general case of MLPs whereas convolutions and recurrences are special cases of MLPs.”
Problem Statement — Image Classification
Consider the problem of image classification. As we are all well aware of the neural network architectures that are widely used for vision tasks — Convolutional Neural Networks and Multilayer Perceptrons, which work exceedingly well for Computer Vision problems, given the right amount of data.
And the way, MLPs/ConvNets solve this problem, is by either processing the entire image as input(in case of MLPs) or focus on local patches of the input image(ConvNets, which is infact a special case of MLP), and learns unique representations specific to every class from the training data.
Now let’s say we want to solve this specific vision problem with a newer neural network architecture that possesses the following properties:
- the network doesn’t assume or have the inductive bias of local patches as in ConvNets, but is a much more general form of MLP where instead of having static(predefined) connections across input neurons, how to enable a neural network learn these connections on its own from the data.
- the network learns which patches/regions in an image are highly correlated with each other — i.e. network that can learn the semantics(of images).
- it learns to integrate these local correlations globally to uniquely learn embeddings/representations that can classify various objects/patterns.
*Inductive bias/priors are certain assumptions made in the algorithm to effectively learn the target function.
Let us come up with an architecture that is able to perform all of the above :-)
Solution:
A. Preprocessing
- First step is to split the image into a sequence of patches — each patch of size 16x16 or 14x14, flatten every patch into a 1d vector. Assume there are N patches.
- For every input patch in the sequence, pass it through a linear projection(or FC) layer to generate input embeddings — this layer can be learnable.
- In order to preserve the locality/position of these local patches, add another embedding(to the above input embedding, which is also learnable) that uniquely represents each patch’s location/position wrt. to the input image. This gives us the final refined input embedding.
*A patch is essentially a token in the input, hence used interchangeably.
B. Learning correlations among the sequence of patches
Now that we have a sequence of patch embeddings(with position information encoded as well), how do we enable these various patches to communicate to each other, pass on its information/features and learn patterns such as semantics, class specific information?
- Take a token embedding and pass it through three weight matrices(square matrices) to transform it into three different embeddings.
- The input token embedding is transformed to produce two embeddings(using these two weight matrices), the dot product of the two representing its similiarity. Repeat this for all tokens in the input, we will get 2N embedding vectors as outputs.
- Compute dot product of every embedding vector with every other vector(generated using the first two weights). This produces another matrix whose values represent the similarity of all tokens with each other — we call it the attention pattern(or map).
- The same input token embedding is transformed to produce a third embedding using the third weight matrix. This embedding output is now multiplied(weighted) by the attention score computed above, to produce a new embedding(delta E, considered more of a change), which when added to the third embedding produces a much more refined embedding.
Before going any further, let us first understand why we use these three different weight matrices and what kind of learning is enabled through them.
The first two weight matrices learn token embeddings such that tokens(or sets of tokens) that are highly related to each other produce similar embeddings.
The third weight matrix learns sub-image(token) specific representations which when weighted by the attention pattern produces much more refined embeddings that represents higher level/more abstract information learnt from the low level features of the input image.
Visually, the above process can be represented as below:
In the below image, we see a transformation from E to K embeddings on the row and E to Q embeddings on the column, where K and Q represents the first two weight matrices we spoke above and i represents a patch (in this case, there are 8 patches).
The two weight matrices Wk and Wq (as mentioned in (2)) learns to relate input patches. The dot products of these K and Q vector embeddings(every patch with every other patch) produces the attention pattern/map — the bigger the circle, the higher the similarity between K and Q embeddings(and vice versa).
Okay, now we saw how a network can be made to learn an attention pattern or a map on its own from the data, but what does this attention pattern really mean?
For instance, consider the below image where “a dog is sitting on a lawn”:
The attention pattern(i.e. the above similarity matrix) helps a neural network understand what patches of the inputs are highly correlated — for eg: what regions form the lawn, which parts form the dog etc, what region in the entire image directly correlates to the class pattern(say when the network needs to learn the pattern of a dog). This helps the network learn effective representations.
Now that we have this attention pattern, how do we incorporate this correlation into the input embeddings so that the resultant embeddings represent a much more refined embedding — that either represents a part or the whole of the image?
The below image shows another transformation from E to V(the third weight matrix) and then the vectors of V are weighted by the attention scores to produce the change in the embeddings(E delta) which when added to the original input embedding produces the refined embedding.
The third weight matrix Wv helps in learning the sub-image or patch specific embeddings from the input image, which in turn helps in classifying the input image.
These three weight matrices are called the Key, Query and Value Matrices, which forms the basis of the attention mechanism. The above mechanism described(which consists of a key, query and value matrix) forms a single head of Attention, more specifically known as the Self-Attention — as the attention is within the same input data — in this case, the image.
This is wonderful because we just learned the one of the most famous and important formula in deep learning — Attention!
where Query(Q) and Key(K) embeddings produces the attention pattern through their dot product, softmax helps normalizing the attention scores(denominator for numerical stability) and finally when multiplied with Value(V) embeddings producing the refined embeddings.
This is great! But we haven’t seen how the network above is able to learn the final class embedding, as we have only seen how refined patch embeddings are learnt through attention.
In order to learn the final class representation/embedding, we could either add an MLP + softmax (+ CE loss) on top of any of the patch embeddings(preferably first or last patch) to predict the class. This naturally forces the network to learn a global representation of the entire input into this patch embedding.
Another way is to add another learnable class embedding(extra patch) on top of which MLP + softmax can be connected to produce the final class score.
We looked at what forms a single head of attention. Now repeat the same head N times(with each head having its own Q, K and V weights) resulting in Multi-Head Attention. The intuition behind having such multiple heads is to enable the network learn many such correlations among the input patches(the more Q,K and V weights the more expressive the network can be).
Along with this, add skip connection at every output of attention, coupled with a normalization layer + MLP, we get the so called, Transformers :-)
Now, take a look at the architecture of Vision Transformers used for Image Classification:
The above architecture pretty much summarizes what we just learned. You can read the paper AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, which basically introduces Vision Transformers(or ViTs).
Above, we see a block that consists of multiple attention heads, skip connection + normalization, and an MLP at the end. This is known as a single block of the Transformer.
*Note: A cross attention is also same as the self attention mechanism, just that it learns correlations between two different data(input and output tokens in language modeling/machine translation, image and text tokens in case of multi-modal networks).
Context Window:
The number of tokens that can be part of an attention pattern is called the context window. In other words, it indicates as to how far one token can attend to(or talk to) another token. The bigger the window, the better context or correlation that the model can learn, but it also comes at the cost of computation as the matrix can quickly grow.
Transformers for Language Modeling:
Till now, we have seen Transformers from the context of only Image classification. But things are never complete until we discuss this architecture from the context of NLP and more specifically, Language Modeling.
Language modeling is the task of modeling a language(in the sense, its semantics, what word follows what, what all words co-occur together) and one of the ways to do is by predicting the next token/word given the preceding set of tokens/words. And the model capable of doing so is termed as a Language Model.
Now that we are clear on the problem statement, how do we go about solving it?
In simplest terms, what we need is a neural network that given all these preceding tokens as inputs, embeds them into representations(or embeddings), keeps learning refined embeddings or correlations among the tokens -> this transformation goes on and on until it is able to produce an embedding that actually close enough to the next word(as per the training data).
The above solution seems straight forward, in the sense, keep stacking multiple transformer blocks to produce this refined embedding.
But however, there is more to it. Let us take a look at what is mentioned in the breakthrough paper “Attention is all you need” regarding Transformers for Language tasks.
In addition to the encoder block we have discussed, the Transformer Architecture defined for the Machine Translation task has a decoder block as well.
Check the architecture below:
We have read only about the encoder part of a Transformer, what really happens in a decoder?
Well, I would like to not think of both of these as two separate components — because technically they both are one and the same — it is just one set of transformer blocks grouped together whose role is to learn information about the preceding tokens — which is called the Encoder and the other set of transformer blocks learn to predict the embedding of the next token(based on the encoded representation) — which is the decoder.
The only detail we need to look at is that the decoder takes in as input the key and value vectors from the encoder(which has seen the input sentence) and uses its own query vectors — the intuition behind this is that, the network is able to focus on the most relevant parts of the sentence before decoding what could possibly be the next word/token.
*One can see why the name Transformers — a series of transformations to generate rich, relevant and informative representations from the initial input representations.
Parallelism in Transformers:
We have gone through pretty much most of the details regarding the transformer architecture, but however there is a one significant practical advantage that Transformer brings in — which is the parallelism that it introduces.
Consider a sequence to sequence problem(be it language modeling or machine translation) and say you have several hundreds of thousands of examples. Before transformers, these tasks were solved with the help of Recurrent Neural Networks and RNNs are sequential in nature.
Why are RNNs sequential in nature?
RNNs are designed having a certain assumption about the nature of problem it is meant to solve(i.e. inductive bias) and in this case, the assumption is that seq2seq tasks can be solved better by processing the input tokens sequentially(one after the other — either forward or backward).
In order to process the next token in a sequence, the token preceding it is supposed to be fed and processed by the neural net. This also forces the network to be trained sequentially — that is, batching train samples is non-trivial.
However, in case of Transformers, because the network does not have any inductive bias about the input task/problem and because it needs to learn attention patterns across the input tokens together in the sequence, the network can process the entire sequence at once.
Take this train sample “<start> Why are RNNs sequential in nature? <end>”
This one sample can be made into multiple train samples:
<start> Why
<start> Why are
<start> Why are RNNs
<start> Why are RNNs sequential
<start> Why are RNNs sequential in
<start> Why are RNNs sequential in nature? <end>
This enables batching to be done within one sample and makes the training more efficient computationally. Note that masking is required during attention calculation to ensure the later tokens do not influence the earlier tokens.
Intuition — “Transformers as a General Purpose Algorithm”:
I would like to put forth an interesting, more intuitive way of thinking what Transformers really are.
Given the problem of image classification or language modeling or machine translation, there is one commonality among all.
“How do we enable a neural network to learn a global embedding by learning the relation among the input tokens?”
A bruteforce solution could be, can we not pass on the information of every token to every other token in the input based on some predefined structure/connections? But here the connections are static(not learned from the data) and it doesn’t help the network focus(or attend) to the most important token/word in the input during the decoding stage — that is, it doesn’t know which token is more important than the other.
But, what if these connections can be learned by the network from the training data?
In other words, can a neural network come up with learnable weights that helps establish correlations among the words/image patches in the input sequence and helps these tokens attend to the most relevant for better downstream representation generation. This is pretty much what Transformers do!
This learnable weights is what we can call as Attention. And the Key and Query weight matrices help in learning these connections dynamically.
Therefore, one can think of Transformers as a message passing architecture where information flows from one token to every other token(and weighted depending on what matters, which comes from the data). Hence, Transformers can be viewed as a much more general version of Multi-layer Perceptrons whereas Convolutions and RNNs are specific versions of MLPs.
This very property makes the transformers really good at learning general, transferable representations, improves as the data and network scales and is applicable to all kinds of problems, irrespective of the problem domain.
That’s all in this article about the attention mechanism and the transformer architecture. I hope I was able to intuitively explain the underlying concept of transformers(and attention), how it mimics a communication or a graph based system/a message passing network and how it learns to come up with these connections from the training data.
While learning about transformers, I spent most of my time going through Yannic Kilcher’s excellent videos on Attention and Vision Transformers and 3blue1brown’s amazing visual representation of attention. Would encourage all of you to go through the same.
Thanks and cheers :-)
References:
- Attention is all you need — https://arxiv.org/abs/1706.03762
- An image is worth 16x16 words — https://arxiv.org/pdf/2010.11929
- Vision Transformers — Yannic Kilcher — https://www.youtube.com/watch?v=iDulhoQ2pro
- Transformers by 3blue1brown — https://www.youtube.com/watch?v=eMlx5fFNoYc
- Andrej Karpathy’s tweet on Transformers — https://x.com/karpathy/status/1582807367988654081