Published on

ViT(Vision Transformer)

Authors
  • avatar
    Name
    Inhwan Cho
    Twitter

NLP분야에서, 기존의 RNN에서의 한계때문에 [Context vector의 크기가 제한적이기 때문에 문장의 모든 정보를 보낼 수 없음]
seq2seq모델(문장을 입력받아 문장을 출력해주는)에서 RNN대신 제안된 모델이 Attention이라는 모델입니다.
이 Attention을 통해 만든 Transformer가 NLP분야에서 너무 좋은 성능을 보여줘서
CNN이 아닌Attention을 통해 vision분야에도 적용해 본 모델이 Vit입니다.

먼저, ViT에 들어가기 전에 Attention, Transformer의 개념을 간단하게 잡고 가겠습니다.

Attention과 Self-Attention의 차이

Attention = (Decoder -> Query)(Encoder -> Key, Value) == encoder와 decoder의 상관관계를 통해 특징 추출
Self-Attention = (입력데이터에서 동일한  Query, Key, Value을 얻음) == 입력데이터 내에서의 상관관계를 통해 특징 추출
(ViT는 self-Attention만 사용)

Transformer와 CNN의 차이

Transformer는 하나의 layer에서 전체 이미지 정보를 통합 가능하나
CNN의 경우 멀리 떨어진 두 정보를 통합할때 여러개의 layer를 통과해야 정보가 통합이 됩니다.

CNN에서의 Inductive bias는 convolution filer라는 것을 통하여 지역적 정보를 추출할 수 있습니다.
Transfoer에서는 1차원 벡터로 만든 후 self-attention을 통해 특징을 추출하기 때문에 지역적 정보는 없습니다.
그러나 그렇기 때문에 모델의 자유도는 증가합니다.(inductivie bias가 CNN보다 적다)

ViT(ViT-Base / 16) 설명

아래는 논문에 나와있는 전체적인 ViT overview 입니다.

VIT overview

여기서 16은 패치의 사이즈를 의미합니다. 즉, 논문의 그림으로 설명하자면 기존의 이미지 사이즈가 48x48이고 이를 9개의 16x16의 패치로 자른 모델을 의미합니다.
(실제 이미지넷 데이터 기준으로 설명하면 H=W=224, Patch_size=16이면 실제 패치의 수 = 14*14 == 196이 됩니다.)

ViT의 연산 과정은 다음과 같습니다.

16x16사이즈의 패치를 1차원으로 만들어줍니다(컬러 이미지이기 때문에 3채널이라 곱하기3) ==> 16x16x3 = 768차원
이를 batch가 있는 실제 코드로 살펴보면 다음과 같습니다.

# input_img[batch, channel, height, width]
x = torch.randn(1,3,224,224)

# 16 pixels
patch_size = 16 
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)

print(pathes.shape)
# torch.Size([1, 196, 768])

이를 Linear Projection을 통과하고 그 값에(768차원)

position embedding값과 Classification token(CLS token과 동일한 기능)을 더해줍니다.
[여기서 position embedding값과 Classification token은 학습을 통해 결정되는 파라미터 값] == z값
코드로 살펴보면 다음과 같습니다.

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x
    
PatchEmbedding()(x).shape
#torch.Size([1, 197, 768])

그리고 여기서 나온 값을 Transformer의 Encoder의 input값으로 사용합니다.
또한, Transformer의 인코더와 약간 차이가 있습니다. 기존의 Transformer에서는 연산 후 normalization을 수행하였지만, ViT에서는 Norm을 먼저 진행하고 attention연산/MLP연산을 진행합니다.
여기서 Normalization은 batch norm이 아닌 layer normalization입니다.(batch단위가 아닌 instance(sample)단위로 norm을 진행)

input(768차원)값으로 사용되는 z값에 Weight matrix(학습되는 파라미터)인 W_QW\_Q 값을 곱하여 [zbulletW_Q z \\bullet W\_Q]-> q(64차원)값을 얻고 W_KW\_K를 곱하여 [zbulletW_K z \\bullet W\_K] -> k(64차원)를 얻고 W_VW\_V를 곱하여 [zbulletW_V z \\bullet W\_V] -> v값(64차원)을 얻습니다.

그리고 얻어진 q,k를 곱한 값에 softmax를 취한 a값[a=Softmax(qTdotk)a = Softmax(q^T\\dot k)]을 구하고 (a=attention score) a와 v를 곱한 최종 score(64차원)를 구합니다.

ViT에서는 이 self-attention을 12번 수행하게 됩니다.

여기서 12번 수행하고 concate을 하면 12*64 -> 768차원으로 다시 복귀됩니다.

그 후 인코더의 두번째 부분은 norm을 수행하고, 768차원 -> 3072차원(히든 레이어) -> 768차원인 MLP를 통과합니다.

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
patches_embedded = PatchEmbedding()(x)
MultiHeadAttention()(patches_embedded).shape
#torch.Size([1, 197, 768])

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )))
TransformerEncoderBlock()(patches_embedded).shape
#torch.Size([1, 197, 768])

그 후 최종 Classification을 위해 MLP head를 통과하여 분류를 합니다.

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))
            

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

아래는 논문에 나와있는 ViT 수식입니다.

VIT 수식

수식으로 설명하면 다음과 같습니다.

(1) 각각의 Linear projection된 패치들에 position encoding값을 더합니다.(+CLS token)
그리고 이 값을 layer normalization 시켜준 뒤 Multi head self-attention에 input으로 넣어줍니다.(12번 연산)

(2) residual connection이기 때문에 기존의 값에 더해주어 값을 보존합니다.[MSA는 Multihead Self Attention]

(3) transformer를 통해 구한output값을 한번 더 Layer normalization 시켜준 뒤 MLP를 통과합니다. 그 후 기존의 값을 더해줍니다.

(4) 최종적으로 Layer normalization을 통해 최종 y vector값이 나오게 됩니다.

  • 이 y값에 MLP를 통과하여 최종 이미지 classification을 수행합니다.

참고 문헌

ViT 논문
관련 유튜브
코드 구현