Soumik555 commited on
Commit
73f8541
·
1 Parent(s): 282f670

blank image issue on multiple req

Browse files
Files changed (2) hide show
  1. Dockerfile +0 -3
  2. controller.py +139 -370
Dockerfile CHANGED
@@ -13,9 +13,6 @@ RUN mkdir -p /app/cache && chmod -R 777 /app/cache
13
  # Create the log file and set permissions
14
  RUN touch /app/pandasai.log && chmod 666 /app/pandasai.log
15
 
16
- # Set the MPLCONFIGDIR environment variable to a writable directory
17
- ENV MPLCONFIGDIR=/tmp/matplotlib
18
-
19
  # Copy the requirements file first
20
  COPY requirements.txt .
21
 
 
13
  # Create the log file and set permissions
14
  RUN touch /app/pandasai.log && chmod 666 /app/pandasai.log
15
 
 
 
 
16
  # Copy the requirements file first
17
  COPY requirements.txt .
18
 
controller.py CHANGED
@@ -38,7 +38,6 @@ os.makedirs("/app/cache", exist_ok=True)
38
 
39
  os.makedirs("/app", exist_ok=True)
40
  open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
41
- os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
42
 
43
  # Ensure the generated_charts directory exists
44
  os.makedirs("/app/generated_charts", exist_ok=True)
@@ -308,179 +307,179 @@ def handle_out_of_range_float(value):
308
 
309
  # CHART CODING STARTS FROM HERE
310
 
311
- # instructions = """
312
 
313
- # - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
314
- # - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
315
- # - Use colorblind-friendly palette
316
- # - Read above instructions and follow them.
317
 
318
- # """
319
 
320
- # # Thread-safe configuration for chart endpoints
321
- # current_groq_chart_key_index = 0
322
- # current_groq_chart_lock = threading.Lock()
323
 
324
- # current_langchain_chart_key_index = 0
325
- # current_langchain_chart_lock = threading.Lock()
326
 
327
- # def model():
328
- # global current_groq_chart_key_index, current_groq_chart_lock
329
- # with current_groq_chart_lock:
330
- # if current_groq_chart_key_index >= len(groq_api_keys):
331
- # raise Exception("All API keys exhausted for chart generation")
332
- # api_key = groq_api_keys[current_groq_chart_key_index]
333
- # return ChatGroq(model=model_name, api_key=api_key)
334
 
335
- # def groq_chart(csv_url: str, question: str):
336
- # global current_groq_chart_key_index, current_groq_chart_lock
337
 
338
- # for attempt in range(len(groq_api_keys)):
339
- # try:
340
- # # Clean cache before processing
341
- # cache_db_path = "/workspace/cache/cache_db_0.11.db"
342
- # if os.path.exists(cache_db_path):
343
- # try:
344
- # os.remove(cache_db_path)
345
- # except Exception as e:
346
- # print(f"Cache cleanup error: {e}")
347
-
348
- # data = clean_data(csv_url)
349
- # with current_groq_chart_lock:
350
- # current_api_key = groq_api_keys[current_groq_chart_key_index]
351
 
352
- # llm = ChatGroq(model=model_name, api_key=current_api_key)
353
 
354
- # # Generate unique filename using UUID
355
- # chart_filename = f"chart_{uuid.uuid4()}.png"
356
- # chart_path = os.path.join("generated_charts", chart_filename)
357
 
358
- # # Configure SmartDataframe with chart settings
359
- # df = SmartDataframe(
360
- # data,
361
- # config={
362
- # 'llm': llm,
363
- # 'save_charts': True, # Enable chart saving
364
- # 'open_charts': False,
365
- # 'save_charts_path': os.path.dirname(chart_path), # Directory to save
366
- # 'custom_chart_filename': chart_filename # Unique filename
367
- # }
368
- # )
369
 
370
- # answer = df.chat(question + instructions)
371
 
372
- # if process_answer(answer):
373
- # return "Chart not generated"
374
- # return answer
375
 
