forked from toraaglobal/MLOps
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
64 lines (53 loc) · 1.96 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import datasets
import pytorch_lightning as pl
from datasets import load_dataset
from transformers import AutoTokenizer
class DataModule(pl.LightningDataModule):
def __init__(
self,
model_name="google/bert_uncased_L-2_H-128_A-2",
batch_size=64,
max_length=128,
):
super().__init__()
self.batch_size = batch_size
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def prepare_data(self):
cola_dataset = load_dataset("glue", "cola")
self.train_data = cola_dataset["train"]
self.val_data = cola_dataset["validation"]
def tokenize_data(self, example):
return self.tokenizer(
example["sentence"],
truncation=True,
padding="max_length",
max_length=self.max_length,
)
def setup(self, stage=None):
# we set up only relevant datasets when stage is specified
if stage == "fit" or stage is None:
self.train_data = self.train_data.map(self.tokenize_data, batched=True)
self.train_data.set_format(
type="torch", columns=["input_ids", "attention_mask", "label"]
)
self.val_data = self.val_data.map(self.tokenize_data, batched=True)
self.val_data.set_format(
type="torch",
columns=["input_ids", "attention_mask", "label"],
output_all_columns=True,
)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_data, batch_size=self.batch_size, shuffle=False
)
if __name__ == "__main__":
data_model = DataModule()
data_model.prepare_data()
data_model.setup()
print(next(iter(data_model.train_dataloader()))["input_ids"].shape)