init
This commit is contained in:
commit
7b45d19308
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
config.yml
|
||||
debug.py
|
||||
notebook.ipynb
|
||||
.vscode/
|
||||
secrets.yaml
|
||||
1
agent-api/.cache
Normal file
1
agent-api/.cache
Normal file
@ -0,0 +1 @@
|
||||
{"access_token": "BQDSGB6rh1-6-bc7wxFTZ2_Idjr6W-xLSZEMizMF6wNXew6_2M4DyWeSaVovZZLPdVv0aQOU8LqIQqQSNr5OwNXDpRk6icpcC6HKYGaZ7I_U9PRWbIOWE-QUakMcAmMFTWwO3FJEBRDYARR43cWTeRn8w9MHG48V5awv2A1447yz8kd4cUk8lAOTbcCeMnpxjJJQqqaigIWAy8NF7ghJXr4dibKj2JtTP5yECA", "token_type": "Bearer", "expires_in": 3600, "scope": "user-library-read", "expires_at": 1736781939, "refresh_token": "AQBaAMOoUVzcichmSIgDC4_CvkXXNNwfjnSbwpypzuaaCQqcQw-_2VP6Rkayp08Y-5gSEbEGVWc0-1CP4-JmQI2kQ_0UmJT5g2G3LvOjufSYowzXRxeedIOWcwbRvjSA1XI"}
|
||||
161
agent-api/agent.py
Normal file
161
agent-api/agent.py
Normal file
@ -0,0 +1,161 @@
|
||||
import datetime
|
||||
import json
|
||||
from jinja2 import Template
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from db import Database
|
||||
from llm import LLM
|
||||
|
||||
|
||||
def relative_to_actual_dates(text: str):
|
||||
from dateparser.search import search_dates
|
||||
|
||||
matches = search_dates(text, languages=["en", "se"])
|
||||
|
||||
if matches:
|
||||
for match in matches:
|
||||
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
|
||||
return text
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.db = Database("rag")
|
||||
self.llm = LLM(config=self.config["llm"])
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flushes the agents knowledge.
|
||||
"""
|
||||
logging.info("Truncating database")
|
||||
self.db.cur.execute("TRUNCATE TABLE items")
|
||||
self.db.conn.commit()
|
||||
|
||||
def insert(self, text):
|
||||
"""
|
||||
Append knowledge.
|
||||
"""
|
||||
logging.info(f"Inserting item into embedding table ({text}).")
|
||||
|
||||
#logging.info(f"\tReplacing relative dates with actual dates.")
|
||||
#text = relative_to_actual_dates(text)
|
||||
#logging.info(f"\tDone. text: {text}")
|
||||
|
||||
vector = self.llm.embedding(text)
|
||||
vector_padded = vector + ([0]*(2048-len(vector)))
|
||||
|
||||
self.db.cur.execute(
|
||||
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
|
||||
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
)
|
||||
self.db.conn.commit()
|
||||
|
||||
def keywords(self, query: str, n_keywords: int = 5):
|
||||
"""
|
||||
Suggest keywords related to a query.
|
||||
"""
|
||||
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
|
||||
prompt = (
|
||||
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
|
||||
+ f"Do not add any context or text other than the json list of {n_keywords} words."
|
||||
)
|
||||
response = self.llm.query(prompt, system_msg=sys_msg)
|
||||
keywords = json.loads(response)
|
||||
return keywords
|
||||
|
||||
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
|
||||
"""
|
||||
Retrieve relevant knowledge.
|
||||
|
||||
Parameters:
|
||||
- n_retrieve (int): How many notes to retrieve.
|
||||
- n_rerank (int): How many notes to keep after reranking.
|
||||
"""
|
||||
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
|
||||
|
||||
logging.debug(f"Using embedding model on '{query}'.")
|
||||
vector_keywords = self.llm.embedding(query)
|
||||
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
|
||||
|
||||
logging.debug("Querying database")
|
||||
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
|
||||
self.db.cur.execute(
|
||||
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
|
||||
(vector_padded,),
|
||||
)
|
||||
db_response = self.db.cur.fetchall()
|
||||
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
|
||||
logging.debug("Database returned: ")
|
||||
logging.debug(knowledge_arr)
|
||||
|
||||
if n_retrieve > n_rerank:
|
||||
logging.info("Reranking the results")
|
||||
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
|
||||
reranked_knowledge_arr = sorted(
|
||||
knowledge_arr,
|
||||
key=lambda x: (
|
||||
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
|
||||
),
|
||||
)
|
||||
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
|
||||
logging.debug("Reranked results: ")
|
||||
logging.debug(reranked_knowledge_arr)
|
||||
return reranked_knowledge_arr
|
||||
else:
|
||||
return knowledge_arr
|
||||
|
||||
def ask(
|
||||
self,
|
||||
question: str,
|
||||
person: str,
|
||||
history: list = None,
|
||||
n_retrieve: int = 5,
|
||||
n_rerank=3,
|
||||
tools: Union[None, list] = None,
|
||||
prompts: Union[None, list] = None,
|
||||
):
|
||||
"""
|
||||
Ask the agent a question.
|
||||
|
||||
Parameters:
|
||||
- question (str): The question to ask
|
||||
- person (str): Who asks the question?
|
||||
- n_retrieve (int): How many notes to retrieve.
|
||||
- n_rerank (int): How many notes to keep after reranking.
|
||||
"""
|
||||
logging.info(f"Answering question: {question}")
|
||||
|
||||
if n_retrieve > 0:
|
||||
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
|
||||
|
||||
with open('templates/knowledge.en.j2', 'r') as file:
|
||||
template = Template(file.read())
|
||||
|
||||
knowledge_str = template.render(
|
||||
knowledge = knowledge,
|
||||
)
|
||||
else:
|
||||
knowledge_str = ""
|
||||
|
||||
with open('templates/ask.en.j2', 'r') as file:
|
||||
template = Template(file.read())
|
||||
|
||||
prompt = template.render(
|
||||
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
knowledge = knowledge_str,
|
||||
plugin_prompts = prompts,
|
||||
person = person
|
||||
)
|
||||
|
||||
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
||||
answer = self.llm.query(
|
||||
question,
|
||||
system_msg=prompt,
|
||||
tools=tools,
|
||||
history=history
|
||||
)
|
||||
logging.info(f"Got answer from LLM: '{answer}'.")
|
||||
|
||||
return answer
|
||||
127
agent-api/backend.py
Normal file
127
agent-api/backend.py
Normal file
@ -0,0 +1,127 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import yaml
|
||||
|
||||
from agent import Agent
|
||||
|
||||
PLUGINS_FOLDER = "plugins"
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
with open("config.yml", "r", encoding='utf8') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
def load_plugins(config):
|
||||
tools = []
|
||||
prompts = []
|
||||
|
||||
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
||||
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
|
||||
logging.info(f"Loading plugin: {plugin_folder}")
|
||||
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
||||
module = importlib.import_module(module_path)
|
||||
plugin = module.Plugin(config)
|
||||
tools += plugin.tools()
|
||||
prompt = plugin.prompt()
|
||||
if prompt is not None:
|
||||
prompts.append(prompt)
|
||||
|
||||
return tools, prompts
|
||||
|
||||
tools, prompts = load_plugins(config)
|
||||
|
||||
app = FastAPI()
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
agent = Agent(config)
|
||||
|
||||
|
||||
class Query(BaseModel):
|
||||
query: str
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class Question(BaseModel):
|
||||
question: str
|
||||
person: str = "User"
|
||||
history: list[Message] = None
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
def ping():
|
||||
"""
|
||||
# Ping
|
||||
Check if the API is up.
|
||||
"""
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/ask")
|
||||
def ask(q: Question):
|
||||
"""
|
||||
# Ask
|
||||
Ask the agent a question.
|
||||
|
||||
## Parameters
|
||||
- **question**: the question to ask.
|
||||
- **person**: who is asking the question?
|
||||
"""
|
||||
history = None
|
||||
if q.history:
|
||||
history = [q.model_dump() for q in q.history]
|
||||
|
||||
answer = agent.ask(
|
||||
question=q.question,
|
||||
person="Pierre",
|
||||
history=history,
|
||||
tools=tools,
|
||||
prompts=prompts,
|
||||
n_retrieve=0,
|
||||
n_rerank=0
|
||||
)
|
||||
return {"question": q.question, "answer": answer, "history": q.history}
|
||||
|
||||
|
||||
@app.post("/retrieve")
|
||||
def retrieve(q: Query):
|
||||
"""
|
||||
# Retrieve
|
||||
Retrieve knowledge related to a query.
|
||||
|
||||
## Parameters
|
||||
- **query**: the question.
|
||||
"""
|
||||
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
|
||||
return {"query": q.query, "answer": result}
|
||||
|
||||
|
||||
@app.post("/learn")
|
||||
def retrieve(q: Query):
|
||||
"""
|
||||
# Learn
|
||||
Learn the agent something. E.g. *"John likes to play volleyball".*
|
||||
|
||||
## Parameters
|
||||
- **query**: the note.
|
||||
"""
|
||||
answer = agent.insert(text=q.query)
|
||||
return {"text": q.query}
|
||||
|
||||
|
||||
@app.get("/flush")
|
||||
def flush():
|
||||
"""
|
||||
# Flush
|
||||
Remove all knowledge from the agent.
|
||||
"""
|
||||
agent.flush()
|
||||
return {"message": "Knowledge flushed."}
|
||||
25
agent-api/config.default.yml
Normal file
25
agent-api/config.default.yml
Normal file
@ -0,0 +1,25 @@
|
||||
server:
|
||||
url: "http://127.0.0.1:8000"
|
||||
llm:
|
||||
use_local_chat: false
|
||||
use_local_embedding: false
|
||||
local_model_dir: ""
|
||||
local_embedding_model: "intfloat/e5-large-v2"
|
||||
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
|
||||
openai_embedding_model: "text-embedding-3-small"
|
||||
openai_chat_model: "gpt-3.5-turbo-0125"
|
||||
temperature: 0.1
|
||||
homeassistant:
|
||||
url: "http://localhost:8123"
|
||||
token: ""
|
||||
plugins:
|
||||
calendar:
|
||||
default_calendar: "calendar.my_calendar"
|
||||
todo:
|
||||
default_list: "todo.todo"
|
||||
vasttrafik:
|
||||
key: ""
|
||||
secret: ""
|
||||
default_from_station: "Brunnsparken"
|
||||
default_to_station: "Centralstationen"
|
||||
delay: 0 # minutes
|
||||
32
agent-api/db.py
Normal file
32
agent-api/db.py
Normal file
@ -0,0 +1,32 @@
|
||||
import logging
|
||||
from pgvector.psycopg2 import register_vector
|
||||
import psycopg2
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(
|
||||
self,
|
||||
database,
|
||||
user="postgres",
|
||||
password="postgres",
|
||||
host="localhost",
|
||||
port="5433",
|
||||
) -> None:
|
||||
logging.info("Connecting to database")
|
||||
self.conn = psycopg2.connect(
|
||||
database="rag",
|
||||
user="postgres",
|
||||
password="postgres",
|
||||
host="localhost",
|
||||
port="5433",
|
||||
)
|
||||
register_vector(self.conn)
|
||||
self.cur = self.conn.cursor()
|
||||
self.cur.execute("SELECT version();")
|
||||
logging.info(" DB Version: %s", self.cur.fetchone()[0])
|
||||
logging.info(" psycopg2 Version: %s", psycopg2.__version__)
|
||||
|
||||
def close(self):
|
||||
logging.info("Closing connection to database")
|
||||
self.cur.close()
|
||||
self.conn.close()
|
||||
7
agent-api/db.sql
Normal file
7
agent-api/db.sql
Normal file
@ -0,0 +1,7 @@
|
||||
CREATE TABLE public.items (
|
||||
id bigserial NOT NULL,
|
||||
embedding vector(2048) NULL,
|
||||
"text" varchar(1024) NULL,
|
||||
date_added date NULL,
|
||||
CONSTRAINT items_pkey PRIMARY KEY (id)
|
||||
);
|
||||
15
agent-api/docker-compose.yml
Normal file
15
agent-api/docker-compose.yml
Normal file
@ -0,0 +1,15 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
ports:
|
||||
- '5433:5432'
|
||||
volumes:
|
||||
- db:/var/lib/postgresql/data
|
||||
volumes:
|
||||
db:
|
||||
driver: local
|
||||
41
agent-api/frontend.py
Normal file
41
agent-api/frontend.py
Normal file
@ -0,0 +1,41 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
with open("config.yml", "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
BASE_URL = config["server"]["url"]
|
||||
|
||||
def chat(message, history):
|
||||
messages = []
|
||||
for (user_msg, bot_msg) in history:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
messages.append({"role": "system", "content": bot_msg})
|
||||
|
||||
logging.info(f"Sending chat message to backend ('{message}').")
|
||||
response = requests.post(BASE_URL + "/ask", json={"question": message, "person": "Pierre", "history": messages})
|
||||
|
||||
if response.status_code == 200:
|
||||
logging.info("Response:")
|
||||
logging.info(response.json())
|
||||
data = response.json()
|
||||
return data["answer"]
|
||||
else:
|
||||
logging.error(response)
|
||||
|
||||
ui = gr.ChatInterface(chat, title="Chat")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.info("Launching frontend.")
|
||||
ui.launch(
|
||||
#share=False,
|
||||
server_name='0.0.0.0',
|
||||
)
|
||||
3
agent-api/install.md
Normal file
3
agent-api/install.md
Normal file
@ -0,0 +1,3 @@
|
||||
# llama.cpp
|
||||
set CMAKE_ARGS=-DLLAMA_CUBLAS=on
|
||||
pip install --upgrade --verbose --force-reinstall --no-cache-dir llama-cpp-python==0.2.39
|
||||
37
agent-api/llama_server.py
Normal file
37
agent-api/llama_server.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Example FastAPI server for llama.cpp.
|
||||
|
||||
To run this example:
|
||||
|
||||
```bash
|
||||
pip install fastapi uvicorn sse-starlette
|
||||
export MODEL=../models/7B/...
|
||||
```
|
||||
|
||||
Then run:
|
||||
```
|
||||
uvicorn --factory llama_cpp.server.app:create_app --reload
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
python3 -m llama_cpp.server
|
||||
```
|
||||
|
||||
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
|
||||
To actually see the implementation of the server, see llama_cpp/server/app.py
|
||||
|
||||
"""
|
||||
import os
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import create_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
)
|
||||
252
agent-api/llm.py
Normal file
252
agent-api/llm.py
Normal file
@ -0,0 +1,252 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
from llama_cpp.llama import Llama
|
||||
import logging
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class BaseChat:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
|
||||
def prepare_function_calling(self, tools):
|
||||
# prepare function calling
|
||||
function_map = {}
|
||||
if tools is not None and len(tools) > 0:
|
||||
for tool in tools:
|
||||
function_name = tool["function"]["name"]
|
||||
function_map[function_name] = tool["function"]["function_to_call"]
|
||||
functions = []
|
||||
for tool in tools:
|
||||
fn = tool["function"]
|
||||
functions.append(
|
||||
{"type": tool["type"], "function": {x: fn[x] for x in fn if x != "function_to_call"}}
|
||||
)
|
||||
logging.info(f"{len(tools)} available functions:")
|
||||
logging.info(function_map.keys())
|
||||
else:
|
||||
functions = None
|
||||
return function_map, functions
|
||||
|
||||
|
||||
class OpenAIChat(BaseChat):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
if self.config["openai"]["base_url"] is not None and self.config["openai"]["base_url"] != "":
|
||||
base_url = self.config["openai"]["base_url"]
|
||||
else:
|
||||
base_url = None
|
||||
self.client = OpenAI(base_url=base_url)
|
||||
|
||||
def chat(self, messages, tools) -> str:
|
||||
function_map, functions = self.prepare_function_calling(tools)
|
||||
|
||||
logging.debug("Sending request to OpenAI.")
|
||||
llm_response = self.client.chat.completions.create(
|
||||
model=self.config["openai"]["chat_model"],
|
||||
messages=messages,
|
||||
tools=functions,
|
||||
temperature=self.config["temperature"],
|
||||
tool_choice="auto" if functions is not None else None,
|
||||
)
|
||||
logging.debug("LLM response:")
|
||||
logging.debug(llm_response.choices)
|
||||
if llm_response.choices[0].message.tool_calls:
|
||||
# Handle function calls
|
||||
followup_response = self.execute_function_call(
|
||||
llm_response.choices[0].message, function_map, messages
|
||||
)
|
||||
return followup_response.choices[0].message.content.strip()
|
||||
else:
|
||||
return llm_response.choices[0].message.content.strip()
|
||||
|
||||
def execute_function_call(self, message, function_map: dict, messages: list) -> str:
|
||||
"""
|
||||
Executes function calls embedded in a LLM message in parallel, and returns a LLM response based on the results.
|
||||
|
||||
Parameters:
|
||||
- message: LLM message containing the function calls.
|
||||
- function_map (dict): dict of {"function_name": function()}
|
||||
- message (list): message history
|
||||
"""
|
||||
tool_calls = message.tool_calls
|
||||
logging.info(f"Got {len(tool_calls)} function call(s).")
|
||||
|
||||
def execute_single_tool_call(tool_call):
|
||||
"""Helper function to execute a single tool call"""
|
||||
logging.info(f"Attempting to execute function requested by LLM ({tool_call.function.name}, {tool_call.function.arguments}).")
|
||||
if tool_call.function.name in function_map:
|
||||
function_to_call = function_map[tool_call.function.name]
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
logging.debug(f"Calling function {tool_call.function.name} with args: {args}")
|
||||
function_response = function_to_call(**args)
|
||||
logging.debug(function_response)
|
||||
return {
|
||||
"role": "function",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"content": function_response,
|
||||
}
|
||||
else:
|
||||
logging.info(f"{tool_call.function.name} not in function_map")
|
||||
logging.info(function_map.keys())
|
||||
return None
|
||||
|
||||
# Execute tool calls in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submit all tool calls to the executor
|
||||
future_to_tool_call = {
|
||||
executor.submit(execute_single_tool_call, tool_call): tool_call
|
||||
for tool_call in tool_calls
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in concurrent.futures.as_completed(future_to_tool_call):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
messages.append(result)
|
||||
|
||||
logging.debug("Functions called, sending the results to LLM.")
|
||||
llm_response = self.client.chat.completions.create(
|
||||
model=self.config["openai"]["chat_model"],
|
||||
messages=messages,
|
||||
temperature=self.config["temperature"],
|
||||
)
|
||||
logging.debug("Got response from LLM:")
|
||||
logging.debug(llm_response)
|
||||
return llm_response
|
||||
|
||||
|
||||
class LMStudioChat(BaseChat):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed")
|
||||
|
||||
def chat(self, messages, tools) -> str:
|
||||
logging.info("Sending request to local LLM.")
|
||||
llm_response = self.client.chat.completions.create(
|
||||
model="",
|
||||
messages=messages,
|
||||
temperature=self.config["temperature"],
|
||||
)
|
||||
return llm_response.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class LlamaChat(BaseChat):
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.client = Llama(
|
||||
self.config["local_model_dir"] + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF/mistral-7b-instruct-v0.1.Q6_K.gguf",
|
||||
n_gpu_layers=32,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def chat(self, messages: list, response_format: dict = None, tools: list = None) -> str:
|
||||
if tools is not None:
|
||||
logging.warning("Tools was provided to LlamaChat, but it's not yet supported.")
|
||||
logging.info("Sending request to local LLM.")
|
||||
logging.info(messages)
|
||||
llm_response = self.client.create_chat_completion(
|
||||
messages = messages,
|
||||
response_format = response_format,
|
||||
temperature = self.config["temperature"],
|
||||
)
|
||||
return llm_response["choices"][0]["message"]["content"].strip()
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, config: dict) -> None:
|
||||
"""
|
||||
LLM constructor.
|
||||
|
||||
Parameters:
|
||||
- config (dict): llm-config parsed from the llm part of config.yml
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
if self.config["use_local_chat"]:
|
||||
self.chat_client = LlamaChat(self.config)
|
||||
else:
|
||||
self.chat_client = OpenAIChat(self.config)
|
||||
|
||||
if self.config["use_local_embedding"]:
|
||||
self.embedding_model = self.config["local_embedding_model"]
|
||||
self.rerank_model = self.config["local_rerank_model"]
|
||||
self.embedding_client = SentenceTransformer(self.embedding_model)
|
||||
else:
|
||||
self.embedding_model = self.config["openai"]["embedding_model"]
|
||||
self.rerank_model = None
|
||||
self.embedding_client = OpenAI()
|
||||
|
||||
def query(
|
||||
self,
|
||||
user_msg: str,
|
||||
system_msg: Union[None, str] = None,
|
||||
history: list = [],
|
||||
tools: Union[None, list] = None,
|
||||
):
|
||||
"""
|
||||
Query the LLM
|
||||
|
||||
Parameters:
|
||||
- user_msg (str): query from the user (will be appended to messages as {"role": "user"})
|
||||
- system_msg (str): query from the user (will be prepended to messages as {"role": "system"})
|
||||
- history (list): optional, list of messages to inject between system_msg and user_msg.
|
||||
- tools: optional, list of functions that may be called by the LLM. Example:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"function": get_current_weather,
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if system_msg is not None:
|
||||
messages.append({"role": "system", "content": system_msg})
|
||||
|
||||
if history and len(history) > 0:
|
||||
logging.info("History")
|
||||
logging.info(history)
|
||||
messages += history
|
||||
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
|
||||
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
||||
answer = self.chat_client.chat(messages=messages, tools=tools)
|
||||
return answer
|
||||
|
||||
def embedding(self, text: str):
|
||||
if self.config["use_local_embedding"]:
|
||||
embedding = self.embedding_client.encode(text)
|
||||
return embedding # len = 1024
|
||||
else:
|
||||
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
|
||||
return llm_response.data[0].embedding
|
||||
|
||||
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
|
||||
if self.rerank_model is not None and len(corpus) > 0:
|
||||
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
|
||||
logging.info("Reranking:")
|
||||
logging.info(sentence_combinations)
|
||||
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
|
||||
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
|
||||
return result[:n_rerank]
|
||||
else:
|
||||
return corpus
|
||||
0
agent-api/plugins/__init__.py
Normal file
0
agent-api/plugins/__init__.py
Normal file
68
agent-api/plugins/base_plugin.py
Normal file
68
agent-api/plugins/base_plugin.py
Normal file
@ -0,0 +1,68 @@
|
||||
import inspect
|
||||
|
||||
from plugins.homeassistant import HomeAssistant
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.homeassistant = HomeAssistant(config)
|
||||
|
||||
def prompt(self) -> str | None:
|
||||
return
|
||||
|
||||
def tools(self) -> list:
|
||||
tools = []
|
||||
for tool_fn in self._list_tool_methods():
|
||||
tool_fn_metadata = self._get_function_metadata(tool_fn)
|
||||
|
||||
if "input" in tool_fn_metadata["parameters"]:
|
||||
json_schema = tool_fn_metadata["parameters"]["input"].annotation.model_json_schema()
|
||||
del json_schema["title"]
|
||||
for k, v in json_schema["properties"].items():
|
||||
del json_schema["properties"][k]["title"]
|
||||
fn_to_call = self._create_function_with_kwargs(
|
||||
tool_fn_metadata["parameters"]["input"].annotation, tool_fn
|
||||
)
|
||||
else:
|
||||
json_schema = {}
|
||||
fn_to_call = tool_fn
|
||||
|
||||
tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_fn_metadata["name"],
|
||||
"function_to_call": fn_to_call,
|
||||
"description": tool_fn_metadata["docstring"],
|
||||
"parameters": json_schema,
|
||||
},
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def _get_function_metadata(self, func):
|
||||
function_name = func.__name__
|
||||
docstring = inspect.getdoc(func)
|
||||
signature = inspect.signature(func)
|
||||
parameters = signature.parameters
|
||||
|
||||
metadata = {"name": function_name, "docstring": docstring or "", "parameters": parameters}
|
||||
|
||||
return metadata
|
||||
|
||||
def _create_function_with_kwargs(self, model_cls, original_function):
|
||||
def dynamic_function(**kwargs):
|
||||
model_instance = model_cls(**kwargs)
|
||||
return original_function(model_instance)
|
||||
|
||||
return dynamic_function
|
||||
|
||||
def _list_tool_methods(self) -> list:
|
||||
attributes = dir(self)
|
||||
tool_functions = [
|
||||
getattr(self, attr)
|
||||
for attr in attributes
|
||||
if callable(getattr(self, attr)) and attr.startswith("tool_")
|
||||
]
|
||||
return tool_functions
|
||||
26
agent-api/plugins/calendar/plugin.py
Normal file
26
agent-api/plugins/calendar/plugin.py
Normal file
@ -0,0 +1,26 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from ..base_plugin import BasePlugin
|
||||
|
||||
class Plugin(BasePlugin):
|
||||
def __init__(self, config: dict) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
def tool_get_calendar_events(self, entity_id=None, days=3):
|
||||
"""
|
||||
Get users calendar events.
|
||||
"""
|
||||
if entity_id is None:
|
||||
entity_id = self.config["plugins"]["calendar"]["default_calendar"]
|
||||
response = asyncio.run(
|
||||
self.homeassistant.send_command(
|
||||
"call_service",
|
||||
domain="calendar",
|
||||
service="get_events",
|
||||
target={"entity_id": entity_id},
|
||||
service_data={"duration": {"days": days}},
|
||||
return_response=True,
|
||||
)
|
||||
)
|
||||
return json.dumps(response)
|
||||
51
agent-api/plugins/homeassistant.py
Normal file
51
agent-api/plugins/homeassistant.py
Normal file
@ -0,0 +1,51 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from hass_client import HomeAssistantClient
|
||||
|
||||
|
||||
class HomeAssistant:
|
||||
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.ha_config = self.call_api("config")
|
||||
|
||||
def call_api(self, endpoint, payload=None):
|
||||
"""
|
||||
Call the REST API
|
||||
"""
|
||||
base_url = self.config["homeassistant"]["url"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.config['homeassistant']['token']}",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if payload is None:
|
||||
response = requests.get(f"{base_url}/api/{endpoint}", headers=headers)
|
||||
else:
|
||||
response = requests.post(f"{base_url}/api/{endpoint}", headers=headers, json=payload)
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
return response.json()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
return json.loads(response.text.replace("'", '"').replace("None", "null"))
|
||||
except:
|
||||
return response.text
|
||||
else:
|
||||
return {"status": "error", "message": response.text}
|
||||
|
||||
async def send_command(self, command: str, **kwargs: dict[str, Any]):
|
||||
"""
|
||||
Send command using the WebSocket API
|
||||
"""
|
||||
async with ClientSession() as session:
|
||||
async with HomeAssistantClient(
|
||||
self.config["homeassistant"]["url"] + "/api/websocket",
|
||||
self.config["homeassistant"]["token"],
|
||||
session,
|
||||
) as client:
|
||||
response = await client.send_command(command, **kwargs)
|
||||
return response
|
||||
73
agent-api/plugins/lights/plugin.py
Normal file
73
agent-api/plugins/lights/plugin.py
Normal file
@ -0,0 +1,73 @@
|
||||
import json
|
||||
|
||||
from ..base_plugin import BasePlugin
|
||||
|
||||
|
||||
class Plugin(BasePlugin):
|
||||
def __init__(self, config: dict) -> None:
|
||||
super().__init__(config=config)
|
||||
self.light_map = self.get_light_map()
|
||||
|
||||
def prompt(self):
|
||||
prompt = "These are the lights available:\n" + json.dumps(
|
||||
self.light_map, indent=2, ensure_ascii=False
|
||||
)
|
||||
return prompt
|
||||
|
||||
def tools(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "control_light",
|
||||
"function_to_call": self.tool_control_light,
|
||||
"description": "Control lights in users a home",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"enum": self.available_lights(),
|
||||
"description": "What light entity to control",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": ["on", "off"],
|
||||
"description": "on or off",
|
||||
},
|
||||
},
|
||||
"required": ["room", "state"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
return tools
|
||||
|
||||
def tool_control_light(self, entity_id: str, state: str):
|
||||
self.homeassistant.call_api(f"services/light/turn_{state}", payload={"entity_id": entity_id})
|
||||
return json.dumps({"status": "success", "message": f"{entity_id} was turned {state}."})
|
||||
|
||||
def available_lights(self):
|
||||
available_lights = []
|
||||
for room in self.light_map:
|
||||
for light in self.light_map[room]:
|
||||
available_lights.append(light["entity_id"])
|
||||
return available_lights
|
||||
|
||||
def get_light_map(self):
|
||||
template = (
|
||||
"{% for state in states.light %}\n"
|
||||
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}", "room": "{{area_name(state.entity_id)}}"}\n'
|
||||
+ "{% endfor %}\n"
|
||||
)
|
||||
response = self.homeassistant.call_api(f"template", payload={"template": template})
|
||||
light_map = {}
|
||||
for item_str in response.split("\n\n"):
|
||||
item_dict = json.loads(item_str)
|
||||
if item_dict["room"] not in light_map:
|
||||
light_map[item_dict["room"]] = []
|
||||
light_map[item_dict["room"]].append(
|
||||
{"entity_id": item_dict["entity_id"], "name": item_dict["name"]}
|
||||
)
|
||||
|
||||
return light_map
|
||||
1
agent-api/plugins/music/.cache
Normal file
1
agent-api/plugins/music/.cache
Normal file
@ -0,0 +1 @@
|
||||
{"access_token": "BQC4auYl6DX-_xbS0zDjbv2_OZGbbdvyWLDa7I6EBE91GhBAExs3TGmhBNCgwS3P0WJwkUyL-kRRH2x-80QpWaRdLhEm0Q7Tfm8vzsV1rRBToy87Vm1jJJX3S08bYQwD3UvxFlaOhGyyWoSa37JnPkmS9T8d2GTPeHne1HaYR5KEJeYu9stixoONiTjHidQJi_7hOozIqwUkTC7TFWBWezJfa_7GlkE", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "AQDSGqpEaLoyR_A-s78bFVHYE1F2PIbA_cfgrgMt7jZA-A3xmEo5V7GazlgB-okD0JHX-w_fVc-n0_b8nvvZUz5PT6iCd7jRHLVoRA-h6XbQKcSAUVwJGL2djOTfnC4wOIE", "scope": "user-library-read", "expires_at": 1736686521}
|
||||
59
agent-api/plugins/music/plugin.py
Normal file
59
agent-api/plugins/music/plugin.py
Normal file
@ -0,0 +1,59 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import spotipy
|
||||
from pydantic import BaseModel, Field
|
||||
from spotipy.oauth2 import SpotifyOAuth
|
||||
|
||||
from ..base_plugin import BasePlugin
|
||||
|
||||
|
||||
class PlayMusicInput(BaseModel):
|
||||
query: str = Field(..., description="Can be a song, artist, album, or playlist")
|
||||
|
||||
|
||||
class Plugin(BasePlugin):
|
||||
def __init__(self, config: dict) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
self.spotify = spotipy.Spotify(
|
||||
auth_manager=SpotifyOAuth(scope="user-library-read", redirect_uri="http://localhost:8080")
|
||||
)
|
||||
|
||||
def _search(self, query: str, limit: int = 10):
|
||||
_result = self.spotify.search(query, limit=limit)
|
||||
result = []
|
||||
for track in _result["tracks"]["items"]:
|
||||
artists = [artist["name"] for artist in track["artists"]]
|
||||
result.append(
|
||||
{
|
||||
"name": track["name"],
|
||||
"artists": artists,
|
||||
"uri": track["uri"]
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def tool_play_music(self, input: PlayMusicInput):
|
||||
"""
|
||||
Play music using a search query.
|
||||
"""
|
||||
track = self._search(input.query, limit=1)[0]
|
||||
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
|
||||
payload = {
|
||||
"entity_id": self.config["plugins"]["music"]["default_speaker"],
|
||||
"media_content_id": track["uri"],
|
||||
"media_content_type": "music",
|
||||
"enqueue": "play",
|
||||
}
|
||||
result = self.homeassistant.call_api(f"services/media_player/play_media", payload=payload)
|
||||
return json.dumps({"status": "success", "message": f"Playing music.", "track": track})
|
||||
|
||||
def tool_stop_music(self):
|
||||
"""
|
||||
Stop playback of music.
|
||||
"""
|
||||
self.homeassistant.call_api(
|
||||
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
|
||||
)
|
||||
return json.dumps({"status": "success", "message": f"Music paused."})
|
||||
3
agent-api/plugins/readme.md
Normal file
3
agent-api/plugins/readme.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Plugins
|
||||
Each folder in /plugins is a plugin.
|
||||
Class files within /plugins (e.g. homeassistant.py) is helper classes that plugins can use in order to have to duplicate code.
|
||||
74
agent-api/plugins/smhi/plugin.py
Normal file
74
agent-api/plugins/smhi/plugin.py
Normal file
@ -0,0 +1,74 @@
|
||||
import json
|
||||
|
||||
from smhi.smhi_lib import Smhi
|
||||
|
||||
from ..homeassistant import HomeAssistant
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.homeassistant = HomeAssistant(config)
|
||||
self.station = Smhi(
|
||||
self.homeassistant.ha_config["longitude"], self.homeassistant.ha_config["latitude"]
|
||||
)
|
||||
self.weather_conditions = {
|
||||
1: "Clear sky",
|
||||
2: "Nearly clear sky",
|
||||
3: "Variable cloudiness",
|
||||
4: "Halfclear sky",
|
||||
5: "Cloudy sky",
|
||||
6: "Overcast",
|
||||
7: "Fog",
|
||||
8: "Light rain showers",
|
||||
9: "Moderate rain showers",
|
||||
10: "Heavy rain showers",
|
||||
11: "Thunderstorm",
|
||||
12: "Light sleet showers",
|
||||
13: "Moderate sleet showers",
|
||||
14: "Heavy sleet showers",
|
||||
15: "Light snow showers",
|
||||
16: "Moderate snow showers",
|
||||
17: "Heavy snow showers",
|
||||
18: "Light rain",
|
||||
19: "Moderate rain",
|
||||
20: "Heavy rain",
|
||||
21: "Thunder",
|
||||
22: "Light sleet",
|
||||
23: "Moderate sleet",
|
||||
24: "Heavy sleet",
|
||||
25: "Light snowfall",
|
||||
26: "Moderate snowfall",
|
||||
27: "Heavy snowfall",
|
||||
}
|
||||
|
||||
def prompt(self):
|
||||
return None
|
||||
|
||||
def tools(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather_forecast",
|
||||
"function_to_call": self.get_weather_forecast,
|
||||
"description": "Get weather forecast for the following 24 hours.",
|
||||
},
|
||||
}
|
||||
]
|
||||
return tools
|
||||
|
||||
def get_weather_forecast(self, hours=24):
|
||||
forecast = []
|
||||
for hour in self.station.get_forecast_hour()[:hours]:
|
||||
forecast.append(
|
||||
{
|
||||
"time": hour.valid_time.strftime("%Y-%m-%d %H:%M"),
|
||||
"weather": self.weather_conditions[hour.symbol],
|
||||
"temperature": hour.temperature,
|
||||
# "cloudiness": hour.cloudiness,
|
||||
"total_precipitation": hour.total_precipitation,
|
||||
"wind_speed": hour.wind_speed,
|
||||
}
|
||||
)
|
||||
return json.dumps(forecast)
|
||||
1
agent-api/plugins/smhi/requirements.txt
Normal file
1
agent-api/plugins/smhi/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
smhi-pkg>=1.0.16
|
||||
107
agent-api/plugins/todo/plugin.py
Normal file
107
agent-api/plugins/todo/plugin.py
Normal file
@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from ..base_plugin import BasePlugin
|
||||
|
||||
class Plugin(BasePlugin):
|
||||
def __init__(self, config: dict) -> None:
|
||||
super().__init__(config=config)
|
||||
|
||||
def prompt(self):
|
||||
prompt = "These are the todo lists available:\n" + json.dumps(
|
||||
self.get_todo_lists(), indent=2, ensure_ascii=False
|
||||
)
|
||||
return prompt
|
||||
|
||||
def tools(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_todo_list_items",
|
||||
"function_to_call": self.get_todo_list_items,
|
||||
"description": "Get items from todo-list.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": "What list to get items from.",
|
||||
}
|
||||
},
|
||||
"required": ["entity_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_todo_item",
|
||||
"function_to_call": self.add_todo_item,
|
||||
"description": "Add item to todo-list.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": "What list to add item to.",
|
||||
},
|
||||
"item": {
|
||||
"type": "string",
|
||||
"description": "Description of the item.",
|
||||
},
|
||||
},
|
||||
"required": ["entity_id", "item"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
return tools
|
||||
|
||||
def get_todo_list_items(self, entity_id: str = None):
|
||||
"""
|
||||
Get items from todo-list.
|
||||
"""
|
||||
if entity_id is None:
|
||||
entity_id = self.config["plugins"]["todo"]["default_list"]
|
||||
response = asyncio.run(
|
||||
self.homeassistant.send_command(
|
||||
"call_service",
|
||||
domain="todo",
|
||||
service="get_items",
|
||||
target={"entity_id": entity_id},
|
||||
service_data={"status": "needs_action"},
|
||||
return_response=True,
|
||||
)
|
||||
)
|
||||
return json.dumps(response["response"])
|
||||
|
||||
def add_todo_item(self, item: str, entity_id: str = None):
|
||||
"""
|
||||
Add item to todo-list.
|
||||
"""
|
||||
if entity_id is None:
|
||||
entity_id = self.config["plugins"]["todo"]["default_list"]
|
||||
asyncio.run(
|
||||
self.homeassistant.send_command(
|
||||
"call_service",
|
||||
domain="todo",
|
||||
service="add_item",
|
||||
target={"entity_id": entity_id},
|
||||
service_data={"item": item},
|
||||
)
|
||||
)
|
||||
return json.dumps({"status": f"{item} added to list."})
|
||||
|
||||
def get_todo_lists(self):
|
||||
template = (
|
||||
"{% for state in states.todo %}\n"
|
||||
+ '{ "entity_id": "{{ state.entity_id }}", "name" : "{{ state.attributes.friendly_name }}"}\n'
|
||||
+ "{% endfor %}\n"
|
||||
)
|
||||
response = self.homeassistant.call_api(f"template", payload={"template": template})
|
||||
todo_lists = {}
|
||||
for item_str in response.split("\n\n"):
|
||||
item_dict = json.loads(item_str)
|
||||
todo_lists[item_dict["name"]] = item_dict["entity_id"]
|
||||
return todo_lists
|
||||
98
agent-api/plugins/vasttrafik/plugin.py
Normal file
98
agent-api/plugins/vasttrafik/plugin.py
Normal file
@ -0,0 +1,98 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import vasttrafik
|
||||
|
||||
from ..homeassistant import HomeAssistant
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.homeassistant = HomeAssistant(config)
|
||||
self.vasttrafik = vasttrafik.JournyPlanner(
|
||||
key=self.config["plugins"]["vasttrafik"]["key"],
|
||||
secret=self.config["plugins"]["vasttrafik"]["secret"],
|
||||
)
|
||||
self.default_from_station_id = self.get_station_id(
|
||||
self.config["plugins"]["vasttrafik"]["default_from_station"]
|
||||
)
|
||||
self.default_to_station_id = self.get_station_id(
|
||||
self.config["plugins"]["vasttrafik"]["default_to_station"]
|
||||
)
|
||||
|
||||
def prompt(self):
|
||||
from_station = self.config["plugins"]["vasttrafik"]["default_from_station"]
|
||||
to_station = self.config["plugins"]["vasttrafik"]["default_to_station"]
|
||||
return "# Public transportation #\nHome station is: %s, and the default to station to use is %s" % (
|
||||
from_station,
|
||||
to_station,
|
||||
)
|
||||
|
||||
def tools(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_trip",
|
||||
"function_to_call": self.search_trip,
|
||||
"description": "Search for trips by public transportation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from_station": {
|
||||
"type": "string",
|
||||
"description": "Station to travel from.",
|
||||
},
|
||||
"to_station": {
|
||||
"type": "string",
|
||||
"description": "Station to travel to.",
|
||||
},
|
||||
},
|
||||
"required": ["from_station", "to_station"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
return tools
|
||||
|
||||
def search_trip(self, from_station=None, to_station=None):
|
||||
if from_station is None:
|
||||
from_station_id = self.default_from_station_id
|
||||
else:
|
||||
from_station_id = self.get_station_id(from_station)
|
||||
if to_station is None:
|
||||
to_station_id = self.default_to_station_id
|
||||
else:
|
||||
to_station_id = self.get_station_id(to_station)
|
||||
trip = self.vasttrafik.trip(from_station_id, to_station_id)
|
||||
|
||||
response = []
|
||||
for departure in trip:
|
||||
try:
|
||||
departure_dt = datetime.strptime(
|
||||
departure["tripLegs"][0]["origin"]["estimatedTime"][:19], "%Y-%m-%dT%H:%M:%S"
|
||||
)
|
||||
minutes_to_departure = (departure_dt - datetime.now()).seconds / 60
|
||||
if minutes_to_departure >= self.config["plugins"]["vasttrafik"]["delay"]:
|
||||
response.append(
|
||||
{
|
||||
"origin": {
|
||||
"stationName": departure["tripLegs"][0]["origin"]["stopPoint"]["name"],
|
||||
# "plannedTime": departure["tripLegs"][0]["origin"]["plannedTime"],
|
||||
"estimatedDepartureTime": departure["tripLegs"][0]["origin"]["estimatedTime"],
|
||||
},
|
||||
"destination": {
|
||||
"stationName": departure["tripLegs"][0]["destination"]["stopPoint"]["name"],
|
||||
"estimatedArrivalTime": departure["tripLegs"][0]["estimatedArrivalTime"],
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
assert len(response) > 0, "No trips found."
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
|
||||
def get_station_id(self, location_name: str) -> str:
|
||||
station_id = self.vasttrafik.location_name(location_name)[0]["gid"]
|
||||
return station_id
|
||||
1
agent-api/plugins/vasttrafik/requirements.txt
Normal file
1
agent-api/plugins/vasttrafik/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
vtjp>=0.2.1
|
||||
2
agent-api/pyproject.toml
Normal file
2
agent-api/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.black]
|
||||
line-length = 110
|
||||
21
agent-api/readme.md
Normal file
21
agent-api/readme.md
Normal file
@ -0,0 +1,21 @@
|
||||
## Start postgres with pgvector
|
||||
docker-compose up -d
|
||||
|
||||
Then create a new db called "rag", and run this SQL:
|
||||
|
||||
```sql
|
||||
CREATE EXTENSION vector;
|
||||
```
|
||||
|
||||
Finally run the content of db.sql.
|
||||
|
||||
## Run backend
|
||||
```bash
|
||||
conda activate llm-api
|
||||
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
|
||||
```
|
||||
|
||||
## Run frontend
|
||||
```bash
|
||||
python frontend.py
|
||||
```
|
||||
11
agent-api/requirements.txt
Normal file
11
agent-api/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
openai>=1.11.1
|
||||
sentence_transformers
|
||||
dateparser
|
||||
pgvector
|
||||
psycopg2
|
||||
pyyaml
|
||||
gradio
|
||||
hass-client>=1.0.1
|
||||
llama_cpp_python>=0.2.44
|
||||
BIN
agent-api/static/tada.wav
Normal file
BIN
agent-api/static/tada.wav
Normal file
Binary file not shown.
12
agent-api/templates/ask.en.j2
Normal file
12
agent-api/templates/ask.en.j2
Normal file
@ -0,0 +1,12 @@
|
||||
You are a helpful assistant. Your replies will be converted to speech, so keep it short and concise and avoid too many special characters.
|
||||
Today is {{ today }}.
|
||||
|
||||
{% if knowledge is not none %}
|
||||
{{ knowledge }}
|
||||
{% endif %}
|
||||
|
||||
{% for prompt in plugin_prompts %}
|
||||
- {{ prompt }}
|
||||
{% endfor %}
|
||||
|
||||
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.
|
||||
5
agent-api/templates/knowledge.en.j2
Normal file
5
agent-api/templates/knowledge.en.j2
Normal file
@ -0,0 +1,5 @@
|
||||
Here are some saved notes that might help you answer any question the user might have:
|
||||
{% for note in knowledge %}
|
||||
- {{ note }}
|
||||
{% endfor %}
|
||||
Never make notes of knowledge you already have (i.e. the list above).
|
||||
203
esphome-va-bridge/audio.py
Normal file
203
esphome-va-bridge/audio.py
Normal file
@ -0,0 +1,203 @@
|
||||
import datetime as dt
|
||||
import io
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import scipy.io.wavfile as wavfile
|
||||
import torch
|
||||
from elevenlabs.client import ElevenLabs
|
||||
from faster_whisper import WhisperModel
|
||||
from silero_vad import get_speech_timestamps, load_silero_vad
|
||||
|
||||
|
||||
class AudioSegment(pydub.AudioSegment):
|
||||
"""
|
||||
Wrapper class for pydub.AudioSegment
|
||||
"""
|
||||
|
||||
def __init__(self, data=None, *args, **kwargs):
|
||||
super().__init__(data, *args, **kwargs)
|
||||
|
||||
def to_ndarray(self) -> np.ndarray:
|
||||
"""
|
||||
Convert the AudioSegment to a numpy array.
|
||||
"""
|
||||
_buffer = io.BytesIO()
|
||||
self.export(_buffer, format="wav")
|
||||
_buffer.seek(0)
|
||||
|
||||
_, wav_data = wavfile.read(_buffer)
|
||||
if wav_data.dtype != np.float32:
|
||||
max_abs = max(np.abs(wav_data)) if max(np.abs(wav_data)) != 0 else 1
|
||||
wav_data = wav_data.astype(np.float32) / max_abs
|
||||
if len(wav_data.shape) == 2:
|
||||
wav_data = wav_data.mean(axis=1)
|
||||
return wav_data
|
||||
|
||||
def to_tensor(self) -> torch.Tensor:
|
||||
"""
|
||||
Convert the AudioSegment to a PyTorch tensor.
|
||||
"""
|
||||
return torch.tensor(self.to_ndarray(), dtype=torch.float32)
|
||||
|
||||
|
||||
class AudioBuffer:
|
||||
"""
|
||||
Buffer for AudioSegments
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int) -> None:
|
||||
self.data: AudioSegment | None = None
|
||||
self.max_size: int = max_size
|
||||
|
||||
def put(self, audio_segment: AudioSegment) -> None:
|
||||
"""
|
||||
Append an AudioSegment to the buffer.
|
||||
"""
|
||||
if self.data:
|
||||
self.data = self.data + audio_segment
|
||||
if len(self.data) >= self.max_size:
|
||||
self.data = self.data[-self.max_size :]
|
||||
else:
|
||||
self.data = audio_segment
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear the buffer.
|
||||
"""
|
||||
self.data = None
|
||||
|
||||
def get(self) -> AudioSegment | None:
|
||||
"""
|
||||
Get the AudioSegment from the buffer.
|
||||
"""
|
||||
return self.data
|
||||
|
||||
def get_last(self, duration: int) -> AudioSegment | None:
|
||||
"""
|
||||
Get the last `duration` milliseconds of the AudioSegment.
|
||||
"""
|
||||
return self.data[-duration:] if self.data else None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if the buffer is empty.
|
||||
"""
|
||||
return self.data is None
|
||||
|
||||
def length(self) -> int:
|
||||
"""
|
||||
Get the length of the buffer.
|
||||
"""
|
||||
return len(self.data) if self.data else 0
|
||||
|
||||
|
||||
class VAD:
|
||||
"""
|
||||
Voice Activity Detection
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = load_silero_vad()
|
||||
self.ts_voice_detected: dt.datetime | None = None
|
||||
|
||||
def detect_voice(self, audio_segment: AudioSegment) -> bool:
|
||||
"""
|
||||
Detect voice in an AudioSegment.
|
||||
"""
|
||||
speech_timestamps = get_speech_timestamps(
|
||||
audio_segment.to_tensor(),
|
||||
self.model,
|
||||
return_seconds=True,
|
||||
)
|
||||
if len(speech_timestamps) > 0:
|
||||
if self.ts_voice_detected is None:
|
||||
pass # print("VAD: Voice detected.")
|
||||
self.ts_voice_detected = dt.datetime.now()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def ms_since_voice_detected(self) -> int | None:
|
||||
"""
|
||||
Get the milliseconds since voice was detected.
|
||||
"""
|
||||
if self.ts_voice_detected is None:
|
||||
return None
|
||||
return (dt.datetime.now() - self.ts_voice_detected).total_seconds() * 1000
|
||||
|
||||
def reset_timer(self):
|
||||
"""
|
||||
Reset the timer for voice detected.
|
||||
"""
|
||||
self.ts_voice_detected = None
|
||||
|
||||
|
||||
class STT:
|
||||
"""
|
||||
Speech-to-Text
|
||||
"""
|
||||
|
||||
def __init__(self, model="medium", device="cuda" if torch.cuda.is_available() else "cpu", benchmark=False) -> None:
|
||||
compute_type = "int8" if device == "cpu" else "float16"
|
||||
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
||||
self.benchmark = benchmark
|
||||
|
||||
def transcribe(self, audio_segment: AudioSegment, language="sv"):
|
||||
"""
|
||||
Transcribe an AudioSegment.
|
||||
"""
|
||||
if self.benchmark:
|
||||
ts_start = dt.datetime.now()
|
||||
segments, info = self.model.transcribe(audio_segment.to_ndarray(), language=language)
|
||||
segments = list(segments)
|
||||
text = " ".join([seg.text for seg in segments])
|
||||
text = text.strip()
|
||||
if self.benchmark:
|
||||
ts_done = dt.datetime.now()
|
||||
delta_ms = (ts_done - ts_start).total_seconds() * 1000
|
||||
print(f"[Benchmark] STT: {delta_ms:.2f} ms.")
|
||||
return text
|
||||
|
||||
|
||||
class TTS:
|
||||
"""
|
||||
Text-to-Speech (TTS) class that uses the ElevenLabs API.
|
||||
"""
|
||||
|
||||
def __init__(self, voice="Brian", model="eleven_multilingual_v2") -> None:
|
||||
self.client = ElevenLabs()
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
|
||||
def generate(self, text: str, voice: str = "default") -> AudioSegment:
|
||||
if voice == "default":
|
||||
voice = self.voice
|
||||
audio = self.client.generate(
|
||||
text=text,
|
||||
voice=self.voice,
|
||||
model=self.model,
|
||||
stream=False,
|
||||
output_format="pcm_16000",
|
||||
optimize_streaming_latency=2,
|
||||
)
|
||||
audio = b"".join(audio)
|
||||
audio_segment = AudioSegment(data=audio, sample_width=2, frame_rate=16000, channels=1)
|
||||
return audio_segment
|
||||
|
||||
def stream(self, text_stream: Iterator[str], voice: str = "default") -> Iterator[bytes]:
|
||||
if voice == "default":
|
||||
voice = self.voice
|
||||
audio_stream = self.client.generate(
|
||||
text=text_stream,
|
||||
voice=self.voice,
|
||||
model=self.model,
|
||||
stream=True,
|
||||
output_format="pcm_16000",
|
||||
optimize_streaming_latency=2,
|
||||
)
|
||||
|
||||
for chunk in audio_stream:
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
113
esphome-va-bridge/esp-working.log
Normal file
113
esphome-va-bridge/esp-working.log
Normal file
@ -0,0 +1,113 @@
|
||||
[16:22:06.368][D][i2s_audio.microphone:377]: Starting I2S Audio Microphne
|
||||
[16:22:06.368][D][i2s_audio.microphone:381]: Started I2S Audio Microphone
|
||||
|
||||
[16:22:06.368][D][micro_wake_word:417]: State changed from IDLE to DETECTING_WAKE_WORD
|
||||
[16:22:08.878][D][micro_wake_word:355]: Detected 'Okay Nabu' with sliding average probability is 0.97 and max probability is 1.00
|
||||
|
||||
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
|
||||
[16:22:08.878][D][media_player:084]: Command: STOP
|
||||
[16:22:08.878][D][media_player:093]: Announcement: yes
|
||||
[16:22:08.878][D][media_player:080]: 'Media Player' - Setting
|
||||
[16:22:08.878][D][media_player:093]: Announcement: yes
|
||||
|
||||
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
|
||||
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 48000
|
||||
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
|
||||
[16:22:08.879][D][ring_buffer:034]: Created ring buffer with size 65536
|
||||
|
||||
[16:22:08.879][D][nabu_media_player.pipeline:173]: Reading FLAC file type
|
||||
[16:22:08.879][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
|
||||
[16:22:08.879][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
|
||||
|
||||
[16:22:08.879][D][ring_buffer:034][speaker_task]: Created ring buffer with size 19200
|
||||
|
||||
[16:22:08.879][D][i2s_audio.speaker:111]: Starting Speaker
|
||||
[16:22:08.880][D][i2s_audio.speaker:116]: Started Speaker
|
||||
|
||||
[16:22:08.880][D][voice_assistant:515]: State changed from IDLE to START_MICROPHONE
|
||||
[16:22:08.880][D][voice_assistant:522]: Desired state set to START_PIPELINE
|
||||
[16:22:08.880][D][voice_assistant:225]: Starting Microphone
|
||||
|
||||
[16:22:08.880][D][ring_buffer:034]: Created ring buffer with size 16384
|
||||
|
||||
[16:22:08.880][D][voice_assistant:515]: State changed from START_MICROPHONE to STARTING_MICROPHONE
|
||||
[16:22:08.880][D][voice_assistant:515]: State changed from STARTING_MICROPHONE to START_PIPELINE
|
||||
[16:22:08.880][D][voice_assistant:280]: Requesting start...
|
||||
[16:22:08.880][D][voice_assistant:515]: State changed from START_PIPELINE to STARTING_PIPELINE
|
||||
[16:22:08.880][D][voice_assistant:537]: Client started, streaming microphone
|
||||
[16:22:08.881][D][voice_assistant:515]: State changed from STARTING_PIPELINE to STREAMING_MICROPHONE
|
||||
[16:22:08.881][D][voice_assistant:522]: Desired state set to STREAMING_MICROPHONE
|
||||
[16:22:08.881][D][voice_assistant:641]: Event Type: 1 (VOICE_ASSISTANT_RUN_START)
|
||||
[16:22:08.881][D][voice_assistant:644]: Assist Pipeline running
|
||||
[16:22:08.881][D][voice_assistant:641]: Event Type: 3 (VOICE_ASSISTANT_STT_START)
|
||||
[16:22:08.881][D][voice_assistant:655]: STT started
|
||||
|
||||
[16:22:08.881][D][light:036]: 'voice_assistant_leds' Setting:
|
||||
[16:22:08.881][D][light:047]: State: ON
|
||||
[16:22:08.881][D][light:051]: Brightness: 66%
|
||||
[16:22:08.882][D][light:109]: Effect: 'Waiting for Command'
|
||||
|
||||
[16:22:08.882][D][power_supply:033]: Enabling power supply.
|
||||
|
||||
[16:22:09.671][D][voice_assistant:641]: Event Type: 11 (VOICE_ASSISTANT_STT_VAD_START)
|
||||
[16:22:09.671][D][voice_assistant:804]: Starting STT by VAD
|
||||
|
||||
[16:22:09.671][D][light:036]: 'voice_assistant_leds' Setting:
|
||||
[16:22:09.671][D][light:051]: Brightness: 66%
|
||||
[16:22:09.671][D][light:109]: Effect: 'Listening For Command'
|
||||
|
||||
[16:22:14.806][D][voice_assistant:641]: Event Type: 12 (VOICE_ASSISTANT_STT_VAD_END)
|
||||
[16:22:14.806][D][voice_assistant:808]: STT by VAD end
|
||||
[16:22:14.806][D][voice_assistant:515]: State changed from STREAMING_MICROPHONE to STOP_MICROPHONE
|
||||
[16:22:14.806][D][voice_assistant:522]: Desired state set to AWAITING_RESPONSE
|
||||
[16:22:14.806][D][voice_assistant:515]: State changed from STOP_MICROPHONE to STOPPING_MICROPHONE
|
||||
|
||||
[16:22:14.807][D][light:036]: 'voice_assistant_leds' Setting:
|
||||
[16:22:14.807][D][light:051]: Brightness: 66%
|
||||
[16:22:14.807][D][light:109]: Effect: 'Thinking'
|
||||
|
||||
[16:22:14.807][D][voice_assistant:515]: State changed from STOPPING_MICROPHONE to AWAITING_RESPONSE
|
||||
[16:22:14.807][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to AWAITING_RESPONSE
|
||||
|
||||
[16:22:14.807][D][power_supply:033]: Enabling power supply.
|
||||
[16:22:14.807][D][power_supply:033]: Enabling power supply.
|
||||
|
||||
[16:22:14.807][D][voice_assistant:641]: Event Type: 4 (VOICE_ASSISTANT_STT_END)
|
||||
[16:22:14.807][D][voice_assistant:669]: Speech recognised as: " Who are you?"
|
||||
[16:22:14.807][D][voice_assistant:641]: Event Type: 5 (VOICE_ASSISTANT_INTENT_START)
|
||||
[16:22:14.808][D][voice_assistant:674]: Intent started
|
||||
|
||||
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||
[16:22:14.808][D][power_supply:033]: Enabling power supply.
|
||||
|
||||
[16:22:14.808][D][voice_assistant:641]: Event Type: 6 (VOICE_ASSISTANT_INTENT_END)
|
||||
[16:22:14.808][D][voice_assistant:641]: Event Type: 7 (VOICE_ASSISTANT_TTS_START)
|
||||
[16:22:14.808][D][voice_assistant:697]: Response: "I am an AI assistant designed to help you with various tasks, answer questions, and provide information. If there's anything specific you need help with, feel free to ask!"
|
||||
|
||||
[16:22:14.809][D][light:036]: 'voice_assistant_leds' Setting:
|
||||
[16:22:14.809][D][light:051]: Brightness: 66%
|
||||
[16:22:14.809][D][light:109]: Effect: 'Replying'
|
||||
|
||||
[16:22:14.809][D][voice_assistant:641]: Event Type: 8 (VOICE_ASSISTANT_TTS_END)
|
||||
[16:22:14.809][D][voice_assistant:719]: Response URL: "https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac"
|
||||
[16:22:14.809][D][voice_assistant:515]: State changed from AWAITING_RESPONSE to STREAMING_RESPONSE
|
||||
[16:22:14.809][D][voice_assistant:522]: Desired state set to STREAMING_RESPONSE
|
||||
|
||||
[16:22:14.809][D][media_player:080]: 'Media Player' - Setting
|
||||
[16:22:14.809][D][media_player:087]: Media URL: https://homeassistant.wessman.xyz/api/tts_proxy/wrNRlBud6QN83FrSO0oa1A.flac
|
||||
[16:22:14.810][D][media_player:093]: Announcement: yes
|
||||
|
||||
[16:22:14.810][D][voice_assistant:641]: Event Type: 2 (VOICE_ASSISTANT_RUN_END)
|
||||
[16:22:14.810][D][voice_assistant:733]: Assist Pipeline ended
|
||||
|
||||
[16:22:14.810][D][esp-idf:000][ann_read]: I (40514) esp-x509-crt-bundle: Certificate validated
|
||||
|
||||
[16:22:14.810][D][nabu_media_player.pipeline:173]: Reading FLAC file type
|
||||
[16:22:14.810][D][nabu_media_player.pipeline:184]: Decoded audio has 1 channels, 48000 Hz sample rate, and 16 bits per sample
|
||||
[16:22:14.810][D][nabu_media_player.pipeline:211]: Converting mono channel audio to stereo channel audio
|
||||
|
||||
[16:22:26.112][D][voice_assistant:515]: State changed from STREAMING_RESPONSE to IDLE
|
||||
[16:22:26.112][D][voice_assistant:522]: Desired state set to IDLE
|
||||
283
esphome-va-bridge/esphome_server.py
Normal file
283
esphome-va-bridge/esphome_server.py
Normal file
@ -0,0 +1,283 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import wave
|
||||
from enum import StrEnum
|
||||
import traceback
|
||||
|
||||
import aioesphomeapi
|
||||
from aioesphomeapi import (VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag, VoiceAssistantEventType)
|
||||
from aiohttp import web
|
||||
|
||||
from audio import STT, TTS, VAD, AudioBuffer, AudioSegment
|
||||
from llm_api import LlmApi
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
|
||||
class PipelineStage(StrEnum):
|
||||
"""Stages of a pipeline."""
|
||||
|
||||
WAKE_WORD = "wake_word"
|
||||
STT = "stt"
|
||||
INTENT = "intent"
|
||||
TTS = "tts"
|
||||
END = "end"
|
||||
|
||||
|
||||
class Satellite:
|
||||
|
||||
def __init__(self, host="192.168.10.155", port=6053, password=""):
|
||||
self.client = aioesphomeapi.APIClient(host, port, password)
|
||||
self.audio_queue: AudioBuffer = AudioBuffer(max_size=60000)
|
||||
self.state = "idle"
|
||||
self.vad = VAD()
|
||||
self.stt = STT(device="cuda")
|
||||
self.tts = TTS()
|
||||
self.agent = LlmApi()
|
||||
|
||||
async def connect(self):
|
||||
await self.client.connect(login=True)
|
||||
|
||||
async def disconnect(self):
|
||||
await self.client.disconnect()
|
||||
|
||||
async def handle_pipeline_start(
|
||||
self,
|
||||
conversation_id: str,
|
||||
flags: int,
|
||||
audio_settings: VoiceAssistantAudioSettings,
|
||||
wake_word_phrase: str | None,
|
||||
) -> int:
|
||||
logging.debug(
|
||||
f"Pipeline starting with conversation_id={conversation_id}, flags={flags}, audio_settings={audio_settings}, wake_word_phrase={wake_word_phrase}"
|
||||
)
|
||||
|
||||
# Device triggered pipeline (wake word, etc.)
|
||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||
start_stage = PipelineStage.WAKE_WORD
|
||||
else:
|
||||
start_stage = PipelineStage.STT
|
||||
|
||||
end_stage = PipelineStage.TTS
|
||||
|
||||
logging.info(f"Starting pipeline from {start_stage} to {end_stage}")
|
||||
self.state = "running_pipeline"
|
||||
|
||||
# Run pipeline
|
||||
asyncio.create_task(self.run_pipeline(wake_word_phrase))
|
||||
|
||||
return 0
|
||||
|
||||
async def run_pipeline(self, wake_word_phrase: str | None = None):
|
||||
logging.info(f"Pipeline started using the wake word '{wake_word_phrase}'.")
|
||||
|
||||
try:
|
||||
logging.debug(" > STT start")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None)
|
||||
|
||||
logging.debug(" > STT VAD start")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START, None)
|
||||
|
||||
# VAD
|
||||
logging.debug(" > VAD: Waiting for silence...")
|
||||
VAD_TIMEOUT = 1000
|
||||
_voice_detected = False
|
||||
while True:
|
||||
if self.audio_queue.length() > 20000:
|
||||
break
|
||||
elif self.audio_queue.length() >= VAD_TIMEOUT:
|
||||
voice_detected = self.vad.detect_voice(self.audio_queue.get_last(VAD_TIMEOUT))
|
||||
if voice_detected != _voice_detected:
|
||||
_voice_detected = voice_detected
|
||||
if not _voice_detected:
|
||||
logging.debug(" > VAD: Silence detected.")
|
||||
break
|
||||
if self.audio_queue.length() > 2000 and not voice_detected:
|
||||
logging.debug(" > VAD: Silence detected (no voice).")
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logging.debug(" > STT VAD end")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END, None)
|
||||
|
||||
# STT
|
||||
text = self.stt.transcribe(self.audio_queue.get())
|
||||
self.audio_queue.clear()
|
||||
logging.info(f" > STT: {text}")
|
||||
|
||||
logging.debug(" > STT end")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": text})
|
||||
|
||||
logging.debug(" > Intent start")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None)
|
||||
agent_response = self.agent.ask(text)["answer"]
|
||||
logging.info(f">Intent: {agent_response}")
|
||||
|
||||
logging.debug(" > Intent end")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, None)
|
||||
|
||||
logging.debug(" > TTS start")
|
||||
TTS = "announce"
|
||||
if TTS == "announce":
|
||||
self.client.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": agent_response}
|
||||
)
|
||||
|
||||
voice = "default"
|
||||
audio_hash_string = hashlib.md5((voice + ":" + agent_response).encode("utf-8")).hexdigest()
|
||||
audio_file_path = f"static/{audio_hash_string}.wav"
|
||||
if not os.path.exists(audio_file_path):
|
||||
logging.debug(" > TTS: Generating audio...")
|
||||
audio_segment = self.tts.generate(agent_response, voice=voice)
|
||||
else:
|
||||
logging.debug(" > TTS: Using cached audio.")
|
||||
audio_segment = AudioSegment.from_file(audio_file_path)
|
||||
audio_segment.export(audio_file_path, format="wav")
|
||||
self.client.send_voice_assistant_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||
{"url": f"http://192.168.10.111:8002/{audio_file_path}"},
|
||||
)
|
||||
elif TTS == "stream":
|
||||
await self._stream_tts_audio("debug")
|
||||
logging.debug(" > TTS end")
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Error: ")
|
||||
logging.error(e)
|
||||
logging.error(traceback.format_exc())
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"message": str(e)})
|
||||
|
||||
logging.debug(" > Run end")
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, None)
|
||||
logging.info("Pipeline done")
|
||||
|
||||
async def handle_pipeline_stop(self, abort: bool):
|
||||
logging.info("Pipeline stopped")
|
||||
self.state = "idle"
|
||||
|
||||
async def handle_audio(self, data: bytes) -> None:
|
||||
"""Handle incoming audio chunk from API."""
|
||||
self.audio_queue.put(AudioSegment(data=data, sample_width=2, frame_rate=16000, channels=1))
|
||||
|
||||
async def handle_announcement_finished(self, event: aioesphomeapi.VoiceAssistantAnnounceFinished):
|
||||
if event.success:
|
||||
logging.info("Announcement finished successfully.")
|
||||
else:
|
||||
logging.error("Announcement failed.")
|
||||
|
||||
def handle_state_change(self, state):
|
||||
logging.debug("State change:")
|
||||
logging.debug(state)
|
||||
|
||||
def handle_log(self, log):
|
||||
logging.info("Log from device:", log)
|
||||
|
||||
def subscribe(self):
|
||||
self.client.subscribe_states(self.handle_state_change)
|
||||
self.client.subscribe_logs(self.handle_log)
|
||||
self.client.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
handle_audio=self.handle_audio,
|
||||
handle_announcement_finished=self.handle_announcement_finished,
|
||||
)
|
||||
|
||||
async def _stream_tts_audio(
|
||||
self,
|
||||
media_id: str,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
sample_channels: int = 1,
|
||||
samples_per_chunk: int = 1024, # 512
|
||||
) -> None:
|
||||
"""Stream TTS audio chunks to device via API or UDP."""
|
||||
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
||||
logging.info("TTS stream start")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
# if not self._is_running:
|
||||
# return
|
||||
|
||||
with wave.open(f"{media_id}.wav", "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
or (wav_file.getsampwidth() != sample_width)
|
||||
or (wav_file.getnchannels() != sample_channels)
|
||||
):
|
||||
logging.info(f"Can only stream 16Khz 16-bit mono WAV, got {wav_file.getparams()}")
|
||||
return
|
||||
|
||||
logging.info("Streaming %s audio samples", wav_file.getnframes())
|
||||
|
||||
rate = wav_file.getframerate()
|
||||
width = wav_file.getsampwidth()
|
||||
channels = wav_file.getnchannels()
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
bytes_per_sample = width * channels
|
||||
bytes_per_chunk = bytes_per_sample * samples_per_chunk
|
||||
num_chunks = int(math.ceil(len(audio_bytes) / bytes_per_chunk))
|
||||
|
||||
# Split into chunks
|
||||
for i in range(num_chunks):
|
||||
offset = i * bytes_per_chunk
|
||||
chunk = audio_bytes[offset : offset + bytes_per_chunk]
|
||||
self.client.send_voice_assistant_audio(chunk)
|
||||
|
||||
# Wait for 90% of the duration of the audio that was
|
||||
# sent for it to be played. This will overrun the
|
||||
# device's buffer for very long audio, so using a media
|
||||
# player is preferred.
|
||||
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||
logging.info(f"Samples in chunk: {samples_in_chunk}, seconds in chunk: {seconds_in_chunk}")
|
||||
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return # Don't trigger state change
|
||||
finally:
|
||||
self.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {})
|
||||
logging.info("TTS stream end")
|
||||
|
||||
|
||||
async def start_http_server(port=8002):
|
||||
app = web.Application()
|
||||
app.router.add_static("/static/", path=os.path.join(os.path.dirname(__file__), "static"))
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "0.0.0.0", port)
|
||||
await site.start()
|
||||
logging.info(f"HTTP server started at http://0.0.0.0:{port}/")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Connect to an ESPHome device and wait for state changes."""
|
||||
await start_http_server()
|
||||
|
||||
satellite = Satellite()
|
||||
await satellite.connect()
|
||||
logging.info("Connected to ESPHome voice assistant.")
|
||||
satellite.subscribe()
|
||||
satellite.client.send_voice_assistant_event(VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START, None)
|
||||
logging.info("Listening for events...")
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
13
esphome-va-bridge/llm_api.py
Normal file
13
esphome-va-bridge/llm_api.py
Normal file
@ -0,0 +1,13 @@
|
||||
import requests
|
||||
|
||||
|
||||
class LlmApi:
|
||||
def __init__(self, host="127.0.0.1", port=8000):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f"http://{host}:{port}"
|
||||
|
||||
def ask(self, question):
|
||||
url = self.url + "/ask"
|
||||
response = requests.post(url, json={"question": question})
|
||||
return response.json()
|
||||
62
esphome-va-bridge/log_reader.py
Normal file
62
esphome-va-bridge/log_reader.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# Helper script and aioesphomeapi to view logs from an esphome device
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from aioesphomeapi.api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.log_runner import async_run
|
||||
|
||||
|
||||
async def main(argv: list[str]) -> None:
|
||||
parser = argparse.ArgumentParser("aioesphomeapi-logs")
|
||||
parser.add_argument("--port", type=int, default=6053)
|
||||
parser.add_argument("--password", type=str)
|
||||
parser.add_argument("--noise-psk", type=str)
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
parser.add_argument("address")
|
||||
args = parser.parse_args(argv[1:])
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
cli = APIClient(
|
||||
args.address,
|
||||
args.port,
|
||||
args.password or "",
|
||||
noise_psk=args.noise_psk,
|
||||
keepalive=10,
|
||||
)
|
||||
|
||||
def on_log(msg: SubscribeLogsResponse) -> None:
|
||||
time_ = datetime.now()
|
||||
message: bytes = msg.message
|
||||
text = message.decode("utf8", "backslashreplace")
|
||||
nanoseconds = time_.microsecond // 1000
|
||||
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}.{nanoseconds:03}]{text}")
|
||||
|
||||
stop = await async_run(cli, on_log)
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
await stop()
|
||||
|
||||
|
||||
def cli_entry_point() -> None:
|
||||
"""Run the CLI."""
|
||||
try:
|
||||
asyncio.run(main(sys.argv))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_entry_point()
|
||||
sys.exit(0)
|
||||
1
external/home-assistant-voice-pe
vendored
Submodule
1
external/home-assistant-voice-pe
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 4c09d89ddbf234546f244bf98b1053a2f89e5ab6
|
||||
2
pyproject.toml
Normal file
2
pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
Loading…
x
Reference in New Issue
Block a user