diff --git a/llmfs/__main__.py b/llmfs/__main__.py index 36d1fbc..3cf0fe0 100644 --- a/llmfs/__main__.py +++ b/llmfs/__main__.py @@ -3,11 +3,23 @@ from pathlib import Path import torch from llmfs.attn import MultiHeadAttention from llmfs.datasets.v1 import GPTDataSetV1 +from llmfs.gpt import GPTConfig from llmfs.tokenizers import BPETokenizer DATA_DIR = Path(__file__).parent.parent / "data" +GPT_CONFIG_124M = GPTConfig( + vocab_size=50257, + context_length=1024, + embedding_dim=768, + n_heads=12, + n_layers=12, + dropout=0.1, + qkv_bias=False, +) + + def process_text(text: str): tokenizer = BPETokenizer.build(text) vocab_size = tokenizer.max_token_value + 1 @@ -48,7 +60,7 @@ def attn_test(): batch = torch.stack((inps, inps), dim=0) attn = MultiHeadAttention( inps.shape[1], - 4, + 8, inps.shape[0], dropout=True, num_heads=2, diff --git a/llmfs/attn.py b/llmfs/attn.py index 7a3a223..80a3d16 100644 --- a/llmfs/attn.py +++ b/llmfs/attn.py @@ -1,4 +1,3 @@ -from typing import final import torch @@ -15,7 +14,7 @@ class SelfAttention(torch.nn.Module): self.w_query = torch.nn.Linear(d_in, d_out, bias=bias) self.w_key = torch.nn.Linear(d_in, d_out, bias=bias) self.w_val = torch.nn.Linear(d_in, d_out, bias=bias) - self.dropout = torch.nn.Dropout(inplace=True) + self.dropout = torch.nn.Dropout(dropout, inplace=True) self.register_buffer( "mask", torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(), @@ -54,14 +53,14 @@ class MultiHeadAttention(torch.nn.Module): self.w_key = torch.nn.Linear(d_in, d_out, bias=bias) self.w_val = torch.nn.Linear(d_in, d_out, bias=bias) self.out_proj = torch.nn.Linear(d_out, d_out) - self.dropout = torch.nn.Dropout(inplace=True) + self.dropout = torch.nn.Dropout(dropout, inplace=True) self.register_buffer( "mask", - torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1), + torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(), ) def forward(self, data: torch.Tensor) -> torch.Tensor: - batches, num_tokens, _features = data.shape + batches, num_tokens, _ = data.shape queries = self.w_query(data) keys = self.w_key(data) values = self.w_val(data) @@ -71,16 +70,22 @@ class MultiHeadAttention(torch.nn.Module): values_v = values.view( batches, num_tokens, self.num_heads, self.head_dim ).transpose(1, 2) + # 2, 6, 2, 4 -> 2, 2, 6, 4 + # (batches, tokens, heads, out_dim) -> (batches, heads, tokens, out_dim) queries_v = queries.view( batches, num_tokens, self.num_heads, self.head_dim ).transpose(1, 2) + # (2, 2, 6, 4) @ (2, 2, 4, 6) -> (2, 2, 6, 6) + # (batches, tokens, heads, out_dim) @ (batches', tokens', out_dim', heads') -> (batches, tokens, heads, heads') attn_scores = queries_v @ keys_v.transpose(2, 3) attn_scores.masked_fill_( - self.mask.bool()[:num_tokens, :num_tokens], + self.mask[:num_tokens, :num_tokens], -torch.inf, ) attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1) self.dropout(attn_weights) + # (2, 2, 6, 6) @ (2, 2, 6, 4) -> (2, 2, 6, 6) + # (2, 2, 6, 6) -> T(1,2) -> (2, 6, 2, 6) context_vec = (attn_weights @ values_v).transpose(1, 2) context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) diff --git a/llmfs/gpt.py b/llmfs/gpt.py new file mode 100644 index 0000000..a1fbbc2 --- /dev/null +++ b/llmfs/gpt.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import torch + + +@dataclass +class GPTConfig: + vocab_size: int + context_length: int + embedding_dim: int + n_heads: int + n_layers: int + dropout: float + qkv_bias: bool + + +class DummyGPT: + def __init__(self, config: GPTConfig): + self.tok_embedding = torch.nn.Embedding( + config.vocab_size, + config.embedding_dim, + ) + self.pos_embedding = torch.nn.Embedding( + config.context_length, + config.embedding_dim, + ) + self.dropout = torch.nn.Dropout(config.dropout)