Soumik555 commited on
Commit
a172af8
·
1 Parent(s): 5f8515f

added openai in orchestrator

Browse files
Files changed (1) hide show
  1. gemini_report_generator.py +1184 -316
gemini_report_generator.py CHANGED
@@ -1,366 +1,1234 @@
1
- import json
2
- import numpy as np
3
- import pandas as pd
4
- import re
5
- import os
6
- import uuid
7
- import logging
8
- from io import StringIO
9
- import sys
10
- import traceback
11
- from typing import Optional, Dict, Any, List
12
- from pydantic import BaseModel, Field
13
- from google.generativeai import GenerativeModel, configure
14
- from dotenv import load_dotenv
15
- import seaborn as sns
16
- import datetime as dt
17
 
18
- from supabase_service import upload_file_to_supabase
19
 
20
- pd.set_option('display.max_columns', None)
21
- pd.set_option('display.max_rows', None)
22
- pd.set_option('display.max_colwidth', None)
23
 
24
- load_dotenv()
25
 
26
 
27
- API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1]
28
- MODEL_NAME = 'gemini-2.0-flash'
29
 
30
- class FileProps(BaseModel):
31
- fileName: str
32
- filePath: str
33
- fileType: str # 'csv' | 'image'
34
 
35
- class Files(BaseModel):
36
- csv_files: List[FileProps]
37
- image_files: List[FileProps]
38
 
39
- class FileBoxProps(BaseModel):
40
- files: Files
41
 
42
- os.environ['MPLBACKEND'] = 'agg'
43
- import matplotlib.pyplot as plt
44
- plt.show = lambda: None
45
 
46
- logging.basicConfig(
47
- level=logging.INFO,
48
- format='%(asctime)s - %(levelname)s - %(message)s'
49
- )
50
- logger = logging.getLogger(__name__)
51
 
52
- class GeminiKeyManager:
53
- """Manage multiple Gemini API keys with failover"""
54
 
55
- def __init__(self, api_keys: List[str]):
56
- self.original_keys = api_keys.copy()
57
- self.available_keys = api_keys.copy()
58
- self.active_key = None
59
- self.failed_keys = {}
60
 
61
- def configure(self) -> bool:
62
- while self.available_keys:
63
- key = self.available_keys.pop(0)
64
- try:
65
- configure(api_key=key)
66
- self.active_key = key
67
- logger.info(f"Configured with key: {self._mask_key(key)}")
68
- return True
69
- except Exception as e:
70
- self.failed_keys[key] = str(e)
71
- logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
72
- logger.critical("All API keys failed")
73
- return False
74
 
75
- def _mask_key(self, key: str) -> str:
76
- return f"{key[:8]}...{key[-4:]}" if key else ""
77
 
78
- class PythonREPL:
79
- """Secure Python REPL with file generation tracking"""
80
 
81
- def __init__(self, df: pd.DataFrame):
82
- self.df = df
83
- self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}')
84
- os.makedirs(self.output_dir, exist_ok=True)
85
- self.local_env = {
86
- "pd": pd,
87
- "df": self.df.copy(),
88
- "plt": plt,
89
- "os": os,
90
- "uuid": uuid,
91
- "sns": sns,
92
- "json": json,
93
- "dt": dt,
94
- "output_dir": self.output_dir
95
- }
96
 
97
- def execute(self, code: str) -> Dict[str, Any]:
98
- print('Executing code...', code)
99
- old_stdout = sys.stdout
100
- sys.stdout = mystdout = StringIO()
101
- file_tracker = {
102
- 'csv_files': set(),
103
- 'image_files': set()
104
- }
105
 
106
- try:
107
- code = f"""
108
- import matplotlib.pyplot as plt
109
- plt.switch_backend('agg')
110
- {code}
111
- plt.close('all')
112
- """
113
- exec(code, self.local_env)
114
- self.df = self.local_env.get('df', self.df)
115
 
116
- # Track generated files
117
- for fname in os.listdir(self.output_dir):
118
- if fname.endswith('.csv'):
119
- file_tracker['csv_files'].add(fname)
120
- elif fname.lower().endswith(('.png', '.jpg', '.jpeg')):
121
- file_tracker['image_files'].add(fname)
122
 
123
- error = False
124
- except Exception as e:
125
- error_msg = traceback.format_exc()
126
- error = True
127
- finally:
128
- sys.stdout = old_stdout
129
 
130
- return {
131
- "output": mystdout.getvalue(),
132
- "error": error,
133
- "error_message": error_msg if error else None,
134
- "df": self.local_env.get('df', self.df),
135
- "output_dir": self.output_dir,
136
- "files": {
137
- "csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']],
138
- "images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']]
139
- }
140
- }
141
-
142
- class RethinkAgent(BaseModel):
143
- df: pd.DataFrame
144
- max_retries: int = Field(default=5, ge=1)
145
- gemini_model: Optional[GenerativeModel] = None
146
- current_retry: int = Field(default=0, ge=0)
147
- repl: Optional[PythonREPL] = None
148
- key_manager: Optional[GeminiKeyManager] = None
149
 
150
- class Config:
151
- arbitrary_types_allowed = True
152
 
153
- def _extract_code(self, response: str) -> str:
154
- code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
155
- return code_match.group(1).strip() if code_match else response.strip()
156
 
157
- def _generate_initial_prompt(self, query: str) -> str:
158
- return f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES:
159
 
160
- MANDATORY REQUIREMENTS:
161
- 1. Operate directly on existing 'df' variable
162
- 2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv')
163
- 3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png')
164
- 4. Use EXACTLY this structure:
165
- # Data processing
166
- df_processed = df[...] # filtering/grouping
167
- # Save results
168
- df_processed.to_csv(f'{{output_dir}}/result.csv')
169
- # Visualizations (if needed)
170
- plt.figure()
171
- ... plotting code ...
172
- plt.savefig(f'{{output_dir}}/chart.png')
173
- plt.close()
174
-
175
- FORBIDDEN:
176
- - Function definitions
177
- - Dummy data creation
178
- - Any code blocks besides pandas operations and matplotlib
179
- - Print statements showing dataframes
180
-
181
- DATAFRAME COLUMNS: {', '.join(self.df.columns)}
182
- DATAFRAME'S FIRST FIVE ROWS: {self.df.head().to_dict('records')}
183
- USER QUERY: {query}
184
-
185
- EXAMPLE RESPONSE FOR "Sales by region":
186
- # Data processing
187
- sales_by_region = df.groupby('region')['sales'].sum().reset_index()
188
- # Save results
189
- sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv')
190
- """
191
-
192
- def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
193
- return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING:
194
 
195
- 1. REMOVE ALL FUNCTION DEFINITIONS
196
- 2. ENSURE DIRECT DF OPERATIONS
197
- 3. USE EXPLICIT output_dir PATHS
198
- 4. ADD NECESSARY IMPORTS IF MISSING
199
- 5. VALIDATE COLUMN NAMES EXIST
200
 
201
- BAD CODE:
202
- {code}
203
 
204
- CORRECTED CODE:"""
205
 
