import datetime as dt import io from typing import Iterator import numpy as np import pydub import scipy.io.wavfile as wavfile import torch from elevenlabs.client import ElevenLabs from faster_whisper import WhisperModel from silero_vad import get_speech_timestamps, load_silero_vad class AudioSegment(pydub.AudioSegment): """ Wrapper class for pydub.AudioSegment """ def __init__(self, data=None, *args, **kwargs): super().__init__(data, *args, **kwargs) def to_ndarray(self) -> np.ndarray: """ Convert the AudioSegment to a numpy array. """ _buffer = io.BytesIO() self.export(_buffer, format="wav") _buffer.seek(0) _, wav_data = wavfile.read(_buffer) if wav_data.dtype != np.float32: max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1 wav_data = wav_data.astype(np.float32) / max_abs if len(wav_data.shape) == 2: wav_data = wav_data.mean(axis=1) return wav_data def to_tensor(self) -> torch.Tensor: """ Convert the AudioSegment to a PyTorch tensor. """ return torch.tensor(self.to_ndarray(), dtype=torch.float32) class AudioBuffer: """ Buffer for AudioSegments """ def __init__(self, max_size: int) -> None: self.data: AudioSegment | None = None self.max_size: int = max_size def put(self, audio_segment: AudioSegment) -> None: """ Append an AudioSegment to the buffer. """ if self.data: self.data = self.data + audio_segment if len(self.data) >= self.max_size: self.data = self.data[-self.max_size :] else: self.data = audio_segment def clear(self) -> None: """ Clear the buffer. """ self.data = None def get(self) -> AudioSegment | None: """ Get the AudioSegment from the buffer. """ return self.data def get_last(self, duration: int) -> AudioSegment | None: """ Get the last `duration` milliseconds of the AudioSegment. """ return self.data[-duration:] if self.data else None def is_empty(self) -> bool: """ Check if the buffer is empty. """ return self.data is None def length(self) -> int: """ Get the length of the buffer. """ return len(self.data) if self.data else 0 class VAD: """ Voice Activity Detection """ def __init__(self) -> None: self.model = load_silero_vad() self.ts_voice_detected: dt.datetime | None = None def detect_voice(self, audio_segment: AudioSegment) -> bool: """ Detect voice in an AudioSegment. """ speech_timestamps = get_speech_timestamps( audio_segment.to_tensor(), self.model, return_seconds=True, ) if len(speech_timestamps) > 0: if self.ts_voice_detected is None: pass # print("VAD: Voice detected.") self.ts_voice_detected = dt.datetime.now() return True return False def ms_since_voice_detected(self) -> int | None: """ Get the milliseconds since voice was detected. """ if self.ts_voice_detected is None: return None return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000 def reset_timer(self): """ Reset the timer for voice detected. """ self.ts_voice_detected = None class STT: """ Speech-to-Text """ def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None: compute_type = "int8" if device == "cpu" else "float16" self.model = WhisperModel(model, device=device, compute_type=compute_type) self.benchmark = benchmark def transcribe(self, audio_segment: AudioSegment, language="sv"): """ Transcribe an AudioSegment. """ if self.benchmark: ts_start = dt.datetime.now() segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language) segments = list(segments) text = " ".join([seg.text for seg in segments]) text = text.strip() if self.benchmark: ts_done = dt.datetime.now() delta_ms = (ts_done - ts_start).total_seconds() * 1000 print(f"[Benchmark] STT: {delta_ms:.2f} ms.") return text class TTS: """ Text-to-Speech (TTS) class that uses the ElevenLabs API. """ def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None: self.client = ElevenLabs() self.voice = voice self.model = model def generate(self, text: str, voice: str = "default") -> AudioSegment: if voice == "default": voice = self.voice audio = self.client.generate( text=text, voice=self.voice, model=self.model, stream=False, output_format="pcm_16000", optimize_streaming_latency=2, ) audio = b"".join(audio) audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1) return audio_segment def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]: if voice == "default": voice = self.voice audio_stream = self.client.generate( text=text_stream, voice=self.voice, model=self.model, stream=True, output_format="pcm_16000", optimize_streaming_latency=2, ) for chunk in audio_stream: if chunk is not None: yield chunk