108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
from collections import Counter
|
|
import re
|
|
from typing import Self
|
|
|
|
|
|
class BPETokenizer:
|
|
SPLIT_PAT: re.Pattern[str] = re.compile(r"([,\.!?\"'\(\)\[\]\{\}_;:]|--|\s)")
|
|
UNKNOWN_TOKEN: str = "<|unknowntoken|>"
|
|
END_OF_TEXT: str = "<|endoftext|>"
|
|
|
|
def __init__(self, vocabulary: dict[str, int]) -> None:
|
|
self.forward: dict[str, int] = vocabulary
|
|
self.reverse: dict[int, str] = {idx: token for token, idx in vocabulary.items()}
|
|
self.unk_token: int = self.forward[self.UNKNOWN_TOKEN]
|
|
|
|
@property
|
|
def max_token_value(self) -> int:
|
|
return len(self.forward)
|
|
|
|
@classmethod
|
|
def build(cls, text: str, target_vocab_size: int = -1) -> Self:
|
|
preprocessed = list(
|
|
filter(bool, map(lambda x: x.lower().strip(), cls.SPLIT_PAT.split(text)))
|
|
)
|
|
pre_vocab: set[str] = set()
|
|
for word in preprocessed:
|
|
pre_vocab |= set(word)
|
|
vocab: list[str] = sorted(pre_vocab)
|
|
all_words = [list(word) for word in preprocessed]
|
|
while True:
|
|
if target_vocab_size > 0 and len(vocab) >= target_vocab_size:
|
|
break
|
|
pairs: Counter[str] = Counter()
|
|
for word in all_words:
|
|
for i in range(len(word) - 1):
|
|
pairs["".join((word[i], word[i + 1]))] += 1
|
|
mc = pairs.most_common(1)
|
|
if not mc:
|
|
break
|
|
[(mc_pair, _)] = mc
|
|
vocab.append(mc_pair)
|
|
new_words: list[list[str]] = []
|
|
for word in all_words:
|
|
new_word: list[str] = []
|
|
i = 0
|
|
while i < len(word):
|
|
if i + 1 == len(word):
|
|
new_word.append(word[i])
|
|
break
|
|
pair = "".join((word[i], word[i + 1]))
|
|
if pair == mc_pair:
|
|
new_word.append("".join(pair))
|
|
i += 2
|
|
continue
|
|
new_word.append(word[i])
|
|
i += 1
|
|
|
|
if not new_word:
|
|
continue
|
|
new_words.append(new_word)
|
|
|
|
all_words = new_words
|
|
|
|
vocab.extend([" ", cls.UNKNOWN_TOKEN, cls.END_OF_TEXT])
|
|
vocab_dict = {token: i for i, token in enumerate(vocab)}
|
|
return cls(vocab_dict)
|
|
|
|
def _encode_word(self, word: str) -> list[int]:
|
|
parts = list(word.strip())
|
|
start_part_idx = 0
|
|
encoded: list[int] = []
|
|
while start_part_idx < len(parts):
|
|
found = False
|
|
for i in range(len(parts), start_part_idx, -1):
|
|
token = self.forward.get("".join(parts[start_part_idx:i]))
|
|
if token is not None:
|
|
found = True
|
|
encoded.append(token)
|
|
start_part_idx = i
|
|
break
|
|
if not found:
|
|
token = self.forward.get(parts[start_part_idx])
|
|
if token is None:
|
|
token = self.forward[self.UNKNOWN_TOKEN]
|
|
encoded.append(token)
|
|
start_part_idx += 1
|
|
return encoded
|
|
|
|
def encode(self, text: str | list[str]) -> list[int]:
|
|
if isinstance(text, list):
|
|
text = f" {self.END_OF_TEXT} ".join(text)
|
|
|
|
preprocessed: list[str] = list(
|
|
filter(bool, map(lambda x: x.lower().strip(), self.SPLIT_PAT.split(text)))
|
|
)
|
|
tokens: list[int] = []
|
|
for word in preprocessed:
|
|
if tokens:
|
|
tokens.append(self.forward[" "])
|
|
tokens.extend(self._encode_word(word))
|
|
return tokens
|
|
|
|
def decode(self, tokens: list[int]) -> str:
|
|
text = "".join(
|
|
[self.reverse.get(token) or self.UNKNOWN_TOKEN for token in tokens]
|
|
)
|
|
return re.sub(r"\s+([\,\.:;\?\!\"\'\(\)\[\]\{\}])", r"\1", text)
|