스터디/AI

LLM component(1): Residual Connection(작성중)

민서타 2025. 3. 5. 19:46

0. LLM workflow

1)입력 인베딩 -> 2)트랜스포머 블록(멀티 헤드 어텐션, 피드포워드 네트워크, 잔차 연결 및 정규화) -> 3)출력 레이어 4) 출력 디코딩

1. 정의

Residual Connection은 신경망에서 발생하는 기울기 소실(vanishing gradient) 문제를 해결하기 위해 도입되었다.

신경망의 특정 층에서 입력값을 출력값에 더해주는 방식인데, 어떤 변환을 수행한 F(X)에

원래 입력값인 X를 더해 최종 출력을 만든다. Y = F(X) + X

## 트랜스포머 블록 구조 상 멀티 헤드 어텐션(출력)과 피드 포워드 네트워크(출력과 입력)에 적용

import torch
import torch.nn as nn

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_ff, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x  # 입력값을 residual에 저장
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x += residual  # 잔차 연결
        x = self.layer_norm(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        attn_output, _ = self.attention(x, x, x)
        x = attn_output + residual
        x = self.layer_norm(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardNetwork(d_model, d_ff)

    def forward(self, x):
        x = self.attention(x)
        x = self.feed_forward(x)
        return x

# 예시
d_model = 512  #
num_heads = 8  #
d_ff = 2048   #
batch_size = 32
seq_length = 10

input_tensor = torch.randn(batch_size, seq_length, d_model)

transformer_block = TransformerBlock(d_model, num_heads, d_ff)
output_tensor = transformer_block(input_tensor)

print(output_tensor.shape)

2. 역할

그렇다면 이 구성요소를 왜 다루느냐, 앞에서 언급했던 (1)기울기 소실 방지가 가장 크고 (2)빠른 학습을 위해 사용하며 (3)효과적인 정보 전달을 위해 사용하게 된다.

신경망이 깊어질수록 기존 input의 정보가 소실될 수 있기 때문에 주어진 F(x)가 0에 가까워지더라도 입력값 x가 더해진다면 정보가 그대로 전달될 수 있기 때문에 사용한다.

반응형