376
- # except Exception as e:
377
- # error = str(e)
378
- # if "429" in error:
379
- # with current_groq_chart_lock:
380
- # current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
381
- # else:
382
- # print(f"Chart generation error: {error}")
383
- # return {"error": error}
384
 
385
- # return {"error": "All API keys exhausted for chart generation"}
386
 
387
 
388
 
389
- # def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
390
- # global current_langchain_chart_key_index, current_langchain_chart_lock
391
 
392
- # data = clean_data(csv_url)
393
 
394
- # for attempt in range(len(groq_api_keys)):
395
- # try:
396
- # with current_langchain_chart_lock:
397
- # api_key = groq_api_keys[current_langchain_chart_key_index]
398
- # current_key = current_langchain_chart_key_index
399
- # current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
400
-
401
- # llm = ChatGroq(model=model_name, api_key=api_key)
402
- # tool = PythonAstREPLTool(locals={
403
- # "df": data,
404
- # "pd": pd,
405
- # "np": np,
406
- # "plt": plt,
407
- # "sns": sns,
408
- # "matplotlib": matplotlib,
409
- # "uuid": uuid
410
- # })
411
-
412
- # agent = create_pandas_dataframe_agent(
413
- # llm,
414
- # data,
415
- # agent_type="openai-tools",
416
- # verbose=True,
417
- # allow_dangerous_code=True,
418
- # extra_tools=[tool],
419
- # return_intermediate_steps=True
420
- # )
421
 
422
- # result = agent.invoke({"input": _prompt_generator(question, True)})
423
- # output = result.get("output", "")
 
 
 
 
 
 
 
424
 
425
- # # Verify chart file creation
426
- # chart_files = extract_chart_filenames(output)
427
- # if len(chart_files) > 0:
428
- # return chart_files
429
 
430
- # if attempt < len(groq_api_keys) - 1:
431
- # print(f"Langchain chart error (key {current_key}): {output}")
 
 
432
 
433
- # except Exception as e:
434
- # print(f"Langchain chart error (key {current_key}): {str(e)}")
 
 
 
435
 
436
- # return "Chart generation failed after all retries"
437
 
438
- # @app.post("/api/csv-chart")
439
- # async def csv_chart(request: dict, authorization: str = Header(None)):
440
- # # Authorization verification
441
- # if not authorization or not authorization.startswith("Bearer "):
442
- # raise HTTPException(status_code=401, detail="Authorization required")
443
 
444
- # token = authorization.split(" ")[1]
445
- # if token != os.getenv("AUTH_TOKEN"):
446
- # raise HTTPException(status_code=403, detail="Invalid credentials")
447
 
448
- # try:
449
- # query = request.get("query", "")
450
- # csv_url = unquote(request.get("csv_url", ""))
451
 
452
- # # Parallel processing with thread pool
453
- # if if_initial_chart_question(query):
454
- # chart_paths = await asyncio.to_thread(
455
- # langchain_csv_chart, csv_url, query, True
456
- # )
457
- # print(chart_paths)
458
 
459
- # if len(chart_paths) > 0:
460
- # return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
461
 
462
- # # Groq-based chart generation
463
- # groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
464
- # print(f"Generated Chart: {groq_result}")
465
- # if groq_result != 'Chart not generated':
466
- # return FileResponse(groq_result, media_type="image/png")
467
 
468
 
469
- # # Fallback to Langchain
470
- # langchain_paths = await asyncio.to_thread(
471
- # langchain_csv_chart, csv_url, query, True
472
- # )
473
 
474
- # print (langchain_paths)
475
 
476
- # if len(langchain_paths) > 0:
477
- # return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
478
- # else:
479
- # return {"error": "All chart generation methods failed"}
480
-
481
- # except Exception as e:
482
- # print(f"Critical chart error: {str(e)}")
483
- # return {"error": "Internal system error"}
484
 
485
 
486
 
@@ -569,233 +568,3 @@ def handle_out_of_range_float(value):
569
 
