fair-plai / src /utils_fct.py
G-T's picture
Update src/utils_fct.py
c9fb5b3 verified
from datetime import datetime, timedelta
import functools
import json
import os
import pandas as pd
from prophet import Prophet
from pathlib import Path
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
# MODEL
MODEL = "mistral-large-latest"
API_KEY=os.environ["MISTRAL_API_KEY"]
CLIENT = MistralClient(api_key=API_KEY)
# PATH
FILE = Path(__file__).resolve()
BASE_PATH = FILE.parents[1]
HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1")
HISTORY = HISTORY[HISTORY["memberStateName"]=="France"]
HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64')
def model_predict(week=26):
"""
Predict future prices using the Prophet model.
Parameters:
- weeks (int): Number of periods to predict into the future (default is 26).
Returns:
- dict: Dictionary containing predicted values and confidence intervals.
"""
# Prepare the historical data for the model
data = HISTORY[['endDate', 'price']]
data.columns = ['ds', 'y']
# Prophet Model
# Instantiate a Prophet object
model = Prophet()
# Fit the model with historical data
model.fit(data)
# Calculate the current date
today_date = datetime.now().date()
# Calculate the end date for the future DataFrame (specified number of periods from today)
end_date = today_date + timedelta(weeks=week)
# Create a DataFrame with dates starting from today and ending in the specified number of periods
future_df = pd.date_range(start=today_date, end=end_date, freq='W').to_frame(name='ds').reset_index(drop=True)
# Make predictions on the future DataFrame
forecast = model.predict(future_df)
# Return relevant columns from the forecast DataFrame as a dictionary
result_dict = {
'ds': forecast['ds'].tolist(),
'yhat_lower': forecast['yhat_lower'].tolist(),
'yhat_upper': forecast['yhat_upper'].tolist(),
'yhat': forecast['yhat'].tolist()
}
return result_dict
model_predict_tool = [{
"type": "function",
"function": {
"name": "model_predict",
"description": "Predict future prices using the Prophet model.",
"parameters": {
"type": "object",
"properties": {
"week": {
"type": "integer",
"description": "Number of periods to predict into the future (default is 26).",
},
},
"required": ["week"]
},
},
}]
names_to_functions = {
'model_predict': functools.partial(model_predict),
}
# messages = [
# ChatMessage(role="user", content="Predict future prices using the Prophet model for 4 weeks in the future")
# ]
def forecast(messages
):
response = CLIENT.chat(
model=MODEL,
messages=messages,
tools=model_predict_tool,
tool_choice="auto"
)
is_ok = True
try:
tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
function_params = json.loads(tool_call.function.arguments)
function_result = names_to_functions[function_name](**function_params)
date = function_result["ds"][-1]
lower = function_result["yhat_lower"][-1]
upper = function_result["yhat_upper"][-1]
prediction = function_result["yhat"][-1]
except:
is_ok = False
pass
if is_ok:
return {"date" : str(date), "prix_minimum": lower, "prix_maximum": upper, "prix_estimé": prediction}
else:
return response.choices[0].message.content