File size: 3,647 Bytes
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbda383
30e7daa
 
2ccbdb1
30e7daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from util_service import process_answer
import os
import threading
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
import pandas as pd
from pandasai import SmartDataframe
import numpy as np
import logging
from csv_service import clean_data
from util_service import handle_out_of_range_float

load_dotenv()

# Thread-safe key management for langchain_csv_chat
current_langchain_key_index = 0
current_langchain_key_lock = threading.Lock()

# Load environment variables
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",")
model_name = os.getenv("GROQ_LLM_MODEL")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


instructions = """

- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
- For multiple charts, put all of them in a single file.
- Use colorblind-friendly palette
- Read above instructions and follow them.
- Please do not use any visualization library other than matplotlib or seaborn.

"""

# Thread-safe configuration for chart endpoints
current_groq_chart_key_index = 0
current_groq_chart_lock = threading.Lock()

def model():
    global current_groq_chart_key_index, current_groq_chart_lock
    with current_groq_chart_lock:
        if current_groq_chart_key_index >= len(groq_api_keys):
            raise Exception("All API keys exhausted for chart generation")
        api_key = groq_api_keys[current_groq_chart_key_index]
    return ChatGroq(model=model_name, api_key=api_key)

def groq_chart(csv_url: str, question: str):
    global current_groq_chart_key_index, current_groq_chart_lock
    
    for attempt in range(len(groq_api_keys)):
        try:
            # Clean cache before processing
            cache_db_path = "/workspace/cache/cache_db_0.11.db"
            if os.path.exists(cache_db_path):
                try:
                    os.remove(cache_db_path)
                except Exception as e:
                    logger.info(f"Cache cleanup error: {e}")

            data = clean_data(csv_url)
            with current_groq_chart_lock:
                current_api_key = groq_api_keys[current_groq_chart_key_index]
            
            llm = ChatGroq(model=model_name, api_key=current_api_key)
            
            # Generate unique filename using UUID
            chart_filename = f"chart_{uuid.uuid4()}.png"
            chart_path = os.path.join("generated_charts", chart_filename)
            
            # Configure SmartDataframe with chart settings
            df = SmartDataframe(
                data,
                config={
                    'llm': llm,
                    'save_charts': True,  # Enable chart saving
                    'open_charts': False,
                    'save_charts_path': os.path.dirname(chart_path),  # Directory to save
                    'custom_chart_filename': chart_filename  # Unique filename
                }
            )
            
            answer = df.chat(question + instructions)
            
            if process_answer(answer):
                return "Chart not generated"
            return answer

        except Exception as e:
            error = str(e)
            if "429" in error:
                with current_groq_chart_lock:
                    current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
            else:
                logger.info(f"Chart generation error: {error}")
                return {"error": error}
    
    logger.info("All API keys exhausted for chart generation") 
    return None