-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathwebqsp.py
136 lines (105 loc) · 4.96 KB
/
webqsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import torch
from copy import deepcopy
from datasets import DatasetDict
from torch.utils.data import Dataset
from torch.utils.data.dataset import T_co
from tqdm import tqdm
class Constructor(object):
def __init__(self, args):
self.args = args
def to_seq2seq(self, raw_datasets: DatasetDict, cache_root: str):
if not len(raw_datasets) == 3:
raise AssertionError("Train, Dev, Test sections of dataset expected.")
train_dataset = TrainDataset(self.args, raw_datasets['train'], cache_root)
dev_dataset = DevDataset(self.args, raw_datasets['validation'], cache_root)
test_dataset = TestDataset(self.args, raw_datasets['test'], cache_root)
return train_dataset, dev_dataset, test_dataset
"""
{
"id": datasets.Value("string"),
"question": datasets.Value("string"),
"answers": datasets.features.Sequence(datasets.features.Sequence(datasets.Value("string"))),
"kg_tuples": datasets.features.Sequence(datasets.features.Sequence(datasets.Value("string"))),
}
"""
def serialize_kg_tuples(kg_tuples: list) -> str:
# [[head, rel, tail], [head, rel, tail]] -> "head rel tail | head rel tail"
return " | ".join([" ".join(t) for t in kg_tuples])
def kgqa_get_input(question: str, kg_tuples: list, entities: list) -> str:
serialized_kg = serialize_kg_tuples(kg_tuples).strip()
serialized_entity = " ".join([": ".join(elm) for elm in entities]).strip()
return question.strip(), serialized_entity + " | " + serialized_kg
class TrainDataset(Dataset):
def __init__(self, args, raw_datasets, cache_root):
self.raw_datasets = raw_datasets
cache_path = os.path.join(cache_root, 'webqsp_train.cache')
if os.path.exists(cache_path) and args.dataset.use_cache:
self.data = torch.load(cache_path)
else:
self.data = []
for raw_data in tqdm(self.raw_datasets):
extend_data = deepcopy(raw_data)
question = extend_data["question"]
answers = extend_data["answers"]
kg_tuples = extend_data["kg_tuples"]
entities = extend_data["entities"]
question, serialized_kg = kgqa_get_input(question, kg_tuples, entities)
seq_out = extend_data["s_expression"]
if seq_out != "null":
extend_data.update({"struct_in": serialized_kg, "text_in": question, "seq_out": seq_out})
self.data.append(extend_data)
if args.dataset.use_cache:
torch.save(self.data, cache_path)
def __getitem__(self, index) -> T_co:
return self.data[index]
def __len__(self):
return len(self.data)
class DevDataset(Dataset):
def __init__(self, args, raw_datasets, cache_root):
self.raw_datasets = raw_datasets
cache_path = os.path.join(cache_root, 'webqsp_dev.cache')
if os.path.exists(cache_path) and args.dataset.use_cache:
self.data = torch.load(cache_path)
else:
self.data = []
for raw_data in tqdm(self.raw_datasets):
extend_data = deepcopy(raw_data)
question = extend_data["question"]
answers = extend_data["answers"]
kg_tuples = extend_data["kg_tuples"]
entities = extend_data["entities"]
question, serialized_kg = kgqa_get_input(question, kg_tuples, entities)
seq_out = extend_data["s_expression"]
extend_data.update({"struct_in": serialized_kg, "text_in": question, "seq_out": seq_out})
self.data.append(extend_data)
if args.dataset.use_cache:
torch.save(self.data, cache_path)
def __getitem__(self, index) -> T_co:
return self.data[index]
def __len__(self):
return len(self.data)
class TestDataset(Dataset):
def __init__(self, args, raw_datasets, cache_root):
self.raw_datasets = raw_datasets
cache_path = os.path.join(cache_root, 'webqsp_test.cache')
if os.path.exists(cache_path) and args.dataset.use_cache:
self.data = torch.load(cache_path)
else:
self.data = []
for raw_data in tqdm(self.raw_datasets):
extend_data = deepcopy(raw_data)
question = extend_data["question"]
answers = extend_data["answers"]
kg_tuples = extend_data["kg_tuples"]
entities = extend_data["entities"]
question, serialized_kg = kgqa_get_input(question, kg_tuples, entities)
seq_out = extend_data["s_expression"]
extend_data.update({"struct_in": serialized_kg, "text_in": question, "seq_out": seq_out})
self.data.append(extend_data)
if args.dataset.use_cache:
torch.save(self.data, cache_path)
def __getitem__(self, index) -> T_co:
return self.data[index]
def __len__(self):
return len(self.data)