206
- def initialize_model(self, api_keys: List[str]) -> bool:
207
- self.key_manager = GeminiKeyManager(api_keys)
208
- if not self.key_manager.configure():
209
- raise RuntimeError("API key initialization failed")
210
- try:
211
- self.gemini_model = GenerativeModel(MODEL_NAME)
212
- return True
213
- except Exception as e:
214
- logger.error(f"Model init failed: {str(e)}")
215
- return False
216
 
217
- def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
218
- prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query)
219
- try:
220
- response = self.gemini_model.generate_content(prompt)
221
- return self._extract_code(response.text)
222
- except Exception as e:
223
- if self.key_manager.available_keys and self.key_manager.configure():
224
- return self.generate_code(query, error, previous_code)
225
- raise
226
 
227
- def execute_query(self, query: str) -> Dict[str, Any]:
228
- self.repl = PythonREPL(self.df)
229
- result = None
230
 
231
- while self.current_retry < self.max_retries:
232
- try:
233
- code = self.generate_code(query,
234
- result["error_message"] if result else None,
235
- result["code"] if result else None)
236
- execution_result = self.repl.execute(code)
237
 
238
- if execution_result["error"]:
239
- self.current_retry += 1
240
- result = {
241
- "error_message": execution_result["error_message"],
242
- "code": code
243
- }
244
- else:
245
- return {
246
- "text": execution_result["output"],
247
- "csv_files": execution_result["files"]["csv"],
248
- "image_files": execution_result["files"]["images"]
249
- }
250
- except Exception as e:
251
- return {
252
- "error": f"Critical failure: {str(e)}",
253
- "csv_files": [],
254
- "image_files": []
255
- }
256
 
257
- return {
258
- "error": f"Failed after {self.max_retries} retries",
259
- "csv_files": [],
260
- "image_files": []
261
- }
262
-
263
- def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]:
264
- try:
265
- df = pd.read_csv(csv_url)
266
- agent = RethinkAgent(df=df)
267
 
268
- if not agent.initialize_model(API_KEYS):
269
- return {"error": "API configuration failed"}
270
 
271
- result = agent.execute_query(query)
272
 
273
- if "error" in result:
274
- return result
275
 
276
- return {
277
- "message": result["text"],
278
- "csv_files": result["csv_files"],
279
- "image_files": result["image_files"]
280
- }
281
- except Exception as e:
282
- logger.error(f"Processing failed: {str(e)}")
283
- return {
284
- "error": f"Processing error: {str(e)}",
285
- "csv_files": [],
286
- "image_files": []
287
- }
288
 
289
 
290
- async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
291
- try:
292
- result = gemini_llm_chat(csv_url, query)
293
- logger.info(f"Raw result from gemini_llm_chat: {result}")
294
 
295
- csv_files = []
296
- image_files = []
297
 
298
- # Check if we got the expected response structure
299
- if isinstance(result, dict) and 'csv_files' in result and 'image_files' in result:
300
- # Process CSV files
301
- for csv_path in result['csv_files']:
302
- if os.path.exists(csv_path):
303
- file_name = os.path.basename(csv_path)
304
- try:
305
- unique_file_name = f"{uuid.uuid4()}_{file_name}"
306
- public_url = await upload_file_to_supabase(
307
- file_path=csv_path,
308
- file_name=unique_file_name
309
- )
310
- csv_files.append(FileProps(
311
- fileName=file_name,
312
- filePath=public_url,
313
- fileType="csv"
314
- ))
315
- os.remove(csv_path) # Clean up
316
- except Exception as upload_error:
317
- logger.error(f"Failed to upload CSV {file_name}: {str(upload_error)}")
318
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- # Process image files
321
- for img_path in result['image_files']:
322
- if os.path.exists(img_path):
323
- file_name = os.path.basename(img_path)
324
- try:
325
- unique_file_name = f"{uuid.uuid4()}_{file_name}"
326
- public_url = await upload_file_to_supabase(
327
- file_path=img_path,
328
- file_name=unique_file_name
329
- )
330
- image_files.append(FileProps(
331
- fileName=file_name,
332
- filePath=public_url,
333
- fileType="image"
334
- ))
335
- os.remove(img_path) # Clean up
336
- except Exception as upload_error:
337
- logger.error(f"Failed to upload image {file_name}: {str(upload_error)}")
338
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- return FileBoxProps(
341
- files=Files(
342
- csv_files=csv_files,
343
- image_files=image_files
344
- )
 
 
 
 
 
345
  )
346
- else:
347
- raise ValueError("Unexpected response format from gemini_llm_chat")
348
 
