162 lines
5.4 KiB
Python

import datetime
import json
from jinja2 import Template
import logging
from typing import Union
from db import Database
from llm import LLM
def relative_to_actual_dates(text: str):
from dateparser.search import search_dates
matches = search_dates(text, languages=["en", "se"])
if matches:
for match in matches:
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
return text
class Agent:
def __init__(self, config: dict) -> None:
self.config = config
self.db = Database("rag")
self.llm = LLM(config=self.config["llm"])
def flush(self):
"""
Flushes the agents knowledge.
"""
logging.info("Truncating database")
self.db.cur.execute("TRUNCATE TABLE items")
self.db.conn.commit()
def insert(self, text):
"""
Append knowledge.
"""
logging.info(f"Inserting item into embedding table ({text}).")
#logging.info(f"\tReplacing relative dates with actual dates.")
#text = relative_to_actual_dates(text)
#logging.info(f"\tDone. text: {text}")
vector = self.llm.embedding(text)
vector_padded = vector + ([0]*(2048-len(vector)))
self.db.cur.execute(
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
)
self.db.conn.commit()
def keywords(self, query: str, n_keywords: int = 5):
"""
Suggest keywords related to a query.
"""
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
prompt = (
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
+ f"Do not add any context or text other than the json list of {n_keywords} words."
)
response = self.llm.query(prompt, system_msg=sys_msg)
keywords = json.loads(response)
return keywords
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
"""
Retrieve relevant knowledge.
Parameters:
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
logging.debug(f"Using embedding model on '{query}'.")
vector_keywords = self.llm.embedding(query)
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
logging.debug("Querying database")
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
self.db.cur.execute(
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
(vector_padded,),
)
db_response = self.db.cur.fetchall()
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
logging.debug("Database returned: ")
logging.debug(knowledge_arr)
if n_retrieve > n_rerank:
logging.info("Reranking the results")
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
reranked_knowledge_arr = sorted(
knowledge_arr,
key=lambda x: (
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
),
)
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
logging.debug("Reranked results: ")
logging.debug(reranked_knowledge_arr)
return reranked_knowledge_arr
else:
return knowledge_arr
def ask(
self,
question: str,
person: str,
history: list = None,
n_retrieve: int = 5,
n_rerank=3,
tools: Union[None, list] = None,
prompts: Union[None, list] = None,
):
"""
Ask the agent a question.
Parameters:
- question (str): The question to ask
- person (str): Who asks the question?
- n_retrieve (int): How many notes to retrieve.
- n_rerank (int): How many notes to keep after reranking.
"""
logging.info(f"Answering question: {question}")
if n_retrieve > 0:
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
with open('templates/knowledge.en.j2', 'r') as file:
template = Template(file.read())
knowledge_str = template.render(
knowledge = knowledge,
)
else:
knowledge_str = ""
with open('templates/ask.en.j2', 'r') as file:
template = Template(file.read())
prompt = template.render(
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
knowledge = knowledge_str,
plugin_prompts = prompts,
person = person
)
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
answer = self.llm.query(
question,
system_msg=prompt,
tools=tools,
history=history
)
logging.info(f"Got answer from LLM: '{answer}'.")
return answer