48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
from typing import override
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
from llmfs.tokenizers import Tokenizer
|
|
|
|
|
|
class GPTDataSetV1(Dataset[tuple[torch.Tensor, torch.Tensor]]):
|
|
def __init__(
|
|
self, text: str, tokenizer: Tokenizer, max_len: int, stride: int
|
|
) -> None:
|
|
self.input_ids: list[torch.Tensor] = []
|
|
self.target_ids: list[torch.Tensor] = []
|
|
token_ids = tokenizer.encode(text)
|
|
for i in range(0, len(token_ids) - max_len, stride):
|
|
input_chunk = token_ids[i : i + max_len]
|
|
target_chunk = token_ids[i + 1 : i + max_len + 1]
|
|
self.input_ids.append(torch.tensor(input_chunk))
|
|
self.target_ids.append(torch.tensor(target_chunk))
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.input_ids)
|
|
|
|
@override
|
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return self.input_ids[idx], self.target_ids[idx]
|
|
|
|
@staticmethod
|
|
def data_loader(
|
|
text: str,
|
|
encoder: Tokenizer,
|
|
batch_size: int = 4,
|
|
max_len: int = 256,
|
|
stride: int = 128,
|
|
shuffle: bool = True,
|
|
drop_last: bool = True,
|
|
num_worker: int = 0,
|
|
) -> DataLoader[tuple[torch.Tensor, ...]]:
|
|
dataset = GPTDataSetV1(text, encoder, max_len, stride)
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
drop_last=drop_last,
|
|
num_workers=num_worker,
|
|
)
|