Soumik555 commited on
Commit
4fbcf68
·
1 Parent(s): 45e593b

put middle orchestrator

Browse files
Files changed (3) hide show
  1. controller.py +17 -0
  2. orchestrator_agent.py +94 -0
  3. orchestrator_functions.py +381 -0
controller.py CHANGED
@@ -26,6 +26,7 @@ import matplotlib.pyplot as plt
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
 
29
  from supabase_service import upload_image_to_supabase
30
  from util_service import _prompt_generator, process_answer
31
  from fastapi.middleware.cors import CORSMiddleware
@@ -306,6 +307,14 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
306
  )
307
  logger.info("langchain_answer:", answer)
308
  return {"answer": jsonable_encoder(answer)}
 
 
 
 
 
 
 
 
309
 
310
  # Process with groq_chat first
311
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
@@ -802,6 +811,14 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
802
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
803
  return {"image_url": image_public_url}
804
  # return FileResponse(langchain_result[0], media_type="image/png")
 
 
 
 
 
 
 
 
805
 
806
  # Next, try the groq-based method
807
  groq_result = await loop.run_in_executor(
 
26
  import matplotlib
27
  import seaborn as sns
28
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
29
+ from orchestrator_agent import csv_orchestrator_chat
30
  from supabase_service import upload_image_to_supabase
31
  from util_service import _prompt_generator, process_answer
32
  from fastapi.middleware.cors import CORSMiddleware
 
307
  )
308
  logger.info("langchain_answer:", answer)
309
  return {"answer": jsonable_encoder(answer)}
310
+
311
+ # Orchestrate the execution
312
+ orchestrator_answer = await asyncio.to_thread(
313
+ csv_orchestrator_chat, decoded_url, query
314
+ )
315
+
316
+ if orchestrator_answer is not None:
317
+ return {"answer": jsonable_encoder(orchestrator_answer)}
318
 
319
  # Process with groq_chat first
320
  groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
 
811
  logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
812
  return {"image_url": image_public_url}
813
  # return FileResponse(langchain_result[0], media_type="image/png")
814
+
815
+ # Use orchestrator to handle the user's chart query first
816
+ orchestrator_answer = await asyncio.to_thread(
817
+ process_executor,csv_orchestrator_chat, csv_url, query
818
+ )
819
+
820
+ if orchestrator_answer is not None:
821
+ return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
822
 
823
  # Next, try the groq-based method
