Vision Transformers Explained
Vision Transformers (ViTs) apply the Transformer architecture, originally designed for natural language processing, to image recognition tasks. Instead of processing text sequences, ViTs process sequences of image patches.
High-Level Overview
The core idea is to treat an image as a sequence of patches, similar to how a sentence is a sequence of words. These patches are then linearly embedded and fed into a standard Transformer encoder. Let's break down the process:
- Patch Extraction: An input image is divided into fixed-size patches. For example, a 224x224 image might be divided into 16x16 patches, resulting in (224/16) * (224/16) = 14 * 14 = 196 patches.
- Linear Embedding of Patches: Each patch is then flattened into a vector. This vector is then linearly projected into a higher-dimensional embedding space. This embedding represents the patch's content in a way that the Transformer can process.
- Positional Encoding: Since the Transformer architecture is permutation-invariant (it doesn't inherently know the order of the input sequence), positional encodings are added to the patch embeddings. These encodings provide information about the location of each patch within the original image. These are usually learned.
- Transformer Encoder: The sequence of embedded patches, along with the positional encodings, is fed into a standard Transformer encoder. The encoder consists of multiple layers of multi-head self-attention and feed-forward networks.
- Classification Head: A classification head, typically a multi-layer perceptron (MLP), is added on top of the Transformer encoder's output. This head maps the Transformer's output to the desired number of classes.
Detailed Breakdown of Key Components
-
Patch Embedding: This step transforms the image patches into a format suitable for the Transformer. Let's say we have an image x with spatial dimensions H x W x C (height, width, channels). The image is divided into N = H*W / P^2 patches, where each patch has dimensions P x P x C. Each patch is then flattened into a vector of size P^2 * C. A learnable linear projection E projects these flattened patches into a D-dimensional embedding space: x_p -> x_p * E
, where E
is a (P^2 * C) x D
matrix.
-
Positional Encodings: Positional embeddings are added to patch embeddings to retain spatial information, since Transformers are permutation invariant. These are added after the linear embedding. z_0 = [x_class; x_p^1 E; x_p^2 E; ... x_p^N E] + E_pos
, where x_class
is a learnable classification token, and E_pos
is the positional encoding.
-
Transformer Encoder: The transformer encoder is the workhorse of the model. It consists of:
- Multi-Head Self-Attention (MSA): Allows the model to attend to different parts of the input sequence. It computes attention weights to determine the importance of each patch relative to all other patches. This part is computationally expensive, usually O(N^2) where N is the number of patches. But that is still lower than comparable CNN's.
- Feed-Forward Network (FFN): A two-layer MLP with a non-linear activation function (e.g., GELU) in between. Applied to each patch embedding independently.
- Layer Normalization (LN) and Residual Connections: Used to improve training stability and performance.
-
Classification Head: The output of the transformer for a prepended "classification token" is used as the image representation y
, and is fed into a simple MLP head.
Advantages of Vision Transformers
- Global Context: ViTs can capture long-range dependencies between image patches, which can be beneficial for understanding the overall scene. CNN's typically struggle with this, because of their limited receptive field.
- Scalability: ViTs can be scaled up to very large models with a large number of parameters, leading to improved performance.
- Transfer Learning: Pre-trained ViTs can be fine-tuned on downstream tasks with relatively little data.
Disadvantages of Vision Transformers
- Data Requirements: ViTs typically require large amounts of training data to achieve good performance. This is because the large number of parameters needs to be trained, and CNN's have inductive biases that make them more sample efficient.
- Computational Cost: Training ViTs can be computationally expensive, especially for large models.
- Patch Size Sensitivity: The choice of patch size can significantly impact performance. Smaller patch sizes can capture finer-grained details but also increase the computational cost.
In Summary
Vision Transformers offer a powerful alternative to CNNs for image recognition tasks. By leveraging the Transformer architecture, ViTs can capture long-range dependencies and achieve state-of-the-art performance, but they also require large amounts of training data and can be computationally expensive.