349
- except Exception as e:
350
- logger.error(f"Report generation failed: {str(e)}")
351
- # Return empty response but log the files we found
352
- if 'csv_files' in locals() and 'image_files' in locals():
353
- logger.info(f"Files that were generated but not processed: CSV: {result.get('csv_files', [])}, Images: {result.get('image_files', [])}")
354
- return FileBoxProps(
355
- files=Files(
356
- csv_files=[],
357
- image_files=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  )
359
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
 
 
362
 
363
- # if __name__ == "__main__":
364
- # result = gemini_llm_chat("./documents/enterprise_sales_data.csv",
365
- # "Generate a detailed sales report of the last 6 months from all the aspects and include a bar chart showing the sales by region.")
366
- # print(json.dumps(result, indent=2))
 
1
+ # import json
2
+ # import numpy as np
3
+ # import pandas as pd
4
+ # import re
5
+ # import os
6
+ # import uuid
7
+ # import logging
8
+ # from io import StringIO
9
+ # import sys
10
+ # import traceback
11
+ # from typing import Optional, Dict, Any, List
12
+ # from pydantic import BaseModel, Field
13
+ # from google.generativeai import GenerativeModel, configure
14
+ # from dotenv import load_dotenv
15
+ # import seaborn as sns
16
+ # import datetime as dt
17
 
18
+ # from supabase_service import upload_file_to_supabase
19
 
20
+ # pd.set_option('display.max_columns', None)
21
+ # pd.set_option('display.max_rows', None)
22
+ # pd.set_option('display.max_colwidth', None)
23
 
24
+ # load_dotenv()
25
 
26
 
27
+ # API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1]
28
+ # MODEL_NAME = 'gemini-2.0-flash'
29
 
30
+ # class FileProps(BaseModel):
31
+ # fileName: str
32
+ # filePath: str
33
+ # fileType: str # 'csv' | 'image'
34
 
35
+ # class Files(BaseModel):
36
+ # csv_files: List[FileProps]
37
+ # image_files: List[FileProps]
38
 
39
+ # class FileBoxProps(BaseModel):
40
+ # files: Files
41
 
42
+ # os.environ['MPLBACKEND'] = 'agg'
43
+ # import matplotlib.pyplot as plt
44
+ # plt.show = lambda: None
45
 
46
+ # logging.basicConfig(
47
+ # level=logging.INFO,
48
+ # format='%(asctime)s - %(levelname)s - %(message)s'
49
+ # )
50
+ # logger = logging.getLogger(__name__)
51
 
52
+ # class GeminiKeyManager:
53
+ # """Manage multiple Gemini API keys with failover"""
54
 
55
+ # def __init__(self, api_keys: List[str]):
56
+ # self.original_keys = api_keys.copy()
57
+ # self.available_keys = api_keys.copy()
58
+ # self.active_key = None
59
+ # self.failed_keys = {}
60
 
61
+ # def configure(self) -> bool:
62
+ # while self.available_keys:
63
+ # key = self.available_keys.pop(0)
64
+ # try:
65
+ # configure(api_key=key)
66
+ # self.active_key = key
67
+ # logger.info(f"Configured with key: {self._mask_key(key)}")
68
+ # return True
69
+ # except Exception as e:
70
+ # self.failed_keys[key] = str(e)
71
+ # logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
72
+ # logger.critical("All API keys failed")
73
+ # return False
74
 
75
+ # def _mask_key(self, key: str) -> str:
76
+ # return f"{key[:8]}...{key[-4:]}" if key else ""
77
 
78
+ # class PythonREPL:
79
+ # """Secure Python REPL with file generation tracking"""
80
 
81
+ # def __init__(self, df: pd.DataFrame):
82
+ # self.df = df
83
+ # self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}')
84
+ # os.makedirs(self.output_dir, exist_ok=True)
85
+ # self.local_env = {
86
+ # "pd": pd,
87
+ # "df": self.df.copy(),
88
+ # "plt": plt,
89
+ # "os": os,
90
+ # "uuid": uuid,
91
+ # "sns": sns,
92
+ # "json": json,
93
+ # "dt": dt,
94
+ # "output_dir": self.output_dir
95
+ # }
96
 
97
+ # def execute(self, code: str) -> Dict[str, Any]:
98
+ # print('Executing code...', code)
99
+ # old_stdout = sys.stdout
100
+ # sys.stdout = mystdout = StringIO()
101
+ # file_tracker = {
102
+ # 'csv_files': set(),
103
+ # 'image_files': set()
104
+ # }
105
 
106
+ # try:
107
+ # code = f"""
108
+ # import matplotlib.pyplot as plt
109
+ # plt.switch_backend('agg')
110
+ # {code}
111
+ # plt.close('all')
112
+ # """
113
+ # exec(code, self.local_env)
114
+ # self.df = self.local_env.get('df', self.df)
115
 
116
+ # # Track generated files
117
+ # for fname in os.listdir(self.output_dir):
118
+ # if fname.endswith('.csv'):
119
+ # file_tracker['csv_files'].add(fname)
120
+ # elif fname.lower().endswith(('.png', '.jpg', '.jpeg')):
121
+ # file_tracker['image_files'].add(fname)
122
 
123
+ # error = False
124
+ # except Exception as e:
125
+ # error_msg = traceback.format_exc()
126
+ # error = True
127
+ # finally:
128
+ # sys.stdout = old_stdout
129
 
130
+ # return {
131
+ # "output": mystdout.getvalue(),
132
+ # "error": error,
133
+ # "error_message": error_msg if error else None,
134
+ # "df": self.local_env.get('df', self.df),
135
+ # "output_dir": self.output_dir,
136
+ # "files": {
137
+ # "csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']],
138
+ # "images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']]
139
+ # }
140
+ # }
141
+
142
+ # class RethinkAgent(BaseModel):
143
+ # df: pd.DataFrame
144
+ # max_retries: int = Field(default=5, ge=1)
145
+ # gemini_model: Optional[GenerativeModel] = None
146
+ # current_retry: int = Field(default=0, ge=0)
147
+ # repl: Optional[PythonREPL] = None
148
+ # key_manager: Optional[GeminiKeyManager] = None
149
 
150
+ # class Config:
151
+ # arbitrary_types_allowed = True
152
 
153
+ # def _extract_code(self, response: str) -> str:
154
+ # code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
155
+ # return code_match.group(1).strip() if code_match else response.strip()
156
 
157
+ # def _generate_initial_prompt(self, query: str) -> str:
158
+ # return f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES:
159
 
160
+ # MANDATORY REQUIREMENTS:
161
+ # 1. Operate directly on existing 'df' variable
162
+ # 2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv')
163
+ # 3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png')
164
+ # 4. Use EXACTLY this structure:
165
+ # # Data processing
166
+ # df_processed = df[...] # filtering/grouping
167
+ # # Save results
168
+ # df_processed.to_csv(f'{{output_dir}}/result.csv')
169
+ # # Visualizations (if needed)
170
+ # plt.figure()
171
+ # ... plotting code ...
172
+ # plt.savefig(f'{{output_dir}}/chart.png')
173
+ # plt.close()
174
+
175
+ # FORBIDDEN:
176
+ # - Function definitions
177
+ # - Dummy data creation
178
+ # - Any code blocks besides pandas operations and matplotlib
179
+ # - Print statements showing dataframes
180
+
181
+ # DATAFRAME COLUMNS: {', '.join(self.df.columns)}
182
+ # DATAFRAME'S FIRST FIVE ROWS: {self.df.head().to_dict('records')}
183
+ # USER QUERY: {query}
184
+
185
+ # EXAMPLE RESPONSE FOR "Sales by region":
186
+ # # Data processing
187
+ # sales_by_region = df.groupby('region')['sales'].sum().reset_index()
188
+ # # Save results
189
+ # sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv')
190
+ # """
191
+
192
+ # def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
193
+ # return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING:
194
 
195
+ # 1. REMOVE ALL FUNCTION DEFINITIONS
196
+ # 2. ENSURE DIRECT DF OPERATIONS
197
+ # 3. USE EXPLICIT output_dir PATHS
198
+ # 4. ADD NECESSARY IMPORTS IF MISSING
199
+ # 5. VALIDATE COLUMN NAMES EXIST
200
 
201
+ # BAD CODE:
202
+ # {code}
203
 
204
+ # CORRECTED CODE:"""
205
 
206
+ # def initialize_model(self, api_keys: List[str]) -> bool:
207
+ # self.key_manager = GeminiKeyManager(api_keys)
208
+ # if not self.key_manager.configure():
209
+ # raise RuntimeError("API key initialization failed")
210
+ # try:
211
+ # self.gemini_model = GenerativeModel(MODEL_NAME)
212
+ # return True
213
+ # except Exception as e:
214
+ # logger.error(f"Model init failed: {str(e)}")
215
+ # return False
216
 
217
+ # def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
218
+ # prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query)
219
+ # try:
220
+ # response = self.gemini_model.generate_content(prompt)
221
+ # return self._extract_code(response.text)
222
+ # except Exception as e:
223
+ # if self.key_manager.available_keys and self.key_manager.configure():
224
+ # return self.generate_code(query, error, previous_code)
225
+ # raise
226
 
227
+ # def execute_query(self, query: str) -> Dict[str, Any]:
228
+ # self.repl = PythonREPL(self.df)
229
+ # result = None
230
 
231
+ # while self.current_retry < self.max_retries:
232
+ # try:
233
+ # code = self.generate_code(query,
234
+ # result["error_message"] if result else None,
235
+ # result["code"] if result else None)
236
+ # execution_result = self.repl.execute(code)
237
 
238
+ # if execution_result["error"]:
239
+ # self.current_retry += 1
240
+ # result = {
241
+ # "error_message": execution_result["error_message"],
242
+ # "code": code
243
+ # }
244
+ # else:
245
+ # return {
246
+ # "text": execution_result["output"],
247
+ # "csv_files": execution_result["files"]["csv"],
248
+ # "image_files": execution_result["files"]["images"]
249
+ # }
250
+ # except Exception as e:
251
+ # return {
252
+ # "error": f"Critical failure: {str(e)}",
253
+ # "csv_files": [],
254
+ # "image_files": []
255
+ # }
256
 
257
+ # return {
258
+ # "error": f"Failed after {self.max_retries} retries",
259
+ # "csv_files": [],
260
+ # "image_files": []
261
+ # }
262
+
263
+ # def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]:
264
+ # try:
265
+ # df = pd.read_csv(csv_url)
266
+ # agent = RethinkAgent(df=df)
267
 
268
+ # if not agent.initialize_model(API_KEYS):
269
+ # return {"error": "API configuration failed"}
270
 
271
+ # result = agent.execute_query(query)
272
 
273
+ # if "error" in result:
274
+ # return result
275
 
276
+ # return {
277
+ # "message": result["text"],
278
+ # "csv_files": result["csv_files"],
279
+ # "image_files": result["image_files"]
280
+ # }
281
+ # except Exception as e:
282
+ # logger.error(f"Processing failed: {str(e)}")
283
+ # return {
284
+ # "error": f"Processing error: {str(e)}",
285
+ # "csv_files": [],
286
+ # "image_files": []
287
+ # }
288
 
289
 
290
+ # async def generate_csv_report(csv_url: str, query: str) -> FileBoxProps:
291
+ # try:
292
+ # result = gemini_llm_chat(csv_url, query)
293
+ # logger.info(f"Raw result from gemini_llm_chat: {result}")
294
 
295
+ # csv_files = []
296
+ # image_files = []
297
 
298
+ # # Check if we got the expected response structure
299
+ # if isinstance(result, dict) and 'csv_files' in result and 'image_files' in result:
300
+ # # Process CSV files
301
+ # for csv_path in result['csv_files']:
302
+ # if os.path.exists(csv_path):
303
+ # file_name = os.path.basename(csv_path)
304
+ # try:
305
+ # unique_file_name = f"{uuid.uuid4()}_{file_name}"
306
+ # public_url = await upload_file_to_supabase(
307
+ # file_path=csv_path,
308
+ # file_name=unique_file_name
309
+ # )
310
+ # csv_files.append(FileProps(
311
+ # fileName=file_name,
312
+ # filePath=public_url,
313
+ # fileType="csv"
314
+ # ))
315
+ # os.remove(csv_path) # Clean up
316
+ # except Exception as upload_error:
317
+ # logger.error(f"Failed to upload CSV {file_name}: {str(upload_error)}")
318
+ # continue
319
+
320
+ # # Process image files
321
+ # for img_path in result['image_files']:
322
+ # if os.path.exists(img_path):
323
+ # file_name = os.path.basename(img_path)
324
+ # try:
325
+ # unique_file_name = f"{uuid.uuid4()}_{file_name}"
326
+ # public_url = await upload_file_to_supabase(
327
+ # file_path=img_path,
328
+ # file_name=unique_file_name
329
+ # )
330
+ # image_files.append(FileProps(
331
+ # fileName=file_name,
332
+ # filePath=public_url,
333
+ # fileType="image"
334
+ # ))
335
+ # os.remove(img_path) # Clean up
336
+ # except Exception as upload_error:
337
+ # logger.error(f"Failed to upload image {file_name}: {str(upload_error)}")
338
+ # continue
339
+
340
+ # return FileBoxProps(
341
+ # files=Files(
342
+ # csv_files=csv_files,
343
+ # image_files=image_files
344
+ # )
345
+ # )
346
+ # else:
347
+ # raise ValueError("Unexpected response format from gemini_llm_chat")
348
 
349
+ # except Exception as e:
350
+ # logger.error(f"Report generation failed: {str(e)}")
351
+ # # Return empty response but log the files we found
352
+ # if 'csv_files' in locals() and 'image_files' in locals():
353
+ # logger.info(f"Files that were generated but not processed: CSV: {result.get('csv_files', [])}, Images: {result.get('image_files', [])}")
354
+ # return FileBoxProps(
355
+ # files=Files(
356
+ # csv_files=[],
357
+ # image_files=[]
358
+ # )
359
+ # )
360
+
361
+
362
+
363
+
364
+
365
+
366
+
367
+
368
+
369
+
370
+
371
+
372
+
373
+
374
+
375
+
376
+
377
+
378
+
379
+
380
+
381
+
382
+
383
+
384
+
385
+
386
+
387
+ # Newly Modified code with openai
388
+
389
+ # Import necessary modules
390
+ import asyncio
391
+ import os
392
+ import threading
393
+ from typing import Any, Dict, Union
394
+ import uuid
395
+ from fastapi.encoders import jsonable_encoder
396
+ from langchain_openai import ChatOpenAI
397
+ import numpy as np
398
+ import pandas as pd
399
+ from pandasai import SmartDataframe
400
+ from langchain_groq.chat_models import ChatGroq
401
+ from dotenv import load_dotenv
402
+ from pydantic import BaseModel
403
+ from csv_service import clean_data, extract_chart_filenames
404
+ from langchain_groq import ChatGroq
405
+ import pandas as pd
406
+ from langchain_experimental.tools import PythonAstREPLTool
407
+ from langchain_experimental.agents import create_pandas_dataframe_agent
408
+ import numpy as np
409
+ import matplotlib.pyplot as plt
410
+ import matplotlib
411
+ import seaborn as sns
412
+ from gemini_langchain_agent import langchain_gemini_csv_handler
413
+ from supabase_service import upload_file_to_supabase
414
+ from util_service import _prompt_generator, process_answer
415
+ import matplotlib
416
+ import logging
417
+ matplotlib.use('Agg')
418
+
419
+
420
+ load_dotenv()
421
+
422
+ image_file_path = os.getenv("IMAGE_FILE_PATH")
423
+ image_not_found = os.getenv("IMAGE_NOT_FOUND")
424
+ allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
425
+
426
+
427
+ # Load environment variables
428
+ groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
429
+ model_name = os.getenv("GROQ_LLM_MODEL")
430
+
431
+ openai_api_keys = os.getenv("OPENAI_API_KEYS").split(",")
432
+ openai_base_url = os.getenv("OPENAI_BASE_URL")
433
+ openai_api_base = os.getenv("OPENAI_BASE_URL")
434
+
435
+ # Set up logging
436
+ logging.basicConfig(level=logging.INFO)
437
+ logger = logging.getLogger(__name__)
438
+
439
+ class CsvUrlRequest(BaseModel):
440
+ csv_url: str
441
+
442
+ class ImageRequest(BaseModel):
443
+ image_path: str
444
+
445
+ class CsvCommonHeadersRequest(BaseModel):
446
+ file_urls: list[str]
447
+
448
+ class CsvsMergeRequest(BaseModel):
449
+ file_urls: list[str]
450
+ merge_type: str
451
+ common_columns_name: list[str]
452
+
453
+ # Thread-safe key management for openai_chat
454
+ current_openai_key_index = 0
455
+ current_openai_key_lock = threading.Lock()
456
+
457
+ # Thread-safe key management for groq_chat
458
+ current_groq_key_index = 0
459
+ current_groq_key_lock = threading.Lock()
460
+
461
+ # Thread-safe key management for langchain_csv_chat
462
+ current_langchain_key_index = 0
463
+ current_langchain_key_lock = threading.Lock()
464
+
465
+
466
+ # CHAT CODING STARTS FROM HERE
467
+ def handle_out_of_range_float(value):
468
+ if isinstance(value, float):
469
+ if np.isnan(value):
470
+ return None
471
+ elif np.isinf(value):
472
+ return "Infinity"
473
+ return value
474
+
475
+
476
+ # Modified groq_chat function with thread-safe key rotation
477
+ def groq_chat(csv_url: str, question: str):
478
+ global current_groq_key_index, current_groq_key_lock
479
+
480
+ while True:
481
+ with current_groq_key_lock:
482
+ if current_groq_key_index >= len(groq_api_keys):
483
+ return {"error": "All API keys exhausted."}
484
+ current_api_key = groq_api_keys[current_groq_key_index]
485
+
486
+ try:
487
+ # Delete cache file if exists
488
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
489
+ if os.path.exists(cache_db_path):
490
+ try:
491
+ os.remove(cache_db_path)
492
+ except Exception as e:
493
+ print(f"Error deleting cache DB file: {e}")
494
+
495
+ data = clean_data(csv_url)
496
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
497
+ # Generate unique filename using UUID
498
+ chart_filename = f"chart_{uuid.uuid4()}.png"
499
+ chart_path = os.path.join("generated_charts", chart_filename)
500
+
501
+ # Configure SmartDataframe with chart settings
502
+ df = SmartDataframe(
503
+ data,
504
+ config={
505
+ 'llm': llm,
506
+ 'save_charts': True, # Enable chart saving
507
+ 'open_charts': False,
508
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
509
+ 'custom_chart_filename': chart_filename # Unique filename
510
+ }
511
+ )
512
+
513
+ answer = df.chat(question)
514
+
515
+ # Process different response types
516
+ if isinstance(answer, pd.DataFrame):
517
+ processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
518
+ elif isinstance(answer, pd.Series):
519
+ processed = answer.apply(handle_out_of_range_float).to_dict()
520
+ elif isinstance(answer, list):
521
+ processed = [handle_out_of_range_float(item) for item in answer]
522
+ elif isinstance(answer, dict):
523
+ processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
524
+ else:
525
+ processed = {"answer": str(handle_out_of_range_float(answer))}
526
+
527
+ return processed
528
+
529
+ except Exception as e:
530
+ error_message = str(e)
531
+ if error_message:
532
+ with current_groq_key_lock:
533
+ current_groq_key_index += 1
534
+ if current_groq_key_index >= len(groq_api_keys):
535
+ print("All API keys exhausted.")
536
+ return None
537
+ else:
538
+ print(f"Error with API key index {current_groq_key_index}: {error_message}")
539
+ return None
540
+
541
+
542
+
543
+
544
+
545
+
546
+
547
+ # Modified langchain_csv_chat with thread-safe key rotation
548
+ def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
549
+ global current_langchain_key_index, current_langchain_key_lock
550
+
551
+ data = clean_data(csv_url)
552
+ attempts = 0
553
+
554
+ while attempts < len(groq_api_keys):
555
+ with current_langchain_key_lock:
556
+ if current_langchain_key_index >= len(groq_api_keys):
557
+ current_langchain_key_index = 0
558
+ api_key = groq_api_keys[current_langchain_key_index]
559
+ current_key = current_langchain_key_index
560
+ current_langchain_key_index += 1
561
+ attempts += 1
562
+
563
+ try:
564
+ llm = ChatGroq(model=model_name, api_key=api_key)
565
+ tool = PythonAstREPLTool(locals={
566
+ "df": data,
567
+ "pd": pd,
568
+ "np": np,
569
+ "plt": plt,
570
+ "sns": sns,
571
+ "matplotlib": matplotlib
572
+ })
573
+
574
+ agent = create_pandas_dataframe_agent(
575
+ llm,
576
+ data,
577
+ agent_type="tool-calling",
578
+ verbose=True,
579
+ allow_dangerous_code=True,
580
+ extra_tools=[tool],
581
+ return_intermediate_steps=True
582
+ )
583
+
584
+ prompt = _prompt_generator(question, chart_required)
585
+ result = agent.invoke({"input": prompt})
586
+ return result.get("output")
587
+
588
+ except Exception as e:
589
+ print(f"Error with key index {current_key}: {str(e)}")
590
+
591
+ # If all keys are exhausted, return None
592
+ print("All API keys have been exhausted.")
593
+ return None
594
+
595
+
596
+ def handle_out_of_range_float(value):
597
+ if isinstance(value, float):
598
+ if np.isnan(value):
599
+ return None
600
+ elif np.isinf(value):
601
+ return "Infinity"
602
+ return value
603
+
604
+
605
+
606
+
607
+
608
+
609
+
610
+ # CHART CODING STARTS FROM HERE
611
+
612
+ instructions = """
613
+
614
+ - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
615
+ - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
616
+ - Use colorblind-friendly palette
617
+ - Read above instructions and follow them.
618
+
619
+ """
620
+
621
+ # Thread-safe configuration for chart endpoints
622
+ current_groq_chart_key_index = 0
623
+ current_groq_chart_lock = threading.Lock()
624
+
625
+ current_langchain_chart_key_index = 0
626
+ current_langchain_chart_lock = threading.Lock()
627
+
628
+ def model():
629
+ global current_groq_chart_key_index, current_groq_chart_lock
630
+ with current_groq_chart_lock:
631
+ if current_groq_chart_key_index >= len(groq_api_keys):
632
+ raise Exception("All API keys exhausted for chart generation")
633
+ api_key = groq_api_keys[current_groq_chart_key_index]
634
+ return ChatGroq(model=model_name, api_key=api_key)
635
+
636
+ def groq_chart(csv_url: str, question: str):
637
+ global current_groq_chart_key_index, current_groq_chart_lock
638
+
639
+ for attempt in range(len(groq_api_keys)):
640
+ try:
641
+ # Clean cache before processing
642
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
643
+ if os.path.exists(cache_db_path):
644
+ try:
645
+ os.remove(cache_db_path)
646
+ except Exception as e:
647
+ print(f"Cache cleanup error: {e}")
648
+
649
+ data = clean_data(csv_url)
650
+ with current_groq_chart_lock:
651
+ current_api_key = groq_api_keys[current_groq_chart_key_index]
652
+
653
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
654
+
655
+ # Generate unique filename using UUID
656
+ chart_filename = f"chart_{uuid.uuid4()}.png"
657
+ chart_path = os.path.join("generated_charts", chart_filename)
658
+
659
+ # Configure SmartDataframe with chart settings
660
+ df = SmartDataframe(
661
+ data,
662
+ config={
663
+ 'llm': llm,
664
+ 'save_charts': True, # Enable chart saving
665
+ 'open_charts': False,
666
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
667
+ 'custom_chart_filename': chart_filename # Unique filename
668
+ }
669
+ )
670
+
671
+ answer = df.chat(question + instructions)
672
+
673
+ if process_answer(answer):
674
+ return "Chart not generated"
675
+ return answer
676
+
677
+ except Exception as e:
678
+ error = str(e)
679
+ if "429" in error or error is not None:
680
+ with current_groq_chart_lock:
681
+ current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
682
+ else:
683
+ print(f"Chart generation error: {error}")
684
+ return {"error": error}
685
+
686
+ print("All API keys exhausted for chart generation")
687
+ return None
688
+
689
+
690
+
691
+ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
692
+ global current_langchain_chart_key_index, current_langchain_chart_lock
693
+
694
+ data = clean_data(csv_url)
695
+
696
+ for attempt in range(len(groq_api_keys)):
697
+ try:
698
+ with current_langchain_chart_lock:
699
+ api_key = groq_api_keys[current_langchain_chart_key_index]
700
+ current_key = current_langchain_chart_key_index
701
+ current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
702
+
703
+ llm = ChatGroq(model=model_name, api_key=api_key)
704
+ tool = PythonAstREPLTool(locals={
705
+ "df": data,
706
+ "pd": pd,
707
+ "np": np,
708
+ "plt": plt,
709
+ "sns": sns,
710
+ "matplotlib": matplotlib,
711
+ "uuid": uuid
712
+ })
713
+
714
+ agent = create_pandas_dataframe_agent(
715
+ llm,
716
+ data,
717
+ agent_type="tool-calling",
718
+ verbose=True,
719
+ allow_dangerous_code=True,
720
+ extra_tools=[tool],
721
+ return_intermediate_steps=True
722
+ )
723
+
724
+ result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)})
725
+ output = result.get("output", "")
726
+
727
+ # Verify chart file creation
728
+ chart_files = extract_chart_filenames(output)
729
+ if len(chart_files) > 0:
730
+ return chart_files
731
+
732
+ if attempt < len(groq_api_keys) - 1:
733
+ print(f"Langchain chart error (key {current_key}): {output}")
734
+
735
+ except Exception as e:
736
+ print(f"Langchain chart error (key {current_key}): {str(e)}")
737
+
738
+ print("All API keys exhausted for chart generation")
739
+ return None
740
+
741
+
742
+ ####################################### OpenAI + PandasAI #######################################
743
+
744
+
745
+
746
+
747
+ # Modified openai_chat function with thread-safe key rotation
748
+ openai_model_name = 'gpt-4o'
749
+
750
+ def openai_chat(csv_url: str, question: str):
751
+ global current_openai_key_index, current_openai_key_lock
752
+
753
+ while True:
754
+ with current_openai_key_lock:
755
+ if current_openai_key_index >= len(openai_api_keys):
756
+ return {"error": "All API keys exhausted."}
757
+ current_api_key = openai_api_keys[current_openai_key_index]
758
+
759
+ try:
760
+ # Delete cache file if exists
761
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
762
+ if os.path.exists(cache_db_path):
763
+ try:
764
+ os.remove(cache_db_path)
765
+ except Exception as e:
766
+ print(f"Error deleting cache DB file: {e}")
767
+
768
+ data = clean_data(csv_url)
769
+ llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
770
+ # Generate unique filename using UUID
771
+ chart_filename = f"chart_{uuid.uuid4()}.png"
772
+ chart_path = os.path.join("generated_charts", chart_filename)
773
 