570
 
571
 
572
-
573
-
574
- import os
575
- import asyncio
576
- import threading
577
- import uuid
578
- from fastapi import FastAPI, HTTPException, Header
579
- from fastapi.responses import FileResponse
580
- from urllib.parse import unquote
581
- from pydantic import BaseModel
582
- from concurrent.futures import ProcessPoolExecutor
583
- import matplotlib.pyplot as plt
584
- import matplotlib
585
- import pandas as pd
586
- import numpy as np
587
- import seaborn as sns
588
-
589
- # Import your custom modules (assumed available)
590
- from csv_service import clean_data, extract_chart_filenames
591
- from langchain_experimental.tools import PythonAstREPLTool
592
- from langchain_experimental.agents import create_pandas_dataframe_agent
593
- from langchain_groq import ChatGroq
594
- from util_service import _prompt_generator, process_answer
595
- from intitial_q_handler import if_initial_chart_question
596
-
597
- # Use non-interactive backend
598
- matplotlib.use('Agg')
599
-
600
- # FastAPI app initialization
601
- app = FastAPI()
602
-
603
- # Ensure necessary directories exist
604
- os.makedirs("/app/generated_charts", exist_ok=True)
605
-
606
-
607
- groq_api_keys = os.getenv("GROQ_API_KEYS", "").split(",")
608
- model_name = os.getenv("GROQ_LLM_MODEL")
609
- image_file_path = os.getenv("IMAGE_FILE_PATH") # e.g. "/app/generated_charts"
610
-
611
- # Global locks for key rotation (chart endpoints)
612
- current_groq_chart_key_index = 0
613
- current_groq_chart_lock = threading.Lock()
614
- current_langchain_chart_key_index = 0
615
- current_langchain_chart_lock = threading.Lock()
616
-
617
- # Use a process pool to run CPU-bound chart generation
618
- process_executor = ProcessPoolExecutor(max_workers=2)
619
-
620
- # --- GROQ-BASED CHART GENERATION ---
621
- def groq_chart(csv_url: str, question: str):
622
- """
623
- Generate a chart using the groq-based method.
624
- Modifications:
625
- • No deletion of a shared cache file (avoid interference).
626
- • After chart generation, close all matplotlib figures.
627
- • Return the full path of the saved chart.
628
- """
629
- global current_groq_chart_key_index, current_groq_chart_lock
630
-
631
- for attempt in range(len(groq_api_keys)):
632
- try:
633
- # Instead of deleting a global cache file, you might later configure a per-request cache.
634
- data = clean_data(csv_url)
635
- with current_groq_chart_lock:
636
- current_api_key = groq_api_keys[current_groq_chart_key_index]
637
-
638
- llm = ChatGroq(model=model_name, api_key=current_api_key)
639
-
640
- # Generate a unique filename and full path for the chart
641
- chart_filename = f"chart_{uuid.uuid4().hex}.png"
642
- chart_path = os.path.join("generated_charts", chart_filename)
643
-
644
- # Configure your dataframe tool (e.g. using SmartDataframe) to save charts.
645
- # (Assuming your SmartDataframe uses these settings to save charts.)
646
- from pandasai import SmartDataframe # Import here if not already imported
647
- df = SmartDataframe(
648
- data,
649
- config={
650
- 'llm': llm,
651
- 'save_charts': True,
652
- 'open_charts': False,
653
- 'save_charts_path': os.path.dirname(chart_path),
654
- 'custom_chart_filename': chart_filename
655
- }
656
- )
657
-
658
- # Append any extra instructions if needed
659
- instructions = """
660
- - Ensure each value is clearly visible.
661
- - Adjust font sizes, rotate labels if necessary.
662
- - Use a colorblind-friendly palette.
663
- - Arrange multiple charts in a grid if needed.
664
- """
665
- answer = df.chat(question + instructions)
666
-
667
- # Make sure to close figures so they don't conflict between processes
668
- plt.close('all')
669
-
670
- # If process_answer indicates a problem, return a failure message.
671
- if process_answer(answer):
672
- return "Chart not generated"
673
- # Return the chart path that was used for saving
674
- return chart_path
675
-
676
- except Exception as e:
677
- error = str(e)
678
- if "429" in error:
679
- with current_groq_chart_lock:
680
- current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
681
- else:
682
- print(f"Groq chart generation error: {error}")
683
- return {"error": error}
684
-
685
- return {"error": "All API keys exhausted for chart generation"}
686
-
687
-
688
- # --- LANGCHAIN-BASED CHART GENERATION ---
689
- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
690
- """
691
- Generate a chart using the langchain-based method.
692
- Modifications:
693
- • No shared deletion of cache.
694
- • Close matplotlib figures after generation.
695
- • Return a list of full chart file paths.
696
- """
697
- global current_langchain_chart_key_index, current_langchain_chart_lock
698
-
699
- data = clean_data(csv_url)
700
-
701
- for attempt in range(len(groq_api_keys)):
702
- try:
703
- with current_langchain_chart_lock:
704
- api_key = groq_api_keys[current_langchain_chart_key_index]
705
- current_key = current_langchain_chart_key_index
706
- current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
707
-
708
- llm = ChatGroq(model=model_name, api_key=api_key)
709
- tool = PythonAstREPLTool(locals={
710
- "df": data,
711
- "pd": pd,
712
- "np": np,
713
- "plt": plt,
714
- "sns": sns,
715
- "matplotlib": matplotlib,
716
- "uuid": uuid
717
- })
718
-
719
- agent = create_pandas_dataframe_agent(
720
- llm,
721
- data,
722
- agent_type="openai-tools",
723
- verbose=True,
724
- allow_dangerous_code=True,
725
- extra_tools=[tool],
726
- return_intermediate_steps=True
727
- )
728
-
729
- result = agent.invoke({"input": _prompt_generator(question, True)})
730
- output = result.get("output", "")
731
-
732
- # Close figures to avoid interference
733
- plt.close('all')
734
-
735
- # Extract chart filenames (assuming extract_chart_filenames returns a list)
736
- chart_files = extract_chart_filenames(output)
737
- if len(chart_files) > 0:
738
- # Return full paths (join with your image_file_path)
739
- return [os.path.join(image_file_path, f) for f in chart_files]
740
-
741
- if attempt < len(groq_api_keys) - 1:
742
- print(f"Langchain chart error (key {current_key}): {output}")
743
-
744
- except Exception as e:
745
- print(f"Langchain chart error (key {current_key}): {str(e)}")
746
-
747
- return "Chart generation failed after all retries"
748
-
749
-
750
- # --- FASTAPI ENDPOINT FOR CHART GENERATION ---
751
- @app.post("/api/csv-chart")
752
- async def csv_chart(request: dict, authorization: str = Header(None)):
753
- """
754
- Endpoint for generating a chart from CSV data.
755
- This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation
756
- functions in separate processes so that multiple requests can run in parallel.
757
- """
758
- # --- Authorization Check ---
759
- if not authorization or not authorization.startswith("Bearer "):
760
- raise HTTPException(status_code=401, detail="Authorization required")
761
-
762
- token = authorization.split(" ")[1]
763
- if token != os.getenv("AUTH_TOKEN"):
764
- raise HTTPException(status_code=403, detail="Invalid credentials")
765
-
766
- try:
767
- query = request.get("query", "")
768
- csv_url = unquote(request.get("csv_url", ""))
769
-
770
- loop = asyncio.get_running_loop()
771
- # First, try the langchain-based method if the question qualifies
772
- if if_initial_chart_question(query):
773
- langchain_result = await loop.run_in_executor(
774
- process_executor, langchain_csv_chart, csv_url, query, True
775
- )
776
- print("Langchain chart result:", langchain_result)
777
- if isinstance(langchain_result, list) and len(langchain_result) > 0:
778
- return FileResponse(langchain_result[0], media_type="image/png")
779
-
780
- # Next, try the groq-based method
781
- groq_result = await loop.run_in_executor(
782
- process_executor, groq_chart, csv_url, query
783
- )
784
- print(f"Groq chart result: {groq_result}")
785
- if isinstance(groq_result, str) and groq_result != "Chart not generated":
786
- return FileResponse(groq_result, media_type="image/png")
787
-
788
- # Fallback: try langchain-based again
789
- langchain_paths = await loop.run_in_executor(
790
- process_executor, langchain_csv_chart, csv_url, query, True
791
- )
792
- print("Fallback langchain chart result:", langchain_paths)
793
- if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
794
- return FileResponse(langchain_paths[0], media_type="image/png")
795
- else:
796
- return {"error": "All chart generation methods failed"}
797
-
798
- except Exception as e:
799
- print(f"Critical chart error: {str(e)}")
800
- return {"error": "Internal system error"}
801
-
 
