Compare commits

..

No commits in common. "5dec83354454900fd6ab5b25f0db385acdb10fd2" and "5ecd5ded3eb26276ff6cdbafc786019c87397400" have entirely different histories.

11 changed files with 214 additions and 201 deletions

1
agent/.gitignore vendored
View File

@ -1 +0,0 @@
config.yml

View File

@ -1,127 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
This is a voice assistant agent that integrates with Home Assistant and provides tool-calling capabilities through a plugin architecture. The system uses OpenAI's API (or compatible endpoints) for LLM interactions and supports multiple languages (English and Swedish).
## Architecture
### Core Components
- **[src/backend.py](src/backend.py)**: FastAPI server that loads plugins and exposes the `/ask` endpoint
- **[src/agent.py](src/agent.py)**: Main Agent class that orchestrates LLM queries with plugin tools
- **[src/llm.py](src/llm.py)**: LLM abstraction layer supporting OpenAI, LMStudio, and Llama models
- `OpenAIChat`: Handles OpenAI API calls with function calling support (parallel execution)
- `LMStudioChat`: Local LLM via LMStudio
- `LlamaChat`: Direct llama.cpp integration
- **[src/db.py](src/db.py)**: PostgreSQL database with pgvector extension for embeddings
- **[src/frontend.py](src/frontend.py)**: Gradio web interface for chat
### Plugin System
The plugin architecture is the heart of this system. All plugins inherit from [src/plugins/base_plugin.py](src/plugins/base_plugin.py) and must:
1. Define methods starting with `tool_` to expose functions to the LLM
2. Use Pydantic models for type-safe function parameters
3. Return JSON-formatted strings from tool functions
4. Optionally provide a `prompt()` method to inject context into the system prompt
**Plugin discovery**: The backend automatically loads all folders in `src/plugins/` that contain a `plugin.py` file with a `Plugin` class.
**Base plugin features**:
- Automatically converts `tool_*` methods into OpenAI function calling format
- Introspects Pydantic models to generate JSON schemas
- Provides access to Home Assistant helper class
**Home Assistant integration**: Plugins can use `self.homeassistant` (from [src/plugins/homeassistant.py](src/plugins/homeassistant.py)) to interact with Home Assistant via REST API and WebSocket.
### Template System
System prompts are managed through Jinja2 templates in `src/templates/`:
- Templates are language-specific (e.g., `ask.en.j2`, `ask.sv.j2`)
- Plugins can inject their own prompts via the `prompt()` method
- Templates receive: current datetime, plugin prompts, and optional knowledge
## Development Commands
### Database Setup
Start PostgreSQL with pgvector:
```bash
docker-compose up -d
```
Create database and enable pgvector extension:
```sql
CREATE DATABASE rag;
\c rag
CREATE EXTENSION vector;
```
Run schema initialization:
```bash
psql -U postgres -h localhost -p 5433 -d rag -f db.sql
```
### Running the Application
Backend (FastAPI server):
```bash
cd src
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
```
Frontend (Gradio UI):
```bash
cd src
python frontend.py
```
### Configuration
Copy `src/config.default.yml` to `src/config.yml` and configure:
- OpenAI API settings (or compatible base_url)
- Home Assistant URL and token
- Plugin-specific settings (Spotify, Västtrafik, etc.)
## Working with Plugins
### Creating a New Plugin
1. Create folder: `src/plugins/myplugin/`
2. Create `plugin.py` with a `Plugin` class inheriting from `BasePlugin`
3. Define tool functions with `tool_` prefix
4. Use Pydantic models for parameters
5. Add plugin config to `config.yml` under `plugins.myplugin`
Example structure:
```python
from pydantic import BaseModel, Field
from ..base_plugin import BasePlugin
class MyInput(BaseModel):
param: str = Field(..., description="Parameter description")
class Plugin(BasePlugin):
def tool_my_function(self, input: MyInput):
"""Function description for LLM"""
# Implementation
return json.dumps({"status": "success"})
def prompt(self) -> str | None:
return "Additional context for the LLM"
```
### Testing Function Calling
The LLM class in [src/llm.py](src/llm.py) handles parallel function execution automatically. When the LLM responds with tool calls, they are executed concurrently using `ThreadPoolExecutor`.
## Important Notes
- **Working directory**: The backend expects to be run from the `src/` directory due to relative imports and file paths
- **Database connection**: Hardcoded to localhost:5433 in [src/db.py](src/db.py:17-23)
- **Function calling**: Currently only fully supported with `OpenAIChat` backend
- **Response format**: Tool functions must return JSON strings for proper LLM interpretation
- **Logging**: All modules use Python's logging module at INFO level

View File

