Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized decoder for the deployment of DS2 #139

Merged
merged 48 commits into from
Sep 18, 2017
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
724b0fb
add initial files for deployment
Jun 29, 2017
348d6bb
add deploy.py
Jun 29, 2017
59b4b87
Merge branch 'develop' into ctc_decoder_deploy
Jun 29, 2017
a506198
fix bugs
Jul 4, 2017
903c300
Merge branch 'develop' into ctc_decoder_deploy
Jul 5, 2017
1ca3814
code cleanup for the deployment decoder
Jul 6, 2017
34f98e0
add setup and README for deployment
Jul 6, 2017
ae05535
enable loading language model in multiple format
Jul 10, 2017
4e5b345
change probs' computation into log scale & add best path decoder
Jul 27, 2017
908932f
refine the interface of decoders in swig
Aug 3, 2017
ac3a49c
Delete swig_decoder.py
kuke Aug 3, 2017
9ff48b0
reorganize cpp files
Aug 22, 2017
32047c7
refine wrapper for swig and simplify setup
Aug 22, 2017
f41375b
add the support of parallel beam search decoding in deployment
Aug 23, 2017
a96c650
Refactor scorer and move utility functions to decoder_util.h
pkuyym Aug 23, 2017
3441148
Merge branch 'ctc_decoder_deploy' of /~https://github.com/kuke/models i…
pkuyym Aug 23, 2017
89c4a96
Make setup.py to support parallel processing.
pkuyym Aug 23, 2017
bbbc988
adapt to the last three commits
Aug 23, 2017
d68732b
convert data structure for prefix from map to trie tree
Aug 24, 2017
955d293
enable finite-state transducer in beam search decoding
Aug 29, 2017
20d13a4
streamline source code
Aug 29, 2017
09f4c6e
remove unused functions in Scorer
Aug 29, 2017
b5c4d83
add min cutoff & top n cutoff
Aug 30, 2017
202a06a
Merge branch 'develop' of /~https://github.com/PaddlePaddle/models into…
Aug 30, 2017
beb0c07
clean up code & update README for decoder in deployment
Aug 30, 2017
103a6ac
format C++ source code
Sep 6, 2017
efc5d9b
Merge branch 'develop' of /~https://github.com/PaddlePaddle/models into…
Sep 7, 2017
5a318e9
adapt to the new folder structure of DS2
Sep 8, 2017
f8c7d46
format header includes & update setup info
Sep 8, 2017
e49f505
resolve conflicts in model.py
Sep 8, 2017
d75f27d
Merge branch 'develop' of /~https://github.com/PaddlePaddle/models into…
Sep 8, 2017
41e9e59
append some comments
Sep 8, 2017
c4bc822
adapt to the new structure
Sep 13, 2017
52a862d
add __init__.py in decoders/swig
Sep 13, 2017
552dd52
move deprecated decoders
Sep 14, 2017
0bda37c
refine by following review comments
Sep 15, 2017
902c35b
append some changes
Sep 15, 2017
bb35363
Merge branch 'develop' of upstream into ctc_decoder_deploy
Sep 16, 2017
15728d0
expose param cutoff_top_n
Sep 16, 2017
e6740af
adjust scorer's init & add logging for scorer & separate long functions
Sep 17, 2017
8c5576d
format varabiables' name & add more comments
Sep 17, 2017
bcc236e
Merge branch 'develop' of upstream into ctc_decoder_deploy
Sep 18, 2017
98d35b9
adjust to pass ci
Sep 18, 2017
d7a9752
specify clang_format to ver3.9
Sep 18, 2017
cc2f91f
disable the make output of libsndfile in setup
Sep 18, 2017
f1cd672
use cd instead of pushd in setup.sh
Sep 18, 2017
cfecaa8
Merge branch 'develop' of upstream into ctc_decoder_deploy
Sep 18, 2017
9db0d25
pass unittest for deprecated decoders
Sep 18, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions deep_speech_2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

## Installation

### Basic setup

