Skip to content

Commit

Permalink
CovidQA, duo cleanup (#140)
Browse files Browse the repository at this point in the history
* fix some bugs introduced with trec-covid introduction

* fix monoT5

* fix model_type

* bf

* fix model

* upload 10k msmarco model, load directly

* addback bert for seq clas

* fix default for du and change title to optional
  • Loading branch information
ronakice authored Jan 5, 2021
1 parent 05a63f1 commit 623285a
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 54 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ texts = hits_to_texts(hits)
# Option 2: here's what Pyserini would have retrieved, hard-coded
passages = [['7744105', 'For Earth-centered it was Geocentric Theory proposed by greeks under the guidance of Ptolemy and Sun-centered was Heliocentric theory proposed by Nicolas Copernicus in 16th century A.D. In short, Your Answers are: 1st blank - Geo-Centric Theory. 2nd blank - Heliocentric Theory.'], ['2593796', 'Copernicus proposed a heliocentric model of the solar system â\x80\x93 a model where everything orbited around the Sun. Today, with advancements in science and technology, the geocentric model seems preposterous.he geocentric model, also known as the Ptolemaic system, is a theory that was developed by philosophers in Ancient Greece and was named after the philosopher Claudius Ptolemy who lived circa 90 to 168 A.D. It was developed to explain how the planets, the Sun, and even the stars orbit around the Earth.'], ['6217200', 'The geocentric model, also known as the Ptolemaic system, is a theory that was developed by philosophers in Ancient Greece and was named after the philosopher Claudius Ptolemy who lived circa 90 to 168 A.D. It was developed to explain how the planets, the Sun, and even the stars orbit around the Earth.opernicus proposed a heliocentric model of the solar system â\x80\x93 a model where everything orbited around the Sun. Today, with advancements in science and technology, the geocentric model seems preposterous.'], ['3276925', 'Copernicus proposed a heliocentric model of the solar system â\x80\x93 a model where everything orbited around the Sun. Today, with advancements in science and technology, the geocentric model seems preposterous.Simple tools, such as the telescope â\x80\x93 which helped convince Galileo that the Earth was not the center of the universe â\x80\x93 can prove that ancient theory incorrect.ou might want to check out one article on the history of the geocentric model and one regarding the geocentric theory. Here are links to two other articles from Universe Today on what the center of the universe is and Galileo one of the advocates of the heliocentric model.'], ['6217208', 'Copernicus proposed a heliocentric model of the solar system â\x80\x93 a model where everything orbited around the Sun. Today, with advancements in science and technology, the geocentric model seems preposterous.Simple tools, such as the telescope â\x80\x93 which helped convince Galileo that the Earth was not the center of the universe â\x80\x93 can prove that ancient theory incorrect.opernicus proposed a heliocentric model of the solar system â\x80\x93 a model where everything orbited around the Sun. Today, with advancements in science and technology, the geocentric model seems preposterous.'], ['4280557', 'The geocentric model, also known as the Ptolemaic system, is a theory that was developed by philosophers in Ancient Greece and was named after the philosopher Claudius Ptolemy who lived circa 90 to 168 A.D. It was developed to explain how the planets, the Sun, and even the stars orbit around the Earth.imple tools, such as the telescope â\x80\x93 which helped convince Galileo that the Earth was not the center of the universe â\x80\x93 can prove that ancient theory incorrect. You might want to check out one article on the history of the geocentric model and one regarding the geocentric theory.'], ['264181', 'Nicolaus Copernicus (b. 1473â\x80\x93d. 1543) was the first modern author to propose a heliocentric theory of the universe. From the time that Ptolemy of Alexandria (c. 150 CE) constructed a mathematically competent version of geocentric astronomy to Copernicusâ\x80\x99s mature heliocentric version (1543), experts knew that the Ptolemaic system diverged from the geocentric concentric-sphere conception of Aristotle.'], ['4280558', 'A Geocentric theory is an astronomical theory which describes the universe as a Geocentric system, i.e., a system which puts the Earth in the center of the universe, and describes other objects from the point of view of the Earth. Geocentric theory is an astronomical theory which describes the universe as a Geocentric system, i.e., a system which puts the Earth in the center of the universe, and describes other objects from the point of view of the Earth.'], ['3276926', 'The geocentric model, also known as the Ptolemaic system, is a theory that was developed by philosophers in Ancient Greece and was named after the philosopher Claudius Ptolemy who lived circa 90 to 168 A.D. It was developed to explain how the planets, the Sun, and even the stars orbit around the Earth.ou might want to check out one article on the history of the geocentric model and one regarding the geocentric theory. Here are links to two other articles from Universe Today on what the center of the universe is and Galileo one of the advocates of the heliocentric model.'], ['5183032', "After 1,400 years, Copernicus was the first to propose a theory which differed from Ptolemy's geocentric system, according to which the earth is at rest in the center with the rest of the planets revolving around it."]]

texts = [ Text(p[1], '', {'docid': p[0]}, 0) for p in passages] # Note, pyserini scores don't matter since T5 will ignore them.
texts = [ Text(p[1], {'docid': p[0]}, 0) for p in passages] # Note, pyserini scores don't matter since T5 will ignore them.

# Either option, let's print out the passages prior to reranking:
for i in range(0, 10):
Expand Down
6 changes: 4 additions & 2 deletions docs/experiments-CovidQA.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ NL Question:
```
python -um pygaggle.run.evaluate_kaggle_highlighter --method t5 \
--dataset data/kaggle-lit-review-0.2.json \
--index-dir indexes/lucene-index-cord19-paragraph-2020-05-12
--index-dir indexes/lucene-index-cord19-paragraph-2020-05-12 \
--model castorini/monot5-base-msmarco-10k
```

The following output will be visible after it has finished:
Expand All @@ -129,7 +130,8 @@ Keyword Query:
python -um pygaggle.run.evaluate_kaggle_highlighter --method t5 \
--split kq \
--dataset data/kaggle-lit-review-0.2.json \
--index-dir indexes/lucene-index-cord19-paragraph-2020-05-12
--index-dir indexes/lucene-index-cord19-paragraph-2020-05-12 \
--model castorini/monot5-base-msmarco-10k
```

The following output will be visible after it has finished:
Expand Down
2 changes: 1 addition & 1 deletion pygaggle/data/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_relevance_examples(self,
for k, v in mean_stats.items():
logging.info(f'{k}: {np.mean(v)}')
return [RelevanceExample(Query(text=query_text, id=qid),
list(map(lambda s: Text(s[1], '', dict(docid=s[0])),
list(map(lambda s: Text(s[1], dict(docid=s[0])),
zip(cands, cands_text))),
rel_cands)
for qid, (query_text, cands, cands_text, rel_cands) in example_map.items()]
3 changes: 3 additions & 0 deletions pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from functools import lru_cache
import logging
from typing import List, Optional
import json
import re
Expand All @@ -10,6 +11,7 @@

__all__ = ['RelevanceExample', 'Cord19DocumentLoader', 'Cord19AbstractLoader']


@dataclass
class RelevanceExample:
query: Query
Expand All @@ -27,6 +29,7 @@ class Cord19Document:
def all_text(self):
return '\n'.join((self.abstract, self.body_text, self.ref_entries))


@dataclass
class Cord19Abstract:
title: str
Expand Down
10 changes: 5 additions & 5 deletions pygaggle/data/trec_covid.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def from_folder(cls,

def query_document_tuples(self):
return [((ex.qid, ex.text, ex.relevant_candidates), perm_pas)
for ex in self.examples
for ex in self.examples
for perm_pas in permutations(ex.candidates, r=1)]

def to_relevance_examples(self,
Expand All @@ -110,7 +110,7 @@ def to_relevance_examples(self,
for passage in passages][0])
except ValueError as e:
logging.error(e)
logging.warning(f'Skipping passages')
logging.warning('Skipping passages')
continue
example_map[qid][3].append(cands[0] in rel_cands)
mean_stats = defaultdict(list)
Expand Down Expand Up @@ -143,7 +143,7 @@ def to_relevance_examples(self,
for k, v in mean_stats.items():
logging.info(f'{k}: {np.mean(v)}')
rel = [RelevanceExample(Query(text=query_text, id=qid),
list(map(lambda s: Text(s[1], s[2], dict(docid=s[0])),
zip(cands, cands_text, title))),
rel_cands) for qid, (query_text, cands, cands_text, rel_cands, title) in example_map.items()]
list(map(lambda s: Text(s[1], dict(docid=s[0]), title=s[2]),
zip(cands, cands_text, title))), rel_cands)
for qid, (query_text, cands, cands_text, rel_cands, title) in example_map.items()]
return rel
6 changes: 4 additions & 2 deletions pygaggle/rerank/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ class Text:
score : Optional[float]
The score of the text. For example, the score might be the BM25 score
from an initial retrieval stage.
title : Optional[str]
The text's title.
"""

def __init__(self,
text: str,
title: str = '',
metadata: Mapping[str, Any] = None,
score: Optional[float] = 0):
score: Optional[float] = 0,
title: Optional[str] = None):
self.text = text
if metadata is None:
metadata = dict()
Expand Down
52 changes: 21 additions & 31 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoTokenizer,
BertForQuestionAnswering,
BertForSequenceClassification)
import torch

Expand All @@ -22,10 +20,8 @@
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import (CachedT5ModelLoader,
RerankerEvaluator,
from pygaggle.model import (RerankerEvaluator,
SimpleBatchTokenizer,
T5BatchTokenizer,
metric_names)
from pygaggle.data import LitReviewDataset
from pygaggle.settings import Cord19Settings
Expand All @@ -45,52 +41,43 @@ class KaggleEvaluationOptions(BaseModel):
split: str
do_lower_case: bool
metrics: List[str]
model_name: Optional[str]
model: Optional[str]
tokenizer_name: Optional[str]

@validator('dataset')
def dataset_exists(cls, v: Path):
assert v.exists(), 'dataset must exist'
return v

@validator('model_name')
def model_name_sane(cls, v: Optional[str], values, **kwargs):
@validator('model')
def model_sane(cls, v: Optional[str], values, **kwargs):
method = values['method']
if method == 'transformer' and v is None:
raise ValueError('transformer name must be specified')
elif method == 't5':
return SETTINGS.t5_model_type
if v == 'biobert':
return 'monologg/biobert_v1.1_pubmed'
return v

@validator('tokenizer_name')
def tokenizer_sane(cls, v: str, values, **kwargs):
if v is None:
return values['model_name']
return values['model']
return v


def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
loader = CachedT5ModelLoader(SETTINGS.t5_model_dir,
SETTINGS.cache_dir,
'ranker',
SETTINGS.t5_model_type,
SETTINGS.flush_cache)
device = torch.device(options.device)
model = loader.load().to(device).eval()
tokenizer = MonoT5.get_tokenizer(options.model_type,
do_lower_case=options.do_lower_case,
batch_size=options.batch_size)
model = MonoT5.get_model(options.model,
device=options.device)
tokenizer = MonoT5.get_tokenizer(options.model, batch_size=options.batch_size)
return MonoT5(model, tokenizer)


def construct_transformer(options: KaggleEvaluationOptions) -> Reranker:
device = torch.device(options.device)
try:
model = AutoModel.from_pretrained(options.model_name).to(device).eval()
model = AutoModel.from_pretrained(options.model).to(device).eval()
except OSError:
model = AutoModel.from_pretrained(options.model_name,
model = AutoModel.from_pretrained(options.model,
from_tf=True).to(device).eval()
tokenizer = SimpleBatchTokenizer(
AutoTokenizer.from_pretrained(
Expand All @@ -103,11 +90,11 @@ def construct_transformer(options: KaggleEvaluationOptions) -> Reranker:
def construct_seq_class_transformer(options:
KaggleEvaluationOptions) -> Reranker:
try:
model = MonoBERT.get_model(options.model_name, device=options.device)
model = MonoBERT.get_model(options.model, device=options.device)
except OSError:
try:
model = MonoBERT.get_model(
options.model_name,
options.model,
from_tf=True,
device=options.device)
except AttributeError:
Expand All @@ -117,7 +104,7 @@ def construct_seq_class_transformer(options:
BertForSequenceClassification.weight = torch.nn.Parameter(
torch.zeros(2, 768))
model = BertForSequenceClassification.from_pretrained(
options.model_name, from_tf=True)
options.model, from_tf=True)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(options.device)
Expand All @@ -132,14 +119,14 @@ def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker:
# Refactor
try:
model = AutoModelForQuestionAnswering.from_pretrained(
options.model_name)
options.model)
except OSError:
model = AutoModelForQuestionAnswering.from_pretrained(
options.model_name, from_tf=True)
options.model, from_tf=True)
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
options.tokenizer_name,
options.tokenizer_name,
do_lower_case=options.do_lower_case,
use_fast=False)
return QuestionAnsweringTransformerReranker(model, tokenizer)
Expand All @@ -157,7 +144,9 @@ def main():
required=True,
type=str,
choices=METHOD_CHOICES),
opt('--model-name', type=str),
opt('--model',
type=str,
help='Path to pre-trained model or huggingface model name'),
opt('--split', type=str, default='nq', choices=('nq', 'kq')),
opt('--batch-size', '-bsz', type=int, default=96),
opt('--device', type=str, default='cuda:0'),
Expand All @@ -167,7 +156,8 @@ def main():
type=str,
nargs='+',
default=metric_names(),
choices=metric_names()))
choices=metric_names()),
opt('--model-type', type=str))
args = apb.parser.parse_args()
options = KaggleEvaluationOptions(**vars(args))
ds = LitReviewDataset.from_file(str(options.dataset))
Expand Down
10 changes: 4 additions & 6 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoTokenizer,
AutoModelForSequenceClassification,
BertForSequenceClassification,
T5ForConditionalGeneration)
BertForSequenceClassification)
import torch

from .args import ArgumentParserBuilder, opt
Expand All @@ -22,7 +20,6 @@
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import (SimpleBatchTokenizer,
T5BatchTokenizer,
RerankerEvaluator,
DuoRerankerEvaluator,
metric_names,
Expand Down Expand Up @@ -92,8 +89,8 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
def construct_duo_t5(options: PassageRankingEvaluationOptions) -> Tuple[Reranker, Reranker]:
mono_reranker = construct_t5(options)
model = DuoT5.get_model(options.duo_model,
from_tf=options.from_tf,
device=options.device)
from_tf=options.from_tf,
device=options.device)
tokenizer = DuoT5.get_tokenizer(options.model_type, batch_size=options.batch_size)
return mono_reranker, DuoT5(model, tokenizer)

Expand Down Expand Up @@ -152,6 +149,7 @@ def main():
help='Path to pre-trained model or huggingface model name'),
opt('--duo_model',
type=str,
default='',
help='Path to pre-trained model or huggingface model name'),
opt('--mono_hits',
type=int,
Expand Down
13 changes: 7 additions & 6 deletions pygaggle/run/evaluate_trec_covid_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer',
'random')


class DocumentRankingEvaluationOptions(BaseModel):
task: str
dataset: Path
Expand Down Expand Up @@ -165,20 +166,20 @@ def main():
help='The T5 tokenizer name.'),
opt('--tokenizer-name',
type=str,
help = 'The name of the tokenizer to pull from huggingface using the AutoTokenizer class. If '
'left empty, this will be set to the model name.'),
help='The name of the tokenizer to pull from huggingface using the AutoTokenizer class. If '
'left empty, this will be set to the model name.'),
opt('--seg-size',
type=int,
default=10,
help='The number of sentences in each segment. For example, given a document with sentences'
'[1,2,3,4,5], a seg_size of 3, and a stride of 2, the document will be broken into segments'
'[[1, 2, 3], [3, 4, 5], and [5]].'),
'[1,2,3,4,5], a seg_size of 3, and a stride of 2, the document will be broken into segments'
'[[1, 2, 3], [3, 4, 5], and [5]].'),
opt('--seg-stride',
type=int,
default=5,
help='The number of sentences to increment between each segment. For example, given a document'
'with sentences [1,2,3,4,5], a seg_size of 3, and a stride of 2, the document will be broken into'
'segments [[1, 2, 3], [3, 4, 5], and [5]].'),
'with sentences [1,2,3,4,5], a seg_size of 3, and a stride of 2, the document will be broken into'
'segments [[1, 2, 3], [3, 4, 5], and [5]].'),
opt('--aggregate-method',
type=str,
default="max",
Expand Down

0 comments on commit 623285a

Please sign in to comment.