204 lines
5.8 KiB
Python
204 lines
5.8 KiB
Python
import datetime as dt
|
|
import io
|
|
from typing import Iterator
|
|
|
|
import numpy as np
|
|
import pydub
|
|
import scipy.io.wavfile as wavfile
|
|
import torch
|
|
from elevenlabs.client import ElevenLabs
|
|
from faster_whisper import WhisperModel
|
|
from silero_vad import get_speech_timestamps, load_silero_vad
|
|
|
|
|
|
class AudioSegment(pydub.AudioSegment):
|
|
"""
|
|
Wrapper class for pydub.AudioSegment
|
|
"""
|
|
|
|
def __init__(self, data=None, *args, **kwargs):
|
|
super().__init__(data, *args, **kwargs)
|
|
|
|
def to_ndarray(self) -> np.ndarray:
|
|
"""
|
|
Convert the AudioSegment to a numpy array.
|
|
"""
|
|
_buffer = io.BytesIO()
|
|
self.export(_buffer, format="wav")
|
|
_buffer.seek(0)
|
|
|
|
_, wav_data = wavfile.read(_buffer)
|
|
if wav_data.dtype != np.float32:
|
|
max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1
|
|
wav_data = wav_data.astype(np.float32) / max_abs
|
|
if len(wav_data.shape) == 2:
|
|
wav_data = wav_data.mean(axis=1)
|
|
return wav_data
|
|
|
|
def to_tensor(self) -> torch.Tensor:
|
|
"""
|
|
Convert the AudioSegment to a PyTorch tensor.
|
|
"""
|
|
return torch.tensor(self.to_ndarray(), dtype=torch.float32)
|
|
|
|
|
|
class AudioBuffer:
|
|
"""
|
|
Buffer for AudioSegments
|
|
"""
|
|
|
|
def __init__(self, max_size: int) -> None:
|
|
self.data: AudioSegment | None = None
|
|
self.max_size: int = max_size
|
|
|
|
def put(self, audio_segment: AudioSegment) -> None:
|
|
"""
|
|
Append an AudioSegment to the buffer.
|
|
"""
|
|
if self.data:
|
|
self.data = self.data + audio_segment
|
|
if len(self.data) >= self.max_size:
|
|
self.data = self.data[-self.max_size :]
|
|
else:
|
|
self.data = audio_segment
|
|
|
|
def clear(self) -> None:
|
|
"""
|
|
Clear the buffer.
|
|
"""
|
|
self.data = None
|
|
|
|
def get(self) -> AudioSegment | None:
|
|
"""
|
|
Get the AudioSegment from the buffer.
|
|
"""
|
|
return self.data
|
|
|
|
def get_last(self, duration: int) -> AudioSegment | None:
|
|
"""
|
|
Get the last `duration` milliseconds of the AudioSegment.
|
|
"""
|
|
return self.data[-duration:] if self.data else None
|
|
|
|
def is_empty(self) -> bool:
|
|
"""
|
|
Check if the buffer is empty.
|
|
"""
|
|
return self.data is None
|
|
|
|
def length(self) -> int:
|
|
"""
|
|
Get the length of the buffer.
|
|
"""
|
|
return len(self.data) if self.data else 0
|
|
|
|
|
|
class VAD:
|
|
"""
|
|
Voice Activity Detection
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.model = load_silero_vad()
|
|
self.ts_voice_detected: dt.datetime | None = None
|
|
|
|
def detect_voice(self, audio_segment: AudioSegment) -> bool:
|
|
"""
|
|
Detect voice in an AudioSegment.
|
|
"""
|
|
speech_timestamps = get_speech_timestamps(
|
|
audio_segment.to_tensor(),
|
|
self.model,
|
|
return_seconds=True,
|
|
)
|
|
if len(speech_timestamps) > 0:
|
|
if self.ts_voice_detected is None:
|
|
pass # print("VAD: Voice detected.")
|
|
self.ts_voice_detected = dt.datetime.now()
|
|
return True
|
|
|
|
return False
|
|
|
|
def ms_since_voice_detected(self) -> int | None:
|
|
"""
|
|
Get the milliseconds since voice was detected.
|
|
"""
|
|
if self.ts_voice_detected is None:
|
|
return None
|
|
return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000
|
|
|
|
def reset_timer(self):
|
|
"""
|
|
Reset the timer for voice detected.
|
|
"""
|
|
self.ts_voice_detected = None
|
|
|
|
|
|
class STT:
|
|
"""
|
|
Speech-to-Text
|
|
"""
|
|
|
|
def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None:
|
|
compute_type = "int8" if device == "cpu" else "float16"
|
|
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
|
self.benchmark = benchmark
|
|
|
|
def transcribe(self, audio_segment: AudioSegment, language="sv"):
|
|
"""
|
|
Transcribe an AudioSegment.
|
|
"""
|
|
if self.benchmark:
|
|
ts_start = dt.datetime.now()
|
|
segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language)
|
|
segments = list(segments)
|
|
text = " ".join([seg.text for seg in segments])
|
|
text = text.strip()
|
|
if self.benchmark:
|
|
ts_done = dt.datetime.now()
|
|
delta_ms = (ts_done - ts_start).total_seconds() * 1000
|
|
print(f"[Benchmark] STT: {delta_ms:.2f} ms.")
|
|
return text
|
|
|
|
|
|
class TTS:
|
|
"""
|
|
Text-to-Speech (TTS) class that uses the ElevenLabs API.
|
|
"""
|
|
|
|
def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None:
|
|
self.client = ElevenLabs()
|
|
self.voice = voice
|
|
self.model = model
|
|
|
|
def generate(self, text: str, voice: str = "default") -> AudioSegment:
|
|
if voice == "default":
|
|
voice = self.voice
|
|
audio = self.client.generate(
|
|
text=text,
|
|
voice=self.voice,
|
|
model=self.model,
|
|
stream=False,
|
|
output_format="pcm_16000",
|
|
optimize_streaming_latency=2,
|
|
)
|
|
audio = b"".join(audio)
|
|
audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1)
|
|
return audio_segment
|
|
|
|
def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]:
|
|
if voice == "default":
|
|
voice = self.voice
|
|
audio_stream = self.client.generate(
|
|
text=text_stream,
|
|
voice=self.voice,
|
|
model=self.model,
|
|
stream=True,
|
|
output_format="pcm_16000",
|
|
optimize_streaming_latency=2,
|
|
)
|
|
|
|
for chunk in audio_stream:
|
|
if chunk is not None:
|
|
yield chunk
|