init
This commit is contained in:
commit
7b45d19308
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
config.yml
|
||||||
|
debug.py
|
||||||
|
notebook.ipynb
|
||||||
|
.vscode/
|
||||||
|
secrets.yaml
|
||||||
1
agent-api/.cache
Normal file
1
agent-api/.cache
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"access_token": "BQDSGB6rh1-6-bc7wxFTZ2_Idjr6W-xLSZEMizMF6wNXew6_2M4DyWeSaVovZZLPdVv0aQOU8LqIQqQSNr5OwNXDpRk6icpcC6HKYGaZ7I_U9PRWbIOWE-QUakMcAmMFTWwO3FJEBRDYARR43cWTeRn8w9MHG48V5awv2A1447yz8kd4cUk8lAOTbcCeMnpxjJJQqqaigIWAy8NF7ghJXr4dibKj2JtTP5yECA", "token_type": "Bearer", "expires_in": 3600, "scope": "user-library-read", "expires_at": 1736781939, "refresh_token": "AQBaAMOoUVzcichmSIgDC4_CvkXXNNwfjnSbwpypzuaaCQqcQw-_2VP6Rkayp08Y-5gSEbEGVWc0-1CP4-JmQI2kQ_0UmJT5g2G3LvOjufSYowzXRxeedIOWcwbRvjSA1XI"}
|
||||||
161
agent-api/agent.py
Normal file
161
agent-api/agent.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from jinja2 import Template
|
||||||
|
import logging
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from db import Database
|
||||||
|
from llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def relative_to_actual_dates(text: str):
|
||||||
|
from dateparser.search import search_dates
|
||||||
|
|
||||||
|
matches = search_dates(text, languages=["en", "se"])
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
for match in matches:
|
||||||
|
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.db = Database("rag")
|
||||||
|
self.llm = LLM(config=self.config["llm"])
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
Flushes the agents knowledge.
|
||||||
|
"""
|
||||||
|
logging.info("Truncating database")
|
||||||
|
self.db.cur.execute("TRUNCATE TABLE items")
|
||||||
|
self.db.conn.commit()
|
||||||
|
|
||||||
|
def insert(self, text):
|
||||||
|
"""
|
||||||
|
Append knowledge.
|
||||||
|
"""
|
||||||
|
logging.info(f"Inserting item into embedding table ({text}).")
|
||||||
|
|
||||||
|
#logging.info(f"\tReplacing relative dates with actual dates.")
|
||||||
|
#text = relative_to_actual_dates(text)
|
||||||
|
#logging.info(f"\tDone. text: {text}")
|
||||||
|
|
||||||
|
vector = self.llm.embedding(text)
|
||||||
|
vector_padded = vector + ([0]*(2048-len(vector)))
|
||||||
|
|
||||||
|
self.db.cur.execute(
|
||||||
|
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
|
||||||
|
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||||
|
)
|
||||||
|
self.db.conn.commit()
|
||||||
|
|
||||||
|
def keywords(self, query: str, n_keywords: int = 5):
|
||||||
|
"""
|
||||||
|
Suggest keywords related to a query.
|
||||||
|
"""
|
||||||
|
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
|
||||||
|
prompt = (
|
||||||
|
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
|
||||||
|
+ f"Do not add any context or text other than the json list of {n_keywords} words."
|
||||||
|
)
|
||||||
|
response = self.llm.query(prompt, system_msg=sys_msg)
|
||||||
|
keywords = json.loads(response)
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
|
||||||
|
"""
|
||||||
|
Retrieve relevant knowledge.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- n_retrieve (int): How many notes to retrieve.
|
||||||
|
- n_rerank (int): How many notes to keep after reranking.
|
||||||
|
"""
|
||||||
|
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
|
||||||
|
|
||||||
|
logging.debug(f"Using embedding model on '{query}'.")
|
||||||
|
vector_keywords = self.llm.embedding(query)
|
||||||
|
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
|
||||||
|
|
||||||
|
logging.debug("Querying database")
|
||||||
|
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
|
||||||
|
self.db.cur.execute(
|
||||||
|
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
|
||||||
|
(vector_padded,),
|
||||||
|
)
|
||||||
|
db_response = self.db.cur.fetchall()
|
||||||
|
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
|
||||||
|
logging.debug("Database returned: ")
|
||||||
|
logging.debug(knowledge_arr)
|
||||||
|
|
||||||
|
if n_retrieve > n_rerank:
|
||||||
|
logging.info("Reranking the results")
|
||||||
|
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
|
||||||
|
reranked_knowledge_arr = sorted(
|
||||||
|
knowledge_arr,
|
||||||
|
key=lambda x: (
|
||||||
|
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
|
||||||
|
logging.debug("Reranked results: ")
|
||||||
|
logging.debug(reranked_knowledge_arr)
|
||||||
|
return reranked_knowledge_arr
|
||||||
|
else:
|
||||||
|
return knowledge_arr
|
||||||
|
|
||||||
|
def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
person: str,
|
||||||
|
history: list = None,
|
||||||
|
n_retrieve: int = 5,
|
||||||
|
n_rerank=3,
|
||||||
|
tools: Union[None, list] = None,
|
||||||
|
prompts: Union[None, list] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ask the agent a question.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- question (str): The question to ask
|
||||||
|
- person (str): Who asks the question?
|
||||||
|
- n_retrieve (int): How many notes to retrieve.
|
||||||
|
- n_rerank (int): How many notes to keep after reranking.
|
||||||
|
"""
|
||||||
|
logging.info(f"Answering question: {question}")
|
||||||
|
|
||||||
|
if n_retrieve > 0:
|
||||||
|
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
|
||||||
|
|
||||||
|
with open('templates/knowledge.en.j2', 'r') as file:
|
||||||
|
template = Template(file.read())
|
||||||
|
|
||||||
|
knowledge_str = template.render(
|
||||||
|
knowledge = knowledge,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
knowledge_str = ""
|
||||||
|
|
||||||
|
with open('templates/ask.en.j2', 'r') as file:
|
||||||
|
template = Template(file.read())
|
||||||
|
|
||||||
|
prompt = template.render(
|
||||||
|
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||||
|
knowledge = knowledge_str,
|
||||||
|
plugin_prompts = prompts,
|
||||||
|
person = person
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
||||||
|
answer = self.llm.query(
|
||||||
|
question,
|
||||||
|
system_msg=prompt,
|
||||||
|
tools=tools,
|
||||||
|
history=history
|
||||||
|
)
|
||||||
|
logging.info(f"Got answer from LLM: '{answer}'.")
|
||||||
|
|
||||||
|
return answer
|
||||||
127
agent-api/backend.py
Normal file
127
agent-api/backend.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agent import Agent
|
||||||
|
|
||||||
|
PLUGINS_FOLDER = "plugins"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("config.yml", "r", encoding='utf8') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
def load_plugins(config):
|
||||||
|
tools = []
|
||||||
|
prompts = []
|
||||||
|
|
||||||
|
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
||||||
|
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
|
||||||
|
logging.info(f"Loading plugin: {plugin_folder}")
|
||||||
|
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
plugin = module.Plugin(config)
|
||||||
|
tools += plugin.tools()
|
||||||
|
prompt = plugin.prompt()
|
||||||
|
if prompt is not None:
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
|
return tools, prompts
|
||||||
|
|
||||||
|
tools, prompts = load_plugins(config)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
|
agent = Agent(config)
|
||||||
|
|
||||||
|
|
||||||
|
class Query(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class Question(BaseModel):
|
||||||
|
question: str
|
||||||
|
person: str = "User"
|
||||||
|
history: list[Message] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/ping")
|
||||||
|
def ping():
|
||||||
|
"""
|
||||||
|
# Ping
|
||||||
|
Check if the API is up.
|
||||||
|
"""
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.post("/ask")
|
||||||
|
def ask(q: Question):
|
||||||
|
"""
|
||||||
|
# Ask
|
||||||
|
Ask the agent a question.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- **question**: the question to ask.
|
||||||
|
- **person**: who is asking the question?
|
||||||
|
"""
|
||||||
|
history = None
|
||||||
|
if q.history:
|
||||||
|
history = [q.model_dump() for q in q.history]
|
||||||
|
|
||||||
|
answer = agent.ask(
|
||||||
|
question=q.question,
|
||||||
|
person="Pierre",
|
||||||
|
history=history,
|
||||||
|
tools=tools,
|
||||||
|
prompts=prompts,
|
||||||
|
n_retrieve=0,
|
||||||
|
n_rerank=0
|
||||||
|
)
|
||||||
|
return {"question": q.question, "answer": answer, "history": q.history}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/retrieve")
|
||||||
|
def retrieve(q: Query):
|
||||||
|
"""
|
||||||
|
# Retrieve
|
||||||
|
Retrieve knowledge related to a query.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- **query**: the question.
|
||||||
|
"""
|
||||||
|
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
|
||||||
|
return {"query": q.query, "answer": result}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/learn")
|
||||||
|
def retrieve(q: Query):
|
||||||
|
"""
|
||||||
|
# Learn
|
||||||
|
Learn the agent something. E.g. *"John likes to play volleyball".*
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- **query**: the note.
|
||||||
|
"""
|
||||||
|
answer = agent.insert(text=q.query)
|
||||||
|
return {"text": q.query}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/flush")
|
||||||
|
def flush():
|
||||||
|
"""
|
||||||
|
# Flush
|
||||||
|
Remove all knowledge from the agent.
|
||||||
|
"""
|
||||||
|
agent.flush()
|
||||||
|
return {"message": "Knowledge flushed."}
|
||||||
25
agent-api/config.default.yml
Normal file
25
agent-api/config.default.yml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
server:
|
||||||
|
url: "http://127.0.0.1:8000"
|
||||||
|
llm:
|
||||||
|
use_local_chat: false
|
||||||
|
use_local_embedding: false
|
||||||
|
local_model_dir: ""
|
||||||
|
local_embedding_model: "intfloat/e5-large-v2"
|
||||||
|
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
|
||||||
|
openai_embedding_model: "text-embedding-3-small"
|
||||||
|
openai_chat_model: "gpt-3.5-turbo-0125"
|
||||||
|
temperature: 0.1
|
||||||
|
homeassistant:
|
||||||
|
url: "http://localhost:8123"
|
||||||
|
token: ""
|
||||||
|
plugins:
|
||||||
|
calendar:
|
||||||
|
default_calendar: "calendar.my_calendar"
|
||||||
|
todo:
|
||||||
|
default_list: "todo.todo"
|
||||||
|
vasttrafik:
|
||||||
|
key: ""
|
||||||
|
secret: ""
|
||||||
|
default_from_station: "Brunnsparken"
|
||||||
|
default_to_station: "Centralstationen"
|
||||||
|
delay: 0 # minutes
|
||||||
32
agent-api/db.py
Normal file
32
agent-api/db.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import logging
|
||||||
|
from pgvector.psycopg2 import register_vector
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database,
|
||||||
|
user="postgres",
|
||||||
|
password="postgres",
|
||||||
|
host="localhost",
|
||||||
|
port="5433",
|
||||||
|
) -> None:
|
||||||
|
logging.info("Connecting to database")
|
||||||
|
self.conn = psycopg2.connect(
|
||||||
|
database="rag",
|
||||||
|
user="postgres",
|
||||||
|
password="postgres",
|
||||||
|
host="localhost",
|
||||||
|
port="5433",
|
||||||
|
)
|
||||||
|
register_vector(self.conn)
|
||||||
|
self.cur = self.conn.cursor()
|
||||||
|
self.cur.execute("SELECT version();")
|
||||||
|
logging.info(" DB Version: %s", self.cur.fetchone()[0])
|
||||||
|
logging.info(" psycopg2 Version: %s", psycopg2.__version__)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
logging.info("Closing connection to database")
|
||||||
|
self.cur.close()
|
||||||
|
self.conn.close()
|
||||||
7
agent-api/db.sql
Normal file
7
agent-api/db.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
CREATE TABLE public.items (
|
||||||
|
id bigserial NOT NULL,
|
||||||
|
embedding vector(2048) NULL,
|
||||||
|
"text" varchar(1024) NULL,
|
||||||
|
date_added date NULL,
|
||||||
|
CONSTRAINT items_pkey PRIMARY KEY (id)
|
||||||
|
);
|
||||||
15
agent-api/docker-compose.yml
Normal file
15
agent-api/docker-compose.yml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=postgres
|
||||||
|
- POSTGRES_PASSWORD=postgres
|
||||||
|
ports:
|
||||||
|
- '5433:5432'
|
||||||
|
volumes:
|
||||||
|
- db:/var/lib/postgresql/data
|
||||||
|
volumes:
|
||||||
|
db:
|
||||||
|
driver: local
|
||||||
41
agent-api/frontend.py
Normal file
41
agent-api/frontend.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("config.yml", "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
BASE_URL = config["server"]["url"]
|
||||||
|
|
||||||
|
def chat(message, history):
|
||||||
|
messages = []
|
||||||
|
for (user_msg, bot_msg) in history:
|
||||||
|
messages.append({"role": "user", "content": user_msg})
|
||||||
|
messages.append({"role": "system", "content": bot_msg})
|
||||||
|
|
||||||
|
logging.info(f"Sending chat message to backend ('{message}').")
|
||||||
|
response = requests.post(BASE_URL + "/ask", json={"question": message, "person": "Pierre", "history": messages})
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logging.info("Response:")
|
||||||
|
logging.info(response.json())
|
||||||
|
data = response.json()
|
||||||
|
return data["answer"]
|
||||||
|
else:
|
||||||
|
logging.error(response)
|
||||||
|
|
||||||
|
ui = gr.ChatInterface(chat, title="Chat")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.info("Launching frontend.")
|
||||||
|
ui.launch(
|
||||||
|
#share=False,
|
||||||
|
server_name='0.0.0.0',
|
||||||
|
)
|
||||||
3
agent-api/install.md
Normal file
3
agent-api/install.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# llama.cpp
|
||||||
|
set CMAKE_ARGS=-DLLAMA_CUBLAS=on
|
||||||
|
pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python==0.2.39
|
||||||
37
agent-api/llama_server.py
Normal file
37
agent-api/llama_server.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Example FastAPI server for llama.cpp.
|
||||||
|
|
||||||
|
To run this example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install fastapi uvicorn sse-starlette
|
||||||
|
export MODEL=../models/7B/...
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run:
|
||||||
|
```
|
||||||
|
uvicorn --factory llama_cpp.server.app:create_app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m llama_cpp.server
|
||||||
|
```
|
||||||
|
|
||||||
|
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||||
|
|
||||||
|
|
||||||
|
To actually see the implementation of the server, see llama_cpp/server/app.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from llama_cpp.server.app import create_app
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||||
|
)
|
||||||
252
agent-api/llm.py
Normal file
252
agent-api/llm.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import json
|
||||||
|
from llama_cpp.llama import Llama
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from openai import OpenAI
|
||||||
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChat:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def prepare_function_calling(self, tools):
|
||||||
|
# prepare function calling
|
||||||
|
function_map = {}
|
||||||
|
if tools is not None and len(tools) > 0:
|
||||||
|
for tool in tools:
|
||||||
|
function_name = tool["function"]["name"]
|
||||||
|
function_map[function_name] = tool["function"]["function_to_call"]
|
||||||
|
functions = []
|
||||||
|
for tool in tools:
|
||||||
|
fn = tool["function"]
|
||||||
|
functions.append(
|
||||||
|
{"type": tool["type"], "function": {x: fn[x] for x in fn if x != "function_to_call"}}
|
||||||
|
)
|
||||||
|
logging.info(f"{len(tools)} available functions:")
|
||||||
|
logging.info(function_map.keys())
|
||||||
|
else:
|
||||||
|
functions = None
|
||||||
|
return function_map, functions
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChat(BaseChat):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
if self.config["openai"]["base_url"] is not None and self.config["openai"]["base_url"] != "":
|
||||||
|
base_url = self.config["openai"]["base_url"]
|
||||||
|
else:
|
||||||
|
base_url = None
|
||||||
|
self.client = OpenAI(base_url=base_url)
|
||||||
|
|
||||||
|
def chat(self, messages, tools) -> str:
|
||||||
|
function_map, functions = self.prepare_function_calling(tools)
|
||||||
|
|
||||||
|
logging.debug("Sending request to OpenAI.")
|
||||||
|
llm_response = self.client.chat.completions.create(
|
||||||
|
model=self.config["openai"]["chat_model"],
|
||||||
|
messages=messages,
|
||||||
|
tools=functions,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
tool_choice="auto" if functions is not None else None,
|
||||||
|
)
|
||||||
|
logging.debug("LLM response:")
|
||||||
|
logging.debug(llm_response.choices)
|
||||||
|
if llm_response.choices[0].message.tool_calls:
|
||||||
|
# Handle function calls
|
||||||
|
followup_response = self.execute_function_call(
|
||||||
|
llm_response.choices[0].message, function_map, messages
|
||||||
|
)
|
||||||
|
return followup_response.choices[0].message.content.strip()
|
||||||
|
else:
|
||||||
|
return llm_response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
def execute_function_call(self, message, function_map: dict, messages: list) -> str:
|
||||||
|
"""
|
||||||
|
Executes function calls embedded in a LLM message in parallel, and returns a LLM response based on the results.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- message: LLM message containing the function calls.
|
||||||
|
- function_map (dict): dict of {"function_name": function()}
|
||||||
|
- message (list): message history
|
||||||
|
"""
|
||||||
|
tool_calls = message.tool_calls
|
||||||
|
logging.info(f"Got {len(tool_calls)} function call(s).")
|
||||||
|
|
||||||
|
def execute_single_tool_call(tool_call):
|
||||||
|
"""Helper function to execute a single tool call"""
|
||||||
|
logging.info(f"Attempting to execute function requested by LLM ({tool_call.function.name}, {tool_call.function.arguments}).")
|
||||||
|
if tool_call.function.name in function_map:
|
||||||
|
function_to_call = function_map[tool_call.function.name]
|
||||||
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
logging.debug(f"Calling function {tool_call.function.name} with args: {args}")
|
||||||
|
function_response = function_to_call(**args)
|
||||||
|
logging.debug(function_response)
|
||||||
|
return {
|
||||||
|
"role": "function",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"content": function_response,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logging.info(f"{tool_call.function.name} not in function_map")
|
||||||
|
logging.info(function_map.keys())
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Execute tool calls in parallel
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
# Submit all tool calls to the executor
|
||||||
|
future_to_tool_call = {
|
||||||
|
executor.submit(execute_single_tool_call, tool_call): tool_call
|
||||||
|
for tool_call in tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for future in concurrent.futures.as_completed(future_to_tool_call):
|
||||||
|
result = future.result()
|
||||||
|
if result is not None:
|
||||||
|
messages.append(result)
|
||||||
|
|
||||||
|
logging.debug("Functions called, sending the results to LLM.")
|
||||||
|
llm_response = self.client.chat.completions.create(
|
||||||
|
model=self.config["openai"]["chat_model"],
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
)
|
||||||
|
logging.debug("Got response from LLM:")
|
||||||
|
logging.debug(llm_response)
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
|
||||||
|
class LMStudioChat(BaseChat):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed")
|
||||||
|
|
||||||
|
def chat(self, messages, tools) -> str:
|
||||||
|
logging.info("Sending request to local LLM.")
|
||||||
|
llm_response = self.client.chat.completions.create(
|
||||||
|
model="",
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
)
|
||||||
|
return llm_response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChat(BaseChat):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client = Llama(
|
||||||
|
self.config["local_model_dir"] + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF/mistral-7b-instruct-v0.1.Q6_K.gguf",
|
||||||
|
n_gpu_layers=32,
|
||||||
|
n_ctx=2048,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat(self, messages: list, response_format: dict = None, tools: list = None) -> str:
|
||||||
|
if tools is not None:
|
||||||
|
logging.warning("Tools was provided to LlamaChat, but it's not yet supported.")
|
||||||
|
logging.info("Sending request to local LLM.")
|
||||||
|
logging.info(messages)
|
||||||
|
llm_response = self.client.create_chat_completion(
|
||||||
|
messages = messages,
|
||||||
|
response_format = response_format,
|
||||||
|
temperature = self.config["temperature"],
|
||||||
|
)
|
||||||
|
return llm_response["choices"][0]["message"]["content"].strip()
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
LLM constructor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- config (dict): llm-config parsed from the llm part of config.yml
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if self.config["use_local_chat"]:
|
||||||
|
self.chat_client = LlamaChat(self.config)
|
||||||
|
else:
|
||||||
|
self.chat_client = OpenAIChat(self.config)
|
||||||
|
|
||||||
|
if self.config["use_local_embedding"]:
|
||||||
|
self.embedding_model = self.config["local_embedding_model"]
|
||||||
|
self.rerank_model = self.config["local_rerank_model"]
|
||||||
|
self.embedding_client = SentenceTransformer(self.embedding_model)
|
||||||
|
else:
|
||||||
|
self.embedding_model = self.config["openai"]["embedding_model"]
|
||||||
|
self.rerank_model = None
|
||||||
|
self.embedding_client = OpenAI()
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
user_msg: str,
|
||||||
|
system_msg: Union[None, str] = None,
|
||||||
|
history: list = [],
|
||||||
|
tools: Union[None, list] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Query the LLM
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- user_msg (str): query from the user (will be appended to messages as {"role": "user"})
|
||||||
|
- system_msg (str): query from the user (will be prepended to messages as {"role": "system"})
|
||||||
|
- history (list): optional, list of messages to inject between system_msg and user_msg.
|
||||||
|
- tools: optional, list of functions that may be called by the LLM. Example:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"function": get_current_weather,
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if system_msg is not None:
|
||||||
|
messages.append({"role": "system", "content": system_msg})
|
||||||
|
|
||||||
|
if history and len(history) > 0:
|
||||||
|
logging.info("History")
|
||||||
|
logging.info(history)
|
||||||
|
messages += history
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": user_msg})
|
||||||
|
|
||||||
|
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
||||||
|
answer = self.chat_client.chat(messages=messages, tools=tools)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def embedding(self, text: str):
|
||||||
|
if self.config["use_local_embedding"]:
|
||||||
|
embedding = self.embedding_client.encode(text)
|
||||||
|
return embedding # len = 1024
|
||||||
|
else:
|
||||||
|
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
|
||||||
|
return llm_response.data[0].embedding
|
||||||
|
|
||||||
|
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
|
||||||
|
if self.rerank_model is not None and len(corpus) > 0:
|
||||||
|
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
|
||||||
|
logging.info("Reranking:")
|
||||||
|
logging.info(sentence_combinations)
|
||||||
|
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
|
||||||
|
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
|
||||||
|
return result[:n_rerank]
|
||||||
|
else:
|
||||||
|
return corpus
|
||||||
0
agent-api/plugins/__init__.py
Normal file
0
agent-api/plugins/__init__.py
Normal file
68
agent-api/plugins/base_plugin.py
Normal file
68
agent-api/plugins/base_plugin.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
from plugins.homeassistant import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlugin:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.homeassistant = HomeAssistant(config)
|
||||||
|
|
||||||
|
def prompt(self) -> str | None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def tools(self) -> list:
|
||||||
|
tools = []
|
||||||
|
for tool_fn in self._list_tool_methods():
|
||||||
|
tool_fn_metadata = self._get_function_metadata(tool_fn)
|
||||||
|
|
||||||
|
if "input" in tool_fn_metadata["parameters"]:
|
||||||
|
json_schema = tool_fn_metadata["parameters"]["input"].annotation.model_json_schema()
|
||||||
|
del json_schema["title"]
|
||||||
|
for k, v in json_schema["properties"].items():
|
||||||
|
del json_schema["properties"][k]["title"]
|
||||||
|
fn_to_call = self._create_function_with_kwargs(
|
||||||
|
tool_fn_metadata["parameters"]["input"].annotation, tool_fn
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
json_schema = {}
|
||||||
|
fn_to_call = tool_fn
|
||||||
|
|
||||||
|
tools.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_fn_metadata["name"],
|
||||||
|
"function_to_call": fn_to_call,
|
||||||
|
"description": tool_fn_metadata["docstring"],
|
||||||
|
"parameters": json_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def _get_function_metadata(self, func):
|
||||||
|
function_name = func.__name__
|
||||||
|
docstring = inspect.getdoc(func)
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
parameters = signature.parameters
|
||||||
|
|
||||||
|
metadata = {"name": function_name, "docstring": docstring or "", "parameters": parameters}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _create_function_with_kwargs(self, model_cls, original_function):
|
||||||
|
def dynamic_function(**kwargs):
|
||||||
|
model_instance = model_cls(**kwargs)
|
||||||
|
return original_function(model_instance)
|
||||||
|
|
||||||
|
return dynamic_function
|
||||||
|
|
||||||
|
def _list_tool_methods(self) -> list:
|
||||||
|
attributes = dir(self)
|
||||||
|
tool_functions = [
|
||||||
|
getattr(self, attr)
|
||||||
|
for attr in attributes
|
||||||
|
if callable(getattr(self, attr)) and attr.startswith("tool_")
|
||||||
|
]
|
||||||
|
return tool_functions
|
||||||
26
agent-api/plugins/calendar/plugin.py
Normal file
26
agent-api/plugins/calendar/plugin.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ..base_plugin import BasePlugin
|
||||||
|
|
||||||
|
class Plugin(BasePlugin):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
def tool_get_calendar_events(self, entity_id=None, days=3):
|
||||||
|
"""
|
||||||
|
Get users calendar events.
|
||||||
|
"""
|
||||||
|
if entity_id is None:
|
||||||
|
entity_id = self.config["plugins"]["calendar"]["default_calendar"]
|
||||||
|
response = asyncio.run(
|
||||||
|
self.homeassistant.send_command(
|
||||||
|
"call_service",
|
||||||
|
domain="calendar",
|
||||||
|
service="get_events",
|
||||||
|
target={"entity_id": entity_id},
|
||||||
|
service_data={"duration": {"days": days}},
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return json.dumps(response)
|
||||||
51
agent-api/plugins/homeassistant.py
Normal file
51
agent-api/plugins/homeassistant.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from hass_client import HomeAssistantClient
|
||||||
|
|
||||||
|
|
||||||
|
class HomeAssistant:
|
||||||
|
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.ha_config = self.call_api("config")
|
||||||
|
|
||||||
|
def call_api(self, endpoint, payload=None):
|
||||||
|
"""
|
||||||
|
Call the REST API
|
||||||
|
"""
|
||||||
|
base_url = self.config["homeassistant"]["url"]
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.config['homeassistant']['token']}",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
if payload is None:
|
||||||
|
response = requests.get(f"{base_url}/api/{endpoint}", headers=headers)
|
||||||
|
else:
|
||||||
|
response = requests.post(f"{base_url}/api/{endpoint}", headers=headers, json=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return json.loads(response.text.replace("'", '"').replace("None", "null"))
|
||||||
|
except:
|
||||||
|
return response.text
|
||||||
|
else:
|
||||||
|
return {"status": "error", "message": response.text}
|
||||||
|
|
||||||
|
async def send_command(self, command: str, **kwargs: dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Send command using the WebSocket API
|
||||||
|
"""
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with HomeAssistantClient(
|
||||||
|
self.config["homeassistant"]["url"] + "/api/websocket",
|
||||||
|
self.config["homeassistant"]["token"],
|
||||||
|
session,
|
||||||
|
) as client:
|
||||||
|
response = await client.send_command(command, **kwargs)
|
||||||
|
return response
|
||||||
73
agent-api/plugins/lights/plugin.py
Normal file
73
agent-api/plugins/lights/plugin.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from ..base_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(BasePlugin):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
super().__init__(config=config)
|
||||||
|
self.light_map = self.get_light_map()
|
||||||
|
|
||||||
|
def prompt(self):
|
||||||
|
prompt = "These are the lights available:\n" + json.dumps(
|
||||||
|
self.light_map, indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def tools(self):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "control_light",
|
||||||
|
"function_to_call": self.tool_control_light,
|
||||||
|
"description": "Control lights in users a home",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entity_id": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": self.available_lights(),
|
||||||
|
"description": "What light entity to control",
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["on", "off"],
|
||||||
|
"description": "on or off",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["room", "state"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def tool_control_light(self, entity_id: str, state: str):
|
||||||
|
self.homeassistant.call_api(f"services/light/turn_{state}", payload={"entity_id": entity_id})
|
||||||
|
return json.dumps({"status": "success", "message": f"{entity_id} was turned {state}."})
|
||||||
|
|
||||||
|
def available_lights(self):
|
||||||
|
available_lights = []
|
||||||
|
for room in self.light_map:
|
||||||
|
for light in self.light_map[room]:
|
||||||
|
available_lights.append(light["entity_id"])
|
||||||
|
return available_lights
|
||||||
|
|
||||||
|
def get_light_map(self):
|
||||||
|
template = (
|
||||||
|
"{% for state in states.light %}\n"
|
||||||
|
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}", "room": "{{area_name(state.entity_id)}}"}\n'
|
||||||
|
+ "{% endfor %}\n"
|
||||||
|
)
|
||||||
|
response = self.homeassistant.call_api(f"template", payload={"template": template})
|
||||||
|
light_map = {}
|
||||||
|
for item_str in response.split("\n\n"):
|
||||||
|
item_dict = json.loads(item_str)
|
||||||
|
if item_dict["room"] not in light_map:
|
||||||
|
light_map[item_dict["room"]] = []
|
||||||
|
light_map[item_dict["room"]].append(
|
||||||
|
{"entity_id": item_dict["entity_id"], "name": item_dict["name"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
return light_map
|
||||||
1
agent-api/plugins/music/.cache
Normal file
1
agent-api/plugins/music/.cache
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"access_token": "BQC4auYl6DX-_xbS0zDjbv2_OZGbbdvyWLDa7I6EBE91GhBAExs3TGmhBNCgwS3P0WJwkUyL-kRRH2x-80QpWaRdLhEm0Q7Tfm8vzsV1rRBToy87Vm1jJJX3S08bYQwD3UvxFlaOhGyyWoSa37JnPkmS9T8d2GTPeHne1HaYR5KEJeYu9stixoONiTjHidQJi_7hOozIqwUkTC7TFWBWezJfa_7GlkE", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "AQDSGqpEaLoyR_A-s78bFVHYE1F2PIbA_cfgrgMt7jZA-A3xmEo5V7GazlgB-okD0JHX-w_fVc-n0_b8nvvZUz5PT6iCd7jRHLVoRA-h6XbQKcSAUVwJGL2djOTfnC4wOIE", "scope": "user-library-read", "expires_at": 1736686521}
|
||||||
59
agent-api/plugins/music/plugin.py
Normal file
59
agent-api/plugins/music/plugin.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import spotipy
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from spotipy.oauth2 import SpotifyOAuth
|
||||||
|
|
||||||
|
from ..base_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class PlayMusicInput(BaseModel):
|
||||||
|
query: str = Field(..., description="Can be a song, artist, album, or playlist")
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(BasePlugin):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
self.spotify = spotipy.Spotify(
|
||||||
|
auth_manager=SpotifyOAuth(scope="user-library-read", redirect_uri="http://localhost:8080")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _search(self, query: str, limit: int = 10):
|
||||||
|
_result = self.spotify.search(query, limit=limit)
|
||||||
|
result = []
|
||||||
|
for track in _result["tracks"]["items"]:
|
||||||
|
artists = [artist["name"] for artist in track["artists"]]
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"name": track["name"],
|
||||||
|
"artists": artists,
|
||||||
|
"uri": track["uri"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def tool_play_music(self, input: PlayMusicInput):
|
||||||
|
"""
|
||||||
|
Play music using a search query.
|
||||||
|
"""
|
||||||
|
track = self._search(input.query, limit=1)[0]
|
||||||
|
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
|
||||||
|
payload = {
|
||||||
|
"entity_id": self.config["plugins"]["music"]["default_speaker"],
|
||||||
|
"media_content_id": track["uri"],
|
||||||
|
"media_content_type": "music",
|
||||||
|
"enqueue": "play",
|
||||||
|
}
|
||||||
|
result = self.homeassistant.call_api(f"services/media_player/play_media", payload=payload)
|
||||||
|
return json.dumps({"status": "success", "message": f"Playing music.", "track": track})
|
||||||
|
|
||||||
|
def tool_stop_music(self):
|
||||||
|
"""
|
||||||
|
Stop playback of music.
|
||||||
|
"""
|
||||||
|
self.homeassistant.call_api(
|
||||||
|
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
|
||||||
|
)
|
||||||
|
return json.dumps({"status": "success", "message": f"Music paused."})
|
||||||
3
agent-api/plugins/readme.md
Normal file
3
agent-api/plugins/readme.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Plugins
|
||||||
|
Each folder in /plugins is a plugin.
|
||||||
|
Class files within /plugins (e.g. homeassistant.py) is helper classes that plugins can use in order to have to duplicate code.
|
||||||
74
agent-api/plugins/smhi/plugin.py
Normal file
74
agent-api/plugins/smhi/plugin.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from smhi.smhi_lib import Smhi
|
||||||
|
|
||||||
|
from ..homeassistant import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.homeassistant = HomeAssistant(config)
|
||||||
|
self.station = Smhi(
|
||||||
|
self.homeassistant.ha_config["longitude"], self.homeassistant.ha_config["latitude"]
|
||||||
|
)
|
||||||
|
self.weather_conditions = {
|
||||||
|
1: "Clear sky",
|
||||||
|
2: "Nearly clear sky",
|
||||||
|
3: "Variable cloudiness",
|
||||||
|
4: "Halfclear sky",
|
||||||
|
5: "Cloudy sky",
|
||||||
|
6: "Overcast",
|
||||||
|
7: "Fog",
|
||||||
|
8: "Light rain showers",
|
||||||
|
9: "Moderate rain showers",
|
||||||
|
10: "Heavy rain showers",
|
||||||
|
11: "Thunderstorm",
|
||||||
|
12: "Light sleet showers",
|
||||||
|
13: "Moderate sleet showers",
|
||||||
|
14: "Heavy sleet showers",
|
||||||
|
15: "Light snow showers",
|
||||||
|
16: "Moderate snow showers",
|
||||||
|
17: "Heavy snow showers",
|
||||||
|
18: "Light rain",
|
||||||
|
19: "Moderate rain",
|
||||||
|
20: "Heavy rain",
|
||||||
|
21: "Thunder",
|
||||||
|
22: "Light sleet",
|
||||||
|
23: "Moderate sleet",
|
||||||
|
24: "Heavy sleet",
|
||||||
|
25: "Light snowfall",
|
||||||
|
26: "Moderate snowfall",
|
||||||
|
27: "Heavy snowfall",
|
||||||
|
}
|
||||||
|
|
||||||
|
def prompt(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def tools(self):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather_forecast",
|
||||||
|
"function_to_call": self.get_weather_forecast,
|
||||||
|
"description": "Get weather forecast for the following 24 hours.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def get_weather_forecast(self, hours=24):
|
||||||
|
forecast = []
|
||||||
|
for hour in self.station.get_forecast_hour()[:hours]:
|
||||||
|
forecast.append(
|
||||||
|
{
|
||||||
|
"time": hour.valid_time.strftime("%Y-%m-%d %H:%M"),
|
||||||
|
"weather": self.weather_conditions[hour.symbol],
|
||||||
|
"temperature": hour.temperature,
|
||||||
|
# "cloudiness": hour.cloudiness,
|
||||||
|
"total_precipitation": hour.total_precipitation,
|
||||||
|
"wind_speed": hour.wind_speed,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return json.dumps(forecast)
|
||||||
1
agent-api/plugins/smhi/requirements.txt
Normal file
1
agent-api/plugins/smhi/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
smhi-pkg>=1.0.16
|
||||||
107
agent-api/plugins/todo/plugin.py
Normal file
107
agent-api/plugins/todo/plugin.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ..base_plugin import BasePlugin
|
||||||
|
|
||||||
|
class Plugin(BasePlugin):
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
def prompt(self):
|
||||||
|
prompt = "These are the todo lists available:\n" + json.dumps(
|
||||||
|
self.get_todo_lists(), indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def tools(self):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_todo_list_items",
|
||||||
|
"function_to_call": self.get_todo_list_items,
|
||||||
|
"description": "Get items from todo-list.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entity_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "What list to get items from.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["entity_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_todo_item",
|
||||||
|
"function_to_call": self.add_todo_item,
|
||||||
|
"description": "Add item to todo-list.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"entity_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "What list to add item to.",
|
||||||
|
},
|
||||||
|
"item": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Description of the item.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["entity_id", "item"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def get_todo_list_items(self, entity_id: str = None):
|
||||||
|
"""
|
||||||
|
Get items from todo-list.
|
||||||
|
"""
|
||||||
|
if entity_id is None:
|
||||||
|
entity_id = self.config["plugins"]["todo"]["default_list"]
|
||||||
|
response = asyncio.run(
|
||||||
|
self.homeassistant.send_command(
|
||||||
|
"call_service",
|
||||||
|
domain="todo",
|
||||||
|
service="get_items",
|
||||||
|
target={"entity_id": entity_id},
|
||||||
|
service_data={"status": "needs_action"},
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return json.dumps(response["response"])
|
||||||
|
|
||||||
|
def add_todo_item(self, item: str, entity_id: str = None):
|
||||||
|
"""
|
||||||
|
Add item to todo-list.
|
||||||
|
"""
|
||||||
|
if entity_id is None:
|
||||||
|
entity_id = self.config["plugins"]["todo"]["default_list"]
|
||||||
|
asyncio.run(
|
||||||
|
self.homeassistant.send_command(
|
||||||
|
"call_service",
|
||||||
|
domain="todo",
|
||||||
|
service="add_item",
|
||||||
|
target={"entity_id": entity_id},
|
||||||
|
service_data={"item": item},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return json.dumps({"status": f"{item} added to list."})
|
||||||
|
|
||||||
|
def get_todo_lists(self):
|
||||||
|
template = (
|
||||||
|
"{% for state in states.todo %}\n"
|
||||||
|
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}"}\n'
|
||||||
|
+ "{% endfor %}\n"
|
||||||
|
)
|
||||||
|
response = self.homeassistant.call_api(f"template", payload={"template": template})
|
||||||
|
todo_lists = {}
|
||||||
|
for item_str in response.split("\n\n"):
|
||||||
|
item_dict = json.loads(item_str)
|
||||||
|
todo_lists[item_dict["name"]] = item_dict["entity_id"]
|
||||||
|
return todo_lists
|
||||||
98
agent-api/plugins/vasttrafik/plugin.py
Normal file
98
agent-api/plugins/vasttrafik/plugin.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import vasttrafik
|
||||||
|
|
||||||
|
from ..homeassistant import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin:
|
||||||
|
def __init__(self, config: dict) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.homeassistant = HomeAssistant(config)
|
||||||
|
self.vasttrafik = vasttrafik.JournyPlanner(
|
||||||
|
key=self.config["plugins"]["vasttrafik"]["key"],
|
||||||
|
secret=self.config["plugins"]["vasttrafik"]["secret"],
|
||||||
|
)
|
||||||
|
self.default_from_station_id = self.get_station_id(
|
||||||
|
self.config["plugins"]["vasttrafik"]["default_from_station"]
|
||||||
|
)
|
||||||
|
self.default_to_station_id = self.get_station_id(
|
||||||
|
self.config["plugins"]["vasttrafik"]["default_to_station"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def prompt(self):
|
||||||
|
from_station = self.config["plugins"]["vasttrafik"]["default_from_station"]
|
||||||
|
to_station = self.config["plugins"]["vasttrafik"]["default_to_station"]
|
||||||
|
return "# Public transportation #\nHome station is: %s, and the default to station to use is %s" % (
|
||||||
|
from_station,
|
||||||
|
to_station,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tools(self):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_trip",
|
||||||
|
"function_to_call": self.search_trip,
|
||||||
|
"description": "Search for trips by public transportation.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"from_station": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Station to travel from.",
|
||||||
|
},
|
||||||
|
"to_station": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Station to travel to.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["from_station", "to_station"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def search_trip(self, from_station=None, to_station=None):
|
||||||
|
if from_station is None:
|
||||||
|
from_station_id = self.default_from_station_id
|
||||||
|
else:
|
||||||
|
from_station_id = self.get_station_id(from_station)
|
||||||
|
if to_station is None:
|
||||||
|
to_station_id = self.default_to_station_id
|
||||||
|
else:
|
||||||
|
to_station_id = self.get_station_id(to_station)
|
||||||
|
trip = self.vasttrafik.trip(from_station_id, to_station_id)
|
||||||
|
|
||||||
|
response = []
|
||||||
|
for departure in trip:
|
||||||
|
try:
|
||||||
|
departure_dt = datetime.strptime(
|
||||||
|
departure["tripLegs"][0]["origin"]["estimatedTime"][:19], "%Y-%m-%dT%H:%M:%S"
|
||||||
|
)
|
||||||
|
minutes_to_departure = (departure_dt - datetime.now()).seconds / 60
|
||||||
|
if minutes_to_departure >= self.config["plugins"]["vasttrafik"]["delay"]:
|
||||||
|
response.append(
|
||||||
|
{
|
||||||
|
"origin": {
|
||||||
|
"stationName": departure["tripLegs"][0]["origin"]["stopPoint"]["name"],
|
||||||
|
# "plannedTime": departure["tripLegs"][0]["origin"]["plannedTime"],
|
||||||
|
"estimatedDepartureTime": departure["tripLegs"][0]["origin"]["estimatedTime"],
|
||||||
|
},
|
||||||
|
"destination": {
|
||||||
|
"stationName": departure["tripLegs"][0]["destination"]["stopPoint"]["name"],
|
||||||
|
"estimatedArrivalTime": departure["tripLegs"][0]["estimatedArrivalTime"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
assert len(response) > 0, "No trips found."
|
||||||
|
return json.dumps(response, ensure_ascii=False)
|
||||||
|
|
||||||
|
def get_station_id(self, location_name: str) -> str:
|
||||||
|
station_id = self.vasttrafik.location_name(location_name)[0]["gid"]
|
||||||
|
return station_id
|
||||||
1
agent-api/plugins/vasttrafik/requirements.txt
Normal file
1
agent-api/plugins/vasttrafik/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
vtjp>=0.2.1
|
||||||
2
agent-api/pyproject.toml
Normal file
2
agent-api/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[tool.black]
|
||||||
|
line-length = 110
|
||||||
21
agent-api/readme.md
Normal file
21
agent-api/readme.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
## Start postgres with pgvector
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
Then create a new db called "rag", and run this SQL:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE EXTENSION vector;
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally run the content of db.sql.
|
||||||
|
|
||||||
|
## Run backend
|
||||||
|
```bash
|
||||||
|
conda activate llm-api
|
||||||
|
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run frontend
|
||||||
|
```bash
|
||||||
|
python frontend.py
|
||||||
|
```
|
||||||
11
agent-api/requirements.txt
Normal file
11
agent-api/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
openai>=1.11.1
|
||||||
|
sentence_transformers
|
||||||
|
dateparser
|
||||||
|
pgvector
|
||||||
|
psycopg2
|
||||||
|
pyyaml
|
||||||
|
gradio
|
||||||
|
hass-client>=1.0.1
|
||||||
|
llama_cpp_python>=0.2.44
|
||||||
BIN
agent-api/static/tada.wav
Normal file
BIN
agent-api/static/tada.wav
Normal file
Binary file not shown.
12
agent-api/templates/ask.en.j2
Normal file
12
agent-api/templates/ask.en.j2
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
You are a helpful assistant. Your replies will be converted to speech, so keep it short and concise and avoid too many special characters.
|
||||||
|
Today is {{ today }}.
|
||||||
|
|
||||||
|
{% if knowledge is not none %}
|
||||||
|
{{ knowledge }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for prompt in plugin_prompts %}
|
||||||
|
- {{ prompt }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.
|
||||||
5
agent-api/templates/knowledge.en.j2
Normal file
5
agent-api/templates/knowledge.en.j2
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
Here are some saved notes that might help you answer any question the user might have:
|
||||||
|
{% for note in knowledge %}
|
||||||
|
- {{ note }}
|
||||||
|
{% endfor %}
|
||||||
|
Never make notes of knowledge you already have (i.e. the list above).
|
||||||
203
esphome-va-bridge/audio.py
Normal file
203
esphome-va-bridge/audio.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import datetime as dt
|
||||||
|
import io
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pydub
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
import torch
|
||||||
|
from elevenlabs.client import ElevenLabs
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from silero_vad import get_speech_timestamps, load_silero_vad
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSegment(pydub.AudioSegment):
|
||||||
|
"""
|
||||||
|
Wrapper class for pydub.AudioSegment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data=None, *args, **kwargs):
|
||||||
|
super().__init__(data, *args, **kwargs)
|
||||||
|
|
||||||
|
def to_ndarray(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert the AudioSegment to a numpy array.
|
||||||
|
"""
|
||||||
|
_buffer = io.BytesIO()
|
||||||
|
self.export(_buffer, format="wav")
|
||||||
|
_buffer.seek(0)
|
||||||
|
|
||||||
|
_, wav_data = wavfile.read(_buffer)
|
||||||
|
if wav_data.dtype != np.float32:
|
||||||
|
max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1
|
||||||
|
wav_data = wav_data.astype(np.float32) / max_abs
|
||||||
|
if len(wav_data.shape) == 2:
|
||||||
|
wav_data = wav_data.mean(axis=1)
|
||||||
|
return wav_data
|
||||||
|
|
||||||
|
def to_tensor(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the AudioSegment to a PyTorch tensor.
|
||||||
|
"""
|
||||||
|
return torch.tensor(self.to_ndarray(), dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBuffer:
|
||||||
|
"""
|
||||||
|
Buffer for AudioSegments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int) -> None:
|
||||||
|
self.data: AudioSegment | None = None
|
||||||
|
self.max_size: int = max_size
|
||||||
|
|
||||||
|
def put(self, audio_segment: AudioSegment) -> None:
|
||||||
|
"""
|
||||||
|
Append an AudioSegment to the buffer.
|
||||||
|
"""
|
||||||
|
if self.data:
|
||||||
|
self.data = self.data + audio_segment
|
||||||
|
if len(self.data) >= self.max_size:
|
||||||
|
self.data = self.data[-self.max_size :]
|
||||||
|
else:
|
||||||
|
self.data = audio_segment
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the buffer.
|
||||||
|
"""
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def get(self) -> AudioSegment | None:
|
||||||
|
"""
|
||||||
|
Get the AudioSegment from the buffer.
|
||||||
|
"""
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def get_last(self, duration: int) -> AudioSegment | None:
|
||||||
|
"""
|
||||||
|
Get the last `duration` milliseconds of the AudioSegment.
|
||||||
|
"""
|
||||||
|
return self.data[-duration:] if self.data else None
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the buffer is empty.
|
||||||
|
"""
|
||||||
|
return self.data is None
|
||||||
|
|
||||||
|
def length(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the length of the buffer.
|
||||||
|
"""
|
||||||
|
return len(self.data) if self.data else 0
|
||||||
|
|
||||||
|
|
||||||
|
class VAD:
|
||||||
|
"""
|
||||||
|
Voice Activity Detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = load_silero_vad()
|
||||||
|
self.ts_voice_detected: dt.datetime | None = None
|
||||||
|
|
||||||
|
def detect_voice(self, audio_segment: AudioSegment) -> bool:
|
||||||
|
"""
|
||||||
|
Detect voice in an AudioSegment.
|
||||||
|
"""
|
||||||
|
speech_timestamps = get_speech_timestamps(
|
||||||
|
audio_segment.to_tensor(),
|
||||||
|
self.model,
|
||||||
|
return_seconds=True,
|
||||||
|
)
|
||||||
|
if len(speech_timestamps) > 0:
|
||||||
|
if self.ts_voice_detected is None:
|
||||||
|
pass # print("VAD: Voice detected.")
|
||||||
|
self.ts_voice_detected = dt.datetime.now()
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def ms_since_voice_detected(self) -> int | None:
|
||||||
|
"""
|
||||||
|
Get the milliseconds since voice was detected.
|
||||||
|
"""
|
||||||
|
if self.ts_voice_detected is None:
|
||||||
|
return None
|
||||||
|
return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000
|
||||||
|
|
||||||
|
def reset_timer(self):
|
||||||
|
"""
|
||||||
|
Reset the timer for voice detected.
|
||||||
|
"""
|
||||||
|
self.ts_voice_detected = None
|
||||||
|
|
||||||
|
|
||||||
|
class STT:
|
||||||
|
"""
|
||||||
|
Speech-to-Text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None:
|
||||||
|
compute_type = "int8" if device == "cpu" else "float16"
|
||||||
|
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
||||||
|
self.benchmark = benchmark
|
||||||
|
|
||||||
|
def transcribe(self, audio_segment: AudioSegment, language="sv"):
|
||||||
|
"""
|
||||||
|
Transcribe an AudioSegment.
|
||||||
|
"""
|
||||||
|
if self.benchmark:
|
||||||
|
ts_start = dt.datetime.now()
|
||||||
|
segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language)
|
||||||
|
segments = list(segments)
|
||||||
|
text = " ".join([seg.text for seg in segments])
|
||||||
|
text = text.strip()
|
||||||
|
if self.benchmark:
|
||||||
|
ts_done = dt.datetime.now()
|
||||||
|
delta_ms = (ts_done - ts_start).total_seconds() * 1000
|
||||||
|
print(f"[Benchmark] STT: {delta_ms:.2f} ms.")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class TTS:
|
||||||
|
"""
|
||||||
|
Text-to-Speech (TTS) class that uses the ElevenLabs API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None:
|
||||||
|
self.client = ElevenLabs()
|
||||||
|
self.voice = voice
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def generate(self, text: str, voice: str = "default") -> AudioSegment:
|
||||||
|
if voice == "default":
|
||||||
|
voice = self.voice
|
||||||
|
audio = self.client.generate(
|
||||||
|
text=text,
|
||||||
|
voice=self.voice,
|
||||||
|
model=self.model,
|
||||||
|
stream=False,
|
||||||
|
output_format="pcm_16000",
|
||||||
|
optimize_streaming_latency=2,
|
||||||
|
)
|
||||||
|
audio = b"".join(audio)
|
||||||
|
audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1)
|
||||||
|
return audio_segment
|
||||||
|
|
||||||
|
def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]:
|
||||||
|
if voice == "default":
|
||||||
|
voice = self.voice
|
||||||
|
audio_stream = self.client.generate(
|
||||||
|
text=text_stream,
|
||||||
|
voice=self.voice,
|
||||||
|
model=self.model,
|
||||||
|
stream=True,
|
||||||
|
output_format="pcm_16000",
|
||||||
|
optimize_streaming_latency=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in audio_stream:
|
||||||
|
if chunk is not None:
|
||||||
|
yield chunk
|
||||||
113
esphome-va-bridge/esp-working.log
Normal file
113
esphome-va-bridge/esp-working.log
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
[16:22:06.368][D][i2s_audio.microphone:377]: Starting I2S Audio Microphne
|
||||||
|
[16:22:06.368][D][i2s_audio.microphone:381]: Started I2S Audio Microphone
|
||||||
|
|
||||||
|
[16:22:06.368][D][micro_wake_word:417]: State changed from IDLE to DETECTING_WAKE_WORD
|
||||||
|
[16:22:08.878][D][micro_wake_word:355]: Detected 'Okay Nabu' with sliding average probability is 0.97 and max probability is 1.00
|
||||||
|
|
||||||
|
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
|
||||||
|
[16:22:08.878][D][media_player:084]: Command: STOP
|
||||||
|
[16:22:08.878][D][media_player:093]: Announcement: yes
|
||||||
|
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
|
||||||
|
[16:22:08.878][D][media_player:093]: Announcement: yes
|
||||||
|
|
||||||
|
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
|
||||||
|
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
|
||||||
|
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
|
||||||
|
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
|
||||||
|
|
||||||
|
[16:22:08.879][D][nabu_media_player.pipeline:173]: Reading FLAC file type
|
||||||
|
[16:22:08.879][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
|
||||||
|
[16:22:08.879][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
|
||||||
|
|
||||||
|
[16:22:08.879][D][ring_buffer:034][speaker_task]: Created ring buffer with size 19200
|
||||||
|
|
||||||
|
[16:22:08.879][D][i2s_audio.speaker:111]: Starting Speaker
|
||||||
|
[16:22:08.880][D][i2s_audio.speaker:116]: Started Speaker
|
||||||
|
|
||||||
|
[16:22:08.880][D][voice_assistant:515]: State changed from IDLE to START_MICROPHONE
|
||||||
|
[16:22:08.880][D][voice_assistant:522]: Desired state set to START_PIPELINE
|
||||||
|
[16:22:08.880][D][voice_assistant:225]: Starting Microphone
|
||||||
|
|
||||||
|
[16:22:08.880][D][ring_buffer:034]: Created ring buffer with size 16384
|
||||||
|
|
||||||
|
[16:22:08.880][D][voice_assistant:515]: State changed from START_MICROPHONE to STARTING_MICROPHONE
|
||||||
|
[16:22:08.880][D][voice_assistant:515]: State changed from STARTING_MICROPHONE to START_PIPELINE
|
||||||
|
[16:22:08.880][D][voice_assistant:280]: Requesting start...
|
||||||
|
[16:22:08.880][D][voice_assistant:515]: State changed from START_PIPELINE to STARTING_PIPELINE
|
||||||
|
[16:22:08.880][D][voice_assistant:537]: Client started, streaming microphone
|
||||||
|
[16:22:08.881][D][voice_assistant:515]: State changed from STARTING_PIPELINE to STREAMING_MICROPHONE
|
||||||
|
[16:22:08.881][D][voice_assistant:522]: Desired state set to STREAMING_MICROPHONE
|
||||||
|
[16:22:08.881][D][voice_assistant:641]: Event Type: 1 (VOICE_ASSISTANT_RUN_START)
|
||||||
|
[16:22:08.881][D][voice_assistant:644]: Assist Pipeline running
|
||||||
|
[16:22:08.881][D][voice_assistant:641]: Event Type: 3 (VOICE_ASSISTANT_STT_START)
|
||||||
|
[16:22:08.881][D][voice_assistant:655]: STT started
|
||||||
|
|
||||||
|
[16:22:08.881][D][light:036]: 'voice_assistant_leds' Setting:
|
||||||
|
[16:22:08.881][D][light:047]: State: ON
|
||||||
|
[16:22:08.881][D][light:051]: Brightness: 66%
|
||||||
|
[16:22:08.882][D][light:109]: Effect: 'Waiting for Command'
|
||||||
|
|
||||||
|
[16:22:08.882][D][power_supply:033]: Enabling power supply.
|
||||||
|
|
||||||
|
[16:22:09.671][D][voice_assistant:641]: Event Type: 11 (VOICE_ASSISTANT_STT_VAD_START)
|
||||||
|
[16:22:09.671][D][voice_assistant:804]: Starting STT by VAD
|
||||||
|
|
||||||
|
[16:22:09.671][D][light:036]: 'voice_assistant_leds' Setting:
|
||||||
|
[16:22:09.671][D][light:051]: Brightness: 66%
|
||||||
|
[16:22:09.671][D][light:109]: Effect: 'Listening For Command'
|
||||||
|
|
||||||
|
[16:22:14.806][D][voice_assistant:641]: Event Type: 12 (VOICE_ASSISTANT_STT_VAD_END)
|
||||||
|
[16:22:14.806][D][voice_assistant:808]: STT by VAD end
|
||||||
|
[16:22:14.806][D][voice_assistant:515]: State changed from STREAMING_MICROPHONE to STOP_MICROPHONE
|
||||||
|
[16:22:14.806][D][voice_assistant:522]: Desired state set to AWAITING_RESPONSE
|
||||||
|
[16:22:14.806][D][voice_assistant:515]: State changed from STOP_MICROPHONE to STOPPING_MICROPHONE
|
||||||
|
|
||||||
|
[16:22:14.807][D][light:036]: 'voice_assistant_leds' Setting:
|
||||||
|
[16:22:14.807][D][light:051]: Brightness: 66%
|
||||||
|
[16:22:14.807][D][light:109]: Effect: 'Thinking'
|
||||||
|
|
||||||
|
[16:22:14.807][D][voice_assistant:515]: State changed from STOPPING_MICROPHONE to AWAITING_RESPONSE
|
||||||
|
[16:22:14.807][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to AWAITING_RESPONSE
|
||||||
|
|
||||||
|
[16:22:14.807][D][power_supply:033]: Enabling power supply.
|
||||||
|
[16:22:14.807][D][power_supply:033]: Enabling power supply.
|
||||||
|
|
||||||
|
[16:22:14.807][D][voice_assistant:641]: Event Type: 4 (VOICE_ASSISTANT_STT_END)
|
||||||
|
[16:22:14.807][D][voice_assistant:669]: Speech recognised as: " Who are you?"
|
||||||
|
[16:22:14.807][D][voice_assistant:641]: Event Type: 5 (VOICE_ASSISTANT_INTENT_START)
|
||||||
|
[16:22:14.808][D][voice_assistant:674]: Intent started
|
||||||
|
|
||||||
|
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||||
|
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||||
|
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||||
|
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||||
|
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||||
|
|
||||||
|
[16:22:14.808][D][voice_assistant:641]: Event Type: 6 (VOICE_ASSISTANT_INTENT_END)
|
||||||
|
[16:22:14.808][D][voice_assistant:641]: Event Type: 7 (VOICE_ASSISTANT_TTS_START)
|
||||||
|
[16:22:14.808][D][voice_assistant:697]: Response: "I am an AI assistant designed to help you with various tasks, answer questions, and provide information. If there's anything specific you need help with, feel free to ask!"
|
||||||
|
|
||||||
|
[16:22:14.809][D][light:036]: 'voice_assistant_leds' Setting:
|
||||||
|
[16:22:14.809][D][light:051]: Brightness: 66%
|
||||||
|
[16:22:14.809][D][light:109]: Effect: 'Replying'
|
||||||
|
|
||||||
|
[16:22:14.809][D][voice_assistant:641]: Event Type: 8 (VOICE_ASSISTANT_TTS_END)
|
||||||
|
[16:22:14.809][D][voice_assistant:719]: Response URL: "https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac"
|
||||||
|
[16:22:14.809][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to STREAMING_RESPONSE
|
||||||
|
[16:22:14.809][D][voice_assistant:522]: Desired state set to STREAMING_RESPONSE
|
||||||
|
|
||||||
|
[16:22:14.809][D][media_player:080]: 'Media Player' - Setting
|
||||||
|
[16:22:14.809][D][media_player:087]: Media URL: https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac
|
||||||
|
[16:22:14.810][D][media_player:093]: Announcement: yes
|
||||||
|
|
||||||
|
[16:22:14.810][D][voice_assistant:641]: Event Type: 2 (VOICE_ASSISTANT_RUN_END)
|
||||||
|
[16:22:14.810][D][voice_assistant:733]: Assist Pipeline ended
|
||||||
|
|
||||||
|
[16:22:14.810][D][esp-idf:000][ann_read]: I (40514) esp-x509-crt-bundle: Certificate validated
|
||||||
|
|
||||||
|
[16:22:14.810][D][nabu_media_player.pipeline:173]: Reading FLAC file type
|
||||||
|
[16:22:14.810][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
|
||||||
|
[16:22:14.810][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
|
||||||
|
|
||||||
|
[16:22:26.112][D][voice_assistant:515]: State changed from STREAMING_RESPONSE to IDLE
|
||||||
|
[16:22:26.112][D][voice_assistant:522]: Desired state set to IDLE
|
||||||
283
esphome-va-bridge/esphome_server.py
Normal file
283
esphome-va-bridge/esphome_server.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import wave
|
||||||
|
from enum import StrEnum
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import aioesphomeapi
|
||||||
|
from aioesphomeapi import (VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantCommandFlag, VoiceAssistantEventType)
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from audio import STT, TTS, VAD, AudioBuffer, AudioSegment
|
||||||
|
from llm_api import LlmApi
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage(StrEnum):
|
||||||
|
"""Stages of a pipeline."""
|
||||||
|
|
||||||
|
WAKE_WORD = "wake_word"
|
||||||
|
STT = "stt"
|
||||||
|
INTENT = "intent"
|
||||||
|
TTS = "tts"
|
||||||
|
END = "end"
|
||||||
|
|
||||||
|
|
||||||
|
class Satellite:
|
||||||
|
|
||||||
|
def __init__(self, host="192.168.10.155", port=6053, password=""):
|
||||||
|
self.client = aioesphomeapi.APIClient(host, port, password)
|
||||||
|
self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000)
|
||||||
|
self.state = "idle"
|
||||||
|
self.vad = VAD()
|
||||||
|
self.stt = STT(device="cuda")
|
||||||
|
self.tts = TTS()
|
||||||
|
self.agent = LlmApi()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
await self.client.connect(login=True)
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
await self.client.disconnect()
|
||||||
|
|
||||||
|
async def handle_pipeline_start(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
flags: int,
|
||||||
|
audio_settings: VoiceAssistantAudioSettings,
|
||||||
|
wake_word_phrase: str | None,
|
||||||
|
) -> int:
|
||||||
|
logging.debug(
|
||||||
|
f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Device triggered pipeline (wake word, etc.)
|
||||||
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||||
|
start_stage = PipelineStage.WAKE_WORD
|
||||||
|
else:
|
||||||
|
start_stage = PipelineStage.STT
|
||||||
|
|
||||||
|
end_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
logging.info(f"Starting pipeline from {start_stage} to {end_stage}")
|
||||||
|
self.state = "running_pipeline"
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
asyncio.create_task(self.run_pipeline(wake_word_phrase))
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def run_pipeline(self, wake_word_phrase: str | None = None):
|
||||||
|
logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.debug(" > STT start")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None)
|
||||||
|
|
||||||
|
logging.debug(" > STT VAD start")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None)
|
||||||
|
|
||||||
|
# VAD
|
||||||
|
logging.debug(" > VAD: Waiting for silence...")
|
||||||
|
VAD_TIMEOUT = 1000
|
||||||
|
_voice_detected = False
|
||||||
|
while True:
|
||||||
|
if self.audio_queue.length() > 20000:
|
||||||
|
break
|
||||||
|
elif self.audio_queue.length() >= VAD_TIMEOUT:
|
||||||
|
voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT))
|
||||||
|
if voice_detected != _voice_detected:
|
||||||
|
_voice_detected = voice_detected
|
||||||
|
if not _voice_detected:
|
||||||
|
logging.debug(" > VAD: Silence detected.")
|
||||||
|
break
|
||||||
|
if self.audio_queue.length() > 2000 and not voice_detected:
|
||||||
|
logging.debug(" > VAD: Silence detected (no voice).")
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
logging.debug(" > STT VAD end")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None)
|
||||||
|
|
||||||
|
# STT
|
||||||
|
text = self.stt.transcribe(self.audio_queue.get())
|
||||||
|
self.audio_queue.clear()
|
||||||
|
logging.info(f" > STT: {text}")
|
||||||
|
|
||||||
|
logging.debug(" > STT end")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text})
|
||||||
|
|
||||||
|
logging.debug(" > Intent start")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None)
|
||||||
|
agent_response = self.agent.ask(text)["answer"]
|
||||||
|
logging.info(f">Intent: {agent_response}")
|
||||||
|
|
||||||
|
logging.debug(" > Intent end")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None)
|
||||||
|
|
||||||
|
logging.debug(" > TTS start")
|
||||||
|
TTS = "announce"
|
||||||
|
if TTS == "announce":
|
||||||
|
self.client.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response}
|
||||||
|
)
|
||||||
|
|
||||||
|
voice = "default"
|
||||||
|
audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest()
|
||||||
|
audio_file_path = f"static/{audio_hash_string}.wav"
|
||||||
|
if not os.path.exists(audio_file_path):
|
||||||
|
logging.debug(" > TTS: Generating audio...")
|
||||||
|
audio_segment = self.tts.generate(agent_response, voice=voice)
|
||||||
|
else:
|
||||||
|
logging.debug(" > TTS: Using cached audio.")
|
||||||
|
audio_segment = AudioSegment.from_file(audio_file_path)
|
||||||
|
audio_segment.export(audio_file_path, format="wav")
|
||||||
|
self.client.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||||
|
{"url": f"http://192.168.10.111:8002/{audio_file_path}"},
|
||||||
|
)
|
||||||
|
elif TTS == "stream":
|
||||||
|
await self._stream_tts_audio("debug")
|
||||||
|
logging.debug(" > TTS end")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error: ")
|
||||||
|
logging.error(e)
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)})
|
||||||
|
|
||||||
|
logging.debug(" > Run end")
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None)
|
||||||
|
logging.info("Pipeline done")
|
||||||
|
|
||||||
|
async def handle_pipeline_stop(self, abort: bool):
|
||||||
|
logging.info("Pipeline stopped")
|
||||||
|
self.state = "idle"
|
||||||
|
|
||||||
|
async def handle_audio(self, data: bytes) -> None:
|
||||||
|
"""Handle incoming audio chunk from API."""
|
||||||
|
self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1))
|
||||||
|
|
||||||
|
async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished):
|
||||||
|
if event.success:
|
||||||
|
logging.info("Announcement finished successfully.")
|
||||||
|
else:
|
||||||
|
logging.error("Announcement failed.")
|
||||||
|
|
||||||
|
def handle_state_change(self, state):
|
||||||
|
logging.debug("State change:")
|
||||||
|
logging.debug(state)
|
||||||
|
|
||||||
|
def handle_log(self, log):
|
||||||
|
logging.info("Log from device:", log)
|
||||||
|
|
||||||
|
def subscribe(self):
|
||||||
|
self.client.subscribe_states(self.handle_state_change)
|
||||||
|
self.client.subscribe_logs(self.handle_log)
|
||||||
|
self.client.subscribe_voice_assistant(
|
||||||
|
handle_start=self.handle_pipeline_start,
|
||||||
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
handle_audio=self.handle_audio,
|
||||||
|
handle_announcement_finished=self.handle_announcement_finished,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_tts_audio(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
sample_width: int = 2,
|
||||||
|
sample_channels: int = 1,
|
||||||
|
samples_per_chunk: int = 1024, # 512
|
||||||
|
) -> None:
|
||||||
|
"""Stream TTS audio chunks to device via API or UDP."""
|
||||||
|
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
||||||
|
logging.info("TTS stream start")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# if not self._is_running:
|
||||||
|
# return
|
||||||
|
|
||||||
|
with wave.open(f"{media_id}.wav", "rb") as wav_file:
|
||||||
|
if (
|
||||||
|
(wav_file.getframerate() != sample_rate)
|
||||||
|
or (wav_file.getsampwidth() != sample_width)
|
||||||
|
or (wav_file.getnchannels() != sample_channels)
|
||||||
|
):
|
||||||
|
logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Streaming %s audio samples", wav_file.getnframes())
|
||||||
|
|
||||||
|
rate = wav_file.getframerate()
|
||||||
|
width = wav_file.getsampwidth()
|
||||||
|
channels = wav_file.getnchannels()
|
||||||
|
|
||||||
|
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||||
|
bytes_per_sample = width * channels
|
||||||
|
bytes_per_chunk = bytes_per_sample * samples_per_chunk
|
||||||
|
num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk))
|
||||||
|
|
||||||
|
# Split into chunks
|
||||||
|
for i in range(num_chunks):
|
||||||
|
offset = i * bytes_per_chunk
|
||||||
|
chunk = audio_bytes[offset : offset + bytes_per_chunk]
|
||||||
|
self.client.send_voice_assistant_audio(chunk)
|
||||||
|
|
||||||
|
# Wait for 90% of the duration of the audio that was
|
||||||
|
# sent for it to be played. This will overrun the
|
||||||
|
# device's buffer for very long audio, so using a media
|
||||||
|
# player is preferred.
|
||||||
|
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||||
|
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||||
|
logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}")
|
||||||
|
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return # Don't trigger state change
|
||||||
|
finally:
|
||||||
|
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {})
|
||||||
|
logging.info("TTS stream end")
|
||||||
|
|
||||||
|
|
||||||
|
async def start_http_server(port=8002):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static"))
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "0.0.0.0", port)
|
||||||
|
await site.start()
|
||||||
|
logging.info(f"HTTP server started at http://0.0.0.0:{port}/")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Connect to an ESPHome device and wait for state changes."""
|
||||||
|
await start_http_server()
|
||||||
|
|
||||||
|
satellite = Satellite()
|
||||||
|
await satellite.connect()
|
||||||
|
logging.info("Connected to ESPHome voice assistant.")
|
||||||
|
satellite.subscribe()
|
||||||
|
satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None)
|
||||||
|
logging.info("Listening for events...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
13
esphome-va-bridge/llm_api.py
Normal file
13
esphome-va-bridge/llm_api.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class LlmApi:
|
||||||
|
def __init__(self, host="127.0.0.1", port=8000):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.url = f"http://{host}:{port}"
|
||||||
|
|
||||||
|
def ask(self, question):
|
||||||
|
url = self.url + "/ask"
|
||||||
|
response = requests.post(url, json={"question": question})
|
||||||
|
return response.json()
|
||||||
62
esphome-va-bridge/log_reader.py
Normal file
62
esphome-va-bridge/log_reader.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Helper script and aioesphomeapi to view logs from an esphome device
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||||
|
from aioesphomeapi.client import APIClient
|
||||||
|
from aioesphomeapi.log_runner import async_run
|
||||||
|
|
||||||
|
|
||||||
|
async def main(argv: list[str]) -> None:
|
||||||
|
parser = argparse.ArgumentParser("aioesphomeapi-logs")
|
||||||
|
parser.add_argument("--port", type=int, default=6053)
|
||||||
|
parser.add_argument("--password", type=str)
|
||||||
|
parser.add_argument("--noise-psk", type=str)
|
||||||
|
parser.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
parser.add_argument("address")
|
||||||
|
args = parser.parse_args(argv[1:])
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
|
||||||
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
cli = APIClient(
|
||||||
|
args.address,
|
||||||
|
args.port,
|
||||||
|
args.password or "",
|
||||||
|
noise_psk=args.noise_psk,
|
||||||
|
keepalive=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||||
|
time_ = datetime.now()
|
||||||
|
message: bytes = msg.message
|
||||||
|
text = message.decode("utf8", "backslashreplace")
|
||||||
|
nanoseconds = time_.microsecond // 1000
|
||||||
|
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}")
|
||||||
|
|
||||||
|
stop = await async_run(cli, on_log)
|
||||||
|
try:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
finally:
|
||||||
|
await stop()
|
||||||
|
|
||||||
|
|
||||||
|
def cli_entry_point() -> None:
|
||||||
|
"""Run the CLI."""
|
||||||
|
try:
|
||||||
|
asyncio.run(main(sys.argv))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_entry_point()
|
||||||
|
sys.exit(0)
|
||||||
1
external/home-assistant-voice-pe
vendored
Submodule
1
external/home-assistant-voice-pe
vendored
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 4c09d89ddbf234546f244bf98b1053a2f89e5ab6
|
||||||
2
pyproject.toml
Normal file
2
pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
Loading…
x
Reference in New Issue
Block a user