This commit is contained in:
Pierre Wessman 2025-01-16 16:22:58 +01:00
commit 7b45d19308
38 changed files with 2004 additions and 0 deletions

13
.gitignore vendored Normal file
View File

@ -0,0 +1,13 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
config.yml
debug.py
notebook.ipynb
.vscode/
secrets.yaml

1
agent-api/.cache Normal file
View File

@ -0,0 +1 @@
{"access_token": "BQDSGB6rh1-6-bc7wxFTZ2_Idjr6W-xLSZEMizMF6wNXew6_2M4DyWeSaVovZZLPdVv0aQOU8LqIQqQSNr5OwNXDpRk6icpcC6HKYGaZ7I_U9PRWbIOWE-QUakMcAmMFTWwO3FJEBRDYARR43cWTeRn8w9MHG48V5awv2A1447yz8kd4cUk8lAOTbcCeMnpxjJJQqqaigIWAy8NF7ghJXr4dibKj2JtTP5yECA", "token_type": "Bearer", "expires_in": 3600, "scope": "user-library-read", "expires_at": 1736781939, "refresh_token": "AQBaAMOoUVzcichmSIgDC4_CvkXXNNwfjnSbwpypzuaaCQqcQw-_2VP6Rkayp08Y-5gSEbEGVWc0-1CP4-JmQI2kQ_0UmJT5g2G3LvOjufSYowzXRxeedIOWcwbRvjSA1XI"}

161
agent-api/agent.py Normal file
View File

@ -0,0 +1,161 @@
import datetime
import json
from jinja2 import Template
import logging
from typing import Union
from db import Database
from llm import LLM
def relative_to_actual_dates(text: str):
from dateparser.search import search_dates
matches = search_dates(text, languages=["en", "se"])
if matches:
for match in matches:
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
return text
class Agent:
def __init__(self, config: dict) -> None:
self.config = config
self.db = Database("rag")
self.llm = LLM(config=self.config["llm"])
def flush(self):
"""
Flushes the agents knowledge.
"""
logging.info("Truncating database")
self.db.cur.execute("TRUNCATE TABLE items")
self.db.conn.commit()
def insert(self, text):
"""
Append knowledge.
"""
logging.info(f"Inserting item into embedding table ({text}).")
#logging.info(f"\tReplacing relative dates with actual dates.")
#text = relative_to_actual_dates(text)
#logging.info(f"\tDone. text: {text}")
vector = self.llm.embedding(text)
vector_padded = vector + ([0]*(2048-len(vector)))
self.db.cur.execute(
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
)
self.db.conn.commit()
def keywords(self, query: str, n_keywords: int = 5):
"""
Suggest keywords related to a query.
"""
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
prompt = (
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
+ f"Do not add any context or text other than the json list of {n_keywords} words."
)
response = self.llm.query(prompt, system_msg=sys_msg)
keywords = json.loads(response)
return keywords
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
"""
Retrieve relevant knowledge.
Parameters:
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
logging.debug(f"Using embedding model on '{query}'.")
vector_keywords = self.llm.embedding(query)
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
logging.debug("Querying database")
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
self.db.cur.execute(
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
(vector_padded,),
)
db_response = self.db.cur.fetchall()
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
logging.debug("Database returned: ")
logging.debug(knowledge_arr)
if n_retrieve > n_rerank:
logging.info("Reranking the results")
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
reranked_knowledge_arr = sorted(
knowledge_arr,
key=lambda x: (
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
),
)
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
logging.debug("Reranked results: ")
logging.debug(reranked_knowledge_arr)
return reranked_knowledge_arr
else:
return knowledge_arr
def ask(
self,
question: str,
person: str,
history: list = None,
n_retrieve: int = 5,
n_rerank=3,
tools: Union[None, list] = None,
prompts: Union[None, list] = None,
):
"""
Ask the agent a question.
Parameters:
- question (str): The question to ask
- person (str): Who asks the question?
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Answering question: {question}")
if n_retrieve > 0:
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
with open('templates/knowledge.en.j2', 'r') as file:
template = Template(file.read())
knowledge_str = template.render(
knowledge = knowledge,
)
else:
knowledge_str = ""
with open('templates/ask.en.j2', 'r') as file:
template = Template(file.read())
prompt = template.render(
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
knowledge = knowledge_str,
plugin_prompts = prompts,
person = person
)
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
answer = self.llm.query(
question,
system_msg=prompt,
tools=tools,
history=history
)
logging.info(f"Got answer from LLM: '{answer}'.")
return answer

127
agent-api/backend.py Normal file
View File

@ -0,0 +1,127 @@
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import importlib
import logging
import os
from pydantic import BaseModel
import yaml
from agent import Agent
PLUGINS_FOLDER = "plugins"
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
with open("config.yml", "r", encoding='utf8') as file:
config = yaml.safe_load(file)
def load_plugins(config):
tools = []
prompts = []
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
logging.info(f"Loading plugin: {plugin_folder}")
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
module = importlib.import_module(module_path)
plugin = module.Plugin(config)
tools += plugin.tools()
prompt = plugin.prompt()
if prompt is not None:
prompts.append(prompt)
return tools, prompts
tools, prompts = load_plugins(config)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
agent = Agent(config)
class Query(BaseModel):
query: str
class Message(BaseModel):
role: str
content: str
class Question(BaseModel):
question: str
person: str = "User"
history: list[Message] = None
@app.get("/ping")
def ping():
"""
# Ping
Check if the API is up.
"""
return {"status": "ok"}
@app.post("/ask")
def ask(q: Question):
"""
# Ask
Ask the agent a question.
## Parameters
- **question**: the question to ask.
- **person**: who is asking the question?
"""
history = None
if q.history:
history = [q.model_dump() for q in q.history]
answer = agent.ask(
question=q.question,
person="Pierre",
history=history,
tools=tools,
prompts=prompts,
n_retrieve=0,
n_rerank=0
)
return {"question": q.question, "answer": answer, "history": q.history}
@app.post("/retrieve")
def retrieve(q: Query):
"""
# Retrieve
Retrieve knowledge related to a query.
## Parameters
- **query**: the question.
"""
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
return {"query": q.query, "answer": result}
@app.post("/learn")
def retrieve(q: Query):
"""
# Learn
Learn the agent something. E.g. *"John likes to play volleyball".*
## Parameters
- **query**: the note.
"""
answer = agent.insert(text=q.query)
return {"text": q.query}
@app.get("/flush")
def flush():
"""
# Flush
Remove all knowledge from the agent.
"""
agent.flush()
return {"message": "Knowledge flushed."}

View File

@ -0,0 +1,25 @@
server:
url: "http://127.0.0.1:8000"
llm:
use_local_chat: false
use_local_embedding: false
local_model_dir: ""
local_embedding_model: "intfloat/e5-large-v2"
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
openai_embedding_model: "text-embedding-3-small"
openai_chat_model: "gpt-3.5-turbo-0125"
temperature: 0.1
homeassistant:
url: "http://localhost:8123"
token: ""
plugins:
calendar:
default_calendar: "calendar.my_calendar"
todo:
default_list: "todo.todo"
vasttrafik:
key: ""
secret: ""
default_from_station: "Brunnsparken"
default_to_station: "Centralstationen"
delay: 0 # minutes

32
agent-api/db.py Normal file
View File

@ -0,0 +1,32 @@
import logging
from pgvector.psycopg2 import register_vector
import psycopg2
class Database:
def __init__(
self,
database,
user="postgres",
password="postgres",
host="localhost",
port="5433",
) -> None:
logging.info("Connecting to database")
self.conn = psycopg2.connect(
database="rag",
user="postgres",
password="postgres",
host="localhost",
port="5433",
)
register_vector(self.conn)
self.cur = self.conn.cursor()
self.cur.execute("SELECT version();")
logging.info(" DB Version: %s", self.cur.fetchone()[0])
logging.info(" psycopg2 Version: %s", psycopg2.__version__)
def close(self):
logging.info("Closing connection to database")
self.cur.close()
self.conn.close()

7
agent-api/db.sql Normal file
View File

@ -0,0 +1,7 @@
CREATE TABLE public.items (
id bigserial NOT NULL,
embedding vector(2048) NULL,
"text" varchar(1024) NULL,
date_added date NULL,
CONSTRAINT items_pkey PRIMARY KEY (id)
);

View File

@ -0,0 +1,15 @@
version: '3.8'
services:
db:
image: pgvector/pgvector:pg16
restart: always
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
ports:
- '5433:5432'
volumes:
- db:/var/lib/postgresql/data
volumes:
db:
driver: local

41
agent-api/frontend.py Normal file
View File

@ -0,0 +1,41 @@
import gradio as gr
import logging
import requests
import yaml
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
with open("config.yml", "r") as file:
config = yaml.safe_load(file)
BASE_URL = config["server"]["url"]
def chat(message, history):
messages = []
for (user_msg, bot_msg) in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "system", "content": bot_msg})
logging.info(f"Sending chat message to backend ('{message}').")
response = requests.post(BASE_URL + "/ask", json={"question": message, "person": "Pierre", "history": messages})
if response.status_code == 200:
logging.info("Response:")
logging.info(response.json())
data = response.json()
return data["answer"]
else:
logging.error(response)
ui = gr.ChatInterface(chat, title="Chat")
if __name__ == "__main__":
logging.info("Launching frontend.")
ui.launch(
#share=False,
server_name='0.0.0.0',
)

3
agent-api/install.md Normal file
View File

@ -0,0 +1,3 @@
# llama.cpp
set CMAKE_ARGS=-DLLAMA_CUBLAS=on
pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python==0.2.39

37
agent-api/llama_server.py Normal file
View File

@ -0,0 +1,37 @@
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette
export MODEL=../models/7B/...
```
Then run:
```
uvicorn --factory llama_cpp.server.app:create_app --reload
```
or
```
python3 -m llama_cpp.server
```
Then visit http://localhost:8000/docs to see the interactive API docs.
To actually see the implementation of the server, see llama_cpp/server/app.py
"""
import os
import uvicorn
from llama_cpp.server.app import create_app
if __name__ == "__main__":
app = create_app()
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
)

252
agent-api/llm.py Normal file
View File

@ -0,0 +1,252 @@
import concurrent.futures
import json
from llama_cpp.llama import Llama
import logging
import numpy as np
from openai import OpenAI
from sentence_transformers import SentenceTransformer, CrossEncoder
from typing import Union, List
class BaseChat:
def __init__(self, config: dict) -> None:
self.config = config
def prepare_function_calling(self, tools):
# prepare function calling
function_map = {}
if tools is not None and len(tools) > 0:
for tool in tools:
function_name = tool["function"]["name"]
function_map[function_name] = tool["function"]["function_to_call"]
functions = []
for tool in tools:
fn = tool["function"]
functions.append(
{"type": tool["type"], "function": {x: fn[x] for x in fn if x != "function_to_call"}}
)
logging.info(f"{len(tools)} available functions:")
logging.info(function_map.keys())
else:
functions = None
return function_map, functions
class OpenAIChat(BaseChat):
def __init__(self, config: dict) -> None:
self.config = config
if self.config["openai"]["base_url"] is not None and self.config["openai"]["base_url"] != "":
base_url = self.config["openai"]["base_url"]
else:
base_url = None
self.client = OpenAI(base_url=base_url)
def chat(self, messages, tools) -> str:
function_map, functions = self.prepare_function_calling(tools)
logging.debug("Sending request to OpenAI.")
llm_response = self.client.chat.completions.create(
model=self.config["openai"]["chat_model"],
messages=messages,
tools=functions,
temperature=self.config["temperature"],
tool_choice="auto" if functions is not None else None,
)
logging.debug("LLM response:")
logging.debug(llm_response.choices)
if llm_response.choices[0].message.tool_calls:
# Handle function calls
followup_response = self.execute_function_call(
llm_response.choices[0].message, function_map, messages
)
return followup_response.choices[0].message.content.strip()
else:
return llm_response.choices[0].message.content.strip()
def execute_function_call(self, message, function_map: dict, messages: list) -> str:
"""
Executes function calls embedded in a LLM message in parallel, and returns a LLM response based on the results.
Parameters:
- message: LLM message containing the function calls.
- function_map (dict): dict of {"function_name": function()}
- message (list): message history
"""
tool_calls = message.tool_calls
logging.info(f"Got {len(tool_calls)} function call(s).")
def execute_single_tool_call(tool_call):
"""Helper function to execute a single tool call"""
logging.info(f"Attempting to execute function requested by LLM ({tool_call.function.name}, {tool_call.function.arguments}).")
if tool_call.function.name in function_map:
function_to_call = function_map[tool_call.function.name]
args = json.loads(tool_call.function.arguments)
logging.debug(f"Calling function {tool_call.function.name} with args: {args}")
function_response = function_to_call(**args)
logging.debug(function_response)
return {
"role": "function",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": function_response,
}
else:
logging.info(f"{tool_call.function.name} not in function_map")
logging.info(function_map.keys())
return None
# Execute tool calls in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit all tool calls to the executor
future_to_tool_call = {
executor.submit(execute_single_tool_call, tool_call): tool_call
for tool_call in tool_calls
}
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_tool_call):
result = future.result()
if result is not None:
messages.append(result)
logging.debug("Functions called, sending the results to LLM.")
llm_response = self.client.chat.completions.create(
model=self.config["openai"]["chat_model"],
messages=messages,
temperature=self.config["temperature"],
)
logging.debug("Got response from LLM:")
logging.debug(llm_response)
return llm_response
class LMStudioChat(BaseChat):
def __init__(self, config: dict) -> None:
self.config = config
self.client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed")
def chat(self, messages, tools) -> str:
logging.info("Sending request to local LLM.")
llm_response = self.client.chat.completions.create(
model="",
messages=messages,
temperature=self.config["temperature"],
)
return llm_response.choices[0].message.content.strip()
class LlamaChat(BaseChat):
def __init__(self, config: dict) -> None:
self.config = config
self.client = Llama(
self.config["local_model_dir"] + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF/mistral-7b-instruct-v0.1.Q6_K.gguf",
n_gpu_layers=32,
n_ctx=2048,
verbose=False
)
def chat(self, messages: list, response_format: dict = None, tools: list = None) -> str:
if tools is not None:
logging.warning("Tools was provided to LlamaChat, but it's not yet supported.")
logging.info("Sending request to local LLM.")
logging.info(messages)
llm_response = self.client.create_chat_completion(
messages = messages,
response_format = response_format,
temperature = self.config["temperature"],
)
return llm_response["choices"][0]["message"]["content"].strip()
class LLM:
def __init__(self, config: dict) -> None:
"""
LLM constructor.
Parameters:
- config (dict): llm-config parsed from the llm part of config.yml
"""
self.config = config
if self.config["use_local_chat"]:
self.chat_client = LlamaChat(self.config)
else:
self.chat_client = OpenAIChat(self.config)
if self.config["use_local_embedding"]:
self.embedding_model = self.config["local_embedding_model"]
self.rerank_model = self.config["local_rerank_model"]
self.embedding_client = SentenceTransformer(self.embedding_model)
else:
self.embedding_model = self.config["openai"]["embedding_model"]
self.rerank_model = None
self.embedding_client = OpenAI()
def query(
self,
user_msg: str,
system_msg: Union[None, str] = None,
history: list = [],
tools: Union[None, list] = None,
):
"""
Query the LLM
Parameters:
- user_msg (str): query from the user (will be appended to messages as {"role": "user"})
- system_msg (str): query from the user (will be prepended to messages as {"role": "system"})
- history (list): optional, list of messages to inject between system_msg and user_msg.
- tools: optional, list of functions that may be called by the LLM. Example:
{
"type": "function",
"function": {
"name": "get_current_weather",
"function": get_current_weather,
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location", "format"],
},
},
}
"""
messages = []
if system_msg is not None:
messages.append({"role": "system", "content": system_msg})
if history and len(history) > 0:
logging.info("History")
logging.info(history)
messages += history
messages.append({"role": "user", "content": user_msg})
logging.info(f"Sending request to LLM with {len(messages)} messages.")
answer = self.chat_client.chat(messages=messages, tools=tools)
return answer
def embedding(self, text: str):
if self.config["use_local_embedding"]:
embedding = self.embedding_client.encode(text)
return embedding # len = 1024
else:
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
return llm_response.data[0].embedding
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
if self.rerank_model is not None and len(corpus) > 0:
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
logging.info("Reranking:")
logging.info(sentence_combinations)
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
return result[:n_rerank]
else:
return corpus

View File

View File

@ -0,0 +1,68 @@
import inspect
from plugins.homeassistant import HomeAssistant
class BasePlugin:
def __init__(self, config: dict) -> None:
self.config = config
self.homeassistant = HomeAssistant(config)
def prompt(self) -> str | None:
return
def tools(self) -> list:
tools = []
for tool_fn in self._list_tool_methods():
tool_fn_metadata = self._get_function_metadata(tool_fn)
if "input" in tool_fn_metadata["parameters"]:
json_schema = tool_fn_metadata["parameters"]["input"].annotation.model_json_schema()
del json_schema["title"]
for k, v in json_schema["properties"].items():
del json_schema["properties"][k]["title"]
fn_to_call = self._create_function_with_kwargs(
tool_fn_metadata["parameters"]["input"].annotation, tool_fn
)
else:
json_schema = {}
fn_to_call = tool_fn
tools.append(
{
"type": "function",
"function": {
"name": tool_fn_metadata["name"],
"function_to_call": fn_to_call,
"description": tool_fn_metadata["docstring"],
"parameters": json_schema,
},
}
)
return tools
def _get_function_metadata(self, func):
function_name = func.__name__
docstring = inspect.getdoc(func)
signature = inspect.signature(func)
parameters = signature.parameters
metadata = {"name": function_name, "docstring": docstring or "", "parameters": parameters}
return metadata
def _create_function_with_kwargs(self, model_cls, original_function):
def dynamic_function(**kwargs):
model_instance = model_cls(**kwargs)
return original_function(model_instance)
return dynamic_function
def _list_tool_methods(self) -> list:
attributes = dir(self)
tool_functions = [
getattr(self, attr)
for attr in attributes
if callable(getattr(self, attr)) and attr.startswith("tool_")
]
return tool_functions

View File

@ -0,0 +1,26 @@
import asyncio
import json
from ..base_plugin import BasePlugin
class Plugin(BasePlugin):
def __init__(self, config: dict) -> None:
super().__init__(config=config)
def tool_get_calendar_events(self, entity_id=None, days=3):
"""
Get users calendar events.
"""
if entity_id is None:
entity_id = self.config["plugins"]["calendar"]["default_calendar"]
response = asyncio.run(
self.homeassistant.send_command(
"call_service",
domain="calendar",
service="get_events",
target={"entity_id": entity_id},
service_data={"duration": {"days": days}},
return_response=True,
)
)
return json.dumps(response)

View File

@ -0,0 +1,51 @@
import json
from typing import Any
import requests
from aiohttp import ClientSession
from hass_client import HomeAssistantClient
class HomeAssistant:
def __init__(self, config: dict) -> None:
self.config = config
self.ha_config = self.call_api("config")
def call_api(self, endpoint, payload=None):
"""
Call the REST API
"""
base_url = self.config["homeassistant"]["url"]
headers = {
"Authorization": f"Bearer {self.config['homeassistant']['token']}",
"content-type": "application/json",
}
if payload is None:
response = requests.get(f"{base_url}/api/{endpoint}", headers=headers)
else:
response = requests.post(f"{base_url}/api/{endpoint}", headers=headers, json=payload)
if response.status_code == 200:
try:
return response.json()
except:
pass
try:
return json.loads(response.text.replace("'", '"').replace("None", "null"))
except:
return response.text
else:
return {"status": "error", "message": response.text}
async def send_command(self, command: str, **kwargs: dict[str, Any]):
"""
Send command using the WebSocket API
"""
async with ClientSession() as session:
async with HomeAssistantClient(
self.config["homeassistant"]["url"] + "/api/websocket",
self.config["homeassistant"]["token"],
session,
) as client:
response = await client.send_command(command, **kwargs)
return response

View File

@ -0,0 +1,73 @@
import json
from ..base_plugin import BasePlugin
class Plugin(BasePlugin):
def __init__(self, config: dict) -> None:
super().__init__(config=config)
self.light_map = self.get_light_map()
def prompt(self):
prompt = "These are the lights available:\n" + json.dumps(
self.light_map, indent=2, ensure_ascii=False
)
return prompt
def tools(self):
tools = [
{
"type": "function",
"function": {
"name": "control_light",
"function_to_call": self.tool_control_light,
"description": "Control lights in users a home",
"parameters": {
"type": "object",
"properties": {
"entity_id": {
"type": "string",
"enum": self.available_lights(),
"description": "What light entity to control",
},
"state": {
"type": "string",
"enum": ["on", "off"],
"description": "on or off",
},
},
"required": ["room", "state"],
},
},
}
]
return tools
def tool_control_light(self, entity_id: str, state: str):
self.homeassistant.call_api(f"services/light/turn_{state}", payload={"entity_id": entity_id})
return json.dumps({"status": "success", "message": f"{entity_id} was turned {state}."})
def available_lights(self):
available_lights = []
for room in self.light_map:
for light in self.light_map[room]:
available_lights.append(light["entity_id"])
return available_lights
def get_light_map(self):
template = (
"{% for state in states.light %}\n"
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}", "room": "{{area_name(state.entity_id)}}"}\n'
+ "{% endfor %}\n"
)
response = self.homeassistant.call_api(f"template", payload={"template": template})
light_map = {}
for item_str in response.split("\n\n"):
item_dict = json.loads(item_str)
if item_dict["room"] not in light_map:
light_map[item_dict["room"]] = []
light_map[item_dict["room"]].append(
{"entity_id": item_dict["entity_id"], "name": item_dict["name"]}
)
return light_map

View File

@ -0,0 +1 @@
{"access_token": "BQC4auYl6DX-_xbS0zDjbv2_OZGbbdvyWLDa7I6EBE91GhBAExs3TGmhBNCgwS3P0WJwkUyL-kRRH2x-80QpWaRdLhEm0Q7Tfm8vzsV1rRBToy87Vm1jJJX3S08bYQwD3UvxFlaOhGyyWoSa37JnPkmS9T8d2GTPeHne1HaYR5KEJeYu9stixoONiTjHidQJi_7hOozIqwUkTC7TFWBWezJfa_7GlkE", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "AQDSGqpEaLoyR_A-s78bFVHYE1F2PIbA_cfgrgMt7jZA-A3xmEo5V7GazlgB-okD0JHX-w_fVc-n0_b8nvvZUz5PT6iCd7jRHLVoRA-h6XbQKcSAUVwJGL2djOTfnC4wOIE", "scope": "user-library-read", "expires_at": 1736686521}

View File

@ -0,0 +1,59 @@
import json
import logging
import spotipy
from pydantic import BaseModel, Field
from spotipy.oauth2 import SpotifyOAuth
from ..base_plugin import BasePlugin
class PlayMusicInput(BaseModel):
query: str = Field(..., description="Can be a song, artist, album, or playlist")
class Plugin(BasePlugin):
def __init__(self, config: dict) -> None:
super().__init__(config=config)
self.spotify = spotipy.Spotify(
auth_manager=SpotifyOAuth(scope="user-library-read", redirect_uri="http://localhost:8080")
)
def _search(self, query: str, limit: int = 10):
_result = self.spotify.search(query, limit=limit)
result = []
for track in _result["tracks"]["items"]:
artists = [artist["name"] for artist in track["artists"]]
result.append(
{
"name": track["name"],
"artists": artists,
"uri": track["uri"]
}
)
return result
def tool_play_music(self, input: PlayMusicInput):
"""
Play music using a search query.
"""
track = self._search(input.query, limit=1)[0]
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
payload = {
"entity_id": self.config["plugins"]["music"]["default_speaker"],
"media_content_id": track["uri"],
"media_content_type": "music",
"enqueue": "play",
}
result = self.homeassistant.call_api(f"services/media_player/play_media", payload=payload)
return json.dumps({"status": "success", "message": f"Playing music.", "track": track})
def tool_stop_music(self):
"""
Stop playback of music.
"""
self.homeassistant.call_api(
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
)
return json.dumps({"status": "success", "message": f"Music paused."})

View File

@ -0,0 +1,3 @@
# Plugins
Each folder in /plugins is a plugin.
Class files within /plugins (e.g. homeassistant.py) is helper classes that plugins can use in order to have to duplicate code.

View File

@ -0,0 +1,74 @@
import json
from smhi.smhi_lib import Smhi
from ..homeassistant import HomeAssistant
class Plugin:
def __init__(self, config: dict) -> None:
self.config = config
self.homeassistant = HomeAssistant(config)
self.station = Smhi(
self.homeassistant.ha_config["longitude"], self.homeassistant.ha_config["latitude"]
)
self.weather_conditions = {
1: "Clear sky",
2: "Nearly clear sky",
3: "Variable cloudiness",
4: "Halfclear sky",
5: "Cloudy sky",
6: "Overcast",
7: "Fog",
8: "Light rain showers",
9: "Moderate rain showers",
10: "Heavy rain showers",
11: "Thunderstorm",
12: "Light sleet showers",
13: "Moderate sleet showers",
14: "Heavy sleet showers",
15: "Light snow showers",
16: "Moderate snow showers",
17: "Heavy snow showers",
18: "Light rain",
19: "Moderate rain",
20: "Heavy rain",
21: "Thunder",
22: "Light sleet",
23: "Moderate sleet",
24: "Heavy sleet",
25: "Light snowfall",
26: "Moderate snowfall",
27: "Heavy snowfall",
}
def prompt(self):
return None
def tools(self):
tools = [
{
"type": "function",
"function": {
"name": "get_weather_forecast",
"function_to_call": self.get_weather_forecast,
"description": "Get weather forecast for the following 24 hours.",
},
}
]
return tools
def get_weather_forecast(self, hours=24):
forecast = []
for hour in self.station.get_forecast_hour()[:hours]:
forecast.append(
{
"time": hour.valid_time.strftime("%Y-%m-%d %H:%M"),
"weather": self.weather_conditions[hour.symbol],
"temperature": hour.temperature,
# "cloudiness": hour.cloudiness,
"total_precipitation": hour.total_precipitation,
"wind_speed": hour.wind_speed,
}
)
return json.dumps(forecast)

View File

@ -0,0 +1 @@
smhi-pkg>=1.0.16

View File

@ -0,0 +1,107 @@
import asyncio
import json
from ..base_plugin import BasePlugin
class Plugin(BasePlugin):
def __init__(self, config: dict) -> None:
super().__init__(config=config)
def prompt(self):
prompt = "These are the todo lists available:\n" + json.dumps(
self.get_todo_lists(), indent=2, ensure_ascii=False
)
return prompt
def tools(self):
tools = [
{
"type": "function",
"function": {
"name": "get_todo_list_items",
"function_to_call": self.get_todo_list_items,
"description": "Get items from todo-list.",
"parameters": {
"type": "object",
"properties": {
"entity_id": {
"type": "string",
"description": "What list to get items from.",
}
},
"required": ["entity_id"],
},
},
},
{
"type": "function",
"function": {
"name": "add_todo_item",
"function_to_call": self.add_todo_item,
"description": "Add item to todo-list.",
"parameters": {
"type": "object",
"properties": {
"entity_id": {
"type": "string",
"description": "What list to add item to.",
},
"item": {
"type": "string",
"description": "Description of the item.",
},
},
"required": ["entity_id", "item"],
},
},
},
]
return tools
def get_todo_list_items(self, entity_id: str = None):
"""
Get items from todo-list.
"""
if entity_id is None:
entity_id = self.config["plugins"]["todo"]["default_list"]
response = asyncio.run(
self.homeassistant.send_command(
"call_service",
domain="todo",
service="get_items",
target={"entity_id": entity_id},
service_data={"status": "needs_action"},
return_response=True,
)
)
return json.dumps(response["response"])
def add_todo_item(self, item: str, entity_id: str = None):
"""
Add item to todo-list.
"""
if entity_id is None:
entity_id = self.config["plugins"]["todo"]["default_list"]
asyncio.run(
self.homeassistant.send_command(
"call_service",
domain="todo",
service="add_item",
target={"entity_id": entity_id},
service_data={"item": item},
)
)
return json.dumps({"status": f"{item} added to list."})
def get_todo_lists(self):
template = (
"{% for state in states.todo %}\n"
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}"}\n'
+ "{% endfor %}\n"
)
response = self.homeassistant.call_api(f"template", payload={"template": template})
todo_lists = {}
for item_str in response.split("\n\n"):
item_dict = json.loads(item_str)
todo_lists[item_dict["name"]] = item_dict["entity_id"]
return todo_lists

View File

@ -0,0 +1,98 @@
import json
from datetime import datetime
import vasttrafik
from ..homeassistant import HomeAssistant
class Plugin:
def __init__(self, config: dict) -> None:
self.config = config
self.homeassistant = HomeAssistant(config)
self.vasttrafik = vasttrafik.JournyPlanner(
key=self.config["plugins"]["vasttrafik"]["key"],
secret=self.config["plugins"]["vasttrafik"]["secret"],
)
self.default_from_station_id = self.get_station_id(
self.config["plugins"]["vasttrafik"]["default_from_station"]
)
self.default_to_station_id = self.get_station_id(
self.config["plugins"]["vasttrafik"]["default_to_station"]
)
def prompt(self):
from_station = self.config["plugins"]["vasttrafik"]["default_from_station"]
to_station = self.config["plugins"]["vasttrafik"]["default_to_station"]
return "# Public transportation #\nHome station is: %s, and the default to station to use is %s" % (
from_station,
to_station,
)
def tools(self):
tools = [
{
"type": "function",
"function": {
"name": "search_trip",
"function_to_call": self.search_trip,
"description": "Search for trips by public transportation.",
"parameters": {
"type": "object",
"properties": {
"from_station": {
"type": "string",
"description": "Station to travel from.",
},
"to_station": {
"type": "string",
"description": "Station to travel to.",
},
},
"required": ["from_station", "to_station"],
},
},
}
]
return tools
def search_trip(self, from_station=None, to_station=None):
if from_station is None:
from_station_id = self.default_from_station_id
else:
from_station_id = self.get_station_id(from_station)
if to_station is None:
to_station_id = self.default_to_station_id
else:
to_station_id = self.get_station_id(to_station)
trip = self.vasttrafik.trip(from_station_id, to_station_id)
response = []
for departure in trip:
try:
departure_dt = datetime.strptime(
departure["tripLegs"][0]["origin"]["estimatedTime"][:19], "%Y-%m-%dT%H:%M:%S"
)
minutes_to_departure = (departure_dt - datetime.now()).seconds / 60
if minutes_to_departure >= self.config["plugins"]["vasttrafik"]["delay"]:
response.append(
{
"origin": {
"stationName": departure["tripLegs"][0]["origin"]["stopPoint"]["name"],
# "plannedTime": departure["tripLegs"][0]["origin"]["plannedTime"],
"estimatedDepartureTime": departure["tripLegs"][0]["origin"]["estimatedTime"],
},
"destination": {
"stationName": departure["tripLegs"][0]["destination"]["stopPoint"]["name"],
"estimatedArrivalTime": departure["tripLegs"][0]["estimatedArrivalTime"],
},
}
)
except Exception as e:
print(f"Error: {e}")
assert len(response) > 0, "No trips found."
return json.dumps(response, ensure_ascii=False)
def get_station_id(self, location_name: str) -> str:
station_id = self.vasttrafik.location_name(location_name)[0]["gid"]
return station_id

View File

@ -0,0 +1 @@
vtjp>=0.2.1

2
agent-api/pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
line-length = 110

21
agent-api/readme.md Normal file
View File

@ -0,0 +1,21 @@
## Start postgres with pgvector
docker-compose up -d
Then create a new db called "rag", and run this SQL:
```sql
CREATE EXTENSION vector;
```
Finally run the content of db.sql.
## Run backend
```bash
conda activate llm-api
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
```
## Run frontend
```bash
python frontend.py
```

View File

@ -0,0 +1,11 @@
fastapi
uvicorn[standard]
openai>=1.11.1
sentence_transformers
dateparser
pgvector
psycopg2
pyyaml
gradio
hass-client>=1.0.1
llama_cpp_python>=0.2.44

BIN
agent-api/static/tada.wav Normal file

Binary file not shown.

View File

@ -0,0 +1,12 @@
You are a helpful assistant. Your replies will be converted to speech, so keep it short and concise and avoid too many special characters.
Today is {{ today }}.
{% if knowledge is not none %}
{{ knowledge }}
{% endif %}
{% for prompt in plugin_prompts %}
- {{ prompt }}
{% endfor %}
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.

View File

@ -0,0 +1,5 @@
Here are some saved notes that might help you answer any question the user might have:
{% for note in knowledge %}
- {{ note }}
{% endfor %}
Never make notes of knowledge you already have (i.e. the list above).

203
esphome-va-bridge/audio.py Normal file
View File

@ -0,0 +1,203 @@
import datetime as dt
import io
from typing import Iterator
import numpy as np
import pydub
import scipy.io.wavfile as wavfile
import torch
from elevenlabs.client import ElevenLabs
from faster_whisper import WhisperModel
from silero_vad import get_speech_timestamps, load_silero_vad
class AudioSegment(pydub.AudioSegment):
"""
Wrapper class for pydub.AudioSegment
"""
def __init__(self, data=None, *args, **kwargs):
super().__init__(data, *args, **kwargs)
def to_ndarray(self) -> np.ndarray:
"""
Convert the AudioSegment to a numpy array.
"""
_buffer = io.BytesIO()
self.export(_buffer, format="wav")
_buffer.seek(0)
_, wav_data = wavfile.read(_buffer)
if wav_data.dtype != np.float32:
max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1
wav_data = wav_data.astype(np.float32) / max_abs
if len(wav_data.shape) == 2:
wav_data = wav_data.mean(axis=1)
return wav_data
def to_tensor(self) -> torch.Tensor:
"""
Convert the AudioSegment to a PyTorch tensor.
"""
return torch.tensor(self.to_ndarray(), dtype=torch.float32)
class AudioBuffer:
"""
Buffer for AudioSegments
"""
def __init__(self, max_size: int) -> None:
self.data: AudioSegment | None = None
self.max_size: int = max_size
def put(self, audio_segment: AudioSegment) -> None:
"""
Append an AudioSegment to the buffer.
"""
if self.data:
self.data = self.data + audio_segment
if len(self.data) >= self.max_size:
self.data = self.data[-self.max_size :]
else:
self.data = audio_segment
def clear(self) -> None:
"""
Clear the buffer.
"""
self.data = None
def get(self) -> AudioSegment | None:
"""
Get the AudioSegment from the buffer.
"""
return self.data
def get_last(self, duration: int) -> AudioSegment | None:
"""
Get the last `duration` milliseconds of the AudioSegment.
"""
return self.data[-duration:] if self.data else None
def is_empty(self) -> bool:
"""
Check if the buffer is empty.
"""
return self.data is None
def length(self) -> int:
"""
Get the length of the buffer.
"""
return len(self.data) if self.data else 0
class VAD:
"""
Voice Activity Detection
"""
def __init__(self) -> None:
self.model = load_silero_vad()
self.ts_voice_detected: dt.datetime | None = None
def detect_voice(self, audio_segment: AudioSegment) -> bool:
"""
Detect voice in an AudioSegment.
"""
speech_timestamps = get_speech_timestamps(
audio_segment.to_tensor(),
self.model,
return_seconds=True,
)
if len(speech_timestamps) > 0:
if self.ts_voice_detected is None:
pass # print("VAD: Voice detected.")
self.ts_voice_detected = dt.datetime.now()
return True
return False
def ms_since_voice_detected(self) -> int | None:
"""
Get the milliseconds since voice was detected.
"""
if self.ts_voice_detected is None:
return None
return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000
def reset_timer(self):
"""
Reset the timer for voice detected.
"""
self.ts_voice_detected = None
class STT:
"""
Speech-to-Text
"""
def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None:
compute_type = "int8" if device == "cpu" else "float16"
self.model = WhisperModel(model, device=device, compute_type=compute_type)
self.benchmark = benchmark
def transcribe(self, audio_segment: AudioSegment, language="sv"):
"""
Transcribe an AudioSegment.
"""
if self.benchmark:
ts_start = dt.datetime.now()
segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language)
segments = list(segments)
text = " ".join([seg.text for seg in segments])
text = text.strip()
if self.benchmark:
ts_done = dt.datetime.now()
delta_ms = (ts_done - ts_start).total_seconds() * 1000
print(f"[Benchmark] STT: {delta_ms:.2f} ms.")
return text
class TTS:
"""
Text-to-Speech (TTS) class that uses the ElevenLabs API.
"""
def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None:
self.client = ElevenLabs()
self.voice = voice
self.model = model
def generate(self, text: str, voice: str = "default") -> AudioSegment:
if voice == "default":
voice = self.voice
audio = self.client.generate(
text=text,
voice=self.voice,
model=self.model,
stream=False,
output_format="pcm_16000",
optimize_streaming_latency=2,
)
audio = b"".join(audio)
audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1)
return audio_segment
def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]:
if voice == "default":
voice = self.voice
audio_stream = self.client.generate(
text=text_stream,
voice=self.voice,
model=self.model,
stream=True,
output_format="pcm_16000",
optimize_streaming_latency=2,
)
for chunk in audio_stream:
if chunk is not None:
yield chunk

View File

@ -0,0 +1,113 @@
[16:22:06.368][D][i2s_audio.microphone:377]: Starting I2S Audio Microphne
[16:22:06.368][D][i2s_audio.microphone:381]: Started I2S Audio Microphone
[16:22:06.368][D][micro_wake_word:417]: State changed from IDLE to DETECTING_WAKE_WORD
[16:22:08.878][D][micro_wake_word:355]: Detected 'Okay Nabu' with sliding average probability is 0.97 and max probability is 1.00
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
[16:22:08.878][D][media_player:084]: Command: STOP
[16:22:08.878][D][media_player:093]: Announcement: yes
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
[16:22:08.878][D][media_player:093]: Announcement: yes
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
[16:22:08.879][D][nabu_media_player.pipeline:173]: Reading FLAC file type
[16:22:08.879][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
[16:22:08.879][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
[16:22:08.879][D][ring_buffer:034][speaker_task]: Created ring buffer with size 19200
[16:22:08.879][D][i2s_audio.speaker:111]: Starting Speaker
[16:22:08.880][D][i2s_audio.speaker:116]: Started Speaker
[16:22:08.880][D][voice_assistant:515]: State changed from IDLE to START_MICROPHONE
[16:22:08.880][D][voice_assistant:522]: Desired state set to START_PIPELINE
[16:22:08.880][D][voice_assistant:225]: Starting Microphone
[16:22:08.880][D][ring_buffer:034]: Created ring buffer with size 16384
[16:22:08.880][D][voice_assistant:515]: State changed from START_MICROPHONE to STARTING_MICROPHONE
[16:22:08.880][D][voice_assistant:515]: State changed from STARTING_MICROPHONE to START_PIPELINE
[16:22:08.880][D][voice_assistant:280]: Requesting start...
[16:22:08.880][D][voice_assistant:515]: State changed from START_PIPELINE to STARTING_PIPELINE
[16:22:08.880][D][voice_assistant:537]: Client started, streaming microphone
[16:22:08.881][D][voice_assistant:515]: State changed from STARTING_PIPELINE to STREAMING_MICROPHONE
[16:22:08.881][D][voice_assistant:522]: Desired state set to STREAMING_MICROPHONE
[16:22:08.881][D][voice_assistant:641]: Event Type: 1 (VOICE_ASSISTANT_RUN_START)
[16:22:08.881][D][voice_assistant:644]: Assist Pipeline running
[16:22:08.881][D][voice_assistant:641]: Event Type: 3 (VOICE_ASSISTANT_STT_START)
[16:22:08.881][D][voice_assistant:655]: STT started
[16:22:08.881][D][light:036]: 'voice_assistant_leds' Setting:
[16:22:08.881][D][light:047]: State: ON
[16:22:08.881][D][light:051]: Brightness: 66%
[16:22:08.882][D][light:109]: Effect: 'Waiting for Command'
[16:22:08.882][D][power_supply:033]: Enabling power supply.
[16:22:09.671][D][voice_assistant:641]: Event Type: 11 (VOICE_ASSISTANT_STT_VAD_START)
[16:22:09.671][D][voice_assistant:804]: Starting STT by VAD
[16:22:09.671][D][light:036]: 'voice_assistant_leds' Setting:
[16:22:09.671][D][light:051]: Brightness: 66%
[16:22:09.671][D][light:109]: Effect: 'Listening For Command'
[16:22:14.806][D][voice_assistant:641]: Event Type: 12 (VOICE_ASSISTANT_STT_VAD_END)
[16:22:14.806][D][voice_assistant:808]: STT by VAD end
[16:22:14.806][D][voice_assistant:515]: State changed from STREAMING_MICROPHONE to STOP_MICROPHONE
[16:22:14.806][D][voice_assistant:522]: Desired state set to AWAITING_RESPONSE
[16:22:14.806][D][voice_assistant:515]: State changed from STOP_MICROPHONE to STOPPING_MICROPHONE
[16:22:14.807][D][light:036]: 'voice_assistant_leds' Setting:
[16:22:14.807][D][light:051]: Brightness: 66%
[16:22:14.807][D][light:109]: Effect: 'Thinking'
[16:22:14.807][D][voice_assistant:515]: State changed from STOPPING_MICROPHONE to AWAITING_RESPONSE
[16:22:14.807][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to AWAITING_RESPONSE
[16:22:14.807][D][power_supply:033]: Enabling power supply.
[16:22:14.807][D][power_supply:033]: Enabling power supply.
[16:22:14.807][D][voice_assistant:641]: Event Type: 4 (VOICE_ASSISTANT_STT_END)
[16:22:14.807][D][voice_assistant:669]: Speech recognised as: " Who are you?"
[16:22:14.807][D][voice_assistant:641]: Event Type: 5 (VOICE_ASSISTANT_INTENT_START)
[16:22:14.808][D][voice_assistant:674]: Intent started
[16:22:14.808][D][power_supply:033]: Enabling power supply.
[16:22:14.808][D][power_supply:033]: Enabling power supply.
[16:22:14.808][D][power_supply:033]: Enabling power supply.
[16:22:14.808][D][power_supply:033]: Enabling power supply.
[16:22:14.808][D][power_supply:033]: Enabling power supply.
[16:22:14.808][D][voice_assistant:641]: Event Type: 6 (VOICE_ASSISTANT_INTENT_END)
[16:22:14.808][D][voice_assistant:641]: Event Type: 7 (VOICE_ASSISTANT_TTS_START)
[16:22:14.808][D][voice_assistant:697]: Response: "I am an AI assistant designed to help you with various tasks, answer questions, and provide information. If there's anything specific you need help with, feel free to ask!"
[16:22:14.809][D][light:036]: 'voice_assistant_leds' Setting:
[16:22:14.809][D][light:051]: Brightness: 66%
[16:22:14.809][D][light:109]: Effect: 'Replying'
[16:22:14.809][D][voice_assistant:641]: Event Type: 8 (VOICE_ASSISTANT_TTS_END)
[16:22:14.809][D][voice_assistant:719]: Response URL: "https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac"
[16:22:14.809][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to STREAMING_RESPONSE
[16:22:14.809][D][voice_assistant:522]: Desired state set to STREAMING_RESPONSE
[16:22:14.809][D][media_player:080]: 'Media Player' - Setting
[16:22:14.809][D][media_player:087]: Media URL: https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac
[16:22:14.810][D][media_player:093]: Announcement: yes
[16:22:14.810][D][voice_assistant:641]: Event Type: 2 (VOICE_ASSISTANT_RUN_END)
[16:22:14.810][D][voice_assistant:733]: Assist Pipeline ended
[16:22:14.810][D][esp-idf:000][ann_read]: I (40514) esp-x509-crt-bundle: Certificate validated
[16:22:14.810][D][nabu_media_player.pipeline:173]: Reading FLAC file type
[16:22:14.810][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
[16:22:14.810][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
[16:22:26.112][D][voice_assistant:515]: State changed from STREAMING_RESPONSE to IDLE
[16:22:26.112][D][voice_assistant:522]: Desired state set to IDLE

View File

@ -0,0 +1,283 @@
import asyncio
import hashlib
import logging
import math
import os
import wave
from enum import StrEnum
import traceback
import aioesphomeapi
from aioesphomeapi import (VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantEventType)
from aiohttp import web
from audio import STT, TTS, VAD, AudioBuffer, AudioSegment
from llm_api import LlmApi
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
routes = web.RouteTableDef()
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
WAKE_WORD = "wake_word"
STT = "stt"
INTENT = "intent"
TTS = "tts"
END = "end"
class Satellite:
def __init__(self, host="192.168.10.155", port=6053, password=""):
self.client = aioesphomeapi.APIClient(host, port, password)
self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000)
self.state = "idle"
self.vad = VAD()
self.stt = STT(device="cuda")
self.tts = TTS()
self.agent = LlmApi()
async def connect(self):
await self.client.connect(login=True)
async def disconnect(self):
await self.client.disconnect()
async def handle_pipeline_start(
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
wake_word_phrase: str | None,
) -> int:
logging.debug(
f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}"
)
# Device triggered pipeline (wake word, etc.)
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
end_stage = PipelineStage.TTS
logging.info(f"Starting pipeline from {start_stage} to {end_stage}")
self.state = "running_pipeline"
# Run pipeline
asyncio.create_task(self.run_pipeline(wake_word_phrase))
return 0
async def run_pipeline(self, wake_word_phrase: str | None = None):
logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.")
try:
logging.debug(" > STT start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None)
logging.debug(" > STT VAD start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None)
# VAD
logging.debug(" > VAD: Waiting for silence...")
VAD_TIMEOUT = 1000
_voice_detected = False
while True:
if self.audio_queue.length() > 20000:
break
elif self.audio_queue.length() >= VAD_TIMEOUT:
voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT))
if voice_detected != _voice_detected:
_voice_detected = voice_detected
if not _voice_detected:
logging.debug(" > VAD: Silence detected.")
break
if self.audio_queue.length() > 2000 and not voice_detected:
logging.debug(" > VAD: Silence detected (no voice).")
break
await asyncio.sleep(0.1)
logging.debug(" > STT VAD end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None)
# STT
text = self.stt.transcribe(self.audio_queue.get())
self.audio_queue.clear()
logging.info(f" > STT: {text}")
logging.debug(" > STT end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text})
logging.debug(" > Intent start")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None)
agent_response = self.agent.ask(text)["answer"]
logging.info(f">Intent: {agent_response}")
logging.debug(" > Intent end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None)
logging.debug(" > TTS start")
TTS = "announce"
if TTS == "announce":
self.client.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response}
)
voice = "default"
audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest()
audio_file_path = f"static/{audio_hash_string}.wav"
if not os.path.exists(audio_file_path):
logging.debug(" > TTS: Generating audio...")
audio_segment = self.tts.generate(agent_response, voice=voice)
else:
logging.debug(" > TTS: Using cached audio.")
audio_segment = AudioSegment.from_file(audio_file_path)
audio_segment.export(audio_file_path, format="wav")
self.client.send_voice_assistant_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
{"url": f"http://192.168.10.111:8002/{audio_file_path}"},
)
elif TTS == "stream":
await self._stream_tts_audio("debug")
logging.debug(" > TTS end")
except Exception as e:
logging.error("Error: ")
logging.error(e)
logging.error(traceback.format_exc())
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)})
logging.debug(" > Run end")
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None)
logging.info("Pipeline done")
async def handle_pipeline_stop(self, abort: bool):
logging.info("Pipeline stopped")
self.state = "idle"
async def handle_audio(self, data: bytes) -> None:
"""Handle incoming audio chunk from API."""
self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1))
async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished):
if event.success:
logging.info("Announcement finished successfully.")
else:
logging.error("Announcement failed.")
def handle_state_change(self, state):
logging.debug("State change:")
logging.debug(state)
def handle_log(self, log):
logging.info("Log from device:", log)
def subscribe(self):
self.client.subscribe_states(self.handle_state_change)
self.client.subscribe_logs(self.handle_log)
self.client.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
handle_announcement_finished=self.handle_announcement_finished,
)
async def _stream_tts_audio(
self,
media_id: str,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
samples_per_chunk: int = 1024, # 512
) -> None:
"""Stream TTS audio chunks to device via API or UDP."""
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
logging.info("TTS stream start")
await asyncio.sleep(1)
try:
# if not self._is_running:
# return
with wave.open(f"{media_id}.wav", "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != sample_channels)
):
logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}")
return
logging.info("Streaming %s audio samples", wav_file.getnframes())
rate = wav_file.getframerate()
width = wav_file.getsampwidth()
channels = wav_file.getnchannels()
audio_bytes = wav_file.readframes(wav_file.getnframes())
bytes_per_sample = width * channels
bytes_per_chunk = bytes_per_sample * samples_per_chunk
num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk))
# Split into chunks
for i in range(num_chunks):
offset = i * bytes_per_chunk
chunk = audio_bytes[offset : offset + bytes_per_chunk]
self.client.send_voice_assistant_audio(chunk)
# Wait for 90% of the duration of the audio that was
# sent for it to be played. This will overrun the
# device's buffer for very long audio, so using a media
# player is preferred.
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
seconds_in_chunk = samples_in_chunk / sample_rate
logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}")
await asyncio.sleep(seconds_in_chunk * 0.9)
except asyncio.CancelledError:
return # Don't trigger state change
finally:
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {})
logging.info("TTS stream end")
async def start_http_server(port=8002):
app = web.Application()
app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static"))
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "0.0.0.0", port)
await site.start()
logging.info(f"HTTP server started at http://0.0.0.0:{port}/")
async def main():
"""Connect to an ESPHome device and wait for state changes."""
await start_http_server()
satellite = Satellite()
await satellite.connect()
logging.info("Connected to ESPHome voice assistant.")
satellite.subscribe()
satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None)
logging.info("Listening for events...")
while True:
await asyncio.sleep(0.1)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass

View File

@ -0,0 +1,13 @@
import requests
class LlmApi:
def __init__(self, host="127.0.0.1", port=8000):
self.host = host
self.port = port
self.url = f"http://{host}:{port}"
def ask(self, question):
url = self.url + "/ask"
response = requests.post(url, json={"question": question})
return response.json()

View File

@ -0,0 +1,62 @@
from __future__ import annotations
# Helper script and aioesphomeapi to view logs from an esphome device
import argparse
import asyncio
import logging
import sys
from datetime import datetime
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
from aioesphomeapi.client import APIClient
from aioesphomeapi.log_runner import async_run
async def main(argv: list[str]) -> None:
parser = argparse.ArgumentParser("aioesphomeapi-logs")
parser.add_argument("--port", type=int, default=6053)
parser.add_argument("--password", type=str)
parser.add_argument("--noise-psk", type=str)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("address")
args = parser.parse_args(argv[1:])
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
cli = APIClient(
args.address,
args.port,
args.password or "",
noise_psk=args.noise_psk,
keepalive=10,
)
def on_log(msg: SubscribeLogsResponse) -> None:
time_ = datetime.now()
message: bytes = msg.message
text = message.decode("utf8", "backslashreplace")
nanoseconds = time_.microsecond // 1000
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}")
stop = await async_run(cli, on_log)
try:
await asyncio.Event().wait()
finally:
await stop()
def cli_entry_point() -> None:
"""Run the CLI."""
try:
asyncio.run(main(sys.argv))
except KeyboardInterrupt:
pass
if __name__ == "__main__":
cli_entry_point()
sys.exit(0)

1
external/home-assistant-voice-pe vendored Submodule

@ -0,0 +1 @@
Subproject commit 4c09d89ddbf234546f244bf98b1053a2f89e5ab6

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
line-length = 120