0. LLM workflow
1)입력 인베딩 -> 2)트랜스포머 블록(멀티 헤드 어텐션, 피드포워드 네트워크, 잔차 연결 및 정규화) -> 3)출력 레이어 4) 출력 디코딩
1. 정의
Residual Connection은 신경망에서 발생하는 기울기 소실(vanishing gradient) 문제를 해결하기 위해 도입되었다.
신경망의 특정 층에서 입력값을 출력값에 더해주는 방식인데, 어떤 변환을 수행한 F(X)에
원래 입력값인 X를 더해 최종 출력을 만든다. Y = F(X) + X
## 트랜스포머 블록 구조 상 멀티 헤드 어텐션(출력)과 피드 포워드 네트워크(출력과 입력)에 적용
import torch
import torch.nn as nn
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForwardNetwork, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(d_ff, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x # 입력값을 residual에 저장
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x += residual # 잔차 연결
x = self.layer_norm(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
attn_output, _ = self.attention(x, x, x)
x = attn_output + residual
x = self.layer_norm(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForwardNetwork(d_model, d_ff)
def forward(self, x):
x = self.attention(x)
x = self.feed_forward(x)
return x
# 예시
d_model = 512 #
num_heads = 8 #
d_ff = 2048 #
batch_size = 32
seq_length = 10
input_tensor = torch.randn(batch_size, seq_length, d_model)
transformer_block = TransformerBlock(d_model, num_heads, d_ff)
output_tensor = transformer_block(input_tensor)
print(output_tensor.shape)
2. 역할
그렇다면 이 구성요소를 왜 다루느냐, 앞에서 언급했던 (1)기울기 소실 방지가 가장 크고 (2)빠른 학습을 위해 사용하며 (3)효과적인 정보 전달을 위해 사용하게 된다.
신경망이 깊어질수록 기존 input의 정보가 소실될 수 있기 때문에 주어진 F(x)가 0에 가까워지더라도 입력값 x가 더해진다면 정보가 그대로 전달될 수 있기 때문에 사용한다.
반응형
'스터디 > AI' 카테고리의 다른 글
벡터 DB(1) 튜토리얼 (0) | 2024.07.04 |
---|---|
파이토치(2): 로지스틱 회귀와 클래스를 통한 구현 (0) | 2023.11.05 |
파이토치(1): Date Set & Tensor (0) | 2023.09.27 |
생성형 인공지능(1): Auto-GPT 소개 및 사용 (1) | 2023.09.17 |