Soumik555 commited on
Commit
27ef145
·
1 Parent(s): 4ee7a1a

added gemini too

Browse files
Files changed (1) hide show
  1. gemini_report_generator.py +295 -0
gemini_report_generator.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import pandas as pd
4
+ import re
5
+ import os
6
+ import uuid
7
+ import logging
8
+ from io import StringIO
9
+ import sys
10
+ import traceback
11
+ from typing import Optional, Dict, Any, List
12
+ from pydantic import BaseModel, Field
13
+ from google.generativeai import GenerativeModel, configure
14
+ from dotenv import load_dotenv
15
+ import seaborn as sns
16
+ import datetime as dt
17
+
18
+ pd.set_option('display.max_columns', None)
19
+ pd.set_option('display.max_rows', None)
20
+ pd.set_option('display.max_colwidth', None)
21
+
22
+ load_dotenv()
23
+
24
+
25
+ API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")[::-1]
26
+ MODEL_NAME = 'gemini-2.0-flash'
27
+
28
+ os.environ['MPLBACKEND'] = 'agg'
29
+ import matplotlib.pyplot as plt
30
+ plt.show = lambda: None
31
+
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(levelname)s - %(message)s'
35
+ )
36
+ logger = logging.getLogger(__name__)
37
+
38
+ class GeminiKeyManager:
39
+ """Manage multiple Gemini API keys with failover"""
40
+
41
+ def __init__(self, api_keys: List[str]):
42
+ self.original_keys = api_keys.copy()
43
+ self.available_keys = api_keys.copy()
44
+ self.active_key = None
45
+ self.failed_keys = {}
46
+
47
+ def configure(self) -> bool:
48
+ while self.available_keys:
49
+ key = self.available_keys.pop(0)
50
+ try:
51
+ configure(api_key=key)
52
+ self.active_key = key
53
+ logger.info(f"Configured with key: {self._mask_key(key)}")
54
+ return True
55
+ except Exception as e:
56
+ self.failed_keys[key] = str(e)
57
+ logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
58
+ logger.critical("All API keys failed")
59
+ return False
60
+
61
+ def _mask_key(self, key: str) -> str:
62
+ return f"{key[:8]}...{key[-4:]}" if key else ""
63
+
64
+ class PythonREPL:
65
+ """Secure Python REPL with file generation tracking"""
66
+
67
+ def __init__(self, df: pd.DataFrame):
68
+ self.df = df
69
+ self.output_dir = os.path.abspath(f'generated_outputs/{uuid.uuid4()}')
70
+ os.makedirs(self.output_dir, exist_ok=True)
71
+ self.local_env = {
72
+ "pd": pd,
73
+ "df": self.df.copy(),
74
+ "plt": plt,
75
+ "os": os,
76
+ "uuid": uuid,
77
+ "sns": sns,
78
+ "json": json,
79
+ "dt": dt,
80
+ "output_dir": self.output_dir
81
+ }
82
+
83
+ def execute(self, code: str) -> Dict[str, Any]:
84
+ print('Executing code...', code)
85
+ old_stdout = sys.stdout
86
+ sys.stdout = mystdout = StringIO()
87
+ file_tracker = {
88
+ 'csv_files': set(),
89
+ 'image_files': set()
90
+ }
91
+
92
+ try:
93
+ code = f"""
94
+ import matplotlib.pyplot as plt
95
+ plt.switch_backend('agg')
96
+ {code}
97
+ plt.close('all')
98
+ """
99
+ exec(code, self.local_env)
100
+ self.df = self.local_env.get('df', self.df)
101
+
102
+ # Track generated files
103
+ for fname in os.listdir(self.output_dir):
104
+ if fname.endswith('.csv'):
105
+ file_tracker['csv_files'].add(fname)
106
+ elif fname.lower().endswith(('.png', '.jpg', '.jpeg')):
107
+ file_tracker['image_files'].add(fname)
108
+
109
+ error = False
110
+ except Exception as e:
111
+ error_msg = traceback.format_exc()
112
+ error = True
113
+ finally:
114
+ sys.stdout = old_stdout
115
+
116
+ return {
117
+ "output": mystdout.getvalue(),
118
+ "error": error,
119
+ "error_message": error_msg if error else None,
120
+ "df": self.local_env.get('df', self.df),
121
+ "output_dir": self.output_dir,
122
+ "files": {
123
+ "csv": [os.path.join(self.output_dir, f) for f in file_tracker['csv_files']],
124
+ "images": [os.path.join(self.output_dir, f) for f in file_tracker['image_files']]
125
+ }
126
+ }
127
+
128
+ class RethinkAgent(BaseModel):
129
+ df: pd.DataFrame
130
+ max_retries: int = Field(default=5, ge=1)
131
+ gemini_model: Optional[GenerativeModel] = None
132
+ current_retry: int = Field(default=0, ge=0)
133
+ repl: Optional[PythonREPL] = None
134
+ key_manager: Optional[GeminiKeyManager] = None
135
+
136
+ class Config:
137
+ arbitrary_types_allowed = True
138
+
139
+ def _extract_code(self, response: str) -> str:
140
+ code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
141
+ return code_match.group(1).strip() if code_match else response.strip()
142
+
143
+ def _generate_initial_prompt(self, query: str) -> str:
144
+ return f"""Generate DIRECT EXECUTION CODE (no functions, no explanations) following STRICT RULES:
145
+
146
+ MANDATORY REQUIREMENTS:
147
+ 1. Operate directly on existing 'df' variable
148
+ 2. Save ALL final DataFrames to CSV using: df.to_csv(f'{{output_dir}}/descriptive_name.csv')
149
+ 3. For visualizations: plt.savefig(f'{{output_dir}}/chart_name.png')
150
+ 4. Use EXACTLY this structure:
151
+ # Data processing
152
+ df_processed = df[...] # filtering/grouping
153
+ # Save results
154
+ df_processed.to_csv(f'{{output_dir}}/result.csv')
155
+ # Visualizations (if needed)
156
+ plt.figure()
157
+ ... plotting code ...
158
+ plt.savefig(f'{{output_dir}}/chart.png')
159
+ plt.close()
160
+
161
+ FORBIDDEN:
162
+ - Function definitions
163
+ - Dummy data creation
164
+ - Any code blocks besides pandas operations and matplotlib
165
+ - Print statements showing dataframes
166
+
167
+ DATAFRAME COLUMNS: {', '.join(self.df.columns)}
168
+ USER QUERY: {query}
169
+
170
+ EXAMPLE RESPONSE FOR "Sales by region":
171
+ # Data processing
172
+ sales_by_region = df.groupby('region')['sales'].sum().reset_index()
173
+ # Save results
174
+ sales_by_region.to_csv(f'{{output_dir}}/sales_by_region.csv')
175
+ """
176
+
177
+ def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
178
+ return f"""FIX THIS CODE (failed with: {error}) by STRICTLY FOLLOWING:
179
+
180
+ 1. REMOVE ALL FUNCTION DEFINITIONS
181
+ 2. ENSURE DIRECT DF OPERATIONS
182
+ 3. USE EXPLICIT output_dir PATHS
183
+ 4. ADD NECESSARY IMPORTS IF MISSING
184
+ 5. VALIDATE COLUMN NAMES EXIST
185
+
186
+ BAD CODE:
187
+ {code}
188
+
189
+ CORRECTED CODE:"""
190
+
191
+ def initialize_model(self, api_keys: List[str]) -> bool:
192
+ self.key_manager = GeminiKeyManager(api_keys)
193
+ if not self.key_manager.configure():
194
+ raise RuntimeError("API key initialization failed")
195
+ try:
196
+ self.gemini_model = GenerativeModel(MODEL_NAME)
197
+ return True
198
+ except Exception as e:
199
+ logger.error(f"Model init failed: {str(e)}")
200
+ return False
201
+
202
+ def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
203
+ prompt = self._generate_retry_prompt(query, error, previous_code) if error else self._generate_initial_prompt(query)
204
+ try:
205
+ response = self.gemini_model.generate_content(prompt)
206
+ return self._extract_code(response.text)
207
+ except Exception as e:
208
+ if self.key_manager.available_keys and self.key_manager.configure():
209
+ return self.generate_code(query, error, previous_code)
210
+ raise
211
+
212
+ def execute_query(self, query: str) -> Dict[str, Any]:
213
+ self.repl = PythonREPL(self.df)
214
+ result = None
215
+
216
+ while self.current_retry < self.max_retries:
217
+ try:
218
+ code = self.generate_code(query,
219
+ result["error_message"] if result else None,
220
+ result["code"] if result else None)
221
+ execution_result = self.repl.execute(code)
222
+
223
+ if execution_result["error"]:
224
+ self.current_retry += 1
225
+ result = {
226
+ "error_message": execution_result["error_message"],
227
+ "code": code
228
+ }
229
+ else:
230
+ return {
231
+ "text": execution_result["output"],
232
+ "csv_files": execution_result["files"]["csv"],
233
+ "image_files": execution_result["files"]["images"]
234
+ }
235
+ except Exception as e:
236
+ return {
237
+ "error": f"Critical failure: {str(e)}",
238
+ "csv_files": [],
239
+ "image_files": []
240
+ }
241
+
242
+ return {
243
+ "error": f"Failed after {self.max_retries} retries",
244
+ "csv_files": [],
245
+ "image_files": []
246
+ }
247
+
248
+ def gemini_llm_chat(csv_url: str, query: str) -> Dict[str, Any]:
249
+ try:
250
+ df = pd.read_csv(csv_url)
251
+ agent = RethinkAgent(df=df)
252
+
253
+ if not agent.initialize_model(API_KEYS):
254
+ return {"error": "API configuration failed"}
255
+
256
+ result = agent.execute_query(query)
257
+
258
+ if "error" in result:
259
+ return result
260
+
261
+ return {
262
+ "message": result["text"],
263
+ "csv_files": result["csv_files"],
264
+ "image_files": result["image_files"]
265
+ }
266
+ except Exception as e:
267
+ logger.error(f"Processing failed: {str(e)}")
268
+ return {
269
+ "error": f"Processing error: {str(e)}",
270
+ "csv_files": [],
271
+ "image_files": []
272
+ }
273
+
274
+
275
+ def generate_csv_report(csv_url: str, query: str):
276
+ try:
277
+ result = gemini_llm_chat(csv_url, query)
278
+ json_result = json.dumps(result, indent=2)
279
+ logger.info(f"Report generated successfully: {json_result}")
280
+ return json_result
281
+
282
+ except Exception as e:
283
+ logger.error(f"Report generation failed: {str(e)}")
284
+ return {
285
+ "error": f"Report generation error: {str(e)}",
286
+ "csv_files": [],
287
+ "image_files": []
288
+ }
289
+
290
+
291
+
292
+ # if __name__ == "__main__":
293
+ # result = gemini_llm_chat("./documents/enterprise_sales_data.csv",
294
+ # "Generate a detailed sales report of the last 6 months from all the aspects and include a bar chart showing the sales by region.")
295
+ # print(json.dumps(result, indent=2))