Spaces:
Runtime error
Runtime error
| from datetime import datetime, timedelta | |
| import functools | |
| import json | |
| import os | |
| import pandas as pd | |
| from prophet import Prophet | |
| from pathlib import Path | |
| from mistralai.client import MistralClient | |
| from mistralai.models.chat_completion import ChatMessage | |
| # MODEL | |
| MODEL = "mistral-large-latest" | |
| API_KEY=os.environ["MISTRAL_API_KEY"] | |
| CLIENT = MistralClient(api_key=API_KEY) | |
| # PATH | |
| FILE = Path(__file__).resolve() | |
| BASE_PATH = FILE.parents[1] | |
| HISTORY = pd.read_csv(os.path.join(BASE_PATH, "data/cereal_price.csv"), encoding="latin-1") | |
| HISTORY = HISTORY[HISTORY["memberStateName"]=="France"] | |
| HISTORY['price'] = HISTORY['price'].str.replace(",", ".").astype('float64') | |
| def model_predict(week=26): | |
| """ | |
| Predict future prices using the Prophet model. | |
| Parameters: | |
| - weeks (int): Number of periods to predict into the future (default is 26). | |
| Returns: | |
| - dict: Dictionary containing predicted values and confidence intervals. | |
| """ | |
| # Prepare the historical data for the model | |
| data = HISTORY[['endDate', 'price']] | |
| data.columns = ['ds', 'y'] | |
| # Prophet Model | |
| # Instantiate a Prophet object | |
| model = Prophet() | |
| # Fit the model with historical data | |
| model.fit(data) | |
| # Calculate the current date | |
| today_date = datetime.now().date() | |
| # Calculate the end date for the future DataFrame (specified number of periods from today) | |
| end_date = today_date + timedelta(weeks=week) | |
| # Create a DataFrame with dates starting from today and ending in the specified number of periods | |
| future_df = pd.date_range(start=today_date, end=end_date, freq='W').to_frame(name='ds').reset_index(drop=True) | |
| # Make predictions on the future DataFrame | |
| forecast = model.predict(future_df) | |
| # Return relevant columns from the forecast DataFrame as a dictionary | |
| result_dict = { | |
| 'ds': forecast['ds'].tolist(), | |
| 'yhat_lower': forecast['yhat_lower'].tolist(), | |
| 'yhat_upper': forecast['yhat_upper'].tolist(), | |
| 'yhat': forecast['yhat'].tolist() | |
| } | |
| return result_dict | |
| model_predict_tool = [{ | |
| "type": "function", | |
| "function": { | |
| "name": "model_predict", | |
| "description": "Predict future prices using the Prophet model.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "week": { | |
| "type": "integer", | |
| "description": "Number of periods to predict into the future (default is 26).", | |
| }, | |
| }, | |
| "required": ["week"] | |
| }, | |
| }, | |
| }] | |
| names_to_functions = { | |
| 'model_predict': functools.partial(model_predict), | |
| } | |
| # messages = [ | |
| # ChatMessage(role="user", content="Predict future prices using the Prophet model for 4 weeks in the future") | |
| # ] | |
| def forecast(messages | |
| ): | |
| response = CLIENT.chat( | |
| model=MODEL, | |
| messages=messages, | |
| tools=model_predict_tool, | |
| tool_choice="auto" | |
| ) | |
| is_ok = True | |
| try: | |
| tool_call = response.choices[0].message.tool_calls[0] | |
| function_name = tool_call.function.name | |
| function_params = json.loads(tool_call.function.arguments) | |
| function_result = names_to_functions[function_name](**function_params) | |
| date = function_result["ds"][-1] | |
| lower = function_result["yhat_lower"][-1] | |
| upper = function_result["yhat_upper"][-1] | |
| prediction = function_result["yhat"][-1] | |
| except: | |
| is_ok = False | |
| pass | |
| if is_ok: | |
| return {"date" : str(date), "prix_minimum": lower, "prix_maximum": upper, "prix_estimé": prediction} | |
| else: | |
| return response.choices[0].message.content |