Compare commits
1 Commits
master
...
e51df51b46
| Author | SHA1 | Date | |
|---|---|---|---|
|
e51df51b46
|
@ -3,11 +3,23 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from llmfs.attn import MultiHeadAttention
|
from llmfs.attn import MultiHeadAttention
|
||||||
from llmfs.datasets.v1 import GPTDataSetV1
|
from llmfs.datasets.v1 import GPTDataSetV1
|
||||||
|
from llmfs.gpt import GPTConfig
|
||||||
from llmfs.tokenizers import BPETokenizer
|
from llmfs.tokenizers import BPETokenizer
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||||
|
|
||||||
|
|
||||||
|
GPT_CONFIG_124M = GPTConfig(
|
||||||
|
vocab_size=50257,
|
||||||
|
context_length=1024,
|
||||||
|
embedding_dim=768,
|
||||||
|
n_heads=12,
|
||||||
|
n_layers=12,
|
||||||
|
dropout=0.1,
|
||||||
|
qkv_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_text(text: str):
|
def process_text(text: str):
|
||||||
tokenizer = BPETokenizer.build(text)
|
tokenizer = BPETokenizer.build(text)
|
||||||
vocab_size = tokenizer.max_token_value + 1
|
vocab_size = tokenizer.max_token_value + 1
|
||||||
@ -48,7 +60,7 @@ def attn_test():
|
|||||||
batch = torch.stack((inps, inps), dim=0)
|
batch = torch.stack((inps, inps), dim=0)
|
||||||
attn = MultiHeadAttention(
|
attn = MultiHeadAttention(
|
||||||
inps.shape[1],
|
inps.shape[1],
|
||||||
4,
|
8,
|
||||||
inps.shape[0],
|
inps.shape[0],
|
||||||
dropout=True,
|
dropout=True,
|
||||||
num_heads=2,
|
num_heads=2,
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import final
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ class SelfAttention(torch.nn.Module):
|
|||||||
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
|
self.w_query = torch.nn.Linear(d_in, d_out, bias=bias)
|
||||||
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
||||||
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
||||||
self.dropout = torch.nn.Dropout(inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"mask",
|
"mask",
|
||||||
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
|
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
|
||||||
@ -54,14 +53,14 @@ class MultiHeadAttention(torch.nn.Module):
|
|||||||
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
self.w_key = torch.nn.Linear(d_in, d_out, bias=bias)
|
||||||
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
self.w_val = torch.nn.Linear(d_in, d_out, bias=bias)
|
||||||
self.out_proj = torch.nn.Linear(d_out, d_out)
|
self.out_proj = torch.nn.Linear(d_out, d_out)
|
||||||
self.dropout = torch.nn.Dropout(inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"mask",
|
"mask",
|
||||||
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1),
|
torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1).bool(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
def forward(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
batches, num_tokens, _features = data.shape
|
batches, num_tokens, _ = data.shape
|
||||||
queries = self.w_query(data)
|
queries = self.w_query(data)
|
||||||
keys = self.w_key(data)
|
keys = self.w_key(data)
|
||||||
values = self.w_val(data)
|
values = self.w_val(data)
|
||||||
@ -71,16 +70,22 @@ class MultiHeadAttention(torch.nn.Module):
|
|||||||
values_v = values.view(
|
values_v = values.view(
|
||||||
batches, num_tokens, self.num_heads, self.head_dim
|
batches, num_tokens, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
# 2, 6, 2, 4 -> 2, 2, 6, 4
|
||||||
|
# (batches, tokens, heads, out_dim) -> (batches, heads, tokens, out_dim)
|
||||||
queries_v = queries.view(
|
queries_v = queries.view(
|
||||||
batches, num_tokens, self.num_heads, self.head_dim
|
batches, num_tokens, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
# (2, 2, 6, 4) @ (2, 2, 4, 6) -> (2, 2, 6, 6)
|
||||||
|
# (batches, tokens, heads, out_dim) @ (batches', tokens', out_dim', heads') -> (batches, tokens, heads, heads')
|
||||||
attn_scores = queries_v @ keys_v.transpose(2, 3)
|
attn_scores = queries_v @ keys_v.transpose(2, 3)
|
||||||
attn_scores.masked_fill_(
|
attn_scores.masked_fill_(
|
||||||
self.mask.bool()[:num_tokens, :num_tokens],
|
self.mask[:num_tokens, :num_tokens],
|
||||||
-torch.inf,
|
-torch.inf,
|
||||||
)
|
)
|
||||||
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1)
|
attn_weights = (attn_scores / keys.shape[-1] ** 0.5).softmax(dim=-1)
|
||||||
self.dropout(attn_weights)
|
self.dropout(attn_weights)
|
||||||
|
# (2, 2, 6, 6) @ (2, 2, 6, 4) -> (2, 2, 6, 6)
|
||||||
|
# (2, 2, 6, 6) -> T(1,2) -> (2, 6, 2, 6)
|
||||||
context_vec = (attn_weights @ values_v).transpose(1, 2)
|
context_vec = (attn_weights @ values_v).transpose(1, 2)
|
||||||
context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out)
|
context_vec = context_vec.contiguous().view(batches, num_tokens, self.d_out)
|
||||||
context_vec = self.out_proj(context_vec)
|
context_vec = self.out_proj(context_vec)
|
||||||
|
|||||||
27
llmfs/gpt.py
Normal file
27
llmfs/gpt.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTConfig:
|
||||||
|
vocab_size: int
|
||||||
|
context_length: int
|
||||||
|
embedding_dim: int
|
||||||
|
n_heads: int
|
||||||
|
n_layers: int
|
||||||
|
dropout: float
|
||||||
|
qkv_bias: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DummyGPT:
|
||||||
|
def __init__(self, config: GPTConfig):
|
||||||
|
self.tok_embedding = torch.nn.Embedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.embedding_dim,
|
||||||
|
)
|
||||||
|
self.pos_embedding = torch.nn.Embedding(
|
||||||
|
config.context_length,
|
||||||
|
config.embedding_dim,
|
||||||
|
)
|
||||||
|
self.dropout = torch.nn.Dropout(config.dropout)
|
||||||
Reference in New Issue
Block a user