Search
✴️

Grouped Query Attention (Llama3)

Category
BlogPost
Venue
Backbone
Text
PPT

Problem in Multi-Query Attention

Query head는 여러개 사용하고 Single Key, Value head를 사용하는 MQA는 training시의 instability와 quality degradation을 야기 (representation 한계)
이를 위해 GQA는 Multi-Head Attention으로 학습된 LM에 Grouped Query Attention으로 Continual Learning해도 (1) 성능 저하 없이 (2) Efficiency를 향상시킬 수 있는 GQA 방법론을 제시한다.

Grouped Query Attention

#### What is GQA
GQA는 MHA와 MQA의 interpolation이라고 논문에서 주장한다.
GQA는 기존의 H개의 Query heads를 G groups으로 나누는데, G groups내에서는 single key head와 value head가 존재하도록 한다.
GQA-1는 MQA랑 동일하고, GQA-H는 MHA랑 동일하다.
그렇다면 GQA-G에서 H/G개로 묶인 key와 value head representation은 어떻게 만들어야 할까?
(이 논문은 MHA pretrained LM을 GQA로 continual learning하는 상황을 상정하기에)
아래의 방법론들을 고려해보았고 continual pre-trainined setting에서느 2가 가장 좋았다고 한다.
(Llama3부터는 GQA로 from the scratch pre-training을 하는듯!)
1.
random initialized
2.
mean pooling
GQA는 어떤 장점이 있는걸까?
1.
LM의 사이즈를 키울수록 head 개수도 증가하게 되는데, 이는 곧 key-value 캐시의 크기와 로드해야 하는 데이터의 양 모델 크기가 증가한다는 것을 의미하고 GQA는 이를 줄이는 방법론이기에 비례적으로 대역폭과 용량을 감소시킬 수 있다는 장점이 존재한다.
2.
Popeet al., 2022에 따르면 standard sharding에서는 single key와 value를 model partition마다 복사를 해줘야한다는 한계가 있지만 GQA는 이를 완화해준다고 합니다. (아래의 코드가 관련 해답을 찾아줍니다!)
#### CODE
class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights
Python
복사
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
head 개수로 projection
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
H/G 개수로 projection
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
H/G 개수로 projection
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
나중에 viewing한 group 수
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights
Python
복사
repeat (expand를 활용한 차원 맞추기)
초기 KV 상태: (batch, num_kv_heads, seq_len, head_dim)
None 추가 후: (batch, num_kv_heads, 1, seq_len, head_dim)
hidden_states[:, :, None, :, :]
expand 후: (batch, num_kv_heads, module.num_key_value_groups, seq_len, head_dim)
.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
reshape 후: (batch, num_kv_heads * module.num_key_value_groups, seq_len, head_dim)
hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
expand는:
실제 메모리 할당 없이 stride 조작으로 구현
동일한 메모리 위치를 여러 번 참조하는 view를 생성
메모리 사용량 증가 없이 차원만 확장
참고 그림 (1)
참고 그림 (2)
→ 위에서도 설명했지만 조금 더 구체적으로 설명하면,
None을 추가한 차원(세 번째 차원)이 n_rep만큼 확장될 때
hidden_states[:, :, None, :, :]
그 아래 차원들(seq_len, head_dim)에 대해서는
.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
num_kv_heads 각각의 데이터를 n_rep만큼 포인팅
expand로 생긴 텐서는 사실상 텐서인데 데이터 자체보다는 다른 데이터를 참조 ⇒ PyTorch에서는 view라고 부름.
eager_attention_forward
attention_interface: Callable에서 함수 정의
별 조건 없으면 ‘attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]’에서 attention function으로 설정
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling로 차원 문제 없이 attention 연산 수행 가능

References