Lets go
Browse files- controller.py +27 -10
controller.py
CHANGED
@@ -17,7 +17,7 @@ from dotenv import load_dotenv
|
|
17 |
from pydantic import BaseModel
|
18 |
from csv_service import clean_data, extract_chart_filenames
|
19 |
from urllib.parse import unquote
|
20 |
-
import
|
21 |
from langchain_groq import ChatGroq
|
22 |
import pandas as pd
|
23 |
from langchain_experimental.tools import PythonAstREPLTool
|
@@ -43,6 +43,17 @@ logger = logging.getLogger(__name__)
|
|
43 |
max_cpus = os.cpu_count()
|
44 |
logger.info(f"Max CPUs: {max_cpus}")
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# Ensure the cache directory exists
|
47 |
os.makedirs("/app/cache", exist_ok=True)
|
48 |
|
@@ -99,7 +110,13 @@ async def basic_csv_data(request: CsvUrlRequest):
|
|
99 |
try:
|
100 |
decoded_url = unquote(request.csv_url)
|
101 |
logger.info(f"Fetching CSV data from URL: {decoded_url}")
|
102 |
-
csv_data =
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
logger.info(f"CSV data fetched successfully: {csv_data}")
|
104 |
return {"data": csv_data}
|
105 |
except Exception as e:
|
@@ -136,7 +153,11 @@ async def get_csv_data(request: CsvUrlRequest):
|
|
136 |
try:
|
137 |
decoded_url = unquote(request.csv_url)
|
138 |
logger.info(f"Fetching CSV data from URL: {decoded_url}")
|
139 |
-
csv_data =
|
|
|
|
|
|
|
|
|
140 |
return csv_data
|
141 |
except Exception as e:
|
142 |
logger.error(f"Error while fetching CSV data: {e}")
|
@@ -585,13 +606,8 @@ def groq_chart(csv_url: str, question: str):
|
|
585 |
# Global locks for key rotation (chart endpoints)
|
586 |
# current_groq_chart_key_index = 0
|
587 |
# current_groq_chart_lock = threading.Lock()
|
588 |
-
current_langchain_chart_key_index = 0
|
589 |
-
current_langchain_chart_lock = threading.Lock()
|
590 |
|
591 |
|
592 |
-
# Use a process pool to run CPU-bound chart generation
|
593 |
-
process_executor = ProcessPoolExecutor(max_workers=10)
|
594 |
-
|
595 |
# --- GROQ-BASED CHART GENERATION ---
|
596 |
# def groq_chart(csv_url: str, question: str):
|
597 |
# """
|
@@ -753,8 +769,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
753 |
try:
|
754 |
query = request.get("query", "")
|
755 |
csv_url = unquote(request.get("csv_url", ""))
|
756 |
-
|
757 |
-
loop = asyncio.get_running_loop()
|
758 |
# First, try the langchain-based method if the question qualifies
|
759 |
if if_initial_chart_question(query):
|
760 |
langchain_result = await loop.run_in_executor(
|
@@ -773,6 +788,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
773 |
return FileResponse(groq_result, media_type="image/png")
|
774 |
|
775 |
# Fallback: try langchain-based again
|
|
|
776 |
langchain_paths = await loop.run_in_executor(
|
777 |
process_executor, langchain_csv_chart, csv_url, query, True
|
778 |
)
|
@@ -780,6 +796,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
780 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
781 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
782 |
else:
|
|
|
783 |
return {"error": "All chart generation methods failed"}
|
784 |
|
785 |
except Exception as e:
|
|
|
17 |
from pydantic import BaseModel
|
18 |
from csv_service import clean_data, extract_chart_filenames
|
19 |
from urllib.parse import unquote
|
20 |
+
from csv_service import generate_csv_data, get_csv_basic_info, get_image_by_file_name
|
21 |
from langchain_groq import ChatGroq
|
22 |
import pandas as pd
|
23 |
from langchain_experimental.tools import PythonAstREPLTool
|
|
|
43 |
max_cpus = os.cpu_count()
|
44 |
logger.info(f"Max CPUs: {max_cpus}")
|
45 |
|
46 |
+
|
47 |
+
# Thread-safe configuration for chart endpoints
|
48 |
+
current_langchain_chart_key_index = 0
|
49 |
+
current_langchain_chart_lock = threading.Lock()
|
50 |
+
|
51 |
+
# running loop for asyncio
|
52 |
+
loop = asyncio.get_running_loop()
|
53 |
+
|
54 |
+
# Use a process pool to run CPU-bound chart generation
|
55 |
+
process_executor = ProcessPoolExecutor(max_workers=10)
|
56 |
+
|
57 |
# Ensure the cache directory exists
|
58 |
os.makedirs("/app/cache", exist_ok=True)
|
59 |
|
|
|
110 |
try:
|
111 |
decoded_url = unquote(request.csv_url)
|
112 |
logger.info(f"Fetching CSV data from URL: {decoded_url}")
|
113 |
+
# csv_data = get_csv_basic_info(decoded_url)
|
114 |
+
|
115 |
+
# Run the synchronous function in a thread pool executor
|
116 |
+
csv_data = await loop.run_in_executor(
|
117 |
+
process_executor, get_csv_basic_info, decoded_url
|
118 |
+
)
|
119 |
+
|
120 |
logger.info(f"CSV data fetched successfully: {csv_data}")
|
121 |
return {"data": csv_data}
|
122 |
except Exception as e:
|
|
|
153 |
try:
|
154 |
decoded_url = unquote(request.csv_url)
|
155 |
logger.info(f"Fetching CSV data from URL: {decoded_url}")
|
156 |
+
# csv_data = generate_csv_data(decoded_url)
|
157 |
+
# Run the synchronous function in a thread pool executor
|
158 |
+
csv_data = await loop.run_in_executor(
|
159 |
+
process_executor, generate_csv_data, decoded_url
|
160 |
+
)
|
161 |
return csv_data
|
162 |
except Exception as e:
|
163 |
logger.error(f"Error while fetching CSV data: {e}")
|
|
|
606 |
# Global locks for key rotation (chart endpoints)
|
607 |
# current_groq_chart_key_index = 0
|
608 |
# current_groq_chart_lock = threading.Lock()
|
|
|
|
|
609 |
|
610 |
|
|
|
|
|
|
|
611 |
# --- GROQ-BASED CHART GENERATION ---
|
612 |
# def groq_chart(csv_url: str, question: str):
|
613 |
# """
|
|
|
769 |
try:
|
770 |
query = request.get("query", "")
|
771 |
csv_url = unquote(request.get("csv_url", ""))
|
772 |
+
|
|
|
773 |
# First, try the langchain-based method if the question qualifies
|
774 |
if if_initial_chart_question(query):
|
775 |
langchain_result = await loop.run_in_executor(
|
|
|
788 |
return FileResponse(groq_result, media_type="image/png")
|
789 |
|
790 |
# Fallback: try langchain-based again
|
791 |
+
logger.error("Groq chart generation failed, trying langchain....")
|
792 |
langchain_paths = await loop.run_in_executor(
|
793 |
process_executor, langchain_csv_chart, csv_url, query, True
|
794 |
)
|
|
|
796 |
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
797 |
return FileResponse(langchain_paths[0], media_type="image/png")
|
798 |
else:
|
799 |
+
logger.error("All chart generation methods failed")
|
800 |
return {"error": "All chart generation methods failed"}
|
801 |
|
802 |
except Exception as e:
|