Problem in Multi-Query Attention
•
Query head는 여러개 사용하고 Single Key, Value head를 사용하는 MQA는 training시의 instability와 quality degradation을 야기 (representation 한계)
•
이를 위해 GQA는 Multi-Head Attention으로 학습된 LM에 Grouped Query Attention으로 Continual Learning해도 (1) 성능 저하 없이 (2) Efficiency를 향상시킬 수 있는 GQA 방법론을 제시한다.
Grouped Query Attention
#### What is GQA
•
GQA는 MHA와 MQA의 interpolation이라고 논문에서 주장한다.
•
GQA는 기존의 H개의 Query heads를 G groups으로 나누는데, G groups내에서는 single key head와 value head가 존재하도록 한다.
◦
GQA-1는 MQA랑 동일하고, GQA-H는 MHA랑 동일하다.
◦
그렇다면 GQA-G에서 H/G개로 묶인 key와 value head representation은 어떻게 만들어야 할까?
(이 논문은 MHA pretrained LM을 GQA로 continual learning하는 상황을 상정하기에)
아래의 방법론들을 고려해보았고 continual pre-trainined setting에서느 2가 가장 좋았다고 한다.
(Llama3부터는 GQA로 from the scratch pre-training을 하는듯!)
1.
random initialized
2.
mean pooling
•
GQA는 어떤 장점이 있는걸까?
1.
LM의 사이즈를 키울수록 head 개수도 증가하게 되는데, 이는 곧 key-value 캐시의 크기와 로드해야 하는 데이터의 양 모델 크기가 증가한다는 것을 의미하고 GQA는 이를 줄이는 방법론이기에 비례적으로 대역폭과 용량을 감소시킬 수 있다는 장점이 존재한다.
2.
Popeet al., 2022에 따르면 standard sharding에서는 single key와 value를 model partition마다 복사를 해줘야한다는 한계가 있지만 GQA는 이를 완화해준다고 합니다. (아래의 코드가 관련 해답을 찾아줍니다!)
#### CODE
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Python
복사
•
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
◦
head 개수로 projection
•
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
◦
H/G 개수로 projection
•
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
◦
H/G 개수로 projection
•
num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
◦
나중에 viewing한 group 수
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
Python
복사
•
repeat (expand를 활용한 차원 맞추기)
◦
초기 KV 상태: (batch, num_kv_heads, seq_len, head_dim)
◦
None 추가 후: (batch, num_kv_heads, 1, seq_len, head_dim)
▪
hidden_states[:, :, None, :, :]
◦
expand 후: (batch, num_kv_heads, module.num_key_value_groups, seq_len, head_dim)
▪
.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
◦
reshape 후: (batch, num_kv_heads * module.num_key_value_groups, seq_len, head_dim)
▪
hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
⇒ expand는:
◦
실제 메모리 할당 없이 stride 조작으로 구현
◦
동일한 메모리 위치를 여러 번 참조하는 view를 생성
◦
메모리 사용량 증가 없이 차원만 확장
◦
참고 그림 (1)
◦
참고 그림 (2)
→ 위에서도 설명했지만 조금 더 구체적으로 설명하면,
▪
None을 추가한 차원(세 번째 차원)이 n_rep만큼 확장될 때
hidden_states[:, :, None, :, :]
▪
그 아래 차원들(seq_len, head_dim)에 대해서는
.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
▪
num_kv_heads 각각의 데이터를 n_rep만큼 포인팅
▪
expand로 생긴 텐서는 사실상 텐서인데 데이터 자체보다는 다른 데이터를 참조 ⇒ PyTorch에서는 view라고 부름.
•
eager_attention_forward
◦
attention_interface: Callable에서 함수 정의
▪
별 조건 없으면 ‘attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]’에서 attention function으로 설정
◦
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling로 차원 문제 없이 attention 연산 수행 가능