diff --git a/openspeech/models/contextnet/model.py b/openspeech/models/contextnet/model.py index 2a73d9f..6de8f3d 100644 --- a/openspeech/models/contextnet/model.py +++ b/openspeech/models/contextnet/model.py @@ -56,7 +56,13 @@ class ContextNetModel(OpenspeechCTCModel): """ def __init__(self, configs: DictConfig, tokenizer: Tokenizer) -> None: super(ContextNetModel, self).__init__(configs, tokenizer) - self.fc = Linear(self.configs.model.encoder_dim, self.num_classes, bias=False) + supported_models = { + 'small': 0.5, + 'medium': 1, + 'large': 2, + } + alpha = supported_models[self.configs.model.model_size] + self.fc = Linear(int(self.configs.model.encoder_dim * alpha), self.num_classes, bias=False) def build_model(self): self.encoder = ContextNetEncoder(