38
 
39
  os.makedirs("/app", exist_ok=True)
40
  open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist
 
41
 
42
  # Ensure the generated_charts directory exists
43
  os.makedirs("/app/generated_charts", exist_ok=True)
 
307
 
308
  # CHART CODING STARTS FROM HERE
309
 
310
+ instructions = """
311
 
312
+ - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed).
313
+ - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.)
314
+ - Use colorblind-friendly palette
315
+ - Read above instructions and follow them.
316
 
317
+ """
318
 
319
+ # Thread-safe configuration for chart endpoints
320
+ current_groq_chart_key_index = 0
321
+ current_groq_chart_lock = threading.Lock()
322
 
323
+ current_langchain_chart_key_index = 0
324
+ current_langchain_chart_lock = threading.Lock()
325
 
326
+ def model():
327
+ global current_groq_chart_key_index, current_groq_chart_lock
328
+ with current_groq_chart_lock:
329
+ if current_groq_chart_key_index >= len(groq_api_keys):
330
+ raise Exception("All API keys exhausted for chart generation")
331
+ api_key = groq_api_keys[current_groq_chart_key_index]
332
+ return ChatGroq(model=model_name, api_key=api_key)
333
 
334
+ def groq_chart(csv_url: str, question: str):
335
+ global current_groq_chart_key_index, current_groq_chart_lock
336
 
