forked from timoschick/pet
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathsuperglue_data_splitting.py
41 lines (36 loc) · 1.52 KB
/
superglue_data_splitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import random
import shutil
TASK_DEV_SIZES = {"boolq": 500, "cb": 50, "copa": 50, "multirc": 50, "record": 7500, "rte": 250, "wic": 100, "wsc": 50}
def file_len(fname):
count = 0
with open(fname) as file:
for line in file:
if not line:
break
else:
count += 1
return count
if __name__ == "__main__":
for task_name, size in TASK_DEV_SIZES.items():
try:
os.makedirs(os.path.join("split_data", task_name))
except FileExistsError:
pass
train_file_path = os.path.join("data", task_name, "train.jsonl")
test_file_path = os.path.join("data", task_name, "val.jsonl")
new_train_file_path = os.path.join("split_data", task_name, "train.jsonl")
dev_file_path = os.path.join("split_data", task_name, "val.jsonl")
new_test_file_path = os.path.join("split_data", task_name, "test.jsonl")
total_lines = file_len(train_file_path)
print(f"{task_name}: {size} out of {total_lines}")
indexes = list(range(total_lines))
dev_indices = random.sample(indexes, size)
with open(train_file_path, encoding="utf8") as f, open(new_train_file_path, 'w', encoding="utf8") as g, open(
dev_file_path, 'w', encoding="utf8") as h:
for i, line in enumerate(f):
if i in dev_indices:
h.write(line)
else:
g.write(line)
shutil.copy(test_file_path, new_test_file_path)