Soumik555 commited on
Commit
b8d0141
·
1 Parent(s): 73f8541

blank image issue on multiple req

Browse files
Files changed (1) hide show
  1. controller.py +367 -139
controller.py CHANGED
@@ -307,179 +307,179 @@ def handle_out_of_range_float(value):
307
 
308
  # CHART CODING STARTS FROM HERE
309
 
310
- instructions = """
311
 
312
- - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
313
- - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
314
- - Use colorblind-friendly palette
315
- - Read above instructions and follow them.
316
 
317
- """
318
 
319
- # Thread-safe configuration for chart endpoints
320
- current_groq_chart_key_index = 0
321
- current_groq_chart_lock = threading.Lock()
322
 
323
- current_langchain_chart_key_index = 0
324
- current_langchain_chart_lock = threading.Lock()
325
 
326
- def model():
327
- global current_groq_chart_key_index, current_groq_chart_lock
328
- with current_groq_chart_lock:
329
- if current_groq_chart_key_index >= len(groq_api_keys):
330
- raise Exception("All API keys exhausted for chart generation")
331
- api_key = groq_api_keys[current_groq_chart_key_index]
332
- return ChatGroq(model=model_name, api_key=api_key)
333
 
334
- def groq_chart(csv_url: str, question: str):
335
- global current_groq_chart_key_index, current_groq_chart_lock
336
 
337
- for attempt in range(len(groq_api_keys)):
338
- try:
339
- # Clean cache before processing
340
- cache_db_path = "/workspace/cache/cache_db_0.11.db"
341
- if os.path.exists(cache_db_path):
342
- try:
343
- os.remove(cache_db_path)
344
- except Exception as e:
345
- print(f"Cache cleanup error: {e}")
346
-
347
- data = clean_data(csv_url)
348
- with current_groq_chart_lock:
349
- current_api_key = groq_api_keys[current_groq_chart_key_index]
350
 
351
- llm = ChatGroq(model=model_name, api_key=current_api_key)
352
 
353
- # Generate unique filename using UUID
354
- chart_filename = f"chart_{uuid.uuid4()}.png"
355
- chart_path = os.path.join("generated_charts", chart_filename)
356
 
357
- # Configure SmartDataframe with chart settings
358
- df = SmartDataframe(
359
- data,
360
- config={
361
- 'llm': llm,
362
- 'save_charts': True, # Enable chart saving
363
- 'open_charts': False,
364
- 'save_charts_path': os.path.dirname(chart_path), # Directory to save
365
- 'custom_chart_filename': chart_filename # Unique filename
366
- }
367
- )
368
 
369
- answer = df.chat(question + instructions)
370
 
371
- if process_answer(answer):
372
- return "Chart not generated"
373
- return answer
374
 
375
- except Exception as e:
376
- error = str(e)
377
- if "429" in error:
378
- with current_groq_chart_lock:
379
- current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
380
- else:
381
- print(f"Chart generation error: {error}")
382
- return {"error": error}
383
 
384
- return {"error": "All API keys exhausted for chart generation"}
385
 
386
 
387
 
388
- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
389
- global current_langchain_chart_key_index, current_langchain_chart_lock
390
 
391
- data = clean_data(csv_url)
392
 
393
- for attempt in range(len(groq_api_keys)):
394
- try:
395
- with current_langchain_chart_lock:
396
- api_key = groq_api_keys[current_langchain_chart_key_index]
397
- current_key = current_langchain_chart_key_index
398
- current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
399
-
400
- llm = ChatGroq(model=model_name, api_key=api_key)
401
- tool = PythonAstREPLTool(locals={
402
- "df": data,
403
- "pd": pd,
404
- "np": np,
405
- "plt": plt,
406
- "sns": sns,
407
- "matplotlib": matplotlib,
408
- "uuid": uuid
409
- })
 
 
 
 
 
 
 
 
 
 
410
 
411
- agent = create_pandas_dataframe_agent(
412
- llm,
413
- data,
414
- agent_type="openai-tools",
415
- verbose=True,
416
- allow_dangerous_code=True,
417
- extra_tools=[tool],
418
- return_intermediate_steps=True
419
- )
420
 
421
- result = agent.invoke({"input": _prompt_generator(question, True)})
422
- output = result.get("output", "")
 
 
423
 
424
- # Verify chart file creation
425
- chart_files = extract_chart_filenames(output)
426
- if len(chart_files) > 0:
427
- return chart_files
428
 
