Files
llmfs/llmfs/gpt.py

132 lines
3.6 KiB
Python

from dataclasses import dataclass
import torch
from llmfs.attn import MultiHeadAttention
@dataclass
class GPTConfig:
vocab_size: int
context_length: int
embedding_dim: int
n_heads: int
n_layers: int
dropout: float
qkv_bias: bool
class DummyTransformerBlock(torch.nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class GELU(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (
0.5
* x
* (
1
+ torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi))
* (x + 0.44715 * torch.pow(x, 3))
)
)
)
class FeedForward(torch.nn.Module):
def __init__(self, cfg: GPTConfig) -> None:
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(cfg.embedding_dim, 4 * cfg.embedding_dim),
GELU(),
torch.nn.Linear(cfg.embedding_dim * 4, cfg.embedding_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class TransformerBlock(torch.nn.Module):
def __init__(self, cfg: GPTConfig) -> None:
super().__init__()
self.att = MultiHeadAttention(
cfg.embedding_dim,
cfg.embedding_dim,
cfg.context_length,
cfg.dropout,
cfg.qkv_bias,
)
self.ff = FeedForward(cfg)
self.norm1 = NormLayer(cfg.embedding_dim)
self.norm2 = NormLayer(cfg.embedding_dim)
self.dropout = torch.nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
x = self.att(x)
x = self.dropout(x)
x = x + shortcut
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.dropout(x)
x = x + shortcut
return x
class NormLayer(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.scale = torch.nn.Parameter(torch.ones(dim))
self.shift = torch.nn.Parameter(torch.zeros(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True, unbiased=True)
# Makes mean = 0 and variance = 1
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class DummyGPT(torch.nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.tok_embedding = torch.nn.Embedding(
config.vocab_size,
config.embedding_dim,
)
self.pos_embedding = torch.nn.Embedding(
config.context_length,
config.embedding_dim,
)
self.drop_emb = torch.nn.Dropout(config.dropout)
self.trf_blocks = torch.nn.Sequential(
*[TransformerBlock(config) for _ in range(config.n_layers)]
)
self.final_norm = NormLayer(config.embedding_dim)
self.out_head = torch.nn.Linear(
config.embedding_dim,
config.vocab_size,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, seq_len = x.shape
res = self.tok_embedding(x) + self.pos_embedding(
torch.arange(seq_len, device=x.device)
)
res = self.drop_emb(res)
res = self.trf_blocks(res)
res = self.final_norm(res)
return self.out_head(res)