딥러닝 모델에서 문장이나 문서 전체를 하나의 벡터로 표현하는 과정은 매우 중요하다. 일반적으로는 [CLS] 토큰이나 평균 풀링(average pooling), 최대값 풀링(max pooling) 등을 사용하지만, 이 방식들은 각 단어의 중요도를 고려하지 못한다는 한계가 있다. 즉, 모든 토큰을 동일하게 취급한다는 점이다. 하지만 실제 문장의므는 특정 단어의 기여도가 훨씬 클 수도 있다.
참고: E5 모델은 문장 단위의 의미를 얻기 위해 평균 풀링을 수행함
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
이를 보안하기 위한 방법 중 하나로 Multi-head Pooling이라는 방법을 사용할 수 있다. 이 방식은 self-attention의 아이디어를 차용하여, 입력의 각 토큰에 대해 학습 가능한 중요도(attention score)를 부여하고, 이를 기반으로 여러 개의 관점(헤드 개수만큼)에서 정보를 요약한다.

Multihead Pooling은 입력 임베딩(예: 문장의 모든 단어 임베딩)에 대해 다음 과정을 수행한다:
- Query 벡터를 학습한다.
입력 임베딩을 통과시켜 각 위치에 대해 attention score를 예측하는 선형층(linear layer)을 만든다. 이때 head 수만큼 독립적인 score를 생성한다.
> input embedding[Batch, Sequence length, Embedding Dimension]를 Query와 Value로 활용하고, Key와 interaction이 없이 Query 순서에 따라 예측된 중요도 score만으로 attention을 계산함
> a = self.ln_attention_score(input_embedding) - Softmax로 attention weight를 만든다.
head별로 각 단어에 대한 가중치를 softmax를 통해 정규화하여 중요도를 표현한다.
> a = F.softmax(a, dim=-1) - 가중합으로 벡터를 요약한다.
계산된 attention weight를 통해 value 벡터(즉, 입력 임베딩)를 head별로 가중합하여 벡터를 만든다.
-> new_v = a.matmul(v) - head별 벡터를 concat하여 하나의 표현으로 만든다.
여러 head가 만든 표현을 이어붙인 후, 최종적으로 선형 변환을 통해 원하는 차원의 문장 벡터를 생성한다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadPoolingLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.dim_per_head = embed_dim // num_heads # 각 head의 차원
self.ln_attention_score = nn.Linear(embed_dim, num_heads) # [B, T, num_heads]
self.ln_value = nn.Linear(embed_dim, num_heads * self.dim_per_head) # [B, T, num_heads * dim_per_head]
self.ln_out = nn.Linear(num_heads * self.dim_per_head, embed_dim) # 다시 하나로 합쳐서 [B, embed_dim]
def forward(self, input_embedding, mask=None, return_attention=False):
"""
input_embedding: [B, T, E]
- B: 배치 크기 (batch size)
- T: 시퀀스 길이
- E: 임베딩 차원 (embed_dim)
mask: [B, T], True일 경우 마스킹
"""
B, T, E = input_embedding.shape
H = self.num_heads
D = self.dim_per_head
# (1) Attention score 계산
a = self.ln_attention_score(input_embedding) # [B, T, H]
a = a.view(B, T, H, 1).transpose(1, 2) # → [B, H, 1, T]
# (2) Value 벡터 계산
v = self.ln_value(input_embedding) # [B, T, H * D]
v = v.view(B, T, H, D).transpose(1, 2) # → [B, H, T, D]
# (3) Mask 처리
if mask is not None:
# mask: [B, T] → [B, 1, 1, T] → broadcast to [B, H, 1, T]
a = a.masked_fill(mask.unsqueeze(1).unsqueeze(2), -1e9)
# (4) Softmax로 attention weight 계산
a = F.softmax(a, dim=-1) # [B, H, 1, T]
# (5) 가중합으로 pooling
new_v = a.matmul(v) # [B, H, 1, D]
new_v = new_v.squeeze(2) # → [B, H, D]
# (6) Head 차원 병합
new_v = new_v.transpose(1, 2).contiguous() # → [B, D, H]
new_v = new_v.view(B, -1) # → [B, H*D]
new_v = F.relu(new_v)
new_v = self.ln_out(new_v) # → [B, E]
if return_attention:
return new_v, a.squeeze(2)
else:
return new_v
문장 개수(개), 단어 개수(단), 임베딩 차원(차)가 들어오면 아래와 같은 작동을 한다.
헤드 개수(해)
입력: [개, 단, 차]
↓
score: [개, 단, 헤] → [개, 헤, 1, 단]
value: [개, 단, 헤×차] → [개, 헤, 단, 차]
↓
Softmax(score) → Attention weights: [개, 헤, 1, 단]
↓
Attention weighted sum: [개, 헤, 1, 차] → squeeze → [개, 헤, 차]
↓
Concat heads: [개, 헤 × 차] → Linear → [개, 차]
참고:
1. Gu, Nianlong, Yingqiang Gao, and Richard HR Hahnloser. "Local citation recommendation with hierarchical-attention text encoder and scibert-based reranking." European conference on information retrieval(ECIR), 2022.
2.Yang Liu and Mirella Lapata. 2019. Hierarchical Transformers for Multi-Document Summarization. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5070–5081, Florence, Italy. Association for Computational Linguistics.