220 lines
8.3 KiB
Python
220 lines
8.3 KiB
Python
import concurrent.futures
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
from openai import OpenAI
|
|
from typing import Union, List
|
|
|
|
|
|
class BaseChat:
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
|
|
def prepare_function_calling(self, tools):
|
|
# prepare function calling
|
|
function_map = {}
|
|
if tools is not None and len(tools) > 0:
|
|
for tool in tools:
|
|
function_name = tool["function"]["name"]
|
|
function_map[function_name] = tool["function"]["function_to_call"]
|
|
functions = []
|
|
for tool in tools:
|
|
fn = tool["function"]
|
|
functions.append(
|
|
{"type": tool["type"], "function": {x: fn[x] for x in fn if x != "function_to_call"}}
|
|
)
|
|
logging.info(f"{len(tools)} available functions:")
|
|
logging.info(function_map.keys())
|
|
else:
|
|
functions = None
|
|
return function_map, functions
|
|
|
|
|
|
class OpenAIChat(BaseChat):
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
if self.config["openai"]["base_url"] is not None and self.config["openai"]["base_url"] != "":
|
|
base_url = self.config["openai"]["base_url"]
|
|
else:
|
|
base_url = None
|
|
self.client = OpenAI(base_url=base_url, api_key=self.config["openai"]["api_key"])
|
|
|
|
def chat(self, messages, tools) -> str:
|
|
function_map, functions = self.prepare_function_calling(tools)
|
|
|
|
logging.debug("Sending request to OpenAI.")
|
|
llm_response = self.client.chat.completions.create(
|
|
model=self.config["openai"]["chat_model"],
|
|
messages=messages,
|
|
tools=functions,
|
|
temperature=self.config["temperature"],
|
|
tool_choice="auto" if functions is not None else None,
|
|
)
|
|
logging.debug("LLM response:")
|
|
logging.debug(llm_response.choices)
|
|
if llm_response.choices[0].message.tool_calls:
|
|
# Handle function calls
|
|
followup_response = self.execute_function_call(
|
|
llm_response.choices[0].message, function_map, messages
|
|
)
|
|
return followup_response.choices[0].message.content.strip()
|
|
else:
|
|
return llm_response.choices[0].message.content.strip()
|
|
|
|
def execute_function_call(self, message, function_map: dict, messages: list) -> str:
|
|
"""
|
|
Executes function calls embedded in a LLM message in parallel, and returns a LLM response based on the results.
|
|
|
|
Parameters:
|
|
- message: LLM message containing the function calls.
|
|
- function_map (dict): dict of {"function_name": function()}
|
|
- message (list): message history
|
|
"""
|
|
tool_calls = message.tool_calls
|
|
logging.info(f"Got {len(tool_calls)} function call(s).")
|
|
|
|
def execute_single_tool_call(tool_call):
|
|
"""Helper function to execute a single tool call"""
|
|
logging.info(f"Attempting to execute function requested by LLM ({tool_call.function.name}, {tool_call.function.arguments}).")
|
|
if tool_call.function.name in function_map:
|
|
function_to_call = function_map[tool_call.function.name]
|
|
args = json.loads(tool_call.function.arguments)
|
|
logging.debug(f"Calling function {tool_call.function.name} with args: {args}")
|
|
function_response = function_to_call(**args)
|
|
logging.debug(function_response)
|
|
return {
|
|
"role": "function",
|
|
"tool_call_id": tool_call.id,
|
|
"name": tool_call.function.name,
|
|
"content": function_response,
|
|
}
|
|
else:
|
|
logging.info(f"{tool_call.function.name} not in function_map")
|
|
logging.info(function_map.keys())
|
|
return None
|
|
|
|
# Execute tool calls in parallel
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
# Submit all tool calls to the executor
|
|
future_to_tool_call = {
|
|
executor.submit(execute_single_tool_call, tool_call): tool_call
|
|
for tool_call in tool_calls
|
|
}
|
|
|
|
# Collect results as they complete
|
|
for future in concurrent.futures.as_completed(future_to_tool_call):
|
|
result = future.result()
|
|
if result is not None:
|
|
messages.append(result)
|
|
|
|
logging.debug("Functions called, sending the results to LLM.")
|
|
llm_response = self.client.chat.completions.create(
|
|
model=self.config["openai"]["chat_model"],
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
)
|
|
logging.debug("Got response from LLM:")
|
|
logging.debug(llm_response)
|
|
return llm_response
|
|
|
|
|
|
class LMStudioChat(BaseChat):
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
self.client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed")
|
|
|
|
def chat(self, messages, tools) -> str:
|
|
logging.info("Sending request to local LLM.")
|
|
llm_response = self.client.chat.completions.create(
|
|
model="",
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
)
|
|
return llm_response.choices[0].message.content.strip()
|
|
|
|
|
|
class LlamaChat(BaseChat):
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
self.client = Llama(
|
|
self.config["local_model_dir"] + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF/mistral-7b-instruct-v0.1.Q6_K.gguf",
|
|
n_gpu_layers=32,
|
|
n_ctx=2048,
|
|
verbose=False
|
|
)
|
|
|
|
def chat(self, messages: list, response_format: dict = None, tools: list = None) -> str:
|
|
if tools is not None:
|
|
logging.warning("Tools was provided to LlamaChat, but it's not yet supported.")
|
|
logging.info("Sending request to local LLM.")
|
|
logging.info(messages)
|
|
llm_response = self.client.create_chat_completion(
|
|
messages = messages,
|
|
response_format = response_format,
|
|
temperature = self.config["temperature"],
|
|
)
|
|
return llm_response["choices"][0]["message"]["content"].strip()
|
|
|
|
|
|
class LLM:
|
|
def __init__(self, config: dict) -> None:
|
|
"""
|
|
LLM constructor.
|
|
|
|
Parameters:
|
|
- config (dict): llm-config parsed from the llm part of config.yml
|
|
"""
|
|
self.config = config
|
|
|
|
self.chat_client = OpenAIChat(self.config)
|
|
|
|
def query(
|
|
self,
|
|
user_msg: str,
|
|
system_msg: Union[None, str] = None,
|
|
history: list = [],
|
|
tools: Union[None, list] = None,
|
|
):
|
|
"""
|
|
Query the LLM
|
|
|
|
Parameters:
|
|
- user_msg (str): query from the user (will be appended to messages as {"role": "user"})
|
|
- system_msg (str): query from the user (will be prepended to messages as {"role": "system"})
|
|
- history (list): optional, list of messages to inject between system_msg and user_msg.
|
|
- tools: optional, list of functions that may be called by the LLM. Example:
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_weather",
|
|
"function": get_current_weather,
|
|
"description": "Get the current weather",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA",
|
|
}
|
|
},
|
|
"required": ["location", "format"],
|
|
},
|
|
},
|
|
}
|
|
"""
|
|
messages = []
|
|
|
|
if system_msg is not None:
|
|
messages.append({"role": "system", "content": system_msg})
|
|
|
|
if history and len(history) > 0:
|
|
logging.info("History")
|
|
logging.info(history)
|
|
messages += history
|
|
|
|
messages.append({"role": "user", "content": user_msg})
|
|
|
|
logging.info(f"Sending request to LLM with {len(messages)} messages.")
|
|
answer = self.chat_client.chat(messages=messages, tools=tools)
|
|
return answer
|