import datetime import json from jinja2 import Template import logging from typing import Union from db import Database from llm import LLM def relative_to_actual_dates(text: str): from dateparser.search import search_dates matches = search_dates(text, languages=["en", "se"]) if matches: for match in matches: text = text.replace(match[0], match[1].strftime("%Y-%m-%d")) return text class Agent: def __init__(self, config: dict) -> None: self.config = config self.db = Database("rag") self.llm = LLM(config=self.config["llm"]) def flush(self): """ Flushes the agents knowledge. """ logging.info("Truncating database") self.db.cur.execute("TRUNCATE TABLE items") self.db.conn.commit() def insert(self, text): """ Append knowledge. """ logging.info(f"Inserting item into embedding table ({text}).") #logging.info(f"\tReplacing relative dates with actual dates.") #text = relative_to_actual_dates(text) #logging.info(f"\tDone. text: {text}") vector = self.llm.embedding(text) vector_padded = vector + ([0]*(2048-len(vector))) self.db.cur.execute( "INSERT INTO items (text, embedding, date_added) VALUES(%s, %s, %s)", (text, vector_padded, datetime.datetime.now().strftime("%Y-%m-%d")), ) self.db.conn.commit() def keywords(self, query: str, n_keywords: int = 5): """ Suggest keywords related to a query. """ sys_msg = "You are a helpful assistant. Make your response as concise as possible, with no introduction or background at the start." prompt = ( f"Provide a json list of {n_keywords} words or nouns that represents in a vector database query for the following prompt: {query}." + f"Do not add any context or text other than the json list of {n_keywords} words." ) response = self.llm.query(prompt, system_msg=sys_msg) keywords = json.loads(response) return keywords def retrieve(self, query: str, n_retrieve=10, n_rerank=5): """ Retrieve relevant knowledge. Parameters: - n_retrieve (int): How many notes to retrieve. - n_rerank (int): How many notes to keep after reranking. """ logging.info(f"Retrieving knowledge from database using the query: '{query}'") logging.debug(f"Using embedding model on '{query}'.") vector_keywords = self.llm.embedding(query) logging.debug("Embedding received. len(vector_keywords): %s", len(vector_keywords)) logging.debug("Querying database") vector_padded = vector_keywords + ([0]*(2048-len(vector_keywords))) self.db.cur.execute( f"SELECT text, date_added FROM items ORDER BY embedding <-> %s::vector LIMIT {n_retrieve}", (vector_padded,), ) db_response = self.db.cur.fetchall() knowledge_arr = [{"text": row[0], "date_added": row[1]} for row in db_response] logging.debug("Database returned: ") logging.debug(knowledge_arr) if n_retrieve > n_rerank: logging.info("Reranking the results") reranked_texts = self.llm.rerank(query, [note["text"] for note in knowledge_arr], n_rerank=n_rerank) reranked_knowledge_arr = sorted( knowledge_arr, key=lambda x: ( reranked_texts.index(x["text"]) if x["text"] in reranked_texts else len(knowledge_arr) ), ) reranked_knowledge_arr = reranked_knowledge_arr[:n_rerank] logging.debug("Reranked results: ") logging.debug(reranked_knowledge_arr) return reranked_knowledge_arr else: return knowledge_arr def ask( self, question: str, person: str, history: list = None, n_retrieve: int = 5, n_rerank=3, tools: Union[None, list] = None, prompts: Union[None, list] = None, ): """ Ask the agent a question. Parameters: - question (str): The question to ask - person (str): Who asks the question? - n_retrieve (int): How many notes to retrieve. - n_rerank (int): How many notes to keep after reranking. """ logging.info(f"Answering question: {question}") if n_retrieve > 0: knowledge = self.retrieve(question, n_retrieve=n_retrieve, n_rerank=n_rerank) with open('templates/knowledge.en.j2', 'r') as file: template = Template(file.read()) knowledge_str = template.render( knowledge = knowledge, ) else: knowledge_str = "" with open('templates/ask.en.j2', 'r') as file: template = Template(file.read()) prompt = template.render( today = datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), knowledge = knowledge_str, plugin_prompts = prompts, person = person ) logging.debug(f"Asking the LLM with the prompt: '{prompt}'.") answer = self.llm.query( question, system_msg=prompt, tools=tools, history=history ) logging.info(f"Got answer from LLM: '{answer}'.") return answer