824
  groq_result = await loop.run_in_executor(
orchestrator_agent.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import Dict, List, Any
4
+ from pydantic_ai import Agent
5
+ from pydantic_ai.models.gemini import GeminiModel
6
+ from pydantic_ai.providers.google_gla import GoogleGLAProvider
7
+ from pydantic_ai import RunContext
8
+ from pydantic import BaseModel
9
+ from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
10
+ from csv_service import get_csv_basic_info
11
+ from orchestrator_functions import csv_chart, csv_chat
12
+
13
+
14
+ # Load all API keys from the environment variable
15
+ GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
16
+
17
+ # Function to initialize the model with a specific API key
18
+ def initialize_model(api_key: str) -> GeminiModel:
19
+ return GeminiModel(
20
+ 'gemini-2.0-flash',
21
+ provider=GoogleGLAProvider(api_key=api_key)
22
+ )
23
+
24
+ # Define the tools
25
+ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
26
+ print("LLM using the csv chat function....")
27
+ print("CSV URL:", csv_url)
28
+ print("User question:", user_questions)
29
+
30
+ # Create an array to accumulate the answers
31
+ answers = []
32
+ # Loop through the user questions and generate answers for each
33
+ for question in user_questions:
34
+ answer = await csv_chat(csv_url, question)
35
+ answers.append(dict(question=question, answer=answer))
36
+ return answers
37
+
38
+ async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
39
+ print("LLM using the csv chart function....")
40
+ print("CSV URL:", csv_url)
41
+ print("User question:", user_questions)
42
+
43
+ # Create an array to accumulate the charts
44
+ charts = []
45
+ # Loop through the user questions and generate charts for each
46
+ for question in user_questions:
47
+ chart = await csv_chart(csv_url, question)
48
+ charts.append(dict(question=question, image_url=chart))
49
+
50
+ return charts
51
+
52
+ # Function to create an agent with a specific CSV URL
53
+ def create_agent(csv_url: str, api_key: str) -> Agent:
54
+ csv_metadata = get_csv_basic_info(csv_url)
55
+
56
+ system_prompt = (
57
+ "You are a data analyst."
58
+ "You have all the tools you need to answer any question."
59
+ "If user asking for multiple answers or charts then break the question into multiple proper questions."
60
+ "Pass csv_url/path with the questions to the tools to generate the answer."
61
+ "Explain the answer in a friendly way."
62
+ "**Format images** in Markdown: `![alt_text](direct-image-url)`"
63
+ f"Your csv url is {csv_url}"
64
+ f"Your csv metadata is {csv_metadata}"
65
+ )
66
+ return Agent(
67
+ model=initialize_model(api_key),
68
+ deps_type=str,
69
+ tools=[generate_csv_answer, generate_chart],
70
+ system_prompt=system_prompt
71
+ )
72
+
73
+ def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
74
+ print("CSV URL:", csv_url)
75
+ print("User questions:", user_question)
76
+
77
+ # Iterate through all API keys
78
+ for api_key in GEMINI_API_KEYS:
79
+ try:
80
+ print(f"Attempting with API key: {api_key}")
81
+ agent = create_agent(csv_url, api_key)
82
+ result = agent.run_sync(user_question)
83
+ print("Orchestrator Result:", result.data)
84
+ return result.data
85
+ except ResourceExhausted or Exception as e:
86
+ print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
87
+ continue # Move to the next key
88
+ except Exception as e:
89
+ print(f"Error with API key {api_key}: {e}")
90
+ continue # Move to the next key
91
+
92
+ # If all keys are exhausted or fail
93
+ print("All API keys have been exhausted or failed.")
94
+ return None
orchestrator_functions.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary modules
2
+ import asyncio
3
+ import os
4
+ import threading
5
+ import uuid
6
+ from fastapi.encoders import jsonable_encoder
7
+ import numpy as np
8
+ import pandas as pd
9
+ from pandasai import SmartDataframe
10
+ from langchain_groq.chat_models import ChatGroq
11
+ from dotenv import load_dotenv
12
+ from pydantic import BaseModel
13
+ from csv_service import clean_data, extract_chart_filenames
14
+ from langchain_groq import ChatGroq
15
+ import pandas as pd
16
+ from langchain_experimental.tools import PythonAstREPLTool
17
+ from langchain_experimental.agents import create_pandas_dataframe_agent
18
+ import numpy as np
19
+ import matplotlib.pyplot as plt
20
+ import matplotlib
21
+ import seaborn as sns
22
+ from supabase_service import upload_image_to_supabase
23
+ from util_service import _prompt_generator, process_answer
24
+ import matplotlib
25
+ matplotlib.use('Agg')
26
+
27
+
28
+ load_dotenv()
29
+
30
+ image_file_path = os.getenv("IMAGE_FILE_PATH")
31
+ image_not_found = os.getenv("IMAGE_NOT_FOUND")
32
+ allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",")
33
+
34
+
35
+ # Load environment variables
36
+ groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
37
+ model_name = os.getenv("GROQ_LLM_MODEL")
38
+
39
+ class CsvUrlRequest(BaseModel):
40
+ csv_url: str
41
+
42
+ class ImageRequest(BaseModel):
43
+ image_path: str
44
+
45
+ class CsvCommonHeadersRequest(BaseModel):
46
+ file_urls: list[str]
47
+
48
+ class CsvsMergeRequest(BaseModel):
49
+ file_urls: list[str]
50
+ merge_type: str
51
+ common_columns_name: list[str]
52
+
53
+ # Thread-safe key management for groq_chat
54
+ current_groq_key_index = 0
55
+ current_groq_key_lock = threading.Lock()
56
+
57
+ # Thread-safe key management for langchain_csv_chat
58
+ current_langchain_key_index = 0
59
+ current_langchain_key_lock = threading.Lock()
60
+
61
+
62
+ # CHAT CODING STARTS FROM HERE
63
+
64
+ # Modified groq_chat function with thread-safe key rotation
65
+ def groq_chat(csv_url: str, question: str):
66
+ global current_groq_key_index, current_groq_key_lock
67
+
68
+ while True:
69
+ with current_groq_key_lock:
70
+ if current_groq_key_index >= len(groq_api_keys):
71
+ return {"error": "All API keys exhausted."}
72
+ current_api_key = groq_api_keys[current_groq_key_index]
73
+
74
+ try:
75
+ # Delete cache file if exists
76
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
77
+ if os.path.exists(cache_db_path):
78
+ try:
79
+ os.remove(cache_db_path)
80
+ except Exception as e:
81
+ print(f"Error deleting cache DB file: {e}")
82
+
83
+ data = clean_data(csv_url)
84
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
85
+ # Generate unique filename using UUID
86
+ chart_filename = f"chart_{uuid.uuid4()}.png"
87
+ chart_path = os.path.join("generated_charts", chart_filename)
88
+
89
+ # Configure SmartDataframe with chart settings
90
+ df = SmartDataframe(
91
+ data,
92
+ config={
93
+ 'llm': llm,
94
+ 'save_charts': True, # Enable chart saving
95
+ 'open_charts': False,
96
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
97
+ 'custom_chart_filename': chart_filename # Unique filename
98
+ }
99
+ )
100
+
101
+ answer = df.chat(question)
102
+
103
+ # Process different response types
104
+ if isinstance(answer, pd.DataFrame):
105
+ processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
106
+ elif isinstance(answer, pd.Series):
107
+ processed = answer.apply(handle_out_of_range_float).to_dict()
108
+ elif isinstance(answer, list):
109
+ processed = [handle_out_of_range_float(item) for item in answer]
110
+ elif isinstance(answer, dict):
111
+ processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
112
+ else:
113
+ processed = {"answer": str(handle_out_of_range_float(answer))}
114
+
115
+ return processed
116
+
117
+ except Exception as e:
118
+ error_message = str(e)
119
+ if "429" in error_message:
120
+ with current_groq_key_lock:
121
+ current_groq_key_index += 1
122
+ if current_groq_key_index >= len(groq_api_keys):
123
+ return {"error": "All API keys exhausted."}
124
+ else:
125
+ return {"error": error_message}
126
+
127
+ # Modified langchain_csv_chat with thread-safe key rotation
128
+ def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
129
+ global current_langchain_key_index, current_langchain_key_lock
130
+
131
+ data = clean_data(csv_url)
132
+ attempts = 0
133
+
134
+ while attempts < len(groq_api_keys):
135
+ with current_langchain_key_lock:
136
+ if current_langchain_key_index >= len(groq_api_keys):
137
+ current_langchain_key_index = 0
138
+ api_key = groq_api_keys[current_langchain_key_index]
139
+ current_key = current_langchain_key_index
140
+ current_langchain_key_index += 1
141
+ attempts += 1
142
+
143
+ try:
144
+ llm = ChatGroq(model=model_name, api_key=api_key)
145
+ tool = PythonAstREPLTool(locals={
146
+ "df": data,
147
+ "pd": pd,
148
+ "np": np,
149
+ "plt": plt,
150
+ "sns": sns,
151
+ "matplotlib": matplotlib
152
+ })
153
+
154
+ agent = create_pandas_dataframe_agent(
155
+ llm,
156
+ data,
157
+ agent_type="openai-tools",
158
+ verbose=True,
159
+ allow_dangerous_code=True,
160
+ extra_tools=[tool],
161
+ return_intermediate_steps=True
162
+ )
163
+
164
+ prompt = _prompt_generator(question, chart_required)
165
+ result = agent.invoke({"input": prompt})
166
+ return result.get("output")
167
+
168
+ except Exception as e:
169
+ print(f"Error with key index {current_key}: {str(e)}")
170
+
171
+ return {"error": "All API keys exhausted"}
172
+
173
+
174
+ def handle_out_of_range_float(value):
175
+ if isinstance(value, float):
176
+ if np.isnan(value):
177
+ return None
178
+ elif np.isinf(value):
179
+ return "Infinity"
180
+ return value
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+ # CHART CODING STARTS FROM HERE
189
+
190
+ instructions = """
191
+
192
+ - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
193
+ - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
194
+ - Use colorblind-friendly palette
195
+ - Read above instructions and follow them.
196
+
197
+ """
198
+
199
+ # Thread-safe configuration for chart endpoints
200
+ current_groq_chart_key_index = 0
201
+ current_groq_chart_lock = threading.Lock()
202
+
203
+ current_langchain_chart_key_index = 0
204
+ current_langchain_chart_lock = threading.Lock()
205
+
206
+ def model():
207
+ global current_groq_chart_key_index, current_groq_chart_lock
208
+ with current_groq_chart_lock:
209
+ if current_groq_chart_key_index >= len(groq_api_keys):
210
+ raise Exception("All API keys exhausted for chart generation")
211
+ api_key = groq_api_keys[current_groq_chart_key_index]
212
+ return ChatGroq(model=model_name, api_key=api_key)
213
+
214
+ def groq_chart(csv_url: str, question: str):
215
+ global current_groq_chart_key_index, current_groq_chart_lock
216
+
217
+ for attempt in range(len(groq_api_keys)):
218
+ try:
219
+ # Clean cache before processing
220
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
221
+ if os.path.exists(cache_db_path):
222
+ try:
223
+ os.remove(cache_db_path)
224
+ except Exception as e:
225
+ print(f"Cache cleanup error: {e}")
226
+
227
+ data = clean_data(csv_url)
228
+ with current_groq_chart_lock:
229
+ current_api_key = groq_api_keys[current_groq_chart_key_index]
230
+
231
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
232
+
233
+ # Generate unique filename using UUID
234
+ chart_filename = f"chart_{uuid.uuid4()}.png"
235
+ chart_path = os.path.join("generated_charts", chart_filename)
236
+
237
+ # Configure SmartDataframe with chart settings
238
+ df = SmartDataframe(
239
+ data,
240
+ config={
241
+ 'llm': llm,
242
+ 'save_charts': True, # Enable chart saving
243
+ 'open_charts': False,
244
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
245
+ 'custom_chart_filename': chart_filename # Unique filename
246
+ }
247
+ )
248
+
249
+ answer = df.chat(question + instructions)
250
+
251
+ if process_answer(answer):
252
+ return "Chart not generated"
253
+ return answer
254
+
255
+ except Exception as e:
256
+ error = str(e)
257
+ if "429" in error:
258
+ with current_groq_chart_lock:
259
+ current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
260
+ else:
261
+ print(f"Chart generation error: {error}")
262
+ return {"error": error}
263
+
264
+ return {"error": "All API keys exhausted for chart generation"}
265
+
266
+
267
+
268
+ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
269
+ global current_langchain_chart_key_index, current_langchain_chart_lock
270
+
271
+ data = clean_data(csv_url)
272
+
273
+ for attempt in range(len(groq_api_keys)):
274
+ try:
275
+ with current_langchain_chart_lock:
276
+ api_key = groq_api_keys[current_langchain_chart_key_index]
277
+ current_key = current_langchain_chart_key_index
278
+ current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
279
+
280
+ llm = ChatGroq(model=model_name, api_key=api_key)
281
+ tool = PythonAstREPLTool(locals={
282
+ "df": data,
283
+ "pd": pd,
284
+ "np": np,
285
+ "plt": plt,
286
+ "sns": sns,
287
+ "matplotlib": matplotlib,
288
+ "uuid": uuid
289
+ })
290
+
291
+ agent = create_pandas_dataframe_agent(
292
+ llm,
293
+ data,
294
+ agent_type="openai-tools",
295
+ verbose=True,
296
+ allow_dangerous_code=True,
297
+ extra_tools=[tool],
298
+ return_intermediate_steps=True
299
+ )
300
+
301
+ result = agent.invoke({"input": _prompt_generator(question, True)})
302
+ output = result.get("output", "")
303
+
304
+ # Verify chart file creation
305
+ chart_files = extract_chart_filenames(output)
306
+ if len(chart_files) > 0:
307
+ return chart_files
308
+
309
+ if attempt < len(groq_api_keys) - 1:
310
+ print(f"Langchain chart error (key {current_key}): {output}")
311
+
312
+ except Exception as e:
313
+ print(f"Langchain chart error (key {current_key}): {str(e)}")
314
+
315
+ return "Chart generation failed after all retries"
316
+
317
+
318
+
319
+
320
+ ###########################################################################################################################
321
+
322
+
323
+
324
+
325
+ async def csv_chart(csv_url: str, query: str):
326
+ try:
327
+
328
+ # Groq-based chart generation
329
+ groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
330
+ print(f"Generated Chart: {groq_result}")
331
+ if groq_result != 'Chart not generated':
332
+ unique_file_name =f'{str(uuid.uuid4())}.png'
333
+ image_public_url = await upload_image_to_supabase(f"{groq_result}", unique_file_name)
334
+ print(f"Image uploaded to Supabase: {image_public_url}")
335
+ return {"image_url": image_public_url}
336
+ else:
337
+ return {"error": "All chart generation methods failed"}
338
+
339
+ except Exception as e:
340
+ print(f"Critical chart error: {str(e)}")
341
+ return {"error": "Internal system error"}
342
+
343
+
344
+
345
+
346
+
347
+
348
+ async def csv_chat(csv_url: str, query: str):
349
+
350
+ try:
351
+ # Process with groq_chat first
352
+ groq_answer = await asyncio.to_thread(groq_chat, csv_url, query)
353
+ print("groq_answer:", groq_answer)
354
+
355
+ if process_answer(groq_answer) == "Empty response received.":
356
+ return {"answer": "Sorry, I couldn't find relevant data..."}
357
+
358
+ if process_answer(groq_answer):
359
+ lang_answer = await asyncio.to_thread(
360
+ langchain_csv_chat, csv_url, query, False
361
+ )
362
+ if process_answer(lang_answer):
363
+ return {"answer": "error"}
364
+ return {"answer": jsonable_encoder(lang_answer)}
365
+
366
+ return {"answer": jsonable_encoder(groq_answer)}
367
+
368
+ except Exception as e:
369
+ print(f"Error processing request: {str(e)}")
370
+ return {"answer": "error"}
371
+
372
+ def handle_out_of_range_float(value):
373
+ if isinstance(value, float):
374
+ if np.isnan(value):
375
+ return None
376
+ elif np.isinf(value):
377
+ return "Infinity"
378
+ return value
379
+
380
+
381
+