diff --git a/llmfs/__main__.py b/llmfs/__main__.py index 25778df..1895b82 100644 --- a/llmfs/__main__.py +++ b/llmfs/__main__.py @@ -3,7 +3,7 @@ from pathlib import Path import tiktoken import torch from llmfs.gpt import DummyGPT, GPTConfig, TransformerBlock -from llmfs.tokenizers import BPETokenizer +from llmfs.tokenizers import BPETokenizer, Tokenizer DATA_DIR = Path(__file__).parent.parent / "data" @@ -34,22 +34,25 @@ def generate_text_simple( return idx +def txt_to_tokens(tokenizer: Tokenizer, text: str) -> torch.Tensor: + encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"}) + return torch.tensor(encoded).unsqueeze(0) + + +def tokens_to_txt(tokenizer: Tokenizer, tokens: torch.Tensor) -> str: + return tokenizer.decode(tokens.squeeze(0).tolist()) + + def process_text(text: str): print("Buiding tokenizer") # tokenizer = BPETokenizer.build(text) tokenizer = tiktoken.encoding_for_model("gpt2") vocab_size = tokenizer.max_token_value + 1 print(f"Tokenizer is ready. Vocab size: {vocab_size}") - batch = torch.stack( - [ - torch.tensor(tokenizer.encode("Every effort moves you")), - torch.tensor(tokenizer.encode("Every day holds a")), - ], - dim=0, - ) + cfg = GPTConfig( vocab_size=vocab_size, - context_length=1024, + context_length=256, embedding_dim=768, n_heads=12, n_layers=12, @@ -58,12 +61,11 @@ def process_text(text: str): ) gpt = DummyGPT(cfg) gpt.eval() - start_ctx = "Hello, I am" - encoded = tokenizer.encode(start_ctx) - encoded_tensor = torch.tensor(encoded).unsqueeze(0) - out = generate_text_simple(gpt, encoded_tensor, 6, cfg.context_length) - decoded_text = tokenizer.decode(out.squeeze(0).tolist()) - print(decoded_text) + text = "Every effort moves you" + encoded = txt_to_tokens(tokenizer, text) + out = generate_text_simple(gpt, encoded, 6, cfg.context_length) + decoded = tokens_to_txt(tokenizer, out) + print(decoded) # logits = gpt(batch) # print(logits) # print(logits.shape) diff --git a/llmfs/tokenizers/__init__.py b/llmfs/tokenizers/__init__.py index e51a69e..0a56135 100644 --- a/llmfs/tokenizers/__init__.py +++ b/llmfs/tokenizers/__init__.py @@ -1,12 +1,14 @@ from .stoopid import StoopidTokenizer from .bpe import BPETokenizer -from typing import Protocol +from typing import AbstractSet, Iterable, Protocol __all__ = ["BPETokenizer", "StoopidTokenizer", "Tokenizer"] class Tokenizer(Protocol): - def encode(self, text: str) -> list[int]: ... + def encode( + self, text: str, allowed_special: AbstractSet[str] = set() + ) -> list[int]: ... def decode(self, tokens: list[int]) -> str: ... @property def max_token_value(self) -> int: ... diff --git a/llmfs/tokenizers/bpe.py b/llmfs/tokenizers/bpe.py index 2ee7f9d..14b4b36 100644 --- a/llmfs/tokenizers/bpe.py +++ b/llmfs/tokenizers/bpe.py @@ -1,6 +1,7 @@ from collections import Counter +from collections.abc import Iterable import re -from typing import Self +from typing import AbstractSet, Self class BPETokenizer: @@ -8,9 +9,11 @@ class BPETokenizer: UNKNOWN_TOKEN: str = "<|unknowntoken|>" END_OF_TEXT: str = "<|endoftext|>" - def __init__(self, vocabulary: dict[str, int]) -> None: + def __init__(self, vocabulary: dict[str, int], specials: dict[str, int]) -> None: self.forward: dict[str, int] = vocabulary self.reverse: dict[int, str] = {idx: token for token, idx in vocabulary.items()} + self.specials = specials + self.special_values = set(specials.values()) self.unk_token: int = self.forward[self.UNKNOWN_TOKEN] @property @@ -18,11 +21,16 @@ class BPETokenizer: return len(self.forward) @classmethod - def build(cls, text: str, target_vocab_size: int = -1) -> Self: + def build( + cls, + text: str, + target_vocab_size: int = -1, + specials: set[str] = {END_OF_TEXT, UNKNOWN_TOKEN}, + ) -> Self: preprocessed = list( filter(bool, map(lambda x: x.lower().strip(), cls.SPLIT_PAT.split(text))) ) - pre_vocab: set[str] = set() + pre_vocab: set[str] = specials for word in preprocessed: pre_vocab |= set(word) vocab: list[str] = sorted(pre_vocab) @@ -63,18 +71,27 @@ class BPETokenizer: vocab.extend([" ", cls.UNKNOWN_TOKEN, cls.END_OF_TEXT]) vocab_dict = {token: i for i, token in enumerate(vocab)} - return cls(vocab_dict) + specials_dict = {special: vocab_dict[special] for special in specials} + return cls(vocab_dict, specials_dict) - def _encode_word(self, word: str) -> list[int]: + def _encode_word( + self, + word: str, + allowed_specials: set[int], + ) -> list[int]: + encoded: list[int] = [] parts = list(word.strip()) start_part_idx = 0 - encoded: list[int] = [] while start_part_idx < len(parts): found = False for i in range(len(parts), start_part_idx, -1): token = self.forward.get("".join(parts[start_part_idx:i])) if token is not None: found = True + if token in self.special_values and token not in allowed_specials: + raise ValueError( + f"The token '{self.reverse[token]}' is not allowed." + ) encoded.append(token) start_part_idx = i break @@ -86,7 +103,11 @@ class BPETokenizer: start_part_idx += 1 return encoded - def encode(self, text: str | list[str]) -> list[int]: + def encode( + self, + text: str | list[str], + allowed_special: AbstractSet[str] = set(), + ) -> list[int]: if isinstance(text, list): text = f" {self.END_OF_TEXT} ".join(text) @@ -94,10 +115,15 @@ class BPETokenizer: filter(bool, map(lambda x: x.lower().strip(), self.SPLIT_PAT.split(text))) ) tokens: list[int] = [] + allowed_specials_tokens = { + self.forward[token] for token in allowed_special or [] + } for word in preprocessed: if tokens: tokens.append(self.forward[" "]) - tokens.extend(self._encode_word(word)) + tokens.extend( + self._encode_word(word, allowed_specials=allowed_specials_tokens) + ) return tokens def decode(self, tokens: list[int]) -> str: