Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 345711057
  • Loading branch information
PEGASUS Team authored and peterjliu committed Jul 12, 2022
1 parent d7316de commit d4371c3
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 71 deletions.
19 changes: 10 additions & 9 deletions pegasus/data/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,16 @@ def parser(input_dic):

supervised = input_dic["supervised"]

pretrain_inputs, pretrain_targets, pretrain_masked_inputs = pretrain_parsing_ops.sentence_mask_and_encode(
input_dic[input_feature], max_input_len, max_target_len,
max_total_words, parser_strategy, parser_masked_sentence_ratio,
parser_masked_words_ratio, parser_mask_word_options_prob,
parser_mask_sentence_options_prob, vocab_filename, encoder_type,
parser_rouge_ngrams_size, parser_rouge_metric_type,
parser_rouge_stopwords_filename, parser_rouge_compute_option,
parser_rouge_noise_ratio, parser_dynamic_mask_min_ratio,
shift_special_token_id)
(pretrain_inputs, pretrain_targets, pretrain_masked_inputs, _,
_) = pretrain_parsing_ops.sentence_mask_and_encode(
input_dic[input_feature], max_input_len, max_target_len,
max_total_words, parser_strategy, parser_masked_sentence_ratio,
parser_masked_words_ratio, parser_mask_word_options_prob,
parser_mask_sentence_options_prob, vocab_filename, encoder_type,
parser_rouge_ngrams_size, parser_rouge_metric_type,
parser_rouge_stopwords_filename, parser_rouge_compute_option,
parser_rouge_noise_ratio, parser_dynamic_mask_min_ratio,
shift_special_token_id)

supervised_inputs = parsing_ops.encode(
tf.reshape(input_dic["inputs"], [1]), max_input_len, vocab_filename,
Expand Down
13 changes: 13 additions & 0 deletions pegasus/ops/pretrain_parsing_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ REGISTER_OP("SentenceMaskAndEncode")
.Output("input_ids: int64")
.Output("target_ids: int64")
.Output("mlm_ids: int64")
.Output("num_sentences: int32")
.Output("num_masked_sentences: int32")
.Attr("strategy: string")
.Attr("masked_sentence_ratio: float")
.Attr("masked_words_ratio: float")
Expand Down Expand Up @@ -300,15 +302,23 @@ class SentenceMaskAndEncodeOp : public OpKernel {
Tensor* input_ids;
Tensor* target_ids;
Tensor* mlm_ids;
Tensor* num_sentences;
Tensor* num_masked_sentences;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({1, max_input_len}),
&input_ids));
OP_REQUIRES_OK(ctx, ctx->allocate_output(
1, TensorShape({1, max_target_len}), &target_ids));
OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({1, max_input_len}),
&mlm_ids));
OP_REQUIRES_OK(ctx,
ctx->allocate_output(3, TensorShape({1}), &num_sentences));
OP_REQUIRES_OK(
ctx, ctx->allocate_output(4, TensorShape({1}), &num_masked_sentences));
input_ids->flat<int64>().setZero();
target_ids->flat<int64>().setZero();
mlm_ids->flat<int64>().setZero();
num_sentences->flat<int32>().setZero();
num_masked_sentences->flat<int32>().setZero();

std::string text = "";
// set a limit on the total number of words in the text.
Expand All @@ -327,6 +337,8 @@ class SentenceMaskAndEncodeOp : public OpKernel {
}

std::vector<std::string> sentences_vec = SentenceSegment(text);
num_sentences->flat<int32>()(0) = sentences_vec.size();

std::vector<std::vector<int64>> sentences_ids_vec =
EncodeSentences(sentences_vec, encoder_);

Expand Down Expand Up @@ -386,6 +398,7 @@ class SentenceMaskAndEncodeOp : public OpKernel {
VecToTensor(input_ids_vec, input_ids, kPadTokenId, 0);
VecToTensor(target_ids_vec, target_ids, kPadTokenId, 0);
VecToTensor(mlm_ids_vec, mlm_ids, kPadTokenId, 0);
num_masked_sentences->flat<int32>()(0) = indices.size();
}

private:
Expand Down
163 changes: 102 additions & 61 deletions pegasus/ops/pretrain_parsing_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ class PretrainParsingOpsTest(tf.test.TestCase, parameterized.TestCase):
)
def test_sentence_mask_and_encode(self, parser_strategy, stopwords_filename):
string = " ".join("sentence %d." % i for i in range(10))
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", stopwords_filename,
"standard")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F",
stopwords_filename, "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_sentences, [10])
self.assertAllEqual(num_masked_sentences.shape, [1])

