Soumik555 commited on
Commit
cde701c
·
1 Parent(s): 7f09169

blank image issue on multiple req

Browse files
Files changed (2) hide show
  1. .env +0 -19
  2. controller.py +371 -139
.env DELETED
@@ -1,19 +0,0 @@
1
- # Pandas API keys
2
- PANDASAI_API_KEYS=$2a$10$VVwPEnzFxnEnJhk2u5ef1ewTuT3rNK59QpYQWAhUY29FHH4b7fwNC,$2a$10$5ikmN9RtNWHvP8aLnHfm.epO/XhVF1Pvk1Chy2Fqa.4x232a374xK,$2a$10$aAvr1DH3Pt3KLPDYa.JED..d83Pl6M4xnQd6uY8fadNkqSEv9KaYK,$2a$10$tJkqyS9Us36ernP1N4/8dOE088rCm7MC3gIj2RQMlaalY34EkkDHy,$2a$10$V0tThT/XnmlHbJucM00yN.hxz9r3ZwVqe0sQRQwDZAGHmhMq81D7O,$2a$10$d9vj8iPtD/L/i2B5AhiKTexOSpZ52XZTRUDkZa4p0vnI6RCj7f0K2,$2a$10$PZdCvVJB8301iDIZrZ2z8uB9d68kaBeOjaOIbbXGgqlZ2frbTm0eG,$2a$10$SHK.YTrTQcol/yM/RD8tZOcIF2fUTXtaETpDo8G0At90NxQ1HGk.C,$2a$10$QYz2Fp2fFZNq80HjAC/Okuy/PZFMgGgpPuQAyFDVtvB0G9bCn8Cee,$2a$10$SGY3HoCX0jbBXHSbpwGH1OEC/yPwT5792MjSZeWYVLew52pE4gR0y,$2a$10$QHPpvXwCXhHtKyx4jWMTh.8Mz1azTEQbDdDMpmikOzdgKtFfOq3FG,$2a$10$KoTsqdLPNIBiLRHWUg/6guqxNrB4ByljnMDTN0HJXmGl.PagdxpGm,$2a$10$ERsxnbIwk0LOMqmFX1SfjuMSXzh5gsBqm1BnYXFNEBAS3J1AfK24m,$2a$10$zwX4F0/pxXgmuAfDteFlHeXswX8cvVAvkv8mBAJ4WLvAEaUM3v266,$2a$10$LPA4FUIjg6CbZYEhi3NLRuY2Yar5SbT9gYoQ/oZuPaFUxNUyaJ/ii,$2a$10$kLDISr9ivaqcYiAZ1TmBOeclXK0C5a/LPPB3Rsxme19NwVPhznQya,$2a$10$qpoxy4k4sQya0tY7/lSEkuEuwVQGEl757A.jVPGNEh6p5tN6Yofyq,$2a$10$TDndpw.NWwx2k5X.9eI30uAaga8pbYO/erUEblVGcj6ydzSgzdVde,$2a$10$TtZtCWXgVSUhaNMMsuOjLuC6tCY1GTzUR/PvIUdowXYQdmefgpvbW,$2a$10$Orj1ZiURJkREK30gdwEYLeV7mY657jJhif8SckIPdvctjkWHXHrq6,$2a$10$CxEXDLjFtK1.nE9GuIt1duxLbvYtz2EA7x1LqddNF44kKVcc8aGZC
3
- PANDASAI_ERROR_MESSAGE="Your BambooLLM inference limit has been reached. Please use pandas-ai with any other LLM available in the library. If you wish to continue using BambooLLM, visit this link https://tally.so/r/wzZNWg for more information."
4
-
5
- # Nessasary file paths
6
- IMAGE_FILE_PATH=generated_charts
7
- IMAGE_NOT_FOUND=exports/charts/image_not_found.png
8
-
9
- # Allowed hosts
10
- ALLOWED_HOSTS=https://vercel-test-10-rho.vercel.app,https://freechat-rho.vercel.app,https://chatcsvandpdf.vercel.app
11
-
12
- # Auth token
13
- AUTH_TOKEN=raVkp0VgY3z0wICaUi95d73VSUGxu8DA3mXUoiheC8B8gBQ1Rk6Zj4aO6kcba3gWy8KU10deRMte8GvG36wGauocooLLp6y9fJt7XUgYgzvxz6y8cfhwuifWSQzeB8qMbQXZjkH1oP3rocFGArSzWZdj5phDpGdwQoxkuBpOSfA4WhMPhMr4HdohjBuiy2TlIa7ICpd5fq35LCRt2ZERaXYGUbD7MqrpDOICgXyABTjTWGHe6r0hMK7k4JiIM36rZ028a777FausbLPke9V0lPqAz5ialT0j7RbMj2fxheiZCoErx15Qx5dGfpcS9O5Xi6bpTADrYcRej0wrJv3rZrcrCBrY4m6ep0eXRkElQM389H2KFu1MlI7Twf0TxcerPh6GMAZTg2YefZU1QE8Y0ODsruCU3Jiq6UfaYXMHP5YMwpcwwHzioybjFVfuMtDePjya7y7qwdwjXTqDDJAsSe061sMDKvHpPYgAOpaYerTVy4qGMuWTwDceUzqs39X0
14
-
15
- # GROQ API keys
16
- GROQ_API_KEYS=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,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,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,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,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
17
- GROQ_LLM_MODEL=llama3-70b-8192
18
-
19
- ALLOW_DANGEROUS_CODE=true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
controller.py CHANGED
@@ -307,179 +307,179 @@ def handle_out_of_range_float(value):
307
 
