diff --git a/transformations/diverse_paraphrase/README.md b/transformations/diverse_paraphrase/README.md new file mode 100644 index 000000000..d6d50d1f9 --- /dev/null +++ b/transformations/diverse_paraphrase/README.md @@ -0,0 +1,60 @@ +# Diverse Paraphrase Generation🦎 + ⌨️ → 🐍 +This transformation produces multiple diverse paraphrases for a given sentence in English. + +## What type of a transformation is this? +This transformation is a multi-output sentence level paraphrase generation model, specifically catered towards generating `num_outputs` (user specified) diverse outputs. + +It has support for 4 candidate selection methods: + +a) dips: Based on Kumar et. al. 2019. (See Below)\ +b) diverse_beam: Based on Vijaykumar et. al 2018. (See Below)\ +c) beam: Selects top `num_outputs` candidates in the beam search.\ +d) random: Randomly selects `num_outputs` candidates. + +Eg: +```python +>>> t = DiverseParaphrase(augmenter='dips', num_outputs=3) +>>> t.generate('Joe Biden is the President of USA.') +``` + +Replace augmenter with any of the above mentioned options. \ +\ +Default: augmenter='dips', num_outputs=3. \ +In most cases, dips should be the preferred choice. + +## What tasks does it intend to benefit? +This perturbation would benefit all tasks which need diverse paraphrase candidates for augmentation in tasks like text classification, text generation, etc. + +## Previous Work +1. DiPS: Original Implementation [here](/~https://github.com/malllabiisc/DiPS) +```bibtex +@inproceedings{dips2019, + title = "Submodular Optimization-based Diverse Paraphrasing and its Effectiveness in Data Augmentation", + author = "Kumar, Ashutosh and + Bhattamishra, Satwik and + Bhandari, Manik and + Talukdar, Partha", + booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", + month = jun, + year = "2019", + address = "Minneapolis, Minnesota", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/N19-1363", + pages = "3609--3619" +} +``` + +2. Diverse Beam +```bibtex +@paper{AAAI1817329, + author = {Ashwin Vijayakumar and Michael Cogswell and Ramprasaath Selvaraju and Qing Sun and Stefan Lee and David Crandall and Dhruv Batra}, + title = {Diverse Beam Search for Improved Description of Complex Scenes}, + conference = {AAAI Conference on Artificial Intelligence}, + year = {2018}, + keywords = {Recurrent Neural Networks, Beam Search, Diversity}, + url = {https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17329} +} +``` + +## What are the limitations of this transformation? +The base paraphrasing model used by this transformation is backtranslation (En -> De -> En). If the base model is poor, the candidate outputs will be of low-quality. diff --git a/transformations/diverse_paraphrase/__init__.py b/transformations/diverse_paraphrase/__init__.py new file mode 100644 index 000000000..930cdce0b --- /dev/null +++ b/transformations/diverse_paraphrase/__init__.py @@ -0,0 +1 @@ +from .transformation import * diff --git a/transformations/diverse_paraphrase/requirements.txt b/transformations/diverse_paraphrase/requirements.txt new file mode 100644 index 000000000..5636ba49b --- /dev/null +++ b/transformations/diverse_paraphrase/requirements.txt @@ -0,0 +1,3 @@ +# for diverse paraphrase dips +nltk==3.6.2 +torchtext diff --git a/transformations/diverse_paraphrase/submod/submodopt.py b/transformations/diverse_paraphrase/submod/submodopt.py new file mode 100644 index 000000000..80f06aa70 --- /dev/null +++ b/transformations/diverse_paraphrase/submod/submodopt.py @@ -0,0 +1,111 @@ +import numpy as np +import scipy.linalg as la +import pdb + +from transformations.diverse_paraphrase.submod.submodular_funcs import ( + distinct_ngrams, + ngram_overlap, + similarity_func, + seq_func, + ngram_overlap_unit, + similarity_gain, + seq_gain, +) + + +class SubmodularOpt: + """ + A class used to select final candidates for diverse paraphrasing + using submodular optimization + + """ + def __init__(self, V=None, v=None, **kwargs): + """ + Parameters + --- + + V : list of str + Ground Set Generations from which candidates are selected + v : str + Sentence which is used to select semantically equivalent + outputs + """ + self.v = v + self.V = V + + def initialize_function(self, lam, a1=1.0, a2=1.0, b1=1.0, b2=1.0): + """ + Parameters + --- + + lam: float (0 <= lam <= 1.) + Determines fraction of weight assigned to the diversity and fidelity(quality) components + a1 : float + Weight assigned to semantic similarity based on word-vectors of V and v. + a2 : float + Weight assigned to semantic similarity based on lexical overlaps between V and v. + b1 : float + Weight assigned to n gram diversity within candidates in V. + b2 : float + Weight assigned to coverage function to obtain diversity within candidates in V. + """ + self.a1 = a1 + self.a2 = a2 + self.b1 = b1 + self.b2 = b2 + + self.noverlap_norm = ngram_overlap(self.v, self.V) + self.ndistinct_norm = distinct_ngrams(self.V) + self.sim_norm = similarity_func(self.v, self.V) + self.edit_norm = np.sqrt(len(self.V)) + self.lam = lam + + def final_func(self, pos_sets, rem_list, selec_set): + distinct_score = ( + np.array(list(map(distinct_ngrams, pos_sets))) / self.ndistinct_norm + ) + + base_noverlap_score = ngram_overlap(self.v, selec_set) + base_sim_score = similarity_func(self.v, selec_set) + base_edit_score = seq_func(self.V, selec_set) + + noverlap_score = [] + for sent in rem_list: + noverlap_score.append(ngram_overlap_unit(self.v, sent, base_noverlap_score)) + noverlap_score = np.array(noverlap_score) / self.noverlap_norm + + sim_score = [] + for sent in rem_list: + sim_score.append(similarity_gain(self.v, sent, base_sim_score)) + sim_score = np.array(sim_score) / self.sim_norm + + edit_score = [] + for sent in rem_list: + edit_score.append(seq_gain(self.v, sent, base_edit_score)) + edit_score = np.array(edit_score) / self.edit_norm + + quality_score = self.a1 * sim_score + self.a2 * noverlap_score + diversity_score = self.b1 * distinct_score + self.b2 * edit_score + + final_score = self.lam * quality_score + (1 - self.lam) * diversity_score + + return final_score + + def maximize_func(self, k=5): + selec_sents = set() + ground_set = set(self.V) + selec_set = set(selec_sents) + rem_set = ground_set.difference(selec_set) + while len(selec_sents) < k: + + rem_list = list(rem_set) + pos_sets = [list(selec_set.union({x})) for x in rem_list] + + score_map = self.final_func(pos_sets, rem_list, selec_set) + max_idx = np.argmax(score_map) + + selec_sents = pos_sets[max_idx] + selec_set = set(selec_sents) + rem_set = ground_set.difference(selec_set) + + return selec_sents diff --git a/transformations/diverse_paraphrase/submod/submodular_funcs.py b/transformations/diverse_paraphrase/submod/submodular_funcs.py new file mode 100644 index 000000000..b1848b553 --- /dev/null +++ b/transformations/diverse_paraphrase/submod/submodular_funcs.py @@ -0,0 +1,218 @@ +import numpy as np +import scipy.linalg as la +import pdb +from nltk import ngrams +import difflib +import pickle +from time import time +import os +import torch + +import urllib.request +from tqdm import tqdm +from scipy.spatial.distance import pdist, squareform +import scipy +from numpy import dot +from numpy.linalg import norm +import gzip +import urllib +from torchtext.vocab import GloVe +from pathlib import Path + + +model = None + + +def trigger_dips(): + global model + def unk_init(x): + return torch.randn_like(x) + model = GloVe('6B', dim=50, unk_init=unk_init) + + +cos_sim = lambda a, b: dot(a, b) / (norm(a) * norm(b)) +rbf = lambda a, b, sigma: np.exp(-(np.sum((a - b) ** 2)) / sigma ** 2) + + +def sent2wvec(s): + v = model.get_vecs_by_tokens(s, lower_case_backup=True) + v = v.detach().cpu().numpy() + return v + + +def sentence_compare(s1, s2, kernel="cos", **kwargs): + l1 = s1.split() + l2 = s2.split() + + v1 = sent2wvec(l1) + v2 = sent2wvec(l2) + score = 0 + len_s1 = v1.shape[0] + for v in v1: + if kernel == "cos": + wscore = np.max(np.array([cos_sim(v, i) for i in v2])) + elif kernel == "rbf": + wscore = np.max(np.array([rbf(v, i, kwargs["sigma"]) for i in v2])) + else: + print("Error in kernel type") + score += wscore / len_s1 + + return score + + +def similarity_func(v, S): + if len(S): + score = 0.0 + + for sent in S: + score += sentence_compare(v, sent, kernel="rbf", sigma=1.0) + + return np.sqrt(score) + else: + return 0.0 + + +def similarity_gain(v, s, base_score=0.0): + score = 0.0 + score += sentence_compare(v, s, sigma=1.0) + score += base_score ** 2 + + return np.sqrt(score) + + +##################################################################################################################### +##################################################################################################################### + +########################################### NGRAM FUNCTIONS ######################################################### + + +def ngram_toks(sents, n=1): + ntoks = [] + for sent in sents: + ntok = list(ngrams(sent.split(), n)) + newtoks = [tok for tok in ntok] + ntoks += newtoks + return ntoks + + +def distinct_ngrams(S): + if len(S): + S = " ".join(S) + N = [1, 2, 3] + score = 0.0 + for n in N: + toks = set(ngram_toks([S], n)) + score += (1.0 / n) * len(toks) + + return score + else: + return 0.0 + + +def ngram_overlap(v, S): + if len(S): + N = [1, 2, 3] + score = 0.0 + + for n in N: + src_toks = set(ngram_toks([v], n)) + for sent in S: + sent_toks = set(ngram_toks(S, n)) + + overlap = src_toks.intersection(sent_toks) + + score += (1.0 / (4 - n)) * len(overlap) + + return np.sqrt(score) + else: + return 0.0 + + +def ngram_overlap_unit(v, S, base_score=0.0): + N = [1, 2, 3] + score = 0.0 + try: + temp = S[0] + except: + S = [S] + + for n in N: + src_toks = set(ngram_toks([v], n)) + sent_toks = set(ngram_toks([S], n)) + overlap = src_toks.intersection(sent_toks) + + score += (1.0 / (4 - n)) * len(overlap) + + return np.sqrt((base_score ** 2) + score) + + +##################################################################################################################### + +########################################### EDIT DISTANCE FUNCTION ################################################## + + +def seq_func(V, S): + if len(S): + score = 0.0 + for v in V: + for s in S: + vx = v.split() + sx = s.split() + + seq = difflib.SequenceMatcher(None, vx, sx) + score += seq.ratio() + + return np.sqrt(score) + else: + return 0.0 + + +def seq_gain(V, s, base_score=0.0): + gain = 0.0 + for v in V: + vx = v.split() + sx = s.split() + + seq = difflib.SequenceMatcher(None, vx, sx) + gain += seq.ratio() + + score = (base_score ** 2) + gain + + return np.sqrt(score) + + +def info_func(S, orig_count, ref_count): + if len(S): + score = 0.0 + for s in S: + stoks = set(s.split()) + orig_toks = set(orig_count.keys()) + + int_toks = stoks.intersection(orig_toks) + for tok in int_toks: + try: + score += orig_count[tok] / (1 + ref_count[tok]) + except: + score += orig_count[tok] + + return np.sqrt(score) + + else: + return 0.0 + + +def info_gain(s, orig_count, ref_count, base_score=0.0): + score = 0.0 + stoks = set(s.split()) + orig_toks = set(orig_count.keys()) + + int_toks = stoks.intersection(orig_toks) + for tok in int_toks: + try: + score += orig_count[tok] / (1 + ref_count[tok]) + except: + score += orig_count[tok] + + score += base_score ** 2 + + return np.sqrt(score) diff --git a/transformations/diverse_paraphrase/test.json b/transformations/diverse_paraphrase/test.json new file mode 100644 index 000000000..d233a55e4 --- /dev/null +++ b/transformations/diverse_paraphrase/test.json @@ -0,0 +1,5 @@ +{ + "type": "diverse_paraphrase", + "test_cases": [ + ] +} diff --git a/transformations/diverse_paraphrase/transformation.py b/transformations/diverse_paraphrase/transformation.py new file mode 100644 index 000000000..c6284a292 --- /dev/null +++ b/transformations/diverse_paraphrase/transformation.py @@ -0,0 +1,149 @@ +import random +import numpy as np +import torch +from random import sample +from transformers import FSMTForConditionalGeneration, FSMTTokenizer + +from interfaces.SentenceOperation import SentenceOperation +from tasks.TaskTypes import TaskType + +from transformations.diverse_paraphrase.submod.submodopt import SubmodularOpt +from transformations.diverse_paraphrase.submod.submodular_funcs import trigger_dips + + +class DiverseParaphrase(SentenceOperation): + tasks = [TaskType.TEXT_CLASSIFICATION, TaskType.TEXT_TO_TEXT_GENERATION] + languages = ["en"] + + def __init__(self, augmenter="dips", num_outputs=3, seed=42): + super().__init__() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(True) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + assert augmenter in ["dips", "random", "diverse_beam", "beam"] + if self.verbose: + choices = ["dips", "random", "diverse_beam", "beam"] + print( + "The base paraphraser being used is Backtranslation - Generating {} candidates based on {}\n".format( + num_outputs, augmenter + ) + ) + print("Primary options for augmenter : {}. \n".format(str(choices))) + print( + "Default: augmenter='dips', num_outputs=3. Change using DiverseParaphrase(augmenter=