File size: 3,367 Bytes
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import threading
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
import pandas as pd
from pandasai import SmartDataframe
import numpy as np
import logging
from csv_service import clean_data
from util_service import handle_out_of_range_float

load_dotenv()

# Thread-safe key management for langchain_csv_chat
current_groq_key_index = 0
current_groq_key_lock = threading.Lock()

# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def groq_chat(csv_url: str, question: str):
    global current_groq_key_index, current_groq_key_lock

    while True:
        with current_groq_key_lock:
            if current_groq_key_index >= len(groq_api_keys):
                return {"error": "All API keys exhausted."}
            current_api_key = groq_api_keys[current_groq_key_index]

        try:
            # Delete cache file if exists
            cache_db_path = "/workspace/cache/cache_db_0.11.db"
            if os.path.exists(cache_db_path):
                try:
                    os.remove(cache_db_path)
                except Exception as e:
                    logger.info(f"Error deleting cache DB file: {e}")

            data = clean_data(csv_url)
            llm = ChatGroq(model=model_name, api_key=current_api_key)
            # Generate unique filename using UUID
            chart_filename = f"chart_{uuid.uuid4()}.png"
            chart_path = os.path.join("generated_charts", chart_filename)
            
            # Configure SmartDataframe with chart settings
            df = SmartDataframe(
                data,
                config={
                    'llm': llm,
                    'save_charts': True,  # Enable chart saving
                    'open_charts': False,
                    'save_charts_path': os.path.dirname(chart_path),  # Directory to save
                    'custom_chart_filename': chart_filename  # Unique filename
                }
            )
            
            answer = df.chat(question)

            # Process different response types
            if isinstance(answer, pd.DataFrame):
                processed = answer.apply(handle_out_of_range_float).to_dict(orient="records")
            elif isinstance(answer, pd.Series):
                processed = answer.apply(handle_out_of_range_float).to_dict()
            elif isinstance(answer, list):
                processed = [handle_out_of_range_float(item) for item in answer]
            elif isinstance(answer, dict):
                processed = {k: handle_out_of_range_float(v) for k, v in answer.items()}
            else:
                processed = {"answer": str(handle_out_of_range_float(answer))}

            return processed

        except Exception as e:
            error_message = str(e)
            if "429" in error_message:
                with current_groq_key_lock:
                    current_groq_key_index += 1
                    if current_groq_key_index >= len(groq_api_keys):
                        logger.info("All API keys exhausted.")
                        return None
            else:
                logger.info(f"Error with API key index {current_groq_key_index}: {error_message}")
                return None