Please make sure the above [prerequisites](#prerequisites) have been satisfied before moving on.

```bash
Expand All @@ -32,6 +34,16 @@ cd models/deep_speech_2
sh setup.sh
```

### Decoders setup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要这个,直接把安装脚本放到根目录下的setup.sh里面,对用户透明。不需要在doc中解释。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


```bash
cd decoders/swig
sh setup.sh
cd ../..
```

These commands will install the decoders that translate the ouptut probability vectors of DS2 model to text data, incuding CTC greedy decoder, CTC beam search decoder and its batch version. And a detailed usuage about them will be given in the following sections.

## Getting Started

Several shell scripts provided in `./examples` will help us to quickly give it a try, for most major modules, including data preparation, model training, case inference and model evaluation, with a few public dataset (e.g. [LibriSpeech](http://www.openslr.org/12/), [Aishell](http://www.openslr.org/33)). Reading these examples will also help you to understand how to make it work with your own data.
Expand Down Expand Up @@ -176,6 +188,8 @@ Data augmentation has often been a highly effective technique to boost the deep

Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.

### Inference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why to add L191-192?

Copy link
Collaborator Author

@kuke kuke Sep 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mistake. Removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove L179

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


- Volume Perturbation
- Speed Perturbation
- Shifting Perturbation
Expand Down
Empty file.
Empty file.
332 changes: 332 additions & 0 deletions deep_speech_2/decoders/swig/ctc_decoders.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
#include "ctc_decoders.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow Google Coding Style for import order.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the including order required by Google Coding Style: https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes


#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>

#include "ThreadPool.h"
#include "fst/fstlib.h"

#include "decoder_utils.h"
#include "path_trie.h"

std::string ctc_greedy_decoder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put ctc_greedy_decoder into another file ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary) {
// dimension check
int num_time_steps = probs_seq.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int --> size_t

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int i = 0; i < num_time_steps; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int --> size_t
i++ --> ++i (modify all other places)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << "The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl;
exit(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exit(1) --> throw exception

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

int blank_id = vocabulary.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int --> size_t
对于size,index等不可能为负的变量,用size_t 更好一点。
同理,请修改其他地方。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


std::vector<int> max_idx_vec;
double max_prob = 0.0;
int max_idx = 0;
for (int i = 0; i < num_time_steps; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了简短,可以把L32-33放到L35前面,然后去除L42-43.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int j = 0; j < probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) {
max_idx = j;
max_prob = probs_seq[i][j];
}
}
max_idx_vec.push_back(max_idx);
max_prob = 0.0;
max_idx = 0;
}

std::vector<int> idx_vec;
for (int i = 0; i < max_idx_vec.size(); i++) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}

std::string best_path_result;
for (int i = 0; i < idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
}

std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is too long. Could you try to separate it into multiple functions to make it more readable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const std::vector<std::vector<double>> &probs_seq,
int beam_size,
std::vector<std::string> vocabulary,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const &

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not use const & because there would be an error when using std::find for this variable. And hasn't fixed this problem yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Add const reference to vocabulary.
  2. Place const referenced arguments together.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no modification for vocabulary, please use const & instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int blank_id,
double cutoff_prob,
int cutoff_top_n,
Scorer *extscorer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. const &
  2. Put all const argument before others

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// dimension check
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap a function to do dimension check ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

size_t num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << " The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl;
exit(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use exit(1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

// blank_id check
if (blank_id > vocabulary.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在greedy decode中 blank_id 自动获取的,和这里用输入参数不同。两个decoder的这一逻辑能否保持一致?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the argument blank_id

std::cout << " Invalid blank_id! " << std::endl;
exit(1);
}

// assign space ID
std::vector<std::string>::iterator it =
std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if (space_id >= vocabulary.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare signed value with unsigned value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

space_id = -2;
}

// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);

if (extscorer != nullptr) {
if (extscorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary);
}
if (!extscorer->is_character_based()) {
if (extscorer->dictionary == nullptr) {
// fill dictionary for fst with space
extscorer->fill_dictionary(true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L94 to L101 could be all put into a setup method inside Scorer itself and make sure it is called outside decoder. That means a scorer must be well prepared before being passed into the decoder function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Copy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required by FST

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有点奇怪,check一下接口? Copy(true) 干什么用的?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处应该不是一个简单的指针赋值,而是根据状态不同返回不同的match type:
http://www.openfst.org/doxygen/fst/html/matcher_8h_source.html#l00041

root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
}

// prefix search over time
for (int time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use const std::vector<double>& prob instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<std::pair<int, double>> prob_idx;
for (int i = 0; i < prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
}

float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (extscorer != nullptr) {
int num_prefixes = std::min((int)prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) -
std::max(0.0, extscorer->beta);
full_beam = (num_prefixes == beam_size);
}

// pruning of vacobulary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

例如这一个代码段,如果只做一件事,是否可以抽出来形成一个函数。
函数要尽可能形成hierachy调用并且尽可能一个函数只做一件事。 一方面减少了单个函数的作用,另一方面以函数名来取代代码注释, 有利于提升可读性。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int cutoff_len = prob.size();
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we can use while loop to remove if

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel unnecessary

double cum_prob = 0.0;
cutoff_len = 0;
for (int i = 0; i < prob_idx.size(); i++) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob) break;
}
}
cutoff_len = std::min(cutoff_len, cutoff_top_n);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move line 143 into for loop as a stop condition

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part has been moved to a function

prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<int, float>> log_prob_idx;
for (int i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log( or std::log(?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--> std::log

}

// loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second;

for (int i = 0; i < prefixes.size() && i < beam_size; i++) {
auto prefix = prefixes[i];

if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);

if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;

if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}

// language model scoring
if (extscorer != nullptr &&
(c == space_id || extscorer->is_character_based())) {
PathTrie *prefix_toscore = nullptr;

// skip scoring the space
if (extscorer->is_character_based()) {
prefix_toscore = prefix_new;
} else {
prefix_toscore = prefix;
}

double score = 0.0;
std::vector<std::string> ngram;
ngram = extscorer->make_ngram(prefix_toscore);
score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha;

log_p += score;
log_p += extscorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over chars

prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);

// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);

for (size_t i = beam_size; i < prefixes.size(); i++) {
prefixes[i]->remove();
}
}
} // end of loop over time

// compute aproximate ctc score as the return score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put into a separated function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is temporarily kept to compare with speech-dl decoder, will be removed some time. And have already put the following part into a function

for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
double approx_ctc = prefixes[i]->score;

if (extscorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
size_t prefix_length = output.size();
auto words = extscorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * extscorer->beta;
// remove language model weight:
approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha;
}

prefixes[i]->approx_ctc = approx_ctc;
}

// allow for the post processing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put into a separated function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
space_prefixes.push_back(prefixes[i]);
}
}

std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (int j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
output_vecs.emplace_back(output_pair);
}

return output_vecs;
}

std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size,
const std::vector<std::string> &vocabulary,
int blank_id,
int num_processes,
double cutoff_prob,
int cutoff_top_n,
Scorer *extscorer) {
if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1);
}
// thread pool
ThreadPool pool(num_processes);
// number of samples
int batch_size = probs_split.size();

// scorer filling up
if (extscorer != nullptr) {
if (extscorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary);
}
if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) {
// init dictionary
extscorer->fill_dictionary(true);
}
}

// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i],
beam_size,
vocabulary,
blank_id,
cutoff_prob,
cutoff_top_n,
extscorer));
}

// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
Loading