Articles in Turing Academy cover three major themes: ESG Net Zero Laboratory, AI Laboratory and Lean Management Laboratory. We will share articles on related topics from time to time. We also welcome students who are interested in the above topics to submit articles and share them with you. Insights (I want to contribute)

深入探討 Vision Transformer (ViT) —— 從 PyTorch 實作學習

 

圖靈學院/科楠/2024年8月21日

 

    在電腦視覺領域中,卷積神經網絡 (CNN) 一直以來是處理視覺任務的主流技術。然而,這一局面在 2020 年發生了變化。當年,Dosovitskiy 等人發表了一篇題為《An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale》的論文,提出了 Vision Transformer (ViT) 模型。ViT 的出現被視為電腦視覺領域的一大突破,因為它能夠超越傳統的 CNN,在圖像識別任務中取得更好的表現。

 

CNN 與 ViT 的核心區別

 

 

CNN 的核心概念在於透過卷積層中的卷積核提取圖像特徵。由於卷積核的大小通常相對較小,因此它只能捕捉圖像中局部區域的資訊。為了獲取圖像的全局上下文,需要疊加多層卷積層,從而增加模型的感受野。然而,ViT 則從一開始就直接捕捉全局資訊,這使得 ViT 在資訊提取方面更為全面。

 

Vision Transformer 的架構

 

 

ViT 的架構與 NLP 中的 Transformer 有著密切的關聯,特別是它的編碼器 (Encoder)。在 NLP 中,編碼器用來捕捉輸入序列中不同詞彙之間的關聯,而在 ViT 中,圖像的每個小塊被視為一個 token,編碼器則負責捕捉這些圖像塊之間的關聯性。

 

圖像塊的分割與線性投影

 

ViT 的第一步是將圖像劃分為若干小塊,這些小塊隨後被展平成一維向量。這些向量再經過線性投影,被映射到高維空間,類似於 NLP 中的詞嵌入 (Word Embedding)。此步驟可以透過多層感知機 (MLP) 或卷積層來實現。這一過程使得圖像塊成為可供 Transformer 處理的 token。

 

類別 token 與位置嵌入

 

由於 ViT 的任務通常是圖像分類,因此需要在投影的 token 序列前加入一個稱為類別 token 的特殊 token。該 token 用於聚合其他圖像塊的信息,並最終負責輸出分類結果。為了解決在圖像塊展平過程中丟失的空間信息,ViT 還會向每個 token(包括類別 token)中加入位置嵌入,這樣可以將空間信息重新引入模型中。

 

Transformer 編碼器與 MLP 頭

 

在圖像塊序列準備好後,它們會被傳入 Transformer 編碼器。該編碼器由層歸一化 (Layer Normalization)、多頭注意力機制 (Multi-Head Attention)、以及 MLP 層組成,並在多處引入殘差連接 (Residual Connections)。ViT 的 MLP 頭負責將 Transformer 編碼器的輸出進一步處理,最終生成分類結果。

 

Vision Transformer 的實作

 

在 ViT 的實作過程中,我們可以使用 PyTorch 來一步步構建整個模型架構。首先,需定義圖像的批量大小、圖像尺寸、通道數等基礎參數。接著,劃分圖像塊並進行線性投影。這一部分可以通過 `nn.Unfold()` 或 `nn.Conv2d()` 來實現。兩者的區別在於,`nn.Conv2d()` 可以同時完成展平和線性投影的操作,因此效率更高。

 

接下來,將類別 token 加入到序列中,並引入位置嵌入。這些步驟可以通過 `torch.cat()` 函數將類別 token 與圖像塊序列連接,並在每個 token 上加上位置嵌入。

 

之後,圖像塊序列被傳入 Transformer 編碼器。在這一步中,編碼器將進行兩次層歸一化、一次多頭注意力計算,並通過 MLP 層來進一步處理數據。每個編碼器都會重複多次,以確保模型能夠深入理解圖像中的全局資訊。

 

最終,ViT 的 MLP 頭會將編碼器的輸出進一步投影到類別數量的維度,從而得到最終的分類結果。這個步驟的輸出將是一個包含預測類別概率的向量。

 

結論

 

Vision Transformer (ViT) 的出現,為電腦視覺領域帶來了新的技術方向。透過直接捕捉圖像的全局資訊,ViT 在許多視覺任務中都取得了比 CNN 更好的表現。雖然 ViT 的模型參數量相對較大,但其優越的性能使其成為未來電腦視覺研究的重要工具。

在實作方面,使用 PyTorch 可以讓我們對 ViT 的架構有更深入的理解,並通過編碼器、類別 token、位置嵌入、多頭注意力機制等核心技術來構建整個模型。在實際應用中,ViT 將逐漸成為圖像識別、物體檢測等領域的主流技術之一。

 

Reference:

1. Dosovitskiy, A., et al. (2020). "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale." [Arxiv]. [Accessed August 8, 2024].

2. Lin, H., et al. (2017). "Maritime Semantic Labeling of Optical Remote Sensing Images with Multi-Scale Fully Convolutional Network." [ResearchGate]. [Accessed August 8, 2024].

3. Vision Transformer. PyTorch. [PyTorch Documentation](https://pytorch.org/vision/main/models/vision_transformer.html). [Accessed August 8, 2024].