429
- if attempt < len(groq_api_keys) - 1:
430
- print(f"Langchain chart error (key {current_key}): {output}")
431
-
432
- except Exception as e:
433
- print(f"Langchain chart error (key {current_key}): {str(e)}")
434
 
435
- return "Chart generation failed after all retries"
436
 
437
- @app.post("/api/csv-chart")
438
- async def csv_chart(request: dict, authorization: str = Header(None)):
439
- # Authorization verification
440
- if not authorization or not authorization.startswith("Bearer "):
441
- raise HTTPException(status_code=401, detail="Authorization required")
442
 
443
- token = authorization.split(" ")[1]
444
- if token != os.getenv("AUTH_TOKEN"):
445
- raise HTTPException(status_code=403, detail="Invalid credentials")
446
 
447
- try:
448
- query = request.get("query", "")
449
- csv_url = unquote(request.get("csv_url", ""))
450
 
451
- # Parallel processing with thread pool
452
- if if_initial_chart_question(query):
453
- chart_paths = await asyncio.to_thread(
454
- langchain_csv_chart, csv_url, query, True
455
- )
456
- print(chart_paths)
457
 
458
- if len(chart_paths) > 0:
459
- return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
460
 
461
- # Groq-based chart generation
462
- groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
463
- print(f"Generated Chart: {groq_result}")
464
- if groq_result != 'Chart not generated':
465
- return FileResponse(groq_result, media_type="image/png")
466
 
467
 
468
- # Fallback to Langchain
469
- langchain_paths = await asyncio.to_thread(
470
- langchain_csv_chart, csv_url, query, True
471
- )
472
 
473
- print (langchain_paths)
474
 
475
- if len(langchain_paths) > 0:
476
- return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
477
- else:
478
- return {"error": "All chart generation methods failed"}
479
-
480
- except Exception as e:
481
- print(f"Critical chart error: {str(e)}")
482
- return {"error": "Internal system error"}
483
 
484
 
485
 
@@ -568,3 +568,231 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
568
 
569
 
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  # CHART CODING STARTS FROM HERE
309
 
310
+ # instructions = """
311
 
312
+ # - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
313
+ # - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
314
+ # - Use colorblind-friendly palette
315
+ # - Read above instructions and follow them.
316
 
317
+ # """
318
 
319
+ # # Thread-safe configuration for chart endpoints
320
+ # current_groq_chart_key_index = 0
321
+ # current_groq_chart_lock = threading.Lock()
322
 
323
+ # current_langchain_chart_key_index = 0
324
+ # current_langchain_chart_lock = threading.Lock()
325
 
326
+ # def model():
327
+ # global current_groq_chart_key_index, current_groq_chart_lock
328
+ # with current_groq_chart_lock:
329
+ # if current_groq_chart_key_index >= len(groq_api_keys):
330
+ # raise Exception("All API keys exhausted for chart generation")
331
+ # api_key = groq_api_keys[current_groq_chart_key_index]
332
+ # return ChatGroq(model=model_name, api_key=api_key)
333
 
334
+ # def groq_chart(csv_url: str, question: str):
335
+ # global current_groq_chart_key_index, current_groq_chart_lock
336
 
337
+ # for attempt in range(len(groq_api_keys)):
338
+ # try:
339
+ # # Clean cache before processing
340
+ # cache_db_path = "/workspace/cache/cache_db_0.11.db"
341
+ # if os.path.exists(cache_db_path):
342
+ # try:
343
+ # os.remove(cache_db_path)
344
+ # except Exception as e:
345
+ # print(f"Cache cleanup error: {e}")
346
+
347
+ # data = clean_data(csv_url)
348
+ # with current_groq_chart_lock:
349
+ # current_api_key = groq_api_keys[current_groq_chart_key_index]
350
 
351
+ # llm = ChatGroq(model=model_name, api_key=current_api_key)
352
 
353
+ # # Generate unique filename using UUID
354
+ # chart_filename = f"chart_{uuid.uuid4()}.png"
355
+ # chart_path = os.path.join("generated_charts", chart_filename)
356
 
357
+ # # Configure SmartDataframe with chart settings
358
+ # df = SmartDataframe(
359
+ # data,
360
+ # config={
361
+ # 'llm': llm,
362
+ # 'save_charts': True, # Enable chart saving
363
+ # 'open_charts': False,
364
+ # 'save_charts_path': os.path.dirname(chart_path), # Directory to save
365
+ # 'custom_chart_filename': chart_filename # Unique filename
366
+ # }
367
+ # )
368
 