337
+ for attempt in range(len(groq_api_keys)):
338
+ try:
339
+ # Clean cache before processing
340
+ cache_db_path = "/workspace/cache/cache_db_0.11.db"
341
+ if os.path.exists(cache_db_path):
342
+ try:
343
+ os.remove(cache_db_path)
344
+ except Exception as e:
345
+ print(f"Cache cleanup error: {e}")
346
+
347
+ data = clean_data(csv_url)
348
+ with current_groq_chart_lock:
349
+ current_api_key = groq_api_keys[current_groq_chart_key_index]
350
 
351
+ llm = ChatGroq(model=model_name, api_key=current_api_key)
352
 
353
+ # Generate unique filename using UUID
354
+ chart_filename = f"chart_{uuid.uuid4()}.png"
355
+ chart_path = os.path.join("generated_charts", chart_filename)
356
 
357
+ # Configure SmartDataframe with chart settings
358
+ df = SmartDataframe(
359
+ data,
360
+ config={
361
+ 'llm': llm,
362
+ 'save_charts': True, # Enable chart saving
363
+ 'open_charts': False,
364
+ 'save_charts_path': os.path.dirname(chart_path), # Directory to save
365
+ 'custom_chart_filename': chart_filename # Unique filename
366
+ }
367
+ )
368
 
369
+ answer = df.chat(question + instructions)
370
 
371
+ if process_answer(answer):
372
+ return "Chart not generated"
373
+ return answer
374
 
375
+ except Exception as e:
376
+ error = str(e)
377
+ if "429" in error:
378
+ with current_groq_chart_lock:
379
+ current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys)
380
+ else:
381
+ print(f"Chart generation error: {error}")
382
+ return {"error": error}
383
 
