diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index a516a53c5d..638e19498b 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -10,7 +10,17 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--min-torch-version", - help="Minimu torch version", + help="torch version", + ) + + parser.add_argument( + "--torch-version", + help="torch version", + ) + + parser.add_argument( + "--python-version", + help="python version", ) return parser.parse_args() @@ -52,7 +62,7 @@ def get_torchaudio_version(torch_version): return torch_version -def get_matrix(min_torch_version): +def get_matrix(min_torch_version, specified_torch_version, specified_python_version): k2_version = "1.24.4.dev20241029" kaldifeat_version = "1.25.5.dev20241029" version = "20241218" @@ -71,6 +81,12 @@ def get_matrix(min_torch_version): torch_version += ["2.5.0"] torch_version += ["2.5.1"] + if specified_torch_version: + torch_version = [specified_torch_version] + + if specified_python_version: + python_version = [specified_python_version] + matrix = [] for p in python_version: for t in torch_version: @@ -115,7 +131,11 @@ def get_matrix(min_torch_version): def main(): args = get_args() - matrix = get_matrix(min_torch_version=args.min_torch_version) + matrix = get_matrix( + min_torch_version=args.min_torch_version, + specified_torch_version=args.torch_version, + specified_python_version=args.python_version, + ) print(json.dumps({"include": matrix})) diff --git a/.github/scripts/librispeech/ASR/run_rknn.sh b/.github/scripts/librispeech/ASR/run_rknn.sh new file mode 100755 index 0000000000..3044717240 --- /dev/null +++ b/.github/scripts/librispeech/ASR/run_rknn.sh @@ -0,0 +1,200 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install kaldi-native-fbank soundfile librosa + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + + +# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed +# sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 +function export_bilingual_zh_en() { + d=exp_zh_en + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/3.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/4.wav + ls -lh + popd + + ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d/ \ + --decode-chunk-len 64 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + ls -lh $d/ + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/0.wav + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/1.wav + + mkdir -p /icefall/rknn-models + + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + mkdir -p $platform + + ./pruned_transducer_stateless7_streaming/export_rknn.py \ + --in-encoder $d/encoder-epoch-99-avg-1.onnx \ + --in-decoder $d/decoder-epoch-99-avg-1.onnx \ + --in-joiner $d/joiner-epoch-99-avg-1.onnx \ + --out-encoder $platform/encoder.rknn \ + --out-decoder $platform/decoder.rknn \ + --out-joiner $platform/joiner.rknn \ + --target-platform $platform 2>/dev/null + + ls -lh $platform/ + + ./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \ + --encoder $d/encoder-epoch-99-avg-1.onnx \ + --decoder $d/decoder-epoch-99-avg-1.onnx \ + --joiner $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + --wav $d/0.wav + + cp $d/tokens.txt $platform + cp $d/*.wav $platform + + cp -av $platform /icefall/rknn-models + done + + ls -lh /icefall/rknn-models +} + +# https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t +# sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16 +function export_bilingual_zh_en_small() { + d=exp_zh_en_small + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/data/lang_char_bpe/tokens.txt + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/3.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/4.wav + + ls -lh + + popd + + + ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d/ \ + --decode-chunk-len 64 \ + \ + --num-encoder-layers 2,2,2,2,2 \ + --feedforward-dims 768,768,768,768,768 \ + --nhead 4,4,4,4,4 \ + --encoder-dims 256,256,256,256,256 \ + --attention-dims 192,192,192,192,192 \ + --encoder-unmasked-dims 192,192,192,192,192 \ + \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + ls -lh $d/ + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/0.wav + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/1.wav + + mkdir -p /icefall/rknn-models-small + + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + mkdir -p $platform + + ./pruned_transducer_stateless7_streaming/export_rknn.py \ + --in-encoder $d/encoder-epoch-99-avg-1.onnx \ + --in-decoder $d/decoder-epoch-99-avg-1.onnx \ + --in-joiner $d/joiner-epoch-99-avg-1.onnx \ + --out-encoder $platform/encoder.rknn \ + --out-decoder $platform/decoder.rknn \ + --out-joiner $platform/joiner.rknn \ + --target-platform $platform 2>/dev/null + + ls -lh $platform/ + + ./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \ + --encoder $d/encoder-epoch-99-avg-1.onnx \ + --decoder $d/decoder-epoch-99-avg-1.onnx \ + --joiner $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + --wav $d/0.wav + + cp $d/tokens.txt $platform + cp $d/*.wav $platform + + cp -av $platform /icefall/rknn-models-small + done + + ls -lh /icefall/rknn-models-small +} + +export_bilingual_zh_en_small + +export_bilingual_zh_en diff --git a/.github/workflows/rknn.yml b/.github/workflows/rknn.yml new file mode 100644 index 0000000000..51aa4eb9b5 --- /dev/null +++ b/.github/workflows/rknn.yml @@ -0,0 +1,180 @@ +name: rknn + +on: + push: + branches: + - master + - ci-rknn-2 + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: rknn-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see /~https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10 + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10) + echo "::set-output name=matrix::${MATRIX}" + rknn: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python + if: false + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Export ONNX model + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + cat /etc/*release + lsb_release -a + uname -a + python3 --version + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + python3 -m torch.utils.collect_env + python3 -m k2.version + pip list + + + # Install rknn + curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ./*.whl "numpy<=1.26.4" + pip list | grep rknn + echo "---" + pip list + echo "---" + + .github/scripts/librispeech/ASR/run_rknn.sh + + - name: Display rknn models + shell: bash + run: | + ls -lh + + ls -lh rknn-models/* + echo "----" + ls -lh rknn-models-small/* + + - name: Collect results (small) + shell: bash + run: | + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + dst=sherpa-onnx-$platform-streaming-zipformer-small-bilingual-zh-en-2023-02-16 + mkdir $dst + mkdir $dst/test_wavs + src=rknn-models-small/$platform + + cp -v $src/*.rknn $dst/ + cp -v $src/tokens.txt $dst/ + cp -v $src/*.wav $dst/test_wavs/ + ls -lh $dst + tar cjfv $dst.tar.bz2 $dst + rm -rf $dst + done + + - name: Collect results + shell: bash + run: | + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + dst=sherpa-onnx-$platform-streaming-zipformer-bilingual-zh-en-2023-02-20 + mkdir $dst + mkdir $dst/test_wavs + src=rknn-models/$platform + + cp -v $src/*.rknn $dst/ + cp -v $src/tokens.txt $dst/ + cp -v $src/*.wav $dst/test_wavs/ + ls -lh $dst + tar cjfv $dst.tar.bz2 $dst + rm -rf $dst + done + + - name: Display results + shell: bash + run: | + ls -lh *rk*.tar.bz2 + + - name: Release to GitHub + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models + + - name: Upload model to huggingface + if: github.event_name == 'push' + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + + git clone https://huggingface.co/csukuangfj/sherpa-onnx-rknn-models huggingface + cd huggingface + + git fetch + git pull + git merge -m "merge remote" --ff origin main + dst=streaming-asr + mkdir -p $dst + rm -fv $dst/* + cp ../*rk*.tar.bz2 $dst/ + + ls -lh $dst + git add . + git status + git commit -m "update models" + git status + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-rknn-models main || true + rm -rf huggingface diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 2de56837e6..a4fbd93baf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -85,6 +85,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -257,6 +271,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """ Onnx model inputs: @@ -274,6 +289,8 @@ def export_encoder_model_onnx( The filename to save the exported ONNX model. opset_version: The opset version to use. + dynamic_batch: + True to export a model supporting dynamic batch size """ encoder_model.encoder.__class__.forward = ( @@ -379,7 +396,9 @@ def build_inputs_outputs(tensors, name, N): "encoder_out": {0: "N"}, **inputs, **outputs, - }, + } + if dynamic_batch + else {}, ) add_meta_data(filename=encoder_filename, meta_data=meta_data) @@ -389,6 +408,7 @@ def export_decoder_model_onnx( decoder_model: nn.Module, decoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the decoder model to ONNX format. @@ -412,7 +432,7 @@ def export_decoder_model_onnx( """ context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + y = torch.zeros(1, context_size, dtype=torch.int64) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -425,7 +445,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "context_size": str(context_size), @@ -438,6 +460,7 @@ def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the joiner model to ONNX format. The exported joiner model has two inputs: @@ -452,8 +475,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -470,7 +493,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -629,6 +654,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -638,6 +664,7 @@ def main(): decoder, decoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported decoder to {decoder_filename}") @@ -647,37 +674,39 @@ def main(): joiner, joiner_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported joiner to {joiner_filename}") # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py new file mode 100755 index 0000000000..cb872cca01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +import logging +from pathlib import Path +from typing import List + +from rknn.api import RKNN + +logging.basicConfig(level=logging.WARNING) + +g_platforms = [ + # "rv1103", + # "rv1103b", + # "rv1106", + # "rk2118", + "rk3562", + "rk3566", + "rk3568", + "rk3576", + "rk3588", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--target-platform", + type=str, + required=True, + help=f"Supported values are: {','.join(g_platforms)}", + ) + + parser.add_argument( + "--in-encoder", + type=str, + required=True, + help="Path to the encoder onnx model", + ) + + parser.add_argument( + "--in-decoder", + type=str, + required=True, + help="Path to the decoder onnx model", + ) + + parser.add_argument( + "--in-joiner", + type=str, + required=True, + help="Path to the joiner onnx model", + ) + + parser.add_argument( + "--out-encoder", + type=str, + required=True, + help="Path to the encoder rknn model", + ) + + parser.add_argument( + "--out-decoder", + type=str, + required=True, + help="Path to the decoder rknn model", + ) + + parser.add_argument( + "--out-joiner", + type=str, + required=True, + help="Path to the joiner rknn model", + ) + + return parser + + +def export_rknn(rknn, filename): + ret = rknn.export_rknn(filename) + if ret != 0: + exit("Export rknn model to {filename} failed!") + + +def init_model(filename: str, target_platform: str, custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config(target_platform=target_platform, custom_string=custom_string) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit("Build model {filename} failed!") + + return rknn + + +class MetaData: + def __init__( + self, + model_type: str, + attention_dims: List[int], + encoder_dims: List[int], + T: int, + left_context_len: List[int], + decode_chunk_len: int, + cnn_module_kernels: List[int], + num_encoder_layers: List[int], + context_size: int, + ): + self.model_type = model_type + self.attention_dims = attention_dims + self.encoder_dims = encoder_dims + self.T = T + self.left_context_len = left_context_len + self.decode_chunk_len = decode_chunk_len + self.cnn_module_kernels = cnn_module_kernels + self.num_encoder_layers = num_encoder_layers + self.context_size = context_size + + def __str__(self) -> str: + return self.to_str() + + def to_str(self) -> str: + def to_s(ll): + return ",".join(list(map(str, ll))) + + s = f"model_type={self.model_type}" + s += ";attention_dims=" + to_s(self.attention_dims) + s += ";encoder_dims=" + to_s(self.encoder_dims) + s += ";T=" + str(self.T) + s += ";left_context_len=" + to_s(self.left_context_len) + s += ";decode_chunk_len=" + str(self.decode_chunk_len) + s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels) + s += ";num_encoder_layers=" + to_s(self.num_encoder_layers) + s += ";context_size=" + str(self.context_size) + + assert len(s) < 1024, (s, len(s)) + + return s + + +def get_meta_data(encoder: str, decoder: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m_encoder = onnxruntime.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + m_decoder = onnxruntime.InferenceSession( + decoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + encoder_meta = m_encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) + + # {'attention_dims': '192,192,192,192,192', 'version': '1', + # 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256', + # 'model_author': 'k2-fsa', 'T': '103', + # 'left_context_len': '192,96,48,24,96', + # 'decode_chunk_len': '96', + # 'cnn_module_kernels': '31,31,31,31,31', + # 'num_encoder_layers': '2,2,2,2,2'} + + def to_int_list(s): + return list(map(int, s.split(","))) + + decoder_meta = m_decoder.get_modelmeta().custom_metadata_map + print(decoder_meta) + + model_type = encoder_meta["model_type"] + attention_dims = to_int_list(encoder_meta["attention_dims"]) + encoder_dims = to_int_list(encoder_meta["encoder_dims"]) + T = int(encoder_meta["T"]) + left_context_len = to_int_list(encoder_meta["left_context_len"]) + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + cnn_module_kernels = to_int_list(encoder_meta["cnn_module_kernels"]) + num_encoder_layers = to_int_list(encoder_meta["num_encoder_layers"]) + context_size = int(decoder_meta["context_size"]) + + return MetaData( + model_type=model_type, + attention_dims=attention_dims, + encoder_dims=encoder_dims, + T=T, + left_context_len=left_context_len, + decode_chunk_len=decode_chunk_len, + cnn_module_kernels=cnn_module_kernels, + num_encoder_layers=num_encoder_layers, + context_size=context_size, + ) + + +class RKNNModel: + def __init__( + self, + encoder: str, + decoder: str, + joiner: str, + target_platform: str, + ): + self.meta = get_meta_data(encoder, decoder) + self.encoder = init_model( + encoder, + custom_string=self.meta.to_str(), + target_platform=target_platform, + ) + self.decoder = init_model(decoder, target_platform=target_platform) + self.joiner = init_model(joiner, target_platform=target_platform) + + def export_rknn(self, encoder, decoder, joiner): + export_rknn(self.encoder, encoder) + export_rknn(self.decoder, decoder) + export_rknn(self.joiner, joiner) + + def release(self): + self.encoder.release() + self.decoder.release() + self.joiner.release() + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + model = RKNNModel( + encoder=args.in_encoder, + decoder=args.in_decoder, + joiner=args.in_joiner, + target_platform=args.target_platform, + ) + print(model.meta) + + model.export_rknn( + encoder=args.out_encoder, + decoder=args.out_decoder, + joiner=args.out_joiner, + ) + + model.release() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 298d1889b0..e5e5136714 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -132,10 +132,18 @@ def init_encoder(self, encoder_model_filename: str): sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Encoder input==========") + for i in self.encoder.get_inputs(): + print(i) + print("==========Encoder output==========") + for i in self.encoder.get_outputs(): + print(i) + self.init_encoder_states() def init_encoder_states(self, batch_size: int = 1): encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) model_type = encoder_meta["model_type"] assert model_type == "zipformer", model_type @@ -232,6 +240,12 @@ def init_decoder(self, decoder_model_filename: str): sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Decoder input==========") + for i in self.decoder.get_inputs(): + print(i) + print("==========Decoder output==========") + for i in self.decoder.get_outputs(): + print(i) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map self.context_size = int(decoder_meta["context_size"]) @@ -247,6 +261,13 @@ def init_joiner(self, joiner_model_filename: str): providers=["CPUExecutionProvider"], ) + print("==========Joiner input==========") + for i in self.joiner.get_inputs(): + print(i) + print("==========Joiner output==========") + for i in self.joiner.get_outputs(): + print(i) + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map self.joiner_dim = int(joiner_meta["joiner_dim"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py new file mode 100755 index 0000000000..a543c6083e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +from pathlib import Path +from typing import List, Tuple + +import kaldi_native_fbank as knf +import numpy as np +import soundfile as sf +from rknn.api import RKNN + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder onnx model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder onnx model", + ) + + parser.add_argument( + "--joiner", + type=str, + required=True, + help="Path to the joiner onnx model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--wav", + type=str, + required=True, + help="Path to test wave", + ) + + return parser + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_features(filename: str, dim: int = 80) -> np.ndarray: + """ + Args: + filename: + Path to an audio file. + Returns: + Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. + """ + wave, sample_rate = load_audio(filename) + if sample_rate != 16000: + import librosa + + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 + + features = [] + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = dim + opts.frame_opts.snip_edges = False + fbank = knf.OnlineFbank(opts) + + fbank.accept_waveform(16000, wave) + tail_paddings = np.zeros(int(0.5 * 16000), dtype=np.float32) + fbank.accept_waveform(16000, tail_paddings) + fbank.input_finished() + for i in range(fbank.num_frames_ready): + f = fbank.get_frame(i) + features.append(f) + + features = np.stack(features, axis=0) + + return features + + +def load_tokens(filename): + tokens = dict() + with open(filename, "r") as f: + for line in f: + t, i = line.split() + tokens[int(i)] = t + return tokens + + +def init_model(filename, target_platform="rk3588", custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config(target_platform=target_platform, custom_string=custom_string) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit("Build model {filename} failed!") + + ret = rknn.init_runtime() + if ret != 0: + exit(f"Failed to init rknn runtime for {filename}") + return rknn + + +class MetaData: + def __init__( + self, + model_type: str, + attention_dims: List[int], + encoder_dims: List[int], + T: int, + left_context_len: List[int], + decode_chunk_len: int, + cnn_module_kernels: List[int], + num_encoder_layers: List[int], + ): + self.model_type = model_type + self.attention_dims = attention_dims + self.encoder_dims = encoder_dims + self.T = T + self.left_context_len = left_context_len + self.decode_chunk_len = decode_chunk_len + self.cnn_module_kernels = cnn_module_kernels + self.num_encoder_layers = num_encoder_layers + + def __str__(self) -> str: + return self.to_str() + + def to_str(self) -> str: + def to_s(ll): + return ",".join(list(map(str, ll))) + + s = f"model_type={self.model_type}" + s += ";attention_dims=" + to_s(self.attention_dims) + s += ";encoder_dims=" + to_s(self.encoder_dims) + s += ";T=" + str(self.T) + s += ";left_context_len=" + to_s(self.left_context_len) + s += ";decode_chunk_len=" + str(self.decode_chunk_len) + s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels) + s += ";num_encoder_layers=" + to_s(self.num_encoder_layers) + + assert len(s) < 1024, (s, len(s)) + + return s + + +def get_meta_data(encoder: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m = onnxruntime.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = m.get_modelmeta().custom_metadata_map + print(meta) + # {'attention_dims': '192,192,192,192,192', 'version': '1', + # 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256', + # 'model_author': 'k2-fsa', 'T': '103', + # 'left_context_len': '192,96,48,24,96', + # 'decode_chunk_len': '96', + # 'cnn_module_kernels': '31,31,31,31,31', + # 'num_encoder_layers': '2,2,2,2,2'} + + def to_int_list(s): + return list(map(int, s.split(","))) + + model_type = meta["model_type"] + attention_dims = to_int_list(meta["attention_dims"]) + encoder_dims = to_int_list(meta["encoder_dims"]) + T = int(meta["T"]) + left_context_len = to_int_list(meta["left_context_len"]) + decode_chunk_len = int(meta["decode_chunk_len"]) + cnn_module_kernels = to_int_list(meta["cnn_module_kernels"]) + num_encoder_layers = to_int_list(meta["num_encoder_layers"]) + + return MetaData( + model_type=model_type, + attention_dims=attention_dims, + encoder_dims=encoder_dims, + T=T, + left_context_len=left_context_len, + decode_chunk_len=decode_chunk_len, + cnn_module_kernels=cnn_module_kernels, + num_encoder_layers=num_encoder_layers, + ) + + +class RKNNModel: + def __init__( + self, encoder: str, decoder: str, joiner: str, target_platform="rk3588" + ): + self.meta = get_meta_data(encoder) + self.encoder = init_model(encoder, custom_string=self.meta.to_str()) + self.decoder = init_model(decoder) + self.joiner = init_model(joiner) + + def release(self): + self.encoder.release() + self.decoder.release() + self.joiner.release() + + def get_init_states( + self, + ) -> List[np.ndarray]: + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + num_encoder_layers = self.meta.num_encoder_layers + encoder_dims = self.meta.encoder_dims + left_context_len = self.meta.left_context_len + attention_dims = self.meta.attention_dims + cnn_module_kernels = self.meta.cnn_module_kernels + + num_encoders = len(num_encoder_layers) + N = 1 + + for i in range(num_encoders): + cached_len.append(np.zeros((num_encoder_layers[i], N), dtype=np.int64)) + cached_avg.append( + np.zeros((num_encoder_layers[i], N, encoder_dims[i]), dtype=np.float32) + ) + cached_key.append( + np.zeros( + (num_encoder_layers[i], left_context_len[i], N, attention_dims[i]), + dtype=np.float32, + ) + ) + + cached_val.append( + np.zeros( + ( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ), + dtype=np.float32, + ) + ) + cached_val2.append( + np.zeros( + ( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ), + dtype=np.float32, + ) + ) + cached_conv1.append( + np.zeros( + ( + num_encoder_layers[i], + N, + encoder_dims[i], + cnn_module_kernels[i] - 1, + ), + dtype=np.float32, + ) + ) + cached_conv2.append( + np.zeros( + ( + num_encoder_layers[i], + N, + encoder_dims[i], + cnn_module_kernels[i] - 1, + ), + dtype=np.float32, + ) + ) + + ans = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + # for i, s in enumerate(ans): + # if s.ndim == 4: + # ans[i] = np.transpose(s, (0, 2, 3, 1)) + return ans + + def run_encoder(self, x: np.ndarray, states: List[np.ndarray]): + """ + Args: + x: (T, C), np.float32 + states: A list of states + """ + x = np.expand_dims(x, axis=0) + + out = self.encoder.inference(inputs=[x] + states, data_format="nchw") + # out[0], encoder_out, shape (1, 24, 512) + return out[0], out[1:] + + def run_decoder(self, x: np.ndarray): + """ + Args: + x: (1, context_size), np.int64 + Returns: + Return decoder_out, (1, C), np.float32 + """ + return self.decoder.inference(inputs=[x])[0] + + def run_joiner(self, encoder_out: np.ndarray, decoder_out: np.ndarray): + """ + Args: + encoder_out: (1, encoder_out_dim), np.float32 + decoder_out: (1, decoder_out_dim), np.float32 + Returns: + joiner_out: (1, vocab_size), np.float32 + """ + return self.joiner.inference(inputs=[encoder_out, decoder_out])[0] + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + id2token = load_tokens(args.tokens) + features = compute_features(args.wav) + model = RKNNModel( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + ) + print(model.meta) + + states = model.get_init_states() + + segment = model.meta.T + offset = model.meta.decode_chunk_len + + context_size = 2 + hyp = [0] * context_size + decoder_input = np.array([hyp], dtype=np.int64) + decoder_out = model.run_decoder(decoder_input) + + i = 0 + while True: + if i + segment > features.shape[0]: + break + x = features[i : i + segment] + i += offset + encoder_out, states = model.run_encoder(x, states) + encoder_out = encoder_out.squeeze(0) # (1, T, C) -> (T, C) + + num_frames = encoder_out.shape[0] + for k in range(num_frames): + joiner_out = model.run_joiner(encoder_out[k : k + 1], decoder_out) + joiner_out = joiner_out.squeeze(0) + max_token_id = joiner_out.argmax() + + # assume 0 is the blank id + if max_token_id != 0: + hyp.append(max_token_id) + decoder_input = np.array([hyp[-context_size:]], dtype=np.int64) + decoder_out = model.run_decoder(decoder_input) + print(hyp) + final_hyp = hyp[context_size:] + print(final_hyp) + text = "".join([id2token[i] for i in final_hyp]) + text = text.replace("▁", " ") + print(text) + + +if __name__ == "__main__": + main()