Soumik555 commited on
Commit
aa0bf91
·
1 Parent(s): 4167849

use supabase

Browse files
controller.py CHANGED
@@ -26,6 +26,7 @@ import matplotlib.pyplot as plt
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
 
29
  from util_service import _prompt_generator, process_answer
30
  from fastapi.middleware.cors import CORSMiddleware
31
  import matplotlib
@@ -128,7 +129,12 @@ async def get_image(request: ImageRequest, authorization: str = Header(None)):
128
 
129
  try:
130
  image_file_path = request.image_path
131
- return FileResponse(image_file_path, media_type="image/png")
 
 
 
 
 
132
  except Exception as e:
133
  logger.error(f"Error: {e}")
134
  return {"answer": "error"}
@@ -789,7 +795,12 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
789
  )
790
  logger.info("Langchain chart result:", langchain_result)
791
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
792
- return FileResponse(langchain_result[0], media_type="image/png")
 
 
 
 
 
793
 
794
  # Next, try the groq-based method
795
  groq_result = await loop.run_in_executor(
@@ -797,7 +808,12 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
797
  )
798
  logger.info(f"Groq chart result: {groq_result}")
799
  if isinstance(groq_result, str) and groq_result != "Chart not generated":
800
- return FileResponse(groq_result, media_type="image/png")
 
 
 
 
 
801
 
802
  # Fallback: try langchain-based again
803
  logger.error("Groq chart generation failed, trying langchain....")
@@ -806,7 +822,12 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
806
  )
807
  logger.info("Fallback langchain chart result:", langchain_paths)
808
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
809
- return FileResponse(langchain_paths[0], media_type="image/png")
 
 
 
 
 
810
  else:
811
  logger.error("All chart generation methods failed")
812
  return {"answer": "error"}
 
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
29
+ from supabase_service import upload_image_to_supabase
30
  from util_service import _prompt_generator, process_answer
31
  from fastapi.middleware.cors import CORSMiddleware
32
  import matplotlib
 
129
 
130
  try:
131
  image_file_path = request.image_path
132
+ unique_file_name =f'{str(uuid.uuid4())}.png'
133
+ logger.info("Uploading the chart to supabase...")
134
+ image_public_url = await upload_image_to_supabase(f"{image_file_path}", unique_file_name)
135
+ logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
136
+ return {"image_url": image_public_url}
137
+ # return FileResponse(image_file_path, media_type="image/png")
138
  except Exception as e:
139
  logger.error(f"Error: {e}")
140
  return {"answer": "error"}
 
795
  )
796
  logger.info("Langchain chart result:", langchain_result)
797
  if isinstance(langchain_result, list) and len(langchain_result) > 0:
798
+ unique_file_name =f'{str(uuid.uuid4())}.png'
799
+ logger.info("Uploading the chart to supabase...")
800
+ image_public_url = await upload_image_to_supabase(f"{langchain_result[0]}", unique_file_name)
801
+ logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
802
+ return {"image_url": image_public_url}
803
+ # return FileResponse(langchain_result[0], media_type="image/png")
804
 
805
  # Next, try the groq-based method
806
  groq_result = await loop.run_in_executor(
 
808
  )
809
  logger.info(f"Groq chart result: {groq_result}")
810
  if isinstance(groq_result, str) and groq_result != "Chart not generated":
811
+ unique_file_name =f'{str(uuid.uuid4())}.png'
812
+ logger.info("Uploading the chart to supabase...")
813
+ image_public_url = await upload_image_to_supabase(f"{groq_result}", unique_file_name)
814
+ logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
815
+ return {"image_url": image_public_url}
816
+ # return FileResponse(groq_result, media_type="image/png")
817
 
818
  # Fallback: try langchain-based again
819
  logger.error("Groq chart generation failed, trying langchain....")
 
822
  )
823
  logger.info("Fallback langchain chart result:", langchain_paths)
824
  if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