384
+ return {"error": "All API keys exhausted for chart generation"}
385
 
386
 
387
 
388
+ def langchain_csv_chart(csv_url: str, question: str, chart_required: bool):
389
+ global current_langchain_chart_key_index, current_langchain_chart_lock
390
 
391
+ data = clean_data(csv_url)
392
 
393
+ for attempt in range(len(groq_api_keys)):
394
+ try:
395
+ with current_langchain_chart_lock:
396
+ api_key = groq_api_keys[current_langchain_chart_key_index]
397
+ current_key = current_langchain_chart_key_index
398
+ current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys)
399
+
400
+ llm = ChatGroq(model=model_name, api_key=api_key)
401
+ tool = PythonAstREPLTool(locals={
402
+ "df": data,
403
+ "pd": pd,
404
+ "np": np,
405
+ "plt": plt,
406
+ "sns": sns,
407
+ "matplotlib": matplotlib,
408
+ "uuid": uuid
409
+ })
 
 
 
 
 
 
 
 
 
 
410
 
411
+ agent = create_pandas_dataframe_agent(
412
+ llm,
413
+ data,
414
+ agent_type="openai-tools",
415
+ verbose=True,
416
+ allow_dangerous_code=True,
417
+ extra_tools=[tool],
418
+ return_intermediate_steps=True
419
+ )
420
 
421
+ result = agent.invoke({"input": _prompt_generator(question, True)})
422
+ output = result.get("output", "")
 
 
423
 
424
+ # Verify chart file creation
425
+ chart_files = extract_chart_filenames(output)
426
+ if len(chart_files) > 0:
427
+ return chart_files
428
 
429
+ if attempt < len(groq_api_keys) - 1:
430
+ print(f"Langchain chart error (key {current_key}): {output}")
431
+
432
+ except Exception as e:
433
+ print(f"Langchain chart error (key {current_key}): {str(e)}")
434
 
435
+ return "Chart generation failed after all retries"
436
 
437
+ @app.post("/api/csv-chart")
438
+ async def csv_chart(request: dict, authorization: str = Header(None)):
439
+ # Authorization verification
440
+ if not authorization or not authorization.startswith("Bearer "):
441
+ raise HTTPException(status_code=401, detail="Authorization required")
442
 
443
+ token = authorization.split(" ")[1]
444
+ if token != os.getenv("AUTH_TOKEN"):
445
+ raise HTTPException(status_code=403, detail="Invalid credentials")
446
 
447
+ try:
448
+ query = request.get("query", "")
449
+ csv_url = unquote(request.get("csv_url", ""))
450
 
451
+ # Parallel processing with thread pool
452
+ if if_initial_chart_question(query):
453
+ chart_paths = await asyncio.to_thread(
454
+ langchain_csv_chart, csv_url, query, True
455
+ )
456
+ print(chart_paths)
457
 
458
+ if len(chart_paths) > 0:
459
+ return FileResponse(f"{image_file_path}/{chart_paths[0]}", media_type="image/png")
460
 
461
+ # Groq-based chart generation
462
+ groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
463
+ print(f"Generated Chart: {groq_result}")
464
+ if groq_result != 'Chart not generated':
465
+ return FileResponse(groq_result, media_type="image/png")
466
 
467
 
468
+ # Fallback to Langchain
469
+ langchain_paths = await asyncio.to_thread(
470
+ langchain_csv_chart, csv_url, query, True
471
+ )
472
 
473
+ print (langchain_paths)
474
 
475
+ if len(langchain_paths) > 0:
476
+ return FileResponse(f"{image_file_path}/{langchain_paths[0]}", media_type="image/png")
477
+ else:
478
+ return {"error": "All chart generation methods failed"}
479
+
480
+ except Exception as e:
481
+ print(f"Critical chart error: {str(e)}")
482
+ return {"error": "Internal system error"}
483
 
484
 
485
 
 
568
 
569
 
570