Spaces:
Sleeping
Sleeping
from datetime import datetime, timedelta | |
import functools | |
import json | |
import os | |
import pandas as pd | |
from prophet import Prophet | |
from pathlib import Path | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
# MODEL | |
MODEL = "mistral-large-latest" | |
API_KEY=os.environ["MISTRAL_API_KEY"] | |
CLIENT = MistralClient(api_key=API_KEY) | |
# PATH | |
FILE = Path(__file__).resolve() | |
BASE_PATH = FILE.parents[1] | |
HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1") | |
HISTORY = HISTORY[HISTORY["memberStateName"]=="France"] | |
HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64') | |
def model_predict(week=26): | |
""" | |
Predict future prices using the Prophet model. | |
Parameters: | |
- weeks (int): Number of periods to predict into the future (default is 26). | |
Returns: | |
- dict: Dictionary containing predicted values and confidence intervals. | |
""" | |
# Prepare the historical data for the model | |
data = HISTORY[['endDate', 'price']] | |
data.columns = ['ds', 'y'] | |
# Prophet Model | |
# Instantiate a Prophet object | |
model = Prophet() | |
# Fit the model with historical data | |
model.fit(data) | |
# Calculate the current date | |
today_date = datetime.now().date() | |
# Calculate the end date for the future DataFrame (specified number of periods from today) | |
end_date = today_date + timedelta(weeks=week) | |
# Create a DataFrame with dates starting from today and ending in the specified number of periods | |
future_df = pd.date_range(start=today_date, end=end_date, freq='W').to_frame(name='ds').reset_index(drop=True) | |
# Make predictions on the future DataFrame | |
forecast = model.predict(future_df) | |
# Return relevant columns from the forecast DataFrame as a dictionary | |
result_dict = { | |
'ds': forecast['ds'].tolist(), | |
'yhat_lower': forecast['yhat_lower'].tolist(), | |
'yhat_upper': forecast['yhat_upper'].tolist(), | |
'yhat': forecast['yhat'].tolist() | |
} | |
return result_dict | |
model_predict_tool = [{ | |
"type": "function", | |
"function": { | |
"name": "model_predict", | |
"description": "Predict future prices using the Prophet model.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"week": { | |
"type": "integer", | |
"description": "Number of periods to predict into the future (default is 26).", | |
}, | |
}, | |
"required": ["week"] | |
}, | |
}, | |
}] | |
names_to_functions = { | |
'model_predict': functools.partial(model_predict), | |
} | |
# messages = [ | |
# ChatMessage(role="user", content="Predict future prices using the Prophet model for 4 weeks in the future") | |
# ] | |
def forecast(messages | |
): | |
response = CLIENT.chat( | |
model=MODEL, | |
messages=messages, | |
tools=model_predict_tool, | |
tool_choice="auto" | |
) | |
is_ok = True | |
try: | |
tool_call = response.choices[0].message.tool_calls[0] | |
function_name = tool_call.function.name | |
function_params = json.loads(tool_call.function.arguments) | |
function_result = names_to_functions[function_name](**function_params) | |
date = function_result["ds"][-1] | |
lower = function_result["yhat_lower"][-1] | |
upper = function_result["yhat_upper"][-1] | |
prediction = function_result["yhat"][-1] | |
except: | |
is_ok = False | |
pass | |
if is_ok: | |
return {"date" : str(date), "prix_minimum": lower, "prix_maximum": upper, "prix_estimé": prediction} | |
else: | |
return response.choices[0].message.content |