774
+ # Configure SmartDataframe with chart settings
775
+ df = SmartDataframe(
776
+ data,
777
+ config={
778
+ 'llm': llm,
779
+ 'save_charts': True, # Enable chart saving
780
+ 'open_charts': False,
781
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
782
+ 'custom_chart_filename': chart_filename # Unique filename
783
+ }
784
  )
 
 
785
 
786
+ answer = df.chat(question)
787
+ # Process different response types
788
+ if isinstance(answer, pd.DataFrame):
789
+ processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
790
+ elif isinstance(answer, pd.Series):
791
+ processed = answer.apply(handle_out_of_range_float).to_dict()
792
+ elif isinstance(answer, list):
793
+ processed = [handle_out_of_range_float(item) for item in answer]
794
+ elif isinstance(answer, dict):
795
+ processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
796
+ else:
797
+ processed = {"answer": str(handle_out_of_range_float(answer))}
798
+
799
+ return processed
800
+
801
+ except Exception as e:
802
+ error_message = str(e)
803
+ if error_message:
804
+ with current_openai_key_lock:
805
+ current_openai_key_index += 1
806
+ if current_openai_key_index >= len(openai_api_keys):
807
+ print("All API keys exhausted.")
808
+ return None
809
+ else:
810
+ print(f"Error with API key index {current_openai_key_index}: {error_message}")
811
+ return None
812
+
813
+
814
+
815
+
816
+
817
+
818
+ def openai_chart(csv_url: str, question: str):
819
+ global current_openai_key_index, current_openai_key_lock
820
+
821
+ while True:
822
+ with current_openai_key_lock:
823
+ if current_openai_key_index >= len(openai_api_keys):
824
+ return {"error": "All API keys exhausted."}
825
+ current_api_key = openai_api_keys[current_openai_key_index]
826
+
827
+ try:
828
+ # Delete cache file if exists
829
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
830
+ if os.path.exists(cache_db_path):
831
+ try:
832
+ os.remove(cache_db_path)
833
+ except Exception as e:
834
+ print(f"Error deleting cache DB file: {e}")
835
+
836
+ data = clean_data(csv_url)
837
+ llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
838
+ # Generate unique filename using UUID
839
+ chart_filename = f"chart_{uuid.uuid4()}.png"
840
+ chart_path = os.path.join("generated_charts", chart_filename)
841
+
842
+ # Configure SmartDataframe with chart settings
843
+ df = SmartDataframe(
844
+ data,
845
+ config={
846
+ 'llm': llm,
847
+ 'save_charts': True, # Enable chart saving
848
+ 'open_charts': False,
849
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
850
+ 'custom_chart_filename': chart_filename # Unique filename
851
+ }
852
  )
