[NLP] Multi-Head Attention - Details

1 minute read

Multi - Head Attention - details

이전 포스팅에서 Attention is all you need 논문 리뷰를 하며 Attention mechanismTransformer 에 대해 알아보았습니다. 저는 개인적으로 공부하면서 Transformer 내 Multi-head attention 부분이 조금 헷갈렸었는데요, 그래서 이번 포스팅에서는 Multi-head attention 에 대해서 조금 더 자세히 이야기해보고자 합니다.

Architecture

일단 Multi-Head Attention 의 전체적인 구조는 아래와 같습니다.

Input 부분부터 살펴보자면 Query, Key, Value 벡터들이 Linear layer를 거쳐 Multi-head 로 나뉘어져 Attention이 계산됩니다. 처음 거치는 Linear layer는 보통 Query, Key, Value의 차원을 감소시킴과 동시에 서로 간 차원이 상이할 경우 맞춰주는 역할을 한다고 합니다.

그러고 나서 Multi-Head로 찢어서 계산이 됩니다. 그림으로 표현하자면 아래와 같습니다.

논문에서는 이렇게 나눠서 계산하는 것이 성능적 측면에서 더 좋다고 언급하였습니다. 그리고 이전 What does BERT look at? 논문 리뷰 포스팅에서 보았던 것처럼 BERT의 각 Attention head 내에서 중점적으로 가지게 되는 Linguistic information 등이 다르기 때문에 Multi-Head 구조를 이해하고 각 Head를 분석하는 것은 중요합니다.

이렇게 각 Attention head에서 Scaled Dot-Product Attention이 계산된 뒤, 이들을 모두 Concat하여 다시 한 번 Linear layer를 거쳐주게 되면 Multi-Head Attention 연산이 완료됩니다. 이 과정을 코드로 잘 표현해놓은 것을 발견해서 공유해보겠습니다.

def multi_head_attention(query,key,value,num_units,heads,masked=False):
    ## 1) Linear layers
    query = keras.layers.Dense(num_units,activation='relu')(query)
    key = keras.layers.Dense(num_units,activation='relu')(key)
    value = keras.layers.Dense(num_units,activation='relu')(value)
    ## 2) MULTI-head
    query = tf.concat(tf.split(query,heads,axis=-1),axis=0)
    key = tf.concat(tf.split(key,heads,axis=-1),axis=0)
    value = tf.concat(tf.split(value,heads,axis=-1),axis=0)
    ## 3) Compute attention
    attention = scaled_dot_product_attention(query,key,value,masked)
    ## 4) Concat
    output = tf.concat(tf.split(attention,heads,axis=0),axis=-1)
    ## 5) Linear layer again
    output = keras.layers.Dense(num_units,activation='relu')(output)
    return output

(Code from : https://simpling.tistory.com/4)

위 코드를 보면 Multi-head 내 일련의 과정이 비교적 잘 이해되는 듯 합니다. 특히 2) MULTI-head 부분에서 feature length를 헤드 수만큼 나눠준 뒤, Batch 파트를 헤드 차원만큼 늘려줍니다. (Batch, Seq, Feature) 로 구성된 3D input을 생각하시면서 헤드 수만큼 쪼개고 다시 Batch쪽에 붙여서 늘리고 하는 과정을 상상해보시면 될 것 같습니다. Concat과 Linear layer 까지 거쳐서 최종 Output이 산출되는 그림은 아래와 같습니다.

Transformer 는 결국 위와 같은 Multi-Head Attention 구조에 Positional Encoding, Layer normalization, FFN 등의 테크니컬 디테일이 합쳐서 구성되는 것이기에 Multi-Head Attention을 잘 이해하는 것이 중요하다고 생각합니다. 저처럼 Multi-Head Attention이 헷갈리셨던 분들에게 조금이나마 도움을 드렸으면 하는 마음에서 이번 포스팅을 작성해보았습니다.

이후 포스팅들에서는 BERT 외 다른 Pre-trained models 에 대해서 하나하나 알아보도록 하겠습니다!