From d6c30288943d9f4c8f216198f76e9686f97007ee Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 3 Jul 2017 15:04:00 +0800 Subject: [PATCH] fix a bug in scorer --- deep_speech_2/deploy.py | 3 ++- .../deploy/ctc_beam_search_decoder.cpp | 26 +++++++++---------- deep_speech_2/deploy/scorer.cpp | 9 +++---- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/deep_speech_2/deploy.py b/deep_speech_2/deploy.py index 3272371bf86..0f85fa0d9c9 100644 --- a/deep_speech_2/deploy.py +++ b/deep_speech_2/deploy.py @@ -162,9 +162,10 @@ def infer(): for i, probs in enumerate(probs_split) ] + # external scorer ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - ## decode and print + ## decode and print wer_sum, wer_counter = 0, 0 for i, probs in enumerate(probs_split): beam_result = ctc_beam_search_decoder( diff --git a/deep_speech_2/deploy/ctc_beam_search_decoder.cpp b/deep_speech_2/deploy/ctc_beam_search_decoder.cpp index 297c7c24b21..cf58aa7fdad 100644 --- a/deep_speech_2/deploy/ctc_beam_search_decoder.cpp +++ b/deep_speech_2/deploy/ctc_beam_search_decoder.cpp @@ -15,10 +15,10 @@ bool pair_comp_second_rev(const std::pair a, const std::pair b) return a.second > b.second; } -/* CTC beam search decoder in C++, the interface is consistent with the original +/* CTC beam search decoder in C++, the interface is consistent with the original decoder in Python version. */ -std::vector > +std::vector > ctc_beam_search_decoder(std::vector > probs_seq, int beam_size, std::vector vocabulary, @@ -29,15 +29,15 @@ std::vector > ) { int num_time_steps = probs_seq.size(); - - // assign space ID + + // assign space ID std::vector::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it-vocabulary.begin(); if(space_id >= vocabulary.size()) { std::cout<<"The character space is not in the vocabulary!"; - exit(1); + exit(1); } - + // initialize // two sets containing selected and candidate prefixes respectively std::map prefix_set_prev, prefix_set_next; @@ -47,7 +47,7 @@ std::vector > prefix_set_prev["\t"] = 1.0; probs_b_prev["\t"] = 1.0; probs_nb_prev["\t"] = 0.0; - + for (int time_step=0; time_step > } prob_idx = std::vector >(prob_idx.begin(), prob_idx.begin()+cutoff_len); } - // extend prefix - for (std::map::iterator it = prefix_set_prev.begin(); + // extend prefix + for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { std::string l = it->first; if( prefix_set_next.find(l) == prefix_set_next.end()) { @@ -109,12 +109,12 @@ std::vector > } } - prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; + prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]; } probs_b_prev = probs_b_cur; probs_nb_prev = probs_nb_cur; - std::vector > + std::vector > prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end()); std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev); int k = beam_size > // post processing std::vector > beam_result; - for (std::map::iterator it = prefix_set_prev.begin(); + for (std::map::iterator it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { if (it->second > 0.0 && it->first.size() > 1) { double prob = it->second; @@ -134,7 +134,7 @@ std::vector > prob = prob * ext_scorer->get_score(sentence); } double log_prob = log(it->second); - beam_result.push_back(std::pair(log_prob, it->first)); + beam_result.push_back(std::pair(log_prob, sentence)); } } // sort the result and return diff --git a/deep_speech_2/deploy/scorer.cpp b/deep_speech_2/deploy/scorer.cpp index 9cb68055679..6343c32852b 100644 --- a/deep_speech_2/deploy/scorer.cpp +++ b/deep_speech_2/deploy/scorer.cpp @@ -35,7 +35,7 @@ inline void strip(std::string &str, char ch=' ') { break; } } - + if (start == 0 && end == str.size()-1) return; if (start > end) { std::string emp_str; @@ -47,13 +47,12 @@ inline void strip(std::string &str, char ch=' ') { int Scorer::word_count(std::string sentence) { strip(sentence); - int cnt = 0; + int cnt = 1; for (int i=0; i 0) cnt ++; return cnt; } @@ -68,8 +67,8 @@ double Scorer::language_model_score(std::string sentence) { ret = model->FullScore(state, vocab, out_state); state = out_state; } - double score = ret.prob; - + double score = ret.prob; + return pow(10, score); }