Compare commits
3 Commits
5ecd5ded3e
...
5dec833544
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5dec833544 | ||
|
|
b6c0d6d752 | ||
|
|
79363786d8 |
1
agent/.gitignore
vendored
Normal file
1
agent/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
config.yml
|
||||||
127
agent/CLAUDE.md
Normal file
127
agent/CLAUDE.md
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
This is a voice assistant agent that integrates with Home Assistant and provides tool-calling capabilities through a plugin architecture. The system uses OpenAI's API (or compatible endpoints) for LLM interactions and supports multiple languages (English and Swedish).
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
- **[src/backend.py](src/backend.py)**: FastAPI server that loads plugins and exposes the `/ask` endpoint
|
||||||
|
- **[src/agent.py](src/agent.py)**: Main Agent class that orchestrates LLM queries with plugin tools
|
||||||
|
- **[src/llm.py](src/llm.py)**: LLM abstraction layer supporting OpenAI, LMStudio, and Llama models
|
||||||
|
- `OpenAIChat`: Handles OpenAI API calls with function calling support (parallel execution)
|
||||||
|
- `LMStudioChat`: Local LLM via LMStudio
|
||||||
|
- `LlamaChat`: Direct llama.cpp integration
|
||||||
|
- **[src/db.py](src/db.py)**: PostgreSQL database with pgvector extension for embeddings
|
||||||
|
- **[src/frontend.py](src/frontend.py)**: Gradio web interface for chat
|
||||||
|
|
||||||
|
### Plugin System
|
||||||
|
|
||||||
|
The plugin architecture is the heart of this system. All plugins inherit from [src/plugins/base_plugin.py](src/plugins/base_plugin.py) and must:
|
||||||
|
|
||||||
|
1. Define methods starting with `tool_` to expose functions to the LLM
|
||||||
|
2. Use Pydantic models for type-safe function parameters
|
||||||
|
3. Return JSON-formatted strings from tool functions
|
||||||
|
4. Optionally provide a `prompt()` method to inject context into the system prompt
|
||||||
|
|
||||||
|
**Plugin discovery**: The backend automatically loads all folders in `src/plugins/` that contain a `plugin.py` file with a `Plugin` class.
|
||||||
|
|
||||||
|
**Base plugin features**:
|
||||||
|
- Automatically converts `tool_*` methods into OpenAI function calling format
|
||||||
|
- Introspects Pydantic models to generate JSON schemas
|
||||||
|
- Provides access to Home Assistant helper class
|
||||||
|
|
||||||
|
**Home Assistant integration**: Plugins can use `self.homeassistant` (from [src/plugins/homeassistant.py](src/plugins/homeassistant.py)) to interact with Home Assistant via REST API and WebSocket.
|
||||||
|
|
||||||
|
### Template System
|
||||||
|
|
||||||
|
System prompts are managed through Jinja2 templates in `src/templates/`:
|
||||||
|
- Templates are language-specific (e.g., `ask.en.j2`, `ask.sv.j2`)
|
||||||
|
- Plugins can inject their own prompts via the `prompt()` method
|
||||||
|
- Templates receive: current datetime, plugin prompts, and optional knowledge
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
Start PostgreSQL with pgvector:
|
||||||
|
```bash
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Create database and enable pgvector extension:
|
||||||
|
```sql
|
||||||
|
CREATE DATABASE rag;
|
||||||
|
\c rag
|
||||||
|
CREATE EXTENSION vector;
|
||||||
|
```
|
||||||
|
|
||||||
|
Run schema initialization:
|
||||||
|
```bash
|
||||||
|
psql -U postgres -h localhost -p 5433 -d rag -f db.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running the Application
|
||||||
|
|
||||||
|
Backend (FastAPI server):
|
||||||
|
```bash
|
||||||
|
cd src
|
||||||
|
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
Frontend (Gradio UI):
|
||||||
|
```bash
|
||||||
|
cd src
|
||||||
|
python frontend.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Copy `src/config.default.yml` to `src/config.yml` and configure:
|
||||||
|
- OpenAI API settings (or compatible base_url)
|
||||||
|
- Home Assistant URL and token
|
||||||
|
- Plugin-specific settings (Spotify, Västtrafik, etc.)
|
||||||
|
|
||||||
|
## Working with Plugins
|
||||||
|
|
||||||
|
### Creating a New Plugin
|
||||||
|
|
||||||
|
1. Create folder: `src/plugins/myplugin/`
|
||||||
|
2. Create `plugin.py` with a `Plugin` class inheriting from `BasePlugin`
|
||||||
|
3. Define tool functions with `tool_` prefix
|
||||||
|
4. Use Pydantic models for parameters
|
||||||
|
5. Add plugin config to `config.yml` under `plugins.myplugin`
|
||||||
|
|
||||||
|
Example structure:
|
||||||
|
```python
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from ..base_plugin import BasePlugin
|
||||||
|
|
||||||
|
class MyInput(BaseModel):
|
||||||
|
param: str = Field(..., description="Parameter description")
|
||||||
|
|
||||||
|
class Plugin(BasePlugin):
|
||||||
|
def tool_my_function(self, input: MyInput):
|
||||||
|
"""Function description for LLM"""
|
||||||
|
# Implementation
|
||||||
|
return json.dumps({"status": "success"})
|
||||||
|
|
||||||
|
def prompt(self) -> str | None:
|
||||||
|
return "Additional context for the LLM"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Function Calling
|
||||||
|
|
||||||
|
The LLM class in [src/llm.py](src/llm.py) handles parallel function execution automatically. When the LLM responds with tool calls, they are executed concurrently using `ThreadPoolExecutor`.
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- **Working directory**: The backend expects to be run from the `src/` directory due to relative imports and file paths
|
||||||
|
- **Database connection**: Hardcoded to localhost:5433 in [src/db.py](src/db.py:17-23)
|
||||||
|
- **Function calling**: Currently only fully supported with `OpenAIChat` backend
|
||||||
|
- **Response format**: Tool functions must return JSON strings for proper LLM interpretation
|
||||||
|
- **Logging**: All modules use Python's logging module at INFO level
|
||||||
@ -8,145 +8,37 @@ from db import Database
|
|||||||
from llm import LLM
|
from llm import LLM
|
||||||
|
|
||||||
|
|
||||||
def relative_to_actual_dates(text: str):
|
|
||||||
from dateparser.search import search_dates
|
|
||||||
|
|
||||||
matches = search_dates(text, languages=["en", "se"])
|
|
||||||
|
|
||||||
if matches:
|
|
||||||
for match in matches:
|
|
||||||
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.db = Database("rag")
|
self.db = Database("rag")
|
||||||
self.llm = LLM(config=self.config["llm"])
|
self.llm = LLM(config=self.config["llm"])
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""
|
|
||||||
Flushes the agents knowledge.
|
|
||||||
"""
|
|
||||||
logging.info("Truncating database")
|
|
||||||
self.db.cur.execute("TRUNCATE TABLE items")
|
|
||||||
self.db.conn.commit()
|
|
||||||
|
|
||||||
def insert(self, text):
|
|
||||||
"""
|
|
||||||
Append knowledge.
|
|
||||||
"""
|
|
||||||
logging.info(f"Inserting item into embedding table ({text}).")
|
|
||||||
|
|
||||||
#logging.info(f"\tReplacing relative dates with actual dates.")
|
|
||||||
#text = relative_to_actual_dates(text)
|
|
||||||
#logging.info(f"\tDone. text: {text}")
|
|
||||||
|
|
||||||
vector = self.llm.embedding(text)
|
|
||||||
vector_padded = vector + ([0]*(2048-len(vector)))
|
|
||||||
|
|
||||||
self.db.cur.execute(
|
|
||||||
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
|
|
||||||
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
|
|
||||||
)
|
|
||||||
self.db.conn.commit()
|
|
||||||
|
|
||||||
def keywords(self, query: str, n_keywords: int = 5):
|
|
||||||
"""
|
|
||||||
Suggest keywords related to a query.
|
|
||||||
"""
|
|
||||||
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
|
|
||||||
prompt = (
|
|
||||||
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
|
|
||||||
+ f"Do not add any context or text other than the json list of {n_keywords} words."
|
|
||||||
)
|
|
||||||
response = self.llm.query(prompt, system_msg=sys_msg)
|
|
||||||
keywords = json.loads(response)
|
|
||||||
return keywords
|
|
||||||
|
|
||||||
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
|
|
||||||
"""
|
|
||||||
Retrieve relevant knowledge.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- n_retrieve (int): How many notes to retrieve.
|
|
||||||
- n_rerank (int): How many notes to keep after reranking.
|
|
||||||
"""
|
|
||||||
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
|
|
||||||
|
|
||||||
logging.debug(f"Using embedding model on '{query}'.")
|
|
||||||
vector_keywords = self.llm.embedding(query)
|
|
||||||
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
|
|
||||||
|
|
||||||
logging.debug("Querying database")
|
|
||||||
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
|
|
||||||
self.db.cur.execute(
|
|
||||||
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
|
|
||||||
(vector_padded,),
|
|
||||||
)
|
|
||||||
db_response = self.db.cur.fetchall()
|
|
||||||
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
|
|
||||||
logging.debug("Database returned: ")
|
|
||||||
logging.debug(knowledge_arr)
|
|
||||||
|
|
||||||
if n_retrieve > n_rerank:
|
|
||||||
logging.info("Reranking the results")
|
|
||||||
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
|
|
||||||
reranked_knowledge_arr = sorted(
|
|
||||||
knowledge_arr,
|
|
||||||
key=lambda x: (
|
|
||||||
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
|
|
||||||
logging.debug("Reranked results: ")
|
|
||||||
logging.debug(reranked_knowledge_arr)
|
|
||||||
return reranked_knowledge_arr
|
|
||||||
else:
|
|
||||||
return knowledge_arr
|
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
person: str,
|
|
||||||
history: list = None,
|
history: list = None,
|
||||||
n_retrieve: int = 5,
|
|
||||||
n_rerank=3,
|
|
||||||
tools: Union[None, list] = None,
|
tools: Union[None, list] = None,
|
||||||
prompts: Union[None, list] = None,
|
prompts: Union[None, list] = None,
|
||||||
|
language: str = "en",
|
||||||
|
device_id: str = ""
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ask the agent a question.
|
Ask the agent a question.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- question (str): The question to ask
|
- question (str): The question to ask
|
||||||
- person (str): Who asks the question?
|
|
||||||
- n_retrieve (int): How many notes to retrieve.
|
|
||||||
- n_rerank (int): How many notes to keep after reranking.
|
|
||||||
"""
|
"""
|
||||||
logging.info(f"Answering question: {question}")
|
logging.info(f"Answering question: {question}")
|
||||||
|
|
||||||
if n_retrieve > 0:
|
allowed_languages = ['en', 'sv']
|
||||||
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
|
assert language in allowed_languages, "Language must be any of " + ", ".join(allowed_languages)
|
||||||
|
with open(f'templates/ask.{language}.j2', 'r') as file:
|
||||||
with open('templates/knowledge.en.j2', 'r') as file:
|
|
||||||
template = Template(file.read())
|
|
||||||
|
|
||||||
knowledge_str = template.render(
|
|
||||||
knowledge = knowledge,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
knowledge_str = ""
|
|
||||||
|
|
||||||
with open('templates/ask.en.j2', 'r') as file:
|
|
||||||
template = Template(file.read())
|
template = Template(file.read())
|
||||||
|
|
||||||
prompt = template.render(
|
prompt = template.render(
|
||||||
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||||
knowledge = knowledge_str,
|
plugin_prompts = prompts
|
||||||
plugin_prompts = prompts,
|
|
||||||
person = person
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
||||||
|
|||||||
@ -16,27 +16,34 @@ logging.basicConfig(
|
|||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("config.yml", "r", encoding='utf8') as file:
|
with open("config.yml", "r", encoding="utf8") as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
||||||
def load_plugins(config):
|
def load_plugins(config):
|
||||||
tools = []
|
tools = []
|
||||||
prompts = []
|
prompts = []
|
||||||
|
plugin_instances = []
|
||||||
|
|
||||||
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
||||||
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
|
if (
|
||||||
|
os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder))
|
||||||
|
and "__pycache__" not in plugin_folder
|
||||||
|
):
|
||||||
logging.info(f"Loading plugin: {plugin_folder}")
|
logging.info(f"Loading plugin: {plugin_folder}")
|
||||||
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
plugin = module.Plugin(config)
|
plugin = module.Plugin(config)
|
||||||
|
plugin_instances.append(plugin)
|
||||||
tools += plugin.tools()
|
tools += plugin.tools()
|
||||||
prompt = plugin.prompt()
|
prompt = plugin.prompt()
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
return tools, prompts
|
return tools, prompts, plugin_instances
|
||||||
|
|
||||||
tools, prompts = load_plugins(config)
|
|
||||||
|
tools, prompts, plugins = load_plugins(config)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
@ -47,14 +54,18 @@ agent = Agent(config)
|
|||||||
class Query(BaseModel):
|
class Query(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class Question(BaseModel):
|
class Question(BaseModel):
|
||||||
question: str
|
question: str
|
||||||
person: str = "User"
|
|
||||||
history: list[Message] = None
|
history: list[Message] = None
|
||||||
|
language: str = "en"
|
||||||
|
device_id: str = ""
|
||||||
|
agent_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ping")
|
@app.get("/ping")
|
||||||
@ -65,6 +76,7 @@ def ping():
|
|||||||
"""
|
"""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/ask")
|
@app.post("/ask")
|
||||||
def ask(q: Question):
|
def ask(q: Question):
|
||||||
"""
|
"""
|
||||||
@ -79,49 +91,9 @@ def ask(q: Question):
|
|||||||
if q.history:
|
if q.history:
|
||||||
history = [q.model_dump() for q in q.history]
|
history = [q.model_dump() for q in q.history]
|
||||||
|
|
||||||
answer = agent.ask(
|
# Set device_id on all plugin instances
|
||||||
question=q.question,
|
for plugin in plugins:
|
||||||
person="Pierre",
|
plugin.device_id = q.device_id
|
||||||
history=history,
|
|
||||||
tools=tools,
|
answer = agent.ask(question=q.question, history=history, tools=tools, prompts=prompts, language=q.language, device_id=q.device_id)
|
||||||
prompts=prompts,
|
|
||||||
n_retrieve=0,
|
|
||||||
n_rerank=0
|
|
||||||
)
|
|
||||||
return {"question": q.question, "answer": answer, "history": q.history}
|
return {"question": q.question, "answer": answer, "history": q.history}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/retrieve")
|
|
||||||
def retrieve(q: Query):
|
|
||||||
"""
|
|
||||||
# Retrieve
|
|
||||||
Retrieve knowledge related to a query.
|
|
||||||
|
|
||||||
## Parameters
|
|
||||||
- **query**: the question.
|
|
||||||
"""
|
|
||||||
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
|
|
||||||
return {"query": q.query, "answer": result}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/learn")
|
|
||||||
def retrieve(q: Query):
|
|
||||||
"""
|
|
||||||
# Learn
|
|
||||||
Learn the agent something. E.g. *"John likes to play volleyball".*
|
|
||||||
|
|
||||||
## Parameters
|
|
||||||
- **query**: the note.
|
|
||||||
"""
|
|
||||||
answer = agent.insert(text=q.query)
|
|
||||||
return {"text": q.query}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/flush")
|
|
||||||
def flush():
|
|
||||||
"""
|
|
||||||
# Flush
|
|
||||||
Remove all knowledge from the agent.
|
|
||||||
"""
|
|
||||||
agent.flush()
|
|
||||||
return {"message": "Knowledge flushed."}
|
|
||||||
|
|||||||
@ -1,11 +1,6 @@
|
|||||||
server:
|
server:
|
||||||
url: "http://127.0.0.1:8000"
|
url: "http://127.0.0.1:8000"
|
||||||
llm:
|
llm:
|
||||||
use_local_chat: false
|
|
||||||
use_local_embedding: false
|
|
||||||
local_model_dir: ""
|
|
||||||
local_embedding_model: "intfloat/e5-large-v2"
|
|
||||||
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
|
|
||||||
openai_embedding_model: "text-embedding-3-small"
|
openai_embedding_model: "text-embedding-3-small"
|
||||||
openai_chat_model: "gpt-3.5-turbo-0125"
|
openai_chat_model: "gpt-3.5-turbo-0125"
|
||||||
temperature: 0.1
|
temperature: 0.1
|
||||||
@ -23,3 +18,9 @@ plugins:
|
|||||||
default_from_station: "Brunnsparken"
|
default_from_station: "Brunnsparken"
|
||||||
default_to_station: "Centralstationen"
|
default_to_station: "Centralstationen"
|
||||||
delay: 0 # minutes
|
delay: 0 # minutes
|
||||||
|
music:
|
||||||
|
default_speaker: "media_player.kok"
|
||||||
|
device_speakers:
|
||||||
|
"d364e96a6664016ea4ac35f64a0d7498": "media_player.media_player.theo"
|
||||||
|
spotify_client_id: ""
|
||||||
|
spotify_client_secret: ""
|
||||||
@ -5,13 +5,6 @@ import numpy as np
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
try:
|
|
||||||
from llama_cpp.llama import Llama
|
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
||||||
except ImportError as e:
|
|
||||||
print("Faield to import packages:")
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChat:
|
class BaseChat:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
@ -173,19 +166,7 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if self.config["use_local_chat"]:
|
self.chat_client = OpenAIChat(self.config)
|
||||||
self.chat_client = LlamaChat(self.config)
|
|
||||||
else:
|
|
||||||
self.chat_client = OpenAIChat(self.config)
|
|
||||||
|
|
||||||
if self.config["use_local_embedding"]:
|
|
||||||
self.embedding_model = self.config["local_embedding_model"]
|
|
||||||
self.rerank_model = self.config["local_rerank_model"]
|
|
||||||
self.embedding_client = SentenceTransformer(self.embedding_model)
|
|
||||||
else:
|
|
||||||
self.embedding_model = self.config["openai"]["embedding_model"]
|
|
||||||
self.rerank_model = None
|
|
||||||
self.embedding_client = OpenAI(api_key=self.config["openai"]["api_key"])
|
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
@ -236,22 +217,3 @@ class LLM:
|
|||||||
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
||||||
answer = self.chat_client.chat(messages=messages, tools=tools)
|
answer = self.chat_client.chat(messages=messages, tools=tools)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
def embedding(self, text: str):
|
|
||||||
if self.config["use_local_embedding"]:
|
|
||||||
embedding = self.embedding_client.encode(text)
|
|
||||||
return embedding # len = 1024
|
|
||||||
else:
|
|
||||||
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
|
|
||||||
return llm_response.data[0].embedding
|
|
||||||
|
|
||||||
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
|
|
||||||
if self.rerank_model is not None and len(corpus) > 0:
|
|
||||||
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
|
|
||||||
logging.info("Reranking:")
|
|
||||||
logging.info(sentence_combinations)
|
|
||||||
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
|
|
||||||
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
|
|
||||||
return result[:n_rerank]
|
|
||||||
else:
|
|
||||||
return corpus
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ class BasePlugin:
|
|||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.homeassistant = HomeAssistant(config)
|
self.homeassistant = HomeAssistant(config)
|
||||||
|
self.device_id = ""
|
||||||
|
|
||||||
def prompt(self) -> str | None:
|
def prompt(self) -> str | None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -25,6 +25,22 @@ class Plugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_speaker_for_device(self, device_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the appropriate speaker entity_id based on device_id.
|
||||||
|
Falls back to default_speaker if device_id is not found or empty.
|
||||||
|
"""
|
||||||
|
device_speakers = self.config["plugins"]["music"].get("device_speakers", {})
|
||||||
|
|
||||||
|
if device_id and device_id in device_speakers:
|
||||||
|
speaker = device_speakers[device_id]
|
||||||
|
logging.info(f"Using device-specific speaker for {device_id}: {speaker}")
|
||||||
|
return speaker
|
||||||
|
|
||||||
|
default_speaker = self.config["plugins"]["music"]["default_speaker"]
|
||||||
|
logging.info(f"Using default speaker: {default_speaker}")
|
||||||
|
return default_speaker
|
||||||
|
|
||||||
def _search(self, query: str, limit: int = 10):
|
def _search(self, query: str, limit: int = 10):
|
||||||
_result = self.spotify.search(query, limit=limit)
|
_result = self.spotify.search(query, limit=limit)
|
||||||
result = []
|
result = []
|
||||||
@ -44,9 +60,10 @@ class Plugin(BasePlugin):
|
|||||||
Play music using a search query.
|
Play music using a search query.
|
||||||
"""
|
"""
|
||||||
track = self._search(input.query, limit=1)[0]
|
track = self._search(input.query, limit=1)[0]
|
||||||
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
|
speaker = self._get_speaker_for_device(self.device_id)
|
||||||
|
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])} on {speaker}")
|
||||||
payload = {
|
payload = {
|
||||||
"entity_id": self.config["plugins"]["music"]["default_speaker"],
|
"entity_id": speaker,
|
||||||
"media_content_id": track["uri"],
|
"media_content_id": track["uri"],
|
||||||
"media_content_type": "music",
|
"media_content_type": "music",
|
||||||
"enqueue": "play",
|
"enqueue": "play",
|
||||||
@ -58,7 +75,8 @@ class Plugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
Stop playback of music.
|
Stop playback of music.
|
||||||
"""
|
"""
|
||||||
|
speaker = self._get_speaker_for_device(self.device_id)
|
||||||
self.homeassistant.call_api(
|
self.homeassistant.call_api(
|
||||||
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
|
f"services/media_player/media_pause", payload={"entity_id": speaker}
|
||||||
)
|
)
|
||||||
return json.dumps({"status": "success", "message": f"Music paused."})
|
return json.dumps({"status": "success", "message": f"Music paused."})
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
requests
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
openai>=1.11.1
|
openai>=1.11.1
|
||||||
|
|||||||
@ -9,4 +9,4 @@ Today is {{ today }}.
|
|||||||
- {{ prompt }}
|
- {{ prompt }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.
|
Answer the questions to the best of your knowledge.
|
||||||
12
agent/src/templates/ask.sv.j2
Normal file
12
agent/src/templates/ask.sv.j2
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
Du är en hjälpsam assistent. Dina svar kommer att konverteras till ljud/tal, så svara kort och koncist och undvik för många specialtecken.
|
||||||
|
Idag är det {{ today }}.
|
||||||
|
|
||||||
|
{% if knowledge is not none %}
|
||||||
|
{{ knowledge }}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for prompt in plugin_prompts %}
|
||||||
|
- {{ prompt }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Svara på frågorna på svenska efter bästa förmåga.
|
||||||
Loading…
x
Reference in New Issue
Block a user