본문 바로가기

블로그/딥러닝

문장 벡터 요약 방법: Multi-Head Pooling

딥러닝 모델에서 문장이나 문서 전체를 하나의 벡터로 표현하는 과정은 매우 중요하다.  일반적으로는 [CLS] 토큰이나 평균 풀링(average pooling), 최대값 풀링(max pooling) 등을 사용하지만, 이 방식들은 각 단어의 중요도를 고려하지 못한다는 한계가 있다.  즉, 모든 토큰을 동일하게 취급한다는 점이다. 하지만 실제 문장의므는 특정 단어의 기여도가 훨씬 클 수도 있다.
 
참고: E5 모델은 문장 단위의 의미를 얻기 위해 평균 풀링을 수행함

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

 
이를 보안하기 위한 방법 중 하나로 Multi-head Pooling이라는 방법을 사용할 수 있다. 이 방식은 self-attention의 아이디어를 차용하여, 입력의 각 토큰에 대해 학습 가능한 중요도(attention score)를 부여하고, 이를 기반으로 여러 개의 관점(헤드 개수만큼)에서 정보를 요약한다.
 

Yang Liu and Mirella Lapata. 2019. Hierarchical Transformers for Multi-Document Summarization. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5070–5081, Florence, Italy. Association for Computational Linguistics.

 
Multihead Pooling은 입력 임베딩(예: 문장의 모든 단어 임베딩)에 대해 다음 과정을 수행한다:

  1. Query 벡터를 학습한다.
    입력 임베딩을 통과시켜 각 위치에 대해 attention score를 예측하는 선형층(linear layer)을 만든다. 이때 head 수만큼 독립적인 score를 생성한다.

    > input embedding[Batch, Sequence length, Embedding Dimension]를 Query와 Value로 활용하고, Key와 interaction이 없이 Query 순서에 따라 예측된 중요도 score만으로 attention을 계산함

    >  a = self.ln_attention_score(input_embedding)
  2. Softmax로 attention weight를 만든다.
    head별로 각 단어에 대한 가중치를 softmax를 통해 정규화하여 중요도를 표현한다.
    > a = F.softmax(a, dim=-1)  
  3. 가중합으로 벡터를 요약한다.
    계산된 attention weight를 통해 value 벡터(즉, 입력 임베딩)를 head별로 가중합하여 벡터를 만든다.
    -> new_v = a.matmul(v)
  4. head별 벡터를 concat하여 하나의 표현으로 만든다.
    여러 head가 만든 표현을 이어붙인 후, 최종적으로 선형 변환을 통해 원하는 차원의 문장 벡터를 생성한다.
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadPoolingLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.dim_per_head = embed_dim // num_heads  # 각 head의 차원

        self.ln_attention_score = nn.Linear(embed_dim, num_heads)               # [B, T, num_heads]
        self.ln_value = nn.Linear(embed_dim, num_heads * self.dim_per_head)     # [B, T, num_heads * dim_per_head]
        self.ln_out = nn.Linear(num_heads * self.dim_per_head, embed_dim)       # 다시 하나로 합쳐서 [B, embed_dim]

    def forward(self, input_embedding, mask=None, return_attention=False):
        """
        input_embedding: [B, T, E]  
            - B: 배치 크기 (batch size)  
            - T: 시퀀스 길이
            - E: 임베딩 차원 (embed_dim)

        mask: [B, T], True일 경우 마스킹
        """
        
        B, T, E = input_embedding.shape
        H = self.num_heads
        D = self.dim_per_head
        
        # (1) Attention score 계산
        a = self.ln_attention_score(input_embedding)           # [B, T, H]
        a = a.view(B, T, H, 1).transpose(1, 2)                 # → [B, H, 1, T]

        # (2) Value 벡터 계산
        v = self.ln_value(input_embedding)                     # [B, T, H * D]
        v = v.view(B, T, H, D).transpose(1, 2)                 # → [B, H, T, D]

        # (3) Mask 처리
        if mask is not None:
            # mask: [B, T] → [B, 1, 1, T] → broadcast to [B, H, 1, T]
            a = a.masked_fill(mask.unsqueeze(1).unsqueeze(2), -1e9)

        # (4) Softmax로 attention weight 계산
        a = F.softmax(a, dim=-1)                               # [B, H, 1, T]

        # (5) 가중합으로 pooling
        new_v = a.matmul(v)                                    # [B, H, 1, D]
        new_v = new_v.squeeze(2)                               # → [B, H, D]

        # (6) Head 차원 병합
        new_v = new_v.transpose(1, 2).contiguous()             # → [B, D, H]
        new_v = new_v.view(B, -1)                              # → [B, H*D]
        new_v = F.relu(new_v)                             
        new_v = self.ln_out(new_v)                             # → [B, E]

        if return_attention:
            return new_v, a.squeeze(2)                      
        else:
            return new_v

 
 
 
문장 개수(개), 단어 개수(단), 임베딩 차원(차)가 들어오면 아래와 같은 작동을 한다.
헤드 개수(해)

입력: [개, 단, 차]
↓
score:      [개, 단, 헤] → [개, 헤, 1, 단]
value:      [개, 단, 헤×차] → [개, 헤, 단, 차]
↓
Softmax(score) → Attention weights: [개, 헤, 1, 단]
↓
Attention weighted sum: [개, 헤, 1, 차] → squeeze → [개, 헤, 차]
↓
Concat heads: [개, 헤 × 차] → Linear → [개, 차]

 
 
참고:
1. Gu, Nianlong, Yingqiang Gao, and Richard HR Hahnloser. "Local citation recommendation with hierarchical-attention text encoder and scibert-based reranking." European conference on information retrieval(ECIR), 2022.
2.Yang Liu and Mirella Lapata. 2019. Hierarchical Transformers for Multi-Document Summarization. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5070–5081, Florence, Italy. Association for Computational Linguistics.