1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis
def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2) xq_ = torch.view_as_complex(xq_) xk_ = torch.view_as_complex(xk_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__()
self.wq = Linear(...) self.wk = Linear(...) self.wv = Linear(...) self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
def forward(self, x: torch.Tensor): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(batch_size, seq_len, dim) xk = xk.view(batch_size, seq_len, dim) xv = xv.view(batch_size, seq_len, dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim) scores = F.softmax(scores.float(), dim=-1) output = torch.matmul(scores, xv)
|