圖靈學院/科楠/2024年8月21日
在電腦視覺領域中,卷積神經網絡 (CNN) 一直以來是處理視覺任務的主流技術。然而,這一局面在 2020 年發生了變化。當年,Dosovitskiy 等人發表了一篇題為《An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale》的論文,提出了 Vision Transformer (ViT) 模型。ViT 的出現被視為電腦視覺領域的一大突破,因為它能夠超越傳統的 CNN,在圖像識別任務中取得更好的表現。
CNN 與 ViT 的核心區別
CNN 的核心概念在於透過卷積層中的卷積核提取圖像特徵。由於卷積核的大小通常相對較小,因此它只能捕捉圖像中局部區域的資訊。為了獲取圖像的全局上下文,需要疊加多層卷積層,從而增加模型的感受野。然而,ViT 則從一開始就直接捕捉全局資訊,這使得 ViT 在資訊提取方面更為全面。
Vision Transformer 的架構
ViT 的架構與 NLP 中的 Transformer 有著密切的關聯,特別是它的編碼器 (Encoder)。在 NLP 中,編碼器用來捕捉輸入序列中不同詞彙之間的關聯,而在 ViT 中,圖像的每個小塊被視為一個 token,編碼器則負責捕捉這些圖像塊之間的關聯性。
圖像塊的分割與線性投影
ViT 的第一步是將圖像劃分為若干小塊,這些小塊隨後被展平成一維向量。這些向量再經過線性投影,被映射到高維空間,類似於 NLP 中的詞嵌入 (Word Embedding)。此步驟可以透過多層感知機 (MLP) 或卷積層來實現。這一過程使得圖像塊成為可供 Transformer 處理的 token。
類別 token 與位置嵌入
由於 ViT 的任務通常是圖像分類,因此需要在投影的 token 序列前加入一個稱為類別 token 的特殊 token。該 token 用於聚合其他圖像塊的信息,並最終負責輸出分類結果。為了解決在圖像塊展平過程中丟失的空間信息,ViT 還會向每個 token(包括類別 token)中加入位置嵌入,這樣可以將空間信息重新引入模型中。
Transformer 編碼器與 MLP 頭
在圖像塊序列準備好後,它們會被傳入 Transformer 編碼器。該編碼器由層歸一化 (Layer Normalization)、多頭注意力機制 (Multi-Head Attention)、以及 MLP 層組成,並在多處引入殘差連接 (Residual Connections)。ViT 的 MLP 頭負責將 Transformer 編碼器的輸出進一步處理,最終生成分類結果。
Vision Transformer 的實作
在 ViT 的實作過程中,我們可以使用 PyTorch 來一步步構建整個模型架構。首先,需定義圖像的批量大小、圖像尺寸、通道數等基礎參數。接著,劃分圖像塊並進行線性投影。這一部分可以通過 `nn.Unfold()` 或 `nn.Conv2d()` 來實現。兩者的區別在於,`nn.Conv2d()` 可以同時完成展平和線性投影的操作,因此效率更高。
接下來,將類別 token 加入到序列中,並引入位置嵌入。這些步驟可以通過 `torch.cat()` 函數將類別 token 與圖像塊序列連接,並在每個 token 上加上位置嵌入。
之後,圖像塊序列被傳入 Transformer 編碼器。在這一步中,編碼器將進行兩次層歸一化、一次多頭注意力計算,並通過 MLP 層來進一步處理數據。每個編碼器都會重複多次,以確保模型能夠深入理解圖像中的全局資訊。
最終,ViT 的 MLP 頭會將編碼器的輸出進一步投影到類別數量的維度,從而得到最終的分類結果。這個步驟的輸出將是一個包含預測類別概率的向量。
結論
Vision Transformer (ViT) 的出現,為電腦視覺領域帶來了新的技術方向。透過直接捕捉圖像的全局資訊,ViT 在許多視覺任務中都取得了比 CNN 更好的表現。雖然 ViT 的模型參數量相對較大,但其優越的性能使其成為未來電腦視覺研究的重要工具。
在實作方面,使用 PyTorch 可以讓我們對 ViT 的架構有更深入的理解,並通過編碼器、類別 token、位置嵌入、多頭注意力機制等核心技術來構建整個模型。在實際應用中,ViT 將逐漸成為圖像識別、物體檢測等領域的主流技術之一。
Reference:
1. Dosovitskiy, A., et al. (2020). "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale." [Arxiv]. [Accessed August 8, 2024].
2. Lin, H., et al. (2017). "Maritime Semantic Labeling of Optical Remote Sensing Images with Multi-Scale Fully Convolutional Network." [ResearchGate]. [Accessed August 8, 2024].
3. Vision Transformer. PyTorch. [PyTorch Documentation](https://pytorch.org/vision/main/models/vision_transformer.html). [Accessed August 8, 2024].
Copyright © 2024 利創智能科技股份有限公司 All rights reserved.
Replace this text with information about you and your business or add information that will be useful for your customers.