93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
import torch
|
|
|
|
|
|
class SelfAttention(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_in: int,
|
|
d_out: int,
|
|
ctx_len: int,
|
|
dropout: float,
|
|
bias: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
|
self.register_buffer(
|
|
"mask",
|
|
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
|
|
)
|
|
|
|
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
|
queries = self.w_query(data)
|
|
keys = self.w_key(data)
|
|
values = self.w_val(data)
|
|
attn_scores = queries @ keys.transpose(1, 2)
|
|
attn_scores.masked_fill(self.mask, -torch.inf)
|
|
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(-1)
|
|
self.dropout(attn_weights)
|
|
return attn_weights @ values # Context vec
|
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_in: int,
|
|
d_out: int,
|
|
ctx_len: int,
|
|
dropout: float,
|
|
bias: bool = False,
|
|
num_heads: int = 2,
|
|
) -> None:
|
|
super().__init__()
|
|
if d_out % num_heads != 0:
|
|
raise RuntimeError(
|
|
"Output dimention should be divisible by number of heads",
|
|
)
|
|
self.d_out = d_out
|
|
self.num_heads = num_heads
|
|
self.head_dim = d_out // num_heads
|
|
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
|
self.out_proj = torch.nn.Linear(d_out, d_out)
|
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
|
self.register_buffer(
|
|
"mask",
|
|
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
|
|
)
|
|
|
|
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
|
batches, num_tokens, _ = data.shape
|
|
queries = self.w_query(data)
|
|
keys = self.w_key(data)
|
|
values = self.w_val(data)
|
|
keys_v = keys.view(
|
|
batches, num_tokens, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
values_v = values.view(
|
|
batches, num_tokens, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
# 2, 6, 2, 4 -> 2, 2, 6, 4
|
|
# (batches, tokens, heads, out_dim) -> (batches, heads, tokens, out_dim)
|
|
queries_v = queries.view(
|
|
batches, num_tokens, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
# (2, 2, 6, 4) @ (2, 2, 4, 6) -> (2, 2, 6, 6)
|
|
# (batches, tokens, heads, out_dim) @ (batches', tokens', out_dim', heads') -> (batches, tokens, heads, heads')
|
|
attn_scores = queries_v @ keys_v.transpose(2, 3)
|
|
attn_scores.masked_fill_(
|
|
self.mask[:num_tokens, :num_tokens],
|
|
-torch.inf,
|
|
)
|
|
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1)
|
|
self.dropout(attn_weights)
|
|
# (2, 2, 6, 6) @ (2, 2, 6, 4) -> (2, 2, 6, 6)
|
|
# (2, 2, 6, 6) -> T(1,2) -> (2, 6, 2, 6)
|
|
context_vec = (attn_weights @ values_v).transpose(1, 2)
|
|
context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out)
|
|
context_vec = self.out_proj(context_vec)
|
|
return context_vec
|