|
import os |
|
import uuid |
|
from langchain_openai import ChatOpenAI |
|
import pandas as pd |
|
from pandasai import SmartDataframe |
|
from csv_service import clean_data |
|
from dotenv import load_dotenv |
|
from util_service import handle_out_of_range_float, process_answer |
|
|
|
load_dotenv() |
|
openai_api_keys = os.getenv("OPENAI_API_KEYS").split(",") |
|
openai_api_base = os.getenv("OPENAI_API_BASE") |
|
|
|
|
|
llm_instances = [ |
|
ChatOpenAI(model='gpt-4o', api_key=key, base_url=openai_api_base) |
|
for key in openai_api_keys |
|
] |
|
current_llm_index = 0 |
|
|
|
instructions = """ |
|
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). |
|
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) |
|
- Use professional and color-blind friendly palettes. |
|
- Do not use sns.set_palette() |
|
- Read above instructions and follow them. |
|
""" |
|
|
|
def should_rotate_key(response): |
|
"""Check if response indicates API key needs rotation""" |
|
if isinstance(response, str): |
|
return any(msg in response for msg in [ |
|
"plan.rule:api_request", |
|
"resource limit", |
|
"exhausted the available", |
|
"update your payment method" |
|
]) |
|
return False |
|
|
|
def handle_api_error(error, current_index): |
|
"""Determine if we should rotate API keys based on error message""" |
|
error_msg = str(error) |
|
if any(msg in error_msg for msg in [ |
|
"plan.rule:api_request", |
|
"429", |
|
"resource limit", |
|
"exhausted the available", |
|
"update your payment method" |
|
]): |
|
print(f"Rotating API key due to resource limit (key index {current_index})") |
|
return True |
|
return False |
|
|
|
def openai_chat(csv_url: str, question: str): |
|
global current_llm_index |
|
|
|
while current_llm_index < len(llm_instances): |
|
try: |
|
data = clean_data(csv_url) |
|
llm = llm_instances[current_llm_index] |
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
df = SmartDataframe( |
|
data, |
|
config={ |
|
'llm': llm, |
|
'save_charts': True, |
|
'open_charts': False, |
|
'save_charts_path': os.path.dirname(chart_path), |
|
'custom_chart_filename': chart_filename |
|
} |
|
) |
|
|
|
answer = df.chat(question) |
|
|
|
|
|
if should_rotate_key(answer) or process_answer(answer): |
|
current_llm_index += 1 |
|
continue |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
if handle_api_error(e, current_llm_index): |
|
current_llm_index += 1 |
|
continue |
|
return {"error": f"Non-recoverable error: {str(e)}"} |
|
|
|
return {"error": "All API keys exhausted. Please update billing information."} |
|
|
|
|
|
|
|
|
|
|
|
def openai_chart(csv_url: str, question: str): |
|
global current_llm_index |
|
|
|
while current_llm_index < len(llm_instances): |
|
try: |
|
data = clean_data(csv_url) |
|
llm = llm_instances[current_llm_index] |
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
df = SmartDataframe( |
|
data, |
|
config={ |
|
'llm': llm, |
|
'save_charts': True, |
|
'open_charts': False, |
|
'save_charts_path': os.path.dirname(chart_path), |
|
'custom_chart_filename': chart_filename |
|
} |
|
) |
|
|
|
answer = df.chat(question + instructions) |
|
|
|
|
|
if should_rotate_key(answer) or process_answer(answer): |
|
current_llm_index += 1 |
|
continue |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
if handle_api_error(e, current_llm_index): |
|
current_llm_index += 1 |
|
continue |
|
return {"error": f"Chart generation failed: {str(e)}"} |
|
|
|
return {"error": "All API keys exhausted. Please update billing information."} |