Compare commits

...

1 Commits

Author SHA1 Message Date
cbdd32faaa Added DummyGPT. 2026-05-01 12:56:05 +02:00
3 changed files with 51 additions and 7 deletions

View File

@ -3,11 +3,23 @@ from pathlib import Path
import torch import torch
from llmfs.attn import MultiHeadAttention from llmfs.attn import MultiHeadAttention
from llmfs.datasets.v1 import GPTDataSetV1 from llmfs.datasets.v1 import GPTDataSetV1
from llmfs.gpt import GPTConfig
from llmfs.tokenizers import BPETokenizer from llmfs.tokenizers import BPETokenizer
DATA_DIR = Path(__file__).parent.parent / "data" DATA_DIR = Path(__file__).parent.parent / "data"
GPT_CONFIG_124M = GPTConfig(
vocab_size=50257,
context_length=1024,
embedding_dim=768,
n_heads=12,
n_layers=12,
dropout=0.1,
qkv_bias=False,
)
def process_text(text: str): def process_text(text: str):
tokenizer = BPETokenizer.build(text) tokenizer = BPETokenizer.build(text)
vocab_size = tokenizer.max_token_value + 1 vocab_size = tokenizer.max_token_value + 1
@ -48,7 +60,7 @@ def attn_test():
batch = torch.stack((inps, inps), dim=0) batch = torch.stack((inps, inps), dim=0)
attn = MultiHeadAttention( attn = MultiHeadAttention(
inps.shape[1], inps.shape[1],
4, 8,
inps.shape[0], inps.shape[0],
dropout=True, dropout=True,
num_heads=2, num_heads=2,

View File

@ -1,4 +1,3 @@
from typing import final
import torch import torch
@ -15,7 +14,7 @@ class SelfAttention(torch.nn.Module):
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias) self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias) self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias) self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
self.dropout = torch.nn.Dropout(inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.register_buffer( self.register_buffer(
"mask", "mask",
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(), torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
@ -54,14 +53,14 @@ class MultiHeadAttention(torch.nn.Module):
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias) self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias) self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
self.out_proj = torch.nn.Linear(d_out, d_out) self.out_proj = torch.nn.Linear(d_out, d_out)
self.dropout = torch.nn.Dropout(inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.register_buffer( self.register_buffer(
"mask", "mask",
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1), torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
) )
def forward(self, data: torch.Tensor) -> torch.Tensor: def forward(self, data: torch.Tensor) -> torch.Tensor:
batches, num_tokens, _features = data.shape batches, num_tokens, _ = data.shape
queries = self.w_query(data) queries = self.w_query(data)
keys = self.w_key(data) keys = self.w_key(data)
values = self.w_val(data) values = self.w_val(data)
@ -71,16 +70,22 @@ class MultiHeadAttention(torch.nn.Module):
values_v = values.view( values_v = values.view(
batches, num_tokens, self.num_heads, self.head_dim batches, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
# 2, 6, 2, 4 -> 2, 2, 6, 4
# (batches, tokens, heads, out_dim) -> (batches, heads, tokens, out_dim)
queries_v = queries.view( queries_v = queries.view(
batches, num_tokens, self.num_heads, self.head_dim batches, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
# (2, 2, 6, 4) @ (2, 2, 4, 6) -> (2, 2, 6, 6)
# (batches, tokens, heads, out_dim) @ (batches', tokens', out_dim', heads') -> (batches, tokens, heads, heads')
attn_scores = queries_v @ keys_v.transpose(2, 3) attn_scores = queries_v @ keys_v.transpose(2, 3)
attn_scores.masked_fill_( attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], self.mask[:num_tokens, :num_tokens],
-torch.inf, -torch.inf,
) )
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1) attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1)
self.dropout(attn_weights) self.dropout(attn_weights)
# (2, 2, 6, 6) @ (2, 2, 6, 4) -> (2, 2, 6, 6)
# (2, 2, 6, 6) -> T(1,2) -> (2, 6, 2, 6)
context_vec = (attn_weights @ values_v).transpose(1, 2) context_vec = (attn_weights @ values_v).transpose(1, 2)
context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out) context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) context_vec = self.out_proj(context_vec)

27
llmfs/gpt.py Normal file
View File

@ -0,0 +1,27 @@
from dataclasses import dataclass
import torch
@dataclass
class GPTConfig:
vocab_size: int
context_length: int
embedding_dim: int
n_heads: int
n_layers: int
dropout: float
qkv_bias: bool
class DummyGPT:
def __init__(self, config: GPTConfig):
self.tok_embedding = torch.nn.Embedding(
config.vocab_size,
config.embedding_dim,
)
self.pos_embedding = torch.nn.Embedding(
config.context_length,
config.embedding_dim,
)
self.dropout = torch.nn.Dropout(config.dropout)