Added MultiHeadSelfAttention.

This commit is contained in:
Pavel Kirilin
2026-04-30 03:46:35 +02:00
parent 7a67050b77
commit ac8852d25e
9 changed files with 691 additions and 27 deletions

87
llmfs/attn.py Normal file
View File

@ -0,0 +1,87 @@
from typing import final
import torch
class SelfAttention(torch.nn.Module):
def __init__(
self,
d_in: int,
d_out: int,
ctx_len: int,
dropout: float,
bias: bool = False,
) -> None:
super().__init__()
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
self.dropout = torch.nn.Dropout(inplace=True)
self.register_buffer(
"mask",
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
)
def forward(self, data: torch.Tensor) -> torch.Tensor:
queries = self.w_query(data)
keys = self.w_key(data)
values = self.w_val(data)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill(self.mask, -torch.inf)
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(-1)
self.dropout(attn_weights)
return attn_weights @ values # Context vec
class MultiHeadAttention(torch.nn.Module):
def __init__(
self,
d_in: int,
d_out: int,
ctx_len: int,
dropout: float,
bias: bool = False,
num_heads: int = 2,
) -> None:
super().__init__()
if d_out % num_heads != 0:
raise RuntimeError(
"Output dimention should be divisible by number of heads",
)
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
self.out_proj = torch.nn.Linear(d_out, d_out)
self.dropout = torch.nn.Dropout(inplace=True)
self.register_buffer(
"mask",
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1),
)
def forward(self, data: torch.Tensor) -> torch.Tensor:
batches, num_tokens, _features = data.shape
queries = self.w_query(data)
keys = self.w_key(data)
values = self.w_val(data)
keys_v = keys.view(
batches, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2)
values_v = values.view(
batches, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2)
queries_v = queries.view(
batches, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2)
attn_scores = queries_v @ keys_v.transpose(2, 3)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens],
-torch.inf,
)
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1)
self.dropout(attn_weights)
context_vec = (attn_weights @ values_v).transpose(1, 2)
context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec