128 lines
2.8 KiB
Python
128 lines
2.8 KiB
Python
from fastapi import FastAPI
|
|
from fastapi.staticfiles import StaticFiles
|
|
import importlib
|
|
import logging
|
|
import os
|
|
from pydantic import BaseModel
|
|
import yaml
|
|
|
|
from agent import Agent
|
|
|
|
PLUGINS_FOLDER = "plugins"
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s | %(levelname)s | %(funcName)s : %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
|
|
with open("config.yml", "r", encoding='utf8') as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
def load_plugins(config):
|
|
tools = []
|
|
prompts = []
|
|
|
|
for plugin_folder in os.listdir("./" + PLUGINS_FOLDER):
|
|
if os.path.isdir(os.path.join("./" + PLUGINS_FOLDER, plugin_folder)) and "__pycache__" not in plugin_folder:
|
|
logging.info(f"Loading plugin: {plugin_folder}")
|
|
module_path = f"{PLUGINS_FOLDER}.{plugin_folder}.plugin"
|
|
module = importlib.import_module(module_path)
|
|
plugin = module.Plugin(config)
|
|
tools += plugin.tools()
|
|
prompt = plugin.prompt()
|
|
if prompt is not None:
|
|
prompts.append(prompt)
|
|
|
|
return tools, prompts
|
|
|
|
tools, prompts = load_plugins(config)
|
|
|
|
app = FastAPI()
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
agent = Agent(config)
|
|
|
|
|
|
class Query(BaseModel):
|
|
query: str
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class Question(BaseModel):
|
|
question: str
|
|
person: str = "User"
|
|
history: list[Message] = None
|
|
|
|
|
|
@app.get("/ping")
|
|
def ping():
|
|
"""
|
|
# Ping
|
|
Check if the API is up.
|
|
"""
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/ask")
|
|
def ask(q: Question):
|
|
"""
|
|
# Ask
|
|
Ask the agent a question.
|
|
|
|
## Parameters
|
|
- **question**: the question to ask.
|
|
- **person**: who is asking the question?
|
|
"""
|
|
history = None
|
|
if q.history:
|
|
history = [q.model_dump() for q in q.history]
|
|
|
|
answer = agent.ask(
|
|
question=q.question,
|
|
person="Pierre",
|
|
history=history,
|
|
tools=tools,
|
|
prompts=prompts,
|
|
n_retrieve=0,
|
|
n_rerank=0
|
|
)
|
|
return {"question": q.question, "answer": answer, "history": q.history}
|
|
|
|
|
|
@app.post("/retrieve")
|
|
def retrieve(q: Query):
|
|
"""
|
|
# Retrieve
|
|
Retrieve knowledge related to a query.
|
|
|
|
## Parameters
|
|
- **query**: the question.
|
|
"""
|
|
result = agent.retrieve(query=q.query, n_retrieve=10, n_rerank=5)
|
|
return {"query": q.query, "answer": result}
|
|
|
|
|
|
@app.post("/learn")
|
|
def retrieve(q: Query):
|
|
"""
|
|
# Learn
|
|
Learn the agent something. E.g. *"John likes to play volleyball".*
|
|
|
|
## Parameters
|
|
- **query**: the note.
|
|
"""
|
|
answer = agent.insert(text=q.query)
|
|
return {"text": q.query}
|
|
|
|
|
|
@app.get("/flush")
|
|
def flush():
|
|
"""
|
|
# Flush
|
|
Remove all knowledge from the agent.
|
|
"""
|
|
agent.flush()
|
|
return {"message": "Knowledge flushed."}
|