Made some updates.

Signed-off-by: Pavel Kirilin <s3riussan@gmail.com>
This commit is contained in:
2026-06-09 19:53:27 +02:00
parent 1c9029ec78
commit 9c343a65aa
3 changed files with 56 additions and 26 deletions

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import tiktoken
import torch
from llmfs.gpt import DummyGPT, GPTConfig, TransformerBlock
from llmfs.tokenizers import BPETokenizer
from llmfs.tokenizers import BPETokenizer, Tokenizer
DATA_DIR = Path(__file__).parent.parent / "data"
@@ -34,22 +34,25 @@ def generate_text_simple(
return idx
def txt_to_tokens(tokenizer: Tokenizer, text: str) -> torch.Tensor:
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
return torch.tensor(encoded).unsqueeze(0)
def tokens_to_txt(tokenizer: Tokenizer, tokens: torch.Tensor) -> str:
return tokenizer.decode(tokens.squeeze(0).tolist())
def process_text(text: str):
print("Buiding tokenizer")
# tokenizer = BPETokenizer.build(text)
tokenizer = tiktoken.encoding_for_model("gpt2")
vocab_size = tokenizer.max_token_value + 1
print(f"Tokenizer is ready. Vocab size: {vocab_size}")
batch = torch.stack(
[
torch.tensor(tokenizer.encode("Every effort moves you")),
torch.tensor(tokenizer.encode("Every day holds a")),
],
dim=0,
)
cfg = GPTConfig(
vocab_size=vocab_size,
context_length=1024,
context_length=256,
embedding_dim=768,
n_heads=12,
n_layers=12,
@@ -58,12 +61,11 @@ def process_text(text: str):
)
gpt = DummyGPT(cfg)
gpt.eval()
start_ctx = "Hello, I am"
encoded = tokenizer.encode(start_ctx)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
out = generate_text_simple(gpt, encoded_tensor, 6, cfg.context_length)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(decoded_text)
text = "Every effort moves you"
encoded = txt_to_tokens(tokenizer, text)
out = generate_text_simple(gpt, encoded, 6, cfg.context_length)
decoded = tokens_to_txt(tokenizer, out)
print(decoded)
# logits = gpt(batch)
# print(logits)
# print(logits.shape)

View File

@@ -1,12 +1,14 @@
from .stoopid import StoopidTokenizer
from .bpe import BPETokenizer
from typing import Protocol
from typing import AbstractSet, Iterable, Protocol
__all__ = ["BPETokenizer", "StoopidTokenizer", "Tokenizer"]
class Tokenizer(Protocol):
def encode(self, text: str) -> list[int]: ...
def encode(
self, text: str, allowed_special: AbstractSet[str] = set()
) -> list[int]: ...
def decode(self, tokens: list[int]) -> str: ...
@property
def max_token_value(self) -> int: ...

View File

@@ -1,6 +1,7 @@
from collections import Counter
from collections.abc import Iterable
import re
from typing import Self
from typing import AbstractSet, Self
class BPETokenizer:
@@ -8,9 +9,11 @@ class BPETokenizer:
UNKNOWN_TOKEN: str = "<|unknowntoken|>"
END_OF_TEXT: str = "<|endoftext|>"
def __init__(self, vocabulary: dict[str, int]) -> None:
def __init__(self, vocabulary: dict[str, int], specials: dict[str, int]) -> None:
self.forward: dict[str, int] = vocabulary
self.reverse: dict[int, str] = {idx: token for token, idx in vocabulary.items()}
self.specials = specials
self.special_values = set(specials.values())
self.unk_token: int = self.forward[self.UNKNOWN_TOKEN]
@property
@@ -18,11 +21,16 @@ class BPETokenizer:
return len(self.forward)
@classmethod
def build(cls, text: str, target_vocab_size: int = -1) -> Self:
def build(
cls,
text: str,
target_vocab_size: int = -1,
specials: set[str] = {END_OF_TEXT, UNKNOWN_TOKEN},
) -> Self:
preprocessed = list(
filter(bool, map(lambda x: x.lower().strip(), cls.SPLIT_PAT.split(text)))
)
pre_vocab: set[str] = set()
pre_vocab: set[str] = specials
for word in preprocessed:
pre_vocab |= set(word)
vocab: list[str] = sorted(pre_vocab)
@@ -63,18 +71,27 @@ class BPETokenizer:
vocab.extend([" ", cls.UNKNOWN_TOKEN, cls.END_OF_TEXT])
vocab_dict = {token: i for i, token in enumerate(vocab)}
return cls(vocab_dict)
specials_dict = {special: vocab_dict[special] for special in specials}
return cls(vocab_dict, specials_dict)
def _encode_word(self, word: str) -> list[int]:
def _encode_word(
self,
word: str,
allowed_specials: set[int],
) -> list[int]:
encoded: list[int] = []
parts = list(word.strip())
start_part_idx = 0
encoded: list[int] = []
while start_part_idx < len(parts):
found = False
for i in range(len(parts), start_part_idx, -1):
token = self.forward.get("".join(parts[start_part_idx:i]))
if token is not None:
found = True
if token in self.special_values and token not in allowed_specials:
raise ValueError(
f"The token '{self.reverse[token]}' is not allowed."
)
encoded.append(token)
start_part_idx = i
break
@@ -86,7 +103,11 @@ class BPETokenizer:
start_part_idx += 1
return encoded
def encode(self, text: str | list[str]) -> list[int]:
def encode(
self,
text: str | list[str],
allowed_special: AbstractSet[str] = set(),
) -> list[int]:
if isinstance(text, list):
text = f" {self.END_OF_TEXT} ".join(text)
@@ -94,10 +115,15 @@ class BPETokenizer:
filter(bool, map(lambda x: x.lower().strip(), self.SPLIT_PAT.split(text)))
)
tokens: list[int] = []
allowed_specials_tokens = {
self.forward[token] for token in allowed_special or []
}
for word in preprocessed:
if tokens:
tokens.append(self.forward[" "])
tokens.extend(self._encode_word(word))
tokens.extend(
self._encode_word(word, allowed_specials=allowed_specials_tokens)
)
return tokens
def decode(self, tokens: list[int]) -> str: