commit 7b45d1930823db9cad434fcd150fb321783ddad7 Author: Pierre Wessman <4029607+pierrewessman@users.noreply.github.com> Date: Thu Jan 16 16:22:58 2025 +0100 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0a6a95f --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +config.yml +debug.py +notebook.ipynb +.vscode/ +secrets.yaml \ No newline at end of file diff --git a/agent-api/.cache b/agent-api/.cache new file mode 100644 index 0000000..01bce28 --- /dev/null +++ b/agent-api/.cache @@ -0,0 +1 @@ +{"access_token": "BQDSGB6rh1-6-bc7wxFTZ2_Idjr6W-xLSZEMizMF6wNXew6_2M4DyWeSaVovZZLPdVv0aQOU8LqIQqQSNr5OwNXDpRk6icpcC6HKYGaZ7I_U9PRWbIOWE-QUakMcAmMFTWwO3FJEBRDYARR43cWTeRn8w9MHG48V5awv2A1447yz8kd4cUk8lAOTbcCeMnpxjJJQqqaigIWAy8NF7ghJXr4dibKj2JtTP5yECA", "token_type": "Bearer", "expires_in": 3600, "scope": "user-library-read", "expires_at": 1736781939, "refresh_token": "AQBaAMOoUVzcichmSIgDC4_CvkXXNNwfjnSbwpypzuaaCQqcQw-_2VP6Rkayp08Y-5gSEbEGVWc0-1CP4-JmQI2kQ_0UmJT5g2G3LvOjufSYowzXRxeedIOWcwbRvjSA1XI"} \ No newline at end of file diff --git a/agent-api/agent.py b/agent-api/agent.py new file mode 100644 index 0000000..fc64f63 --- /dev/null +++ b/agent-api/agent.py @@ -0,0 +1,161 @@ +import datetime +import json +from jinja2 import Template +import logging +from typing import Union + +from db import Database +from llm import LLM + + +def relative_to_actual_dates(text: str): + from dateparser.search import search_dates + + matches = search_dates(text, languages=["en", "se"]) + + if matches: + for match in matches: + text = text.replace(match[0], match[1].strftime("%Y-%m-%d")) + return text + + +class Agent: + def __init__(self, config: dict) -> None: + self.config = config + self.db = Database("rag") + self.llm = LLM(config=self.config["llm"]) + + def flush(self): + """ + Flushes the agents knowledge. + """ + logging.info("Truncating database") + self.db.cur.execute("TRUNCATE TABLE items") + self.db.conn.commit() + + def insert(self, text): + """ + Append knowledge. + """ + logging.info(f"Inserting item into embedding table ({text}).") + + #logging.info(f"\tReplacing relative dates with actual dates.") + #text = relative_to_actual_dates(text) + #logging.info(f"\tDone. text: {text}") + + vector = self.llm.embedding(text) + vector_padded = vector + ([0]*(2048-len(vector))) + + self.db.cur.execute( + "INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)", + (text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")), + ) + self.db.conn.commit() + + def keywords(self, query: str, n_keywords: int = 5): + """ + Suggest keywords related to a query. + """ + sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start." + prompt = ( + f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}." + + f"Do not add any context or text other than the json list of {n_keywords} words." + ) + response = self.llm.query(prompt, system_msg=sys_msg) + keywords = json.loads(response) + return keywords + + def retrieve(self, query: str, n_retrieve=10, n_rerank=5): + """ + Retrieve relevant knowledge. + + Parameters: + - n_retrieve (int): How many notes to retrieve. + - n_rerank (int): How many notes to keep after reranking. + """ + logging.info(f"Retrieving knowledge from database using the query: '{query}'") + + logging.debug(f"Using embedding model on '{query}'.") + vector_keywords = self.llm.embedding(query) + logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords)) + + logging.debug("Querying database") + vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords))) + self.db.cur.execute( + f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}", + (vector_padded,), + ) + db_response = self.db.cur.fetchall() + knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response] + logging.debug("Database returned: ") + logging.debug(knowledge_arr) + + if n_retrieve > n_rerank: + logging.info("Reranking the results") + reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank) + reranked_knowledge_arr = sorted( + knowledge_arr, + key=lambda x: ( + reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr) + ), + ) + reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank] + logging.debug("Reranked results: ") + logging.debug(reranked_knowledge_arr) + return reranked_knowledge_arr + else: + return knowledge_arr + + def ask( + self, + question: str, + person: str, + history: list = None, + n_retrieve: int = 5, + n_rerank=3, + tools: Union[None, list] = None, + prompts: Union[None, list] = None, + ): + """ + Ask the agent a question. + + Parameters: + - question (str): The question to ask + - person (str): Who asks the question? + - n_retrieve (int): How many notes to retrieve. + - n_rerank (int): How many notes to keep after reranking. + """ + logging.info(f"Answering question: {question}") + + if n_retrieve > 0: + knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank) + + with open('templates/knowledge.en.j2', 'r') as file: + template = Template(file.read()) + + knowledge_str = template.render( + knowledge = knowledge, + ) + else: + knowledge_str = "" + + with open('templates/ask.en.j2', 'r') as file: + template = Template(file.read()) + + prompt = template.render( + today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), + knowledge = knowledge_str, + plugin_prompts = prompts, + person = person + ) + + logging.debug(f"Asking the LLM with the prompt: '{prompt}'.") + answer = self.llm.query( + question, + system_msg=prompt, + tools=tools, + history=history + ) + logging.info(f"Got answer from LLM: '{answer}'.") + + return answer diff --git a/agent-api/backend.py b/agent-api/backend.py new file mode 100644 index 0000000..ff16e6a --- /dev/null +++ b/agent-api/backend.py @@ -0,0 +1,127 @@ +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +import importlib +import logging +import os +from pydantic import BaseModel +import yaml + +from agent import Agent + +PLUGINS_FOLDER = "plugins" + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) + +with open("config.yml", "r", encoding='utf8') as file: + config = yaml.safe_load(file) + +def load_plugins(config): + tools = [] + prompts = [] + + for plugin_folder in os.listdir("./" + PLUGINS_FOLDER): + if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder: + logging.info(f"Loading plugin: {plugin_folder}") + module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin" + module = importlib.import_module(module_path) + plugin = module.Plugin(config) + tools += plugin.tools() + prompt = plugin.prompt() + if prompt is not None: + prompts.append(prompt) + + return tools, prompts + +tools, prompts = load_plugins(config) + +app = FastAPI() +app.mount("/static", StaticFiles(directory="static"), name="static") + +agent = Agent(config) + + +class Query(BaseModel): + query: str + +class Message(BaseModel): + role: str + content: str + +class Question(BaseModel): + question: str + person: str = "User" + history: list[Message] = None + + +@app.get("/ping") +def ping(): + """ + # Ping + Check if the API is up. + """ + return {"status": "ok"} + +@app.post("/ask") +def ask(q: Question): + """ + # Ask + Ask the agent a question. + + ## Parameters + - **question**: the question to ask. + - **person**: who is asking the question? + """ + history = None + if q.history: + history = [q.model_dump() for q in q.history] + + answer = agent.ask( + question=q.question, + person="Pierre", + history=history, + tools=tools, + prompts=prompts, + n_retrieve=0, + n_rerank=0 + ) + return {"question": q.question, "answer": answer, "history": q.history} + + +@app.post("/retrieve") +def retrieve(q: Query): + """ + # Retrieve + Retrieve knowledge related to a query. + + ## Parameters + - **query**: the question. + """ + result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5) + return {"query": q.query, "answer": result} + + +@app.post("/learn") +def retrieve(q: Query): + """ + # Learn + Learn the agent something. E.g. *"John likes to play volleyball".* + + ## Parameters + - **query**: the note. + """ + answer = agent.insert(text=q.query) + return {"text": q.query} + + +@app.get("/flush") +def flush(): + """ + # Flush + Remove all knowledge from the agent. + """ + agent.flush() + return {"message": "Knowledge flushed."} diff --git a/agent-api/config.default.yml b/agent-api/config.default.yml new file mode 100644 index 0000000..52d6b3b --- /dev/null +++ b/agent-api/config.default.yml @@ -0,0 +1,25 @@ +server: + url: "http://127.0.0.1:8000" +llm: + use_local_chat: false + use_local_embedding: false + local_model_dir: "" + local_embedding_model: "intfloat/e5-large-v2" + local_rerank_model: "cross-encoder/stsb-distilroberta-base" + openai_embedding_model: "text-embedding-3-small" + openai_chat_model: "gpt-3.5-turbo-0125" + temperature: 0.1 +homeassistant: + url: "http://localhost:8123" + token: "" +plugins: + calendar: + default_calendar: "calendar.my_calendar" + todo: + default_list: "todo.todo" + vasttrafik: + key: "" + secret: "" + default_from_station: "Brunnsparken" + default_to_station: "Centralstationen" + delay: 0 # minutes \ No newline at end of file diff --git a/agent-api/db.py b/agent-api/db.py new file mode 100644 index 0000000..da0bb1e --- /dev/null +++ b/agent-api/db.py @@ -0,0 +1,32 @@ +import logging +from pgvector.psycopg2 import register_vector +import psycopg2 + + +class Database: + def __init__( + self, + database, + user="postgres", + password="postgres", + host="localhost", + port="5433", + ) -> None: + logging.info("Connecting to database") + self.conn = psycopg2.connect( + database="rag", + user="postgres", + password="postgres", + host="localhost", + port="5433", + ) + register_vector(self.conn) + self.cur = self.conn.cursor() + self.cur.execute("SELECT version();") + logging.info(" DB Version: %s", self.cur.fetchone()[0]) + logging.info(" psycopg2 Version: %s", psycopg2.__version__) + + def close(self): + logging.info("Closing connection to database") + self.cur.close() + self.conn.close() diff --git a/agent-api/db.sql b/agent-api/db.sql new file mode 100644 index 0000000..bbc919a --- /dev/null +++ b/agent-api/db.sql @@ -0,0 +1,7 @@ +CREATE TABLE public.items ( + id bigserial NOT NULL, + embedding vector(2048) NULL, + "text" varchar(1024) NULL, + date_added date NULL, + CONSTRAINT items_pkey PRIMARY KEY (id) +); \ No newline at end of file diff --git a/agent-api/docker-compose.yml b/agent-api/docker-compose.yml new file mode 100644 index 0000000..0032aa1 --- /dev/null +++ b/agent-api/docker-compose.yml @@ -0,0 +1,15 @@ +version: '3.8' +services: + db: + image: pgvector/pgvector:pg16 + restart: always + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + ports: + - '5433:5432' + volumes: + - db:/var/lib/postgresql/data +volumes: + db: + driver: local \ No newline at end of file diff --git a/agent-api/frontend.py b/agent-api/frontend.py new file mode 100644 index 0000000..51b51a3 --- /dev/null +++ b/agent-api/frontend.py @@ -0,0 +1,41 @@ +import gradio as gr +import logging +import requests +import yaml + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) + +with open("config.yml", "r") as file: + config = yaml.safe_load(file) + +BASE_URL = config["server"]["url"] + +def chat(message, history): + messages = [] + for (user_msg, bot_msg) in history: + messages.append({"role": "user", "content": user_msg}) + messages.append({"role": "system", "content": bot_msg}) + + logging.info(f"Sending chat message to backend ('{message}').") + response = requests.post(BASE_URL + "/ask", json={"question": message, "person": "Pierre", "history": messages}) + + if response.status_code == 200: + logging.info("Response:") + logging.info(response.json()) + data = response.json() + return data["answer"] + else: + logging.error(response) + +ui = gr.ChatInterface(chat, title="Chat") + +if __name__ == "__main__": + logging.info("Launching frontend.") + ui.launch( + #share=False, + server_name='0.0.0.0', + ) diff --git a/agent-api/install.md b/agent-api/install.md new file mode 100644 index 0000000..0833f57 --- /dev/null +++ b/agent-api/install.md @@ -0,0 +1,3 @@ +# llama.cpp +set CMAKE_ARGS=-DLLAMA_CUBLAS=on +pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python==0.2.39 \ No newline at end of file diff --git a/agent-api/llama_server.py b/agent-api/llama_server.py new file mode 100644 index 0000000..9421db5 --- /dev/null +++ b/agent-api/llama_server.py @@ -0,0 +1,37 @@ +"""Example FastAPI server for llama.cpp. + +To run this example: + +```bash +pip install fastapi uvicorn sse-starlette +export MODEL=../models/7B/... +``` + +Then run: +``` +uvicorn --factory llama_cpp.server.app:create_app --reload +``` + +or + +``` +python3 -m llama_cpp.server +``` + +Then visit http://localhost:8000/docs to see the interactive API docs. + + +To actually see the implementation of the server, see llama_cpp/server/app.py + +""" +import os +import uvicorn + +from llama_cpp.server.app import create_app + +if __name__ == "__main__": + app = create_app() + + uvicorn.run( + app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) + ) diff --git a/agent-api/llm.py b/agent-api/llm.py new file mode 100644 index 0000000..9a1cc98 --- /dev/null +++ b/agent-api/llm.py @@ -0,0 +1,252 @@ +import concurrent.futures +import json +from llama_cpp.llama import Llama +import logging +import numpy as np +from openai import OpenAI +from sentence_transformers import SentenceTransformer, CrossEncoder +from typing import Union, List + + +class BaseChat: + def __init__(self, config: dict) -> None: + self.config = config + + def prepare_function_calling(self, tools): + # prepare function calling + function_map = {} + if tools is not None and len(tools) > 0: + for tool in tools: + function_name = tool["function"]["name"] + function_map[function_name] = tool["function"]["function_to_call"] + functions = [] + for tool in tools: + fn = tool["function"] + functions.append( + {"type": tool["type"], "function": {x: fn[x] for x in fn if x != "function_to_call"}} + ) + logging.info(f"{len(tools)} available functions:") + logging.info(function_map.keys()) + else: + functions = None + return function_map, functions + + +class OpenAIChat(BaseChat): + def __init__(self, config: dict) -> None: + self.config = config + if self.config["openai"]["base_url"] is not None and self.config["openai"]["base_url"] != "": + base_url = self.config["openai"]["base_url"] + else: + base_url = None + self.client = OpenAI(base_url=base_url) + + def chat(self, messages, tools) -> str: + function_map, functions = self.prepare_function_calling(tools) + + logging.debug("Sending request to OpenAI.") + llm_response = self.client.chat.completions.create( + model=self.config["openai"]["chat_model"], + messages=messages, + tools=functions, + temperature=self.config["temperature"], + tool_choice="auto" if functions is not None else None, + ) + logging.debug("LLM response:") + logging.debug(llm_response.choices) + if llm_response.choices[0].message.tool_calls: + # Handle function calls + followup_response = self.execute_function_call( + llm_response.choices[0].message, function_map, messages + ) + return followup_response.choices[0].message.content.strip() + else: + return llm_response.choices[0].message.content.strip() + + def execute_function_call(self, message, function_map: dict, messages: list) -> str: + """ + Executes function calls embedded in a LLM message in parallel, and returns a LLM response based on the results. + + Parameters: + - message: LLM message containing the function calls. + - function_map (dict): dict of {"function_name": function()} + - message (list): message history + """ + tool_calls = message.tool_calls + logging.info(f"Got {len(tool_calls)} function call(s).") + + def execute_single_tool_call(tool_call): + """Helper function to execute a single tool call""" + logging.info(f"Attempting to execute function requested by LLM ({tool_call.function.name}, {tool_call.function.arguments}).") + if tool_call.function.name in function_map: + function_to_call = function_map[tool_call.function.name] + args = json.loads(tool_call.function.arguments) + logging.debug(f"Calling function {tool_call.function.name} with args: {args}") + function_response = function_to_call(**args) + logging.debug(function_response) + return { + "role": "function", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": function_response, + } + else: + logging.info(f"{tool_call.function.name} not in function_map") + logging.info(function_map.keys()) + return None + + # Execute tool calls in parallel + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit all tool calls to the executor + future_to_tool_call = { + executor.submit(execute_single_tool_call, tool_call): tool_call + for tool_call in tool_calls + } + + # Collect results as they complete + for future in concurrent.futures.as_completed(future_to_tool_call): + result = future.result() + if result is not None: + messages.append(result) + + logging.debug("Functions called, sending the results to LLM.") + llm_response = self.client.chat.completions.create( + model=self.config["openai"]["chat_model"], + messages=messages, + temperature=self.config["temperature"], + ) + logging.debug("Got response from LLM:") + logging.debug(llm_response) + return llm_response + + +class LMStudioChat(BaseChat): + def __init__(self, config: dict) -> None: + self.config = config + self.client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed") + + def chat(self, messages, tools) -> str: + logging.info("Sending request to local LLM.") + llm_response = self.client.chat.completions.create( + model="", + messages=messages, + temperature=self.config["temperature"], + ) + return llm_response.choices[0].message.content.strip() + + +class LlamaChat(BaseChat): + def __init__(self, config: dict) -> None: + self.config = config + self.client = Llama( + self.config["local_model_dir"] + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF/mistral-7b-instruct-v0.1.Q6_K.gguf", + n_gpu_layers=32, + n_ctx=2048, + verbose=False + ) + + def chat(self, messages: list, response_format: dict = None, tools: list = None) -> str: + if tools is not None: + logging.warning("Tools was provided to LlamaChat, but it's not yet supported.") + logging.info("Sending request to local LLM.") + logging.info(messages) + llm_response = self.client.create_chat_completion( + messages = messages, + response_format = response_format, + temperature = self.config["temperature"], + ) + return llm_response["choices"][0]["message"]["content"].strip() + + +class LLM: + def __init__(self, config: dict) -> None: + """ + LLM constructor. + + Parameters: + - config (dict): llm-config parsed from the llm part of config.yml + """ + self.config = config + + if self.config["use_local_chat"]: + self.chat_client = LlamaChat(self.config) + else: + self.chat_client = OpenAIChat(self.config) + + if self.config["use_local_embedding"]: + self.embedding_model = self.config["local_embedding_model"] + self.rerank_model = self.config["local_rerank_model"] + self.embedding_client = SentenceTransformer(self.embedding_model) + else: + self.embedding_model = self.config["openai"]["embedding_model"] + self.rerank_model = None + self.embedding_client = OpenAI() + + def query( + self, + user_msg: str, + system_msg: Union[None, str] = None, + history: list = [], + tools: Union[None, list] = None, + ): + """ + Query the LLM + + Parameters: + - user_msg (str): query from the user (will be appended to messages as {"role": "user"}) + - system_msg (str): query from the user (will be prepended to messages as {"role": "system"}) + - history (list): optional, list of messages to inject between system_msg and user_msg. + - tools: optional, list of functions that may be called by the LLM. Example: + { + "type": "function", + "function": { + "name": "get_current_weather", + "function": get_current_weather, + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location", "format"], + }, + }, + } + """ + messages = [] + + if system_msg is not None: + messages.append({"role": "system", "content": system_msg}) + + if history and len(history) > 0: + logging.info("History") + logging.info(history) + messages += history + + messages.append({"role": "user", "content": user_msg}) + + logging.info(f"Sending request to LLM with {len(messages)} messages.") + answer = self.chat_client.chat(messages=messages, tools=tools) + return answer + + def embedding(self, text: str): + if self.config["use_local_embedding"]: + embedding = self.embedding_client.encode(text) + return embedding # len = 1024 + else: + llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")]) + return llm_response.data[0].embedding + + def rerank(self, text: str, corpus: List[str], n_rerank: int = 5): + if self.rerank_model is not None and len(corpus) > 0: + sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus] + logging.info("Reranking:") + logging.info(sentence_combinations) + similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations) + result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))] + return result[:n_rerank] + else: + return corpus diff --git a/agent-api/plugins/__init__.py b/agent-api/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent-api/plugins/base_plugin.py b/agent-api/plugins/base_plugin.py new file mode 100644 index 0000000..8cbe8b9 --- /dev/null +++ b/agent-api/plugins/base_plugin.py @@ -0,0 +1,68 @@ +import inspect + +from plugins.homeassistant import HomeAssistant + + +class BasePlugin: + def __init__(self, config: dict) -> None: + self.config = config + self.homeassistant = HomeAssistant(config) + + def prompt(self) -> str | None: + return + + def tools(self) -> list: + tools = [] + for tool_fn in self._list_tool_methods(): + tool_fn_metadata = self._get_function_metadata(tool_fn) + + if "input" in tool_fn_metadata["parameters"]: + json_schema = tool_fn_metadata["parameters"]["input"].annotation.model_json_schema() + del json_schema["title"] + for k, v in json_schema["properties"].items(): + del json_schema["properties"][k]["title"] + fn_to_call = self._create_function_with_kwargs( + tool_fn_metadata["parameters"]["input"].annotation, tool_fn + ) + else: + json_schema = {} + fn_to_call = tool_fn + + tools.append( + { + "type": "function", + "function": { + "name": tool_fn_metadata["name"], + "function_to_call": fn_to_call, + "description": tool_fn_metadata["docstring"], + "parameters": json_schema, + }, + } + ) + return tools + + def _get_function_metadata(self, func): + function_name = func.__name__ + docstring = inspect.getdoc(func) + signature = inspect.signature(func) + parameters = signature.parameters + + metadata = {"name": function_name, "docstring": docstring or "", "parameters": parameters} + + return metadata + + def _create_function_with_kwargs(self, model_cls, original_function): + def dynamic_function(**kwargs): + model_instance = model_cls(**kwargs) + return original_function(model_instance) + + return dynamic_function + + def _list_tool_methods(self) -> list: + attributes = dir(self) + tool_functions = [ + getattr(self, attr) + for attr in attributes + if callable(getattr(self, attr)) and attr.startswith("tool_") + ] + return tool_functions diff --git a/agent-api/plugins/calendar/plugin.py b/agent-api/plugins/calendar/plugin.py new file mode 100644 index 0000000..7cf2fb6 --- /dev/null +++ b/agent-api/plugins/calendar/plugin.py @@ -0,0 +1,26 @@ +import asyncio +import json + +from ..base_plugin import BasePlugin + +class Plugin(BasePlugin): + def __init__(self, config: dict) -> None: + super().__init__(config=config) + + def tool_get_calendar_events(self, entity_id=None, days=3): + """ + Get users calendar events. + """ + if entity_id is None: + entity_id = self.config["plugins"]["calendar"]["default_calendar"] + response = asyncio.run( + self.homeassistant.send_command( + "call_service", + domain="calendar", + service="get_events", + target={"entity_id": entity_id}, + service_data={"duration": {"days": days}}, + return_response=True, + ) + ) + return json.dumps(response) diff --git a/agent-api/plugins/homeassistant.py b/agent-api/plugins/homeassistant.py new file mode 100644 index 0000000..150d451 --- /dev/null +++ b/agent-api/plugins/homeassistant.py @@ -0,0 +1,51 @@ +import json +from typing import Any + +import requests +from aiohttp import ClientSession +from hass_client import HomeAssistantClient + + +class HomeAssistant: + + def __init__(self, config: dict) -> None: + self.config = config + self.ha_config = self.call_api("config") + + def call_api(self, endpoint, payload=None): + """ + Call the REST API + """ + base_url = self.config["homeassistant"]["url"] + headers = { + "Authorization": f"Bearer {self.config['homeassistant']['token']}", + "content-type": "application/json", + } + if payload is None: + response = requests.get(f"{base_url}/api/{endpoint}", headers=headers) + else: + response = requests.post(f"{base_url}/api/{endpoint}", headers=headers, json=payload) + if response.status_code == 200: + try: + return response.json() + except: + pass + try: + return json.loads(response.text.replace("'", '"').replace("None", "null")) + except: + return response.text + else: + return {"status": "error", "message": response.text} + + async def send_command(self, command: str, **kwargs: dict[str, Any]): + """ + Send command using the WebSocket API + """ + async with ClientSession() as session: + async with HomeAssistantClient( + self.config["homeassistant"]["url"] + "/api/websocket", + self.config["homeassistant"]["token"], + session, + ) as client: + response = await client.send_command(command, **kwargs) + return response diff --git a/agent-api/plugins/lights/plugin.py b/agent-api/plugins/lights/plugin.py new file mode 100644 index 0000000..64ade0a --- /dev/null +++ b/agent-api/plugins/lights/plugin.py @@ -0,0 +1,73 @@ +import json + +from ..base_plugin import BasePlugin + + +class Plugin(BasePlugin): + def __init__(self, config: dict) -> None: + super().__init__(config=config) + self.light_map = self.get_light_map() + + def prompt(self): + prompt = "These are the lights available:\n" + json.dumps( + self.light_map, indent=2, ensure_ascii=False + ) + return prompt + + def tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "control_light", + "function_to_call": self.tool_control_light, + "description": "Control lights in users a home", + "parameters": { + "type": "object", + "properties": { + "entity_id": { + "type": "string", + "enum": self.available_lights(), + "description": "What light entity to control", + }, + "state": { + "type": "string", + "enum": ["on", "off"], + "description": "on or off", + }, + }, + "required": ["room", "state"], + }, + }, + } + ] + return tools + + def tool_control_light(self, entity_id: str, state: str): + self.homeassistant.call_api(f"services/light/turn_{state}", payload={"entity_id": entity_id}) + return json.dumps({"status": "success", "message": f"{entity_id} was turned {state}."}) + + def available_lights(self): + available_lights = [] + for room in self.light_map: + for light in self.light_map[room]: + available_lights.append(light["entity_id"]) + return available_lights + + def get_light_map(self): + template = ( + "{% for state in states.light %}\n" + + '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}", "room": "{{area_name(state.entity_id)}}"}\n' + + "{% endfor %}\n" + ) + response = self.homeassistant.call_api(f"template", payload={"template": template}) + light_map = {} + for item_str in response.split("\n\n"): + item_dict = json.loads(item_str) + if item_dict["room"] not in light_map: + light_map[item_dict["room"]] = [] + light_map[item_dict["room"]].append( + {"entity_id": item_dict["entity_id"], "name": item_dict["name"]} + ) + + return light_map diff --git a/agent-api/plugins/music/.cache b/agent-api/plugins/music/.cache new file mode 100644 index 0000000..420d072 --- /dev/null +++ b/agent-api/plugins/music/.cache @@ -0,0 +1 @@ +{"access_token": "BQC4auYl6DX-_xbS0zDjbv2_OZGbbdvyWLDa7I6EBE91GhBAExs3TGmhBNCgwS3P0WJwkUyL-kRRH2x-80QpWaRdLhEm0Q7Tfm8vzsV1rRBToy87Vm1jJJX3S08bYQwD3UvxFlaOhGyyWoSa37JnPkmS9T8d2GTPeHne1HaYR5KEJeYu9stixoONiTjHidQJi_7hOozIqwUkTC7TFWBWezJfa_7GlkE", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "AQDSGqpEaLoyR_A-s78bFVHYE1F2PIbA_cfgrgMt7jZA-A3xmEo5V7GazlgB-okD0JHX-w_fVc-n0_b8nvvZUz5PT6iCd7jRHLVoRA-h6XbQKcSAUVwJGL2djOTfnC4wOIE", "scope": "user-library-read", "expires_at": 1736686521} \ No newline at end of file diff --git a/agent-api/plugins/music/plugin.py b/agent-api/plugins/music/plugin.py new file mode 100644 index 0000000..6f5d939 --- /dev/null +++ b/agent-api/plugins/music/plugin.py @@ -0,0 +1,59 @@ +import json +import logging + +import spotipy +from pydantic import BaseModel, Field +from spotipy.oauth2 import SpotifyOAuth + +from ..base_plugin import BasePlugin + + +class PlayMusicInput(BaseModel): + query: str = Field(..., description="Can be a song, artist, album, or playlist") + + +class Plugin(BasePlugin): + def __init__(self, config: dict) -> None: + super().__init__(config=config) + + self.spotify = spotipy.Spotify( + auth_manager=SpotifyOAuth(scope="user-library-read", redirect_uri="http://localhost:8080") + ) + + def _search(self, query: str, limit: int = 10): + _result = self.spotify.search(query, limit=limit) + result = [] + for track in _result["tracks"]["items"]: + artists = [artist["name"] for artist in track["artists"]] + result.append( + { + "name": track["name"], + "artists": artists, + "uri": track["uri"] + } + ) + return result + + def tool_play_music(self, input: PlayMusicInput): + """ + Play music using a search query. + """ + track = self._search(input.query, limit=1)[0] + logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}") + payload = { + "entity_id": self.config["plugins"]["music"]["default_speaker"], + "media_content_id": track["uri"], + "media_content_type": "music", + "enqueue": "play", + } + result = self.homeassistant.call_api(f"services/media_player/play_media", payload=payload) + return json.dumps({"status": "success", "message": f"Playing music.", "track": track}) + + def tool_stop_music(self): + """ + Stop playback of music. + """ + self.homeassistant.call_api( + f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]} + ) + return json.dumps({"status": "success", "message": f"Music paused."}) diff --git a/agent-api/plugins/readme.md b/agent-api/plugins/readme.md new file mode 100644 index 0000000..194fe33 --- /dev/null +++ b/agent-api/plugins/readme.md @@ -0,0 +1,3 @@ +# Plugins +Each folder in /plugins is a plugin. +Class files within /plugins (e.g. homeassistant.py) is helper classes that plugins can use in order to have to duplicate code. \ No newline at end of file diff --git a/agent-api/plugins/smhi/plugin.py b/agent-api/plugins/smhi/plugin.py new file mode 100644 index 0000000..72113d9 --- /dev/null +++ b/agent-api/plugins/smhi/plugin.py @@ -0,0 +1,74 @@ +import json + +from smhi.smhi_lib import Smhi + +from ..homeassistant import HomeAssistant + + +class Plugin: + def __init__(self, config: dict) -> None: + self.config = config + self.homeassistant = HomeAssistant(config) + self.station = Smhi( + self.homeassistant.ha_config["longitude"], self.homeassistant.ha_config["latitude"] + ) + self.weather_conditions = { + 1: "Clear sky", + 2: "Nearly clear sky", + 3: "Variable cloudiness", + 4: "Halfclear sky", + 5: "Cloudy sky", + 6: "Overcast", + 7: "Fog", + 8: "Light rain showers", + 9: "Moderate rain showers", + 10: "Heavy rain showers", + 11: "Thunderstorm", + 12: "Light sleet showers", + 13: "Moderate sleet showers", + 14: "Heavy sleet showers", + 15: "Light snow showers", + 16: "Moderate snow showers", + 17: "Heavy snow showers", + 18: "Light rain", + 19: "Moderate rain", + 20: "Heavy rain", + 21: "Thunder", + 22: "Light sleet", + 23: "Moderate sleet", + 24: "Heavy sleet", + 25: "Light snowfall", + 26: "Moderate snowfall", + 27: "Heavy snowfall", + } + + def prompt(self): + return None + + def tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather_forecast", + "function_to_call": self.get_weather_forecast, + "description": "Get weather forecast for the following 24 hours.", + }, + } + ] + return tools + + def get_weather_forecast(self, hours=24): + forecast = [] + for hour in self.station.get_forecast_hour()[:hours]: + forecast.append( + { + "time": hour.valid_time.strftime("%Y-%m-%d %H:%M"), + "weather": self.weather_conditions[hour.symbol], + "temperature": hour.temperature, + # "cloudiness": hour.cloudiness, + "total_precipitation": hour.total_precipitation, + "wind_speed": hour.wind_speed, + } + ) + return json.dumps(forecast) diff --git a/agent-api/plugins/smhi/requirements.txt b/agent-api/plugins/smhi/requirements.txt new file mode 100644 index 0000000..df77dbb --- /dev/null +++ b/agent-api/plugins/smhi/requirements.txt @@ -0,0 +1 @@ +smhi-pkg>=1.0.16 \ No newline at end of file diff --git a/agent-api/plugins/todo/plugin.py b/agent-api/plugins/todo/plugin.py new file mode 100644 index 0000000..dafd5f2 --- /dev/null +++ b/agent-api/plugins/todo/plugin.py @@ -0,0 +1,107 @@ +import asyncio +import json + +from ..base_plugin import BasePlugin + +class Plugin(BasePlugin): + def __init__(self, config: dict) -> None: + super().__init__(config=config) + + def prompt(self): + prompt = "These are the todo lists available:\n" + json.dumps( + self.get_todo_lists(), indent=2, ensure_ascii=False + ) + return prompt + + def tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_todo_list_items", + "function_to_call": self.get_todo_list_items, + "description": "Get items from todo-list.", + "parameters": { + "type": "object", + "properties": { + "entity_id": { + "type": "string", + "description": "What list to get items from.", + } + }, + "required": ["entity_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "add_todo_item", + "function_to_call": self.add_todo_item, + "description": "Add item to todo-list.", + "parameters": { + "type": "object", + "properties": { + "entity_id": { + "type": "string", + "description": "What list to add item to.", + }, + "item": { + "type": "string", + "description": "Description of the item.", + }, + }, + "required": ["entity_id", "item"], + }, + }, + }, + ] + return tools + + def get_todo_list_items(self, entity_id: str = None): + """ + Get items from todo-list. + """ + if entity_id is None: + entity_id = self.config["plugins"]["todo"]["default_list"] + response = asyncio.run( + self.homeassistant.send_command( + "call_service", + domain="todo", + service="get_items", + target={"entity_id": entity_id}, + service_data={"status": "needs_action"}, + return_response=True, + ) + ) + return json.dumps(response["response"]) + + def add_todo_item(self, item: str, entity_id: str = None): + """ + Add item to todo-list. + """ + if entity_id is None: + entity_id = self.config["plugins"]["todo"]["default_list"] + asyncio.run( + self.homeassistant.send_command( + "call_service", + domain="todo", + service="add_item", + target={"entity_id": entity_id}, + service_data={"item": item}, + ) + ) + return json.dumps({"status": f"{item} added to list."}) + + def get_todo_lists(self): + template = ( + "{% for state in states.todo %}\n" + + '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}"}\n' + + "{% endfor %}\n" + ) + response = self.homeassistant.call_api(f"template", payload={"template": template}) + todo_lists = {} + for item_str in response.split("\n\n"): + item_dict = json.loads(item_str) + todo_lists[item_dict["name"]] = item_dict["entity_id"] + return todo_lists diff --git a/agent-api/plugins/vasttrafik/plugin.py b/agent-api/plugins/vasttrafik/plugin.py new file mode 100644 index 0000000..51fa29a --- /dev/null +++ b/agent-api/plugins/vasttrafik/plugin.py @@ -0,0 +1,98 @@ +import json +from datetime import datetime + +import vasttrafik + +from ..homeassistant import HomeAssistant + + +class Plugin: + def __init__(self, config: dict) -> None: + self.config = config + self.homeassistant = HomeAssistant(config) + self.vasttrafik = vasttrafik.JournyPlanner( + key=self.config["plugins"]["vasttrafik"]["key"], + secret=self.config["plugins"]["vasttrafik"]["secret"], + ) + self.default_from_station_id = self.get_station_id( + self.config["plugins"]["vasttrafik"]["default_from_station"] + ) + self.default_to_station_id = self.get_station_id( + self.config["plugins"]["vasttrafik"]["default_to_station"] + ) + + def prompt(self): + from_station = self.config["plugins"]["vasttrafik"]["default_from_station"] + to_station = self.config["plugins"]["vasttrafik"]["default_to_station"] + return "# Public transportation #\nHome station is: %s, and the default to station to use is %s" % ( + from_station, + to_station, + ) + + def tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "search_trip", + "function_to_call": self.search_trip, + "description": "Search for trips by public transportation.", + "parameters": { + "type": "object", + "properties": { + "from_station": { + "type": "string", + "description": "Station to travel from.", + }, + "to_station": { + "type": "string", + "description": "Station to travel to.", + }, + }, + "required": ["from_station", "to_station"], + }, + }, + } + ] + return tools + + def search_trip(self, from_station=None, to_station=None): + if from_station is None: + from_station_id = self.default_from_station_id + else: + from_station_id = self.get_station_id(from_station) + if to_station is None: + to_station_id = self.default_to_station_id + else: + to_station_id = self.get_station_id(to_station) + trip = self.vasttrafik.trip(from_station_id, to_station_id) + + response = [] + for departure in trip: + try: + departure_dt = datetime.strptime( + departure["tripLegs"][0]["origin"]["estimatedTime"][:19], "%Y-%m-%dT%H:%M:%S" + ) + minutes_to_departure = (departure_dt - datetime.now()).seconds / 60 + if minutes_to_departure >= self.config["plugins"]["vasttrafik"]["delay"]: + response.append( + { + "origin": { + "stationName": departure["tripLegs"][0]["origin"]["stopPoint"]["name"], + # "plannedTime": departure["tripLegs"][0]["origin"]["plannedTime"], + "estimatedDepartureTime": departure["tripLegs"][0]["origin"]["estimatedTime"], + }, + "destination": { + "stationName": departure["tripLegs"][0]["destination"]["stopPoint"]["name"], + "estimatedArrivalTime": departure["tripLegs"][0]["estimatedArrivalTime"], + }, + } + ) + except Exception as e: + print(f"Error: {e}") + assert len(response) > 0, "No trips found." + return json.dumps(response, ensure_ascii=False) + + def get_station_id(self, location_name: str) -> str: + station_id = self.vasttrafik.location_name(location_name)[0]["gid"] + return station_id diff --git a/agent-api/plugins/vasttrafik/requirements.txt b/agent-api/plugins/vasttrafik/requirements.txt new file mode 100644 index 0000000..ad8d832 --- /dev/null +++ b/agent-api/plugins/vasttrafik/requirements.txt @@ -0,0 +1 @@ +vtjp>=0.2.1 \ No newline at end of file diff --git a/agent-api/pyproject.toml b/agent-api/pyproject.toml new file mode 100644 index 0000000..8b24fe3 --- /dev/null +++ b/agent-api/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 110 \ No newline at end of file diff --git a/agent-api/readme.md b/agent-api/readme.md new file mode 100644 index 0000000..fed8912 --- /dev/null +++ b/agent-api/readme.md @@ -0,0 +1,21 @@ +## Start postgres with pgvector +docker-compose up -d + +Then create a new db called "rag", and run this SQL: + +```sql +CREATE EXTENSION vector; +``` + +Finally run the content of db.sql. + +## Run backend +```bash +conda activate llm-api +python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml +``` + +## Run frontend +```bash +python frontend.py +``` \ No newline at end of file diff --git a/agent-api/requirements.txt b/agent-api/requirements.txt new file mode 100644 index 0000000..e6c1512 --- /dev/null +++ b/agent-api/requirements.txt @@ -0,0 +1,11 @@ +fastapi +uvicorn[standard] +openai>=1.11.1 +sentence_transformers +dateparser +pgvector +psycopg2 +pyyaml +gradio +hass-client>=1.0.1 +llama_cpp_python>=0.2.44 diff --git a/agent-api/static/tada.wav b/agent-api/static/tada.wav new file mode 100644 index 0000000..d74c3d3 Binary files /dev/null and b/agent-api/static/tada.wav differ diff --git a/agent-api/templates/ask.en.j2 b/agent-api/templates/ask.en.j2 new file mode 100644 index 0000000..9fbe5c0 --- /dev/null +++ b/agent-api/templates/ask.en.j2 @@ -0,0 +1,12 @@ +You are a helpful assistant. Your replies will be converted to speech, so keep it short and concise and avoid too many special characters. +Today is {{ today }}. + +{% if knowledge is not none %} + {{ knowledge }} +{% endif %} + +{% for prompt in plugin_prompts %} +- {{ prompt }} +{% endfor %} + +Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question. \ No newline at end of file diff --git a/agent-api/templates/knowledge.en.j2 b/agent-api/templates/knowledge.en.j2 new file mode 100644 index 0000000..b53dcec --- /dev/null +++ b/agent-api/templates/knowledge.en.j2 @@ -0,0 +1,5 @@ +Here are some saved notes that might help you answer any question the user might have: +{% for note in knowledge %} +- {{ note }} +{% endfor %} +Never make notes of knowledge you already have (i.e. the list above). \ No newline at end of file diff --git a/esphome-va-bridge/audio.py b/esphome-va-bridge/audio.py new file mode 100644 index 0000000..488afc2 --- /dev/null +++ b/esphome-va-bridge/audio.py @@ -0,0 +1,203 @@ +import datetime as dt +import io +from typing import Iterator + +import numpy as np +import pydub +import scipy.io.wavfile as wavfile +import torch +from elevenlabs.client import ElevenLabs +from faster_whisper import WhisperModel +from silero_vad import get_speech_timestamps, load_silero_vad + + +class AudioSegment(pydub.AudioSegment): + """ + Wrapper class for pydub.AudioSegment + """ + + def __init__(self, data=None, *args, **kwargs): + super().__init__(data, *args, **kwargs) + + def to_ndarray(self) -> np.ndarray: + """ + Convert the AudioSegment to a numpy array. + """ + _buffer = io.BytesIO() + self.export(_buffer, format="wav") + _buffer.seek(0) + + _, wav_data = wavfile.read(_buffer) + if wav_data.dtype != np.float32: + max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1 + wav_data = wav_data.astype(np.float32) / max_abs + if len(wav_data.shape) == 2: + wav_data = wav_data.mean(axis=1) + return wav_data + + def to_tensor(self) -> torch.Tensor: + """ + Convert the AudioSegment to a PyTorch tensor. + """ + return torch.tensor(self.to_ndarray(), dtype=torch.float32) + + +class AudioBuffer: + """ + Buffer for AudioSegments + """ + + def __init__(self, max_size: int) -> None: + self.data: AudioSegment | None = None + self.max_size: int = max_size + + def put(self, audio_segment: AudioSegment) -> None: + """ + Append an AudioSegment to the buffer. + """ + if self.data: + self.data = self.data + audio_segment + if len(self.data) >= self.max_size: + self.data = self.data[-self.max_size :] + else: + self.data = audio_segment + + def clear(self) -> None: + """ + Clear the buffer. + """ + self.data = None + + def get(self) -> AudioSegment | None: + """ + Get the AudioSegment from the buffer. + """ + return self.data + + def get_last(self, duration: int) -> AudioSegment | None: + """ + Get the last `duration` milliseconds of the AudioSegment. + """ + return self.data[-duration:] if self.data else None + + def is_empty(self) -> bool: + """ + Check if the buffer is empty. + """ + return self.data is None + + def length(self) -> int: + """ + Get the length of the buffer. + """ + return len(self.data) if self.data else 0 + + +class VAD: + """ + Voice Activity Detection + """ + + def __init__(self) -> None: + self.model = load_silero_vad() + self.ts_voice_detected: dt.datetime | None = None + + def detect_voice(self, audio_segment: AudioSegment) -> bool: + """ + Detect voice in an AudioSegment. + """ + speech_timestamps = get_speech_timestamps( + audio_segment.to_tensor(), + self.model, + return_seconds=True, + ) + if len(speech_timestamps) > 0: + if self.ts_voice_detected is None: + pass # print("VAD: Voice detected.") + self.ts_voice_detected = dt.datetime.now() + return True + + return False + + def ms_since_voice_detected(self) -> int | None: + """ + Get the milliseconds since voice was detected. + """ + if self.ts_voice_detected is None: + return None + return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000 + + def reset_timer(self): + """ + Reset the timer for voice detected. + """ + self.ts_voice_detected = None + + +class STT: + """ + Speech-to-Text + """ + + def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None: + compute_type = "int8" if device == "cpu" else "float16" + self.model = WhisperModel(model, device=device, compute_type=compute_type) + self.benchmark = benchmark + + def transcribe(self, audio_segment: AudioSegment, language="sv"): + """ + Transcribe an AudioSegment. + """ + if self.benchmark: + ts_start = dt.datetime.now() + segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language) + segments = list(segments) + text = " ".join([seg.text for seg in segments]) + text = text.strip() + if self.benchmark: + ts_done = dt.datetime.now() + delta_ms = (ts_done - ts_start).total_seconds() * 1000 + print(f"[Benchmark] STT: {delta_ms:.2f} ms.") + return text + + +class TTS: + """ + Text-to-Speech (TTS) class that uses the ElevenLabs API. + """ + + def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None: + self.client = ElevenLabs() + self.voice = voice + self.model = model + + def generate(self, text: str, voice: str = "default") -> AudioSegment: + if voice == "default": + voice = self.voice + audio = self.client.generate( + text=text, + voice=self.voice, + model=self.model, + stream=False, + output_format="pcm_16000", + optimize_streaming_latency=2, + ) + audio = b"".join(audio) + audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1) + return audio_segment + + def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]: + if voice == "default": + voice = self.voice + audio_stream = self.client.generate( + text=text_stream, + voice=self.voice, + model=self.model, + stream=True, + output_format="pcm_16000", + optimize_streaming_latency=2, + ) + + for chunk in audio_stream: + if chunk is not None: + yield chunk diff --git a/esphome-va-bridge/esp-working.log b/esphome-va-bridge/esp-working.log new file mode 100644 index 0000000..9cd5a57 --- /dev/null +++ b/esphome-va-bridge/esp-working.log @@ -0,0 +1,113 @@ +[16:22:06.368][D][i2s_audio.microphone:377]: Starting I2S Audio Microphne +[16:22:06.368][D][i2s_audio.microphone:381]: Started I2S Audio Microphone + +[16:22:06.368][D][micro_wake_word:417]: State changed from IDLE to DETECTING_WAKE_WORD +[16:22:08.878][D][micro_wake_word:355]: Detected 'Okay Nabu' with sliding average probability is 0.97 and max probability is 1.00 + +[16:22:08.878][D][media_player:080]: 'Media Player' - Setting +[16:22:08.878][D][media_player:084]: Command: STOP +[16:22:08.878][D][media_player:093]: Announcement: yes +[16:22:08.878][D][media_player:080]: 'Media Player' - Setting +[16:22:08.878][D][media_player:093]: Announcement: yes + +[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000 +[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000 +[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536 +[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536 + +[16:22:08.879][D][nabu_media_player.pipeline:173]: Reading FLAC file type +[16:22:08.879][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample +[16:22:08.879][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio + +[16:22:08.879][D][ring_buffer:034][speaker_task]: Created ring buffer with size 19200 + +[16:22:08.879][D][i2s_audio.speaker:111]: Starting Speaker +[16:22:08.880][D][i2s_audio.speaker:116]: Started Speaker + +[16:22:08.880][D][voice_assistant:515]: State changed from IDLE to START_MICROPHONE +[16:22:08.880][D][voice_assistant:522]: Desired state set to START_PIPELINE +[16:22:08.880][D][voice_assistant:225]: Starting Microphone + +[16:22:08.880][D][ring_buffer:034]: Created ring buffer with size 16384 + +[16:22:08.880][D][voice_assistant:515]: State changed from START_MICROPHONE to STARTING_MICROPHONE +[16:22:08.880][D][voice_assistant:515]: State changed from STARTING_MICROPHONE to START_PIPELINE +[16:22:08.880][D][voice_assistant:280]: Requesting start... +[16:22:08.880][D][voice_assistant:515]: State changed from START_PIPELINE to STARTING_PIPELINE +[16:22:08.880][D][voice_assistant:537]: Client started, streaming microphone +[16:22:08.881][D][voice_assistant:515]: State changed from STARTING_PIPELINE to STREAMING_MICROPHONE +[16:22:08.881][D][voice_assistant:522]: Desired state set to STREAMING_MICROPHONE +[16:22:08.881][D][voice_assistant:641]: Event Type: 1 (VOICE_ASSISTANT_RUN_START) +[16:22:08.881][D][voice_assistant:644]: Assist Pipeline running +[16:22:08.881][D][voice_assistant:641]: Event Type: 3 (VOICE_ASSISTANT_STT_START) +[16:22:08.881][D][voice_assistant:655]: STT started + +[16:22:08.881][D][light:036]: 'voice_assistant_leds' Setting: +[16:22:08.881][D][light:047]: State: ON +[16:22:08.881][D][light:051]: Brightness: 66% +[16:22:08.882][D][light:109]: Effect: 'Waiting for Command' + +[16:22:08.882][D][power_supply:033]: Enabling power supply. + +[16:22:09.671][D][voice_assistant:641]: Event Type: 11 (VOICE_ASSISTANT_STT_VAD_START) +[16:22:09.671][D][voice_assistant:804]: Starting STT by VAD + +[16:22:09.671][D][light:036]: 'voice_assistant_leds' Setting: +[16:22:09.671][D][light:051]: Brightness: 66% +[16:22:09.671][D][light:109]: Effect: 'Listening For Command' + +[16:22:14.806][D][voice_assistant:641]: Event Type: 12 (VOICE_ASSISTANT_STT_VAD_END) +[16:22:14.806][D][voice_assistant:808]: STT by VAD end +[16:22:14.806][D][voice_assistant:515]: State changed from STREAMING_MICROPHONE to STOP_MICROPHONE +[16:22:14.806][D][voice_assistant:522]: Desired state set to AWAITING_RESPONSE +[16:22:14.806][D][voice_assistant:515]: State changed from STOP_MICROPHONE to STOPPING_MICROPHONE + +[16:22:14.807][D][light:036]: 'voice_assistant_leds' Setting: +[16:22:14.807][D][light:051]: Brightness: 66% +[16:22:14.807][D][light:109]: Effect: 'Thinking' + +[16:22:14.807][D][voice_assistant:515]: State changed from STOPPING_MICROPHONE to AWAITING_RESPONSE +[16:22:14.807][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to AWAITING_RESPONSE + +[16:22:14.807][D][power_supply:033]: Enabling power supply. +[16:22:14.807][D][power_supply:033]: Enabling power supply. + +[16:22:14.807][D][voice_assistant:641]: Event Type: 4 (VOICE_ASSISTANT_STT_END) +[16:22:14.807][D][voice_assistant:669]: Speech recognised as: " Who are you?" +[16:22:14.807][D][voice_assistant:641]: Event Type: 5 (VOICE_ASSISTANT_INTENT_START) +[16:22:14.808][D][voice_assistant:674]: Intent started + +[16:22:14.808][D][power_supply:033]: Enabling power supply. +[16:22:14.808][D][power_supply:033]: Enabling power supply. +[16:22:14.808][D][power_supply:033]: Enabling power supply. +[16:22:14.808][D][power_supply:033]: Enabling power supply. +[16:22:14.808][D][power_supply:033]: Enabling power supply. + +[16:22:14.808][D][voice_assistant:641]: Event Type: 6 (VOICE_ASSISTANT_INTENT_END) +[16:22:14.808][D][voice_assistant:641]: Event Type: 7 (VOICE_ASSISTANT_TTS_START) +[16:22:14.808][D][voice_assistant:697]: Response: "I am an AI assistant designed to help you with various tasks, answer questions, and provide information. If there's anything specific you need help with, feel free to ask!" + +[16:22:14.809][D][light:036]: 'voice_assistant_leds' Setting: +[16:22:14.809][D][light:051]: Brightness: 66% +[16:22:14.809][D][light:109]: Effect: 'Replying' + +[16:22:14.809][D][voice_assistant:641]: Event Type: 8 (VOICE_ASSISTANT_TTS_END) +[16:22:14.809][D][voice_assistant:719]: Response URL: "https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac" +[16:22:14.809][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to STREAMING_RESPONSE +[16:22:14.809][D][voice_assistant:522]: Desired state set to STREAMING_RESPONSE + +[16:22:14.809][D][media_player:080]: 'Media Player' - Setting +[16:22:14.809][D][media_player:087]: Media URL: https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac +[16:22:14.810][D][media_player:093]: Announcement: yes + +[16:22:14.810][D][voice_assistant:641]: Event Type: 2 (VOICE_ASSISTANT_RUN_END) +[16:22:14.810][D][voice_assistant:733]: Assist Pipeline ended + +[16:22:14.810][D][esp-idf:000][ann_read]: I (40514) esp-x509-crt-bundle: Certificate validated + +[16:22:14.810][D][nabu_media_player.pipeline:173]: Reading FLAC file type +[16:22:14.810][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample +[16:22:14.810][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio + +[16:22:26.112][D][voice_assistant:515]: State changed from STREAMING_RESPONSE to IDLE +[16:22:26.112][D][voice_assistant:522]: Desired state set to IDLE \ No newline at end of file diff --git a/esphome-va-bridge/esphome_server.py b/esphome-va-bridge/esphome_server.py new file mode 100644 index 0000000..d5431c1 --- /dev/null +++ b/esphome-va-bridge/esphome_server.py @@ -0,0 +1,283 @@ +import asyncio +import hashlib +import logging +import math +import os +import wave +from enum import StrEnum +import traceback + +import aioesphomeapi +from aioesphomeapi import (VoiceAssistantAudioSettings, + VoiceAssistantCommandFlag, VoiceAssistantEventType) +from aiohttp import web + +from audio import STT, TTS, VAD, AudioBuffer, AudioSegment +from llm_api import LlmApi + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +routes = web.RouteTableDef() + + +class PipelineStage(StrEnum): + """Stages of a pipeline.""" + + WAKE_WORD = "wake_word" + STT = "stt" + INTENT = "intent" + TTS = "tts" + END = "end" + + +class Satellite: + + def __init__(self, host="192.168.10.155", port=6053, password=""): + self.client = aioesphomeapi.APIClient(host, port, password) + self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000) + self.state = "idle" + self.vad = VAD() + self.stt = STT(device="cuda") + self.tts = TTS() + self.agent = LlmApi() + + async def connect(self): + await self.client.connect(login=True) + + async def disconnect(self): + await self.client.disconnect() + + async def handle_pipeline_start( + self, + conversation_id: str, + flags: int, + audio_settings: VoiceAssistantAudioSettings, + wake_word_phrase: str | None, + ) -> int: + logging.debug( + f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}" + ) + + # Device triggered pipeline (wake word, etc.) + if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: + start_stage = PipelineStage.WAKE_WORD + else: + start_stage = PipelineStage.STT + + end_stage = PipelineStage.TTS + + logging.info(f"Starting pipeline from {start_stage} to {end_stage}") + self.state = "running_pipeline" + + # Run pipeline + asyncio.create_task(self.run_pipeline(wake_word_phrase)) + + return 0 + + async def run_pipeline(self, wake_word_phrase: str | None = None): + logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.") + + try: + logging.debug(" > STT start") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None) + + logging.debug(" > STT VAD start") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None) + + # VAD + logging.debug(" > VAD: Waiting for silence...") + VAD_TIMEOUT = 1000 + _voice_detected = False + while True: + if self.audio_queue.length() > 20000: + break + elif self.audio_queue.length() >= VAD_TIMEOUT: + voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT)) + if voice_detected != _voice_detected: + _voice_detected = voice_detected + if not _voice_detected: + logging.debug(" > VAD: Silence detected.") + break + if self.audio_queue.length() > 2000 and not voice_detected: + logging.debug(" > VAD: Silence detected (no voice).") + break + await asyncio.sleep(0.1) + + logging.debug(" > STT VAD end") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None) + + # STT + text = self.stt.transcribe(self.audio_queue.get()) + self.audio_queue.clear() + logging.info(f" > STT: {text}") + + logging.debug(" > STT end") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text}) + + logging.debug(" > Intent start") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None) + agent_response = self.agent.ask(text)["answer"] + logging.info(f">Intent: {agent_response}") + + logging.debug(" > Intent end") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None) + + logging.debug(" > TTS start") + TTS = "announce" + if TTS == "announce": + self.client.send_voice_assistant_event( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response} + ) + + voice = "default" + audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest() + audio_file_path = f"static/{audio_hash_string}.wav" + if not os.path.exists(audio_file_path): + logging.debug(" > TTS: Generating audio...") + audio_segment = self.tts.generate(agent_response, voice=voice) + else: + logging.debug(" > TTS: Using cached audio.") + audio_segment = AudioSegment.from_file(audio_file_path) + audio_segment.export(audio_file_path, format="wav") + self.client.send_voice_assistant_event( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, + {"url": f"http://192.168.10.111:8002/{audio_file_path}"}, + ) + elif TTS == "stream": + await self._stream_tts_audio("debug") + logging.debug(" > TTS end") + + except Exception as e: + logging.error("Error: ") + logging.error(e) + logging.error(traceback.format_exc()) + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)}) + + logging.debug(" > Run end") + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None) + logging.info("Pipeline done") + + async def handle_pipeline_stop(self, abort: bool): + logging.info("Pipeline stopped") + self.state = "idle" + + async def handle_audio(self, data: bytes) -> None: + """Handle incoming audio chunk from API.""" + self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1)) + + async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished): + if event.success: + logging.info("Announcement finished successfully.") + else: + logging.error("Announcement failed.") + + def handle_state_change(self, state): + logging.debug("State change:") + logging.debug(state) + + def handle_log(self, log): + logging.info("Log from device:", log) + + def subscribe(self): + self.client.subscribe_states(self.handle_state_change) + self.client.subscribe_logs(self.handle_log) + self.client.subscribe_voice_assistant( + handle_start=self.handle_pipeline_start, + handle_stop=self.handle_pipeline_stop, + handle_audio=self.handle_audio, + handle_announcement_finished=self.handle_announcement_finished, + ) + + async def _stream_tts_audio( + self, + media_id: str, + sample_rate: int = 16000, + sample_width: int = 2, + sample_channels: int = 1, + samples_per_chunk: int = 1024, # 512 + ) -> None: + """Stream TTS audio chunks to device via API or UDP.""" + + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}) + logging.info("TTS stream start") + await asyncio.sleep(1) + + try: + # if not self._is_running: + # return + + with wave.open(f"{media_id}.wav", "rb") as wav_file: + if ( + (wav_file.getframerate() != sample_rate) + or (wav_file.getsampwidth() != sample_width) + or (wav_file.getnchannels() != sample_channels) + ): + logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}") + return + + logging.info("Streaming %s audio samples", wav_file.getnframes()) + + rate = wav_file.getframerate() + width = wav_file.getsampwidth() + channels = wav_file.getnchannels() + + audio_bytes = wav_file.readframes(wav_file.getnframes()) + bytes_per_sample = width * channels + bytes_per_chunk = bytes_per_sample * samples_per_chunk + num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk)) + + # Split into chunks + for i in range(num_chunks): + offset = i * bytes_per_chunk + chunk = audio_bytes[offset : offset + bytes_per_chunk] + self.client.send_voice_assistant_audio(chunk) + + # Wait for 90% of the duration of the audio that was + # sent for it to be played. This will overrun the + # device's buffer for very long audio, so using a media + # player is preferred. + samples_in_chunk = len(chunk) // (sample_width * sample_channels) + seconds_in_chunk = samples_in_chunk / sample_rate + logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}") + await asyncio.sleep(seconds_in_chunk * 0.9) + + except asyncio.CancelledError: + return # Don't trigger state change + finally: + self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}) + logging.info("TTS stream end") + + +async def start_http_server(port=8002): + app = web.Application() + app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static")) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", port) + await site.start() + logging.info(f"HTTP server started at http://0.0.0.0:{port}/") + + +async def main(): + """Connect to an ESPHome device and wait for state changes.""" + await start_http_server() + + satellite = Satellite() + await satellite.connect() + logging.info("Connected to ESPHome voice assistant.") + satellite.subscribe() + satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None) + logging.info("Listening for events...") + + while True: + await asyncio.sleep(0.1) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/esphome-va-bridge/llm_api.py b/esphome-va-bridge/llm_api.py new file mode 100644 index 0000000..69b3149 --- /dev/null +++ b/esphome-va-bridge/llm_api.py @@ -0,0 +1,13 @@ +import requests + + +class LlmApi: + def __init__(self, host="127.0.0.1", port=8000): + self.host = host + self.port = port + self.url = f"http://{host}:{port}" + + def ask(self, question): + url = self.url + "/ask" + response = requests.post(url, json={"question": question}) + return response.json() diff --git a/esphome-va-bridge/log_reader.py b/esphome-va-bridge/log_reader.py new file mode 100644 index 0000000..8b899e1 --- /dev/null +++ b/esphome-va-bridge/log_reader.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +# Helper script and aioesphomeapi to view logs from an esphome device +import argparse +import asyncio +import logging +import sys +from datetime import datetime + +from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore +from aioesphomeapi.client import APIClient +from aioesphomeapi.log_runner import async_run + + +async def main(argv: list[str]) -> None: + parser = argparse.ArgumentParser("aioesphomeapi-logs") + parser.add_argument("--port", type=int, default=6053) + parser.add_argument("--password", type=str) + parser.add_argument("--noise-psk", type=str) + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("address") + args = parser.parse_args(argv[1:]) + + logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + cli = APIClient( + args.address, + args.port, + args.password or "", + noise_psk=args.noise_psk, + keepalive=10, + ) + + def on_log(msg: SubscribeLogsResponse) -> None: + time_ = datetime.now() + message: bytes = msg.message + text = message.decode("utf8", "backslashreplace") + nanoseconds = time_.microsecond // 1000 + print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}") + + stop = await async_run(cli, on_log) + try: + await asyncio.Event().wait() + finally: + await stop() + + +def cli_entry_point() -> None: + """Run the CLI.""" + try: + asyncio.run(main(sys.argv)) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + cli_entry_point() + sys.exit(0) diff --git a/external/home-assistant-voice-pe b/external/home-assistant-voice-pe new file mode 160000 index 0000000..4c09d89 --- /dev/null +++ b/external/home-assistant-voice-pe @@ -0,0 +1 @@ +Subproject commit 4c09d89ddbf234546f244bf98b1053a2f89e5ab6 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e34796e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 \ No newline at end of file