Skip to content

Commit

Permalink
wip ASR - building up test framework
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Dec 12, 2024
1 parent aaeea81 commit bd8afde
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 0 deletions.
82 changes: 82 additions & 0 deletions asr_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env bash
# asr_test.sh
#
# Automatic Speech Recognition (ASR) testing for the Radio Autoencoder

CODEC2_DEV=${CODEC2_DEV:-${HOME}/codec2-dev}
PATH=${PATH}:${CODEC2_DEV}/build_linux/src:${CODEC2_DEV}/build_linux/misc:${PWD}/build/src

which ch >/dev/null || { printf "\n**** Can't find ch - check CODEC2_PATH **** \n\n"; exit 1; }

source utils.sh

function print_help {
echo
echo "Automated Speech Recognition (ASR) testing for the Radio Autoencoder"
echo
echo " usage ./asr_test.sh path/to/source dest [test option below]"
echo " usage ./ota_test.sh ~/.cache/LibriSpeech/test-clean ~/.cache/LibriSpeech/test-awgn-2dB --awgn 2"
echo
echo " --awgn SNRdB AWGN channel simulation"
echo " -d verbose debug information"
exit
}

POSITIONAL=()
while [[ $# -gt 0 ]]
do
key="$1"
case $key in
--awgn)
awgn_snr_dB="$2"
shift
shift
;;
-d)
set -x;
shift
;;
-h)
print_help
;;
*)
POSITIONAL+=("$1") # save it in an array for later
shift
;;
esac
done
set -- "${POSITIONAL[@]}" # restore positional parameters

if [ $# -lt 2 ]; then
print_help
fi

source=$1
dest=$2

# cp translation files to new test directory
function cp_translation_files {
pushd $source; trans=$(find . -name '*.txt'); popd
for f in $trans
do
d=$(dirname $f)
mkdir -p ${dest}/${d}
cp ${source}/${f} ${dest}/${f}
done
}

# process audio files and place in new test directory
function process {
pushd $source; flac=$(find . -name '*.flac'); popd
for f in $flac
do
d=$(dirname $f)
mkdir -p ${dest}/${d}
sox ${source}/${f} -t .s16 -r 8000 - | ch - - --No -30 | sox -t .s16 -r 8000 -c 1 - -r 16000 ${dest}/${f}
done
pwd
}

process
#mkidr -p ${test-clean}

145 changes: 145 additions & 0 deletions asr_wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding: utf-8

# # Installing Whisper
#
# The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results.

# In[1]:


#get_ipython().system(' pip install git+/~https://github.com/openai/whisper.git')
#get_ipython().system(' pip install jiwer')


# # Loading the LibriSpeech dataset
#
# The following will load the test-clean split of the LibriSpeech corpus using torchaudio.

# In[2]:


import os,argparse
import numpy as np

#try:
# import tensorflow # required in Colab to avoid protobuf compatibility issues
#except ImportError:
# pass

import torch
import pandas as pd
import whisper
import torchaudio

from tqdm.notebook import tqdm


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# In[3]:


class LibriSpeech(torch.utils.data.Dataset):
"""
A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
It will drop the last few seconds of a very small portion of the utterances.
"""
def __init__(self, split="test-clean", device=DEVICE):
self.dataset = torchaudio.datasets.LIBRISPEECH(
root=os.path.expanduser("~/.cache"),
url=split,
download=True,
)
self.device = device

def __len__(self):
return len(self.dataset)

def __getitem__(self, item):
audio, sample_rate, text, _, _, _ = self.dataset[item]
assert sample_rate == 16000
audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
mel = whisper.log_mel_spectrogram(audio)

return (mel, text)


parser = argparse.ArgumentParser()
parser.add_argument('--test_name', default="test-clean", type=str, help='Librispeech test name')
args = parser.parse_args()

print("start");
dataset = LibriSpeech(args.test_name)
print("dataset")
loader = torch.utils.data.DataLoader(dataset, batch_size=16)
print("loader")

# # Running inference on the dataset using a base Whisper model
#
# The following will take a few minutes to transcribe all utterances in the dataset.

# In[5]:


model = whisper.load_model("base.en")
print(
f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)


# In[6]:


# predict without timestamps for short-form transcription
options = whisper.DecodingOptions(language="en", without_timestamps=True)


# In[7]:


hypotheses = []
references = []

for mels, texts in loader:
results = model.decode(mels, options)
hypotheses.extend([result.text for result in results])
references.extend(texts)


# In[8]:


data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))
data


# # Calculating the word error rate
#
# Now, we use our English normalizer implementation to standardize the transcription and calculate the WER.

# In[9]:


import jiwer
from whisper.normalizers import EnglishTextNormalizer

normalizer = EnglishTextNormalizer()


# In[10]:


data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]]
data["reference_clean"] = [normalizer(text) for text in data["reference"]]
print(data)


# In[11]:


wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

print(f"WER: {wer * 100:.2f} %")

0 comments on commit bd8afde

Please sign in to comment.