Multi Query Attention & Group Query Attention
Swift Lv6

Multi Query Attention(MQA)在2019年就被提出来了,用于推理加速,但在当时并没有受到很多关注,毕竟一张2080就能跑Bert-base了。随着LLM的大火,MQA所带来的收益得以放大。

思路

Multi Query Attention(MQA)跟Multi Head Attention(MHA)只有一词之差,但其思路非常简单,几乎跟MHA一致:

model

MHA的Query、Key、Value分拆成8个头,每个头进行self-attention运算,而MQA是Query分成8个头,每个头共享一组Key和Value

1
2
3
4
5
6
7
8
9
10
11
MHA: Q, K, V = (512, 768), # seq_len, hidden_dim
拆成8个头:
Q : (8, 512, 96)
k, v: (8, 512, 96)
MQA:
Q -> (512, 768)
K -> (512, 96)
v -> (512, 96)
把Q拆成8个头:
Q: (8, 512, 96)
K, V:(512, 96)

代码实现

  • MHA

    1
    2
    3
    4
    5
    6
    7
    ...
    self.Wqkv = nn.Linear(
    d_model,
    d_model * 3,
    device=device,
    )
    ...

    d_model * 3 拆成3个768维

  • MQA

    1
    2
    3
    4
    5
    6
    7
    ...
    self.Wqkv = nn.Linear(
    d_model,
    d_model + 2 * self.head_dim,
    device=device,
    )
    ...

    d_model + 2 * self.head_dim 拆成1个768维 + 2个96维

可以看到参数数量大幅减少。

实验结果

实验指标略微降低,但推理加速非常明显。

result

Group Query Attention

Q拆分成8个头,K和V分别拆成4个头,然后对应进行attention运算。


参考

Powered by Hexo & Theme Keep
Unique Visitor Page View