@ -8,37 +8,145 @@ from db import Database
from llm import LLM
def relative_to_actual_dates(text: str):
from dateparser.search import search_dates
matches = search_dates(text, languages=["en", "se"])
if matches:
for match in matches:
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
return text
class Agent:
def __init__(self, config: dict) -> None:
self.config = config
self.db = Database("rag")
self.llm = LLM(config=self.config["llm"])
def flush(self):
"""
Flushes the agents knowledge.
"""
logging.info("Truncating database")
self.db.cur.execute("TRUNCATE TABLE items")
self.db.conn.commit()
def insert(self, text):
"""
Append knowledge.
"""
logging.info(f"Inserting item into embedding table ({text}).")
#logging.info(f"\tReplacing relative dates with actual dates.")
#text = relative_to_actual_dates(text)
#logging.info(f"\tDone. text: {text}")
vector = self.llm.embedding(text)
vector_padded = vector + ([0]*(2048-len(vector)))
self.db.cur.execute(
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
)
self.db.conn.commit()
def keywords(self, query: str, n_keywords: int = 5):
"""
Suggest keywords related to a query.
"""
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
prompt = (
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
+ f"Do not add any context or text other than the json list of {n_keywords} words."
)
response = self.llm.query(prompt, system_msg=sys_msg)
keywords = json.loads(response)
return keywords
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
"""
Retrieve relevant knowledge.
Parameters:
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
logging.debug(f"Using embedding model on '{query}'.")
vector_keywords = self.llm.embedding(query)
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
logging.debug("Querying database")
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
self.db.cur.execute(
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
(vector_padded,),
)
db_response = self.db.cur.fetchall()
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
logging.debug("Database returned: ")
logging.debug(knowledge_arr)
if n_retrieve > n_rerank:
logging.info("Reranking the results")
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
reranked_knowledge_arr = sorted(
knowledge_arr,
key=lambda x: (
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
),
)
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
logging.debug("Reranked results: ")
logging.debug(reranked_knowledge_arr)
return reranked_knowledge_arr
else:
return knowledge_arr
def ask(
self,
question: str,
person: str,
history: list = None,
n_retrieve: int = 5,
n_rerank=3,
tools: Union[None, list] = None,
prompts: Union[None, list] = None,
language: str = "en",
device_id: str = ""
):
"""
Ask the agent a question.
Parameters:
- question (str): The question to ask
- person (str): Who asks the question?
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Answering question: {question}")
allowed_languages = ['en', 'sv']
assert language in allowed_languages, "Language must be any of " + ", ".join(allowed_languages)
with open(f'templates/ask.{language}.j2', 'r') as file:
if n_retrieve > 0:
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
with open('templates/knowledge.en.j2', 'r') as file:
template = Template(file.read())
knowledge_str = template.render(
knowledge = knowledge,
)
else:
knowledge_str = ""
with open('templates/ask.en.j2', 'r') as file:
template = Template(file.read())
prompt = template.render(
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
plugin_prompts = prompts
knowledge = knowledge_str,
plugin_prompts = prompts,
person = person
)
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")

View File

@ -16,34 +16,27 @@ logging.basicConfig(
level=logging.INFO,
)
with open("config.yml", "r", encoding="utf8") as file:
with open("config.yml", "r", encoding='utf8') as file:
config = yaml.safe_load(file)
def load_plugins(config):
tools = []
prompts = []
plugin_instances = []
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
if (
os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder))
and "__pycache__" not in plugin_folder
):
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
logging.info(f"Loading plugin: {plugin_folder}")
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
module = importlib.import_module(module_path)
plugin = module.Plugin(config)
plugin_instances.append(plugin)
tools += plugin.tools()
prompt = plugin.prompt()
if prompt is not None:
prompts.append(prompt)
return tools, prompts, plugin_instances
return tools, prompts
tools, prompts, plugins = load_plugins(config)
tools, prompts = load_plugins(config)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
@ -54,18 +47,14 @@ agent = Agent(config)
class Query(BaseModel):
query: str
class Message(BaseModel):
role: str
content: str
class Question(BaseModel):
question: str
person: str = "User"
history: list[Message] = None
language: str = "en"
device_id: str = ""
agent_id: str = ""
@app.get("/ping")
@ -76,7 +65,6 @@ def ping():
"""
return {"status": "ok"}
@app.post("/ask")
def ask(q: Question):
"""
@ -91,9 +79,49 @@ def ask(q: Question):
if q.history:
history = [q.model_dump() for q in q.history]
# Set device_id on all plugin instances
for plugin in plugins:
plugin.device_id = q.device_id
answer = agent.ask(question=q.question, history=history, tools=tools, prompts=prompts, language=q.language, device_id=q.device_id)
answer = agent.ask(
question=q.question,
person="Pierre",
history=history,
tools=tools,
prompts=prompts,
n_retrieve=0,
n_rerank=0
)
return {"question": q.question, "answer": answer, "history": q.history}
@app.post("/retrieve")
def retrieve(q: Query):
"""
# Retrieve
Retrieve knowledge related to a query.
## Parameters
- **query**: the question.
"""
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
return {"query": q.query, "answer": result}
@app.post("/learn")
def retrieve(q: Query):
"""
# Learn
Learn the agent something. E.g. *"John likes to play volleyball".*
## Parameters
- **query**: the note.
"""
answer = agent.insert(text=q.query)
return {"text": q.query}
@app.get("/flush")
def flush():
"""
# Flush
Remove all knowledge from the agent.
"""
agent.flush()
return {"message": "Knowledge flushed."}

