Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize FasterTokenizer #36701

Merged
merged 2 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions paddle/fluid/operators/string/faster_tokenizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,39 @@ void BasicTokenizer::Tokenize(const string& text, vector<wstring>* res) const {
// String is converted into wstring failedly.
return;
}

std::wstring dest_text;
for (auto ch : unicode_text) {
std::wstring cache_text = L"";
auto PushCacheText = [&]() {
if (cache_text != L"") {
res->emplace_back(cache_text);
cache_text = L"";
}
};
for (auto& ch : unicode_text) {
if (ch == 0 || ch == 0xfffd || IsControl(ch)) {
continue;
}
if (do_lower_case_) {
ch = do_lower_case(ch);
}
if (IsChineseChar(ch) || IsPunctuation(ch)) {
dest_text += ' ';
dest_text += ch;
dest_text += ' ';
PushCacheText();
res->emplace_back(std::wstring{ch});
} else if (IsWhiteSpace(ch)) {
dest_text += ' ';
PushCacheText();
} else {
dest_text += ch;
cache_text += ch;
}
}
boost::split(*res, dest_text, boost::is_any_of(kStripChars));
PushCacheText();
}

WordPieceTokenizer::WordPieceTokenizer(
framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/,
const framework::Vocab* vocab, const wstring& unk_token /* = L"[UNK]"*/,
const size_t max_input_chars_per_word /* = 100 */)
: vocab_(vocab),
unk_token_(unk_token),
max_input_chars_per_word_(max_input_chars_per_word) {
unk_token_id_ = (*vocab_)[unk_token_];
unk_token_id_ = vocab_->at(unk_token_);
}

void WordPieceTokenizer::Tokenize(const wstring& text,
Expand Down Expand Up @@ -178,7 +182,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text,
}
}

BertTokenizer::BertTokenizer(framework::Vocab* vocab,
BertTokenizer::BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case /* = false */,
const wstring& unk_token /* = L"[UNK]" */,
const wstring& pad_token /* = L"[PAD]" */,
Expand All @@ -196,11 +200,11 @@ BertTokenizer::BertTokenizer(framework::Vocab* vocab,
vocab_(vocab),
basic_tokenizer_(do_lower_case_),
word_piece_tokenizer_(vocab_, unk_token) {
unk_token_id_ = (*vocab_)[unk_token_];
pad_token_id_ = (*vocab_)[pad_token_];
cls_token_id_ = (*vocab_)[cls_token_];
mask_token_id_ = (*vocab_)[mask_token_];
sep_token_id_ = (*vocab_)[sep_token_];
unk_token_id_ = vocab_->at(unk_token_);
pad_token_id_ = vocab_->at(pad_token_);
cls_token_id_ = vocab_->at(cls_token_);
mask_token_id_ = vocab_->at(mask_token_);
sep_token_id_ = vocab_->at(sep_token_);

all_special_tokens_ = vector<wstring>(
{unk_token_, pad_token_, cls_token_, mask_token_, sep_token_});
Expand Down
27 changes: 13 additions & 14 deletions paddle/fluid/operators/string/faster_tokenizer_op.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,22 @@ class BasicTokenizer {

class WordPieceTokenizer {
public:
explicit WordPieceTokenizer(framework::Vocab* vocab,
explicit WordPieceTokenizer(const framework::Vocab* vocab,
const wstring& unk_token = L"[UNK]",
const size_t max_input_chars_per_word = 100);
void Tokenize(const wstring& text, vector<int64_t>* output) const;

private:
framework::Vocab* vocab_;
const framework::Vocab* vocab_;
wstring unk_token_{L"[UNK]"};
int64_t unk_token_id_;
size_t max_input_chars_per_word_;
};

class BertTokenizer {
public:
explicit BertTokenizer(framework::Vocab* vocab, bool do_lower_case = false,
explicit BertTokenizer(const framework::Vocab* vocab,
bool do_lower_case = false,
const wstring& unk_token = L"[UNK]",
const wstring& pad_token = L"[PAD]",
const wstring& cls_token = L"[CLS]",
Expand Down Expand Up @@ -106,7 +107,7 @@ class BertTokenizer {
bool do_lower_case_;
wstring unk_token_, pad_token_, cls_token_, mask_token_, sep_token_;
string padding_site_;
framework::Vocab* vocab_;
const framework::Vocab* vocab_;
BasicTokenizer basic_tokenizer_;
WordPieceTokenizer word_piece_tokenizer_;
int64_t unk_token_id_, cls_token_id_, mask_token_id_, pad_token_id_,
Expand Down Expand Up @@ -140,21 +141,20 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
return;
}

BertTokenizer* tokenizer_ptr =
new BertTokenizer(const_cast<framework::Vocab*>(vocab), do_lower_case);
BertTokenizer tokenizer(vocab, do_lower_case);
size_t batch_max_seq_len = 0;
size_t batch_size = text->size();

vector<unordered_map<string, vector<int64_t>>> batch_encode_inputs(
batch_size);
if (text_pair) {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, *text_pair,
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
tokenizer.BatchEncode(&batch_encode_inputs, *text, *text_pair,
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
} else {
tokenizer_ptr->BatchEncode(&batch_encode_inputs, *text, vector<string>(),
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
tokenizer.BatchEncode(&batch_encode_inputs, *text, vector<string>(),
is_split_into_words, max_seq_len,
pad_to_max_seq_len);
}

for (size_t i = 0; i < batch_size; ++i) {
Expand All @@ -173,7 +173,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
static_cast<int64_t>(batch_max_seq_len)}));
auto* seg_ids_data = seg_ids->mutable_data<T>(ctx.GetPlace());

auto pad_token_id = tokenizer_ptr->GetPadTokenID();
auto pad_token_id = tokenizer.GetPadTokenID();
for (size_t i = 0; i < batch_size; i++) {
auto& encoder_input_ids = batch_encode_inputs[i]["input_ids"];
auto& encoder_seg_ids = batch_encode_inputs[i]["token_type_ids"];
Expand All @@ -188,7 +188,6 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
std::memset(seg_ids_data + i * batch_max_seq_len + seq_len, pad_token_id,
(batch_max_seq_len - seq_len) * sizeof(T));
}
delete tokenizer_ptr;
}
};

Expand Down