Files
llmfs/llmfs/datasets/v1.py
2026-04-30 03:46:35 +02:00

48 lines
1.5 KiB
Python

from typing import override
import torch
from torch.utils.data import Dataset, DataLoader
from llmfs.tokenizers import Tokenizer
class GPTDataSetV1(Dataset[tuple[torch.Tensor, torch.Tensor]]):
def __init__(
self, text: str, tokenizer: Tokenizer, max_len: int, stride: int
) -> None:
self.input_ids: list[torch.Tensor] = []
self.target_ids: list[torch.Tensor] = []
token_ids = tokenizer.encode(text)
for i in range(0, len(token_ids) - max_len, stride):
input_chunk = token_ids[i : i + max_len]
target_chunk = token_ids[i + 1 : i + max_len + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))
def __len__(self) -> int:
return len(self.input_ids)
@override
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
return self.input_ids[idx], self.target_ids[idx]
@staticmethod
def data_loader(
text: str,
encoder: Tokenizer,
batch_size: int = 4,
max_len: int = 256,
stride: int = 128,
shuffle: bool = True,
drop_last: bool = True,
num_worker: int = 0,
) -> DataLoader[tuple[torch.Tensor, ...]]:
dataset = GPTDataSetV1(text, encoder, max_len, stride)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_worker,
)