Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use past agrument for GPT2 to speed up decoding #63

Merged
merged 8 commits into from
Nov 19, 2019
23 changes: 12 additions & 11 deletions nlpaug/augmenter/sentence/context_word_embs_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def init_context_word_embs_sentence_model(model_path, device, force_reload=False, temperature=1.0, top_k=None,
top_p=None):
top_p=None, return_past=True):
global CONTEXT_WORD_EMBS_SENTENCE_MODELS

model_name = os.path.basename(model_path)
Expand All @@ -23,16 +23,12 @@ def init_context_word_embs_sentence_model(model_path, device, force_reload=False
return CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name]

if 'xlnet' in model_path:
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p)
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past)
elif 'gpt2' in model_path:
model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p)
model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, return_past=return_past)
else:
raise ValueError('Model name value is unexpected. Only support XLNet and GPT2 model.')

CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name] = model
return model


class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
# https://arxiv.org/pdf/1707.07328.pdf
"""
Expand All @@ -52,8 +48,8 @@ class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
Default value is False and suggesting to keep it as False if performance is the consideration.
:param str name: Name of this augmenter

>>> import nlpaug.augmenter.word as naw
>>> aug = naw.ContextualWordEmbsForSentenceAug()
>>> import nlpaug.augmenter.sentence as nas
>>> aug = nas.ContextualWordEmbsForSentenceAug()
"""

def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=100, top_p=None,
Expand Down Expand Up @@ -86,7 +82,8 @@ def insert(self, data):
if data is None or data == '' or data.strip() == '':
return data

max_try = 100
max_try = 30 # On average 30 should be enough to complete a sentence
past = None
augmented_text = ''

for _ in range(max_try):
Expand All @@ -95,7 +92,11 @@ def insert(self, data):
if self.model_type in ['xlnet']:
text += ' ' + self.model.MASK_TOKEN

results = self.model.predict(text, n=1)
results = self.model.predict(text, n=1, past=past)

if self.model.return_past:
results, past = results

new_word, proba = results[0]

if new_word in self.SENTENCE_SEPARATOR:
Expand Down
14 changes: 11 additions & 3 deletions nlpaug/model/lang_models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Gpt2(LanguageModels):
# https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
SUBWORD_PREFIX = 'Ġ'

def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None):
def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, device=None, return_past=False):
super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p)
self.model_path = model_path

Expand All @@ -22,17 +22,21 @@ def __init__(self, model_path='gpt2', temperature=1.0, top_k=None, top_p=None, d
self.model.to(self.device)
self.model.eval()

self.return_past = return_past

def id2token(self, _id):
return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip()

def predict(self, text, target_word=None, n=1):
def predict(self, text, target_word=None, n=1, past=None):
# Convert feature
input_idxes = self.tokenizer.encode(text)
if past is not None:
input_idxes = input_idxes[-1:]
input_idxes = torch.tensor(input_idxes, device=self.device).unsqueeze(0).repeat(1, 1)

# Prediction
with torch.no_grad():
outputs = self.model(input_idxes)
outputs = self.model(input_ids=input_idxes, past=past)
target_token_logits = outputs[0][0][-1] # GPT2 only predict last token

# Selection
Expand All @@ -41,4 +45,8 @@ def predict(self, text, target_word=None, n=1):
target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)

results = self.pick(target_token_logits, target_word=target_word, n=n)
if self.return_past:
past = outputs[1]
results = (results, past,)

return results
13 changes: 10 additions & 3 deletions nlpaug/model/lang_models/xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class XlNet(LanguageModels):
NEW_PARAGRAPH_TOKEN = '<eop>'

def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, top_p=None, padding_text=None,
device=None):
device=None, return_past=False):
super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p)
self.model_path = model_path

Expand All @@ -39,13 +39,16 @@ def __init__(self, model_path='xlnet-base-cased', temperature=1.0, top_k=None, t
self.model.to(self.device)
self.model.eval()

self.return_past = return_past

def id2token(self, _id):
return self.tokenizer.decode(_id, clean_up_tokenization_spaces=True).strip()

def clean(self, text):
return text.replace(self.NEW_PARAGRAPH_TOKEN, '').strip()

def predict(self, text, target_word=None, n=1):
def predict(self, text, target_word=None, n=1, past=None):
# xlnet does not support `past`, instead there is `mems` which works differently
# Convert feature
input_idxes = self.tokenizer.encode(text)
concatenated_idxes = self.padding_text_idxes + input_idxes
Expand All @@ -72,4 +75,8 @@ def predict(self, text, target_word=None, n=1):
target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)

results = self.pick(target_token_logits, target_word=target_word, n=n)
return results

if self.return_past:
results = (results, past,) # Only, for API compatibility, past is not used for xlnet

return results