853
+
854
+ answer = df.chat(question + instructions)
855
+
856
+ if process_answer(answer):
857
+ return "Chart not generated"
858
+ return answer
859
+
860
+ except Exception as e:
861
+ error = str(e)
862
+ print(f"Error with API key index {current_openai_key_index}: {error}")
863
+ if "429" in error or error is not None:
864
+ with current_openai_key_lock:
865
+ current_openai_key_index = (current_openai_key_index + 1) % len(openai_api_keys)
866
+ else:
867
+ print(f"Chart generation error: {error}")
868
+ return {"error": error}
869
+
870
+ print("All API keys exhausted for chart generation")
871
+ return None
872
+
873
+
874
+
875
+
876
+
877
+ ####################################### Start with lc_gemini #######################################
878
+
879
+
880
+ # async def csv_chat(csv_url: str, query: str):
881
+ # """
882
+ # Generate a response based on the provided CSV URL and query.
883
+ # Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback.
884
+
885
+ # Parameters:
886
+ # - csv_url (str): The URL of the CSV file.
887
+ # - query (str): The query for generating the response.
888
+
889
+ # Returns:
890
+ # - dict: A dictionary containing the generated response.
891
+
892
+ # Example:
893
+ # - csv_url: "https://example.com/data.csv"
894
+ # - query: "What is the total sales for the year 2022?"
895
+ # Returns:
896
+ # - dict: {"answer": "The total sales for 2022 is $100,000."}
897
+ # """
898
+ # try:
899
+ # updated_query = f"{query} and Do not show any charts or graphs."
900
+
901
+ # # --- 1. First Attempt: LangChain Gemini ---
902
+ # try:
903
+ # gemini_answer = await asyncio.to_thread(
904
+ # langchain_gemini_csv_handler, csv_url, updated_query, False
905
+ # )
906
+ # print("LangChain-Gemini answer:", gemini_answer)
907
+
908
+ # if not process_answer(gemini_answer) or gemini_answer is None:
909
+ # return {"answer": jsonable_encoder(gemini_answer)}
910
+
911
+ # raise Exception("LangChain-Gemini response not usable, falling back to LangChain-Groq")
912
+
913
+ # except Exception as gemini_error:
914
+ # print(f"LangChain-Gemini error: {str(gemini_error)}")
915
+
916
+ # # --- 2. Second Attempt: LangChain Groq ---
917
+ # try:
918
+ # lang_groq_answer = await asyncio.to_thread(
919
+ # langchain_csv_chat, csv_url, updated_query, False
920
+ # )
921
+ # print("LangChain-Groq answer:", lang_groq_answer)
922
+
923
+ # if not process_answer(lang_groq_answer):
924
+ # return {"answer": jsonable_encoder(lang_groq_answer)}
925
+
926
+ # raise Exception("LangChain-Groq response not usable, falling back to raw Groq")
927
+
928
+ # except Exception as lang_groq_error:
929
+ # print(f"LangChain-Groq error: {str(lang_groq_error)}")
930
+
931
+ # # --- 3. Final Attempt: Raw OpenAI Chat ---
932
+ # try:
933
+ # raw_openai_answer = await asyncio.to_thread(openai_chat, csv_url, updated_query)
934
+ # print("Raw OpenAI answer:", raw_openai_answer)
935
+
936
+ # if process_answer(raw_openai_answer) == "Empty response received." or raw_openai_answer is None:
937
+ # return {"answer": "Sorry, I couldn't find relevant data..."}
938
+
939
+ # if process_answer(raw_openai_answer):
940
+ # except Exception as openai_exception:
941
+ # print(f"Raw OpenAI error: {str(openai_exception)}")
942
+
943
+
944
+ # # --- 4. Final Attempt: Raw Groq Chat ---
945
+ # try:
946
+ # raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
947
+ # print("Raw Groq answer:", raw_groq_answer)
948
+
949
+ # if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None:
950
+ # return {"answer": "Sorry, I couldn't find relevant data..."}
951
+
952
+ # if process_answer(raw_groq_answer):
953
+ # raise Exception("All fallbacks exhausted")
954
+
955
+ # return {"answer": jsonable_encoder(raw_groq_answer)}
956
+
957
+ # except Exception as raw_groq_error:
958
+ # print(f"Raw Groq error: {str(raw_groq_error)}")
959
+ # return {"answer": "error"}
960
+
961
+ # except Exception as e:
962
+ # print(f"Unexpected error: {str(e)}")
963
+ # return {"answer": "error"}
964
+
965
+
966
+
967
+
968
+
969
+
970
+
971
+
972
+ # async def csv_chart(csv_url: str, query: str):
973
+ # """
974
+ # Generate a chart based on the provided CSV URL and query.
975
+ # Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
976
+
977
+ # Parameters:
978
+ # - csv_url (str): The URL of the CSV file.
979
+ # - query (str): The query for generating the chart.
980
+
981
+ # Returns:
982
+ # - dict: A dictionary containing either:
983
+ # - {"image_url": "https://example.com/chart.png"} on success, or
984
+ # - {"error": "error message"} on failure
985
+
986
+ # Example:
987
+ # - csv_url: "https://example.com/data.csv"
988
+ # - query: "Show sales trends as a line chart"
989
+ # Returns:
990
+ # - dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
991
+ # """
992
+
993
+ # async def upload_and_return(image_path: str) -> dict:
994
+ # """Helper function to handle image uploads"""
995
+ # unique_name = f'{uuid.uuid4()}.png'
996
+ # public_url = await upload_file_to_supabase(image_path, unique_name)
997
+ # print(f"Uploaded chart: {public_url}")
998
+ # os.remove(image_path) # Remove the local image file after upload
999
+ # return {"image_url": public_url}
1000
+
1001
+ # try:
1002
+ # # --- 1. First Attempt: Raw OpenAI ---
1003
+ # try:
1004
+ # openai_result = await asyncio.to_thread(openai_chart, csv_url, query)
1005
+ # print(f"OpenAI chart result:", openai_result)
1006
+
1007
+ # if openai_result and openai_result != 'Chart not generated':
1008
+ # return await upload_and_return(openai_result)
1009
+
1010
+ # raise Exception("OpenAI failed to generate chart")
1011
+
1012
+ # except Exception as openai_error:
1013
+ # print(f"OpenAI failed ({str(openai_error)}), trying LangChain Gemini...")
1014
+
1015
+ # # --- 2.. First Attempt: Raw Groq ---
1016
+ # try:
1017
+ # groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
1018
+ # print(f"Raw Groq chart result:", groq_result)
1019
+
1020
+ # if groq_result and groq_result != 'Chart not generated':
1021
+ # return await upload_and_return(groq_result)
1022
+
1023
+ # raise Exception("Raw Groq failed to generate chart")
1024
+
1025
+ # except Exception as groq_error:
1026
+ # print(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...")
1027
+
1028
+ # # --- 3. Second Attempt: LangChain Gemini ---
1029
+ # try:
1030
+ # gemini_result = await asyncio.to_thread(
1031
+ # langchain_gemini_csv_handler, csv_url, query, True
1032
+ # )
1033
+ # print("LangChain Gemini chart result:", gemini_result)
1034
+
1035
+ # # --- i) If Gemini result is a string, return it ---
1036
+ # if gemini_result and isinstance(gemini_result, str):
1037
+ # clean_path = gemini_result.strip()
1038
+ # return await upload_and_return(clean_path)
1039
+
1040
+ # # --- ii) If Gemini result is a list, return the first element ---
1041
+ # if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
1042
+ # return await upload_and_return(gemini_result[0])
1043
+
1044
+ # raise Exception("LangChain Gemini returned empty result")
1045
+
1046
+ # except Exception as gemini_error:
1047
+ # print(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...")
1048
+
1049
+ # # --- 4. Final Attempt: LangChain Groq ---
1050
+ # try:
1051
+ # lc_groq_paths = await asyncio.to_thread(
1052
+ # langchain_csv_chart, csv_url, query, True
1053
+ # )
1054
+ # print("LangChain Groq chart result:", lc_groq_paths)
1055
+
1056
+ # if isinstance(lc_groq_paths, list) and lc_groq_paths:
1057
+ # return await upload_and_return(lc_groq_paths[0])
1058
+
1059
+ # return {"error": "All chart generation methods failed"}
1060
+
1061
+ # except Exception as lc_groq_error:
1062
+ # print(f"LangChain Groq failed: {str(lc_groq_error)}")
1063
+ # return {"error": "Could not generate chart"}
1064
+
1065
+ # except Exception as e:
1066
+ # print(f"Critical error: {str(e)}")
1067
+ # return {"error": "Internal system error"}
1068
+
1069
+
1070
+
1071
+ ####################################### Optimized Version #######################################
1072
+
1073
+
1074
+ async def csv_chat(csv_url: str, query: str) -> Dict[str, Any]:
1075
+ """
1076
+ Generate a response based on the provided CSV URL and query.
1077
+ Prioritizes LangChain-Gemini, then LangChain-Groq, then raw OpenAI and finally raw Groq as fallback.
1078
+
1079
+ Parameters:
1080
+ - csv_url (str): The URL of the CSV file.
1081
+ - query (str): The query for generating the response.
1082
+
1083
+ Returns:
1084
+ - dict: A dictionary containing the generated response or error message.
1085
+
1086
+ Example:
1087
+ - csv_url: "https://example.com/data.csv"
1088
+ - query: "What is the total sales for the year 2022?"
1089
+ Returns:
1090
+ - dict: {"answer": "The total sales for 2022 is $100,000."}
1091
+ """
1092
+ updated_query = f"{query} and Do not show any charts or graphs."
1093
+ fallback_answer = "Sorry, I couldn't find relevant data..."
1094
+ error_answer = "An error occurred while processing your request."
1095
+
1096
+ async def try_chat_method(method_name: str, method, *args) -> Dict[str, Any]:
1097
+ """Attempt to get answer from a specific chat method"""
1098
+ try:
1099
+ logger.info(f"Attempting {method_name}")
1100
+ answer = await asyncio.to_thread(method, *args)
1101
+
1102
+ if answer is None:
1103
+ logger.warning(f"{method_name} returned None")
1104
+ return {"status": "empty", "answer": None}
1105
+
1106
+ processed = process_answer(answer)
1107
+ if processed == "Empty response received.":
1108
+ logger.warning(f"{method_name} returned empty response")
1109
+ return {"status": "empty", "answer": answer}
1110
+ elif processed:
1111
+ logger.warning(f"{method_name} response not usable")
1112
+ return {"status": "invalid", "answer": answer}
1113
+ else:
1114
+ logger.info(f"{method_name} succeeded")
1115
+ return {"status": "success", "answer": answer}
1116
+
1117
+ except Exception as e:
1118
+ logger.error(f"{method_name} failed: {str(e)}")
1119
+ return {"status": "error", "error": str(e)}
1120
+
1121
+ # Define the methods to try in priority order
1122
+ chat_methods = [
1123
+ ("LangChain-Gemini", langchain_gemini_csv_handler, csv_url, updated_query, False),
1124
+ ("LangChain-Groq", langchain_csv_chat, csv_url, updated_query, False),
1125
+ ("Raw OpenAI", openai_chat, csv_url, updated_query),
1126
+ ("Raw Groq", groq_chat, csv_url, updated_query),
1127
+ ]
1128
+
1129
+ for method_name, method, *args in chat_methods:
1130
+ result = await try_chat_method(method_name, method, *args)
1131
+
1132
+ if result["status"] == "success":
1133
+ return {"answer": jsonable_encoder(result["answer"])}
1134
+ elif result["status"] == "empty":
1135
+ return {"answer": fallback_answer}
1136
+
1137
+ # If all methods failed or returned invalid responses
1138
+ logger.error("All chat methods failed to produce a valid response")
1139
+ return {"answer": error_answer}
1140
+
1141
+
1142
+
1143
+
1144
+
1145
+
1146
+
1147
+
1148
+
1149
+ async def csv_chart(csv_url: str, query: str) -> Dict[str, str]:
1150
+ """
1151
+ Generate a chart based on the provided CSV URL and query.
1152
+ Prioritizes raw OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
1153
+
1154
+ Parameters:
1155
+ - csv_url (str): The URL of the CSV file.
1156
+ - query (str): The query for generating the chart.
1157
+
1158
+ Returns:
1159
+ - dict: A dictionary containing either:
1160
+ - {"image_url": "https://example.com/chart.png"} on success, or
1161
+ - {"error": "error message"} on failure
1162
+
1163
+ Example:
1164
+ - csv_url: "https://example.com/data.csv"
1165
+ - query: "Show sales trends as a line chart"
1166
+ Returns:
1167
+ - dict: {"image_url": "https://storage.example.com/chart_uuid.png"}
1168
+ """
1169
+ async def upload_and_return(image_path: str) -> Dict[str, str]:
1170
+ """Helper function to handle image uploads and cleanup"""
1171
+ try:
1172
+ if not os.path.exists(image_path):
1173
+ raise FileNotFoundError(f"Image file not found at {image_path}")
1174
+
1175
+ unique_name = f'{uuid.uuid4()}.png'
1176
+ public_url = await upload_file_to_supabase(image_path, unique_name)
1177
+ logger.info(f"Uploaded chart: {public_url}")
1178
+
1179
+ try:
1180
+ os.remove(image_path)
1181
+ except OSError as e:
1182
+ logger.warning(f"Failed to remove local image file: {e}")
1183
+
1184
+ return {"image_url": public_url}
1185
+ except Exception as e:
1186
+ logger.error(f"Error in upload_and_return: {e}")
1187
+ raise e
1188
+
1189
+ async def try_generation(method_name: str, method, *args) -> Union[str, None]:
1190
+ """Attempt chart generation with a specific method"""
1191
+ try:
1192
+ logger.info(f"Attempting chart generation with {method_name}")
1193
+ result = await asyncio.to_thread(method, *args)
1194
+
1195
+ if not result or result == 'Chart not generated':
1196
+ raise ValueError(f"{method_name} returned empty or invalid result")
1197
+
1198
+ if isinstance(result, str):
1199
+ return result.strip()
1200
+ elif isinstance(result, list) and result:
1201
+ return result[0]
1202
+
1203
+ raise ValueError(f"{method_name} returned unexpected result type")
1204
+ except Exception as e:
1205
+ logger.warning(f"{method_name} failed: {str(e)}")
1206
+ return None
1207
+
1208
+ generation_methods = [
1209
+ ("Raw OpenAI", openai_chart, csv_url, query),
1210
+ ("Raw Groq", groq_chart, csv_url, query),
1211
+ ("LangChain Gemini", lambda u, q: langchain_gemini_csv_handler(u, q, True), csv_url, query),
1212
+ ("LangChain Groq", lambda u, q: langchain_csv_chart(u, q, True), csv_url, query),
1213
+ ]
1214
+
1215
+ for attempt, (method_name, method, *args) in enumerate(generation_methods, 1):
1216
+ try:
1217
+ result = await try_generation(method_name, method, *args)
1218
+ if result:
1219
+ return await upload_and_return(result)
1220
+ except Exception as e:
1221
+ logger.error(f"Error processing {method_name}: {e}")
1222
+ if attempt == len(generation_methods):
1223
+ logger.error("All chart generation methods failed")
1224
+ return {"error": "Could not generate chart using any available method"}
1225
+
1226
+ return {"error": "All chart generation methods failed"}
1227
 
1228
 
1229
+ # Example usage:
1230
 
1231
+ # csv_url = './documents/titanic.csv'
1232
+ # query = "Create a pie chart of male vs female passengers?"
1233
+ # result = openai_chart(csv_url, query)
1234
+ # print(result)