132 lines
3.6 KiB
Python
132 lines
3.6 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
from llmfs.attn import MultiHeadAttention
|
|
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
vocab_size: int
|
|
context_length: int
|
|
embedding_dim: int
|
|
n_heads: int
|
|
n_layers: int
|
|
dropout: float
|
|
qkv_bias: bool
|
|
|
|
|
|
class DummyTransformerBlock(torch.nn.Module):
|
|
def __init__(self, config: GPTConfig):
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x
|
|
|
|
|
|
class GELU(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return (
|
|
0.5
|
|
* x
|
|
* (
|
|
1
|
|
+ torch.tanh(
|
|
torch.sqrt(torch.tensor(2.0 / torch.pi))
|
|
* (x + 0.44715 * torch.pow(x, 3))
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
class FeedForward(torch.nn.Module):
|
|
def __init__(self, cfg: GPTConfig) -> None:
|
|
super().__init__()
|
|
self.layers = torch.nn.Sequential(
|
|
torch.nn.Linear(cfg.embedding_dim, 4 * cfg.embedding_dim),
|
|
GELU(),
|
|
torch.nn.Linear(cfg.embedding_dim * 4, cfg.embedding_dim),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.layers(x)
|
|
|
|
|
|
class TransformerBlock(torch.nn.Module):
|
|
def __init__(self, cfg: GPTConfig) -> None:
|
|
super().__init__()
|
|
self.att = MultiHeadAttention(
|
|
cfg.embedding_dim,
|
|
cfg.embedding_dim,
|
|
cfg.context_length,
|
|
cfg.dropout,
|
|
cfg.qkv_bias,
|
|
)
|
|
self.ff = FeedForward(cfg)
|
|
self.norm1 = NormLayer(cfg.embedding_dim)
|
|
self.norm2 = NormLayer(cfg.embedding_dim)
|
|
self.dropout = torch.nn.Dropout(cfg.dropout)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
x = self.att(x)
|
|
x = self.dropout(x)
|
|
x = x + shortcut
|
|
|
|
shortcut = x
|
|
x = self.norm2(x)
|
|
x = self.ff(x)
|
|
x = self.dropout(x)
|
|
x = x + shortcut
|
|
return x
|
|
|
|
|
|
class NormLayer(torch.nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
self.scale = torch.nn.Parameter(torch.ones(dim))
|
|
self.shift = torch.nn.Parameter(torch.zeros(dim))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
mean = x.mean(-1, keepdim=True)
|
|
var = x.var(-1, keepdim=True, unbiased=True)
|
|
# Makes mean = 0 and variance = 1
|
|
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
return self.scale * norm_x + self.shift
|
|
|
|
|
|
class DummyGPT(torch.nn.Module):
|
|
def __init__(self, config: GPTConfig):
|
|
super().__init__()
|
|
self.tok_embedding = torch.nn.Embedding(
|
|
config.vocab_size,
|
|
config.embedding_dim,
|
|
)
|
|
self.pos_embedding = torch.nn.Embedding(
|
|
config.context_length,
|
|
config.embedding_dim,
|
|
)
|
|
self.drop_emb = torch.nn.Dropout(config.dropout)
|
|
self.trf_blocks = torch.nn.Sequential(
|
|
*[TransformerBlock(config) for _ in range(config.n_layers)]
|
|
)
|
|
self.final_norm = NormLayer(config.embedding_dim)
|
|
self.out_head = torch.nn.Linear(
|
|
config.embedding_dim,
|
|
config.vocab_size,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
_, seq_len = x.shape
|
|
res = self.tok_embedding(x) + self.pos_embedding(
|
|
torch.arange(seq_len, device=x.device)
|
|
)
|
|
res = self.drop_emb(res)
|
|
res = self.trf_blocks(res)
|
|
res = self.final_norm(res)
|
|
return self.out_head(res)
|