旋转位置编码
Swift Lv6

旋转位置编码具有良好的外推性,即模型在预测时可以处理比训练时更长的序列。

想要获得良好的外推性,必须使用相对位置编码。Transformer使用的是绝对位置编码,外推性不强。

pos

那么,如何使用绝对位置编码来实现相对位置编码呢?

推导过程

欧拉公式:$e^{i x}=\cos x+i \sin x$

其中,$m$ 就是位置下标,$\theta_j=10000^{-2(j-1) / d}, j \in[1,2, \ldots, d / 2]$,跟transformer基本一致。

https://zhuanlan.zhihu.com/p/642884818

下面这是极简的证明:

prove

源码解读

LLaMA的官方代码并不是直接乘以一个旋转矩阵,而是利用复数乘法性质来实现RoPE。我们的目标是对 $x_m$ 添加位置编码,即:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float()
# torch.polar 的文档
# https://pytorch.org/docs/stable/generated/torch.polar.html
# 计算结果是个复数向量:e^{im\theta}
# polar(abs, angle, *, out=None) -> Tensor: abs是幅值,angle是相位角
# 假设 freqs = [x, y]
# 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [batch_size, seq_len, dim]
# xq_.shape = [batch_size, seq_len, dim // 2, 2]
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

# 转为复数域:q_m^{(1)} + iq_m^{(2)}
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)

# 应用旋转操作,然后将结果转回实数域
# xq_out.shape = [batch_size, seq_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)

self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(batch_size, seq_len, dim)
xk = xk.view(batch_size, seq_len, dim)
xv = xv.view(batch_size, seq_len, dim)

# attention 操作之前,应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# scores.shape = (bs, seqlen, seqlen)
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)
# ......

参考

Powered by Hexo & Theme Keep
Unique Visitor Page View