Compare commits
No commits in common. "5dec83354454900fd6ab5b25f0db385acdb10fd2" and "5ecd5ded3eb26276ff6cdbafc786019c87397400" have entirely different histories.
5dec833544
...
5ecd5ded3e
1
agent/.gitignore
vendored
1
agent/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
config.yml
|
|
||||||
127
agent/CLAUDE.md
127
agent/CLAUDE.md
@ -1,127 +0,0 @@
|
|||||||
# CLAUDE.md
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
This is a voice assistant agent that integrates with Home Assistant and provides tool-calling capabilities through a plugin architecture. The system uses OpenAI's API (or compatible endpoints) for LLM interactions and supports multiple languages (English and Swedish).
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Core Components
|
|
||||||
|
|
||||||
- **[src/backend.py](src/backend.py)**: FastAPI server that loads plugins and exposes the `/ask` endpoint
|
|
||||||
- **[src/agent.py](src/agent.py)**: Main Agent class that orchestrates LLM queries with plugin tools
|
|
||||||
- **[src/llm.py](src/llm.py)**: LLM abstraction layer supporting OpenAI, LMStudio, and Llama models
|
|
||||||
- `OpenAIChat`: Handles OpenAI API calls with function calling support (parallel execution)
|
|
||||||
- `LMStudioChat`: Local LLM via LMStudio
|
|
||||||
- `LlamaChat`: Direct llama.cpp integration
|
|
||||||
- **[src/db.py](src/db.py)**: PostgreSQL database with pgvector extension for embeddings
|
|
||||||
- **[src/frontend.py](src/frontend.py)**: Gradio web interface for chat
|
|
||||||
|
|
||||||
### Plugin System
|
|
||||||
|
|
||||||
The plugin architecture is the heart of this system. All plugins inherit from [src/plugins/base_plugin.py](src/plugins/base_plugin.py) and must:
|
|
||||||
|
|
||||||
1. Define methods starting with `tool_` to expose functions to the LLM
|
|
||||||
2. Use Pydantic models for type-safe function parameters
|
|
||||||
3. Return JSON-formatted strings from tool functions
|
|
||||||
4. Optionally provide a `prompt()` method to inject context into the system prompt
|
|
||||||
|
|
||||||
**Plugin discovery**: The backend automatically loads all folders in `src/plugins/` that contain a `plugin.py` file with a `Plugin` class.
|
|
||||||
|
|
||||||
**Base plugin features**:
|
|
||||||
- Automatically converts `tool_*` methods into OpenAI function calling format
|
|
||||||
- Introspects Pydantic models to generate JSON schemas
|
|
||||||
- Provides access to Home Assistant helper class
|
|
||||||
|
|
||||||
**Home Assistant integration**: Plugins can use `self.homeassistant` (from [src/plugins/homeassistant.py](src/plugins/homeassistant.py)) to interact with Home Assistant via REST API and WebSocket.
|
|
||||||
|
|
||||||
### Template System
|
|
||||||
|
|
||||||
System prompts are managed through Jinja2 templates in `src/templates/`:
|
|
||||||
- Templates are language-specific (e.g., `ask.en.j2`, `ask.sv.j2`)
|
|
||||||
- Plugins can inject their own prompts via the `prompt()` method
|
|
||||||
- Templates receive: current datetime, plugin prompts, and optional knowledge
|
|
||||||
|
|
||||||
## Development Commands
|
|
||||||
|
|
||||||
### Database Setup
|
|
||||||
|
|
||||||
Start PostgreSQL with pgvector:
|
|
||||||
```bash
|
|
||||||
docker-compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
Create database and enable pgvector extension:
|
|
||||||
```sql
|
|
||||||
CREATE DATABASE rag;
|
|
||||||
\c rag
|
|
||||||
CREATE EXTENSION vector;
|
|
||||||
```
|
|
||||||
|
|
||||||
Run schema initialization:
|
|
||||||
```bash
|
|
||||||
psql -U postgres -h localhost -p 5433 -d rag -f db.sql
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running the Application
|
|
||||||
|
|
||||||
Backend (FastAPI server):
|
|
||||||
```bash
|
|
||||||
cd src
|
|
||||||
python -m uvicorn backend:app --host 0.0.0.0 --reload --reload-include config.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
Frontend (Gradio UI):
|
|
||||||
```bash
|
|
||||||
cd src
|
|
||||||
python frontend.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
Copy `src/config.default.yml` to `src/config.yml` and configure:
|
|
||||||
- OpenAI API settings (or compatible base_url)
|
|
||||||
- Home Assistant URL and token
|
|
||||||
- Plugin-specific settings (Spotify, Västtrafik, etc.)
|
|
||||||
|
|
||||||
## Working with Plugins
|
|
||||||
|
|
||||||
### Creating a New Plugin
|
|
||||||
|
|
||||||
1. Create folder: `src/plugins/myplugin/`
|
|
||||||
2. Create `plugin.py` with a `Plugin` class inheriting from `BasePlugin`
|
|
||||||
3. Define tool functions with `tool_` prefix
|
|
||||||
4. Use Pydantic models for parameters
|
|
||||||
5. Add plugin config to `config.yml` under `plugins.myplugin`
|
|
||||||
|
|
||||||
Example structure:
|
|
||||||
```python
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from ..base_plugin import BasePlugin
|
|
||||||
|
|
||||||
class MyInput(BaseModel):
|
|
||||||
param: str = Field(..., description="Parameter description")
|
|
||||||
|
|
||||||
class Plugin(BasePlugin):
|
|
||||||
def tool_my_function(self, input: MyInput):
|
|
||||||
"""Function description for LLM"""
|
|
||||||
# Implementation
|
|
||||||
return json.dumps({"status": "success"})
|
|
||||||
|
|
||||||
def prompt(self) -> str | None:
|
|
||||||
return "Additional context for the LLM"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing Function Calling
|
|
||||||
|
|
||||||
The LLM class in [src/llm.py](src/llm.py) handles parallel function execution automatically. When the LLM responds with tool calls, they are executed concurrently using `ThreadPoolExecutor`.
|
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
- **Working directory**: The backend expects to be run from the `src/` directory due to relative imports and file paths
|
|
||||||
- **Database connection**: Hardcoded to localhost:5433 in [src/db.py](src/db.py:17-23)
|
|
||||||
- **Function calling**: Currently only fully supported with `OpenAIChat` backend
|
|
||||||
- **Response format**: Tool functions must return JSON strings for proper LLM interpretation
|
|
||||||
- **Logging**: All modules use Python's logging module at INFO level
|
|
||||||
@ -8,37 +8,145 @@ from db import Database
|
|||||||
from llm import LLM
|
from llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def relative_to_actual_dates(text: str):
|
||||||
|
from dateparser.search import search_dates
|
||||||
|
|
||||||
|
matches = search_dates(text, languages=["en", "se"])
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
for match in matches:
|
||||||
|
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.db = Database("rag")
|
self.db = Database("rag")
|
||||||
self.llm = LLM(config=self.config["llm"])
|
self.llm = LLM(config=self.config["llm"])
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
Flushes the agents knowledge.
|
||||||
|
"""
|
||||||
|
logging.info("Truncating database")
|
||||||
|
self.db.cur.execute("TRUNCATE TABLE items")
|
||||||
|
self.db.conn.commit()
|
||||||
|
|
||||||
|
def insert(self, text):
|
||||||
|
"""
|
||||||
|
Append knowledge.
|
||||||
|
"""
|
||||||
|
logging.info(f"Inserting item into embedding table ({text}).")
|
||||||
|
|
||||||
|
#logging.info(f"\tReplacing relative dates with actual dates.")
|
||||||
|
#text = relative_to_actual_dates(text)
|
||||||
|
#logging.info(f"\tDone. text: {text}")
|
||||||
|
|
||||||
|
vector = self.llm.embedding(text)
|
||||||
|
vector_padded = vector + ([0]*(2048-len(vector)))
|
||||||
|
|
||||||
|
self.db.cur.execute(
|
||||||
|
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
|
||||||
|
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||||
|
)
|
||||||
|
self.db.conn.commit()
|
||||||
|
|
||||||
|
def keywords(self, query: str, n_keywords: int = 5):
|
||||||
|
"""
|
||||||
|
Suggest keywords related to a query.
|
||||||
|
"""
|
||||||
|
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
|
||||||
|
prompt = (
|
||||||
|
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
|
||||||
|
+ f"Do not add any context or text other than the json list of {n_keywords} words."
|
||||||
|
)
|
||||||
|
response = self.llm.query(prompt, system_msg=sys_msg)
|
||||||
|
keywords = json.loads(response)
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
|
||||||
|
"""
|
||||||
|
Retrieve relevant knowledge.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- n_retrieve (int): How many notes to retrieve.
|
||||||
|
- n_rerank (int): How many notes to keep after reranking.
|
||||||
|
"""
|
||||||
|
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
|
||||||
|
|
||||||
|
logging.debug(f"Using embedding model on '{query}'.")
|
||||||
|
vector_keywords = self.llm.embedding(query)
|
||||||
|
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
|
||||||
|
|
||||||
|
logging.debug("Querying database")
|
||||||
|
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
|
||||||
|
self.db.cur.execute(
|
||||||
|
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
|
||||||
|
(vector_padded,),
|
||||||
|
)
|
||||||
|
db_response = self.db.cur.fetchall()
|
||||||
|
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
|
||||||
|
logging.debug("Database returned: ")
|
||||||
|
logging.debug(knowledge_arr)
|
||||||
|
|
||||||
|
if n_retrieve > n_rerank:
|
||||||
|
logging.info("Reranking the results")
|
||||||
|
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
|
||||||
|
reranked_knowledge_arr = sorted(
|
||||||
|
knowledge_arr,
|
||||||
|
key=lambda x: (
|
||||||
|
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
|
||||||
|
logging.debug("Reranked results: ")
|
||||||
|
logging.debug(reranked_knowledge_arr)
|
||||||
|
return reranked_knowledge_arr
|
||||||
|
else:
|
||||||
|
return knowledge_arr
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
|
person: str,
|
||||||
history: list = None,
|
history: list = None,
|
||||||
|
n_retrieve: int = 5,
|
||||||
|
n_rerank=3,
|
||||||
tools: Union[None, list] = None,
|
tools: Union[None, list] = None,
|
||||||
prompts: Union[None, list] = None,
|
prompts: Union[None, list] = None,
|
||||||
language: str = "en",
|
|
||||||
device_id: str = ""
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ask the agent a question.
|
Ask the agent a question.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- question (str): The question to ask
|
- question (str): The question to ask
|
||||||
|
- person (str): Who asks the question?
|
||||||
|
- n_retrieve (int): How many notes to retrieve.
|
||||||
|
- n_rerank (int): How many notes to keep after reranking.
|
||||||
"""
|
"""
|
||||||
logging.info(f"Answering question: {question}")
|
logging.info(f"Answering question: {question}")
|
||||||
|
|
||||||
allowed_languages = ['en', 'sv']
|
if n_retrieve > 0:
|
||||||
assert language in allowed_languages, "Language must be any of " + ", ".join(allowed_languages)
|
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
|
||||||
with open(f'templates/ask.{language}.j2', 'r') as file:
|
|
||||||
|
with open('templates/knowledge.en.j2', 'r') as file:
|
||||||
|
template = Template(file.read())
|
||||||
|
|
||||||
|
knowledge_str = template.render(
|
||||||
|
knowledge = knowledge,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
knowledge_str = ""
|
||||||
|
|
||||||
|
with open('templates/ask.en.j2', 'r') as file:
|
||||||
template = Template(file.read())
|
template = Template(file.read())
|
||||||
|
|
||||||
prompt = template.render(
|
prompt = template.render(
|
||||||
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||||
plugin_prompts = prompts
|
knowledge = knowledge_str,
|
||||||
|
plugin_prompts = prompts,
|
||||||
|
person = person
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
||||||
|
|||||||
@ -16,34 +16,27 @@ logging.basicConfig(
|
|||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("config.yml", "r", encoding="utf8") as file:
|
with open("config.yml", "r", encoding='utf8') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
|
||||||
def load_plugins(config):
|
def load_plugins(config):
|
||||||
tools = []
|
tools = []
|
||||||
prompts = []
|
prompts = []
|
||||||
plugin_instances = []
|
|
||||||
|
|
||||||
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
||||||
if (
|
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
|
||||||
os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder))
|
|
||||||
and "__pycache__" not in plugin_folder
|
|
||||||
):
|
|
||||||
logging.info(f"Loading plugin: {plugin_folder}")
|
logging.info(f"Loading plugin: {plugin_folder}")
|
||||||
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
plugin = module.Plugin(config)
|
plugin = module.Plugin(config)
|
||||||
plugin_instances.append(plugin)
|
|
||||||
tools += plugin.tools()
|
tools += plugin.tools()
|
||||||
prompt = plugin.prompt()
|
prompt = plugin.prompt()
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
return tools, prompts, plugin_instances
|
return tools, prompts
|
||||||
|
|
||||||
|
tools, prompts = load_plugins(config)
|
||||||
tools, prompts, plugins = load_plugins(config)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
@ -54,18 +47,14 @@ agent = Agent(config)
|
|||||||
class Query(BaseModel):
|
class Query(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class Question(BaseModel):
|
class Question(BaseModel):
|
||||||
question: str
|
question: str
|
||||||
|
person: str = "User"
|
||||||
history: list[Message] = None
|
history: list[Message] = None
|
||||||
language: str = "en"
|
|
||||||
device_id: str = ""
|
|
||||||
agent_id: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ping")
|
@app.get("/ping")
|
||||||
@ -76,7 +65,6 @@ def ping():
|
|||||||
"""
|
"""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/ask")
|
@app.post("/ask")
|
||||||
def ask(q: Question):
|
def ask(q: Question):
|
||||||
"""
|
"""
|
||||||
@ -91,9 +79,49 @@ def ask(q: Question):
|
|||||||
if q.history:
|
if q.history:
|
||||||
history = [q.model_dump() for q in q.history]
|
history = [q.model_dump() for q in q.history]
|
||||||
|
|
||||||
# Set device_id on all plugin instances
|
answer = agent.ask(
|
||||||
for plugin in plugins:
|
question=q.question,
|
||||||
plugin.device_id = q.device_id
|
person="Pierre",
|
||||||
|
history=history,
|
||||||
answer = agent.ask(question=q.question, history=history, tools=tools, prompts=prompts, language=q.language, device_id=q.device_id)
|
tools=tools,
|
||||||
|
prompts=prompts,
|
||||||
|
n_retrieve=0,
|
||||||
|
n_rerank=0
|
||||||
|
)
|
||||||
return {"question": q.question, "answer": answer, "history": q.history}
|
return {"question": q.question, "answer": answer, "history": q.history}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/retrieve")
|
||||||
|
def retrieve(q: Query):
|
||||||
|
"""
|
||||||
|
# Retrieve
|
||||||
|
Retrieve knowledge related to a query.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- **query**: the question.
|
||||||
|
"""
|
||||||
|
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
|
||||||
|
return {"query": q.query, "answer": result}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/learn")
|
||||||
|
def retrieve(q: Query):
|
||||||
|
"""
|
||||||
|
# Learn
|
||||||
|
Learn the agent something. E.g. *"John likes to play volleyball".*
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
- **query**: the note.
|
||||||
|
"""
|
||||||
|
answer = agent.insert(text=q.query)
|
||||||
|
return {"text": q.query}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/flush")
|
||||||
|
def flush():
|
||||||
|
"""
|
||||||
|
# Flush
|
||||||
|
Remove all knowledge from the agent.
|
||||||
|
"""
|
||||||
|
agent.flush()
|
||||||
|
return {"message": "Knowledge flushed."}
|
||||||
|
|||||||
@ -1,6 +1,11 @@
|
|||||||
server:
|
server:
|
||||||
url: "http://127.0.0.1:8000"
|
url: "http://127.0.0.1:8000"
|
||||||
llm:
|
llm:
|
||||||
|
use_local_chat: false
|
||||||
|
use_local_embedding: false
|
||||||
|
local_model_dir: ""
|
||||||
|
local_embedding_model: "intfloat/e5-large-v2"
|
||||||
|
local_rerank_model: "cross-encoder/stsb-distilroberta-base"
|
||||||
openai_embedding_model: "text-embedding-3-small"
|
openai_embedding_model: "text-embedding-3-small"
|
||||||
openai_chat_model: "gpt-3.5-turbo-0125"
|
openai_chat_model: "gpt-3.5-turbo-0125"
|
||||||
temperature: 0.1
|
temperature: 0.1
|
||||||
@ -18,9 +23,3 @@ plugins:
|
|||||||
default_from_station: "Brunnsparken"
|
default_from_station: "Brunnsparken"
|
||||||
default_to_station: "Centralstationen"
|
default_to_station: "Centralstationen"
|
||||||
delay: 0 # minutes
|
delay: 0 # minutes
|
||||||
music:
|
|
||||||
default_speaker: "media_player.kok"
|
|
||||||
device_speakers:
|
|
||||||
"d364e96a6664016ea4ac35f64a0d7498": "media_player.media_player.theo"
|
|
||||||
spotify_client_id: ""
|
|
||||||
spotify_client_secret: ""
|
|
||||||
@ -5,6 +5,13 @@ import numpy as np
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
|
try:
|
||||||
|
from llama_cpp.llama import Llama
|
||||||
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
except ImportError as e:
|
||||||
|
print("Faield to import packages:")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
class BaseChat:
|
class BaseChat:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
@ -166,7 +173,19 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.chat_client = OpenAIChat(self.config)
|
if self.config["use_local_chat"]:
|
||||||
|
self.chat_client = LlamaChat(self.config)
|
||||||
|
else:
|
||||||
|
self.chat_client = OpenAIChat(self.config)
|
||||||
|
|
||||||
|
if self.config["use_local_embedding"]:
|
||||||
|
self.embedding_model = self.config["local_embedding_model"]
|
||||||
|
self.rerank_model = self.config["local_rerank_model"]
|
||||||
|
self.embedding_client = SentenceTransformer(self.embedding_model)
|
||||||
|
else:
|
||||||
|
self.embedding_model = self.config["openai"]["embedding_model"]
|
||||||
|
self.rerank_model = None
|
||||||
|
self.embedding_client = OpenAI(api_key=self.config["openai"]["api_key"])
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
@ -217,3 +236,22 @@ class LLM:
|
|||||||
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
||||||
answer = self.chat_client.chat(messages=messages, tools=tools)
|
answer = self.chat_client.chat(messages=messages, tools=tools)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def embedding(self, text: str):
|
||||||
|
if self.config["use_local_embedding"]:
|
||||||
|
embedding = self.embedding_client.encode(text)
|
||||||
|
return embedding # len = 1024
|
||||||
|
else:
|
||||||
|
llm_response = self.embedding_client.embeddings.create(model=self.embedding_model, input=[text.replace("\n", " ")])
|
||||||
|
return llm_response.data[0].embedding
|
||||||
|
|
||||||
|
def rerank(self, text: str, corpus: List[str], n_rerank: int = 5):
|
||||||
|
if self.rerank_model is not None and len(corpus) > 0:
|
||||||
|
sentence_combinations = [[text, corpus_sentence] for corpus_sentence in corpus]
|
||||||
|
logging.info("Reranking:")
|
||||||
|
logging.info(sentence_combinations)
|
||||||
|
similarity_scores = CrossEncoder(self.rerank_model, max_length=1024).predict(sentence_combinations)
|
||||||
|
result = [corpus[idx] for idx in reversed(np.argsort(similarity_scores))]
|
||||||
|
return result[:n_rerank]
|
||||||
|
else:
|
||||||
|
return corpus
|
||||||
|
|||||||
@ -7,7 +7,6 @@ class BasePlugin:
|
|||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.homeassistant = HomeAssistant(config)
|
self.homeassistant = HomeAssistant(config)
|
||||||
self.device_id = ""
|
|
||||||
|
|
||||||
def prompt(self) -> str | None:
|
def prompt(self) -> str | None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -25,22 +25,6 @@ class Plugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_speaker_for_device(self, device_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the appropriate speaker entity_id based on device_id.
|
|
||||||
Falls back to default_speaker if device_id is not found or empty.
|
|
||||||
"""
|
|
||||||
device_speakers = self.config["plugins"]["music"].get("device_speakers", {})
|
|
||||||
|
|
||||||
if device_id and device_id in device_speakers:
|
|
||||||
speaker = device_speakers[device_id]
|
|
||||||
logging.info(f"Using device-specific speaker for {device_id}: {speaker}")
|
|
||||||
return speaker
|
|
||||||
|
|
||||||
default_speaker = self.config["plugins"]["music"]["default_speaker"]
|
|
||||||
logging.info(f"Using default speaker: {default_speaker}")
|
|
||||||
return default_speaker
|
|
||||||
|
|
||||||
def _search(self, query: str, limit: int = 10):
|
def _search(self, query: str, limit: int = 10):
|
||||||
_result = self.spotify.search(query, limit=limit)
|
_result = self.spotify.search(query, limit=limit)
|
||||||
result = []
|
result = []
|
||||||
@ -60,10 +44,9 @@ class Plugin(BasePlugin):
|
|||||||
Play music using a search query.
|
Play music using a search query.
|
||||||
"""
|
"""
|
||||||
track = self._search(input.query, limit=1)[0]
|
track = self._search(input.query, limit=1)[0]
|
||||||
speaker = self._get_speaker_for_device(self.device_id)
|
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])}")
|
||||||
logging.info(f"Playing {track['name']} by {', '.join(track['artists'])} on {speaker}")
|
|
||||||
payload = {
|
payload = {
|
||||||
"entity_id": speaker,
|
"entity_id": self.config["plugins"]["music"]["default_speaker"],
|
||||||
"media_content_id": track["uri"],
|
"media_content_id": track["uri"],
|
||||||
"media_content_type": "music",
|
"media_content_type": "music",
|
||||||
"enqueue": "play",
|
"enqueue": "play",
|
||||||
@ -75,8 +58,7 @@ class Plugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
Stop playback of music.
|
Stop playback of music.
|
||||||
"""
|
"""
|
||||||
speaker = self._get_speaker_for_device(self.device_id)
|
|
||||||
self.homeassistant.call_api(
|
self.homeassistant.call_api(
|
||||||
f"services/media_player/media_pause", payload={"entity_id": speaker}
|
f"services/media_player/media_pause", payload={"entity_id": self.config["plugins"]["music"]["default_speaker"]}
|
||||||
)
|
)
|
||||||
return json.dumps({"status": "success", "message": f"Music paused."})
|
return json.dumps({"status": "success", "message": f"Music paused."})
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
requests
|
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
openai>=1.11.1
|
openai>=1.11.1
|
||||||
|
|||||||
@ -9,4 +9,4 @@ Today is {{ today }}.
|
|||||||
- {{ prompt }}
|
- {{ prompt }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
Answer the questions to the best of your knowledge.
|
Answer the questions from {{ person }} to the best of your knowledge. Reply in Swedish when you get a sweedish question.
|
||||||
@ -1,12 +0,0 @@
|
|||||||
Du är en hjälpsam assistent. Dina svar kommer att konverteras till ljud/tal, så svara kort och koncist och undvik för många specialtecken.
|
|
||||||
Idag är det {{ today }}.
|
|
||||||
|
|
||||||
{% if knowledge is not none %}
|
|
||||||
{{ knowledge }}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% for prompt in plugin_prompts %}
|
|
||||||
- {{ prompt }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
Svara på frågorna på svenska efter bästa förmåga.
|
|
||||||
Loading…
x
Reference in New Issue
Block a user