voice-assistant/esphome-bridge/esphome_server.py

284 lines
11 KiB
Python

import asyncio
import hashlib
import logging
import math
import os
import wave
from enum import StrEnum
import traceback
import aioesphomeapi
from aioesphomeapi import (VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantEventType)
from aiohttp import web
from audio import STT, TTS, VAD, AudioBuffer, AudioSegment
from llm_api import LlmApi
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
routes = web.RouteTableDef()
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
WAKE_WORD = "wake_word"
STT = "stt"
INTENT = "intent"
TTS = "tts"
END = "end"
class Satellite:
def __init__(self, host="192.168.10.155", port=6053, password=""):
self.client = aioesphomeapi.APIClient(host, port, password)
self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000)
self.state = "idle"
self.vad = VAD()
self.stt = STT(device="cuda")
self.tts = TTS()
self.agent = LlmApi()
async def connect(self):
await self.client.connect(login=True)
async def disconnect(self):
await self.client.disconnect()
async def handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int:
logging.debug(
f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}"
)
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
end_stage = PipelineStage.TTS
logging.info(f"Starting pipeline from {start_stage} to {end_stage}")
self.state = "running_pipeline"
# Run pipeline
asyncio.create_task(self.run_pipeline(wake_word_phrase))
return 0
async def run_pipeline(self, wake_word_phrase: str | None = None):
logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.")
try:
logging.debug(" > STT start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None)
logging.debug(" > STT VAD start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None)
# VAD
logging.debug(" > VAD: Waiting for silence...")
VAD_TIMEOUT = 1000
_voice_detected = False
while True:
if self.audio_queue.length() > 20000:
break
elif self.audio_queue.length() >= VAD_TIMEOUT:
voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT))
if voice_detected != _voice_detected:
_voice_detected = voice_detected
if not _voice_detected:
logging.debug(" > VAD: Silence detected.")
break
if self.audio_queue.length() > 2000 and not voice_detected:
logging.debug(" > VAD: Silence detected (no voice).")
break
await asyncio.sleep(0.1)
logging.debug(" > STT VAD end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None)
# STT
text = self.stt.transcribe(self.audio_queue.get())
self.audio_queue.clear()
logging.info(f" > STT: {text}")
logging.debug(" > STT end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text})
logging.debug(" > Intent start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None)
agent_response = self.agent.ask(text)["answer"]
logging.info(f">Intent: {agent_response}")
logging.debug(" > Intent end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None)
logging.debug(" > TTS start")
TTS = "announce"
if TTS == "announce":
self.client.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response}
)
voice = "default"
audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest()
audio_file_path = f"static/{audio_hash_string}.wav"
if not os.path.exists(audio_file_path):
logging.debug(" > TTS: Generating audio...")
audio_segment = self.tts.generate(agent_response, voice=voice)
else:
logging.debug(" > TTS: Using cached audio.")
audio_segment = AudioSegment.from_file(audio_file_path)
audio_segment.export(audio_file_path, format="wav")
self.client.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
{"url": f"http://192.168.10.111:8002/{audio_file_path}"},
)
elif TTS == "stream":
await self._stream_tts_audio("debug")
logging.debug(" > TTS end")
except Exception as e:
logging.error("Error: ")
logging.error(e)
logging.error(traceback.format_exc())
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)})
logging.debug(" > Run end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None)
logging.info("Pipeline done")
async def handle_pipeline_stop(self, abort: bool):
logging.info("Pipeline stopped")
self.state = "idle"
async def handle_audio(self, data: bytes) -> None:
"""Handle incoming audio chunk from API."""
self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1))
async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished):
if event.success:
logging.info("Announcement finished successfully.")
else:
logging.error("Announcement failed.")
def handle_state_change(self, state):
logging.debug("State change:")
logging.debug(state)
def handle_log(self, log):
logging.info("Log from device:", log)
def subscribe(self):
self.client.subscribe_states(self.handle_state_change)
self.client.subscribe_logs(self.handle_log)
self.client.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
handle_announcement_finished=self.handle_announcement_finished,
)
async def _stream_tts_audio(
self,
media_id: str,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
samples_per_chunk: int = 1024, # 512
) -> None:
"""Stream TTS audio chunks to device via API or UDP."""
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
logging.info("TTS stream start")
await asyncio.sleep(1)
try:
# if not self._is_running:
# return
with wave.open(f"{media_id}.wav", "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != sample_channels)
):
logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}")
return
logging.info("Streaming %s audio samples", wav_file.getnframes())
rate = wav_file.getframerate()
width = wav_file.getsampwidth()
channels = wav_file.getnchannels()
audio_bytes = wav_file.readframes(wav_file.getnframes())
bytes_per_sample = width * channels
bytes_per_chunk = bytes_per_sample * samples_per_chunk
num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk))
# Split into chunks
for i in range(num_chunks):
offset = i * bytes_per_chunk
chunk = audio_bytes[offset : offset + bytes_per_chunk]
self.client.send_voice_assistant_audio(chunk)
# Wait for 90% of the duration of the audio that was
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}")
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
return # Don't trigger state change
finally:
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {})
logging.info("TTS stream end")
async def start_http_server(port=8002):
app = web.Application()
app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static"))
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "0.0.0.0", port)
await site.start()
logging.info(f"HTTP server started at http://0.0.0.0:{port}/")
async def main():
"""Connect to an ESPHome device and wait for state changes."""
await start_http_server()
satellite = Satellite()
await satellite.connect()
logging.info("Connected to ESPHome voice assistant.")
satellite.subscribe()
satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None)
logging.info("Listening for events...")
while True:
await asyncio.sleep(0.1)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass