Skip to content

Commit

Permalink
dont let the thread die
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Oct 21, 2024
1 parent 66533a2 commit 75b364b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 30 deletions.
2 changes: 1 addition & 1 deletion arguments_classes/module_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ModuleArguments:
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
"help": "The mode to run the pipeline in. Either 'local', 'socket', or 'none'. Default is 'socket'."
},
)
local_mac_optimal_settings: bool = field(
Expand Down
100 changes: 73 additions & 27 deletions audio_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,30 @@
from queue import Queue
import sounddevice as sd
import numpy as np
import requests
import base64
import time
from dataclasses import dataclass, field
import websocket
import threading
import ssl


@dataclass
class AudioStreamingClientArguments:
sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 512."})
api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})
sample_rate: int = field(
default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."}
)
chunk_size: int = field(
default=512,
metadata={"help": "The size of audio chunks in samples. Default is 512."},
)
api_url: str = field(
default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud",
metadata={"help": "The URL of the API endpoint."},
)
auth_token: str = field(
default="your_auth_token",
metadata={"help": "Authentication token for the API."},
)


class AudioStreamingClient:
def __init__(self, args: AudioStreamingClientArguments):
Expand All @@ -27,9 +37,11 @@ def __init__(self, args: AudioStreamingClientArguments):
self.headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.args.auth_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
self.session_state = "idle" # Possible states: idle, sending, processing, waiting
self.session_state = (
"idle" # Possible states: idle, sending, processing, waiting
)
self.ws_ready = threading.Event()

def start(self):
Expand All @@ -43,12 +55,14 @@ def start(self):
on_open=self.on_open,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close
on_close=self.on_close,
)

ws_thread = threading.Thread(target=self.ws.run_forever, kwargs={'sslopt': {"cert_reqs": ssl.CERT_NONE}})
ws_thread.start()

self.ws_thread = threading.Thread(
target=self.ws.run_forever, kwargs={"sslopt": {"cert_reqs": ssl.CERT_NONE}}
)
self.ws_thread.start()

# Wait for the WebSocket to be ready
self.ws_ready.wait()
self.start_audio_streaming()
Expand All @@ -57,17 +71,25 @@ def start_audio_streaming(self):
self.send_thread = threading.Thread(target=self.send_audio)
self.play_thread = threading.Thread(target=self.play_audio)

with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_input_callback, blocksize=self.args.chunk_size):
with sd.InputStream(
samplerate=self.args.sample_rate,
channels=1,
dtype="int16",
callback=self.audio_input_callback,
blocksize=self.args.chunk_size,
):
self.send_thread.start()
self.play_thread.start()
input("Press Enter to stop streaming...")
self.on_shutdown()

def on_open(self, ws):
print("WebSocket connection opened.")
self.ws_ready.set() # Signal that the WebSocket is ready

def on_message(self, ws, message):
# message is bytes
if message == b'DONE':
if message == b"DONE":
print("listen")
self.session_state = "listen"
else:
Expand Down Expand Up @@ -97,7 +119,7 @@ def send_audio(self):
if self.session_state != "processing":
self.ws.send(chunk.tobytes(), opcode=websocket.ABNF.OPCODE_BINARY)
else:
self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
time.sleep(0.01)

def audio_input_callback(self, indata, frames, time, status):
Expand All @@ -106,33 +128,57 @@ def audio_input_callback(self, indata, frames, time, status):
def audio_out_callback(self, outdata, frames, time, status):
if not self.recv_queue.empty():
chunk = self.recv_queue.get()

# Ensure chunk is int16 and clip to valid range
chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)

if len(chunk_int16) < len(outdata):
outdata[:len(chunk_int16), 0] = chunk_int16
outdata[len(chunk_int16):] = 0
outdata[: len(chunk_int16), 0] = chunk_int16
outdata[len(chunk_int16) :] = 0
else:
outdata[:, 0] = chunk_int16[:len(outdata)]
outdata[:, 0] = chunk_int16[: len(outdata)]
else:
outdata[:] = 0

def play_audio(self):
with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_out_callback, blocksize=self.args.chunk_size):
with sd.OutputStream(
samplerate=self.args.sample_rate,
channels=1,
dtype="int16",
callback=self.audio_out_callback,
blocksize=self.args.chunk_size,
):
while not self.stop_event.is_set():
time.sleep(0.1)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Audio Streaming Client")
parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.")
parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.")
parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.")
parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.")
parser.add_argument(
"--sample_rate",
type=int,
default=16000,
help="Audio sample rate in Hz. Default is 16000.",
)
parser.add_argument(
"--chunk_size",
type=int,
default=1024,
help="The size of audio chunks in samples. Default is 1024.",
)
parser.add_argument(
"--api_url", type=str, required=True, help="The URL of the API endpoint."
)
parser.add_argument(
"--auth_token",
type=str,
required=True,
help="Authentication token for the API.",
)

args = parser.parse_args()
client_args = AudioStreamingClientArguments(**vars(args))
client = AudioStreamingClient(client_args)
client.start()
client.start()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ funasr>=1.1.6
faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
openai>=1.40.1
websocket-client>=1.8.0
3 changes: 2 additions & 1 deletion requirements_mac.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ funasr>=1.1.6
faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
openai>=1.40.1
websocket-client>=1.8.0

0 comments on commit 75b364b

Please sign in to comment.