@parameterized.named_parameters(
("mlm_0", 0.5, [0.3, 0.3, 0.4]),
Expand All @@ -58,14 +62,18 @@ def test_sentence_mask_and_encode(self, parser_strategy, stopwords_filename):
)
def test_mlm_only(self, masked_words_ratio, bert_masking_procedure):
string = "the brown the brown the brown the brown"
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "none", 0.3, masked_words_ratio,
bert_masking_procedure, [1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1,
"F", _STOPWORDS, "standard")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "none", 0.3, masked_words_ratio,
bert_masking_procedure, [1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1,
"F", _STOPWORDS, "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 100])
self.assertAllEqual(target_t, [[1] + [0] * 99])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_masked_sentences.shape, [1])
self.assertAllEqual(num_masked_sentences, [0])
if masked_words_ratio == 0.0:
self.assertAllEqual(mlm_t, [[0] * 100])

Expand All @@ -81,14 +89,19 @@ def test_mlm_only(self, masked_words_ratio, bert_masking_procedure):
)
def test_empty_input(self, parser_strategy):
string = ""
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, parser_strategy, 0.4, 0.2, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS,
"standard")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, parser_strategy, 0.4, 0.2, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS,
"standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(target_t, [[1] + [0] * 9])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_masked_sentences.shape, [1])
self.assertAllEqual(num_sentences, [0])
self.assertAllEqual(num_masked_sentences, [0])

@parameterized.named_parameters(
("rouge_0", "rouge"),
Expand All @@ -98,20 +111,26 @@ def test_empty_input(self, parser_strategy):
def test_rouge_sentence_mask_and_encode_with_stopwords(self, parser_strategy):
string = " ".join("sentence me %d." % i for i in range(5)) + " "
string += " ".join("sentence %d." % i for i in range(5, 10))
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 300, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS,
"standard") # remove stopwords when computing rouge
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 300, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS,
"standard") # remove stopwords when computing rouge
self.assertAllEqual(input_t.shape, [1, 300])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 300])
input_t_2, target_t_2, mlm_t_2 = pretrain_parsing_ops.sentence_mask_and_encode(
string, 300, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard") # without removing stopwords
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_masked_sentences.shape, [1])
(input_t_2, target_t_2, mlm_t_2, num_sentences_2,
num_masked_sentences_2) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 300, 10, 0, parser_strategy, 0.3, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard") # without removing stopwords
self.assertAllEqual(input_t_2.shape, [1, 300])
self.assertAllEqual(target_t_2.shape, [1, 10])
self.assertAllEqual(mlm_t_2.shape, [1, 300])
self.assertAllEqual(num_sentences_2.shape, [1])
self.assertAllEqual(num_masked_sentences_2.shape, [1])
self.assertNotAllEqual(input_t, input_t_2)
self.assertNotAllEqual(target_t, target_t_2)

Expand All @@ -120,57 +139,72 @@ def test_rouge_sentence_mask_and_encode_with_stopwords(self, parser_strategy):
("rouge_recall", "recall", "1 1 1 1. 1. 2 1. 3 1. 4 5 6."),
)
def test_rouge_metric_type_precision_recall(self, metric_type, string):
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.2, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, metric_type, _STOPWORDS, "standard")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.2, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, metric_type, _STOPWORDS, "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(input_t[0][0], 2)
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_masked_sentences.shape, [1])

def test_rouge_compute_option_deduplicate(self):
string = " ".join("sentence %d." % i for i in range(5)) + " "
string += " ".join("sentence %d %d." % (i, i) for i in range(5, 10))
# the last five sentences should be masked
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "deduplicate")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "deduplicate")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences, [10])
self.assertAllEqual(num_masked_sentences, [5])
string_2 = " ".join("sentence %d." % i for i in range(10))
# the last five sentences should be masked
input_t_2, target_t_2, mlm_t_2 = pretrain_parsing_ops.sentence_mask_and_encode(
string_2, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "deduplicate")
(input_t_2, target_t_2, mlm_t_2, num_sentences_2,
num_masked_sentences_2) = pretrain_parsing_ops.sentence_mask_and_encode(
string_2, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "deduplicate")
self.assertAllEqual(input_t_2.shape, [1, 100])
self.assertAllEqual(target_t_2.shape, [1, 10])
self.assertAllEqual(mlm_t_2.shape, [1, 100])
self.assertAllEqual(num_sentences_2, [10])
self.assertAllEqual(num_masked_sentences_2, [5])
# since the first five sentences which are unmasked are identical
self.assertAllEqual(input_t, input_t_2)
self.assertNotAllEqual(target_t, target_t_2)

