Attention is All You Need 구현
본 포스팅은 2018년 harvardnlp(The Annotated Transformer (harvard.edu))에 게시된 “Attention is All You Need” paper를 바탕으로 실제 PyTorch로 구현하는 내용을 리뷰하는 것이다. 또한, 본 코드는 OpenNMT 패키지를 기반으로 한다.
Prelims
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline
Background
CNN은 입출력 위치의 모든 hidden representations을 병렬로 계산한다. 이러한 경우, 두 임의의 입력 또는 출력 위치 간 신호를 연결하는데 필요한 연산의 수가 위치 간 거리에 따라 증가하며, 이로 인한 위치에 따른 의존성을 학습하는 것이 더 어려워진다.
Transformer에서는 이를 상수 연산 수로 줄였으나, 이는 attention 가중치를 평균화하여 위치를 반영하는 과정에서 효과적인 해상도가 감소하는 비용을 수반한다. 이를 해결하기 위해 Multi-Head Attention을 도입했다.
- self attention : 하나의 시퀀스 내 서로 다른 위치를 연결하여 시퀀스의 표현을 계산하는 attention mechanism.
우리가 알고 있는 Transformer는 RNN이나 합성곱을 사용하지 않고, 오직 self-attention만으로 입출력의 표현을 계산하는 모델인 것이다.
Model Architecture
Transformer 역시 Encoder-Decoder 구조를 갖고 있다. 기본적인 동작 방식은 다음과 같다.
🚀 인코더(Encoder)
- 입력 시퀀스 (x1, …, xn)을 받아 연속적인 표현 z = (z1, …, zn)으로 매핑한다.
- 인코더의 역할은 원본 입력 데이터를 고차원적인 특성 표현으로 변환하는 것이다.
🚀 디코더(Decoder)
- 인코더가 생성한 표현 z를 사용하여 출력 시퀀스 (y1, …, yn)을 한 번에 하나의 요소씩 생성한다.
- 모델을 자동회귀(auto-regressive) 방식을 따른다 → 즉, 이전에 생성된 출력을 다음 요소를 생성할 때 추가 입력으로 사용한다.
🚀 실제 모델 흐름
입력 시퀀스 → 임베딩 → 인코더 → 연속 표현(memory) → 디코더 → 출력 시퀀스 확률 분포
Code 분석
✅ EncoderDecoder 클래스
- 전체 모델의 구조를 정의하는 기본 클래스이다.
- 주요 컴포넌트 : encoder, decoder, src_embed(입력 시퀀스를 임베딩하는 함수), tgt_embed(출력 시퀀스를 임베딩하는 함수), generator(디코더의 출력을 최종적으로 변환하여 출력 확률을 생헌하는 컴포넌트)로 구성이 되어있다.
- 주요 메서드 :
- forward() - 마스크가 적용된 입출력 시퀀스를 받아 인코더와 디코더를 차례로 호출한다.
- encode() - 입력 데이터를 임베딩한 후 인코더를 통해 처리한다.
- decode() - 인코더의 출력 결과(memory)를 받아 디코더에서 처리한다.
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
def forward(self, src, tgt, src_mask, tgt_mask):
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
"입력 시퀀스 src를 인코더에 전달해 memory(연속 표현) 생성"
"이후 memory와 디코더에 입력 시퀀스 tgt를 전달해 출력 생성"
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
✅ Generator 클래스
- 디코더의 출력 결과를 실제 어휘 확률 분포로 변환
- 주요 컴퍼넌트 )
- proj : d_model → vocab size로 변환하는 선형 계층
- 주요 메서드
- forward(x) : 디코더 출력 x에 대해 선형 변환을 수행한 후 log-softmax를 적용하여 확률 분포 반환
class Generator(nn.Module):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)
Encoder와 Decoder 구조 상세 설명
🚀 인코더(Encoder)
- 인코더는 6개의 동일한 레이어(stack)으로 구성
- 레이어 구성 = 2개의 서브레이어(sub-layer)으로 이루어져있으며, residual connection과 layer normalization이 적용됨
[코드 상세]
✅ Encoder 클래스 & Layernorm 클래스
- clones() 함수를 사용하여 동일한 레이어 6개 복제
- 각 레이어의 출력을 LayerNorm으로 정규화
"Encoder class"
class Encoder(nn.Module) :
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = clones(layer, N) #N개의 동일한 레이어 복제
self.norm = LayerNorm(layer.size) #마지막 출력 정규화
def forward(self, x, mask):
for layer in self.layers: #레이어를 순차적으로 통과
x = layer(x, mask)
return self.norm(x) # 정규화 후 최종 출력
"Layernorm class"
class LayerNorm(nn.Module) :
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features)) # 가중치(초깃값 : 1)
self.b_2 = nn.Parameter(torch.zeros(features)) # 편향(초깃값 : 0)
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
✅ Sub-Layer1) Residual Connection
- 출력에 입력 x를 더한 뒤 정규화 / 드롭아웃을 추가로 적용
class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size) # 정규화
self.dropout = nn.Dropout(dropout) # 드롭아웃
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x))) # 잔차 연결
✅ Sub-Layer2) Multi-head Self-Attention & Feed-forward Network
class EncoderLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn # Self-Attention
self.feed_forward = feed_forward # Feed-Forward Network
self.sublayer = clones(SublayerConnection(size, dropout), 2) # 2개의 서브레이어
self.size = size
def forward(self, x, mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) # Self-Attention
return self.sublayer[1](x, self.feed_forward) # Feed-Forward
🚀 디코더(Decoder)
- 디코더도 6개의 동일한 레이어로 구성이 되며, 3개의 서브레이어를 포함한다
- Multi-head Self-Attention
- Source-Target Attention (인코더의 출력과 디코더 입력 간의 상호작용)
- Feed-forward Network
class Decoder(nn.Module):
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3) # 3개의 서브레이어
def forward(self, x, memory, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) # Self-Attention
x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask)) # Source-Target Attention
return self.sublayer[2](x, self.feed_forward) # Feed-Forward
def subsequent_mask(size):
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
Attention Mechanism
<Attention의 개념>
→ Attention 메커니즘은 Query, Key, Value라는 세 가지 벡터를 입력받아 출력을 생성하는 함수로 출력은 Value들의 가중 합으로 계산된다. weight은 Query와 Key 간의 유사도를 측정하는 호환 함수(compativility function)을 통해 계산된다.
🚀 Scaled Dot-Product Attention
- 입력 벡터의 차원 :
- Q, K = 차원 d_k
- V = 차원 d_v
- 동작 방식
- Q와 모든 K 간의 내적(dot-product)을 통해 유사도 측정
- 유사도를 제곱근 d_k로 나누어 스케일링
- 이후 softmax 함수를 적용해 valude에 대한 가중치 생성
- 가중치를 V에 곱해 최종 출력 계산
- 공식
- 코드
def attention(query, key, value, mask=None, dropout=None):
d_k = query.size(-1) # Query(Key) 벡터의 차원
scores = torch.matmul(query, key.transpose(-2, -1)) \\
/ math.sqrt(d_k) # Query와 Key의 내적을 스케일링
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # 마스크 적용 (미래 위치 차단 등)
p_attn = F.softmax(scores, dim=-1) # Softmax로 가중치 계산
if dropout is not None:
p_attn = dropout(p_attn) # 드롭아웃 적용 (선택)
return torch.matmul(p_attn, value), p_attn # 가중치를 Value에 곱해 출력 생성
🚀 Multi-head Attention
- 여러 개의 병렬 Attention head를 통해 다양한 관점에서 정보를 학습할 수 있다.
- Multi-Head Attention 공식
- 코드
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0 # d_model이 h로 나누어 떨어져야 함
self.d_k = d_model // h # 각 헤드의 차원
self.h = h # 헤드의 개수
self.linears = clones(nn.Linear(d_model, d_model), 4) # Query, Key, Value, Output용 선형 변환
self.attn = None # Attention 가중치 저장
self.dropout = nn.Dropout(p=dropout) # 드롭아웃
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1) # 모든 헤드에 동일한 마스크 적용
nbatches = query.size(0) # 배치 크기
# 1) 선형 변환 및 차원 재조정: d_model => h x d_k
query, key, value = [
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))
]
# 2) 모든 헤드에 대해 Attention 적용
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
# 3) 헤드들을 결합(Concat)하고 최종 선형 변환 적용
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x) # 최종 출력
'AI > Deep Learning' 카테고리의 다른 글
[Pytorch] Pre-traing Vision Transformer로 Fine-tuning : 이미지 분류기 (0) | 2025.03.04 |
---|---|
[Transformer] C로 Transformer 구현하기 (0) | 2025.02.26 |
[Transformer 정리] 03. Positional Encoding과 특수 토큰 (0) | 2025.01.13 |
[Transformer 정리] 02. 트랜스포머 기본 구조 (0) | 2025.01.13 |
[Transformer 정리] 01. 개요 (2) | 2024.12.26 |