modifies orchestrator,add code_exec tool, openai_chat (later we add chat)
Browse files- code_exec_service.py +128 -0
- openai_pandasai_service.py +145 -0
- orchestrator_agent.py +40 -49
- orchestrator_functions.py +93 -75
- util_service.py +1 -1
code_exec_service.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import io
|
| 3 |
+
import time
|
| 4 |
+
from contextlib import redirect_stdout, redirect_stderr
|
| 5 |
+
import uuid
|
| 6 |
+
from matplotlib import pyplot as plt
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import traceback
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 14 |
+
|
| 15 |
+
class PythonDataAnalysisExecutor:
|
| 16 |
+
"""
|
| 17 |
+
Simplified Python code execution environment for data analysis
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, timeout_seconds=30):
|
| 21 |
+
self.timeout = timeout_seconds
|
| 22 |
+
self.safe_globals = {
|
| 23 |
+
'__builtins__': {
|
| 24 |
+
'print': print, 'range': range, 'len': len,
|
| 25 |
+
'str': str, 'int': int, 'float': float, 'bool': bool,
|
| 26 |
+
'list': list, 'dict': dict, 'tuple': tuple, 'set': set,
|
| 27 |
+
'min': min, 'max': max, 'sum': sum, 'abs': abs,
|
| 28 |
+
'round': round, 'zip': zip, 'enumerate': enumerate,
|
| 29 |
+
'__import__': __import__
|
| 30 |
+
},
|
| 31 |
+
'pd': pd, 'np': np,
|
| 32 |
+
'matplotlib.pyplot': plt,
|
| 33 |
+
'seaborn': sns,
|
| 34 |
+
'uuid': uuid.uuid4,
|
| 35 |
+
'datetime': datetime, 'time': time,
|
| 36 |
+
'DataFrame': pd.DataFrame, 'Series': pd.Series
|
| 37 |
+
}
|
| 38 |
+
self.last_result = None
|
| 39 |
+
|
| 40 |
+
def _validate_code(self, code):
|
| 41 |
+
"""Basic security checks"""
|
| 42 |
+
forbidden = ['sys.', 'open(', 'exec(', 'eval(']
|
| 43 |
+
if any(f in code for f in forbidden):
|
| 44 |
+
raise ValueError("Potentially unsafe code detected")
|
| 45 |
+
|
| 46 |
+
def execute(self, code: str) -> dict:
|
| 47 |
+
"""
|
| 48 |
+
Execute Python code safely with timeout
|
| 49 |
+
Returns dict with: success, output, error, execution_time, variables, result
|
| 50 |
+
"""
|
| 51 |
+
result = {
|
| 52 |
+
'success': False,
|
| 53 |
+
'output': '',
|
| 54 |
+
'error': '',
|
| 55 |
+
'execution_time': 0,
|
| 56 |
+
'variables': {},
|
| 57 |
+
'result': None # This will capture the last expression result
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
output = io.StringIO()
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
self._validate_code(code)
|
| 65 |
+
|
| 66 |
+
with redirect_stdout(output), redirect_stderr(output):
|
| 67 |
+
# Split into lines and handle last expression
|
| 68 |
+
lines = [line for line in code.split('\n') if line.strip()]
|
| 69 |
+
if lines:
|
| 70 |
+
# Execute all but last line normally
|
| 71 |
+
if len(lines) > 1:
|
| 72 |
+
exec('\n'.join(lines[:-1]), self.safe_globals)
|
| 73 |
+
|
| 74 |
+
# Handle last line specially to capture its value
|
| 75 |
+
last_line = lines[-1].strip()
|
| 76 |
+
if last_line:
|
| 77 |
+
# If it's an expression (not assignment or control structure)
|
| 78 |
+
if not (last_line.startswith((' ', '\t')) or
|
| 79 |
+
last_line.split()[0] in ('if', 'for', 'while', 'def', 'class') or
|
| 80 |
+
'=' in last_line):
|
| 81 |
+
self.last_result = eval(last_line, self.safe_globals)
|
| 82 |
+
result['result'] = self.last_result
|
| 83 |
+
output.write(str(self.last_result) + '\n')
|
| 84 |
+
else:
|
| 85 |
+
exec(last_line, self.safe_globals)
|
| 86 |
+
|
| 87 |
+
result['output'] = output.getvalue()
|
| 88 |
+
result['variables'] = {
|
| 89 |
+
k: v for k, v in self.safe_globals.items()
|
| 90 |
+
if not k.startswith('__') and k in code
|
| 91 |
+
}
|
| 92 |
+
result['success'] = True
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
result['error'] = f"{str(e)}\n{traceback.format_exc()}"
|
| 96 |
+
|
| 97 |
+
result['execution_time'] = time.time() - start_time
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def run_analysis(code: str, timeout=20) -> dict:
|
| 102 |
+
"""Simplified interface for code execution"""
|
| 103 |
+
executor = PythonDataAnalysisExecutor(timeout_seconds=timeout)
|
| 104 |
+
return executor.execute(code)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Example usage
|
| 108 |
+
# if __name__ == "__main__":
|
| 109 |
+
# analysis_code = """
|
| 110 |
+
# import datetime
|
| 111 |
+
# print(datetime.datetime.now())
|
| 112 |
+
# """
|
| 113 |
+
|
| 114 |
+
# result = run_analysis(analysis_code)
|
| 115 |
+
|
| 116 |
+
# # Improved output formatting
|
| 117 |
+
# if result['success']:
|
| 118 |
+
# print("Execution successful")
|
| 119 |
+
# print("Execution time:", result['execution_time'], "seconds")
|
| 120 |
+
# print("Output:", result['output'].strip())
|
| 121 |
+
# print("Result:", result['result'])
|
| 122 |
+
# print("Variables:", list(result['variables'].keys()))
|
| 123 |
+
# else:
|
| 124 |
+
# print("Execution failed")
|
| 125 |
+
# print("Error:", result['error'])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
openai_pandasai_service.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import threading
|
| 5 |
+
import uuid
|
| 6 |
+
from langchain_openai import ChatOpenAI
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from pandasai import SmartDataframe
|
| 9 |
+
from csv_service import clean_data
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from util_service import handle_out_of_range_float, process_answer
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
openai_api_keys = os.getenv("OPENAI_API_KEYS").split(",")
|
| 15 |
+
openai_api_base = os.getenv("OPENAI_API_BASE")
|
| 16 |
+
|
| 17 |
+
# Thread-safe key management for openai_chat
|
| 18 |
+
current_openai_key_index = 0
|
| 19 |
+
current_openai_key_lock = threading.Lock()
|
| 20 |
+
|
| 21 |
+
instructions = instructions = """
|
| 22 |
+
|
| 23 |
+
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
|
| 24 |
+
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
|
| 25 |
+
- Use professional and color-blind friendly palettes.
|
| 26 |
+
- Do not use sns.set_palette()
|
| 27 |
+
- Read above instructions and follow them.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Modified openai_chat function with thread-safe key rotation
|
| 34 |
+
openai_model_name = 'gpt-4o'
|
| 35 |
+
|
| 36 |
+
def openai_chat(csv_url: str, question: str):
|
| 37 |
+
global current_openai_key_index, current_openai_key_lock
|
| 38 |
+
|
| 39 |
+
while True:
|
| 40 |
+
with current_openai_key_lock:
|
| 41 |
+
if current_openai_key_index >= len(openai_api_keys):
|
| 42 |
+
return {"error": "All API keys exhausted."}
|
| 43 |
+
current_api_key = openai_api_keys[current_openai_key_index]
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
|
| 47 |
+
data = clean_data(csv_url)
|
| 48 |
+
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
|
| 49 |
+
# Generate unique filename using UUID
|
| 50 |
+
chart_filename = f"chart_{uuid.uuid4()}.png"
|
| 51 |
+
chart_path = os.path.join("generated_charts", chart_filename)
|
| 52 |
+
|
| 53 |
+
# Configure SmartDataframe with chart settings
|
| 54 |
+
df = SmartDataframe(
|
| 55 |
+
data,
|
| 56 |
+
config={
|
| 57 |
+
'llm': llm,
|
| 58 |
+
'save_charts': True, # Enable chart saving
|
| 59 |
+
'open_charts': False,
|
| 60 |
+
'save_charts_path': os.path.dirname(chart_path), # Directory to save
|
| 61 |
+
'custom_chart_filename': chart_filename # Unique filename
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
answer = df.chat(question)
|
| 66 |
+
# Process different response types
|
| 67 |
+
if isinstance(answer, pd.DataFrame):
|
| 68 |
+
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
|
| 69 |
+
elif isinstance(answer, pd.Series):
|
| 70 |
+
processed = answer.apply(handle_out_of_range_float).to_dict()
|
| 71 |
+
elif isinstance(answer, list):
|
| 72 |
+
processed = [handle_out_of_range_float(item) for item in answer]
|
| 73 |
+
elif isinstance(answer, dict):
|
| 74 |
+
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
|
| 75 |
+
else:
|
| 76 |
+
processed = {"answer": str(handle_out_of_range_float(answer))}
|
| 77 |
+
|
| 78 |
+
if process_answer(processed):
|
| 79 |
+
return {"error": "Answer is not valid."}
|
| 80 |
+
return processed
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_message = str(e)
|
| 84 |
+
if error_message:
|
| 85 |
+
with current_openai_key_lock:
|
| 86 |
+
current_openai_key_index += 1
|
| 87 |
+
if current_openai_key_index >= len(openai_api_keys):
|
| 88 |
+
print("All API keys exhausted.")
|
| 89 |
+
return None
|
| 90 |
+
else:
|
| 91 |
+
print(f"Error with API key index {current_openai_key_index}: {error_message}")
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def openai_chart(csv_url: str, question: str):
|
| 100 |
+
global current_openai_key_index, current_openai_key_lock
|
| 101 |
+
|
| 102 |
+
while True:
|
| 103 |
+
with current_openai_key_lock:
|
| 104 |
+
if current_openai_key_index >= len(openai_api_keys):
|
| 105 |
+
return {"error": "All API keys exhausted."}
|
| 106 |
+
current_api_key = openai_api_keys[current_openai_key_index]
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
|
| 110 |
+
data = clean_data(csv_url)
|
| 111 |
+
llm = ChatOpenAI(model=openai_model_name, api_key=current_api_key,base_url=openai_api_base)
|
| 112 |
+
# Generate unique filename using UUID
|
| 113 |
+
chart_filename = f"chart_{uuid.uuid4()}.png"
|
| 114 |
+
chart_path = os.path.join("generated_charts", chart_filename)
|
| 115 |
+
|
| 116 |
+
# Configure SmartDataframe with chart settings
|
| 117 |
+
df = SmartDataframe(
|
| 118 |
+
data,
|
| 119 |
+
config={
|
| 120 |
+
'llm': llm,
|
| 121 |
+
'save_charts': True, # Enable chart saving
|
| 122 |
+
'open_charts': False,
|
| 123 |
+
'save_charts_path': os.path.dirname(chart_path), # Directory to save
|
| 124 |
+
'custom_chart_filename': chart_filename # Unique filename
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
answer = df.chat(question + instructions)
|
| 129 |
+
|
| 130 |
+
if process_answer(answer):
|
| 131 |
+
return "Chart not generated"
|
| 132 |
+
return answer
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
error = str(e)
|
| 136 |
+
print(f"Error with API key index {current_openai_key_index}: {error}")
|
| 137 |
+
if "429" in error or error is not None:
|
| 138 |
+
with current_openai_key_lock:
|
| 139 |
+
current_openai_key_index = (current_openai_key_index + 1) % len(openai_api_keys)
|
| 140 |
+
else:
|
| 141 |
+
print(f"Chart generation error: {error}")
|
| 142 |
+
return {"error": error}
|
| 143 |
+
|
| 144 |
+
print("All API keys exhausted for chart generation")
|
| 145 |
+
return None
|
orchestrator_agent.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
|
|
|
|
| 2 |
import os
|
| 3 |
from typing import Dict, List, Any
|
| 4 |
from pydantic_ai import Agent
|
|
@@ -7,10 +8,12 @@ from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
|
| 7 |
from pydantic_ai import RunContext
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
|
|
|
|
| 10 |
from csv_service import get_csv_basic_info
|
| 11 |
from orchestrator_functions import csv_chart, csv_chat
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
|
|
|
|
| 14 |
load_dotenv()
|
| 15 |
|
| 16 |
|
|
@@ -125,8 +128,9 @@ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agen
|
|
| 125 |
- Highlight limitations/caveats
|
| 126 |
|
| 127 |
5. TOOL USAGE:
|
| 128 |
-
-
|
| 129 |
-
-
|
|
|
|
| 130 |
|
| 131 |
## Current Context:
|
| 132 |
- Working with CSV_URL: {csv_url}
|
|
@@ -141,58 +145,44 @@ def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agen
|
|
| 141 |
4. Provide interpretation
|
| 142 |
5. Offer next-step suggestions
|
| 143 |
"""
|
| 144 |
-
|
| 145 |
-
# system_prompt = (
|
| 146 |
-
# "You are a data analyst. "
|
| 147 |
-
# "You have all the tools you need to answer any question. "
|
| 148 |
-
# "If the user asks for multiple answers or charts, break the question into several well-defined questions. "
|
| 149 |
-
# "Pass the CSV URL or file path along with the questions to the tools to generate the answer. "
|
| 150 |
-
# "The tools are actually LLMs with Python code execution capabilities. "
|
| 151 |
-
# "Modify the query if needed to make it simpler for the LLM to understand. "
|
| 152 |
-
# "Answer in a friendly and helpful manner. "
|
| 153 |
-
# "**Format images** in Markdown: ``. "
|
| 154 |
-
# f"Your CSV URL is {csv_url}. "
|
| 155 |
-
# f"Your CSV metadata is {csv_metadata}."
|
| 156 |
-
# )
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# system_prompt = (
|
| 160 |
-
# "You are a data analyst assistant with limited tool capabilities. "
|
| 161 |
-
# "Available tools can only handle simple data queries: "
|
| 162 |
-
# "- Count rows/columns\n- Calculate basic stats (avg, sum, min/max)\n"
|
| 163 |
-
# "- Create simple visualizations (pie charts, bar graphs)\n"
|
| 164 |
-
# "- Show column names/types\n\n"
|
| 165 |
-
|
| 166 |
-
# "Query Handling Rules:\n"
|
| 167 |
-
# "1. If query is complex, ambiguous, or exceeds tool capabilities:\n"
|
| 168 |
-
# " - Break into simpler sub-questions\n"
|
| 169 |
-
# " - Ask for clarification\n"
|
| 170 |
-
# " - Rephrase to nearest simple query\n"
|
| 171 |
-
# "2. For 'full report' requests:\n"
|
| 172 |
-
# " - Outline possible analysis steps\n"
|
| 173 |
-
# " - Ask user to select one component at a time\n\n"
|
| 174 |
-
|
| 175 |
-
# "Examples:\n"
|
| 176 |
-
# "- Bad query: 'Show me everything'\n"
|
| 177 |
-
# " Response: 'I can show row count (10), columns (5: Name, Age...), "
|
| 178 |
-
# "or a pie chart of categories. Which would you like?'\n"
|
| 179 |
-
# "- Bad query: 'Analyze trends'\n"
|
| 180 |
-
# " Response: 'For trend analysis, I can show monthly averages or "
|
| 181 |
-
# "year-over-year comparisons. Please specify time period and metric.'\n\n"
|
| 182 |
-
|
| 183 |
-
# "Current CSV Context:\n"
|
| 184 |
-
# f"- URL: {csv_url}\n"
|
| 185 |
-
# f"- Metadata: {csv_metadata}\n\n"
|
| 186 |
-
|
| 187 |
-
# "Always format images as: "
|
| 188 |
-
# )
|
| 189 |
-
|
| 190 |
-
return Agent(
|
| 191 |
model=initialize_model(api_key),
|
| 192 |
deps_type=str,
|
| 193 |
tools=[generate_csv_answer, generate_chart],
|
| 194 |
system_prompt=system_prompt
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
| 198 |
print("CSV URL:", csv_url)
|
|
@@ -216,3 +206,4 @@ def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history
|
|
| 216 |
# If all keys are exhausted or fail
|
| 217 |
print("All API keys have been exhausted or failed.")
|
| 218 |
return None
|
|
|
|
|
|
| 1 |
|
| 2 |
+
from datetime import datetime
|
| 3 |
import os
|
| 4 |
from typing import Dict, List, Any
|
| 5 |
from pydantic_ai import Agent
|
|
|
|
| 8 |
from pydantic_ai import RunContext
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
|
| 11 |
+
from code_exec_service import run_analysis
|
| 12 |
from csv_service import get_csv_basic_info
|
| 13 |
from orchestrator_functions import csv_chart, csv_chat
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
|
| 16 |
+
|
| 17 |
load_dotenv()
|
| 18 |
|
| 19 |
|
|
|
|
| 128 |
- Highlight limitations/caveats
|
| 129 |
|
| 130 |
5. TOOL USAGE:
|
| 131 |
+
- Python Code Executor Tool (To execute Python code, To get date-time, For lightweight data analysis etc.)
|
| 132 |
+
- Data Analysis Tool
|
| 133 |
+
- Chart Generation Tool
|
| 134 |
|
| 135 |
## Current Context:
|
| 136 |
- Working with CSV_URL: {csv_url}
|
|
|
|
| 145 |
4. Provide interpretation
|
| 146 |
5. Offer next-step suggestions
|
| 147 |
"""
|
| 148 |
+
gemini_csv_orchestrator_agent = Agent(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
model=initialize_model(api_key),
|
| 150 |
deps_type=str,
|
| 151 |
tools=[generate_csv_answer, generate_chart],
|
| 152 |
system_prompt=system_prompt
|
| 153 |
)
|
| 154 |
+
|
| 155 |
+
@gemini_csv_orchestrator_agent.tool_plain
|
| 156 |
+
def python_code_executor(analysis_code: str) -> dict:
|
| 157 |
+
"""_summary_
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
analysis_code (str): _description_
|
| 161 |
+
Ex:
|
| 162 |
+
df = pd.read_csv({csv_url})
|
| 163 |
+
len(df)
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict: _description_
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
print(f'LLM Passed a code: {analysis_code}')
|
| 170 |
+
result = run_analysis(analysis_code)
|
| 171 |
+
if result['success']:
|
| 172 |
+
print("Execution successful")
|
| 173 |
+
print("Execution time:", result['execution_time'], "seconds")
|
| 174 |
+
print("Output:", result['output'].strip())
|
| 175 |
+
print("Result:", result['result'])
|
| 176 |
+
print("Variables:", list(result['variables'].keys()))
|
| 177 |
+
# convert the result to a string
|
| 178 |
+
result_str = str(result['output'])
|
| 179 |
+
return result_str
|
| 180 |
+
else:
|
| 181 |
+
print("Execution failed")
|
| 182 |
+
print("Error:", result['error'])
|
| 183 |
+
error_str = str(result['error'])
|
| 184 |
+
return error_str
|
| 185 |
+
return gemini_csv_orchestrator_agent
|
| 186 |
|
| 187 |
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
| 188 |
print("CSV URL:", csv_url)
|
|
|
|
| 206 |
# If all keys are exhausted or fail
|
| 207 |
print("All API keys have been exhausted or failed.")
|
| 208 |
return None
|
| 209 |
+
|
orchestrator_functions.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# Import necessary modules
|
| 2 |
import asyncio
|
|
|
|
| 3 |
import os
|
| 4 |
import threading
|
| 5 |
import uuid
|
|
@@ -20,6 +21,7 @@ import matplotlib.pyplot as plt
|
|
| 20 |
import matplotlib
|
| 21 |
import seaborn as sns
|
| 22 |
from gemini_langchain_agent import langchain_gemini_csv_handler
|
|
|
|
| 23 |
from supabase_service import upload_file_to_supabase
|
| 24 |
from util_service import _prompt_generator, process_answer
|
| 25 |
import matplotlib
|
|
@@ -69,6 +71,10 @@ def handle_out_of_range_float(value):
|
|
| 69 |
return "Infinity"
|
| 70 |
return value
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Modified groq_chat function with thread-safe key rotation
|
| 74 |
def groq_chat(csv_url: str, question: str):
|
|
@@ -87,7 +93,7 @@ def groq_chat(csv_url: str, question: str):
|
|
| 87 |
try:
|
| 88 |
os.remove(cache_db_path)
|
| 89 |
except Exception as e:
|
| 90 |
-
|
| 91 |
|
| 92 |
data = clean_data(csv_url)
|
| 93 |
llm = ChatGroq(model=model_name, api_key=current_api_key)
|
|
@@ -129,10 +135,10 @@ def groq_chat(csv_url: str, question: str):
|
|
| 129 |
with current_groq_key_lock:
|
| 130 |
current_groq_key_index += 1
|
| 131 |
if current_groq_key_index >= len(groq_api_keys):
|
| 132 |
-
|
| 133 |
return None
|
| 134 |
else:
|
| 135 |
-
|
| 136 |
return None
|
| 137 |
|
| 138 |
|
|
@@ -183,10 +189,10 @@ def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
|
|
| 183 |
return result.get("output")
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
-
|
| 187 |
|
| 188 |
# If all keys are exhausted, return None
|
| 189 |
-
|
| 190 |
return None
|
| 191 |
|
| 192 |
|
|
@@ -241,7 +247,7 @@ def groq_chart(csv_url: str, question: str):
|
|
| 241 |
try:
|
| 242 |
os.remove(cache_db_path)
|
| 243 |
except Exception as e:
|
| 244 |
-
|
| 245 |
|
| 246 |
data = clean_data(csv_url)
|
| 247 |
with current_groq_chart_lock:
|
|
@@ -277,10 +283,10 @@ def groq_chart(csv_url: str, question: str):
|
|
| 277 |
with current_groq_chart_lock:
|
| 278 |
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
|
| 279 |
else:
|
| 280 |
-
|
| 281 |
return {"error": error}
|
| 282 |
|
| 283 |
-
|
| 284 |
return None
|
| 285 |
|
| 286 |
|
|
@@ -327,12 +333,12 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 327 |
return chart_files
|
| 328 |
|
| 329 |
if attempt < len(groq_api_keys) - 1:
|
| 330 |
-
|
| 331 |
|
| 332 |
except Exception as e:
|
| 333 |
-
|
| 334 |
|
| 335 |
-
|
| 336 |
return None
|
| 337 |
|
| 338 |
|
|
@@ -363,50 +369,50 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 363 |
# # First try Groq-based chart generation
|
| 364 |
# try:
|
| 365 |
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
|
| 366 |
-
#
|
| 367 |
|
| 368 |
# if groq_result != 'Chart not generated':
|
| 369 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 370 |
# image_public_url = await upload_file_to_supabase(groq_result, unique_file_name)
|
| 371 |
-
#
|
| 372 |
# return {"image_url": image_public_url}
|
| 373 |
|
| 374 |
# except Exception as groq_error:
|
| 375 |
-
#
|
| 376 |
|
| 377 |
# # Fallback to Langchain if Groq fails
|
| 378 |
# try:
|
| 379 |
# langchain_paths = await asyncio.to_thread(langchain_csv_chart, csv_url, query, True)
|
| 380 |
-
#
|
| 381 |
|
| 382 |
# if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
| 383 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 384 |
-
#
|
| 385 |
# image_public_url = await upload_file_to_supabase(langchain_paths[0], unique_file_name)
|
| 386 |
-
#
|
| 387 |
# return {"image_url": image_public_url}
|
| 388 |
|
| 389 |
# except Exception as langchain_error:
|
| 390 |
-
#
|
| 391 |
# try:
|
| 392 |
# # Last resort: Try with the gemini langchain agent
|
| 393 |
-
#
|
| 394 |
# lc_gemini_chart_result = await asyncio.to_thread(langchain_gemini_csv_handler, csv_url, query, True)
|
| 395 |
# if lc_gemini_chart_result is not None:
|
| 396 |
# clean_path = lc_gemini_chart_result.strip()
|
| 397 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 398 |
-
#
|
| 399 |
# image_public_url = await upload_file_to_supabase(clean_path, unique_file_name)
|
| 400 |
-
#
|
| 401 |
# return {"image_url": image_public_url}
|
| 402 |
# except Exception as gemini_error:
|
| 403 |
-
#
|
| 404 |
|
| 405 |
# # If both methods fail
|
| 406 |
# return {"error": "Could not generate the chart, please try again."}
|
| 407 |
|
| 408 |
# except Exception as e:
|
| 409 |
-
#
|
| 410 |
# return {"error": "Internal system error"}
|
| 411 |
|
| 412 |
|
|
@@ -436,7 +442,7 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 436 |
# # Process with Groq first
|
| 437 |
# try:
|
| 438 |
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
| 439 |
-
#
|
| 440 |
|
| 441 |
# if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
|
| 442 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
|
@@ -447,7 +453,7 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 447 |
# return {"answer": jsonable_encoder(groq_answer)}
|
| 448 |
|
| 449 |
# except Exception as groq_error:
|
| 450 |
-
#
|
| 451 |
|
| 452 |
# # Process with LangChain if Groq fails
|
| 453 |
# try:
|
|
@@ -458,7 +464,7 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 458 |
# return {"answer": jsonable_encoder(lang_answer)}
|
| 459 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 460 |
# except Exception as langchain_error:
|
| 461 |
-
#
|
| 462 |
|
| 463 |
# # last resort: Try with the gemini langchain agent
|
| 464 |
# try:
|
|
@@ -469,11 +475,11 @@ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
|
|
| 469 |
# return {"answer": jsonable_encoder(gemini_answer)}
|
| 470 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 471 |
# except Exception as gemini_error:
|
| 472 |
-
#
|
| 473 |
# return {"answer": "error"}
|
| 474 |
|
| 475 |
# except Exception as e:
|
| 476 |
-
#
|
| 477 |
# return {"answer": "error"}
|
| 478 |
|
| 479 |
|
|
@@ -511,7 +517,7 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 511 |
gemini_answer = await asyncio.to_thread(
|
| 512 |
langchain_gemini_csv_handler, csv_url, updated_query, False
|
| 513 |
)
|
| 514 |
-
|
| 515 |
|
| 516 |
if not process_answer(gemini_answer) or gemini_answer is None:
|
| 517 |
return {"answer": jsonable_encoder(gemini_answer)}
|
|
@@ -519,14 +525,14 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 519 |
raise Exception("LangChain-Gemini response not usable, falling back to LangChain-Groq")
|
| 520 |
|
| 521 |
except Exception as gemini_error:
|
| 522 |
-
|
| 523 |
|
| 524 |
# --- 2. Second Attempt: LangChain Groq ---
|
| 525 |
try:
|
| 526 |
lang_groq_answer = await asyncio.to_thread(
|
| 527 |
langchain_csv_chat, csv_url, updated_query, False
|
| 528 |
)
|
| 529 |
-
|
| 530 |
|
| 531 |
if not process_answer(lang_groq_answer):
|
| 532 |
return {"answer": jsonable_encoder(lang_groq_answer)}
|
|
@@ -534,12 +540,12 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 534 |
raise Exception("LangChain-Groq response not usable, falling back to raw Groq")
|
| 535 |
|
| 536 |
except Exception as lang_groq_error:
|
| 537 |
-
|
| 538 |
|
| 539 |
# --- 3. Final Attempt: Raw Groq Chat ---
|
| 540 |
try:
|
| 541 |
raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
| 542 |
-
|
| 543 |
|
| 544 |
if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None:
|
| 545 |
return {"answer": "Sorry, I couldn't find relevant data..."}
|
|
@@ -550,11 +556,11 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 550 |
return {"answer": jsonable_encoder(raw_groq_answer)}
|
| 551 |
|
| 552 |
except Exception as raw_groq_error:
|
| 553 |
-
|
| 554 |
return {"answer": "error"}
|
| 555 |
|
| 556 |
except Exception as e:
|
| 557 |
-
|
| 558 |
return {"answer": "error"}
|
| 559 |
|
| 560 |
|
|
@@ -567,7 +573,7 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 567 |
async def csv_chart(csv_url: str, query: str):
|
| 568 |
"""
|
| 569 |
Generate a chart based on the provided CSV URL and query.
|
| 570 |
-
Prioritizes raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
|
| 571 |
|
| 572 |
Parameters:
|
| 573 |
- csv_url (str): The URL of the CSV file.
|
|
@@ -589,61 +595,73 @@ async def csv_chart(csv_url: str, query: str):
|
|
| 589 |
"""Helper function to handle image uploads"""
|
| 590 |
unique_name = f'{uuid.uuid4()}.png'
|
| 591 |
public_url = await upload_file_to_supabase(image_path, unique_name)
|
| 592 |
-
|
| 593 |
os.remove(image_path) # Remove the local image file after upload
|
| 594 |
return {"image_url": public_url}
|
| 595 |
|
| 596 |
try:
|
| 597 |
-
# --- 1. First Attempt: Raw Groq ---
|
| 598 |
try:
|
| 599 |
-
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
|
| 605 |
-
|
| 606 |
|
| 607 |
-
|
| 608 |
-
|
| 609 |
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
|
| 626 |
-
|
| 627 |
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
|
| 638 |
-
|
| 639 |
-
|
| 640 |
|
| 641 |
-
|
| 642 |
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
|
| 647 |
except Exception as e:
|
| 648 |
-
|
| 649 |
return {"error": "Internal system error"}
|
|
|
|
| 1 |
# Import necessary modules
|
| 2 |
import asyncio
|
| 3 |
+
import logging
|
| 4 |
import os
|
| 5 |
import threading
|
| 6 |
import uuid
|
|
|
|
| 21 |
import matplotlib
|
| 22 |
import seaborn as sns
|
| 23 |
from gemini_langchain_agent import langchain_gemini_csv_handler
|
| 24 |
+
from openai_pandasai_service import openai_chart
|
| 25 |
from supabase_service import upload_file_to_supabase
|
| 26 |
from util_service import _prompt_generator, process_answer
|
| 27 |
import matplotlib
|
|
|
|
| 71 |
return "Infinity"
|
| 72 |
return value
|
| 73 |
|
| 74 |
+
# Set up logging
|
| 75 |
+
logging.basicConfig(level=logging.INFO)
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
|
| 78 |
|
| 79 |
# Modified groq_chat function with thread-safe key rotation
|
| 80 |
def groq_chat(csv_url: str, question: str):
|
|
|
|
| 93 |
try:
|
| 94 |
os.remove(cache_db_path)
|
| 95 |
except Exception as e:
|
| 96 |
+
logger.info(f"Error deleting cache DB file: {e}")
|
| 97 |
|
| 98 |
data = clean_data(csv_url)
|
| 99 |
llm = ChatGroq(model=model_name, api_key=current_api_key)
|
|
|
|
| 135 |
with current_groq_key_lock:
|
| 136 |
current_groq_key_index += 1
|
| 137 |
if current_groq_key_index >= len(groq_api_keys):
|
| 138 |
+
logger.info("All API keys exhausted.")
|
| 139 |
return None
|
| 140 |
else:
|
| 141 |
+
logger.info(f"Error with API key index {current_groq_key_index}: {error_message}")
|
| 142 |
return None
|
| 143 |
|
| 144 |
|
|
|
|
| 189 |
return result.get("output")
|
| 190 |
|
| 191 |
except Exception as e:
|
| 192 |
+
logger.info(f"Error with key index {current_key}: {str(e)}")
|
| 193 |
|
| 194 |
# If all keys are exhausted, return None
|
| 195 |
+
logger.info("All API keys have been exhausted.")
|
| 196 |
return None
|
| 197 |
|
| 198 |
|
|
|
|
| 247 |
try:
|
| 248 |
os.remove(cache_db_path)
|
| 249 |
except Exception as e:
|
| 250 |
+
logger.info(f"Cache cleanup error: {e}")
|
| 251 |
|
| 252 |
data = clean_data(csv_url)
|
| 253 |
with current_groq_chart_lock:
|
|
|
|
| 283 |
with current_groq_chart_lock:
|
| 284 |
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
|
| 285 |
else:
|
| 286 |
+
logger.info(f"Chart generation error: {error}")
|
| 287 |
return {"error": error}
|
| 288 |
|
| 289 |
+
logger.info("All API keys exhausted for chart generation")
|
| 290 |
return None
|
| 291 |
|
| 292 |
|
|
|
|
| 333 |
return chart_files
|
| 334 |
|
| 335 |
if attempt < len(groq_api_keys) - 1:
|
| 336 |
+
logger.info(f"Langchain chart error (key {current_key}): {output}")
|
| 337 |
|
| 338 |
except Exception as e:
|
| 339 |
+
logger.info(f"Langchain chart error (key {current_key}): {str(e)}")
|
| 340 |
|
| 341 |
+
logger.info("All API keys exhausted for chart generation")
|
| 342 |
return None
|
| 343 |
|
| 344 |
|
|
|
|
| 369 |
# # First try Groq-based chart generation
|
| 370 |
# try:
|
| 371 |
# groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
|
| 372 |
+
# logger.info(f"Generated Chart (Groq): {groq_result}")
|
| 373 |
|
| 374 |
# if groq_result != 'Chart not generated':
|
| 375 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 376 |
# image_public_url = await upload_file_to_supabase(groq_result, unique_file_name)
|
| 377 |
+
# logger.info(f"Image uploaded to Supabase: {image_public_url}")
|
| 378 |
# return {"image_url": image_public_url}
|
| 379 |
|
| 380 |
# except Exception as groq_error:
|
| 381 |
+
# logger.info(f"Groq chart generation failed, falling back to Langchain: {str(groq_error)}")
|
| 382 |
|
| 383 |
# # Fallback to Langchain if Groq fails
|
| 384 |
# try:
|
| 385 |
# langchain_paths = await asyncio.to_thread(langchain_csv_chart, csv_url, query, True)
|
| 386 |
+
# logger.info("Fallback langchain chart result:", langchain_paths)
|
| 387 |
|
| 388 |
# if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
| 389 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 390 |
+
# logger.info("Uploading the chart to supabase...")
|
| 391 |
# image_public_url = await upload_file_to_supabase(langchain_paths[0], unique_file_name)
|
| 392 |
+
# logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url)
|
| 393 |
# return {"image_url": image_public_url}
|
| 394 |
|
| 395 |
# except Exception as langchain_error:
|
| 396 |
+
# logger.info(f"Langchain chart generation also failed: {str(langchain_error)}")
|
| 397 |
# try:
|
| 398 |
# # Last resort: Try with the gemini langchain agent
|
| 399 |
+
# logger.info("Trying with the gemini langchain agent...")
|
| 400 |
# lc_gemini_chart_result = await asyncio.to_thread(langchain_gemini_csv_handler, csv_url, query, True)
|
| 401 |
# if lc_gemini_chart_result is not None:
|
| 402 |
# clean_path = lc_gemini_chart_result.strip()
|
| 403 |
# unique_file_name = f'{str(uuid.uuid4())}.png'
|
| 404 |
+
# logger.info("Uploading the chart to supabase...")
|
| 405 |
# image_public_url = await upload_file_to_supabase(clean_path, unique_file_name)
|
| 406 |
+
# logger.info("Image uploaded to Supabase and Image URL is... ", image_public_url)
|
| 407 |
# return {"image_url": image_public_url}
|
| 408 |
# except Exception as gemini_error:
|
| 409 |
+
# logger.info(f"Gemini Langchain chart generation also failed: {str(gemini_error)}")
|
| 410 |
|
| 411 |
# # If both methods fail
|
| 412 |
# return {"error": "Could not generate the chart, please try again."}
|
| 413 |
|
| 414 |
# except Exception as e:
|
| 415 |
+
# logger.info(f"Critical chart error: {str(e)}")
|
| 416 |
# return {"error": "Internal system error"}
|
| 417 |
|
| 418 |
|
|
|
|
| 442 |
# # Process with Groq first
|
| 443 |
# try:
|
| 444 |
# groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
| 445 |
+
# logger.info("groq_answer:", groq_answer)
|
| 446 |
|
| 447 |
# if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
|
| 448 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
|
|
|
| 453 |
# return {"answer": jsonable_encoder(groq_answer)}
|
| 454 |
|
| 455 |
# except Exception as groq_error:
|
| 456 |
+
# logger.info(f"Groq error, falling back to LangChain: {str(groq_error)}")
|
| 457 |
|
| 458 |
# # Process with LangChain if Groq fails
|
| 459 |
# try:
|
|
|
|
| 464 |
# return {"answer": jsonable_encoder(lang_answer)}
|
| 465 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 466 |
# except Exception as langchain_error:
|
| 467 |
+
# logger.info(f"LangChain processing error: {str(langchain_error)}")
|
| 468 |
|
| 469 |
# # last resort: Try with the gemini langchain agent
|
| 470 |
# try:
|
|
|
|
| 475 |
# return {"answer": jsonable_encoder(gemini_answer)}
|
| 476 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 477 |
# except Exception as gemini_error:
|
| 478 |
+
# logger.info(f"Gemini Langchain processing error: {str(gemini_error)}")
|
| 479 |
# return {"answer": "error"}
|
| 480 |
|
| 481 |
# except Exception as e:
|
| 482 |
+
# logger.info(f"Error processing request: {str(e)}")
|
| 483 |
# return {"answer": "error"}
|
| 484 |
|
| 485 |
|
|
|
|
| 517 |
gemini_answer = await asyncio.to_thread(
|
| 518 |
langchain_gemini_csv_handler, csv_url, updated_query, False
|
| 519 |
)
|
| 520 |
+
logger.info("LangChain-Gemini answer:", gemini_answer)
|
| 521 |
|
| 522 |
if not process_answer(gemini_answer) or gemini_answer is None:
|
| 523 |
return {"answer": jsonable_encoder(gemini_answer)}
|
|
|
|
| 525 |
raise Exception("LangChain-Gemini response not usable, falling back to LangChain-Groq")
|
| 526 |
|
| 527 |
except Exception as gemini_error:
|
| 528 |
+
logger.info(f"LangChain-Gemini error: {str(gemini_error)}")
|
| 529 |
|
| 530 |
# --- 2. Second Attempt: LangChain Groq ---
|
| 531 |
try:
|
| 532 |
lang_groq_answer = await asyncio.to_thread(
|
| 533 |
langchain_csv_chat, csv_url, updated_query, False
|
| 534 |
)
|
| 535 |
+
logger.info("LangChain-Groq answer:", lang_groq_answer)
|
| 536 |
|
| 537 |
if not process_answer(lang_groq_answer):
|
| 538 |
return {"answer": jsonable_encoder(lang_groq_answer)}
|
|
|
|
| 540 |
raise Exception("LangChain-Groq response not usable, falling back to raw Groq")
|
| 541 |
|
| 542 |
except Exception as lang_groq_error:
|
| 543 |
+
logger.info(f"LangChain-Groq error: {str(lang_groq_error)}")
|
| 544 |
|
| 545 |
# --- 3. Final Attempt: Raw Groq Chat ---
|
| 546 |
try:
|
| 547 |
raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
|
| 548 |
+
logger.info("Raw Groq answer:", raw_groq_answer)
|
| 549 |
|
| 550 |
if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None:
|
| 551 |
return {"answer": "Sorry, I couldn't find relevant data..."}
|
|
|
|
| 556 |
return {"answer": jsonable_encoder(raw_groq_answer)}
|
| 557 |
|
| 558 |
except Exception as raw_groq_error:
|
| 559 |
+
logger.info(f"Raw Groq error: {str(raw_groq_error)}")
|
| 560 |
return {"answer": "error"}
|
| 561 |
|
| 562 |
except Exception as e:
|
| 563 |
+
logger.info(f"Unexpected error: {str(e)}")
|
| 564 |
return {"answer": "error"}
|
| 565 |
|
| 566 |
|
|
|
|
| 573 |
async def csv_chart(csv_url: str, query: str):
|
| 574 |
"""
|
| 575 |
Generate a chart based on the provided CSV URL and query.
|
| 576 |
+
Prioritizes OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
|
| 577 |
|
| 578 |
Parameters:
|
| 579 |
- csv_url (str): The URL of the CSV file.
|
|
|
|
| 595 |
"""Helper function to handle image uploads"""
|
| 596 |
unique_name = f'{uuid.uuid4()}.png'
|
| 597 |
public_url = await upload_file_to_supabase(image_path, unique_name)
|
| 598 |
+
logger.info(f"Uploaded chart: {public_url}")
|
| 599 |
os.remove(image_path) # Remove the local image file after upload
|
| 600 |
return {"image_url": public_url}
|
| 601 |
|
| 602 |
try:
|
|
|
|
| 603 |
try:
|
| 604 |
+
# --- 1. First Attempt: OpenAI ---
|
| 605 |
+
openai_result = await asyncio.to_thread(openai_chart, csv_url, query)
|
| 606 |
+
logger.info(f"OpenAI chart result:", openai_result)
|
| 607 |
+
|
| 608 |
+
if openai_result and openai_result != 'Chart not generated':
|
| 609 |
+
return await upload_and_return(openai_result)
|
| 610 |
+
|
| 611 |
+
raise Exception("OpenAI failed to generate chart")
|
| 612 |
+
|
| 613 |
+
except Exception as openai_error:
|
| 614 |
+
logger.info(f"OpenAI failed ({str(openai_error)}), trying raw Groq...")
|
| 615 |
+
# --- 2. Second Attempt: Raw Groq ---
|
| 616 |
+
try:
|
| 617 |
+
groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
|
| 618 |
+
logger.info(f"Raw Groq chart result:", groq_result)
|
| 619 |
|
| 620 |
+
if groq_result and groq_result != 'Chart not generated':
|
| 621 |
+
return await upload_and_return(groq_result)
|
| 622 |
|
| 623 |
+
raise Exception("Raw Groq failed to generate chart")
|
| 624 |
|
| 625 |
+
except Exception as groq_error:
|
| 626 |
+
logger.info(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...")
|
| 627 |
|
| 628 |
+
# --- 3. Third Attempt: LangChain Gemini ---
|
| 629 |
+
try:
|
| 630 |
+
gemini_result = await asyncio.to_thread(
|
| 631 |
+
langchain_gemini_csv_handler, csv_url, query, True
|
| 632 |
+
)
|
| 633 |
+
logger.info("LangChain Gemini chart result:", gemini_result)
|
| 634 |
|
| 635 |
+
# --- i) If Gemini result is a string, return it ---
|
| 636 |
+
if gemini_result and isinstance(gemini_result, str):
|
| 637 |
+
clean_path = gemini_result.strip()
|
| 638 |
+
return await upload_and_return(clean_path)
|
| 639 |
|
| 640 |
+
# --- ii) If Gemini result is a list, return the first element ---
|
| 641 |
+
if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0:
|
| 642 |
+
return await upload_and_return(gemini_result[0])
|
| 643 |
|
| 644 |
+
raise Exception("LangChain Gemini returned empty result")
|
| 645 |
|
| 646 |
+
except Exception as gemini_error:
|
| 647 |
+
logger.info(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...")
|
| 648 |
+
|
| 649 |
+
# --- 4. Final Attempt: LangChain Groq ---
|
| 650 |
+
try:
|
| 651 |
+
lc_groq_paths = await asyncio.to_thread(
|
| 652 |
+
langchain_csv_chart, csv_url, query, True
|
| 653 |
+
)
|
| 654 |
+
logger.info("LangChain Groq chart result:", lc_groq_paths)
|
| 655 |
|
| 656 |
+
if isinstance(lc_groq_paths, list) and lc_groq_paths:
|
| 657 |
+
return await upload_and_return(lc_groq_paths[0])
|
| 658 |
|
| 659 |
+
return {"error": "All chart generation methods failed"}
|
| 660 |
|
| 661 |
+
except Exception as lc_groq_error:
|
| 662 |
+
logger.info(f"LangChain Groq failed: {str(lc_groq_error)}")
|
| 663 |
+
return {"error": "Could not generate chart"}
|
| 664 |
|
| 665 |
except Exception as e:
|
| 666 |
+
logger.info(f"Critical error: {str(e)}")
|
| 667 |
return {"error": "Internal system error"}
|
util_service.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
import numpy as np
|
| 3 |
|
| 4 |
-
keywords = ["i encountered","unfortunately", "unsupported", "error", "sorry", "response", "unable", "because", "too many"]
|
| 5 |
|
| 6 |
def contains_keywords(text, keywords):
|
| 7 |
return any(keyword.lower() in text.lower() for keyword in keywords)
|
|
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
import numpy as np
|
| 3 |
|
| 4 |
+
keywords = ["i encountered", "429", "unfortunately", "unsupported", "error", "sorry", "response", "unable", "because", "too many"]
|
| 5 |
|
| 6 |
def contains_keywords(text, keywords):
|
| 7 |
return any(keyword.lower() in text.lower() for keyword in keywords)
|