308
  # CHART CODING STARTS FROM HERE
309
 
310
- instructions = """
311
 
312
- - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
313
- - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
314
- - Use colorblind-friendly palette
315
- - Read above instructions and follow them.
316
 
317
- """
318
 
319
- # Thread-safe configuration for chart endpoints
320
- current_groq_chart_key_index = 0
321
- current_groq_chart_lock = threading.Lock()
322
 
323
- current_langchain_chart_key_index = 0
324
- current_langchain_chart_lock = threading.Lock()
325
 
326
- def model():
327
- global current_groq_chart_key_index, current_groq_chart_lock
328
- with current_groq_chart_lock:
329
- if current_groq_chart_key_index >= len(groq_api_keys):
330
- raise Exception("All API keys exhausted for chart generation")
331
- api_key = groq_api_keys[current_groq_chart_key_index]
332
- return ChatGroq(model=model_name, api_key=api_key)
333
 
334
- def groq_chart(csv_url: str, question: str):
335
- global current_groq_chart_key_index, current_groq_chart_lock
336
 
337
- for attempt in range(len(groq_api_keys)):
338
- try:
339
- # Clean cache before processing
340
- cache_db_path = "/workspace/cache/cache_db_0.11.db"
341
- if os.path.exists(cache_db_path):
342
- try:
343
- os.remove(cache_db_path)
344
- except Exception as e:
345
- print(f"Cache cleanup error: {e}")
346
-
347
- data = clean_data(csv_url)
348
- with current_groq_chart_lock:
349
- current_api_key = groq_api_keys[current_groq_chart_key_index]
350
 
351
- llm = ChatGroq(model=model_name, api_key=current_api_key)
352
 
353
- # Generate unique filename using UUID
354
- chart_filename = f"chart_{uuid.uuid4()}.png"
355
- chart_path = os.path.join("generated_charts", chart_filename)
356
 
357
- # Configure SmartDataframe with chart settings
358
- df = SmartDataframe(
359
- data,
360
- config={
361
- 'llm': llm,
362
- 'save_charts': True, # Enable chart saving
363
- 'open_charts': False,
364
- 'save_charts_path': os.path.dirname(chart_path), # Directory to save
365
- 'custom_chart_filename': chart_filename # Unique filename
366
- }
367
- )
368
 
