Published on

Attention 코드로 구현하기

Authors
  • avatar
    Name
    Inhwan Cho
    Twitter

스케일드 닷-프로덕트 어텐션

  • 참조 : <위키독스>
  • 닷-프로덕트 어텐션(dot-product attention)에서 스케일링하는 것을 추가하면
    스케일드 닷-프로덕트 어텐션(Scaled dot-product Attention)이라고 합니다
  • scaled_dot_product_attention을 tensorflow로 구현, 살펴보겠습니다.
def scaled_dot_product_attention(query, key, value, mask):
  # query 크기    : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
  # key 크기      : (batch_size, num_heads, key의 문장 길이,   d_model/num_heads)
  # value 크기    : (batch_size, num_heads, value의 문장 길이, d_model/num_heads)
  # padding_mask : (batch_size, 1, 1, key의 문장 길이)

  # Q와 K의 곱. 어텐션 스코어 행렬.
  matmul_qk = tf.matmul(query, key, transpose_b=True)

  # 스케일링
  depth = tf.cast(tf.shape(key)[-1], tf.float32)
  logits = matmul_qk / tf.math.sqrt(depth)

  # 마스킹, 매우 작은 값이므로 소프트맥스 함수에 의해 0이 된다.
  if mask is not None:
    logits += (mask * -1e9)

  # 소프트맥스 함수는 마지막 차원인 key의 문장 길이 방향으로 수행(axis=-1)
  # attention weight : (batch_size, num_heads, query의 문장 길이, key의 문장 길이)
  attention_weights = tf.nn.softmax(logits, axis=-1)

  # output : (batch_size, num_heads, query의 문장 길이, d_model/num_heads)
  output = tf.matmul(attention_weights, value)

  return output, attention_weights

테스트

  • temp_q의 값 [0, 10, 0]은 Key에 해당하는 temp_k의 두번째 값 [0, 10, 0]과 일치하게 설정
import tensorflow as tf
import numpy as np
# 임의의 Query, Key, Value인 Q, K, V 행렬 생성
np.set_printoptions(suppress=True) #옵션 넣어줘야 보기 편함(소수점 반올림)
temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[1,0],
                      [10,0],
                      [100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3) #transpose_b
  • 어텐션 분포는 [0, 1, 0, 0]의 값을 가지며
  • Value의 두번째 값인 [10, 0]이 출력되는 것을 확인할 수 있습니다.
# 함수 실행
temp_out, temp_attn = scaled_dot_product_attention(temp_q, temp_k, temp_v, None)

print(temp_attn) # 어텐션 분포(어텐션 가중치의 나열)
# tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)

print(temp_out) # 어텐션 값
# tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)