369
+ # answer = df.chat(question + instructions)
370
 
371
+ # if process_answer(answer):
372
+ # return "Chart not generated"
373
+ # return answer
374
 
375
+ # except Exception as e:
376
+ # error = str(e)
377
+ # if "429" in error:
378
+ # with current_groq_chart_lock:
379
+ # current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
380
+ # else:
381
+ # print(f"Chart generation error: {error}")
382
+ # return {"error": error}
383
 
384
+ # return {"error": "All API keys exhausted for chart generation"}
385
 
386
 
387
 
388
+ # def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
389
+ # global current_langchain_chart_key_index, current_langchain_chart_lock
390
 
391
+ # data = clean_data(csv_url)
392
 
393
+ # for attempt in range(len(groq_api_keys)):
394
+ # try:
395
+ # with current_langchain_chart_lock:
396
+ # api_key = groq_api_keys[current_langchain_chart_key_index]
397
+ # current_key = current_langchain_chart_key_index
398
+ # current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
399
+
400
+ # llm = ChatGroq(model=model_name, api_key=api_key)
401
+ # tool = PythonAstREPLTool(locals={
402
+ # "df": data,
403
+ # "pd": pd,
404
+ # "np": np,
405
+ # "plt": plt,
406
+ # "sns": sns,
407
+ # "matplotlib": matplotlib,
408
+ # "uuid": uuid
409
+ # })
410
+
411
+ # agent = create_pandas_dataframe_agent(
412
+ # llm,
413
+ # data,
414
+ # agent_type="openai-tools",
415
+ # verbose=True,
416
+ # allow_dangerous_code=True,
417
+ # extra_tools=[tool],
418
+ # return_intermediate_steps=True
419
+ # )
420
 
421
+ # result = agent.invoke({"input": _prompt_generator(question, True)})
422
+ # output = result.get("output", "")
 
 
 
 
 
 
 
423
 
424
+ # # Verify chart file creation
425
+ # chart_files = extract_chart_filenames(output)
426
+ # if len(chart_files) > 0:
427
+ # return chart_files
428
 
429
+ # if attempt < len(groq_api_keys) - 1:
430
+ # print(f"Langchain chart error (key {current_key}): {output}")
 
 
431
 
432
+ # except Exception as e:
433
+ # print(f"Langchain chart error (key {current_key}): {str(e)}")
 
 
 
434
 
435
+ # return "Chart generation failed after all retries"
436
 
437
+ # @app.post("/api/csv-chart")
438
+ # async def csv_chart(request: dict, authorization: str = Header(None)):
439
+ # # Authorization verification
440
+ # if not authorization or not authorization.startswith("Bearer "):
441
+ # raise HTTPException(status_code=401, detail="Authorization required")
442
 
443
+ # token = authorization.split(" ")[1]
444
+ # if token != os.getenv("AUTH_TOKEN"):
445
+ # raise HTTPException(status_code=403, detail="Invalid credentials")
446
 
447
+ # try:
448
+ # query = request.get("query", "")
449
+ # csv_url = unquote(request.get("csv_url", ""))
450
 
451
+ # # Parallel processing with thread pool
452
+ # if if_initial_chart_question(query):
453
+ # chart_paths = await asyncio.to_thread(
454
+ # langchain_csv_chart, csv_url, query, True
455
+ # )
456
+ # print(chart_paths)
457
 
458
+ # if len(chart_paths) > 0:
459
+ # return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
460
 
461
+ # # Groq-based chart generation
462
+ # groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
463
+ # print(f"Generated Chart: {groq_result}")
464
+ # if groq_result != 'Chart not generated':
465
+ # return FileResponse(groq_result, media_type="image/png")
466
 
467
 
468
+ # # Fallback to Langchain
469
+ # langchain_paths = await asyncio.to_thread(
470
+ # langchain_csv_chart, csv_url, query, True
471
+ # )
472
 
473
+ # print (langchain_paths)
474
 
475
+ # if len(langchain_paths) > 0:
476
+ # return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
477
+ # else:
478
+ # return {"error": "All chart generation methods failed"}
479
+
480
+ # except Exception as e:
481
+ # print(f"Critical chart error: {str(e)}")
482
+ # return {"error": "Internal system error"}
483
 
484
 
485
 
 
568
 
569
 
570
 