369
- answer = df.chat(question + instructions)
370
 
371
- if process_answer(answer):
372
- return "Chart not generated"
373
- return answer
374
 
375
- except Exception as e:
376
- error = str(e)
377
- if "429" in error:
378
- with current_groq_chart_lock:
379
- current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
380
- else:
381
- print(f"Chart generation error: {error}")
382
- return {"error": error}
383
 
384
- return {"error": "All API keys exhausted for chart generation"}
385
 
386
 
387
 
388
- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
389
- global current_langchain_chart_key_index, current_langchain_chart_lock
390
 
391
- data = clean_data(csv_url)
392
-
393
- for attempt in range(len(groq_api_keys)):
394
- try:
395
- with current_langchain_chart_lock:
396
- api_key = groq_api_keys[current_langchain_chart_key_index]
397
- current_key = current_langchain_chart_key_index
398
- current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
399
-
400
- llm = ChatGroq(model=model_name, api_key=api_key)
401
- tool = PythonAstREPLTool(locals={
402
- "df": data,
403
- "pd": pd,
404
- "np": np,
405
- "plt": plt,
406
- "sns": sns,
407
- "matplotlib": matplotlib,
408
- "uuid": uuid
409
- })
410
 
411
- agent = create_pandas_dataframe_agent(
412
- llm,
413
- data,
414
- agent_type="openai-tools",
415
- verbose=True,
416
- allow_dangerous_code=True,
417
- extra_tools=[tool],
418
- return_intermediate_steps=True
419
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- result = agent.invoke({"input": _prompt_generator(question, True)})
422
- output = result.get("output", "")
423
 
424
- # Verify chart file creation
425
- chart_files = extract_chart_filenames(output)
426
- if len(chart_files) > 0:
427
- return chart_files
428
 
429
- if attempt < len(groq_api_keys) - 1:
430
- print(f"Langchain chart error (key {current_key}): {output}")
431
 
432
- except Exception as e:
433
- print(f"Langchain chart error (key {current_key}): {str(e)}")
434
 
435
- return "Chart generation failed after all retries"
436
 
437
- @app.post("/api/csv-chart")
438
- async def csv_chart(request: dict, authorization: str = Header(None)):
439
- # Authorization verification
440
- if not authorization or not authorization.startswith("Bearer "):
441
- raise HTTPException(status_code=401, detail="Authorization required")
442
 
443
- token = authorization.split(" ")[1]
444
- if token != os.getenv("AUTH_TOKEN"):
445
- raise HTTPException(status_code=403, detail="Invalid credentials")
446
 
447
- try:
448
- query = request.get("query", "")
449
- csv_url = unquote(request.get("csv_url", ""))
450
 
451
- # Parallel processing with thread pool
452
- if if_initial_chart_question(query):
453
- chart_paths = await asyncio.to_thread(
454
- langchain_csv_chart, csv_url, query, True
455
- )
456
- print(chart_paths)
457
 
458
- if len(chart_paths) > 0:
459
- return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
460
 
461
- # Groq-based chart generation
462
- groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
463
- print(f"Generated Chart: {groq_result}")
464
- if groq_result != 'Chart not generated':
465
- return FileResponse(groq_result, media_type="image/png")
466
 
467
 
468
- # Fallback to Langchain
469
- langchain_paths = await asyncio.to_thread(
470
- langchain_csv_chart, csv_url, query, True
471
- )
472
 
473
- print (langchain_paths)
474
 
475
- if len(langchain_paths) > 0:
476
- return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
477
- else:
478
- return {"error": "All chart generation methods failed"}
479
-
480
- except Exception as e:
481
- print(f"Critical chart error: {str(e)}")
482
- return {"error": "Internal system error"}
483
 
484
 
485
 
@@ -568,3 +568,235 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
568
 
569
 
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  # CHART CODING STARTS FROM HERE
309
 
310
+ # instructions = """
311
 
312
+ # - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
313
+ # - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
314
+ # - Use colorblind-friendly palette
315
+ # - Read above instructions and follow them.
316
 
317
+ # """
318
 
319
+ # # Thread-safe configuration for chart endpoints
320
+ # current_groq_chart_key_index = 0
321
+ # current_groq_chart_lock = threading.Lock()
322
 
323
+ # current_langchain_chart_key_index = 0
324
+ # current_langchain_chart_lock = threading.Lock()
325
 
326
+ # def model():
327
+ # global current_groq_chart_key_index, current_groq_chart_lock
328
+ # with current_groq_chart_lock:
329
+ # if current_groq_chart_key_index >= len(groq_api_keys):
330
+ # raise Exception("All API keys exhausted for chart generation")
331
+ # api_key = groq_api_keys[current_groq_chart_key_index]
332
+ # return ChatGroq(model=model_name, api_key=api_key)
333
 
334
+ # def groq_chart(csv_url: str, question: str):
335
+ # global current_groq_chart_key_index, current_groq_chart_lock
336
 
337
+ # for attempt in range(len(groq_api_keys)):
338
+ # try:
339
+ # # Clean cache before processing
340
+ # cache_db_path = "/workspace/cache/cache_db_0.11.db"
341
+ # if os.path.exists(cache_db_path):
342
+ # try:
343
+ # os.remove(cache_db_path)
344
+ # except Exception as e:
345
+ # print(f"Cache cleanup error: {e}")
346
+
347
+ # data = clean_data(csv_url)
348
+ # with current_groq_chart_lock:
349
+ # current_api_key = groq_api_keys[current_groq_chart_key_index]
350
 
351
+ # llm = ChatGroq(model=model_name, api_key=current_api_key)
352
 
353
+ # # Generate unique filename using UUID
354
+ # chart_filename = f"chart_{uuid.uuid4()}.png"
355
+ # chart_path = os.path.join("generated_charts", chart_filename)
356
 
357
+ # # Configure SmartDataframe with chart settings
358
+ # df = SmartDataframe(
359
+ # data,
360
+ # config={
361
+ # 'llm': llm,
362
+ # 'save_charts': True, # Enable chart saving
363
+ # 'open_charts': False,
364
+ # 'save_charts_path': os.path.dirname(chart_path), # Directory to save
365
+ # 'custom_chart_filename': chart_filename # Unique filename
366
+ # }
367
+ # )
368
 
369
+ # answer = df.chat(question + instructions)
370
 
371
+ # if process_answer(answer):
372
+ # return "Chart not generated"
373
+ # return answer
374
 
375
+ # except Exception as e:
376
+ # error = str(e)
377
+ # if "429" in error:
378
+ # with current_groq_chart_lock:
379
+ # current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
380
+ # else:
381
+ # print(f"Chart generation error: {error}")
382
+ # return {"error": error}
383
 
384
+ # return {"error": "All API keys exhausted for chart generation"}
385
 
386
 
387
 
388
+ # def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
389
+ # global current_langchain_chart_key_index, current_langchain_chart_lock
390
 
391
+ # data = clean_data(csv_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ # for attempt in range(len(groq_api_keys)):
394
+ # try:
395
+ # with current_langchain_chart_lock:
396
+ # api_key = groq_api_keys[current_langchain_chart_key_index]
397
+ # current_key = current_langchain_chart_key_index
398
+ # current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
399
+
400
+ # llm = ChatGroq(model=model_name, api_key=api_key)
401
+ # tool = PythonAstREPLTool(locals={
402
+ # "df": data,
403
+ # "pd": pd,
404
+ # "np": np,
405
+ # "plt": plt,
406
+ # "sns": sns,
407
+ # "matplotlib": matplotlib,
408
+ # "uuid": uuid
409
+ # })
410
+
411
+ # agent = create_pandas_dataframe_agent(
412
+ # llm,
413
+ # data,
414
+ # agent_type="openai-tools",
415
+ # verbose=True,
416
+ # allow_dangerous_code=True,
417
+ # extra_tools=[tool],
418
+ # return_intermediate_steps=True
419
+ # )
420
 
421
+ # result = agent.invoke({"input": _prompt_generator(question, True)})
422
+ # output = result.get("output", "")
423
 
424
+ # # Verify chart file creation
425
+ # chart_files = extract_chart_filenames(output)
426
+ # if len(chart_files) > 0:
427
+ # return chart_files
428
 
429
+ # if attempt < len(groq_api_keys) - 1:
430
+ # print(f"Langchain chart error (key {current_key}): {output}")
431
 
432
+ # except Exception as e:
433
+ # print(f"Langchain chart error (key {current_key}): {str(e)}")
434
 
435
+ # return "Chart generation failed after all retries"
436
 
437
+ # @app.post("/api/csv-chart")
438
+ # async def csv_chart(request: dict, authorization: str = Header(None)):
439
+ # # Authorization verification
440
+ # if not authorization or not authorization.startswith("Bearer "):
441
+ # raise HTTPException(status_code=401, detail="Authorization required")
442
 
443
+ # token = authorization.split(" ")[1]
444
+ # if token != os.getenv("AUTH_TOKEN"):
445
+ # raise HTTPException(status_code=403, detail="Invalid credentials")
446
 
447
+ # try:
448
+ # query = request.get("query", "")
449
+ # csv_url = unquote(request.get("csv_url", ""))
450
 
451
+ # # Parallel processing with thread pool
452
+ # if if_initial_chart_question(query):
453
+ # chart_paths = await asyncio.to_thread(
454
+ # langchain_csv_chart, csv_url, query, True
455
+ # )
456
+ # print(chart_paths)
457
 
458
+ # if len(chart_paths) > 0:
459
+ # return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
460
 
461
+ # # Groq-based chart generation
462
+ # groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
463
+ # print(f"Generated Chart: {groq_result}")
464
+ # if groq_result != 'Chart not generated':
465
+ # return FileResponse(groq_result, media_type="image/png")
466
 
467
 
468
+ # # Fallback to Langchain
469
+ # langchain_paths = await asyncio.to_thread(
470
+ # langchain_csv_chart, csv_url, query, True
471
+ # )
472
 
473
+ # print (langchain_paths)
474
 
475
+ # if len(langchain_paths) > 0:
476
+ # return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
477
+ # else:
478
+ # return {"error": "All chart generation methods failed"}
479
+
480
+ # except Exception as e:
481
+ # print(f"Critical chart error: {str(e)}")
482
+ # return {"error": "Internal system error"}
483
 
484
 
485
 
 
568
 
569
 
570
 
571
+
572
+
573
+ import os
574
+ import asyncio
575
+ import threading
576
+ import uuid
577
+ from fastapi import FastAPI, HTTPException, Header
578
+ from fastapi.responses import FileResponse
579
+ from urllib.parse import unquote
580
+ from pydantic import BaseModel
581
+ from concurrent.futures import ProcessPoolExecutor
582
+ import matplotlib.pyplot as plt
583
+ import matplotlib
584
+ import pandas as pd
585
+ import numpy as np
586
+ import seaborn as sns
587
+
588
+ # Import your custom modules (assumed available)
589
+ from csv_service import clean_data, extract_chart_filenames
590
+ from langchain_experimental.tools import PythonAstREPLTool
591
+ from langchain_experimental.agents import create_pandas_dataframe_agent
592
+ from langchain_groq import ChatGroq
593
+ from util_service import _prompt_generator, process_answer
594
+ from intitial_q_handler import if_initial_chart_question
595
+
596
+ # Use non-interactive backend
597
+ matplotlib.use('Agg')
598
+
599
+ # FastAPI app initialization
600
+ app = FastAPI()
601
+
602
+ # Ensure necessary directories exist
603
+ os.makedirs("/app/generated_charts", exist_ok=True)
604
+ os.makedirs("/workspace/cache", exist_ok=True)
605
+
606
+ # Environment variables and configuration
607
+ import os
608
+ groq_api_keys = os.getenv("GROQ_API_KEYS", "").split(",")
609
+ model_name = os.getenv("GROQ_LLM_MODEL")
610
+ image_file_path = os.getenv("IMAGE_FILE_PATH") # e.g. "/app/generated_charts"
611
+
612
+ # Global locks for key rotation (chart endpoints)
613
+ current_groq_chart_key_index = 0
614
+ current_groq_chart_lock = threading.Lock()
615
+ current_langchain_chart_key_index = 0
616
+ current_langchain_chart_lock = threading.Lock()
617
+
618
+ # Use a process pool to run CPU-bound chart generation
619
+ process_executor = ProcessPoolExecutor(max_workers=2)
620
+
621
+ # --- GROQ-BASED CHART GENERATION ---
622
+ def groq_chart(csv_url: str, question: str):
623
+ """
624
+ Generate a chart using the groq-based method.
625
+ Modifications:
626
+ • No deletion of a shared cache file (avoid interference).
627
+ • After chart generation, close all matplotlib figures.
628
+ • Return the full path of the saved chart.
629
+ """
630
+ global current_groq_chart_key_index, current_groq_chart_lock
631
+
632
+ for attempt in range(len(groq_api_keys)):
633
+ try:
634
+ # Instead of deleting a global cache file, you might later configure a per-request cache.
635
+ data = clean_data(csv_url)
636
+ with current_groq_chart_lock:
637
+ current_api_key = groq_api_keys[current_groq_chart_key_index]
638
+
639
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
640
+
641
+ # Generate a unique filename and full path for the chart
642
+ chart_filename = f"chart_{uuid.uuid4().hex}.png"
643
+ chart_path = os.path.join("generated_charts", chart_filename)
644
+
645
+ # Configure your dataframe tool (e.g. using SmartDataframe) to save charts.
646
+ # (Assuming your SmartDataframe uses these settings to save charts.)
647
+ from pandasai import SmartDataframe # Import here if not already imported
648
+ df = SmartDataframe(
649
+ data,
650
+ config={
651
+ 'llm': llm,
652
+ 'save_charts': True,
653
+ 'open_charts': False,
654
+ 'save_charts_path': os.path.dirname(chart_path),
655
+ 'custom_chart_filename': chart_filename
656
+ }
657
+ )
658
+
659
+ # Append any extra instructions if needed
660
+ instructions = """
661
+ - Ensure each value is clearly visible.
662
+ - Adjust font sizes, rotate labels if necessary.
663
+ - Use a colorblind-friendly palette.
664
+ - Arrange multiple charts in a grid if needed.
665
+ """
666
+ answer = df.chat(question + instructions)
667
+
668
+ # Make sure to close figures so they don't conflict between processes
669
+ plt.close('all')
670
+
671
+ # If process_answer indicates a problem, return a failure message.
672
+ if process_answer(answer):
673
+ return "Chart not generated"
674
+ # Return the chart path that was used for saving
675
+ return chart_path
676
+
677
+ except Exception as e:
678
+ error = str(e)
679
+ if "429" in error:
680
+ with current_groq_chart_lock:
681
+ current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
682
+ else:
683
+ print(f"Groq chart generation error: {error}")
684
+ return {"error": error}
685
+
686
+ return {"error": "All API keys exhausted for chart generation"}
687
+
688
+
689
+ # --- LANGCHAIN-BASED CHART GENERATION ---
690
+ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
691
+ """
692
+ Generate a chart using the langchain-based method.
693
+ Modifications:
694
+ • No shared deletion of cache.
695
+ • Close matplotlib figures after generation.
696
+ • Return a list of full chart file paths.
697
+ """
698
+ global current_langchain_chart_key_index, current_langchain_chart_lock
699
+
700
+ data = clean_data(csv_url)
701
+
702
+ for attempt in range(len(groq_api_keys)):
703
+ try:
704
+ with current_langchain_chart_lock:
705
+ api_key = groq_api_keys[current_langchain_chart_key_index]
706
+ current_key = current_langchain_chart_key_index
707
+ current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
708
+
709
+ llm = ChatGroq(model=model_name, api_key=api_key)
710
+ tool = PythonAstREPLTool(locals={
711
+ "df": data,
712
+ "pd": pd,
713
+ "np": np,
714
+ "plt": plt,
715
+ "sns": sns,
716
+ "matplotlib": matplotlib,
717
+ "uuid": uuid
718
+ })
719
+
720
+ agent = create_pandas_dataframe_agent(
721
+ llm,
722
+ data,
723
+ agent_type="openai-tools",
724
+ verbose=True,
725
+ allow_dangerous_code=True,
726
+ extra_tools=[tool],
727
+ return_intermediate_steps=True
728
+ )
729
+
730
+ result = agent.invoke({"input": _prompt_generator(question, True)})
731
+ output = result.get("output", "")
732
+
733
+ # Close figures to avoid interference
734
+ plt.close('all')
735
+
736
+ # Extract chart filenames (assuming extract_chart_filenames returns a list)
737
+ chart_files = extract_chart_filenames(output)
738
+ if len(chart_files) > 0:
739
+ # Return full paths (join with your image_file_path)
740
+ return [os.path.join(image_file_path, f) for f in chart_files]
741
+
742
+ if attempt < len(groq_api_keys) - 1:
743
+ print(f"Langchain chart error (key {current_key}): {output}")
744
+
745
+ except Exception as e:
746
+ print(f"Langchain chart error (key {current_key}): {str(e)}")
747
+
748
+ return "Chart generation failed after all retries"
749
+
750
+
751
+ # --- FASTAPI ENDPOINT FOR CHART GENERATION ---
752
+ @app.post("/api/csv-chart")
753
+ async def csv_chart(request: dict, authorization: str = Header(None)):
754
+ """
755
+ Endpoint for generating a chart from CSV data.
756
+ This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation
757
+ functions in separate processes so that multiple requests can run in parallel.
758
+ """
759
+ # --- Authorization Check ---
760
+ if not authorization or not authorization.startswith("Bearer "):
761
+ raise HTTPException(status_code=401, detail="Authorization required")
762
+
763
+ token = authorization.split(" ")[1]
764
+ if token != os.getenv("AUTH_TOKEN"):
765
+ raise HTTPException(status_code=403, detail="Invalid credentials")
766
+
767
+ try:
768
+ query = request.get("query", "")
769
+ csv_url = unquote(request.get("csv_url", ""))
770
+
771
+ loop = asyncio.get_running_loop()
772
+ # First, try the langchain-based method if the question qualifies
773
+ if if_initial_chart_question(query):
774
+ langchain_result = await loop.run_in_executor(
775
+ process_executor, langchain_csv_chart, csv_url, query, True
776
+ )
777
+ print("Langchain chart result:", langchain_result)
778
+ if isinstance(langchain_result, list) and len(langchain_result) > 0:
779
+ return FileResponse(langchain_result[0], media_type="image/png")
780
+
781
+ # Next, try the groq-based method
782
+ groq_result = await loop.run_in_executor(
783
+ process_executor, groq_chart, csv_url, query
784
+ )
785
+ print(f"Groq chart result: {groq_result}")
786
+ if isinstance(groq_result, str) and groq_result != "Chart not generated":
787
+ return FileResponse(groq_result, media_type="image/png")
788
+
789
+ # Fallback: try langchain-based again
790
+ langchain_paths = await loop.run_in_executor(
791
+ process_executor, langchain_csv_chart, csv_url, query, True
792
+ )
793
+ print("Fallback langchain chart result:", langchain_paths)
794
+ if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
795
+ return FileResponse(langchain_paths[0], media_type="image/png")
796
+ else:
797
+ return {"error": "All chart generation methods failed"}
798
+
799
+ except Exception as e:
800
+ print(f"Critical chart error: {str(e)}")
801
+ return {"error": "Internal system error"}
802
+