|
|
|
import asyncio |
|
import logging |
|
import os |
|
import threading |
|
import uuid |
|
from fastapi.encoders import jsonable_encoder |
|
import numpy as np |
|
import pandas as pd |
|
from pandasai import SmartDataframe |
|
from langchain_groq.chat_models import ChatGroq |
|
from dotenv import load_dotenv |
|
from pydantic import BaseModel |
|
from csv_service import clean_data, extract_chart_filenames |
|
from langchain_groq import ChatGroq |
|
import pandas as pd |
|
from langchain_experimental.tools import PythonAstREPLTool |
|
from langchain_experimental.agents import create_pandas_dataframe_agent |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import seaborn as sns |
|
from gemini_langchain_agent import langchain_gemini_csv_handler |
|
from openai_pandasai_service import openai_chart |
|
from supabase_service import upload_file_to_supabase |
|
from util_service import _prompt_generator, process_answer |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
|
|
load_dotenv() |
|
|
|
image_file_path = os.getenv("IMAGE_FILE_PATH") |
|
image_not_found = os.getenv("IMAGE_NOT_FOUND") |
|
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") |
|
|
|
|
|
|
|
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") |
|
model_name = os.getenv("GROQ_LLM_MODEL") |
|
|
|
class CsvUrlRequest(BaseModel): |
|
csv_url: str |
|
|
|
class ImageRequest(BaseModel): |
|
image_path: str |
|
|
|
class CsvCommonHeadersRequest(BaseModel): |
|
file_urls: list[str] |
|
|
|
class CsvsMergeRequest(BaseModel): |
|
file_urls: list[str] |
|
merge_type: str |
|
common_columns_name: list[str] |
|
|
|
|
|
current_groq_key_index = 0 |
|
current_groq_key_lock = threading.Lock() |
|
|
|
|
|
current_langchain_key_index = 0 |
|
current_langchain_key_lock = threading.Lock() |
|
|
|
|
|
|
|
def handle_out_of_range_float(value): |
|
if isinstance(value, float): |
|
if np.isnan(value): |
|
return None |
|
elif np.isinf(value): |
|
return "Infinity" |
|
return value |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def groq_chat(csv_url: str, question: str): |
|
global current_groq_key_index, current_groq_key_lock |
|
|
|
while True: |
|
with current_groq_key_lock: |
|
if current_groq_key_index >= len(groq_api_keys): |
|
return {"error": "All API keys exhausted."} |
|
current_api_key = groq_api_keys[current_groq_key_index] |
|
|
|
try: |
|
|
|
cache_db_path = "/workspace/cache/cache_db_0.11.db" |
|
if os.path.exists(cache_db_path): |
|
try: |
|
os.remove(cache_db_path) |
|
except Exception as e: |
|
logger.info(f"Error deleting cache DB file: {e}") |
|
|
|
data = clean_data(csv_url) |
|
llm = ChatGroq(model=model_name, api_key=current_api_key) |
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
|
|
df = SmartDataframe( |
|
data, |
|
config={ |
|
'llm': llm, |
|
'save_charts': True, |
|
'open_charts': False, |
|
'save_charts_path': os.path.dirname(chart_path), |
|
'custom_chart_filename': chart_filename |
|
} |
|
) |
|
|
|
answer = df.chat(question) |
|
|
|
|
|
if isinstance(answer, pd.DataFrame): |
|
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") |
|
elif isinstance(answer, pd.Series): |
|
processed = answer.apply(handle_out_of_range_float).to_dict() |
|
elif isinstance(answer, list): |
|
processed = [handle_out_of_range_float(item) for item in answer] |
|
elif isinstance(answer, dict): |
|
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} |
|
else: |
|
processed = {"answer": str(handle_out_of_range_float(answer))} |
|
|
|
return processed |
|
|
|
except Exception as e: |
|
error_message = str(e) |
|
if "429" in error_message: |
|
with current_groq_key_lock: |
|
current_groq_key_index += 1 |
|
if current_groq_key_index >= len(groq_api_keys): |
|
logger.info("All API keys exhausted.") |
|
return None |
|
else: |
|
logger.info(f"Error with API key index {current_groq_key_index}: {error_message}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): |
|
global current_langchain_key_index, current_langchain_key_lock |
|
|
|
data = clean_data(csv_url) |
|
attempts = 0 |
|
|
|
while attempts < len(groq_api_keys): |
|
with current_langchain_key_lock: |
|
if current_langchain_key_index >= len(groq_api_keys): |
|
current_langchain_key_index = 0 |
|
api_key = groq_api_keys[current_langchain_key_index] |
|
current_key = current_langchain_key_index |
|
current_langchain_key_index += 1 |
|
attempts += 1 |
|
|
|
try: |
|
llm = ChatGroq(model=model_name, api_key=api_key) |
|
tool = PythonAstREPLTool(locals={ |
|
"df": data, |
|
"pd": pd, |
|
"np": np, |
|
"plt": plt, |
|
"sns": sns, |
|
"matplotlib": matplotlib |
|
}) |
|
|
|
agent = create_pandas_dataframe_agent( |
|
llm, |
|
data, |
|
agent_type="tool-calling", |
|
verbose=True, |
|
allow_dangerous_code=True, |
|
extra_tools=[tool], |
|
return_intermediate_steps=True |
|
) |
|
|
|
prompt = _prompt_generator(question, chart_required) |
|
result = agent.invoke({"input": prompt}) |
|
return result.get("output") |
|
|
|
except Exception as e: |
|
logger.info(f"Error with key index {current_key}: {str(e)}") |
|
|
|
|
|
logger.info("All API keys have been exhausted.") |
|
return None |
|
|
|
|
|
def handle_out_of_range_float(value): |
|
if isinstance(value, float): |
|
if np.isnan(value): |
|
return None |
|
elif np.isinf(value): |
|
return "Infinity" |
|
return value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
instructions = """ |
|
|
|
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). |
|
- For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) |
|
- Use colorblind-friendly palette |
|
- Read above instructions and follow them. |
|
|
|
""" |
|
|
|
|
|
current_groq_chart_key_index = 0 |
|
current_groq_chart_lock = threading.Lock() |
|
|
|
current_langchain_chart_key_index = 0 |
|
current_langchain_chart_lock = threading.Lock() |
|
|
|
def model(): |
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
with current_groq_chart_lock: |
|
if current_groq_chart_key_index >= len(groq_api_keys): |
|
raise Exception("All API keys exhausted for chart generation") |
|
api_key = groq_api_keys[current_groq_chart_key_index] |
|
return ChatGroq(model=model_name, api_key=api_key) |
|
|
|
def groq_chart(csv_url: str, question: str): |
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
|
|
for attempt in range(len(groq_api_keys)): |
|
try: |
|
|
|
cache_db_path = "/workspace/cache/cache_db_0.11.db" |
|
if os.path.exists(cache_db_path): |
|
try: |
|
os.remove(cache_db_path) |
|
except Exception as e: |
|
logger.info(f"Cache cleanup error: {e}") |
|
|
|
data = clean_data(csv_url) |
|
with current_groq_chart_lock: |
|
current_api_key = groq_api_keys[current_groq_chart_key_index] |
|
|
|
llm = ChatGroq(model=model_name, api_key=current_api_key) |
|
|
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
|
|
df = SmartDataframe( |
|
data, |
|
config={ |
|
'llm': llm, |
|
'save_charts': True, |
|
'open_charts': False, |
|
'save_charts_path': os.path.dirname(chart_path), |
|
'custom_chart_filename': chart_filename |
|
} |
|
) |
|
|
|
answer = df.chat(question + instructions) |
|
|
|
if process_answer(answer): |
|
return "Chart not generated" |
|
return answer |
|
|
|
except Exception as e: |
|
error = str(e) |
|
if "429" in error: |
|
with current_groq_chart_lock: |
|
current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) |
|
else: |
|
logger.info(f"Chart generation error: {error}") |
|
return {"error": error} |
|
|
|
logger.info("All API keys exhausted for chart generation") |
|
return None |
|
|
|
|
|
|
|
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): |
|
global current_langchain_chart_key_index, current_langchain_chart_lock |
|
|
|
data = clean_data(csv_url) |
|
|
|
for attempt in range(len(groq_api_keys)): |
|
try: |
|
with current_langchain_chart_lock: |
|
api_key = groq_api_keys[current_langchain_chart_key_index] |
|
current_key = current_langchain_chart_key_index |
|
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) |
|
|
|
llm = ChatGroq(model=model_name, api_key=api_key) |
|
tool = PythonAstREPLTool(locals={ |
|
"df": data, |
|
"pd": pd, |
|
"np": np, |
|
"plt": plt, |
|
"sns": sns, |
|
"matplotlib": matplotlib, |
|
"uuid": uuid |
|
}) |
|
|
|
agent = create_pandas_dataframe_agent( |
|
llm, |
|
data, |
|
agent_type="tool-calling", |
|
verbose=True, |
|
allow_dangerous_code=True, |
|
extra_tools=[tool], |
|
return_intermediate_steps=True |
|
) |
|
|
|
result = agent.invoke({"input": _prompt_generator(f"{question} and use this csv_url: {csv_url} to read the csv file", True)}) |
|
output = result.get("output", "") |
|
|
|
|
|
chart_files = extract_chart_filenames(output) |
|
if len(chart_files) > 0: |
|
return chart_files |
|
|
|
if attempt < len(groq_api_keys) - 1: |
|
logger.info(f"Langchain chart error (key {current_key}): {output}") |
|
|
|
except Exception as e: |
|
logger.info(f"Langchain chart error (key {current_key}): {str(e)}") |
|
|
|
logger.info("All API keys exhausted for chart generation") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def csv_chat(csv_url: str, query: str): |
|
""" |
|
Generate a response based on the provided CSV URL and query. |
|
Prioritizes LangChain-Groq, then raw Groq, and finally LangChain-Gemini as fallback. |
|
|
|
Parameters: |
|
- csv_url (str): The URL of the CSV file. |
|
- query (str): The query for generating the response. |
|
|
|
Returns: |
|
- dict: A dictionary containing the generated response. |
|
|
|
Example: |
|
- csv_url: "https://example.com/data.csv" |
|
- query: "What is the total sales for the year 2022?" |
|
Returns: |
|
- dict: {"answer": "The total sales for 2022 is $100,000."} |
|
""" |
|
try: |
|
updated_query = f"{query} and Do not show any charts or graphs." |
|
|
|
|
|
try: |
|
lang_groq_answer = await asyncio.to_thread( |
|
langchain_csv_chat, csv_url, updated_query, False |
|
) |
|
logger.info("LangChain-Groq answer:", lang_groq_answer) |
|
|
|
if lang_groq_answer is not None: |
|
return {"answer": jsonable_encoder(lang_groq_answer)} |
|
|
|
raise Exception("LangChain-Groq response not usable, falling back to raw Groq") |
|
|
|
except Exception as lang_groq_error: |
|
logger.info(f"LangChain-Groq error: {str(lang_groq_error)}") |
|
|
|
|
|
try: |
|
raw_groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query) |
|
logger.info("Raw Groq answer:", raw_groq_answer) |
|
|
|
if process_answer(raw_groq_answer) == "Empty response received." or raw_groq_answer is None: |
|
raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini") |
|
|
|
if process_answer(raw_groq_answer): |
|
raise Exception("Raw Groq response not usable, falling back to LangChain-Gemini") |
|
|
|
return {"answer": jsonable_encoder(raw_groq_answer)} |
|
|
|
except Exception as raw_groq_error: |
|
logger.info(f"Raw Groq error: {str(raw_groq_error)}") |
|
|
|
|
|
try: |
|
gemini_answer = await asyncio.to_thread( |
|
langchain_gemini_csv_handler, csv_url, updated_query, False |
|
) |
|
logger.info("LangChain-Gemini answer:", gemini_answer) |
|
|
|
if gemini_answer is not None: |
|
return {"answer": jsonable_encoder(gemini_answer)} |
|
|
|
raise Exception("All fallbacks exhausted") |
|
|
|
except Exception as gemini_error: |
|
logger.info(f"LangChain-Gemini error: {str(gemini_error)}") |
|
return {"answer": "Sorry, I couldn't find relevant data..."} |
|
|
|
except Exception as e: |
|
logger.info(f"Unexpected error: {str(e)}") |
|
return {"answer": "error"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def csv_chart(csv_url: str, query: str, chat_id: str): |
|
""" |
|
Generate a chart based on the provided CSV URL and query. |
|
Prioritizes OpenAI, then raw Groq, then LangChain Gemini, and finally LangChain Groq as fallback. |
|
|
|
Parameters: |
|
- csv_url (str): The URL of the CSV file. |
|
- query (str): The query for generating the chart. |
|
|
|
Returns: |
|
- dict: A dictionary containing either: |
|
- {"image_url": "https://example.com/chart.png"} on success, or |
|
- {"error": "error message"} on failure |
|
|
|
Example: |
|
- csv_url: "https://example.com/data.csv" |
|
- query: "Show sales trends as a line chart" |
|
Returns: |
|
- dict: {"image_url": "https://storage.example.com/chart_uuid.png"} |
|
""" |
|
|
|
async def upload_and_return(image_path: str, chat_id: str) -> dict: |
|
"""Helper function to handle image uploads""" |
|
unique_name = f'{uuid.uuid4()}.png' |
|
public_url = await upload_file_to_supabase(image_path, unique_name, chat_id) |
|
logger.info(f"Uploaded chart: {public_url}") |
|
os.remove(image_path) |
|
return {"image_url": public_url} |
|
|
|
try: |
|
try: |
|
|
|
openai_result = await asyncio.to_thread(openai_chart, csv_url, query) |
|
logger.info(f"OpenAI chart result:", openai_result) |
|
|
|
if openai_result and openai_result != 'Chart not generated': |
|
return await upload_and_return(openai_result, chat_id) |
|
|
|
raise Exception("OpenAI failed to generate chart") |
|
|
|
except Exception as openai_error: |
|
logger.info(f"OpenAI failed ({str(openai_error)}), trying raw Groq...") |
|
|
|
try: |
|
groq_result = await asyncio.to_thread(groq_chart, csv_url, query) |
|
logger.info(f"Raw Groq chart result:", groq_result) |
|
|
|
if groq_result and groq_result != 'Chart not generated': |
|
return await upload_and_return(groq_result, chat_id) |
|
|
|
raise Exception("Raw Groq failed to generate chart") |
|
|
|
except Exception as groq_error: |
|
logger.info(f"Raw Groq failed ({str(groq_error)}), trying LangChain Gemini...") |
|
|
|
|
|
try: |
|
gemini_result = await asyncio.to_thread( |
|
langchain_gemini_csv_handler, csv_url, query, True |
|
) |
|
logger.info("LangChain Gemini chart result:", gemini_result) |
|
|
|
|
|
if gemini_result and isinstance(gemini_result, str): |
|
clean_path = gemini_result.strip() |
|
return await upload_and_return(clean_path, chat_id) |
|
|
|
|
|
if gemini_result and isinstance(gemini_result, list) and len(gemini_result) > 0: |
|
return await upload_and_return(gemini_result[0], chat_id) |
|
|
|
raise Exception("LangChain Gemini returned empty result") |
|
|
|
except Exception as gemini_error: |
|
logger.info(f"LangChain Gemini failed ({str(gemini_error)}), trying LangChain Groq...") |
|
|
|
|
|
try: |
|
lc_groq_paths = await asyncio.to_thread( |
|
langchain_csv_chart, csv_url, query, True |
|
) |
|
logger.info("LangChain Groq chart result:", lc_groq_paths) |
|
|
|
if isinstance(lc_groq_paths, list) and lc_groq_paths: |
|
return await upload_and_return(lc_groq_paths[0], chat_id) |
|
|
|
return {"error": "All chart generation methods failed"} |
|
|
|
except Exception as lc_groq_error: |
|
logger.info(f"LangChain Groq failed: {str(lc_groq_error)}") |
|
return {"error": "Could not generate chart"} |
|
|
|
except Exception as e: |
|
logger.info(f"Critical error: {str(e)}") |
|
return {"error": "Internal system error"} |