-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathbuild_doc_tfidf.py
288 lines (241 loc) · 10.3 KB
/
build_doc_tfidf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A script to build the tf-idf document matrices for retrieval."""
import numpy as np
import scipy.sparse as sp
import argparse
import os
import math
import logging
import json
import copy
import pandas as pd
from multiprocessing import Pool as ProcessPool
from multiprocessing.util import Finalize
from functools import partial
from collections import Counter
import tfidf_util
from simple_tokenizer import SimpleTokenizer
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
console = logging.StreamHandler()
console.setFormatter(fmt)
logger.addHandler(console)
# ------------------------------------------------------------------------------
# Multiprocessing functions
# ------------------------------------------------------------------------------
DOC2IDX = None
PROCESS_TOK = None
PROCESS_DB = None
def init(tokenizer_class, db):
global PROCESS_TOK, PROCESS_DB
PROCESS_TOK = tokenizer_class()
Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100)
PROCESS_DB = db
def fetch_text(doc_id):
global PROCESS_DB
return PROCESS_DB[doc_id]
def tokenize(text):
global PROCESS_TOK
return PROCESS_TOK.tokenize(text)
# ------------------------------------------------------------------------------
# Build article --> word count sparse matrix.
# ------------------------------------------------------------------------------
def count(ngram, hash_size, doc_id):
"""Fetch the text of a document and compute hashed ngrams counts."""
global DOC2IDX
row, col, data = [], [], []
# Tokenize
tokens = tokenize(tfidf_util.normalize(fetch_text(doc_id)))
# Get ngrams from tokens, with stopword/punctuation filtering.
ngrams = tokens.ngrams(
n=ngram, uncased=True, filter_fn=tfidf_util.filter_ngram
)
# Hash ngrams and count occurences
counts = Counter([tfidf_util.hash(gram, hash_size) for gram in ngrams])
# Return in sparse matrix data format.
row.extend(counts.keys())
col.extend([DOC2IDX[doc_id]] * len(counts))
data.extend(counts.values())
return row, col, data
def get_count_matrix(args, file_path):
"""Form a sparse word to document count matrix (inverted index).
M[i, j] = # times word i appears in document j.
"""
# Map doc_ids to indexes
global DOC2IDX
doc_ids = {}
doc_metas = {}
year_score = {}
if_score = {}
uid2docid = {}
pmid2docid = {}
nan_cnt = 0
for filename in sorted(os.listdir(file_path)):
print(filename)
with open(os.path.join(file_path, filename), 'r') as f:
articles = json.load(f)['data']
for article in articles:
title = article['title']
kk = 0
while title in doc_ids:
title += f'_{kk}'
kk += 1
uid2docid[article['cord_uid']] = len(doc_ids)
doc_ids[title] = ' '.join([par['context'] for par in article['paragraphs']])
# year_score[title] = 0 if article['publish_time']['year'] == '2020' else -1e15 # Penalty if not 2020
# if_score[title] = float(article['IF']) if article['IF'] != 'NaN' else 0
# Keep metadata
doc_meta = {}
for key, val in article.items():
if key != 'paragraphs':
doc_meta[key] = val if val == val else 'NaN'
else:
doc_meta[key] = []
for para in val:
para_meta = {}
for para_key, para_val in para.items():
if para_key != 'context':
para_meta[para_key] = para_val if para_val == para_val else 'NaN'
doc_meta[key].append(para_meta)
if not pd.isnull(article.get('pubmed_id', np.nan)):
doc_metas[str(article['pubmed_id'])] = doc_meta # For BEST (might be duplicate)
pmid2docid[str(article['pubmed_id'])] = len(doc_ids) - 1
else:
nan_cnt += 1
doc_metas[article['title']] = doc_meta
DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)}
IDX2DOC = {i: doc_id for i, doc_id in enumerate(doc_ids)}
print('doc ids:', len(DOC2IDX))
print('doc metas:', len(doc_metas), 'with nan', str(nan_cnt))
# assert len(doc_ids)*2 == len(doc_metas) + nan_cnt
# Make score matrix
year_matrix = np.array([score for score in year_score.values()])
if_matrix = np.array([score for score in if_score.values()])
# assert len(year_matrix) == len(DOC2IDX) == len(if_matrix)
# Covidex scores
covidex_trec_question = json.load(open('data/ir_scores/covidex_trec_question.json'))
covidex_trec_query = json.load(open('data/ir_scores/covidex_trec_query.json'))
covidex_evalquery = json.load(open('data/ir_scores/covidex_evalquery.json'))
covidex_all = covidex_trec_question + covidex_trec_query + covidex_evalquery
print(f'Saving {len(covidex_all)} query results from Covidex')
covidex_dict = [{}, {}]
for result in covidex_all:
covidex_dict[0][result['question']] = [
uid2docid[cord_uid] for cord_uid in result['cord_uid'] if cord_uid in uid2docid
]
covidex_dict[1][result['question']] = [
score for score, cord_uid in zip(result['score'], result['cord_uid']) if cord_uid in uid2docid
]
tmp = [title for cord_uid, title in zip(result['cord_uid'], result['title']) if cord_uid in uid2docid]
# assert all([title.lower() in IDX2DOC[docid] for title, docid in zip(tmp, covidex_dict[result['question']])])
# Covidex BEST
best_trec_query = json.load(open('data/ir_scores/best_trec_query.json'))
best_evalquery = json.load(open('data/ir_scores/best_evalquery.json'))
best_all = best_trec_query + best_evalquery
print(f'Saving {len(best_all)} query results from BEST')
best_dict = [{}, {}]
for result in best_all:
best_dict[0][result['question']] = [
pmid2docid[pmid] for pmids in result['pmid'] for pmid in pmids if pmid in pmid2docid
]
best_dict[1][result['question']] = [ # Dummy score
pmid2docid[pmid] for pmids in result['pmid'] for pmid in pmids if pmid in pmid2docid
]
meta_score = {
'covidex': covidex_dict,
'best': best_dict,
}
# Setup worker pool
tok_class = SimpleTokenizer
workers = ProcessPool(
args.num_workers,
initializer=init,
initargs=(tok_class, doc_ids)
)
doc_ids = list(doc_ids.keys())
# Compute the count matrix in steps (to keep in memory)
logger.info('Mapping...')
row, col, data = [], [], []
step = max(int(len(doc_ids) / 10), 1)
batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)]
_count = partial(count, args.ngram, args.hash_size)
for i, batch in enumerate(batches):
logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25)
for b_row, b_col, b_data in workers.imap_unordered(_count, batch):
row.extend(b_row)
col.extend(b_col)
data.extend(b_data)
workers.close()
workers.join()
logger.info('Creating sparse matrix...')
count_matrix = sp.csr_matrix(
(data, (row, col)), shape=(args.hash_size, len(doc_ids))
)
count_matrix.sum_duplicates()
return count_matrix, (DOC2IDX, doc_ids, doc_metas, meta_score)
# ------------------------------------------------------------------------------
# Transform count matrix to different forms.
# ------------------------------------------------------------------------------
def get_tfidf_matrix(cnts):
"""Convert the word count matrix into tfidf one.
tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
* tf = term frequency in document
* N = number of documents
* Nt = number of occurences of term in all documents
"""
Ns = get_doc_freqs(cnts)
idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5))
idfs[idfs < 0] = 0
idfs = sp.diags(idfs, 0)
tfs = cnts.log1p()
tfidfs = idfs.dot(tfs)
return tfidfs
def get_doc_freqs(cnts):
"""Return word --> # of docs it appears in."""
binary = (cnts > 0).astype(int)
freqs = np.array(binary.sum(1)).squeeze()
return freqs
# ------------------------------------------------------------------------------
# Main.
# ------------------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('file_path', type=str, default=None,
help='Path to document texts')
parser.add_argument('out_dir', type=str, default=None,
help='Directory for saving output files')
parser.add_argument('--ngram', type=int, default=2,
help=('Use up to N-size n-grams '
'(e.g. 2 = unigrams + bigrams)'))
parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)),
help='Number of buckets to use for hashing ngrams')
parser.add_argument('--num-workers', type=int, default=None,
help='Number of CPU processes (for tokenizing, etc)')
args = parser.parse_args()
logging.info('Counting words...')
count_matrix, doc_dict = get_count_matrix(
args, args.file_path
)
logger.info('Making tfidf vectors...')
tfidf = get_tfidf_matrix(count_matrix)
logger.info('Getting word-doc frequencies...')
freqs = get_doc_freqs(count_matrix)
basename = os.path.splitext(os.path.basename(args.file_path))[0]
basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=simple' %
(args.ngram, args.hash_size))
filename = os.path.join(args.out_dir, basename)
logger.info('Saving to %s.npz' % filename)
metadata = {
'doc_freqs': freqs,
'tokenizer': 'simple',
'hash_size': args.hash_size,
'ngram': args.ngram,
'doc_dict': doc_dict
}
tfidf_util.save_sparse_csr(filename, tfidf, metadata)