571
+
572
+
573
+ import os
574
+ import asyncio
575
+ import threading
576
+ import uuid
577
+ from fastapi import FastAPI, HTTPException, Header
578
+ from fastapi.responses import FileResponse
579
+ from urllib.parse import unquote
580
+ from pydantic import BaseModel
581
+ from concurrent.futures import ProcessPoolExecutor
582
+ import matplotlib.pyplot as plt
583
+ import matplotlib
584
+ import pandas as pd
585
+ import numpy as np
586
+ import seaborn as sns
587
+
588
+ # Import your custom modules (assumed available)
589
+ from csv_service import clean_data, extract_chart_filenames
590
+ from langchain_experimental.tools import PythonAstREPLTool
591
+ from langchain_experimental.agents import create_pandas_dataframe_agent
592
+ from langchain_groq import ChatGroq
593
+ from util_service import _prompt_generator, process_answer
594
+ from intitial_q_handler import if_initial_chart_question
595
+
596
+ # Use non-interactive backend
597
+ matplotlib.use('Agg')
598
+
599
+ # FastAPI app initialization
600
+ app = FastAPI()
601
+
602
+ # Environment variables and configuration
603
+ import os
604
+ groq_api_keys = os.getenv("GROQ_API_KEYS", "").split(",")
605
+ model_name = os.getenv("GROQ_LLM_MODEL")
606
+ image_file_path = os.getenv("IMAGE_FILE_PATH") # e.g. "/app/generated_charts"
607
+
608
+ # Global locks for key rotation (chart endpoints)
609
+ current_groq_chart_key_index = 0
610
+ current_groq_chart_lock = threading.Lock()
611
+ current_langchain_chart_key_index = 0
612
+ current_langchain_chart_lock = threading.Lock()
613
+
614
+ # Use a process pool to run CPU-bound chart generation
615
+ process_executor = ProcessPoolExecutor(max_workers=2)
616
+
617
+ # --- GROQ-BASED CHART GENERATION ---
618
+ def groq_chart(csv_url: str, question: str):
619
+ """
620
+ Generate a chart using the groq-based method.
621
+ Modifications:
622
+ • No deletion of a shared cache file (avoid interference).
623
+ • After chart generation, close all matplotlib figures.
624
+ • Return the full path of the saved chart.
625
+ """
626
+ global current_groq_chart_key_index, current_groq_chart_lock
627
+
628
+ for attempt in range(len(groq_api_keys)):
629
+ try:
630
+ # Instead of deleting a global cache file, you might later configure a per-request cache.
631
+ data = clean_data(csv_url)
632
+ with current_groq_chart_lock:
633
+ current_api_key = groq_api_keys[current_groq_chart_key_index]
634
+
635
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
636
+
637
+ # Generate a unique filename and full path for the chart
638
+ chart_filename = f"chart_{uuid.uuid4().hex}.png"
639
+ chart_path = os.path.join("generated_charts", chart_filename)
640
+
641
+ # Configure your dataframe tool (e.g. using SmartDataframe) to save charts.
642
+ # (Assuming your SmartDataframe uses these settings to save charts.)
643
+ from pandasai import SmartDataframe # Import here if not already imported
644
+ df = SmartDataframe(
645
+ data,
646
+ config={
647
+ 'llm': llm,
648
+ 'save_charts': True,
649
+ 'open_charts': False,
650
+ 'save_charts_path': os.path.dirname(chart_path),
651
+ 'custom_chart_filename': chart_filename
652
+ }
653
+ )
654
+
655
+ # Append any extra instructions if needed
656
+ instructions = """
657
+ - Ensure each value is clearly visible.
658
+ - Adjust font sizes, rotate labels if necessary.
659
+ - Use a colorblind-friendly palette.
660
+ - Arrange multiple charts in a grid if needed.
661
+ """
662
+ answer = df.chat(question + instructions)
663
+
664
+ # Make sure to close figures so they don't conflict between processes
665
+ plt.close('all')
666
+
667
+ # If process_answer indicates a problem, return a failure message.
668
+ if process_answer(answer):
669
+ return "Chart not generated"
670
+ # Return the chart path that was used for saving
671
+ return chart_path
672
+
673
+ except Exception as e:
674
+ error = str(e)
675
+ if "429" in error:
676
+ with current_groq_chart_lock:
677
+ current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
678
+ else:
679
+ print(f"Groq chart generation error: {error}")
680
+ return {"error": error}
681
+
682
+ return {"error": "All API keys exhausted for chart generation"}
683
+
684
+
685
+ # --- LANGCHAIN-BASED CHART GENERATION ---
686
+ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
687
+ """
688
+ Generate a chart using the langchain-based method.
689
+ Modifications:
690
+ • No shared deletion of cache.
691
+ • Close matplotlib figures after generation.
692
+ • Return a list of full chart file paths.
693
+ """
694
+ global current_langchain_chart_key_index, current_langchain_chart_lock
695
+
696
+ data = clean_data(csv_url)
697
+
698
+ for attempt in range(len(groq_api_keys)):
699
+ try:
700
+ with current_langchain_chart_lock:
701
+ api_key = groq_api_keys[current_langchain_chart_key_index]
702
+ current_key = current_langchain_chart_key_index
703
+ current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
704
+
705
+ llm = ChatGroq(model=model_name, api_key=api_key)
706
+ tool = PythonAstREPLTool(locals={
707
+ "df": data,
708
+ "pd": pd,
709
+ "np": np,
710
+ "plt": plt,
711
+ "sns": sns,
712
+ "matplotlib": matplotlib,
713
+ "uuid": uuid
714
+ })
715
+
716
+ agent = create_pandas_dataframe_agent(
717
+ llm,
718
+ data,
719
+ agent_type="openai-tools",
720
+ verbose=True,
721
+ allow_dangerous_code=True,
722
+ extra_tools=[tool],
723
+ return_intermediate_steps=True
724
+ )
725
+
726
+ result = agent.invoke({"input": _prompt_generator(question, True)})
727
+ output = result.get("output", "")
728
+
729
+ # Close figures to avoid interference
730
+ plt.close('all')
731
+
732
+ # Extract chart filenames (assuming extract_chart_filenames returns a list)
733
+ chart_files = extract_chart_filenames(output)
734
+ if len(chart_files) > 0:
735
+ # Return full paths (join with your image_file_path)
736
+ return [os.path.join(image_file_path, f) for f in chart_files]
737
+
738
+ if attempt < len(groq_api_keys) - 1:
739
+ print(f"Langchain chart error (key {current_key}): {output}")
740
+
741
+ except Exception as e:
742
+ print(f"Langchain chart error (key {current_key}): {str(e)}")
743
+
744
+ return "Chart generation failed after all retries"
745
+
746
+
747
+ # --- FASTAPI ENDPOINT FOR CHART GENERATION ---
748
+ @app.post("/api/csv-chart")
749
+ async def csv_chart(request: dict, authorization: str = Header(None)):
750
+ """
751
+ Endpoint for generating a chart from CSV data.
752
+ This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation
753
+ functions in separate processes so that multiple requests can run in parallel.
754
+ """
755
+ # --- Authorization Check ---
756
+ if not authorization or not authorization.startswith("Bearer "):
757
+ raise HTTPException(status_code=401, detail="Authorization required")
758
+
759
+ token = authorization.split(" ")[1]
760
+ if token != os.getenv("AUTH_TOKEN"):
761
+ raise HTTPException(status_code=403, detail="Invalid credentials")
762
+
763
+ try:
764
+ query = request.get("query", "")
765
+ csv_url = unquote(request.get("csv_url", ""))
766
+
767
+ loop = asyncio.get_running_loop()
768
+ # First, try the langchain-based method if the question qualifies
769
+ if if_initial_chart_question(query):
770
+ langchain_result = await loop.run_in_executor(
771
+ process_executor, langchain_csv_chart, csv_url, query, True
772
+ )
773
+ print("Langchain chart result:", langchain_result)
774
+ if isinstance(langchain_result, list) and len(langchain_result) > 0:
775
+ return FileResponse(langchain_result[0], media_type="image/png")
776
+
777
+ # Next, try the groq-based method
778
+ groq_result = await loop.run_in_executor(
779
+ process_executor, groq_chart, csv_url, query
780
+ )
781
+ print(f"Groq chart result: {groq_result}")
782
+ if isinstance(groq_result, str) and groq_result != "Chart not generated":
783
+ return FileResponse(groq_result, media_type="image/png")
784
+
785
+ # Fallback: try langchain-based again
786
+ langchain_paths = await loop.run_in_executor(
787
+ process_executor, langchain_csv_chart, csv_url, query, True
788
+ )
789
+ print("Fallback langchain chart result:", langchain_paths)
790
+ if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
791
+ return FileResponse(langchain_paths[0], media_type="image/png")
792
+ else:
793
+ return {"error": "All chart generation methods failed"}
794
+
795
+ except Exception as e:
796
+ print(f"Critical chart error: {str(e)}")
797
+ return {"error": "Internal system error"}
798
+