99 lines
4.0 KiB
Python
99 lines
4.0 KiB
Python
import json
|
|
from datetime import datetime
|
|
|
|
import vasttrafik
|
|
|
|
from ..homeassistant import HomeAssistant
|
|
|
|
|
|
class Plugin:
|
|
def __init__(self, config: dict) -> None:
|
|
self.config = config
|
|
self.homeassistant = HomeAssistant(config)
|
|
self.vasttrafik = vasttrafik.JournyPlanner(
|
|
key=self.config["plugins"]["vasttrafik"]["key"],
|
|
secret=self.config["plugins"]["vasttrafik"]["secret"],
|
|
)
|
|
self.default_from_station_id = self.get_station_id(
|
|
self.config["plugins"]["vasttrafik"]["default_from_station"]
|
|
)
|
|
self.default_to_station_id = self.get_station_id(
|
|
self.config["plugins"]["vasttrafik"]["default_to_station"]
|
|
)
|
|
|
|
def prompt(self):
|
|
from_station = self.config["plugins"]["vasttrafik"]["default_from_station"]
|
|
to_station = self.config["plugins"]["vasttrafik"]["default_to_station"]
|
|
return "# Public transportation #\nHome station is: %s, and the default to station to use is %s" % (
|
|
from_station,
|
|
to_station,
|
|
)
|
|
|
|
def tools(self):
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_trip",
|
|
"function_to_call": self.search_trip,
|
|
"description": "Search for trips by public transportation.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"from_station": {
|
|
"type": "string",
|
|
"description": "Station to travel from.",
|
|
},
|
|
"to_station": {
|
|
"type": "string",
|
|
"description": "Station to travel to.",
|
|
},
|
|
},
|
|
"required": ["from_station", "to_station"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
return tools
|
|
|
|
def search_trip(self, from_station=None, to_station=None):
|
|
if from_station is None:
|
|
from_station_id = self.default_from_station_id
|
|
else:
|
|
from_station_id = self.get_station_id(from_station)
|
|
if to_station is None:
|
|
to_station_id = self.default_to_station_id
|
|
else:
|
|
to_station_id = self.get_station_id(to_station)
|
|
trip = self.vasttrafik.trip(from_station_id, to_station_id)
|
|
|
|
response = []
|
|
for departure in trip:
|
|
try:
|
|
departure_dt = datetime.strptime(
|
|
departure["tripLegs"][0]["origin"]["estimatedTime"][:19], "%Y-%m-%dT%H:%M:%S"
|
|
)
|
|
minutes_to_departure = (departure_dt - datetime.now()).seconds / 60
|
|
if minutes_to_departure >= self.config["plugins"]["vasttrafik"]["delay"]:
|
|
response.append(
|
|
{
|
|
"origin": {
|
|
"stationName": departure["tripLegs"][0]["origin"]["stopPoint"]["name"],
|
|
# "plannedTime": departure["tripLegs"][0]["origin"]["plannedTime"],
|
|
"estimatedDepartureTime": departure["tripLegs"][0]["origin"]["estimatedTime"],
|
|
},
|
|
"destination": {
|
|
"stationName": departure["tripLegs"][0]["destination"]["stopPoint"]["name"],
|
|
"estimatedArrivalTime": departure["tripLegs"][0]["estimatedArrivalTime"],
|
|
},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
assert len(response) > 0, "No trips found."
|
|
return json.dumps(response, ensure_ascii=False)
|
|
|
|
def get_station_id(self, location_name: str) -> str:
|
|
station_id = self.vasttrafik.location_name(location_name)[0]["gid"]
|
|
return station_id
|