284 lines
11 KiB
Python
284 lines
11 KiB
Python
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import math
|
|
import os
|
|
import wave
|
|
from enum import StrEnum
|
|
import traceback
|
|
|
|
import aioesphomeapi
|
|
from aioesphomeapi import (VoiceAssistantAudioSettings,
|
|
VoiceAssistantCommandFlag, VoiceAssistantEventType)
|
|
from aiohttp import web
|
|
|
|
from audio import STT, TTS, VAD, AudioBuffer, AudioSegment
|
|
from llm_api import LlmApi
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
routes = web.RouteTableDef()
|
|
|
|
|
|
class PipelineStage(StrEnum):
|
|
"""Stages of a pipeline."""
|
|
|
|
WAKE_WORD = "wake_word"
|
|
STT = "stt"
|
|
INTENT = "intent"
|
|
TTS = "tts"
|
|
END = "end"
|
|
|
|
|
|
class Satellite:
|
|
|
|
def __init__(self, host="192.168.10.155", port=6053, password=""):
|
|
self.client = aioesphomeapi.APIClient(host, port, password)
|
|
self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000)
|
|
self.state = "idle"
|
|
self.vad = VAD()
|
|
self.stt = STT(device="cuda")
|
|
self.tts = TTS()
|
|
self.agent = LlmApi()
|
|
|
|
async def connect(self):
|
|
await self.client.connect(login=True)
|
|
|
|
async def disconnect(self):
|
|
await self.client.disconnect()
|
|
|
|
async def handle_pipeline_start(
|
|
self,
|
|
conversation_id: str,
|
|
flags: int,
|
|
audio_settings: VoiceAssistantAudioSettings,
|
|
wake_word_phrase: str | None,
|
|
) -> int:
|
|
logging.debug(
|
|
f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}"
|
|
)
|
|
|
|
# Device triggered pipeline (wake word, etc.)
|
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
|
start_stage = PipelineStage.WAKE_WORD
|
|
else:
|
|
start_stage = PipelineStage.STT
|
|
|
|
end_stage = PipelineStage.TTS
|
|
|
|
logging.info(f"Starting pipeline from {start_stage} to {end_stage}")
|
|
self.state = "running_pipeline"
|
|
|
|
# Run pipeline
|
|
asyncio.create_task(self.run_pipeline(wake_word_phrase))
|
|
|
|
return 0
|
|
|
|
async def run_pipeline(self, wake_word_phrase: str | None = None):
|
|
logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.")
|
|
|
|
try:
|
|
logging.debug(" > STT start")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None)
|
|
|
|
logging.debug(" > STT VAD start")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None)
|
|
|
|
# VAD
|
|
logging.debug(" > VAD: Waiting for silence...")
|
|
VAD_TIMEOUT = 1000
|
|
_voice_detected = False
|
|
while True:
|
|
if self.audio_queue.length() > 20000:
|
|
break
|
|
elif self.audio_queue.length() >= VAD_TIMEOUT:
|
|
voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT))
|
|
if voice_detected != _voice_detected:
|
|
_voice_detected = voice_detected
|
|
if not _voice_detected:
|
|
logging.debug(" > VAD: Silence detected.")
|
|
break
|
|
if self.audio_queue.length() > 2000 and not voice_detected:
|
|
logging.debug(" > VAD: Silence detected (no voice).")
|
|
break
|
|
await asyncio.sleep(0.1)
|
|
|
|
logging.debug(" > STT VAD end")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None)
|
|
|
|
# STT
|
|
text = self.stt.transcribe(self.audio_queue.get())
|
|
self.audio_queue.clear()
|
|
logging.info(f" > STT: {text}")
|
|
|
|
logging.debug(" > STT end")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text})
|
|
|
|
logging.debug(" > Intent start")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None)
|
|
agent_response = self.agent.ask(text)["answer"]
|
|
logging.info(f">Intent: {agent_response}")
|
|
|
|
logging.debug(" > Intent end")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None)
|
|
|
|
logging.debug(" > TTS start")
|
|
TTS = "announce"
|
|
if TTS == "announce":
|
|
self.client.send_voice_assistant_event(
|
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response}
|
|
)
|
|
|
|
voice = "default"
|
|
audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest()
|
|
audio_file_path = f"static/{audio_hash_string}.wav"
|
|
if not os.path.exists(audio_file_path):
|
|
logging.debug(" > TTS: Generating audio...")
|
|
audio_segment = self.tts.generate(agent_response, voice=voice)
|
|
else:
|
|
logging.debug(" > TTS: Using cached audio.")
|
|
audio_segment = AudioSegment.from_file(audio_file_path)
|
|
audio_segment.export(audio_file_path, format="wav")
|
|
self.client.send_voice_assistant_event(
|
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
|
{"url": f"http://192.168.10.111:8002/{audio_file_path}"},
|
|
)
|
|
elif TTS == "stream":
|
|
await self._stream_tts_audio("debug")
|
|
logging.debug(" > TTS end")
|
|
|
|
except Exception as e:
|
|
logging.error("Error: ")
|
|
logging.error(e)
|
|
logging.error(traceback.format_exc())
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)})
|
|
|
|
logging.debug(" > Run end")
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None)
|
|
logging.info("Pipeline done")
|
|
|
|
async def handle_pipeline_stop(self, abort: bool):
|
|
logging.info("Pipeline stopped")
|
|
self.state = "idle"
|
|
|
|
async def handle_audio(self, data: bytes) -> None:
|
|
"""Handle incoming audio chunk from API."""
|
|
self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1))
|
|
|
|
async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished):
|
|
if event.success:
|
|
logging.info("Announcement finished successfully.")
|
|
else:
|
|
logging.error("Announcement failed.")
|
|
|
|
def handle_state_change(self, state):
|
|
logging.debug("State change:")
|
|
logging.debug(state)
|
|
|
|
def handle_log(self, log):
|
|
logging.info("Log from device:", log)
|
|
|
|
def subscribe(self):
|
|
self.client.subscribe_states(self.handle_state_change)
|
|
self.client.subscribe_logs(self.handle_log)
|
|
self.client.subscribe_voice_assistant(
|
|
handle_start=self.handle_pipeline_start,
|
|
handle_stop=self.handle_pipeline_stop,
|
|
handle_audio=self.handle_audio,
|
|
handle_announcement_finished=self.handle_announcement_finished,
|
|
)
|
|
|
|
async def _stream_tts_audio(
|
|
self,
|
|
media_id: str,
|
|
sample_rate: int = 16000,
|
|
sample_width: int = 2,
|
|
sample_channels: int = 1,
|
|
samples_per_chunk: int = 1024, # 512
|
|
) -> None:
|
|
"""Stream TTS audio chunks to device via API or UDP."""
|
|
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
|
logging.info("TTS stream start")
|
|
await asyncio.sleep(1)
|
|
|
|
try:
|
|
# if not self._is_running:
|
|
# return
|
|
|
|
with wave.open(f"{media_id}.wav", "rb") as wav_file:
|
|
if (
|
|
(wav_file.getframerate() != sample_rate)
|
|
or (wav_file.getsampwidth() != sample_width)
|
|
or (wav_file.getnchannels() != sample_channels)
|
|
):
|
|
logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}")
|
|
return
|
|
|
|
logging.info("Streaming %s audio samples", wav_file.getnframes())
|
|
|
|
rate = wav_file.getframerate()
|
|
width = wav_file.getsampwidth()
|
|
channels = wav_file.getnchannels()
|
|
|
|
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
|
bytes_per_sample = width * channels
|
|
bytes_per_chunk = bytes_per_sample * samples_per_chunk
|
|
num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk))
|
|
|
|
# Split into chunks
|
|
for i in range(num_chunks):
|
|
offset = i * bytes_per_chunk
|
|
chunk = audio_bytes[offset : offset + bytes_per_chunk]
|
|
self.client.send_voice_assistant_audio(chunk)
|
|
|
|
# Wait for 90% of the duration of the audio that was
|
|
# sent for it to be played. This will overrun the
|
|
# device's buffer for very long audio, so using a media
|
|
# player is preferred.
|
|
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
|
seconds_in_chunk = samples_in_chunk / sample_rate
|
|
logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}")
|
|
await asyncio.sleep(seconds_in_chunk * 0.9)
|
|
|
|
except asyncio.CancelledError:
|
|
return # Don't trigger state change
|
|
finally:
|
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {})
|
|
logging.info("TTS stream end")
|
|
|
|
|
|
async def start_http_server(port=8002):
|
|
app = web.Application()
|
|
app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static"))
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "0.0.0.0", port)
|
|
await site.start()
|
|
logging.info(f"HTTP server started at http://0.0.0.0:{port}/")
|
|
|
|
|
|
async def main():
|
|
"""Connect to an ESPHome device and wait for state changes."""
|
|
await start_http_server()
|
|
|
|
satellite = Satellite()
|
|
await satellite.connect()
|
|
logging.info("Connected to ESPHome voice assistant.")
|
|
satellite.subscribe()
|
|
satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None)
|
|
logging.info("Listening for events...")
|
|
|
|
while True:
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|