diff --git a/agent/.gitignore b/agent/.gitignore new file mode 100644 index 0000000..e9abc7f --- /dev/null +++ b/agent/.gitignore @@ -0,0 +1 @@ +config.yml \ No newline at end of file diff --git a/agent/CLAUDE.md b/agent/CLAUDE.md new file mode 100644 index 0000000..2de8e6e --- /dev/null +++ b/agent/CLAUDE.md @@ -0,0 +1,127 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is a voice assistant agent that integrates with Home Assistant and provides tool-calling capabilities through a plugin architecture. The system uses OpenAI's API (or compatible endpoints) for LLM interactions and supports multiple languages (English and Swedish). + +## Architecture + +### Core Components + +- **[src/backend.py](src/backend.py)**: FastAPI server that loads plugins and exposes the `/ask` endpoint +- **[src/agent.py](src/agent.py)**: Main Agent class that orchestrates LLM queries with plugin tools +- **[src/llm.py](src/llm.py)**: LLM abstraction layer supporting OpenAI, LMStudio, and Llama models + - `OpenAIChat`: Handles OpenAI API calls with function calling support (parallel execution) + - `LMStudioChat`: Local LLM via LMStudio + - `LlamaChat`: Direct llama.cpp integration +- **[src/db.py](src/db.py)**: PostgreSQL database with pgvector extension for embeddings +- **[src/frontend.py](src/frontend.py)**: Gradio web interface for chat + +### Plugin System + +The plugin architecture is the heart of this system. All plugins inherit from [src/plugins/base_plugin.py](src/plugins/base_plugin.py) and must: + +1. Define methods starting with `tool_` to expose functions to the LLM +2. Use Pydantic models for type-safe function parameters +3. Return JSON-formatted strings from tool functions +4. Optionally provide a `prompt()` method to inject context into the system prompt + +**Plugin discovery**: The backend automatically loads all folders in `src/plugins/` that contain a `plugin.py` file with a `Plugin` class. + +**Base plugin features**: +- Automatically converts `tool_*` methods into OpenAI function calling format +- Introspects Pydantic models to generate JSON schemas +- Provides access to Home Assistant helper class + +**Home Assistant integration**: Plugins can use `self.homeassistant` (from [src/plugins/homeassistant.py](src/plugins/homeassistant.py)) to interact with Home Assistant via REST API and WebSocket. + +### Template System + +System prompts are managed through Jinja2 templates in `src/templates/`: +- Templates are language-specific (e.g., `ask.en.j2`, `ask.sv.j2`) +- Plugins can inject their own prompts via the `prompt()` method +- Templates receive: current datetime, plugin prompts, and optional knowledge + +## Development Commands + +### Database Setup + +Start PostgreSQL with pgvector: +```bash +docker-compose up -d +``` + +Create database and enable pgvector extension: +```sql +CREATE DATABASE rag; +\c rag +CREATE EXTENSION vector; +``` + +Run schema initialization: +```bash +psql -U postgres -h localhost -p 5433 -d rag -f db.sql +``` + +### Running the Application + +Backend (FastAPI server): +```bash +cd src +python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml +``` + +Frontend (Gradio UI): +```bash +cd src +python frontend.py +``` + +### Configuration + +Copy `src/config.default.yml` to `src/config.yml` and configure: +- OpenAI API settings (or compatible base_url) +- Home Assistant URL and token +- Plugin-specific settings (Spotify, Västtrafik, etc.) + +## Working with Plugins + +### Creating a New Plugin + +1. Create folder: `src/plugins/myplugin/` +2. Create `plugin.py` with a `Plugin` class inheriting from `BasePlugin` +3. Define tool functions with `tool_` prefix +4. Use Pydantic models for parameters +5. Add plugin config to `config.yml` under `plugins.myplugin` + +Example structure: +```python +from pydantic import BaseModel, Field +from ..base_plugin import BasePlugin + +class MyInput(BaseModel): + param: str = Field(..., description="Parameter description") + +class Plugin(BasePlugin): + def tool_my_function(self, input: MyInput): + """Function description for LLM""" + # Implementation + return json.dumps({"status": "success"}) + + def prompt(self) -> str | None: + return "Additional context for the LLM" +``` + +### Testing Function Calling + +The LLM class in [src/llm.py](src/llm.py) handles parallel function execution automatically. When the LLM responds with tool calls, they are executed concurrently using `ThreadPoolExecutor`. + +## Important Notes + +- **Working directory**: The backend expects to be run from the `src/` directory due to relative imports and file paths +- **Database connection**: Hardcoded to localhost:5433 in [src/db.py](src/db.py:17-23) +- **Function calling**: Currently only fully supported with `OpenAIChat` backend +- **Response format**: Tool functions must return JSON strings for proper LLM interpretation +- **Logging**: All modules use Python's logging module at INFO level diff --git a/agent/src/agent.py b/agent/src/agent.py index fc64f63..5dab9d0 100644 --- a/agent/src/agent.py +++ b/agent/src/agent.py @@ -8,145 +8,36 @@ from db import Database from llm import LLM -def relative_to_actual_dates(text: str): - from dateparser.search import search_dates - - matches = search_dates(text, languages=["en", "se"]) - - if matches: - for match in matches: - text = text.replace(match[0], match[1].strftime("%Y-%m-%d")) - return text - - class Agent: def __init__(self, config: dict) -> None: self.config = config self.db = Database("rag") self.llm = LLM(config=self.config["llm"]) - def flush(self): - """ - Flushes the agents knowledge. - """ - logging.info("Truncating database") - self.db.cur.execute("TRUNCATE TABLE items") - self.db.conn.commit() - - def insert(self, text): - """ - Append knowledge. - """ - logging.info(f"Inserting item into embedding table ({text}).") - - #logging.info(f"\tReplacing relative dates with actual dates.") - #text = relative_to_actual_dates(text) - #logging.info(f"\tDone. text: {text}") - - vector = self.llm.embedding(text) - vector_padded = vector + ([0]*(2048-len(vector))) - - self.db.cur.execute( - "INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)", - (text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")), - ) - self.db.conn.commit() - - def keywords(self, query: str, n_keywords: int = 5): - """ - Suggest keywords related to a query. - """ - sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start." - prompt = ( - f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}." - + f"Do not add any context or text other than the json list of {n_keywords} words." - ) - response = self.llm.query(prompt, system_msg=sys_msg) - keywords = json.loads(response) - return keywords - - def retrieve(self, query: str, n_retrieve=10, n_rerank=5): - """ - Retrieve relevant knowledge. - - Parameters: - - n_retrieve (int): How many notes to retrieve. - - n_rerank (int): How many notes to keep after reranking. - """ - logging.info(f"Retrieving knowledge from database using the query: '{query}'") - - logging.debug(f"Using embedding model on '{query}'.") - vector_keywords = self.llm.embedding(query) - logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords)) - - logging.debug("Querying database") - vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords))) - self.db.cur.execute( - f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}", - (vector_padded,), - ) - db_response = self.db.cur.fetchall() - knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response] - logging.debug("Database returned: ") - logging.debug(knowledge_arr) - - if n_retrieve > n_rerank: - logging.info("Reranking the results") - reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank) - reranked_knowledge_arr = sorted( - knowledge_arr, - key=lambda x: ( - reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr) - ), - ) - reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank] - logging.debug("Reranked results: ") - logging.debug(reranked_knowledge_arr) - return reranked_knowledge_arr - else: - return knowledge_arr - def ask( self, question: str, - person: str, history: list = None, - n_retrieve: int = 5, - n_rerank=3, tools: Union[None, list] = None, prompts: Union[None, list] = None, + language: str = "en" ): """ Ask the agent a question. Parameters: - question (str): The question to ask - - person (str): Who asks the question? - - n_retrieve (int): How many notes to retrieve. - - n_rerank (int): How many notes to keep after reranking. """ logging.info(f"Answering question: {question}") - if n_retrieve > 0: - knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank) - - with open('templates/knowledge.en.j2', 'r') as file: - template = Template(file.read()) - - knowledge_str = template.render( - knowledge = knowledge, - ) - else: - knowledge_str = "" - - with open('templates/ask.en.j2', 'r') as file: + allowed_languages = ['en', 'sv'] + assert language in allowed_languages, "Language must be any of " + ", ".join(allowed_languages) + with open(f'templates/ask.{language}.j2', 'r') as file: template = Template(file.read()) prompt = template.render( today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), - knowledge = knowledge_str, - plugin_prompts = prompts, - person = person + plugin_prompts = prompts ) logging.debug(f"Asking the LLM with the prompt: '{prompt}'.") diff --git a/agent/src/backend.py b/agent/src/backend.py index ff16e6a..d6ad555 100644 --- a/agent/src/backend.py +++ b/agent/src/backend.py @@ -16,15 +16,19 @@ logging.basicConfig( level=logging.INFO, ) -with open("config.yml", "r", encoding='utf8') as file: +with open("config.yml", "r", encoding="utf8") as file: config = yaml.safe_load(file) + def load_plugins(config): tools = [] prompts = [] for plugin_folder in os.listdir("./" + PLUGINS_FOLDER): - if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder: + if ( + os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) + and "__pycache__" not in plugin_folder + ): logging.info(f"Loading plugin: {plugin_folder}") module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin" module = importlib.import_module(module_path) @@ -36,6 +40,7 @@ def load_plugins(config): return tools, prompts + tools, prompts = load_plugins(config) app = FastAPI() @@ -47,14 +52,18 @@ agent = Agent(config) class Query(BaseModel): query: str + class Message(BaseModel): role: str content: str + class Question(BaseModel): question: str - person: str = "User" history: list[Message] = None + language: str = "en" + device_id: str = "" + agent_id: str = "" @app.get("/ping") @@ -65,6 +74,7 @@ def ping(): """ return {"status": "ok"} + @app.post("/ask") def ask(q: Question): """ @@ -79,49 +89,5 @@ def ask(q: Question): if q.history: history = [q.model_dump() for q in q.history] - answer = agent.ask( - question=q.question, - person="Pierre", - history=history, - tools=tools, - prompts=prompts, - n_retrieve=0, - n_rerank=0 - ) + answer = agent.ask(question=q.question, history=history, tools=tools, prompts=prompts, language=q.language) return {"question": q.question, "answer": answer, "history": q.history} - - -@app.post("/retrieve") -def retrieve(q: Query): - """ - # Retrieve - Retrieve knowledge related to a query. - - ## Parameters - - **query**: the question. - """ - result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5) - return {"query": q.query, "answer": result} - - -@app.post("/learn") -def retrieve(q: Query): - """ - # Learn - Learn the agent something. E.g. *"John likes to play volleyball".* - - ## Parameters - - **query**: the note. - """ - answer = agent.insert(text=q.query) - return {"text": q.query} - - -@app.get("/flush") -def flush(): - """ - # Flush - Remove all knowledge from the agent. - """ - agent.flush() - return {"message": "Knowledge flushed."} diff --git a/agent/src/config.default.yml b/agent/src/config.default.yml index 52d6b3b..6d0cfd1 100644 --- a/agent/src/config.default.yml +++ b/agent/src/config.default.yml @@ -1,11 +1,6 @@ server: url: "http://127.0.0.1:8000" llm: - use_local_chat: false - use_local_embedding: false - local_model_dir: "" - local_embedding_model: "intfloat/e5-large-v2" - local_rerank_model: "cross-encoder/stsb-distilroberta-base" openai_embedding_model: "text-embedding-3-small" openai_chat_model: "gpt-3.5-turbo-0125" temperature: 0.1 @@ -22,4 +17,10 @@ plugins: secret: "" default_from_station: "Brunnsparken" default_to_station: "Centralstationen" - delay: 0 # minutes \ No newline at end of file + delay: 0 # minutes + music: + default_speaker: "media_player.kok" + device_speakers: + "d364e96a6664016ea4ac35f64a0d7498": "media_player.media_player.theo" + spotify_client_id: "" + spotify_client_secret: "" \ No newline at end of file diff --git a/agent/src/llm.py b/agent/src/llm.py index d7c5099..740bd67 100644 --- a/agent/src/llm.py +++ b/agent/src/llm.py @@ -5,13 +5,6 @@ import numpy as np from openai import OpenAI from typing import Union, List -try: - from llama_cpp.llama import Llama - from sentence_transformers import SentenceTransformer, CrossEncoder -except ImportError as e: - print("Failed to import packages:") - print(e) - class BaseChat: def __init__(self, config: dict) -> None: @@ -173,19 +166,7 @@ class LLM: """ self.config = config - if self.config["use_local_chat"]: - self.chat_client = LlamaChat(self.config) - else: - self.chat_client = OpenAIChat(self.config) - - if self.config["use_local_embedding"]: - self.embedding_model = self.config["local_embedding_model"] - self.rerank_model = self.config["local_rerank_model"] - self.embedding_client = SentenceTransformer(self.embedding_model) - else: - self.embedding_model = self.config["openai"]["embedding_model"] - self.rerank_model = None - self.embedding_client = OpenAI(api_key=self.config["openai"]["api_key"]) + self.chat_client = OpenAIChat(self.config) def query( self, @@ -236,22 +217,3 @@ class LLM: logging.info(f"Sending request to LLM with {len(messages)} messages.") answer = self.chat_client.chat(messages=messages, tools=tools) return answer - - def embedding(self, text: str): - if self.config["use_local_embedding"]: - embedding = self.embedding_client.encode(text) - return embedding # len = 1024 - else: - llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")]) - return llm_response.data[0].embedding - - def rerank(self, text: str, corpus: List[str], n_rerank: int = 5): - if self.rerank_model is not None and len(corpus) > 0: - sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus] - logging.info("Reranking:") - logging.info(sentence_combinations) - similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations) - result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))] - return result[:n_rerank] - else: - return corpus diff --git a/agent/src/plugins/music/plugin.py b/agent/src/plugins/music/plugin.py index d496b5a..4b8d5b6 100644 --- a/agent/src/plugins/music/plugin.py +++ b/agent/src/plugins/music/plugin.py @@ -15,6 +15,7 @@ class PlayMusicInput(BaseModel): class Plugin(BasePlugin): def __init__(self, config: dict) -> None: super().__init__(config=config) + self.entity_id = self.config["plugins"]["music"]["default_speaker"] self.spotify = spotipy.Spotify( auth_manager=SpotifyOAuth( @@ -46,7 +47,7 @@ class Plugin(BasePlugin): track = self._search(input.query, limit=1)[0] logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}") payload = { - "entity_id": self.config["plugins"]["music"]["default_speaker"], + "entity_id": self.entity_id, "media_content_id": track["uri"], "media_content_type": "music", "enqueue": "play", @@ -59,6 +60,6 @@ class Plugin(BasePlugin): Stop playback of music. """ self.homeassistant.call_api( - f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]} + f"services/media_player/media_pause", payload={"entity_id": self.entity_id} ) return json.dumps({"status": "success", "message": f"Music paused."}) diff --git a/agent/src/templates/ask.en.j2 b/agent/src/templates/ask.en.j2 index 9fbe5c0..c45ccb9 100644 --- a/agent/src/templates/ask.en.j2 +++ b/agent/src/templates/ask.en.j2 @@ -9,4 +9,4 @@ Today is {{ today }}. - {{ prompt }} {% endfor %} -Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question. \ No newline at end of file +Answer the questions to the best of your knowledge. \ No newline at end of file diff --git a/agent/src/templates/ask.sv.j2 b/agent/src/templates/ask.sv.j2 new file mode 100644 index 0000000..7665354 --- /dev/null +++ b/agent/src/templates/ask.sv.j2 @@ -0,0 +1,12 @@ +Du är en hjälpsam assistent. Dina svar kommer att konverteras till ljud/tal, så svara kort och koncist och undvik för många specialtecken. +Idag är det {{ today }}. + +{% if knowledge is not none %} + {{ knowledge }} +{% endif %} + +{% for prompt in plugin_prompts %} +- {{ prompt }} +{% endfor %} + +Svara på frågorna på svenska efter bästa förmåga. \ No newline at end of file