162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
import datetime
|
|
import json
|
|
from jinja2 import Template
|
|
import logging
|
|
from typing import Union
|
|
|
|
from db import Database
|
|
from llm import LLM
|
|
|
|
|
|
def relative_to_actual_dates(text: str):
|
|
from dateparser.search import search_dates
|
|
|
|
matches = search_dates(text, languages=["en", "se"])
|
|
|
|
if matches:
|
|
for match in matches:
|
|
text = text.replace(match[0], match[1].strftime("%Y-%m-%d"))
|
|
return text
|
|
|
|
|
|
class Agent:
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
self.db = Database("rag")
|
|
self.llm = LLM(config=self.config["llm"])
|
|
|
|
def flush(self):
|
|
"""
|
|
Flushes the agents knowledge.
|
|
"""
|
|
logging.info("Truncating database")
|
|
self.db.cur.execute("TRUNCATE TABLE items")
|
|
self.db.conn.commit()
|
|
|
|
def insert(self, text):
|
|
"""
|
|
Append knowledge.
|
|
"""
|
|
logging.info(f"Inserting item into embedding table ({text}).")
|
|
|
|
#logging.info(f"\tReplacing relative dates with actual dates.")
|
|
#text = relative_to_actual_dates(text)
|
|
#logging.info(f"\tDone. text: {text}")
|
|
|
|
vector = self.llm.embedding(text)
|
|
vector_padded = vector + ([0]*(2048-len(vector)))
|
|
|
|
self.db.cur.execute(
|
|
"INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)",
|
|
(text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")),
|
|
)
|
|
self.db.conn.commit()
|
|
|
|
def keywords(self, query: str, n_keywords: int = 5):
|
|
"""
|
|
Suggest keywords related to a query.
|
|
"""
|
|
sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start."
|
|
prompt = (
|
|
f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}."
|
|
+ f"Do not add any context or text other than the json list of {n_keywords} words."
|
|
)
|
|
response = self.llm.query(prompt, system_msg=sys_msg)
|
|
keywords = json.loads(response)
|
|
return keywords
|
|
|
|
def retrieve(self, query: str, n_retrieve=10, n_rerank=5):
|
|
"""
|
|
Retrieve relevant knowledge.
|
|
|
|
Parameters:
|
|
- n_retrieve (int): How many notes to retrieve.
|
|
- n_rerank (int): How many notes to keep after reranking.
|
|
"""
|
|
logging.info(f"Retrieving knowledge from database using the query: '{query}'")
|
|
|
|
logging.debug(f"Using embedding model on '{query}'.")
|
|
vector_keywords = self.llm.embedding(query)
|
|
logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords))
|
|
|
|
logging.debug("Querying database")
|
|
vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords)))
|
|
self.db.cur.execute(
|
|
f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}",
|
|
(vector_padded,),
|
|
)
|
|
db_response = self.db.cur.fetchall()
|
|
knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response]
|
|
logging.debug("Database returned: ")
|
|
logging.debug(knowledge_arr)
|
|
|
|
if n_retrieve > n_rerank:
|
|
logging.info("Reranking the results")
|
|
reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank)
|
|
reranked_knowledge_arr = sorted(
|
|
knowledge_arr,
|
|
key=lambda x: (
|
|
reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr)
|
|
),
|
|
)
|
|
reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank]
|
|
logging.debug("Reranked results: ")
|
|
logging.debug(reranked_knowledge_arr)
|
|
return reranked_knowledge_arr
|
|
else:
|
|
return knowledge_arr
|
|
|
|
def ask(
|
|
self,
|
|
question: str,
|
|
person: str,
|
|
history: list = None,
|
|
n_retrieve: int = 5,
|
|
n_rerank=3,
|
|
tools: Union[None, list] = None,
|
|
prompts: Union[None, list] = None,
|
|
):
|
|
"""
|
|
Ask the agent a question.
|
|
|
|
Parameters:
|
|
- question (str): The question to ask
|
|
- person (str): Who asks the question?
|
|
- n_retrieve (int): How many notes to retrieve.
|
|
- n_rerank (int): How many notes to keep after reranking.
|
|
"""
|
|
logging.info(f"Answering question: {question}")
|
|
|
|
if n_retrieve > 0:
|
|
knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank)
|
|
|
|
with open('templates/knowledge.en.j2', 'r') as file:
|
|
template = Template(file.read())
|
|
|
|
knowledge_str = template.render(
|
|
knowledge = knowledge,
|
|
)
|
|
else:
|
|
knowledge_str = ""
|
|
|
|
with open('templates/ask.en.j2', 'r') as file:
|
|
template = Template(file.read())
|
|
|
|
prompt = template.render(
|
|
today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"),
|
|
knowledge = knowledge_str,
|
|
plugin_prompts = prompts,
|
|
person = person
|
|
)
|
|
|
|
logging.debug(f"Asking the LLM with the prompt: '{prompt}'.")
|
|
answer = self.llm.query(
|
|
question,
|
|
system_msg=prompt,
|
|
tools=tools,
|
|
history=history
|
|
)
|
|
logging.info(f"Got answer from LLM: '{answer}'.")
|
|
|
|
return answer
|