-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optimized decoder for the deployment of DS2 #139
Changes from 34 commits
724b0fb
348d6bb
59b4b87
a506198
903c300
1ca3814
34f98e0
ae05535
4e5b345
908932f
ac3a49c
9ff48b0
32047c7
f41375b
a96c650
3441148
89c4a96
bbbc988
d68732b
955d293
20d13a4
09f4c6e
b5c4d83
202a06a
beb0c07
103a6ac
efc5d9b
5a318e9
f8c7d46
e49f505
d75f27d
41e9e59
c4bc822
52a862d
552dd52
0bda37c
902c35b
bb35363
15728d0
e6740af
8c5576d
bcc236e
98d35b9
d7a9752
cc2f91f
f1cd672
cfecaa8
9db0d25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,8 @@ | |
|
||
## Installation | ||
|
||
### Basic setup | ||
|
||
Please make sure the above [prerequisites](#prerequisites) have been satisfied before moving on. | ||
|
||
```bash | ||
|
@@ -32,6 +34,16 @@ cd models/deep_speech_2 | |
sh setup.sh | ||
``` | ||
|
||
### Decoders setup | ||
|
||
```bash | ||
cd decoders/swig | ||
sh setup.sh | ||
cd ../.. | ||
``` | ||
|
||
These commands will install the decoders that translate the ouptut probability vectors of DS2 model to text data, incuding CTC greedy decoder, CTC beam search decoder and its batch version. And a detailed usuage about them will be given in the following sections. | ||
|
||
## Getting Started | ||
|
||
Several shell scripts provided in `./examples` will help us to quickly give it a try, for most major modules, including data preparation, model training, case inference and model evaluation, with a few public dataset (e.g. [LibriSpeech](http://www.openslr.org/12/), [Aishell](http://www.openslr.org/33)). Reading these examples will also help you to understand how to make it work with your own data. | ||
|
@@ -176,6 +188,8 @@ Data augmentation has often been a highly effective technique to boost the deep | |
|
||
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. | ||
|
||
### Inference | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why to add L191-192? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a mistake. Removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove L179 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
- Volume Perturbation | ||
- Speed Perturbation | ||
- Shifting Perturbation | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,332 @@ | ||
#include "ctc_decoders.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow Google Coding Style for import order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the including order required by Google Coding Style: https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes |
||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <iostream> | ||
#include <limits> | ||
#include <map> | ||
#include <utility> | ||
|
||
#include "ThreadPool.h" | ||
#include "fst/fstlib.h" | ||
|
||
#include "decoder_utils.h" | ||
#include "path_trie.h" | ||
|
||
std::string ctc_greedy_decoder( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put ctc_greedy_decoder into another file ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
const std::vector<std::vector<double>> &probs_seq, | ||
const std::vector<std::string> &vocabulary) { | ||
// dimension check | ||
int num_time_steps = probs_seq.size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int --> size_t There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
for (int i = 0; i < num_time_steps; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int --> size_t There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if (probs_seq[i].size() != vocabulary.size() + 1) { | ||
std::cout << "The shape of probs_seq does not match" | ||
<< " with the shape of the vocabulary!" << std::endl; | ||
exit(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exit(1) --> throw exception There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
|
||
int blank_id = vocabulary.size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int --> size_t There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
std::vector<int> max_idx_vec; | ||
double max_prob = 0.0; | ||
int max_idx = 0; | ||
for (int i = 0; i < num_time_steps; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了简短,可以把L32-33放到L35前面,然后去除L42-43. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
for (int j = 0; j < probs_seq[i].size(); j++) { | ||
if (max_prob < probs_seq[i][j]) { | ||
max_idx = j; | ||
max_prob = probs_seq[i][j]; | ||
} | ||
} | ||
max_idx_vec.push_back(max_idx); | ||
max_prob = 0.0; | ||
max_idx = 0; | ||
} | ||
|
||
std::vector<int> idx_vec; | ||
for (int i = 0; i < max_idx_vec.size(); i++) { | ||
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { | ||
idx_vec.push_back(max_idx_vec[i]); | ||
} | ||
} | ||
|
||
std::string best_path_result; | ||
for (int i = 0; i < idx_vec.size(); i++) { | ||
if (idx_vec[i] != blank_id) { | ||
best_path_result += vocabulary[idx_vec[i]]; | ||
} | ||
} | ||
return best_path_result; | ||
} | ||
|
||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is too long. Could you try to separate it into multiple functions to make it more readable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
const std::vector<std::vector<double>> &probs_seq, | ||
int beam_size, | ||
std::vector<std::string> vocabulary, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const & There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's no modification for vocabulary, please use const & instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
int blank_id, | ||
double cutoff_prob, | ||
int cutoff_top_n, | ||
Scorer *extscorer) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// dimension check | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrap a function to do dimension check ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
size_t num_time_steps = probs_seq.size(); | ||
for (int i = 0; i < num_time_steps; i++) { | ||
if (probs_seq[i].size() != vocabulary.size() + 1) { | ||
std::cout << " The shape of probs_seq does not match" | ||
<< " with the shape of the vocabulary!" << std::endl; | ||
exit(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not use exit(1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
|
||
// blank_id check | ||
if (blank_id > vocabulary.size()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在greedy decode中 blank_id 自动获取的,和这里用输入参数不同。两个decoder的这一逻辑能否保持一致? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the argument |
||
std::cout << " Invalid blank_id! " << std::endl; | ||
exit(1); | ||
} | ||
|
||
// assign space ID | ||
std::vector<std::string>::iterator it = | ||
std::find(vocabulary.begin(), vocabulary.end(), " "); | ||
int space_id = it - vocabulary.begin(); | ||
// if no space in vocabulary | ||
if (space_id >= vocabulary.size()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. compare signed value with unsigned value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
space_id = -2; | ||
} | ||
|
||
// init prefixes' root | ||
PathTrie root; | ||
root.score = root.log_prob_b_prev = 0.0; | ||
std::vector<PathTrie *> prefixes; | ||
prefixes.push_back(&root); | ||
|
||
if (extscorer != nullptr) { | ||
if (extscorer->is_char_map_empty()) { | ||
extscorer->set_char_map(vocabulary); | ||
} | ||
if (!extscorer->is_character_based()) { | ||
if (extscorer->dictionary == nullptr) { | ||
// fill dictionary for fst with space | ||
extscorer->fill_dictionary(true); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L94 to L101 could be all put into a setup method inside Scorer itself and make sure it is called outside decoder. That means a scorer must be well prepared before being passed into the decoder function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary); | ||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required by FST There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有点奇怪,check一下接口? Copy(true) 干什么用的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处应该不是一个简单的指针赋值,而是根据状态不同返回不同的 |
||
root.set_dictionary(dict_ptr); | ||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); | ||
root.set_matcher(matcher); | ||
} | ||
} | ||
|
||
// prefix search over time | ||
for (int time_step = 0; time_step < num_time_steps; time_step++) { | ||
std::vector<double> prob = probs_seq[time_step]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
std::vector<std::pair<int, double>> prob_idx; | ||
for (int i = 0; i < prob.size(); i++) { | ||
prob_idx.push_back(std::pair<int, double>(i, prob[i])); | ||
} | ||
|
||
float min_cutoff = -NUM_FLT_INF; | ||
bool full_beam = false; | ||
if (extscorer != nullptr) { | ||
int num_prefixes = std::min((int)prefixes.size(), beam_size); | ||
std::sort( | ||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); | ||
min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) - | ||
std::max(0.0, extscorer->beta); | ||
full_beam = (num_prefixes == beam_size); | ||
} | ||
|
||
// pruning of vacobulary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 例如这一个代码段,如果只做一件事,是否可以抽出来形成一个函数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
int cutoff_len = prob.size(); | ||
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { | ||
std::sort( | ||
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); | ||
if (cutoff_prob < 1.0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we can use while loop to remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel unnecessary |
||
double cum_prob = 0.0; | ||
cutoff_len = 0; | ||
for (int i = 0; i < prob_idx.size(); i++) { | ||
cum_prob += prob_idx[i].second; | ||
cutoff_len += 1; | ||
if (cum_prob >= cutoff_prob) break; | ||
} | ||
} | ||
cutoff_len = std::min(cutoff_len, cutoff_top_n); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move line 143 into for loop as a stop condition There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part has been moved to a function |
||
prob_idx = std::vector<std::pair<int, double>>( | ||
prob_idx.begin(), prob_idx.begin() + cutoff_len); | ||
} | ||
std::vector<std::pair<int, float>> log_prob_idx; | ||
for (int i = 0; i < cutoff_len; i++) { | ||
log_prob_idx.push_back(std::pair<int, float>( | ||
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. --> |
||
} | ||
|
||
// loop over chars | ||
for (int index = 0; index < log_prob_idx.size(); index++) { | ||
auto c = log_prob_idx[index].first; | ||
float log_prob_c = log_prob_idx[index].second; | ||
|
||
for (int i = 0; i < prefixes.size() && i < beam_size; i++) { | ||
auto prefix = prefixes[i]; | ||
|
||
if (full_beam && log_prob_c + prefix->score < min_cutoff) { | ||
break; | ||
} | ||
// blank | ||
if (c == blank_id) { | ||
prefix->log_prob_b_cur = | ||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); | ||
continue; | ||
} | ||
// repeated character | ||
if (c == prefix->character) { | ||
prefix->log_prob_nb_cur = log_sum_exp( | ||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); | ||
} | ||
// get new prefix | ||
auto prefix_new = prefix->get_path_trie(c); | ||
|
||
if (prefix_new != nullptr) { | ||
float log_p = -NUM_FLT_INF; | ||
|
||
if (c == prefix->character && | ||
prefix->log_prob_b_prev > -NUM_FLT_INF) { | ||
log_p = log_prob_c + prefix->log_prob_b_prev; | ||
} else if (c != prefix->character) { | ||
log_p = log_prob_c + prefix->score; | ||
} | ||
|
||
// language model scoring | ||
if (extscorer != nullptr && | ||
(c == space_id || extscorer->is_character_based())) { | ||
PathTrie *prefix_toscore = nullptr; | ||
|
||
// skip scoring the space | ||
if (extscorer->is_character_based()) { | ||
prefix_toscore = prefix_new; | ||
} else { | ||
prefix_toscore = prefix; | ||
} | ||
|
||
double score = 0.0; | ||
std::vector<std::string> ngram; | ||
ngram = extscorer->make_ngram(prefix_toscore); | ||
score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha; | ||
|
||
log_p += score; | ||
log_p += extscorer->beta; | ||
} | ||
prefix_new->log_prob_nb_cur = | ||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); | ||
} | ||
} // end of loop over prefix | ||
} // end of loop over chars | ||
|
||
prefixes.clear(); | ||
// update log probs | ||
root.iterate_to_vec(prefixes); | ||
|
||
// only preserve top beam_size prefixes | ||
if (prefixes.size() >= beam_size) { | ||
std::nth_element(prefixes.begin(), | ||
prefixes.begin() + beam_size, | ||
prefixes.end(), | ||
prefix_compare); | ||
|
||
for (size_t i = beam_size; i < prefixes.size(); i++) { | ||
prefixes[i]->remove(); | ||
} | ||
} | ||
} // end of loop over time | ||
|
||
// compute aproximate ctc score as the return score | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put into a separated function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is temporarily kept to compare with speech-dl decoder, will be removed some time. And have already put the following part into a function |
||
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { | ||
double approx_ctc = prefixes[i]->score; | ||
|
||
if (extscorer != nullptr) { | ||
std::vector<int> output; | ||
prefixes[i]->get_path_vec(output); | ||
size_t prefix_length = output.size(); | ||
auto words = extscorer->split_labels(output); | ||
// remove word insert | ||
approx_ctc = approx_ctc - prefix_length * extscorer->beta; | ||
// remove language model weight: | ||
approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha; | ||
} | ||
|
||
prefixes[i]->approx_ctc = approx_ctc; | ||
} | ||
|
||
// allow for the post processing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put into a separated function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
std::vector<PathTrie *> space_prefixes; | ||
if (space_prefixes.empty()) { | ||
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { | ||
space_prefixes.push_back(prefixes[i]); | ||
} | ||
} | ||
|
||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); | ||
std::vector<std::pair<double, std::string>> output_vecs; | ||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { | ||
std::vector<int> output; | ||
space_prefixes[i]->get_path_vec(output); | ||
// convert index to string | ||
std::string output_str; | ||
for (int j = 0; j < output.size(); j++) { | ||
output_str += vocabulary[output[j]]; | ||
} | ||
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc, | ||
output_str); | ||
output_vecs.emplace_back(output_pair); | ||
} | ||
|
||
return output_vecs; | ||
} | ||
|
||
std::vector<std::vector<std::pair<double, std::string>>> | ||
ctc_beam_search_decoder_batch( | ||
const std::vector<std::vector<std::vector<double>>> &probs_split, | ||
int beam_size, | ||
const std::vector<std::string> &vocabulary, | ||
int blank_id, | ||
int num_processes, | ||
double cutoff_prob, | ||
int cutoff_top_n, | ||
Scorer *extscorer) { | ||
if (num_processes <= 0) { | ||
std::cout << "num_processes must be nonnegative!" << std::endl; | ||
exit(1); | ||
} | ||
// thread pool | ||
ThreadPool pool(num_processes); | ||
// number of samples | ||
int batch_size = probs_split.size(); | ||
|
||
// scorer filling up | ||
if (extscorer != nullptr) { | ||
if (extscorer->is_char_map_empty()) { | ||
extscorer->set_char_map(vocabulary); | ||
} | ||
if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) { | ||
// init dictionary | ||
extscorer->fill_dictionary(true); | ||
} | ||
} | ||
|
||
// enqueue the tasks of decoding | ||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; | ||
for (int i = 0; i < batch_size; i++) { | ||
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, | ||
probs_split[i], | ||
beam_size, | ||
vocabulary, | ||
blank_id, | ||
cutoff_prob, | ||
cutoff_top_n, | ||
extscorer)); | ||
} | ||
|
||
// get decoding results | ||
std::vector<std::vector<std::pair<double, std::string>>> batch_results; | ||
for (int i = 0; i < batch_size; i++) { | ||
batch_results.emplace_back(res[i].get()); | ||
} | ||
return batch_results; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要这个,直接把安装脚本放到根目录下的setup.sh里面,对用户透明。不需要在doc中解释。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done