def test_rouge_compute_option_log(self):
string = " ".join("sentence %d." % i for i in range(10))
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "log")
(input_t, target_t, mlm_t, num_sentences,
num_masked_sentences) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", _STOPWORDS, "log")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
self.assertAllEqual(num_sentences.shape, [1])
self.assertAllEqual(num_sentences, [10])
self.assertAllEqual(num_masked_sentences.shape, [1])

def test_greedy_rouge(self):
string_list = ["1. 2. 4. 4.", "1. 2. 4. 4. 5. 6. 6. 6. 6."]
for string in string_list:
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t, target_t, mlm_t, _,
_) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
input_t_2, target_t_2, mlm_t_2 = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "greedy_rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t_2, target_t_2, mlm_t_2, _,
_) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "greedy_rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t_2.shape, [1, 100])
self.assertAllEqual(target_t_2.shape, [1, 10])
self.assertAllEqual(mlm_t_2.shape, [1, 100])
Expand All @@ -180,15 +214,17 @@ def test_greedy_rouge(self):

def test_continuous_rouge(self):
string = "1. 2. 3. 4. 5. 6. 7. 8. 9. 0. 1. 1."
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "continuous_rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t, target_t, mlm_t,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "continuous_rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[1, 0, 0, 0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
input_t_2, target_t_2, mlm_t_2 = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t_2, target_t_2, mlm_t_2,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t_2.shape, [1, 100])
self.assertAllEqual(target_t_2.shape, [1, 10])
self.assertAllEqual(mlm_t_2.shape, [1, 100])
Expand All @@ -197,44 +233,49 @@ def test_continuous_rouge(self):

def test_mask_sentence_rates(self):
string = ". ".join("%s" % i for i in range(10))
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[0.25, 0.25, 0.25, 0.25], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard")
(input_t, target_t, mlm_t,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1],
[0.25, 0.25, 0.25, 0.25], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 100])
self.assertAllEqual(mlm_t.shape, [1, 100])

def test_dynamic_rouge_rates(self):
string = "1 2 3. 4 5 6. 2 3 4. 2 3 5."
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 40, 40, 0, "dynamic_rouge", 0.8, 0.0, [0.9, 0.0, 0.1],
[0.9, 0.0, 0.1, 0.0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard", 0.25, 0.1)
(input_t, target_t, mlm_t,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 40, 40, 0, "dynamic_rouge", 0.8, 0.0, [0.9, 0.0, 0.1],
[0.9, 0.0, 0.1, 0.0], _SUBWORDS_PRETRAIN, "subword", 1, "F", "",
"standard", 0.25, 0.1)
self.assertAllEqual(input_t.shape, [1, 40])
self.assertAllEqual(target_t.shape, [1, 40])
self.assertAllEqual(mlm_t.shape, [1, 40])

def test_sentence_piece(self):
string = "beautifully. beautifully. beautifully. beautifully. beautifully."
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "random", 0.3, 0.5, [1, 0, 0], [1, 0, 0, 0],
_SENTENCEPIECE_VOCAB, "sentencepiece", 1, "F", "", "standard")
(input_t, target_t, mlm_t,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 100, 0, "random", 0.3, 0.5, [1, 0, 0], [1, 0, 0, 0],
_SENTENCEPIECE_VOCAB, "sentencepiece", 1, "F", "", "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 100])
self.assertAllEqual(mlm_t.shape, [1, 100])

def test_max_total_words(self):
string = "1. 2. 3. 4. 5. 6. 7. 8. 9. 0. 1. 1."
input_t, target_t, mlm_t = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 10, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t, target_t, mlm_t,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 10, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t.shape, [1, 100])
self.assertAllEqual(target_t.shape, [1, 10])
self.assertAllEqual(mlm_t.shape, [1, 100])
input_t_2, target_t_2, mlm_t_2 = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
(input_t_2, target_t_2, mlm_t_2,
_, _) = pretrain_parsing_ops.sentence_mask_and_encode(
string, 100, 10, 0, "rouge", 0.5, 0.0, [0.8, 0.1, 0.1], [1, 0, 0, 0],
_SUBWORDS_PRETRAIN, "subword", 1, "F", "", "standard")
self.assertAllEqual(input_t_2.shape, [1, 100])
self.assertAllEqual(target_t_2.shape, [1, 10])
self.assertAllEqual(mlm_t_2.shape, [1, 100])
Expand Down
3 changes: 2 additions & 1 deletion pegasus/ops/sentence_selection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
#include <utility>
#include <vector>

#include "pegasus/ops/rouge.h"
#include "glog/logging.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/random.h"
#include "pegasus/ops/rouge.h"

namespace pegasus {

Expand Down

0 comments on commit d4371c3

Please sign in to comment.