825
+ unique_file_name =f'{str(uuid.uuid4())}.png'
826
+ logger.info("Uploading the chart to supabase...")
827
+ image_public_url = await upload_image_to_supabase(f"{langchain_paths[0]}", unique_file_name)
828
+ logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
829
+ return {"image_url": image_public_url}
830
+ # return FileResponse(langchain_paths[0], media_type="image/png")
831
  else:
832
  logger.error("All chart generation methods failed")
833
  return {"answer": "error"}
requirements.txt CHANGED
@@ -12,4 +12,5 @@ langchain_experimental==0.3.3
12
  tabulate==0.9.0
13
  gradio>=4.0.0
14
  google-generativeai==0.8.3
15
- langchain-google-genai==2.0.7
 
 
12
  tabulate==0.9.0
13
  gradio>=4.0.0
14
  google-generativeai==0.8.3
15
+ langchain-google-genai==2.0.7
16
+ supabase==2.13.0
rethink_gemini_agents/gemini_langchain_service.py DELETED
@@ -1,205 +0,0 @@
1
- import os
2
- import re
3
- import uuid
4
- from langchain_google_genai import ChatGoogleGenerativeAI
5
- import pandas as pd
6
- from langchain_core.prompts import ChatPromptTemplate
7
- from langchain_experimental.tools import PythonAstREPLTool
8
- from langchain_experimental.agents import create_pandas_dataframe_agent
9
- from dotenv import load_dotenv
10
- import numpy as np
11
- import matplotlib.pyplot as plt
12
- import matplotlib
13
- import seaborn as sns
14
-
15
-
16
- # Set the backend for matplotlib to 'Agg' to avoid GUI issues
17
- matplotlib.use('Agg')
18
-
19
- load_dotenv()
20
- model_name = os.getenv("GOOGLE_GENERATIVE_AI_MODEL_LANGCHAIN_AGENT")
21
- google_api_keys = os.getenv("GOOGLE_GENERATIVE_AI_API_KEYS").split(",")
22
- current_key_index = 0 # Global index for API keys
23
-
24
-
25
- def _prompt_generator(question: str, chart_required: bool):
26
-
27
- chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
28
-
29
- 1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
30
- 2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
31
- 3. **Communication:** Provide concise, professional, and well-structured responses.
32
- 4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
33
-
34
- **Query:** {question}
35
-
36
- """
37
-
38
- chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
39
-
40
- 1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
41
- 2. Visualization requirements:
42
- - Adjust font sizes, rotate labels (45° if needed), truncate for readability
43
- - Figure size: (12, 6)
44
- - Descriptive titles (fontsize=14)
45
- - Colorblind-friendly palettes
46
- 3. File handling rules:
47
- - Create MAXIMUM 2 charts if absolutely necessary
48
- - For multiple charts:
49
- * Arrange in grid format (2x1 vertical layout preferred)
50
- * Use SAME unique_id with suffixes:
51
- - f"{{unique_id}}_1.png"
52
- - f"{{unique_id}}_2.png"
53
- - Save EXCLUSIVELY to "generated_charts" folder
54
- - File naming: f"chart_{{unique_id}}.png" (for single chart)
55
- 4. FINAL OUTPUT MUST BE:
56
- - For single chart: f"generated_charts/chart_{{unique_id}}.png"
57
- - For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
58
- - ONLY return this full path string, nothing else
59
-
60
- **Query:** {question}
61
-
62
- IMPORTANT:
63
- - Generate the unique_id FIRST before any operations
64
- - Use THE SAME unique_id throughout entire process
65
- - NEVER generate new UUIDs after initial creation
66
- - Return EXACT filepath string of the final saved chart
67
- """
68
-
69
-
70
- if chart_required:
71
- return ChatPromptTemplate.from_template(chart_prompt)
72
- else:
73
- return ChatPromptTemplate.from_template(chat_prompt)
74
-
75
-
76
-
77
- def langchain_gemini_csv_chat(csv_url: str, question: str, chart_required: bool):
78
- global current_key_index
79
-
80
- data = pd.read_csv(csv_url)
81
- # Try each API key until a successful response is generated or keys run out
82
- attempts = 0
83
- total_keys = len(google_api_keys)
84
- while attempts < total_keys:
85
- try:
86
- # Select the current API key
87
- api_key = google_api_keys[current_key_index]
88
- print(f"Using API key index {current_key_index}")
89
-
90
- # Initialize the LLM with the current API key
91
- llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
92
-
93
- # Prepare the Python REPL tool with the dataframe and necessary libraries
94
- tool = PythonAstREPLTool(locals={
95
- "df": data,
96
- "pd": pd,
97
- "np": np,
98
- "plt": plt, # Ensure plt is available
99
- "sns": sns,
100
- "matplotlib": matplotlib,
101
- "uuid": uuid,
102
- })
103
-
104
- # Create the pandas agent with the provided tools and settings
105
- agent = create_pandas_dataframe_agent(
106
- llm,
107
- data,
108
- agent_type="openai-tools",
109
- verbose=True,
110
- allow_dangerous_code=True,
111
- extra_tools=[tool],
112
- return_intermediate_steps=True
113
- )
114
-
115
- chat_prompt = _prompt_generator(question, chart_required)
116
- # Attempt to invoke the agent with the question
117
- result = agent.invoke({"input": chat_prompt})
118
- # If successful, return the output
119
- return result.get("output")
120
-
121
- except Exception as e:
122
- # Log the error along with the current API key index
123
- print(f"Error using API key index {current_key_index}: {e}")
124
-
125
- # Move to the next API key
126
- current_key_index += 1
127
- attempts += 1
128
-
129
- # If all keys have been exhausted, exit the loop
130
- if current_key_index >= total_keys:
131
- print("All API keys have been exhausted.")
132
- return None
133
-
134
-
135
-
136
-
137
- def langchain_gemini_csv_chart(csv_url: str, question: str, chart_required: bool):
138
- global current_key_index
139
- data = pd.read_csv(csv_url)
140
-
141
- # Try each API key until a successful response is generated or keys run out
142
- attempts = 0
143
- total_keys = len(google_api_keys)
144
- while attempts < total_keys:
145
- try:
146
- # Select the current API key
147
- api_key = google_api_keys[current_key_index]
148
- print(f"Using API key index {current_key_index}")
149
-
150
- # Initialize the LLM with the current API key
151
- llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
152
-
153
- # Prepare the Python REPL tool with the dataframe and necessary libraries
154
- tool = PythonAstREPLTool(locals={
155
- "df": data,
156
- "pd": pd,
157
- "np": np,
158
- "plt": plt, # Ensure plt is available
159
- "sns": sns,
160
- "matplotlib": matplotlib
161
- })
162
-
163
- # Create the pandas agent with the provided tools and settings
164
- agent = create_pandas_dataframe_agent(
165
- llm,
166
- data,
167
- agent_type="openai-tools",
168
- verbose=True,
169
- allow_dangerous_code=True,
170
- extra_tools=[tool],
171
- return_intermediate_steps=True
172
- )
173
-
174
- chart_prompt = _prompt_generator(question, chart_required)
175
- # Attempt to invoke the agent with the question
176
- result = agent.invoke({"input": chart_prompt})
177
- # If successful, return the output
178
- return result.get("output")
179
-
180
- except Exception as e:
181
- # Log the error along with the current API key index
182
- print(f"Error using API key index {current_key_index}: {e}")
183
-
184
- # Move to the next API key
185
- current_key_index += 1
186
- attempts += 1
187
-
188
- # If all keys have been exhausted, exit the loop
189
- if current_key_index >= total_keys:
190
- print("All API keys have been exhausted.")
191
- return None
192
-
193
-
194
-
195
-
196
-
197
-
198
-
199
- # Example usage:
200
- if __name__ == "__main__":
201
- csv_url = "./documents/titanic.csv"
202
- question = "Create a pie chart of males vs females"
203
- output = langchain_gemini_csv_chat(csv_url, question, True)
204
- print("Agent output:", output)
205
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rethink_gemini_agents/rethink_chart.py DELETED
@@ -1,266 +0,0 @@
1
- import pandas as pd
2
- import re
3
- import os
4
- import uuid
5
- import logging
6
- from io import StringIO
7
- import sys
8
- import traceback
9
- from typing import Optional, Dict, Any, List
10
- from pydantic import BaseModel, Field
11
- from google.generativeai import GenerativeModel, configure
12
- from dotenv import load_dotenv
13
-
14
- # Load environment variables from .env file
15
- load_dotenv()
16
-
17
- API_KEYS = os.getenv("GOOGLE_GENERATIVE_AI_API_KEYS", "").split(",")
18
- MODEL_NAME = os.getenv("GOOGLE_GENERATIVE_AI_MODEL")
19
-
20
- # Set up non-interactive matplotlib backend
21
- os.environ['MPLBACKEND'] = 'agg'
22
- import matplotlib.pyplot as plt
23
- plt.show = lambda: None # Monkey patch to disable display
24
-
25
- # Configure logging
26
- logging.basicConfig(
27
- level=logging.INFO,
28
- format='%(asctime)s - %(levelname)s - %(message)s',
29
- handlers=[logging.FileHandler('api_key_rotation.log'), logging.StreamHandler()]
30
- )
31
- logger = logging.getLogger(__name__)
32
-
33
- class GeminiKeyManager:
34
- """Manage multiple Gemini API keys with failover"""
35
-
36
- def __init__(self, api_keys: List[str]):
37
- self.original_keys = api_keys.copy()
38
- self.available_keys = api_keys.copy()
39
- self.active_key = None
40
- self.failed_keys = {}
41
-
42
- def configure(self) -> bool:
43
- """Try to configure API with available keys"""
44
- while self.available_keys:
45
- key = self.available_keys.pop(0)
46
- try:
47
- configure(api_key=key)
48
- self.active_key = key
49
- logger.info(f"Successfully configured with key: {self._mask_key(key)}")
50
- return True
51
- except Exception as e:
52
- self.failed_keys[key] = str(e)
53
- logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
54
-
55
- logger.critical("All API keys failed to configure")
56
- return False
57
-
58
- def _mask_key(self, key: str) -> str:
59
- return f"{key[:8]}...{key[-4:]}" if key else ""
60
-
61
- class PythonREPL:
62
- """Secure Python REPL with non-interactive plotting"""
63
-
64
- def __init__(self, df: pd.DataFrame):
65
- self.df = df
66
- self.local_env = {
67
- "pd": pd,
68
- "df": self.df.copy(),
69
- "plt": plt,
70
- "os": os,
71
- "uuid": uuid,
72
- "plt": plt
73
- }
74
- os.makedirs('generated_charts', exist_ok=True)
75
-
76
- def execute(self, code: str) -> Dict[str, Any]:
77
- old_stdout = sys.stdout
78
- sys.stdout = mystdout = StringIO()
79
-
80
- try:
81
- # Ensure figure closure and non-interactive mode
82
- code = f"""
83
- import matplotlib.pyplot as plt
84
- plt.switch_backend('agg')
85
- {code}
86
- plt.close('all')
87
- """
88
- exec(code, self.local_env)
89
- self.df = self.local_env.get('df', self.df)
90
- error = False
91
- except Exception as e:
92
- error_msg = traceback.format_exc()
93
- error = True
94
- finally:
95
- sys.stdout = old_stdout
96
-
97
- return {
98
- "output": mystdout.getvalue(),
99
- "error": error,
100
- "error_message": error_msg if error else None,
101
- "df": self.local_env.get('df', self.df)
102
- }
103
-
104
- class RethinkAgent(BaseModel):
105
- df: pd.DataFrame
106
- max_retries: int = Field(default=5, ge=1)
107
- gemini_model: Optional[GenerativeModel] = None
108
- current_retry: int = Field(default=0, ge=0)
109
- repl: Optional[PythonREPL] = None
110
- key_manager: Optional[GeminiKeyManager] = None
111
-
112
- class Config:
113
- arbitrary_types_allowed = True
114
-
115
- def _extract_code(self, response: str) -> str:
116
- code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
117
- return code_match.group(1).strip() if code_match else response.strip()
118
-
119
- def _generate_initial_prompt(self, query: str) -> str:
120
- columns = "\n".join(self.df.columns)
121
- return f"""
122
- Generate Python code to analyze this DataFrame with columns:
123
- {columns}
124
-
125
- Query: {query}
126
-
127
- Requirements:
128
- 1. Save visualizations to 'generated_charts/' with UUID filename
129
- 2. Use plt.savefig() with format='png'
130
- 3. No plt.show() calls allowed
131
- 4. After saving each chart, print exactly: CHART_SAVED: generated_charts/{{uuid}}.png
132
- 5. Start with 'import pandas as pd'
133
- 6. The DataFrame is available as 'df'
134
- 7. Wrap code in ```python``` blocks
135
- """
136
-
137
- def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
138
- return f"""
139
- Previous code failed with error:
140
- {error}
141
-
142
- Revise this code:
143
- {code}
144
-
145
- New requirements:
146
- 1. Fix the error
147
- 2. Ensure plots are saved to generated_charts/
148
- 3. After saving each chart, print exactly: CHART_SAVED: generated_charts/{{uuid}}.png
149
- 4. No figure display
150
- 5. Complete query: {query}
151
-
152
- Explain the error first, then show corrected code in ```python``` blocks
153
- """
154
-
155
- def initialize_model(self, api_keys: List[str]) -> bool:
156
- """Initialize Gemini model with key rotation"""
157
- self.key_manager = GeminiKeyManager(api_keys)
158
- if not self.key_manager.configure():
159
- raise RuntimeError("All API keys failed to initialize")
160
-
161
- try:
162
- self.gemini_model = GenerativeModel(MODEL_NAME)
163
- return True
164
- except Exception as e:
165
- logger.error(f"Model initialization failed: {str(e)}")
166
- return False
167
-
168
- def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
169
- if error:
170
- prompt = self._generate_retry_prompt(query, error, previous_code)
171
- else:
172
- prompt = self._generate_initial_prompt(query)
173
-
174
- try:
175
- response = self.gemini_model.generate_content(prompt)
176
- return self._extract_code(response.text)
177
- except Exception as e:
178
- logger.error(f"API call failed: {str(e)}")
179
- if self.key_manager.available_keys:
180
- logger.info("Attempting key rotation...")
181
- if self.key_manager.configure():
182
- self.gemini_model = GenerativeModel(MODEL_NAME)
183
- return self.generate_code(query, error, previous_code)
184
- raise
185
-
186
- def execute_query(self, query: str) -> str:
187
- self.repl = PythonREPL(self.df)
188
- error = None
189
- previous_code = None
190
-
191
- while self.current_retry < self.max_retries:
192
- try:
193
- code = self.generate_code(query, error, previous_code)
194
- result = self.repl.execute(code)
195
-
196
- if result["error"]:
197
- self.current_retry += 1
198
- error = result["error_message"]
199
- previous_code = code
200
- logger.warning(f"Retry {self.current_retry}/{self.max_retries}...")
201
- else:
202
- self.df = result["df"]
203
- return result["output"]
204
- except Exception as e:
205
- logger.error(f"Critical failure: {str(e)}")
206
- return f"System error: {str(e)}"
207
-
208
- return f"Failed after {self.max_retries} attempts. Last error: {error}"
209
-
210
-
211
-
212
- def gemini_llm_chart(csv_url: str, query: str) -> str:
213
- df = pd.read_csv(csv_url)
214
-
215
- agent = RethinkAgent(df=df)
216
- if not agent.initialize_model(API_KEYS):
217
- print("Failed to initialize model with provided keys")
218
- exit(1)
219
-
220
- result = agent.execute_query(query)
221
- print("\nAnalysis Result:")
222
- print(result)
223
-
224
- if isinstance(result, str):
225
- result = result.strip() # Remove any leading/trailing spaces or newlines
226
-
227
- match = re.search(r'CHART_SAVED:\s*(\S+)', result)
228
-
229
- if match:
230
- chart_path = match.group(1)
231
- print("Chart Path:", chart_path)
232
- return chart_path
233
- else:
234
- print("Chart path not found")
235
- return "Chart path not found"
236
- else:
237
- print("Unexpected result format:", type(result))
238
- return "Chart path not found"
239
-
240
-
241
-
242
- # Usage Example
243
- # if __name__ == "__main__":
244
- # df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv')
245
-
246
- # agent = RethinkAgent(df=df)
247
- # if not agent.initialize_model(API_KEYS):
248
- # print("Failed to initialize model with provided keys")
249
- # exit(1)
250
-
251
- # result = agent.execute_query("Create a scatter plot of total_bill vs tip with kernel density estimate")
252
- # print("\nAnalysis Result:")
253
- # print(result)
254
-
255
- # if isinstance(result, str):
256
- # result = result.strip() # Remove any leading/trailing spaces or newlines
257
-
258
- # match = re.search(r'CHART_SAVED:\s*(\S+)', result)
259
-
260
- # if match:
261
- # chart_path = match.group(1)
262
- # print("Chart Path:", chart_path)
263
- # else:
264
- # print("Chart path not found")
265
- # else:
266
- # print("Unexpected result format:", type(result))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rethink_gemini_agents/rethink_chat.py DELETED
@@ -1,259 +0,0 @@
1
- import pandas as pd
2
- import re
3
- import os
4
- import uuid
5
- import logging
6
- from io import StringIO
7
- import sys
8
- import traceback
9
- from typing import Optional, Dict, Any, List
10
- from pydantic import BaseModel, Field
11
- from google.generativeai import GenerativeModel, configure
12
- from dotenv import load_dotenv
13
- import seaborn as sns
14
- from csv_service import clean_data
15
- from util_service import handle_out_of_range_float
16
-
17
- pd.set_option('display.max_columns', None) # Show all columns
18
- pd.set_option('display.max_rows', None) # Show all rows
19
- pd.set_option('display.max_colwidth', None) # Do not truncate cell content
20
-
21
- # Load environment variables from .env file
22
- load_dotenv()
23
-
24
- API_KEYS = os.getenv("GOOGLE_GENERATIVE_AI_API_KEYS", "").split(",")
25
- MODEL_NAME = os.getenv("GOOGLE_GENERATIVE_AI_MODEL")
26
-
27
- # Set up non-interactive matplotlib backend
28
- os.environ['MPLBACKEND'] = 'agg'
29
- import matplotlib.pyplot as plt
30
- plt.show = lambda: None # Monkey patch to disable display
31
-
32
- # Configure logging
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format='%(asctime)s - %(levelname)s - %(message)s',
36
- handlers=[logging.FileHandler('api_key_rotation.log'), logging.StreamHandler()]
37
- )
38
- logger = logging.getLogger(__name__)
39
-
40
- class GeminiKeyManager:
41
- """Manage multiple Gemini API keys with failover"""
42
-
43
- def __init__(self, api_keys: List[str]):
44
- self.original_keys = api_keys.copy()
45
- self.available_keys = api_keys.copy()
46
- self.active_key = None
47
- self.failed_keys = {}
48
-
49
- def configure(self) -> bool:
50
- """Try to configure API with available keys"""
51
- while self.available_keys:
52
- key = self.available_keys.pop(0)
53
- try:
54
- configure(api_key=key)
55
- self.active_key = key
56
- logger.info(f"Successfully configured with key: {self._mask_key(key)}")
57
- return True
58
- except Exception as e:
59
- self.failed_keys[key] = str(e)
60
- logger.error(f"Key failed: {self._mask_key(key)}. Error: {str(e)}")
61
-
62
- logger.critical("All API keys failed to configure")
63
- return False
64
-
65
- def _mask_key(self, key: str) -> str:
66
- return f"{key[:8]}...{key[-4:]}" if key else ""
67
-
68
- class PythonREPL:
69
- """Secure Python REPL with non-interactive plotting"""
70
-
71
- def __init__(self, df: pd.DataFrame):
72
- self.df = df
73
- self.local_env = {
74
- "pd": pd,
75
- "df": self.df.copy(),
76
- "plt": plt,
77
- "os": os,
78
- "uuid": uuid,
79
- "plt": plt,
80
- "sns": sns,
81
- }
82
- os.makedirs('generated_charts', exist_ok=True)
83
-
84
- def execute(self, code: str) -> Dict[str, Any]:
85
- old_stdout = sys.stdout
86
- sys.stdout = mystdout = StringIO()
87
-
88
- try:
89
- # Ensure figure closure and non-interactive mode
90
- code = f"""
91
- import matplotlib.pyplot as plt
92
- plt.switch_backend('agg')
93
- {code}
94
- plt.close('all')
95
- """
96
- exec(code, self.local_env)
97
- self.df = self.local_env.get('df', self.df)
98
- error = False
99
- except Exception as e:
100
- error_msg = traceback.format_exc()
101
- error = True
102
- finally:
103
- sys.stdout = old_stdout
104
-
105
- return {
106
- "output": mystdout.getvalue(),
107
- "error": error,
108
- "error_message": error_msg if error else None,
109
- "df": self.local_env.get('df', self.df)
110
- }
111
-
112
- class RethinkAgent(BaseModel):
113
- df: pd.DataFrame
114
- max_retries: int = Field(default=5, ge=1)
115
- gemini_model: Optional[GenerativeModel] = None
116
- current_retry: int = Field(default=0, ge=0)
117
- repl: Optional[PythonREPL] = None
118
- key_manager: Optional[GeminiKeyManager] = None
119
-
120
- class Config:
121
- arbitrary_types_allowed = True
122
-
123
- def _extract_code(self, response: str) -> str:
124
- code_match = re.search(r'```python(.*?)```', response, re.DOTALL)
125
- return code_match.group(1).strip() if code_match else response.strip()
126
-
127
- def _generate_initial_prompt(self, query: str) -> str:
128
- columns = "\n".join(self.df.columns)
129
- return f"""
130
- You are a data analyst assistant. Generate Python code to analyze this DataFrame with columns:
131
- {columns}
132
-
133
- Query: {query}
134
-
135
- Requirements:
136
- 1. Use print() to show results
137
- 2. Start with 'import pandas as pd'
138
- 3. The DataFrame is available as 'df'
139
- 4. Wrap code in ```python``` blocks
140
- """
141
-
142
- def _generate_retry_prompt(self, query: str, error: str, code: str) -> str:
143
- return f"""
144
- Previous code failed with error:
145
- {error}
146
-
147
- Failed code:
148
- {code}
149
-
150
- Revise the code to fix the error and complete this query:
151
- {query}
152
-
153
- Requirements:
154
- 1. Explain the error first
155
- 2. Show corrected code in ```python``` blocks
156
- """
157
-
158
- def initialize_model(self, api_keys: List[str]) -> bool:
159
- """Initialize Gemini model with key rotation"""
160
- self.key_manager = GeminiKeyManager(api_keys)
161
- if not self.key_manager.configure():
162
- raise RuntimeError("All API keys failed to initialize")
163
-
164
- try:
165
- self.gemini_model = GenerativeModel(MODEL_NAME)
166
- return True
167
- except Exception as e:
168
- logger.error(f"Model initialization failed: {str(e)}")
169
- return False
170
-
171
- def generate_code(self, query: str, error: Optional[str] = None, previous_code: Optional[str] = None) -> str:
172
- if error:
173
- prompt = self._generate_retry_prompt(query, error, previous_code)
174
- else:
175
- prompt = self._generate_initial_prompt(query)
176
-
177
- try:
178
- response = self.gemini_model.generate_content(prompt)
179
- return self._extract_code(response.text)
180
- except Exception as e:
181
- logger.error(f"API call failed: {str(e)}")
182
- if self.key_manager.available_keys:
183
- logger.info("Attempting key rotation...")
184
- if self.key_manager.configure():
185
- self.gemini_model = GenerativeModel(MODEL_NAME)
186
- return self.generate_code(query, error, previous_code)
187
- raise
188
-
189
- def execute_query(self, query: str) -> str:
190
- self.repl = PythonREPL(self.df)
191
- error = None
192
- previous_code = None
193
-
194
- while self.current_retry < self.max_retries:
195
- try:
196
- code = self.generate_code(query, error, previous_code)
197
- result = self.repl.execute(code)
198
-
199
- if result["error"]:
200
- self.current_retry += 1
201
- error = result["error_message"]
202
- previous_code = code
203
- logger.warning(f"Retry {self.current_retry}/{self.max_retries}...")
204
- else:
205
- self.df = result["df"]
206
- return result["output"]
207
- except Exception as e:
208
- logger.error(f"Critical failure: {str(e)}")
209
- return f"System error: {str(e)}"
210
-
211
- return f"Failed after {self.max_retries} attempts. Last error: {error}"
212
-
213
-
214
- def gemini_llm_chat(csv_url: str, query: str) -> str:
215
-
216
- try:
217
- # Assuming clean_data and RethinkAgent are defined elsewhere
218
- df = clean_data(csv_url)
219
- agent = RethinkAgent(df=df)
220
-
221
- # Assuming API_KEYS is defined elsewhere
222
- if not agent.initialize_model(API_KEYS):
223
- print("Failed to initialize model with provided keys")
224
- exit(1)
225
-
226
- result = agent.execute_query(query)
227
-
228
- # Process different response types
229
- if isinstance(result, pd.DataFrame):
230
- processed = result.apply(handle_out_of_range_float).to_dict(orient="records")
231
- elif isinstance(result, pd.Series):
232
- processed = result.apply(handle_out_of_range_float).to_dict()
233
- elif isinstance(result, list):
234
- processed = [handle_out_of_range_float(item) for item in result]
235
- elif isinstance(result, dict):
236
- processed = {k: handle_out_of_range_float(v) for k, v in result.items()}
237
- else:
238
- processed = {"answer": str(handle_out_of_range_float(result))}
239
-
240
- logger.info(f"gemini processed result: {processed}")
241
- return processed
242
- except Exception as e:
243
- logger.error(f"Error in gemini_llm_chat: {str(e)}")
244
- return None
245
-
246
- # uvicorn controller:app --host localhost --port 8000 --reload
247
-
248
- # Usage Example
249
- # if __name__ == "__main__":
250
- # df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv')
251
-
252
- # agent = RethinkAgent(df=df)
253
- # if not agent.initialize_model(API_KEYS):
254
- # print("Failed to initialize model with provided keys")
255
- # exit(1)
256
-
257
- # result = agent.execute_query("How many rows and cols r there and what r their names?")
258
- # print("\nAnalysis Result:")
259
- # print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
supabase_service.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from supabase import create_client, Client
3
+
4
+ # Replace with your Supabase URL and API key
5
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL")
6
+ SUPABASE_KEY: str = os.getenv("SUPABASE_KEY")
7
+
8
+ # Initialize the Supabase client
9
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
10
+
11
+ # Define the bucket name (you can create one in the Supabase Storage section)
12
+ BUCKET_NAME = "csvcharts"
13
+
14
+ async def upload_image_to_supabase(file_path: str, file_name: str) -> str:
15
+ """
16
+ Uploads an image to Supabase Storage and returns the public URL.
17
+
18
+ :param file_path: Path to the image file on your local machine.
19
+ :param file_name: Name to save the file as in Supabase Storage.
20
+ :return: Public URL of the uploaded image.
21
+ """
22
+ # Check if the file exists
23
+ if not os.path.exists(file_path):
24
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
25
+
26
+ # Read the file in binary mode
27
+ with open(file_path, "rb") as f:
28
+ file_data = f.read()
29
+
30
+ # Upload the file to Supabase Storage
31
+ try:
32
+ res = supabase.storage.from_(BUCKET_NAME).upload(file_name, file_data)
33
+ print("Upload response:", res) # Debugging: Print the response
34
+ except Exception as e:
35
+ raise Exception(f"Failed to upload file: {e}")
36
+
37
+ # Get the public URL of the uploaded file
38
+ public_url = supabase.storage.from_(BUCKET_NAME).get_public_url(file_name)
39
+ print("Public URL:", public_url) # Debugging: Print the public URL
40
+
41
+ return public_url
42
+