From 3227a88e869039773ba6103e14f8ed981bcec315 Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Mon, 25 Oct 2021 18:29:25 +0800 Subject: [PATCH 1/2] optimize fast tokenizer --- .../operators/string/faster_tokenizer_op.cc | 22 +++++++++++-------- .../operators/string/faster_tokenizer_op.h | 19 ++++++++-------- 2 files changed, 22 insertions(+), 19 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/string/faster_tokenizer_op.h diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index 49457af8f00c80..794d65d0fa0fe5 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -100,9 +100,14 @@ void BasicTokenizer::Tokenize(const string& text, vector* res) const { // String is converted into wstring failedly. return; } - - std::wstring dest_text; - for (auto ch : unicode_text) { + std::wstring cache_text = L""; + auto PushCacheText = [&]() { + if (cache_text != L"") { + res->emplace_back(cache_text); + cache_text = L""; + } + }; + for (auto& ch : unicode_text) { if (ch == 0 || ch == 0xfffd || IsControl(ch)) { continue; } @@ -110,16 +115,15 @@ void BasicTokenizer::Tokenize(const string& text, vector* res) const { ch = do_lower_case(ch); } if (IsChineseChar(ch) || IsPunctuation(ch)) { - dest_text += ' '; - dest_text += ch; - dest_text += ' '; + PushCacheText(); + res->emplace_back(std::wstring{ch}); } else if (IsWhiteSpace(ch)) { - dest_text += ' '; + PushCacheText(); } else { - dest_text += ch; + cache_text += ch; } } - boost::split(*res, dest_text, boost::is_any_of(kStripChars)); + PushCacheText(); } WordPieceTokenizer::WordPieceTokenizer( diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.h b/paddle/fluid/operators/string/faster_tokenizer_op.h old mode 100755 new mode 100644 index d9b7fa26a6704b..28e54295d4732d --- a/paddle/fluid/operators/string/faster_tokenizer_op.h +++ b/paddle/fluid/operators/string/faster_tokenizer_op.h @@ -140,21 +140,21 @@ class FasterTokenizerKernel : public framework::OpKernel { return; } - BertTokenizer* tokenizer_ptr = - new BertTokenizer(const_cast(vocab), do_lower_case); + BertTokenizer tokenizer(const_cast(vocab), + do_lower_case); size_t batch_max_seq_len = 0; size_t batch_size = text->size(); vector>> batch_encode_inputs( batch_size); if (text_pair) { - tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, *text_pair, - is_split_into_words, max_seq_len, - pad_to_max_seq_len); + tokenizer.BatchEncode(&batch_encode_inputs, *text, *text_pair, + is_split_into_words, max_seq_len, + pad_to_max_seq_len); } else { - tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, vector(), - is_split_into_words, max_seq_len, - pad_to_max_seq_len); + tokenizer.BatchEncode(&batch_encode_inputs, *text, vector(), + is_split_into_words, max_seq_len, + pad_to_max_seq_len); } for (size_t i = 0; i < batch_size; ++i) { @@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel { static_cast(batch_max_seq_len)})); auto* seg_ids_data = seg_ids->mutable_data(ctx.GetPlace()); - auto pad_token_id = tokenizer_ptr->GetPadTokenID(); + auto pad_token_id = tokenizer.GetPadTokenID(); for (size_t i = 0; i < batch_size; i++) { auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"]; auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"]; @@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel { std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id, (batch_max_seq_len - seq_len) * sizeof(T)); } - delete tokenizer_ptr; } }; From 2127f6d6a087465443c899114a9fd6e4c5a7ab50 Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Mon, 25 Oct 2021 21:40:38 +0800 Subject: [PATCH 2/2] remove const_cast --- .../operators/string/faster_tokenizer_op.cc | 16 ++++++++-------- .../fluid/operators/string/faster_tokenizer_op.h | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index 794d65d0fa0fe5..42047021b408a8 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -127,12 +127,12 @@ void BasicTokenizer::Tokenize(const string& text, vector* res) const { } WordPieceTokenizer::WordPieceTokenizer( - framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/, + const framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/, const size_t max_input_chars_per_word /* = 100 */) : vocab_(vocab), unk_token_(unk_token), max_input_chars_per_word_(max_input_chars_per_word) { - unk_token_id_ = (*vocab_)[unk_token_]; + unk_token_id_ = vocab_->at(unk_token_); } void WordPieceTokenizer::Tokenize(const wstring& text, @@ -182,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text, } } -BertTokenizer::BertTokenizer(framework::Vocab* vocab, +BertTokenizer::BertTokenizer(const framework::Vocab* vocab, bool do_lower_case /* = false */, const wstring& unk_token /* = L"[UNK]" */, const wstring& pad_token /* = L"[PAD]" */, @@ -200,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab, vocab_(vocab), basic_tokenizer_(do_lower_case_), word_piece_tokenizer_(vocab_, unk_token) { - unk_token_id_ = (*vocab_)[unk_token_]; - pad_token_id_ = (*vocab_)[pad_token_]; - cls_token_id_ = (*vocab_)[cls_token_]; - mask_token_id_ = (*vocab_)[mask_token_]; - sep_token_id_ = (*vocab_)[sep_token_]; + unk_token_id_ = vocab_->at(unk_token_); + pad_token_id_ = vocab_->at(pad_token_); + cls_token_id_ = vocab_->at(cls_token_); + mask_token_id_ = vocab_->at(mask_token_); + sep_token_id_ = vocab_->at(sep_token_); all_special_tokens_ = vector( {unk_token_, pad_token_, cls_token_, mask_token_, sep_token_}); diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.h b/paddle/fluid/operators/string/faster_tokenizer_op.h index 28e54295d4732d..5218b7c2eaa51d 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.h +++ b/paddle/fluid/operators/string/faster_tokenizer_op.h @@ -56,13 +56,13 @@ class BasicTokenizer { class WordPieceTokenizer { public: - explicit WordPieceTokenizer(framework::Vocab* vocab, + explicit WordPieceTokenizer(const framework::Vocab* vocab, const wstring& unk_token = L"[UNK]", const size_t max_input_chars_per_word = 100); void Tokenize(const wstring& text, vector* output) const; private: - framework::Vocab* vocab_; + const framework::Vocab* vocab_; wstring unk_token_{L"[UNK]"}; int64_t unk_token_id_; size_t max_input_chars_per_word_; @@ -70,7 +70,8 @@ class WordPieceTokenizer { class BertTokenizer { public: - explicit BertTokenizer(framework::Vocab* vocab, bool do_lower_case = false, + explicit BertTokenizer(const framework::Vocab* vocab, + bool do_lower_case = false, const wstring& unk_token = L"[UNK]", const wstring& pad_token = L"[PAD]", const wstring& cls_token = L"[CLS]", @@ -106,7 +107,7 @@ class BertTokenizer { bool do_lower_case_; wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_; string padding_site_; - framework::Vocab* vocab_; + const framework::Vocab* vocab_; BasicTokenizer basic_tokenizer_; WordPieceTokenizer word_piece_tokenizer_; int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_, @@ -140,8 +141,7 @@ class FasterTokenizerKernel : public framework::OpKernel { return; } - BertTokenizer tokenizer(const_cast(vocab), - do_lower_case); + BertTokenizer tokenizer(vocab, do_lower_case); size_t batch_max_seq_len = 0; size_t batch_size = text->size();