View File

@ -1,6 +1,11 @@
server:
url: "http://127.0.0.1:8000"
llm:
use_local_chat: false
use_local_embedding: false
local_model_dir: ""
local_embedding_model: "intfloat/e5-large-v2"
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
openai_embedding_model: "text-embedding-3-small"
openai_chat_model: "gpt-3.5-turbo-0125"
temperature: 0.1
@ -18,9 +23,3 @@ plugins:
default_from_station: "Brunnsparken"
default_to_station: "Centralstationen"
delay: 0 # minutes
music:
default_speaker: "media_player.kok"
device_speakers:
"d364e96a6664016ea4ac35f64a0d7498": "media_player.media_player.theo"
spotify_client_id: ""
spotify_client_secret: ""

View File

@ -5,6 +5,13 @@ import numpy as np
from openai import OpenAI
from typing import Union, List
try:
from llama_cpp.llama import Llama
from sentence_transformers import SentenceTransformer, CrossEncoder
except ImportError as e:
print("Faield to import packages:")
print(e)
class BaseChat:
def __init__(self, config: dict) -> None:
@ -166,7 +173,19 @@ class LLM:
"""
self.config = config
self.chat_client = OpenAIChat(self.config)
if self.config["use_local_chat"]:
self.chat_client = LlamaChat(self.config)
else:
self.chat_client = OpenAIChat(self.config)
if self.config["use_local_embedding"]:
self.embedding_model = self.config["local_embedding_model"]
self.rerank_model = self.config["local_rerank_model"]
self.embedding_client = SentenceTransformer(self.embedding_model)
else:
self.embedding_model = self.config["openai"]["embedding_model"]
self.rerank_model = None
self.embedding_client = OpenAI(api_key=self.config["openai"]["api_key"])
def query(
self,
@ -217,3 +236,22 @@ class LLM:
logging.info(f"Sending request to LLM with {len(messages)} messages.")
answer = self.chat_client.chat(messages=messages, tools=tools)
return answer
def embedding(self, text: str):
if self.config["use_local_embedding"]:
embedding = self.embedding_client.encode(text)
return embedding # len = 1024
else:
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
return llm_response.data[0].embedding
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
if self.rerank_model is not None and len(corpus) > 0:
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
logging.info("Reranking:")
logging.info(sentence_combinations)
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
return result[:n_rerank]
else:
return corpus

View File

@ -7,7 +7,6 @@ class BasePlugin:
def __init__(self, config: dict) -> None:
self.config = config
self.homeassistant = HomeAssistant(config)
self.device_id = ""
def prompt(self) -> str | None:
return

View File

@ -25,22 +25,6 @@ class Plugin(BasePlugin):
)
)
def _get_speaker_for_device(self, device_id: str) -> str:
"""
Get the appropriate speaker entity_id based on device_id.
Falls back to default_speaker if device_id is not found or empty.
"""
device_speakers = self.config["plugins"]["music"].get("device_speakers", {})
if device_id and device_id in device_speakers:
speaker = device_speakers[device_id]
logging.info(f"Using device-specific speaker for {device_id}: {speaker}")
return speaker
default_speaker = self.config["plugins"]["music"]["default_speaker"]
logging.info(f"Using default speaker: {default_speaker}")
return default_speaker
def _search(self, query: str, limit: int = 10):
_result = self.spotify.search(query, limit=limit)
result = []
@ -60,10 +44,9 @@ class Plugin(BasePlugin):
Play music using a search query.
"""
track = self._search(input.query, limit=1)[0]
speaker = self._get_speaker_for_device(self.device_id)
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])} on {speaker}")
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
payload = {
"entity_id": speaker,
"entity_id": self.config["plugins"]["music"]["default_speaker"],
"media_content_id": track["uri"],
"media_content_type": "music",
"enqueue": "play",
@ -75,8 +58,7 @@ class Plugin(BasePlugin):
"""
Stop playback of music.
"""
speaker = self._get_speaker_for_device(self.device_id)
self.homeassistant.call_api(
f"services/media_player/media_pause", payload={"entity_id": speaker}
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
)
return json.dumps({"status": "success", "message": f"Music paused."})

View File

@ -1,4 +1,3 @@
requests
fastapi
uvicorn[standard]
openai>=1.11.1

View File

@ -9,4 +9,4 @@ Today is {{ today }}.
- {{ prompt }}
{% endfor %}
Answer the questions to the best of your knowledge.
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.

View File

@ -1,12 +0,0 @@
Du är en hjälpsam assistent. Dina svar kommer att konverteras till ljud/tal, så svara kort och koncist och undvik för många specialtecken.
Idag är det {{ today }}.
{% if knowledge is not none %}
{{ knowledge }}
{% endif %}
{% for prompt in plugin_prompts %}
- {{ prompt }}
{% endfor %}
Svara på frågorna på svenska efter bästa förmåga.