A Visual Guide to Vision Transformers ā
This is a visual guide to Vision Transformers (ViTs), a class of deep learning models that have achieved state-of-the-art performance on image classification tasks. Vision Transformers apply the transformer architecture, originally designed for natural language processing (NLP), to image data. This guide will walk you through the key components of Vision Transformers in a scroll story format, using visualizations and simple explanations to help you understand how these models work and how the flow of the data through the model looks like.
Translations ā
Language | Link | Translated by |
---|---|---|
š°š· Korean | Link | Junghwan Park |
Please enjoy and start scrolling!
0) Lets start with the data ā
Like normal convolutional neural networks, vision transformers are trained in a supervised manner. This means that the model is trained on a dataset of images and their corresponding labels.
1) Focus on one data point
To get a better understanding of what happens inside a vision transformer lets focus on a single data point (batch size of 1). And lets ask the question: How is this data point prepared in order to be consumed by a transformer?
2) Forget the label for the moment
The label will become more relevant later. For now the only thing that we are left with is a single image.
3) Create patches of the image
To prepare the image for the use inside the transformer we divide the image into equally sized patches of size p x p.
4) Flatting of the images patches
The patches are now flattened into vectors of dimension p'= pĀ²*c where p is the size of the patch and c is the number of channels.
5) Creating patch embeddings
These image patch vectors are now encoded using a linear transformation. The resulting Patch Embedding Vector has a fixed size d.
6) Embedding all patches
Now that we have embedded our image patches into vectors of fixed size, we are left with an array of size n x d where n is the the number of image patches and d is the size of the patch embedding
7) Appending a classification token
In order for us to effectively train our model we extend the array of patch embeddings by an additional vector called classification token (cls token). This vector is a learnable parameter of the network and is randomly initialized. Note: We only have one cls token and we append the same vector for all data points.
8) Add positional embedding Vectors
Currently our patch embeddings have no positional information associated with them. We remedy that by adding a learnable randomly initialized positional embedding vector to all our patch embeddings. We also add a such a positional embedding vector to our classification token.
9) Transformer Input
After the positional embedding vectors have been added we are left with an array of size (n+1) x d . This will be our input for the transformer which will be explained in greater detail in the next steps
10.1) Transformer: QKV Creation
Our transformer input patch embedding vectors are linearly embedded into multiple large vectors. These new vectors are than separated into three equal sized parts. The Q - Query Vector, the K - Key Vector and the V - Value Vector . We will have (n+1) of a all of those vectors.
10.2) Transformer: Attention Score Calculation
To calculate our attention scores A we will now multiply all of our query vectors Q with all of our key vectors K.
10.3)Transformer: Attention Score Matrix
Now that we have the attention score matrix A we apply a `softmax` function to every row such that every row sums up to 1.
10.4)Transformer: Aggregated Contextual Information Calculation
To calculate the aggregated contextual information for the first patch embedding vector. We focus on the first row of the attention matrix. And use the entires as weights for our Value Vectors V. The result is our aggregated contextual information vector for the first image patch embedding.
10.5)Transformer: Aggregated Contextual Information for every patch
Now we repeat this process for every row of our attention score matrix and the result will be N+1 aggregated contextual information vectors. One for every patch + one for the classification token. This steps concludes our first Attention Head.
10.6)Transformer: Multi-Head Attention
Now because we are dealing multi head attention we repeat the entire process from step 10.1 - 10-5 again with a different QKV mapping. For our explanatory setup we assume 2 Heads but typically a VIT has many more. In the end this results in multiple Aggregated contextual information vectors.
10.7)Transformer: Last Attention Layer Step
These heads are stacked together and are mapped to vectors of size d which was the same size as our patch embeddings had.
10.8)Transformer: Attention Layer Result
The previous step concluded the attention layer and we are left with the same amount of embeddings of exactly the same size as we used as input.
10.9)Transformer: Residual connections
Transformers make heavy use of residual connections which simply means adding the input of the previous layer to the output the current layer. This is also something that we will do now.
10.10)Transformer: Residual connection Result
The addition results in vectors of the same size.
10.11)Transformer: Feed Forward Network
Now these outputs are feed through a feed forward neural network with non linear activation functions
10.12)Transformer: Final Result
After the transformer step there is another residual connections which we will skip here for brevity. And so the last step concluded the transformer layer. In the end the transformer produced outputs of the same size as input.
11) Repeat Transformers
Repeat the entire transformer calculation Steps 10.1 - Steps 10.12 for the Transformer several times e.g. 6 times.
12) Identify Classification token output
Last step is to identify the classification token output. This vector will be used in the final step of our Vision Transformer journey.
13) Final Step: Predicting classification probabilities
In the final and last step we use this classification output token and another fully connected neural network to predict the classification probabilities of our input image.
14) Training of the Vision Transformer ā
We train the Vision Transformer using a standard cross-entropy loss function, which compares the predicted class probabilities with the true class labels. The model is trained using backpropagation and gradient descent, updating the model parameters to minimize the loss function.
Conclusion ā
In this visual guide, we have walked through the key components of Vision Transformers, from the data preparation to the training of the model. We hope this guide has helped you understand how Vision Transformers work and how they can be used to classify images.
I prepared this little Colab Notebook to help you understand the Vision Transformer even better. Please have look for the 'Blogpost' comment. The code was taken from @lucidrains great VIT Pytorch implementation be sure to checkout his work.
If you have any questions or feedback, please feel free to reach out to me. Thank you for reading!
Acknowledgements ā
- VIT Pytorch implementation
- All images have been taken from Wikipedia and are licensed under the Creative Commons Attribution-Share Alike 4.0 International license.