Files
llmfs/llmfs/tokenizers/bpe.py
2026-04-30 03:46:35 +02:00

108 lines
3.8 KiB
Python

from collections import Counter
import re
from typing import Self
class BPETokenizer:
SPLIT_PAT: re.Pattern[str] = re.compile(r"([,\.!?\"'\(\)\[\]\{\}_;:]|--|\s)")
UNKNOWN_TOKEN: str = "<|unknowntoken|>"
END_OF_TEXT: str = "<|endoftext|>"
def __init__(self, vocabulary: dict[str, int]) -> None:
self.forward: dict[str, int] = vocabulary
self.reverse: dict[int, str] = {idx: token for token, idx in vocabulary.items()}
self.unk_token: int = self.forward[self.UNKNOWN_TOKEN]
@property
def max_token_value(self) -> int:
return len(self.forward)
@classmethod
def build(cls, text: str, target_vocab_size: int = -1) -> Self:
preprocessed = list(
filter(bool, map(lambda x: x.lower().strip(), cls.SPLIT_PAT.split(text)))
)
pre_vocab: set[str] = set()
for word in preprocessed:
pre_vocab |= set(word)
vocab: list[str] = sorted(pre_vocab)
all_words = [list(word) for word in preprocessed]
while True:
if target_vocab_size > 0 and len(vocab) >= target_vocab_size:
break
pairs: Counter[str] = Counter()
for word in all_words:
for i in range(len(word) - 1):
pairs["".join((word[i], word[i + 1]))] += 1
mc = pairs.most_common(1)
if not mc:
break
[(mc_pair, _)] = mc
vocab.append(mc_pair)
new_words: list[list[str]] = []
for word in all_words:
new_word: list[str] = []
i = 0
while i < len(word):
if i + 1 == len(word):
new_word.append(word[i])
break
pair = "".join((word[i], word[i + 1]))
if pair == mc_pair:
new_word.append("".join(pair))
i += 2
continue
new_word.append(word[i])
i += 1
if not new_word:
continue
new_words.append(new_word)
all_words = new_words
vocab.extend([" ", cls.UNKNOWN_TOKEN, cls.END_OF_TEXT])
vocab_dict = {token: i for i, token in enumerate(vocab)}
return cls(vocab_dict)
def _encode_word(self, word: str) -> list[int]:
parts = list(word.strip())
start_part_idx = 0
encoded: list[int] = []
while start_part_idx < len(parts):
found = False
for i in range(len(parts), start_part_idx, -1):
token = self.forward.get("".join(parts[start_part_idx:i]))
if token is not None:
found = True
encoded.append(token)
start_part_idx = i
break
if not found:
token = self.forward.get(parts[start_part_idx])
if token is None:
token = self.forward[self.UNKNOWN_TOKEN]
encoded.append(token)
start_part_idx += 1
return encoded
def encode(self, text: str | list[str]) -> list[int]:
if isinstance(text, list):
text = f" {self.END_OF_TEXT} ".join(text)
preprocessed: list[str] = list(
filter(bool, map(lambda x: x.lower().strip(), self.SPLIT_PAT.split(text)))
)
tokens: list[int] = []
for word in preprocessed:
if tokens:
tokens.append(self.forward[" "])
tokens.extend(self._encode_word(word))
return tokens
def decode(self, tokens: list[int]) -> str:
text = "".join(
[self.reverse.get(token) or self.UNKNOWN_TOKEN for token in tokens]
)
return re.sub(r"\s+([\,\.:;\?\!\"\'\(\)\[\]\{\}])", r"\1", text)