added gemini too
Browse files- controller.py +45 -3
controller.py
CHANGED
@@ -7,7 +7,7 @@ import threading
|
|
7 |
import uuid
|
8 |
from fastapi import FastAPI, HTTPException, Header
|
9 |
from fastapi.encoders import jsonable_encoder
|
10 |
-
from typing import Dict
|
11 |
from fastapi.responses import FileResponse
|
12 |
import numpy as np
|
13 |
import pandas as pd
|
@@ -75,6 +75,48 @@ class CsvUrlRequest(BaseModel):
|
|
75 |
|
76 |
class ImageRequest(BaseModel):
|
77 |
image_path: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Thread-safe key management for groq_chat
|
80 |
current_groq_key_index = 0
|
@@ -305,7 +347,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
305 |
generate_report = request.get("generate_report")
|
306 |
|
307 |
if generate_report is True:
|
308 |
-
return {"answer":
|
309 |
|
310 |
if if_initial_chat_question(query):
|
311 |
answer = await asyncio.to_thread(
|
@@ -806,7 +848,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
806 |
conversation_history = request.get("conversation_history", [])
|
807 |
generate_report = request.get("generate_report", False)
|
808 |
if generate_report is True:
|
809 |
-
return {"orchestrator_response":
|
810 |
|
811 |
loop = asyncio.get_running_loop()
|
812 |
# First, try the langchain-based method if the question qualifies
|
|
|
7 |
import uuid
|
8 |
from fastapi import FastAPI, HTTPException, Header
|
9 |
from fastapi.encoders import jsonable_encoder
|
10 |
+
from typing import Dict, List
|
11 |
from fastapi.responses import FileResponse
|
12 |
import numpy as np
|
13 |
import pandas as pd
|
|
|
75 |
|
76 |
class ImageRequest(BaseModel):
|
77 |
image_path: str
|
78 |
+
|
79 |
+
class FileProps(BaseModel):
|
80 |
+
fileName: str
|
81 |
+
filePath: str
|
82 |
+
fileType: str # 'csv' | 'image'
|
83 |
+
|
84 |
+
class Files(BaseModel):
|
85 |
+
csv_files: List[FileProps]
|
86 |
+
image_files: List[FileProps]
|
87 |
+
|
88 |
+
class FileBoxProps(BaseModel):
|
89 |
+
files: Files
|
90 |
+
|
91 |
+
dummy_response = FileBoxProps(
|
92 |
+
files=Files(
|
93 |
+
csv_files=[
|
94 |
+
FileProps(
|
95 |
+
fileName="sales_data.csv",
|
96 |
+
filePath="/downloads/sales_data.csv",
|
97 |
+
fileType="csv"
|
98 |
+
),
|
99 |
+
FileProps(
|
100 |
+
fileName="customer_data.csv",
|
101 |
+
filePath="/downloads/customer_data.csv",
|
102 |
+
fileType="csv"
|
103 |
+
)
|
104 |
+
],
|
105 |
+
image_files=[
|
106 |
+
FileProps(
|
107 |
+
fileName="chart.png",
|
108 |
+
filePath="/downloads/chart.png",
|
109 |
+
fileType="image"
|
110 |
+
),
|
111 |
+
FileProps(
|
112 |
+
fileName="graph.jpg",
|
113 |
+
filePath="/downloads/graph.jpg",
|
114 |
+
fileType="image"
|
115 |
+
)
|
116 |
+
]
|
117 |
+
)
|
118 |
+
)
|
119 |
+
|
120 |
|
121 |
# Thread-safe key management for groq_chat
|
122 |
current_groq_key_index = 0
|
|
|
347 |
generate_report = request.get("generate_report")
|
348 |
|
349 |
if generate_report is True:
|
350 |
+
return {"answer": jsonable_encoder(dummy_response)}
|
351 |
|
352 |
if if_initial_chat_question(query):
|
353 |
answer = await asyncio.to_thread(
|
|
|
848 |
conversation_history = request.get("conversation_history", [])
|
849 |
generate_report = request.get("generate_report", False)
|
850 |
if generate_report is True:
|
851 |
+
return {"orchestrator_response": jsonable_encoder(dummy_response)}
|
852 |
|
853 |
loop = asyncio.get_running_loop()
|
854 |
# First, try the langchain-based method if the question qualifies
|