204 lines
5.8 KiB
Python

import datetime as dt
import io
from typing import Iterator
import numpy as np
import pydub
import scipy.io.wavfile as wavfile
import torch
from elevenlabs.client import ElevenLabs
from faster_whisper import WhisperModel
from silero_vad import get_speech_timestamps, load_silero_vad
class AudioSegment(pydub.AudioSegment):
"""
Wrapper class for pydub.AudioSegment
"""
def __init__(self, data=None, *args, **kwargs):
super().__init__(data, *args, **kwargs)
def to_ndarray(self) -> np.ndarray:
"""
Convert the AudioSegment to a numpy array.
"""
_buffer = io.BytesIO()
self.export(_buffer, format="wav")
_buffer.seek(0)
_, wav_data = wavfile.read(_buffer)
if wav_data.dtype != np.float32:
max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1
wav_data = wav_data.astype(np.float32) / max_abs
if len(wav_data.shape) == 2:
wav_data = wav_data.mean(axis=1)
return wav_data
def to_tensor(self) -> torch.Tensor:
"""
Convert the AudioSegment to a PyTorch tensor.
"""
return torch.tensor(self.to_ndarray(), dtype=torch.float32)
class AudioBuffer:
"""
Buffer for AudioSegments
"""
def __init__(self, max_size: int) -> None:
self.data: AudioSegment | None = None
self.max_size: int = max_size
def put(self, audio_segment: AudioSegment) -> None:
"""
Append an AudioSegment to the buffer.
"""
if self.data:
self.data = self.data + audio_segment
if len(self.data) >= self.max_size:
self.data = self.data[-self.max_size :]
else:
self.data = audio_segment
def clear(self) -> None:
"""
Clear the buffer.
"""
self.data = None
def get(self) -> AudioSegment | None:
"""
Get the AudioSegment from the buffer.
"""
return self.data
def get_last(self, duration: int) -> AudioSegment | None:
"""
Get the last `duration` milliseconds of the AudioSegment.
"""
return self.data[-duration:] if self.data else None
def is_empty(self) -> bool:
"""
Check if the buffer is empty.
"""
return self.data is None
def length(self) -> int:
"""
Get the length of the buffer.
"""
return len(self.data) if self.data else 0
class VAD:
"""
Voice Activity Detection
"""
def __init__(self) -> None:
self.model = load_silero_vad()
self.ts_voice_detected: dt.datetime | None = None
def detect_voice(self, audio_segment: AudioSegment) -> bool:
"""
Detect voice in an AudioSegment.
"""
speech_timestamps = get_speech_timestamps(
audio_segment.to_tensor(),
self.model,
return_seconds=True,
)
if len(speech_timestamps) > 0:
if self.ts_voice_detected is None:
pass # print("VAD: Voice detected.")
self.ts_voice_detected = dt.datetime.now()
return True
return False
def ms_since_voice_detected(self) -> int | None:
"""
Get the milliseconds since voice was detected.
"""
if self.ts_voice_detected is None:
return None
return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000
def reset_timer(self):
"""
Reset the timer for voice detected.
"""
self.ts_voice_detected = None
class STT:
"""
Speech-to-Text
"""
def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None:
compute_type = "int8" if device == "cpu" else "float16"
self.model = WhisperModel(model, device=device, compute_type=compute_type)
self.benchmark = benchmark
def transcribe(self, audio_segment: AudioSegment, language="sv"):
"""
Transcribe an AudioSegment.
"""
if self.benchmark:
ts_start = dt.datetime.now()
segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language)
segments = list(segments)
text = " ".join([seg.text for seg in segments])
text = text.strip()
if self.benchmark:
ts_done = dt.datetime.now()
delta_ms = (ts_done - ts_start).total_seconds() * 1000
print(f"[Benchmark] STT: {delta_ms:.2f} ms.")
return text
class TTS:
"""
Text-to-Speech (TTS) class that uses the ElevenLabs API.
"""
def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None:
self.client = ElevenLabs()
self.voice = voice
self.model = model
def generate(self, text: str, voice: str = "default") -> AudioSegment:
if voice == "default":
voice = self.voice
audio = self.client.generate(
text=text,
voice=self.voice,
model=self.model,
stream=False,
output_format="pcm_16000",
optimize_streaming_latency=2,
)
audio = b"".join(audio)
audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1)
return audio_segment
def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]:
if voice == "default":
voice = self.voice
audio_stream = self.client.generate(
text=text_stream,
voice=self.voice,
model=self.model,
stream=True,
output_format="pcm_16000",
optimize_streaming_latency=2,
)
for chunk in audio_stream:
if chunk is not None:
yield chunk