Spaces:
Sleeping
Sleeping
Commit
·
86fd3c3
1
Parent(s):
9e4b8b0
Cleaning
Browse files- Dockerfile +3 -26
- README.md +1 -1
- backend/Dockerfile +0 -10
- backend/__init__.py +0 -0
- backend/main.py +196 -672
- backend/rag.py +255 -95
- backend/requirements.txt +0 -7
- frontend/Dockerfile +0 -10
- frontend/src/hooks/useAppState.ts +102 -103
- run.sh +0 -14
Dockerfile
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# Stage 1: build frontend (React + Vite)
|
2 |
FROM node:18-alpine AS frontend-builder
|
3 |
WORKDIR /app/frontend
|
4 |
COPY frontend/package.json ./
|
@@ -6,7 +5,6 @@ RUN npm install
|
|
6 |
COPY frontend/ .
|
7 |
RUN npm run build
|
8 |
|
9 |
-
# Stage 2: build backend (FastAPI)
|
10 |
FROM python:3.11-slim AS backend-builder
|
11 |
WORKDIR /app/backend
|
12 |
COPY backend/requirements.txt ./
|
@@ -14,11 +12,9 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
14 |
COPY backend/ .
|
15 |
RUN pip install --no-cache-dir gunicorn uvicorn
|
16 |
|
17 |
-
# Stage 3: runtime image with nginx and run script
|
18 |
FROM python:3.11-slim as runtime
|
19 |
|
20 |
-
|
21 |
-
# Install OS deps
|
22 |
USER root
|
23 |
USER root
|
24 |
RUN apt-get update && \
|
@@ -39,21 +35,8 @@ RUN mkdir -p \
|
|
39 |
/var/lib/nginx/proxy \
|
40 |
/var/lib/nginx/fastcgi \
|
41 |
/var/lib/nginx/scgi \
|
42 |
-
/var/lib/nginx/uwsgi \
|
43 |
-
|
44 |
-
|
45 |
-
RUN mkdir -p /var/cache/nginx/client_temp \
|
46 |
-
/var/cache/nginx/proxy_temp \
|
47 |
-
/var/cache/nginx/fastcgi_temp \
|
48 |
-
/var/cache/nginx/scgi_temp \
|
49 |
-
/var/cache/nginx/uwsgi_temp \
|
50 |
-
/var/log/nginx \
|
51 |
-
/var/run/nginx \
|
52 |
-
/var/lib/nginx/body \
|
53 |
-
/var/lib/nginx/proxy \
|
54 |
-
/var/lib/nginx/fastcgi \
|
55 |
-
/var/lib/nginx/scgi \
|
56 |
-
/var/lib/nginx/uwsgi && \
|
57 |
touch /var/log/nginx/error.log /var/log/nginx/access.log && \
|
58 |
chown -R www-data:www-data /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx
|
59 |
|
@@ -65,21 +48,16 @@ ENV HF_HOME=/tmp/hf_cache \
|
|
65 |
RUN mkdir -p /tmp/hf_cache \
|
66 |
&& chmod 777 /tmp/hf_cache
|
67 |
|
68 |
-
# Install Python deps from requirements (ensures numpy/pandas compatibility), then ASGI
|
69 |
-
# copy in your requirements
|
70 |
COPY --from=backend-builder /app/backend/requirements.txt /tmp/requirements.txt
|
71 |
|
72 |
RUN python3 -m pip install --no-cache-dir \
|
73 |
-
# 2) Now install the rest (including gptqmodel)
|
74 |
-r /tmp/requirements.txt \
|
75 |
&& python3 -m pip install --no-cache-dir fastapi starlette uvicorn
|
76 |
|
77 |
|
78 |
-
# Copy frontend build and backend app
|
79 |
COPY --from=frontend-builder /app/frontend/dist /app/static
|
80 |
COPY --from=backend-builder /app/backend /app/app
|
81 |
|
82 |
-
# Copy nginx config and run script
|
83 |
COPY nginx.conf /etc/nginx/nginx.conf
|
84 |
COPY run.sh /app/run.sh
|
85 |
RUN chmod +x /app/run.sh
|
@@ -87,6 +65,5 @@ RUN chmod -R a+rwx /var/log/nginx
|
|
87 |
|
88 |
WORKDIR /app
|
89 |
|
90 |
-
# Use run.sh as entrypoint (runs nginx, static server, uvicorn)
|
91 |
ENTRYPOINT ["/bin/bash", "/app/run.sh"]
|
92 |
|
|
|
|
|
1 |
FROM node:18-alpine AS frontend-builder
|
2 |
WORKDIR /app/frontend
|
3 |
COPY frontend/package.json ./
|
|
|
5 |
COPY frontend/ .
|
6 |
RUN npm run build
|
7 |
|
|
|
8 |
FROM python:3.11-slim AS backend-builder
|
9 |
WORKDIR /app/backend
|
10 |
COPY backend/requirements.txt ./
|
|
|
12 |
COPY backend/ .
|
13 |
RUN pip install --no-cache-dir gunicorn uvicorn
|
14 |
|
|
|
15 |
FROM python:3.11-slim as runtime
|
16 |
|
17 |
+
|
|
|
18 |
USER root
|
19 |
USER root
|
20 |
RUN apt-get update && \
|
|
|
35 |
/var/lib/nginx/proxy \
|
36 |
/var/lib/nginx/fastcgi \
|
37 |
/var/lib/nginx/scgi \
|
38 |
+
/var/lib/nginx/uwsgi && \
|
39 |
+
chmod -R a+rwx /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx && \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
touch /var/log/nginx/error.log /var/log/nginx/access.log && \
|
41 |
chown -R www-data:www-data /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx
|
42 |
|
|
|
48 |
RUN mkdir -p /tmp/hf_cache \
|
49 |
&& chmod 777 /tmp/hf_cache
|
50 |
|
|
|
|
|
51 |
COPY --from=backend-builder /app/backend/requirements.txt /tmp/requirements.txt
|
52 |
|
53 |
RUN python3 -m pip install --no-cache-dir \
|
|
|
54 |
-r /tmp/requirements.txt \
|
55 |
&& python3 -m pip install --no-cache-dir fastapi starlette uvicorn
|
56 |
|
57 |
|
|
|
58 |
COPY --from=frontend-builder /app/frontend/dist /app/static
|
59 |
COPY --from=backend-builder /app/backend /app/app
|
60 |
|
|
|
61 |
COPY nginx.conf /etc/nginx/nginx.conf
|
62 |
COPY run.sh /app/run.sh
|
63 |
RUN chmod +x /app/run.sh
|
|
|
65 |
|
66 |
WORKDIR /app
|
67 |
|
|
|
68 |
ENTRYPOINT ["/bin/bash", "/app/run.sh"]
|
69 |
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: EU Explorer
|
3 |
emoji: 🤖 # a single emoji that represents your app
|
4 |
colorFrom: purple # one of: red, yellow, green, blue, indigo, purple, pink, gray
|
5 |
colorTo: indigo # another one from the same list, for the gradient
|
|
|
1 |
---
|
2 |
+
title: EU Explorer (MDA Assignment) # the display name of your Space
|
3 |
emoji: 🤖 # a single emoji that represents your app
|
4 |
colorFrom: purple # one of: red, yellow, green, blue, indigo, purple, pink, gray
|
5 |
colorTo: indigo # another one from the same list, for the gradient
|
backend/Dockerfile
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
FROM python:3.11-slim
|
2 |
-
WORKDIR /app
|
3 |
-
COPY requirements.txt ./
|
4 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
5 |
-
RUN adduser --disabled-password appuser
|
6 |
-
USER appuser
|
7 |
-
COPY --chown=appuser:appuser . .
|
8 |
-
ENV PARQUET_PATH="data/consolidated_clean.parquet"
|
9 |
-
EXPOSE 8080
|
10 |
-
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/__init__.py
DELETED
File without changes
|
backend/main.py
CHANGED
@@ -1,639 +1,118 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
3 |
import traceback
|
4 |
-
from starlette.concurrency import run_in_threadpool
|
5 |
-
from pydantic import BaseModel
|
6 |
-
from pydantic_settings import BaseSettings
|
7 |
from contextlib import asynccontextmanager
|
8 |
-
from
|
|
|
9 |
|
10 |
-
import os
|
11 |
-
import logging
|
12 |
import aiofiles
|
|
|
|
|
13 |
import polars as pl
|
|
|
14 |
import zipfile
|
15 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
19 |
-
from langchain_community.vectorstores import FAISS
|
20 |
-
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
21 |
-
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
22 |
-
from langchain.memory import ConversationBufferWindowMemory
|
23 |
from langchain.chains import ConversationalRetrievalChain
|
|
|
24 |
from langchain.prompts import PromptTemplate
|
25 |
-
from
|
26 |
-
|
27 |
-
from
|
28 |
-
from
|
29 |
-
|
30 |
-
from
|
31 |
-
from whoosh.fields import Schema, TEXT, ID
|
32 |
-
from whoosh.analysis import StemmingAnalyzer
|
33 |
-
from whoosh.qparser import MultifieldParser
|
34 |
-
import pickle
|
35 |
-
from pydantic import PrivateAttr
|
36 |
-
from tqdm import tqdm
|
37 |
-
import faiss
|
38 |
-
import torch
|
39 |
-
import tempfile
|
40 |
-
import shutil
|
41 |
|
42 |
-
|
|
|
43 |
|
44 |
# ---------------------------------------------------------------------------- #
|
45 |
# Settings #
|
46 |
# ---------------------------------------------------------------------------- #
|
47 |
-
# === Logging ===
|
48 |
-
logging.basicConfig(level=logging.INFO)
|
49 |
-
logger = logging.getLogger(__name__)
|
50 |
|
51 |
-
class Settings(
|
52 |
-
|
|
|
|
|
|
|
53 |
parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
|
54 |
-
whoosh_dir:
|
55 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
|
|
60 |
# RAG parameters
|
61 |
-
chunk_size:
|
62 |
chunk_overlap: int = 100
|
63 |
-
hybrid_k:
|
64 |
assistant_role: str = (
|
65 |
-
"You are a knowledgeable project analyst.
|
66 |
)
|
67 |
skip_warmup: bool = True
|
|
|
|
|
68 |
allowed_origins: List[str] = ["*"]
|
69 |
|
70 |
class Config:
|
71 |
env_file = ".env"
|
72 |
|
|
|
73 |
settings = Settings()
|
|
|
|
|
74 |
|
75 |
-
# Pre
|
76 |
-
EMBEDDING = HuggingFaceEmbeddings(
|
77 |
-
|
|
|
|
|
78 |
|
79 |
@lru_cache(maxsize=256)
|
80 |
def embed_query_cached(query: str) -> List[float]:
|
81 |
-
"""Cache embedding vectors for queries."""
|
82 |
return EMBEDDING.embed_query(query.strip().lower())
|
83 |
|
84 |
-
# === Whoosh Cache & Builder ===
|
85 |
-
"""_WHOOSH_CACHE: Dict[str, index.Index] = {}
|
86 |
-
|
87 |
-
async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
|
88 |
-
key = whoosh_dir
|
89 |
-
fs = gcsfs.GCSFileSystem()
|
90 |
-
local_dir = key
|
91 |
-
is_gcs = key.startswith("gs://")
|
92 |
-
try:
|
93 |
-
# stage local copy for GCS
|
94 |
-
if is_gcs:
|
95 |
-
local_dir = "/tmp/whoosh_index"
|
96 |
-
if not os.path.exists(local_dir):
|
97 |
-
if await run_in_threadpool(fs.exists, key):
|
98 |
-
await run_in_threadpool(fs.get, key, local_dir, recursive=True)
|
99 |
-
else:
|
100 |
-
os.makedirs(local_dir, exist_ok=True)
|
101 |
-
# build once
|
102 |
-
if key not in _WHOOSH_CACHE:
|
103 |
-
os.makedirs(local_dir, exist_ok=True)
|
104 |
-
schema = Schema(
|
105 |
-
id=ID(stored=True, unique=True),
|
106 |
-
content=TEXT(analyzer=StemmingAnalyzer()),
|
107 |
-
)
|
108 |
-
ix = index.create_in(local_dir, schema)
|
109 |
-
with ix.writer() as writer:
|
110 |
-
for doc in docs:
|
111 |
-
writer.add_document(
|
112 |
-
id=doc.metadata.get("id", ""),
|
113 |
-
content=doc.page_content,
|
114 |
-
)
|
115 |
-
# push back to GCS atomically
|
116 |
-
if is_gcs:
|
117 |
-
await run_in_threadpool(fs.put, local_dir, key, recursive=True)
|
118 |
-
_WHOOSH_CACHE[key] = ix
|
119 |
-
return _WHOOSH_CACHE[key]
|
120 |
-
except Exception as e:
|
121 |
-
logger.error(f"Failed to build Whoosh index: {e}")
|
122 |
-
raise"""
|
123 |
-
|
124 |
-
async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
|
125 |
-
"""
|
126 |
-
If gs://.../whoosh_index.zip exists, download & extract it once.
|
127 |
-
Otherwise build locally from docs and upload the ZIP back to GCS.
|
128 |
-
"""
|
129 |
-
fs = gcsfs.GCSFileSystem()
|
130 |
-
is_gcs = whoosh_dir.startswith("gs://")
|
131 |
-
zip_uri = whoosh_dir.rstrip("/") + ".zip"
|
132 |
-
|
133 |
-
local_zip = "/tmp/whoosh_index.zip"
|
134 |
-
local_dir = "/tmp/whoosh_index"
|
135 |
-
|
136 |
-
# Clean slate
|
137 |
-
if os.path.exists(local_dir):
|
138 |
-
shutil.rmtree(local_dir)
|
139 |
-
os.makedirs(local_dir, exist_ok=True)
|
140 |
-
|
141 |
-
# 1️⃣ Try downloading the ZIP if it exists on GCS
|
142 |
-
if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
|
143 |
-
logger.info("Found whoosh_index.zip on GCS; downloading…")
|
144 |
-
await run_in_threadpool(fs.get, zip_uri, local_zip)
|
145 |
-
# Extract all files (flat) into local_dir
|
146 |
-
with zipfile.ZipFile(local_zip, "r") as zf:
|
147 |
-
for member in zf.infolist():
|
148 |
-
if member.is_dir():
|
149 |
-
continue
|
150 |
-
filename = os.path.basename(member.filename)
|
151 |
-
if not filename:
|
152 |
-
continue
|
153 |
-
target = os.path.join(local_dir, filename)
|
154 |
-
os.makedirs(os.path.dirname(target), exist_ok=True)
|
155 |
-
with zf.open(member) as src, open(target, "wb") as dst:
|
156 |
-
dst.write(src.read())
|
157 |
-
logger.info("Whoosh index extracted from ZIP.")
|
158 |
-
else:
|
159 |
-
logger.info("No whoosh_index.zip found; building index from docs.")
|
160 |
-
|
161 |
-
# Define the schema with stored content
|
162 |
-
schema = Schema(
|
163 |
-
id=ID(stored=True, unique=True),
|
164 |
-
content=TEXT(stored=True, analyzer=StemmingAnalyzer()),
|
165 |
-
)
|
166 |
-
|
167 |
-
# Create the index
|
168 |
-
ix = index.create_in(local_dir, schema)
|
169 |
-
writer = ix.writer()
|
170 |
-
for doc in docs:
|
171 |
-
writer.add_document(
|
172 |
-
id=doc.metadata.get("id", ""),
|
173 |
-
content=doc.page_content,
|
174 |
-
)
|
175 |
-
writer.commit()
|
176 |
-
logger.info("Whoosh index built locally.")
|
177 |
-
|
178 |
-
# Upload the ZIP back to GCS
|
179 |
-
if is_gcs:
|
180 |
-
logger.info("Zipping and uploading new whoosh_index.zip to GCS…")
|
181 |
-
with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
|
182 |
-
for root, _, files in os.walk(local_dir):
|
183 |
-
for fname in files:
|
184 |
-
full = os.path.join(root, fname)
|
185 |
-
arc = os.path.relpath(full, local_dir)
|
186 |
-
zf.write(full, arc)
|
187 |
-
await run_in_threadpool(fs.put, local_zip, zip_uri)
|
188 |
-
logger.info("Uploaded whoosh_index.zip to GCS.")
|
189 |
-
|
190 |
-
# 2️⃣ Finally open the index and return it
|
191 |
-
ix = index.open_dir(local_dir)
|
192 |
-
return ix
|
193 |
-
|
194 |
-
# === Document Loader ===
|
195 |
-
async def load_documents(
|
196 |
-
path: str,
|
197 |
-
sample_size: Optional[int] = None
|
198 |
-
) -> List[Document]:
|
199 |
-
"""
|
200 |
-
Load a Parquet file from local or GCS, convert to a list of Documents.
|
201 |
-
"""
|
202 |
-
def _read_local(p: str, n: Optional[int]):
|
203 |
-
# streaming scan keeps memory low
|
204 |
-
lf = pl.scan_parquet(p)
|
205 |
-
if n:
|
206 |
-
lf = lf.limit(n)
|
207 |
-
return lf.collect(streaming=True)
|
208 |
-
|
209 |
-
def _read_gcs(p: str, n: Optional[int]):
|
210 |
-
# download to a temp file synchronously, then read with Polars
|
211 |
-
fs = gcsfs.GCSFileSystem()
|
212 |
-
with tempfile.TemporaryDirectory() as td:
|
213 |
-
local_path = os.path.join(td, "data.parquet")
|
214 |
-
fs.get(p, local_path, recursive=False)
|
215 |
-
df = pl.read_parquet(local_path)
|
216 |
-
if n:
|
217 |
-
df = df.head(n)
|
218 |
-
return df
|
219 |
-
|
220 |
-
try:
|
221 |
-
if path.startswith("gs://"):
|
222 |
-
df = await run_in_threadpool(_read_gcs, path, sample_size)
|
223 |
-
else:
|
224 |
-
df = await run_in_threadpool(_read_local, path, sample_size)
|
225 |
-
except Exception as e:
|
226 |
-
logger.error(f"Error loading documents: {e}")
|
227 |
-
raise HTTPException(status_code=500, detail="Document loading failed.")
|
228 |
-
|
229 |
-
docs: List[Document] = []
|
230 |
-
for row in df.rows(named=True):
|
231 |
-
context_parts: List[str] = []
|
232 |
-
# build metadata context
|
233 |
-
max_contrib = row.get("ecMaxContribution", "")
|
234 |
-
end_date = row.get("endDate", "")
|
235 |
-
duration = row.get("durationDays", "")
|
236 |
-
status = row.get("status", "")
|
237 |
-
legal = row.get("legalBasis", "")
|
238 |
-
framework = row.get("frameworkProgramme", "")
|
239 |
-
scheme = row.get("fundingScheme", "")
|
240 |
-
names = row.get("list_name", []) or []
|
241 |
-
cities = row.get("list_city", []) or []
|
242 |
-
countries = row.get("list_country", []) or []
|
243 |
-
activity = row.get("list_activityType", []) or []
|
244 |
-
contributions = row.get("list_ecContribution", []) or []
|
245 |
-
smes = row.get("list_sme", []) or []
|
246 |
-
project_id =row.get("id", "")
|
247 |
-
pred=row.get("predicted_label", "")
|
248 |
-
proba=row.get("predicted_prob", "")
|
249 |
-
top1_feats=row.get("top1_features", "")
|
250 |
-
top2_feats=row.get("top2_features", "")
|
251 |
-
top3_feats=row.get("top3_features", "")
|
252 |
-
top1_shap=row.get("top1_shap", "")
|
253 |
-
top2_shap=row.get("top2_shap", "")
|
254 |
-
top3_shap=row.get("top3_shap", "")
|
255 |
-
|
256 |
-
|
257 |
-
context_parts.append(
|
258 |
-
f"This project under framework {framework} with funding scheme {scheme}, status {status}, legal basis {legal}."
|
259 |
-
)
|
260 |
-
context_parts.append(
|
261 |
-
f"It ends on {end_date} after {duration} days and has a max EC contribution of {max_contrib}."
|
262 |
-
)
|
263 |
-
context_parts.append("Participating organizations:")
|
264 |
-
for i, name in enumerate(names):
|
265 |
-
city = cities[i] if i < len(cities) else ""
|
266 |
-
country = countries[i] if i < len(countries) else ""
|
267 |
-
act = activity[i] if i < len(activity) else ""
|
268 |
-
contrib = contributions[i] if i < len(contributions) else ""
|
269 |
-
sme_flag = "SME" if (smes and i < len(smes) and smes[i]) else "non-SME"
|
270 |
-
context_parts.append(
|
271 |
-
f"- {name} in {city}, {country}, activity: {act}, contributed: {contrib}, {sme_flag}."
|
272 |
-
)
|
273 |
-
if status in (None,"signed","SIGNED","Signed"):
|
274 |
-
if int(pred) == 1:
|
275 |
-
label = "TERMINATED"
|
276 |
-
score = float(proba)
|
277 |
-
else:
|
278 |
-
label = "CLOSED"
|
279 |
-
score = 1 - float(proba)
|
280 |
-
|
281 |
-
score_str = f"{score:.2f}"
|
282 |
-
|
283 |
-
context_parts.append(
|
284 |
-
f"- Project {project_id} is predicted to be {label} (score={score_str}). "
|
285 |
-
f"The 3 most predictive features were: "
|
286 |
-
f"{top1_feats} ({top1_shap:.3f}), "
|
287 |
-
f"{top2_feats} ({top2_shap:.3f}), "
|
288 |
-
f"{top3_feats} ({top3_shap:.3f})."
|
289 |
-
)
|
290 |
-
|
291 |
-
title_report = row.get("list_title_report", "")
|
292 |
-
objective = row.get("objective", "")
|
293 |
-
full_body = f"{title_report} {objective}"
|
294 |
-
full_text = " ".join(context_parts + [full_body])
|
295 |
-
meta: Dict[str, Any] = {"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))}
|
296 |
-
meta.update({"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))})
|
297 |
-
docs.append(Document(page_content=full_text, metadata=meta))
|
298 |
-
return docs
|
299 |
-
|
300 |
-
# === BM25 Search ===
|
301 |
-
async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
|
302 |
-
parser = MultifieldParser(["content"], schema=ix.schema)
|
303 |
-
def _search() -> List[Document]:
|
304 |
-
with ix.searcher() as searcher:
|
305 |
-
hits = searcher.search(parser.parse(query), limit=k)
|
306 |
-
return [Document(page_content=h["content"], metadata={"id": h["id"]}) for h in hits]
|
307 |
-
return await run_in_threadpool(_search)
|
308 |
-
|
309 |
-
# === Helper: build or load FAISS with mmap ===
|
310 |
-
"""async def build_or_load_faiss(
|
311 |
-
chunks: List[Document],
|
312 |
-
vectorstore_path: str,
|
313 |
-
batch_size: int = 15000
|
314 |
-
) -> FAISS:"""
|
315 |
-
"""
|
316 |
-
Always uses GCS. Expects 'index.faiss' and 'index.pkl' under vectorstore_path.
|
317 |
-
Reconstructs the FAISS store using your provided logic.
|
318 |
-
"""
|
319 |
-
"""assert vectorstore_path.startswith("gs://")
|
320 |
-
fs = gcsfs.GCSFileSystem()
|
321 |
-
|
322 |
-
base = vectorstore_path.rstrip("/")
|
323 |
-
uri_index = f"{base}/index.faiss"
|
324 |
-
uri_meta = f"{base}/index.pkl"
|
325 |
-
|
326 |
-
local_index = "/tmp/index.faiss"
|
327 |
-
local_meta = "/tmp/index.pkl"
|
328 |
-
|
329 |
-
# 1) If existing index + metadata on GCS → load
|
330 |
-
if fs.exists(uri_index) and fs.exists(uri_meta):
|
331 |
-
logger.info("Found existing FAISS index on GCS; loading…")
|
332 |
-
os.makedirs(os.path.dirname(local_index), exist_ok=True)
|
333 |
-
await run_in_threadpool(
|
334 |
-
fs.get_bulk,
|
335 |
-
[uri_index, uri_meta],
|
336 |
-
[local_index, local_meta]
|
337 |
-
)
|
338 |
-
|
339 |
-
# 3) Memory‐map load
|
340 |
-
mmap_idx = await run_in_threadpool(
|
341 |
-
faiss.read_index, local_index, faiss.IO_FLAG_MMAP
|
342 |
-
)
|
343 |
-
|
344 |
-
# Load metadata
|
345 |
-
with open(local_meta, "rb") as f:
|
346 |
-
saved = pickle.load(f)
|
347 |
-
|
348 |
-
# extract metadata
|
349 |
-
if isinstance(saved, tuple):
|
350 |
-
# Handle tuple of length 2 or 3
|
351 |
-
if len(saved) == 3:
|
352 |
-
_, docstore, index_to_docstore = saved
|
353 |
-
elif len(saved) == 2:
|
354 |
-
docstore, index_to_docstore = saved
|
355 |
-
else:
|
356 |
-
raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
|
357 |
-
else:
|
358 |
-
# saved is an object with attributes
|
359 |
-
if hasattr(saved, "docstore"):
|
360 |
-
docstore = saved.docstore
|
361 |
-
elif hasattr(saved, "_docstore"):
|
362 |
-
docstore = saved._docstore
|
363 |
-
else:
|
364 |
-
raise AttributeError("Could not find docstore in FAISS metadata")
|
365 |
-
|
366 |
-
if hasattr(saved, "index_to_docstore"):
|
367 |
-
index_to_docstore = saved.index_to_docstore
|
368 |
-
elif hasattr(saved, "_index_to_docstore"):
|
369 |
-
index_to_docstore = saved._index_to_docstore
|
370 |
-
elif hasattr(saved, "_faiss_index_to_docstore"):
|
371 |
-
index_to_docstore = saved._faiss_index_to_docstore
|
372 |
-
else:
|
373 |
-
raise AttributeError("Could not find index_to_docstore in FAISS metadata")
|
374 |
-
|
375 |
-
# reconstruct FAISS wrapper
|
376 |
-
vs = FAISS(
|
377 |
-
embedding_function=EMBEDDING, # your embedding function
|
378 |
-
index=mmap_idx,
|
379 |
-
docstore=docstore,
|
380 |
-
index_to_docstore_id=index_to_docstore,
|
381 |
-
)
|
382 |
-
return vs
|
383 |
-
|
384 |
-
# 2) Else: build from scratch in batches
|
385 |
-
# parse bucket & prefix
|
386 |
-
_, rest = vectorstore_path.split("://", 1)
|
387 |
-
bucket, *path_parts = rest.split("/", 1)
|
388 |
-
prefix = path_parts[0] if path_parts else ""
|
389 |
-
|
390 |
-
# helper to upload entire local dir to GCS
|
391 |
-
def upload_dir(local_dir: str):
|
392 |
-
for root, _, files in os.walk(local_dir):
|
393 |
-
for fname in files:
|
394 |
-
local_path = os.path.join(root, fname)
|
395 |
-
# construct the corresponding GCS path
|
396 |
-
rel_path = os.path.relpath(local_path, local_dir)
|
397 |
-
gcs_path = f"gs://{bucket}/{prefix}/{rel_path}"
|
398 |
-
fs.makedirs(os.path.dirname(gcs_path), exist_ok=True)
|
399 |
-
fs.put(local_path, gcs_path)
|
400 |
-
|
401 |
-
# temporary local staging area
|
402 |
-
local_store = "/tmp/faiss_store"
|
403 |
-
os.makedirs(local_store, exist_ok=True)
|
404 |
-
|
405 |
-
# 2) Else: build from scratch in batches
|
406 |
-
logger.info(f"Building FAISS index in batches of {batch_size}…")
|
407 |
-
vs: Optional[FAISS] = None
|
408 |
-
|
409 |
-
for i in tqdm(range(0, len(chunks), batch_size),
|
410 |
-
desc="Building FAISS index",
|
411 |
-
unit="batch"):
|
412 |
-
batch = chunks[i : i + batch_size]
|
413 |
-
|
414 |
-
if vs is None:
|
415 |
-
vs = FAISS.from_documents(batch, EMBEDDING)
|
416 |
-
else:
|
417 |
-
vs.add_documents(batch)
|
418 |
-
|
419 |
-
# periodic save every 5 batches
|
420 |
-
if (i // batch_size) % 5 == 0:
|
421 |
-
# save into local_store
|
422 |
-
vs.save_local(local_store)
|
423 |
-
# push local_store → GCS
|
424 |
-
upload_dir(local_store)
|
425 |
-
logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
|
426 |
-
|
427 |
-
assert vs is not None, "No documents to index!"
|
428 |
-
|
429 |
-
# final save at end
|
430 |
-
vs.save_local(local_store)
|
431 |
-
upload_dir(local_store)
|
432 |
-
logger.info("Finished building index and uploaded to GCS.")
|
433 |
-
|
434 |
-
return vs"""
|
435 |
-
|
436 |
-
async def build_or_load_faiss(
|
437 |
-
docs: List[Document],
|
438 |
-
vectorstore_path: str,
|
439 |
-
batch_size: int = 15000
|
440 |
-
) -> FAISS:
|
441 |
-
"""
|
442 |
-
Expects a ZIP at vectorstore_path + ".zip" containing:
|
443 |
-
- index.faiss
|
444 |
-
- index.pkl
|
445 |
-
Files may be nested under a subfolder (e.g. vectorstore_index_colab/).
|
446 |
-
If the ZIP exists on GCS, download & load only.
|
447 |
-
Otherwise, build from `docs`, save, re-zip, and upload.
|
448 |
-
"""
|
449 |
-
fs = gcsfs.GCSFileSystem()
|
450 |
-
is_gcs = vectorstore_path.startswith("gs://")
|
451 |
-
zip_uri = vectorstore_path.rstrip("/") + ".zip"
|
452 |
-
|
453 |
-
local_zip = "/tmp/faiss_index.zip"
|
454 |
-
local_dir = "/tmp/faiss_store"
|
455 |
-
|
456 |
-
# 1) if ZIP exists, download & extract
|
457 |
-
if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
|
458 |
-
logger.info("Found FAISS ZIP on GCS; loading only.")
|
459 |
-
# clean slate
|
460 |
-
if os.path.exists(local_dir):
|
461 |
-
shutil.rmtree(local_dir)
|
462 |
-
os.makedirs(local_dir, exist_ok=True)
|
463 |
-
|
464 |
-
# download zip
|
465 |
-
await run_in_threadpool(fs.get, zip_uri, local_zip)
|
466 |
-
|
467 |
-
# extract
|
468 |
-
def _extract():
|
469 |
-
with zipfile.ZipFile(local_zip, "r") as zf:
|
470 |
-
zf.extractall(local_dir)
|
471 |
-
await run_in_threadpool(_extract)
|
472 |
-
|
473 |
-
# locate the two files anywhere under local_dir
|
474 |
-
idx_path = None
|
475 |
-
meta_path = None
|
476 |
-
for root, _, files in os.walk(local_dir):
|
477 |
-
if "index.faiss" in files:
|
478 |
-
idx_path = os.path.join(root, "index.faiss")
|
479 |
-
if "index.pkl" in files:
|
480 |
-
meta_path = os.path.join(root, "index.pkl")
|
481 |
-
if not idx_path or not meta_path:
|
482 |
-
raise FileNotFoundError("Couldn't find index.faiss or index.pkl in extracted ZIP.")
|
483 |
-
|
484 |
-
# memory-map load
|
485 |
-
mmap_index = await run_in_threadpool(
|
486 |
-
faiss.read_index, idx_path, faiss.IO_FLAG_MMAP
|
487 |
-
)
|
488 |
-
|
489 |
-
# load metadata
|
490 |
-
with open(meta_path, "rb") as f:
|
491 |
-
saved = pickle.load(f)
|
492 |
-
|
493 |
-
# unpack metadata
|
494 |
-
if isinstance(saved, tuple):
|
495 |
-
_, docstore, index_to_docstore = (
|
496 |
-
saved if len(saved) == 3 else (None, *saved)
|
497 |
-
)
|
498 |
-
else:
|
499 |
-
docstore = getattr(saved, "docstore", saved._docstore)
|
500 |
-
index_to_docstore = getattr(
|
501 |
-
saved,
|
502 |
-
"index_to_docstore",
|
503 |
-
getattr(saved, "_index_to_docstore", saved._faiss_index_to_docstore)
|
504 |
-
)
|
505 |
-
|
506 |
-
# reconstruct FAISS
|
507 |
-
vs = FAISS(
|
508 |
-
embedding_function=EMBEDDING,
|
509 |
-
index=mmap_index,
|
510 |
-
docstore=docstore,
|
511 |
-
index_to_docstore_id=index_to_docstore,
|
512 |
-
)
|
513 |
-
logger.info("FAISS index loaded from ZIP.")
|
514 |
-
return vs
|
515 |
-
|
516 |
-
# 2) otherwise, build from scratch and upload
|
517 |
-
logger.info("No FAISS ZIP found; building index from scratch.")
|
518 |
-
if os.path.exists(local_dir):
|
519 |
-
shutil.rmtree(local_dir)
|
520 |
-
os.makedirs(local_dir, exist_ok=True)
|
521 |
-
|
522 |
-
vs: FAISS = None
|
523 |
-
for i in range(0, len(docs), batch_size):
|
524 |
-
batch = docs[i : i + batch_size]
|
525 |
-
if vs is None:
|
526 |
-
vs = FAISS.from_documents(batch, EMBEDDING)
|
527 |
-
else:
|
528 |
-
vs.add_documents(batch)
|
529 |
-
assert vs is not None, "No documents to index!"
|
530 |
-
|
531 |
-
# save locally
|
532 |
-
vs.save_local(local_dir)
|
533 |
-
|
534 |
-
if is_gcs:
|
535 |
-
# re-zip all contents of local_dir (flattened)
|
536 |
-
def _zip_dir():
|
537 |
-
with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
|
538 |
-
for root, _, files in os.walk(local_dir):
|
539 |
-
for fname in files:
|
540 |
-
full = os.path.join(root, fname)
|
541 |
-
arc = os.path.relpath(full, local_dir)
|
542 |
-
zf.write(full, arc)
|
543 |
-
await run_in_threadpool(_zip_dir)
|
544 |
-
await run_in_threadpool(fs.put, local_zip, zip_uri)
|
545 |
-
logger.info("Built FAISS index and uploaded ZIP to GCS.")
|
546 |
-
|
547 |
-
return vs
|
548 |
-
|
549 |
-
|
550 |
-
# === Index Builder ===
|
551 |
-
async def build_indexes(
|
552 |
-
parquet_path: str,
|
553 |
-
vectorstore_path: str,
|
554 |
-
whoosh_dir: str,
|
555 |
-
chunk_size: int,
|
556 |
-
chunk_overlap: int,
|
557 |
-
debug_size: Optional[int]
|
558 |
-
) -> Tuple[FAISS, index.Index]:
|
559 |
-
docs = await load_documents(parquet_path, debug_size)
|
560 |
-
ix = await build_whoosh_index(docs, whoosh_dir)
|
561 |
-
|
562 |
-
splitter = RecursiveCharacterTextSplitter(
|
563 |
-
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
564 |
-
)
|
565 |
-
chunks = splitter.split_documents(docs)
|
566 |
-
|
567 |
-
# build or load (with mmap) FAISS
|
568 |
-
vs = await build_or_load_faiss(chunks, vectorstore_path)
|
569 |
-
|
570 |
-
return vs, ix
|
571 |
-
|
572 |
-
# === Hybrid Retriever ===
|
573 |
-
class HybridRetriever(BaseRetriever):
|
574 |
-
"""Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
|
575 |
-
# store FAISS and Whoosh under private attributes to avoid Pydantic field errors
|
576 |
-
_vs: FAISS = PrivateAttr()
|
577 |
-
_ix: index.Index = PrivateAttr()
|
578 |
-
_compressor: DocumentCompressorPipeline = PrivateAttr()
|
579 |
-
_cross_encoder: CrossEncoder = PrivateAttr()
|
580 |
-
|
581 |
-
def __init__(
|
582 |
-
self,
|
583 |
-
vs: FAISS,
|
584 |
-
ix: index.Index,
|
585 |
-
compressor: DocumentCompressorPipeline,
|
586 |
-
cross_encoder: CrossEncoder
|
587 |
-
) -> None:
|
588 |
-
super().__init__()
|
589 |
-
object.__setattr__(self, '_vs', vs)
|
590 |
-
object.__setattr__(self, '_ix', ix)
|
591 |
-
object.__setattr__(self, '_compressor', compressor)
|
592 |
-
object.__setattr__(self, '_cross_encoder', cross_encoder)
|
593 |
-
|
594 |
-
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
595 |
-
# BM25 retrieval using Whoosh index
|
596 |
-
bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
|
597 |
-
# Dense retrieval using FAISS
|
598 |
-
dense_docs = self._vs.similarity_search_by_vector(
|
599 |
-
embed_query_cached(query), k=settings.hybrid_k
|
600 |
-
)
|
601 |
-
# Cross-encoder re-ranking
|
602 |
-
candidates = bm_docs + dense_docs
|
603 |
-
scores = self._cross_encoder.predict([
|
604 |
-
(query, doc.page_content) for doc in candidates
|
605 |
-
])
|
606 |
-
ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
|
607 |
-
top = [doc for _, doc in ranked[: settings.hybrid_k]]
|
608 |
-
# Compress and return
|
609 |
-
return self._compressor.compress_documents(top, query=query)
|
610 |
-
|
611 |
-
def _get_relevant_documents(self, query: str) -> List[Document]:
|
612 |
-
import asyncio
|
613 |
-
return asyncio.get_event_loop().run_until_complete(
|
614 |
-
self._aget_relevant_documents(query)
|
615 |
-
)
|
616 |
-
|
617 |
# ---------------------------------------------------------------------------- #
|
618 |
-
# Lifespan
|
619 |
# ---------------------------------------------------------------------------- #
|
|
|
|
|
|
|
620 |
@asynccontextmanager
|
621 |
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
#
|
627 |
logger.info("Initializing Document Compressor")
|
628 |
compressor = DocumentCompressorPipeline(
|
629 |
transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
|
630 |
)
|
631 |
|
632 |
-
# Cross
|
633 |
logger.info("Initializing Cross-Encoder")
|
634 |
cross_encoder = CrossEncoder(settings.cross_encoder_model)
|
635 |
-
|
636 |
-
# Apply dynamic quantization to speed up CPU inference
|
637 |
cross_encoder.model = torch.quantization.quantize_dynamic(
|
638 |
cross_encoder.model,
|
639 |
{torch.nn.Linear},
|
@@ -641,34 +120,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
641 |
)
|
642 |
logger.info("Cross-Encoder quantized")
|
643 |
|
644 |
-
#
|
645 |
-
logger.info("Initializing
|
646 |
-
#full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
|
647 |
-
#full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
|
648 |
-
|
649 |
-
# Apply dynamic quantization to all Linear layers
|
650 |
-
#llm_model = torch.quantization.quantize_dynamic(
|
651 |
-
# full_model,
|
652 |
-
# {torch.nn.Linear},
|
653 |
-
# dtype=torch.qint8
|
654 |
-
#)
|
655 |
-
# Create your text-generation pipeline on CPU
|
656 |
-
#gen_pipe = pipeline(
|
657 |
-
# "text-generation",#"text2text-generation",##"text2text-generation",
|
658 |
-
# model=llm_model,
|
659 |
-
# tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
|
660 |
-
# device=-1, # CPU
|
661 |
-
# max_new_tokens=256,
|
662 |
-
# do_sample=True,
|
663 |
-
# temperature=0.7,
|
664 |
-
# #device_map="auto"
|
665 |
-
#)
|
666 |
tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
|
667 |
-
model
|
668 |
-
model
|
669 |
model, {torch.nn.Linear}, dtype=torch.qint8
|
670 |
)
|
671 |
-
|
672 |
gen_pipe = pipeline(
|
673 |
"text2text-generation",
|
674 |
model=model,
|
@@ -678,20 +136,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
678 |
do_sample=True,
|
679 |
temperature=0.7,
|
680 |
)
|
681 |
-
# Wrap in LangChain's HuggingFacePipeline
|
682 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
683 |
|
684 |
-
#
|
685 |
logger.info("Initializing Conversation Memory")
|
686 |
memory = ConversationBufferWindowMemory(
|
687 |
memory_key="chat_history",
|
688 |
-
k=5,
|
689 |
input_key="question",
|
690 |
output_key="answer",
|
691 |
return_messages=True,
|
|
|
692 |
)
|
693 |
-
|
694 |
-
# Build or load
|
|
|
695 |
vs, ix = await build_indexes(
|
696 |
settings.parquet_path,
|
697 |
settings.vectorstore_path,
|
@@ -700,21 +158,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
700 |
settings.chunk_overlap,
|
701 |
None,
|
702 |
)
|
703 |
-
logger.info("Initializing Hybrid Retriever")
|
704 |
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
|
705 |
-
|
|
|
706 |
prompt = PromptTemplate.from_template(
|
707 |
-
f"{settings.assistant_role}
|
708 |
"{context}\n"
|
709 |
-
"User Question:\n"
|
710 |
-
"
|
711 |
-
|
712 |
-
"1. Write at least 4-6 full sentences.\n"
|
713 |
-
"2. Use clear, technical language in full sentences.\n"
|
714 |
-
"3. Cite any document you reference by including its ID in [brackets] inline.\n"
|
715 |
-
"4. Conclude with high-level insights or recommendations.\n"
|
716 |
-
"Answer:")
|
717 |
|
|
|
718 |
logger.info("Initializing Retrieval Chain")
|
719 |
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
|
720 |
llm=llm,
|
@@ -724,35 +178,33 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
724 |
return_source_documents=True,
|
725 |
)
|
726 |
|
|
|
727 |
if not settings.skip_warmup:
|
728 |
-
logger.info("Warming up RAG chain
|
729 |
await app.state.rag_chain.ainvoke({"question": "warmup"})
|
730 |
-
logger.info("RAG ready.")
|
731 |
|
732 |
-
#
|
733 |
-
logger.info("Loading Parquet data from GCS
|
734 |
fs = gcsfs.GCSFileSystem()
|
735 |
with fs.open(settings.parquet_path, "rb") as f:
|
736 |
df = pl.read_parquet(f)
|
737 |
-
|
738 |
df = df.with_columns(
|
739 |
-
pl.col("id").cast(pl.Int64)
|
|
|
|
|
|
|
740 |
)
|
741 |
|
742 |
-
#
|
743 |
-
for col in ("title", "status", "legalBasis","fundingScheme"):
|
744 |
-
df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
|
745 |
-
|
746 |
-
# materialize unique filter values
|
747 |
app.state.df = df
|
748 |
-
app.state.statuses
|
749 |
-
app.state.legal_bases
|
750 |
-
app.state.orgs_list
|
751 |
-
app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
|
752 |
app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
|
753 |
|
754 |
-
yield
|
755 |
-
|
756 |
# ---------------------------------------------------------------------------- #
|
757 |
# App Setup #
|
758 |
# ---------------------------------------------------------------------------- #
|
@@ -765,19 +217,29 @@ app.add_middleware(
|
|
765 |
)
|
766 |
|
767 |
# ---------------------------------------------------------------------------- #
|
768 |
-
#
|
769 |
# ---------------------------------------------------------------------------- #
|
|
|
770 |
class RAGRequest(BaseModel):
|
771 |
-
session_id: Optional[str] = None
|
772 |
-
query: str
|
773 |
|
774 |
class RAGResponse(BaseModel):
|
775 |
answer: str
|
776 |
source_ids: List[str]
|
|
|
|
|
|
|
|
|
777 |
|
778 |
def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
|
|
|
|
|
|
|
|
|
779 |
chain = app.state.rag_chain
|
780 |
if chain is None:
|
|
|
781 |
raise HTTPException(status_code=500, detail="RAG chain not initialized")
|
782 |
return chain
|
783 |
|
@@ -786,27 +248,47 @@ async def ask_rag(
|
|
786 |
req: RAGRequest,
|
787 |
rag_chain = Depends(rag_chain_depender)
|
788 |
):
|
789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
try:
|
|
|
791 |
result = await rag_chain.ainvoke({"question": req.query})
|
792 |
-
logger.info("
|
|
|
|
|
793 |
if not isinstance(result, dict):
|
|
|
794 |
result2 = await rag_chain.acall({"question": req.query})
|
795 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
796 |
answer = result.get("answer")
|
797 |
docs = result.get("source_documents", [])
|
798 |
-
sources = [d.metadata.get("id","") for d in docs]
|
|
|
799 |
return RAGResponse(answer=answer, source_ids=sources)
|
800 |
|
801 |
except Exception as e:
|
802 |
-
#
|
803 |
traceback.print_exc()
|
804 |
-
#
|
805 |
raise HTTPException(status_code=500, detail=str(e))
|
806 |
|
|
|
807 |
# ---------------------------------------------------------------------------- #
|
808 |
# Data Endpoints #
|
809 |
# ---------------------------------------------------------------------------- #
|
|
|
810 |
@app.get("/api/projects")
|
811 |
def get_projects(
|
812 |
page: int = 0,
|
@@ -821,10 +303,25 @@ def get_projects(
|
|
821 |
sortOrder: str = "desc",
|
822 |
sortField: str = "startDate",
|
823 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
824 |
df: pl.DataFrame = app.state.df
|
825 |
start = page * limit
|
826 |
sel = df
|
827 |
|
|
|
828 |
if search:
|
829 |
sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
|
830 |
if status:
|
@@ -843,19 +340,22 @@ def get_projects(
|
|
843 |
if proj_id:
|
844 |
sel = sel.filter(pl.col("id") == int(proj_id))
|
845 |
|
|
|
846 |
base_cols = [
|
847 |
"id","title","status","startDate","endDate","ecMaxContribution","acronym",
|
848 |
"legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
|
849 |
"list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
|
850 |
]
|
851 |
-
#
|
852 |
for i in range(1,7):
|
853 |
base_cols += [f"top{i}_feature", f"top{i}_shap"]
|
854 |
base_cols += ["predicted_label","predicted_prob"]
|
855 |
|
856 |
-
|
|
|
857 |
sortField = sortField if sortField in df.columns else "startDate"
|
858 |
|
|
|
859 |
rows = (
|
860 |
sel.sort(sortField, descending=sort_desc)
|
861 |
.slice(start, limit)
|
@@ -865,6 +365,7 @@ def get_projects(
|
|
865 |
|
866 |
projects = []
|
867 |
for row in rows:
|
|
|
868 |
explanations = []
|
869 |
for i in range(1,7):
|
870 |
feat = row.pop(f"top{i}_feature", None)
|
@@ -873,7 +374,7 @@ def get_projects(
|
|
873 |
explanations.append({"feature": feat, "shap": shap})
|
874 |
row["explanations"] = explanations
|
875 |
|
876 |
-
# publications
|
877 |
raw_pubs = row.pop("list_publications", []) or []
|
878 |
pub_counts: Dict[str,int] = {}
|
879 |
for p in raw_pubs:
|
@@ -886,14 +387,20 @@ def get_projects(
|
|
886 |
|
887 |
@app.get("/api/filters")
|
888 |
def get_filters(request: Request):
|
|
|
|
|
|
|
|
|
|
|
889 |
df = app.state.df
|
890 |
params = request.query_params
|
891 |
|
|
|
892 |
if s := params.get("status"):
|
893 |
-
df = df.filter(pl.col("status").is_null() if s=="UNKNOWN"
|
894 |
-
else pl.col("_status_lc")==s.lower())
|
895 |
if lb := params.get("legalBasis"):
|
896 |
-
df = df.filter(pl.col("_legalBasis_lc")==lb.lower())
|
897 |
if org := params.get("organization"):
|
898 |
df = df.filter(pl.col("list_name").list.contains(org))
|
899 |
if c := params.get("country"):
|
@@ -902,6 +409,7 @@ def get_filters(request: Request):
|
|
902 |
df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
|
903 |
|
904 |
def normalize(vals):
|
|
|
905 |
return sorted({("UNKNOWN" if v is None else v) for v in vals})
|
906 |
|
907 |
return {
|
@@ -910,31 +418,37 @@ def get_filters(request: Request):
|
|
910 |
"organizations": normalize(df["list_name"].explode().to_list())[:500],
|
911 |
"countries": normalize(df["list_country"].explode().to_list()),
|
912 |
"fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
|
913 |
-
#"ids": normalize(df["id"].to_list()),
|
914 |
}
|
915 |
|
916 |
@app.get("/api/stats")
|
917 |
def get_stats(request: Request):
|
|
|
|
|
|
|
|
|
|
|
918 |
lf = app.state.df.lazy()
|
919 |
params = request.query_params
|
920 |
|
|
|
921 |
if s := params.get("status"):
|
922 |
-
lf = lf.filter(pl.col("_status_lc")==s.lower())
|
923 |
if lb := params.get("legalBasis"):
|
924 |
-
lf = lf.filter(pl.col("_legalBasis_lc")==lb.lower())
|
925 |
if org := params.get("organization"):
|
926 |
lf = lf.filter(pl.col("list_name").list.contains(org))
|
927 |
if c := params.get("country"):
|
928 |
lf = lf.filter(pl.col("list_country").list.contains(c))
|
929 |
if mn := params.get("minFunding"):
|
930 |
-
lf = lf.filter(pl.col("ecMaxContribution")>=int(mn))
|
931 |
if mx := params.get("maxFunding"):
|
932 |
-
lf = lf.filter(pl.col("ecMaxContribution")<=int(mx))
|
933 |
if y1 := params.get("minYear"):
|
934 |
-
lf = lf.filter(pl.col("startDate").dt.year()>=int(y1))
|
935 |
if y2 := params.get("maxYear"):
|
936 |
-
lf = lf.filter(pl.col("startDate").dt.year()<=int(y2))
|
937 |
|
|
|
938 |
grouped = (
|
939 |
lf.select(pl.col("startDate").dt.year().alias("year"))
|
940 |
.group_by("year")
|
@@ -944,6 +458,7 @@ def get_stats(request: Request):
|
|
944 |
)
|
945 |
years, counts = grouped["year"].to_list(), grouped["count"].to_list()
|
946 |
|
|
|
947 |
return {
|
948 |
"Projects per Year": {"labels": years, "values": counts},
|
949 |
"Projects per Year 2": {"labels": years, "values": counts},
|
@@ -955,11 +470,17 @@ def get_stats(request: Request):
|
|
955 |
|
956 |
@app.get("/api/project/{project_id}/organizations")
|
957 |
def get_project_organizations(project_id: str):
|
|
|
|
|
|
|
|
|
|
|
958 |
df = app.state.df
|
959 |
-
sel = df.filter(pl.col("id")==int(project_id))
|
960 |
if sel.is_empty():
|
961 |
raise HTTPException(status_code=404, detail="Project not found")
|
962 |
|
|
|
963 |
orgs_df = (
|
964 |
sel.select([
|
965 |
pl.col("list_name").explode().alias("name"),
|
@@ -973,9 +494,11 @@ def get_project_organizations(project_id: str):
|
|
973 |
pl.col("list_geolocation").explode().alias("geoloc"),
|
974 |
])
|
975 |
.with_columns([
|
|
|
976 |
pl.col("geoloc").str.split(",").alias("latlon"),
|
977 |
])
|
978 |
.with_columns([
|
|
|
979 |
pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
|
980 |
pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
|
981 |
])
|
@@ -985,5 +508,6 @@ def get_project_organizations(project_id: str):
|
|
985 |
"activityType","orgURL","country","latitude","longitude"
|
986 |
])
|
987 |
)
|
988 |
-
logger.info(f"{orgs_df.to_dicts()}")
|
989 |
return orgs_df.to_dicts()
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
import traceback
|
|
|
|
|
|
|
6 |
from contextlib import asynccontextmanager
|
7 |
+
from functools import lru_cache
|
8 |
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
9 |
|
|
|
|
|
10 |
import aiofiles
|
11 |
+
import faiss
|
12 |
+
import gcsfs
|
13 |
import polars as pl
|
14 |
+
import torch
|
15 |
import zipfile
|
16 |
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
18 |
+
from pydantic import BaseModel, BaseSettings, PrivateAttr
|
19 |
+
from pydantic_settings import BaseSettings as SettingsBase
|
20 |
+
from sentence_transformers import CrossEncoder
|
21 |
+
from starlette.concurrency import run_in_threadpool
|
22 |
+
from tqdm import tqdm
|
23 |
+
from transformers import ( # Transformers for LLM pipeline
|
24 |
+
AutoModelForCausalLM,
|
25 |
+
AutoModelForSeq2SeqLM,
|
26 |
+
AutoTokenizer,
|
27 |
+
pipeline,
|
28 |
+
T5ForConditionalGeneration,
|
29 |
+
T5Tokenizer,
|
30 |
+
)
|
31 |
|
32 |
+
# LangChain imports for RAG
|
|
|
|
|
|
|
|
|
|
|
33 |
from langchain.chains import ConversationalRetrievalChain
|
34 |
+
from langchain.memory import ConversationBufferWindowMemory
|
35 |
from langchain.prompts import PromptTemplate
|
36 |
+
from langchain.schema import BaseRetriever, Document
|
37 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
+
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
39 |
+
from langchain_community.vectorstores import FAISS
|
40 |
+
from langchain_community.retrievers.document_compressors import DocumentCompressorPipeline
|
41 |
+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Project-specific imports
|
44 |
+
from app.rag import build_indexes, HybridRetriever
|
45 |
|
46 |
# ---------------------------------------------------------------------------- #
|
47 |
# Settings #
|
48 |
# ---------------------------------------------------------------------------- #
|
|
|
|
|
|
|
49 |
|
50 |
+
class Settings(SettingsBase):
|
51 |
+
"""
|
52 |
+
Configuration settings loaded from environment or .env file.
|
53 |
+
"""
|
54 |
+
# Data sources
|
55 |
parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
|
56 |
+
whoosh_dir: str = "gs://mda_eu_project/whoosh_index"
|
57 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
58 |
+
|
59 |
+
# Model names
|
60 |
+
embedding_model: str = "sentence-transformers/LaBSE"
|
61 |
+
llm_model: str = "google/flan-t5-base"
|
62 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
63 |
+
|
64 |
# RAG parameters
|
65 |
+
chunk_size: int = 750
|
66 |
chunk_overlap: int = 100
|
67 |
+
hybrid_k: int = 2
|
68 |
assistant_role: str = (
|
69 |
+
"You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
|
70 |
)
|
71 |
skip_warmup: bool = True
|
72 |
+
|
73 |
+
# CORS
|
74 |
allowed_origins: List[str] = ["*"]
|
75 |
|
76 |
class Config:
|
77 |
env_file = ".env"
|
78 |
|
79 |
+
# Instantiate settings and logger
|
80 |
settings = Settings()
|
81 |
+
logging.basicConfig(level=logging.INFO)
|
82 |
+
logger = logging.getLogger(__name__)
|
83 |
|
84 |
+
# Pre-instantiate the embedding model for reuse
|
85 |
+
EMBEDDING = HuggingFaceEmbeddings(
|
86 |
+
model_name=settings.embedding_model,
|
87 |
+
model_kwargs={"trust_remote_code": True},
|
88 |
+
)
|
89 |
|
90 |
@lru_cache(maxsize=256)
|
91 |
def embed_query_cached(query: str) -> List[float]:
|
92 |
+
"""Cache embedding vectors for repeated queries."""
|
93 |
return EMBEDDING.embed_query(query.strip().lower())
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
# ---------------------------------------------------------------------------- #
|
96 |
+
# Application Lifespan #
|
97 |
# ---------------------------------------------------------------------------- #
|
98 |
+
|
99 |
+
app = FastAPI(lifespan=lambda app: lifespan(app))
|
100 |
+
|
101 |
@asynccontextmanager
|
102 |
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
103 |
+
"""
|
104 |
+
Startup: initialize RAG chain, embeddings, memory, indexes, and load data.
|
105 |
+
Shutdown: clean up resources if needed.
|
106 |
+
"""
|
107 |
+
# 1) Initialize document compressor
|
108 |
logger.info("Initializing Document Compressor")
|
109 |
compressor = DocumentCompressorPipeline(
|
110 |
transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
|
111 |
)
|
112 |
|
113 |
+
# 2) Initialize and quantize Cross-Encoder
|
114 |
logger.info("Initializing Cross-Encoder")
|
115 |
cross_encoder = CrossEncoder(settings.cross_encoder_model)
|
|
|
|
|
116 |
cross_encoder.model = torch.quantization.quantize_dynamic(
|
117 |
cross_encoder.model,
|
118 |
{torch.nn.Linear},
|
|
|
120 |
)
|
121 |
logger.info("Cross-Encoder quantized")
|
122 |
|
123 |
+
# 3) Build Seq2Seq pipeline and wrap in LangChain
|
124 |
+
logger.info("Initializing LLM pipeline")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
|
126 |
+
model = T5ForConditionalGeneration.from_pretrained(settings.llm_model)
|
127 |
+
model = torch.quantization.quantize_dynamic(
|
128 |
model, {torch.nn.Linear}, dtype=torch.qint8
|
129 |
)
|
|
|
130 |
gen_pipe = pipeline(
|
131 |
"text2text-generation",
|
132 |
model=model,
|
|
|
136 |
do_sample=True,
|
137 |
temperature=0.7,
|
138 |
)
|
|
|
139 |
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
140 |
|
141 |
+
# 4) Initialize conversation memory
|
142 |
logger.info("Initializing Conversation Memory")
|
143 |
memory = ConversationBufferWindowMemory(
|
144 |
memory_key="chat_history",
|
|
|
145 |
input_key="question",
|
146 |
output_key="answer",
|
147 |
return_messages=True,
|
148 |
+
k=5,
|
149 |
)
|
150 |
+
|
151 |
+
# 5) Build or load indexes for vectorstore and Whoosh
|
152 |
+
logger.info("Building or loading indexes")
|
153 |
vs, ix = await build_indexes(
|
154 |
settings.parquet_path,
|
155 |
settings.vectorstore_path,
|
|
|
158 |
settings.chunk_overlap,
|
159 |
None,
|
160 |
)
|
|
|
161 |
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
|
162 |
+
|
163 |
+
# 6) Define prompt template for RAG chain
|
164 |
prompt = PromptTemplate.from_template(
|
165 |
+
f"{settings.assistant_role}\n"
|
166 |
"{context}\n"
|
167 |
+
"User Question:\n{question}\n"
|
168 |
+
"Answer:" # Rules are embedded in assistant_role
|
169 |
+
)
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
# 7) Instantiate the conversational retrieval chain
|
172 |
logger.info("Initializing Retrieval Chain")
|
173 |
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
|
174 |
llm=llm,
|
|
|
178 |
return_source_documents=True,
|
179 |
)
|
180 |
|
181 |
+
# Optional warmup
|
182 |
if not settings.skip_warmup:
|
183 |
+
logger.info("Warming up RAG chain")
|
184 |
await app.state.rag_chain.ainvoke({"question": "warmup"})
|
|
|
185 |
|
186 |
+
# 8) Load project data into Polars DataFrame
|
187 |
+
logger.info("Loading Parquet data from GCS")
|
188 |
fs = gcsfs.GCSFileSystem()
|
189 |
with fs.open(settings.parquet_path, "rb") as f:
|
190 |
df = pl.read_parquet(f)
|
191 |
+
# Cast id to integer and lowercase key columns for filtering
|
192 |
df = df.with_columns(
|
193 |
+
pl.col("id").cast(pl.Int64),
|
194 |
+
*(pl.col(col).str.to_lowercase().alias(f"_{col}_lc") for col in [
|
195 |
+
"title", "status", "legalBasis", "fundingScheme"
|
196 |
+
])
|
197 |
)
|
198 |
|
199 |
+
# Cache DataFrame and filter values in app state
|
|
|
|
|
|
|
|
|
200 |
app.state.df = df
|
201 |
+
app.state.statuses = df["_status_lc"].unique().to_list()
|
202 |
+
app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
|
203 |
+
app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
|
|
|
204 |
app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
|
205 |
|
206 |
+
yield # Application is ready
|
207 |
+
|
208 |
# ---------------------------------------------------------------------------- #
|
209 |
# App Setup #
|
210 |
# ---------------------------------------------------------------------------- #
|
|
|
217 |
)
|
218 |
|
219 |
# ---------------------------------------------------------------------------- #
|
220 |
+
# Pydantic Models #
|
221 |
# ---------------------------------------------------------------------------- #
|
222 |
+
|
223 |
class RAGRequest(BaseModel):
|
224 |
+
session_id: Optional[str] = None # Optional conversation ID
|
225 |
+
query: str # User's query text
|
226 |
|
227 |
class RAGResponse(BaseModel):
|
228 |
answer: str
|
229 |
source_ids: List[str]
|
230 |
+
|
231 |
+
# ---------------------------------------------------------------------------- #
|
232 |
+
# RAG Endpoint #
|
233 |
+
# ---------------------------------------------------------------------------- #
|
234 |
|
235 |
def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
|
236 |
+
"""
|
237 |
+
Dependency injector to retrieve the initialized RAG chain from the application state.
|
238 |
+
Raises HTTPException if chain is not yet initialized.
|
239 |
+
"""
|
240 |
chain = app.state.rag_chain
|
241 |
if chain is None:
|
242 |
+
# If the chain isn't set up, respond with a 500 server error
|
243 |
raise HTTPException(status_code=500, detail="RAG chain not initialized")
|
244 |
return chain
|
245 |
|
|
|
248 |
req: RAGRequest,
|
249 |
rag_chain = Depends(rag_chain_depender)
|
250 |
):
|
251 |
+
"""
|
252 |
+
Endpoint to process a RAG-based query.
|
253 |
+
|
254 |
+
1. Logs start of processing.
|
255 |
+
2. Invokes the RAG chain asynchronously with the user question.
|
256 |
+
3. Validates returned result structure and extracts answer + source IDs.
|
257 |
+
4. Handles any exceptions by logging traceback and returning a JSON error.
|
258 |
+
"""
|
259 |
+
logger.info("Starting to answer RAG query")
|
260 |
try:
|
261 |
+
# Asynchronously invoke the chain to get answer + docs
|
262 |
result = await rag_chain.ainvoke({"question": req.query})
|
263 |
+
logger.info("RAG results retrieved")
|
264 |
+
|
265 |
+
# Validate that the chain returned expected dict
|
266 |
if not isinstance(result, dict):
|
267 |
+
# Try sync call for debugging
|
268 |
result2 = await rag_chain.acall({"question": req.query})
|
269 |
+
raise ValueError(
|
270 |
+
f"Expected dict from chain, got {type(result)}; "
|
271 |
+
f"acall() returned {type(result2)}"
|
272 |
+
)
|
273 |
+
|
274 |
+
# Extract answer text and source document IDs
|
275 |
answer = result.get("answer")
|
276 |
docs = result.get("source_documents", [])
|
277 |
+
sources = [d.metadata.get("id", "") for d in docs]
|
278 |
+
|
279 |
return RAGResponse(answer=answer, source_ids=sources)
|
280 |
|
281 |
except Exception as e:
|
282 |
+
# Log full stacktrace to container logs
|
283 |
traceback.print_exc()
|
284 |
+
# Return HTTP 500 with error detail
|
285 |
raise HTTPException(status_code=500, detail=str(e))
|
286 |
|
287 |
+
|
288 |
# ---------------------------------------------------------------------------- #
|
289 |
# Data Endpoints #
|
290 |
# ---------------------------------------------------------------------------- #
|
291 |
+
|
292 |
@app.get("/api/projects")
|
293 |
def get_projects(
|
294 |
page: int = 0,
|
|
|
303 |
sortOrder: str = "desc",
|
304 |
sortField: str = "startDate",
|
305 |
):
|
306 |
+
"""
|
307 |
+
Paginated project listing with optional filtering and sorting.
|
308 |
+
|
309 |
+
Query Parameters:
|
310 |
+
- page: zero-based page index
|
311 |
+
- limit: number of items per page
|
312 |
+
- search: substring search in project title
|
313 |
+
- status, legalBasis, organization, country, fundingScheme: filters
|
314 |
+
- proj_id: exact project ID filter
|
315 |
+
- sortOrder: 'asc' or 'desc'
|
316 |
+
- sortField: field name to sort by (fallback to startDate)
|
317 |
+
|
318 |
+
Returns a list of project dicts including explanations and publication counts.
|
319 |
+
"""
|
320 |
df: pl.DataFrame = app.state.df
|
321 |
start = page * limit
|
322 |
sel = df
|
323 |
|
324 |
+
# Apply text and field filters as needed
|
325 |
if search:
|
326 |
sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
|
327 |
if status:
|
|
|
340 |
if proj_id:
|
341 |
sel = sel.filter(pl.col("id") == int(proj_id))
|
342 |
|
343 |
+
# Base columns to return
|
344 |
base_cols = [
|
345 |
"id","title","status","startDate","endDate","ecMaxContribution","acronym",
|
346 |
"legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
|
347 |
"list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
|
348 |
]
|
349 |
+
# Append top feature & SHAP value columns
|
350 |
for i in range(1,7):
|
351 |
base_cols += [f"top{i}_feature", f"top{i}_shap"]
|
352 |
base_cols += ["predicted_label","predicted_prob"]
|
353 |
|
354 |
+
# Determine sort direction and safe field
|
355 |
+
sort_desc = sortOrder.lower() == "desc"
|
356 |
sortField = sortField if sortField in df.columns else "startDate"
|
357 |
|
358 |
+
# Query, sort, slice, and collect to Python dicts
|
359 |
rows = (
|
360 |
sel.sort(sortField, descending=sort_desc)
|
361 |
.slice(start, limit)
|
|
|
365 |
|
366 |
projects = []
|
367 |
for row in rows:
|
368 |
+
# Reformat SHAP explanations into list of dicts
|
369 |
explanations = []
|
370 |
for i in range(1,7):
|
371 |
feat = row.pop(f"top{i}_feature", None)
|
|
|
374 |
explanations.append({"feature": feat, "shap": shap})
|
375 |
row["explanations"] = explanations
|
376 |
|
377 |
+
# Aggregate publications counts
|
378 |
raw_pubs = row.pop("list_publications", []) or []
|
379 |
pub_counts: Dict[str,int] = {}
|
380 |
for p in raw_pubs:
|
|
|
387 |
|
388 |
@app.get("/api/filters")
|
389 |
def get_filters(request: Request):
|
390 |
+
"""
|
391 |
+
Retrieve available filter options based on current dataset and optional query filters.
|
392 |
+
|
393 |
+
Returns JSON with lists for statuses, legalBases, organizations, countries, and fundingSchemes.
|
394 |
+
"""
|
395 |
df = app.state.df
|
396 |
params = request.query_params
|
397 |
|
398 |
+
# Dynamically filter df based on provided params
|
399 |
if s := params.get("status"):
|
400 |
+
df = df.filter(pl.col("status").is_null() if s == "UNKNOWN"
|
401 |
+
else pl.col("_status_lc") == s.lower())
|
402 |
if lb := params.get("legalBasis"):
|
403 |
+
df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
|
404 |
if org := params.get("organization"):
|
405 |
df = df.filter(pl.col("list_name").list.contains(org))
|
406 |
if c := params.get("country"):
|
|
|
409 |
df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
|
410 |
|
411 |
def normalize(vals):
|
412 |
+
# Map None to "UNKNOWN" and return sorted unique list
|
413 |
return sorted({("UNKNOWN" if v is None else v) for v in vals})
|
414 |
|
415 |
return {
|
|
|
418 |
"organizations": normalize(df["list_name"].explode().to_list())[:500],
|
419 |
"countries": normalize(df["list_country"].explode().to_list()),
|
420 |
"fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
|
|
|
421 |
}
|
422 |
|
423 |
@app.get("/api/stats")
|
424 |
def get_stats(request: Request):
|
425 |
+
"""
|
426 |
+
Compute annual statistics on projects with optional filters for status, legal basis, etc.
|
427 |
+
|
428 |
+
Returns a dict of chart data for projects per year.
|
429 |
+
"""
|
430 |
lf = app.state.df.lazy()
|
431 |
params = request.query_params
|
432 |
|
433 |
+
# Apply lazy filters
|
434 |
if s := params.get("status"):
|
435 |
+
lf = lf.filter(pl.col("_status_lc") == s.lower())
|
436 |
if lb := params.get("legalBasis"):
|
437 |
+
lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower())
|
438 |
if org := params.get("organization"):
|
439 |
lf = lf.filter(pl.col("list_name").list.contains(org))
|
440 |
if c := params.get("country"):
|
441 |
lf = lf.filter(pl.col("list_country").list.contains(c))
|
442 |
if mn := params.get("minFunding"):
|
443 |
+
lf = lf.filter(pl.col("ecMaxContribution") >= int(mn))
|
444 |
if mx := params.get("maxFunding"):
|
445 |
+
lf = lf.filter(pl.col("ecMaxContribution") <= int(mx))
|
446 |
if y1 := params.get("minYear"):
|
447 |
+
lf = lf.filter(pl.col("startDate").dt.year() >= int(y1))
|
448 |
if y2 := params.get("maxYear"):
|
449 |
+
lf = lf.filter(pl.col("startDate").dt.year() <= int(y2))
|
450 |
|
451 |
+
# Group by year and count
|
452 |
grouped = (
|
453 |
lf.select(pl.col("startDate").dt.year().alias("year"))
|
454 |
.group_by("year")
|
|
|
458 |
)
|
459 |
years, counts = grouped["year"].to_list(), grouped["count"].to_list()
|
460 |
|
461 |
+
# Return data ready for frontend charts
|
462 |
return {
|
463 |
"Projects per Year": {"labels": years, "values": counts},
|
464 |
"Projects per Year 2": {"labels": years, "values": counts},
|
|
|
470 |
|
471 |
@app.get("/api/project/{project_id}/organizations")
|
472 |
def get_project_organizations(project_id: str):
|
473 |
+
"""
|
474 |
+
Retrieve organization details for a given project ID, including geolocation.
|
475 |
+
|
476 |
+
Raises 404 if the project ID does not exist.
|
477 |
+
"""
|
478 |
df = app.state.df
|
479 |
+
sel = df.filter(pl.col("id") == int(project_id))
|
480 |
if sel.is_empty():
|
481 |
raise HTTPException(status_code=404, detail="Project not found")
|
482 |
|
483 |
+
# Explode list columns and parse latitude/longitude
|
484 |
orgs_df = (
|
485 |
sel.select([
|
486 |
pl.col("list_name").explode().alias("name"),
|
|
|
494 |
pl.col("list_geolocation").explode().alias("geoloc"),
|
495 |
])
|
496 |
.with_columns([
|
497 |
+
# Split "lat,lon" string into list
|
498 |
pl.col("geoloc").str.split(",").alias("latlon"),
|
499 |
])
|
500 |
.with_columns([
|
501 |
+
# Cast to floats for numeric use
|
502 |
pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
|
503 |
pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
|
504 |
])
|
|
|
508 |
"activityType","orgURL","country","latitude","longitude"
|
509 |
])
|
510 |
)
|
511 |
+
logger.info(f"Organization data for project {project_id}: {orgs_df.to_dicts()}")
|
512 |
return orgs_df.to_dicts()
|
513 |
+
|
backend/rag.py
CHANGED
@@ -33,6 +33,48 @@ from whoosh.qparser import MultifieldParser
|
|
33 |
from tqdm import tqdm
|
34 |
import faiss
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
from functools import lru_cache
|
37 |
|
38 |
# === Logging ===
|
@@ -46,16 +88,16 @@ class Settings(BaseSettings):
|
|
46 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
47 |
# Models
|
48 |
embedding_model: str = "sentence-transformers/LaBSE"
|
49 |
-
llm_model: str = "
|
50 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
51 |
# RAG parameters
|
52 |
chunk_size: int = 750
|
53 |
chunk_overlap: int = 100
|
54 |
-
hybrid_k: int =
|
55 |
assistant_role: str = (
|
56 |
-
"You are a
|
57 |
)
|
58 |
-
skip_warmup: bool =
|
59 |
allowed_origins: List[str] = ["*"]
|
60 |
|
61 |
class Config:
|
@@ -66,50 +108,81 @@ settings = Settings()
|
|
66 |
# === Global Embeddings & Cache ===
|
67 |
EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
|
68 |
|
69 |
-
@lru_cache(maxsize=
|
70 |
def embed_query_cached(query: str) -> List[float]:
|
71 |
"""Cache embedding vectors for queries."""
|
72 |
return EMBEDDING.embed_query(query.strip().lower())
|
73 |
|
74 |
# === Whoosh Cache & Builder ===
|
75 |
-
_WHOOSH_CACHE: Dict[str, index.Index] = {}
|
76 |
-
|
77 |
async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
|
78 |
-
|
|
|
|
|
|
|
79 |
fs = gcsfs.GCSFileSystem()
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
113 |
|
114 |
# === Document Loader ===
|
115 |
async def load_documents(
|
@@ -117,7 +190,8 @@ async def load_documents(
|
|
117 |
sample_size: Optional[int] = None
|
118 |
) -> List[Document]:
|
119 |
"""
|
120 |
-
Load a Parquet file
|
|
|
121 |
"""
|
122 |
def _read_local(p: str, n: Optional[int]):
|
123 |
# streaming scan keeps memory low
|
@@ -228,81 +302,119 @@ async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
|
|
228 |
|
229 |
# === Helper: build or load FAISS with mmap ===
|
230 |
async def build_or_load_faiss(
|
231 |
-
|
232 |
vectorstore_path: str,
|
233 |
batch_size: int = 15000
|
234 |
) -> FAISS:
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
if isinstance(saved, tuple):
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
elif len(saved) == 2:
|
258 |
-
docstore, index_to_docstore = saved
|
259 |
-
else:
|
260 |
-
raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
|
261 |
else:
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
elif hasattr(saved, '_index_to_docstore'):
|
271 |
-
index_to_docstore = saved._index_to_docstore
|
272 |
-
elif hasattr(saved, '_faiss_index_to_docstore'):
|
273 |
-
index_to_docstore = saved._faiss_index_to_docstore
|
274 |
-
else:
|
275 |
-
raise AttributeError("Could not find index_to_docstore in FAISS metadata")
|
276 |
-
# reconstruct FAISS wrapper
|
277 |
vs = FAISS(
|
278 |
embedding_function=EMBEDDING,
|
279 |
-
index=
|
280 |
docstore=docstore,
|
281 |
index_to_docstore_id=index_to_docstore,
|
282 |
)
|
|
|
283 |
return vs
|
284 |
|
285 |
-
# 2)
|
286 |
-
logger.info(
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
unit="batch"):
|
291 |
-
batch = chunks[i : i + batch_size]
|
292 |
|
|
|
|
|
|
|
293 |
if vs is None:
|
294 |
vs = FAISS.from_documents(batch, EMBEDDING)
|
295 |
else:
|
296 |
vs.add_documents(batch)
|
|
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
|
303 |
-
assert vs is not None, "No documents to index!"
|
304 |
return vs
|
305 |
|
|
|
306 |
# === Index Builder ===
|
307 |
async def build_indexes(
|
308 |
parquet_path: str,
|
@@ -312,6 +424,9 @@ async def build_indexes(
|
|
312 |
chunk_overlap: int,
|
313 |
debug_size: Optional[int]
|
314 |
) -> Tuple[FAISS, index.Index]:
|
|
|
|
|
|
|
315 |
docs = await load_documents(parquet_path, debug_size)
|
316 |
ix = await build_whoosh_index(docs, whoosh_dir)
|
317 |
|
@@ -324,3 +439,48 @@ async def build_indexes(
|
|
324 |
vs = await build_or_load_faiss(chunks, vectorstore_path)
|
325 |
|
326 |
return vs, ix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from tqdm import tqdm
|
34 |
import faiss
|
35 |
|
36 |
+
from functools import lru_cache
|
37 |
+
from fastapi import FastAPI, Request, HTTPException, Depends
|
38 |
+
from fastapi.middleware.cors import CORSMiddleware
|
39 |
+
import traceback
|
40 |
+
from starlette.concurrency import run_in_threadpool
|
41 |
+
from pydantic import BaseModel
|
42 |
+
from pydantic_settings import BaseSettings
|
43 |
+
from contextlib import asynccontextmanager
|
44 |
+
from typing import Any, Dict, List, Optional, AsyncGenerator, Tuple
|
45 |
+
|
46 |
+
import os
|
47 |
+
import logging
|
48 |
+
import aiofiles
|
49 |
+
import polars as pl
|
50 |
+
import zipfile
|
51 |
+
import gcsfs
|
52 |
+
|
53 |
+
from langchain.schema import Document,BaseRetriever
|
54 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
55 |
+
from langchain_community.vectorstores import FAISS
|
56 |
+
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
57 |
+
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
58 |
+
from langchain.memory import ConversationBufferWindowMemory
|
59 |
+
from langchain.chains import ConversationalRetrievalChain
|
60 |
+
from langchain.prompts import PromptTemplate
|
61 |
+
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
62 |
+
|
63 |
+
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer,T5ForConditionalGeneration
|
64 |
+
from sentence_transformers import CrossEncoder
|
65 |
+
|
66 |
+
from whoosh import index
|
67 |
+
from whoosh.fields import Schema, TEXT, ID
|
68 |
+
from whoosh.analysis import StemmingAnalyzer
|
69 |
+
from whoosh.qparser import MultifieldParser
|
70 |
+
import pickle
|
71 |
+
from pydantic import PrivateAttr
|
72 |
+
from tqdm import tqdm
|
73 |
+
import faiss
|
74 |
+
import torch
|
75 |
+
import tempfile
|
76 |
+
import shutil
|
77 |
+
|
78 |
from functools import lru_cache
|
79 |
|
80 |
# === Logging ===
|
|
|
88 |
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
89 |
# Models
|
90 |
embedding_model: str = "sentence-transformers/LaBSE"
|
91 |
+
llm_model: str = "google/flan-t5-base"
|
92 |
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
93 |
# RAG parameters
|
94 |
chunk_size: int = 750
|
95 |
chunk_overlap: int = 100
|
96 |
+
hybrid_k: int = 2
|
97 |
assistant_role: str = (
|
98 |
+
"You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
|
99 |
)
|
100 |
+
skip_warmup: bool = True
|
101 |
allowed_origins: List[str] = ["*"]
|
102 |
|
103 |
class Config:
|
|
|
108 |
# === Global Embeddings & Cache ===
|
109 |
EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
|
110 |
|
111 |
+
@lru_cache(maxsize=128)
|
112 |
def embed_query_cached(query: str) -> List[float]:
|
113 |
"""Cache embedding vectors for queries."""
|
114 |
return EMBEDDING.embed_query(query.strip().lower())
|
115 |
|
116 |
# === Whoosh Cache & Builder ===
|
|
|
|
|
117 |
async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
|
118 |
+
"""
|
119 |
+
If gs://.../whoosh_index.zip exists, download & extract it once.
|
120 |
+
Otherwise build locally from docs and upload the ZIP back to GCS.
|
121 |
+
"""
|
122 |
fs = gcsfs.GCSFileSystem()
|
123 |
+
is_gcs = whoosh_dir.startswith("gs://")
|
124 |
+
zip_uri = whoosh_dir.rstrip("/") + ".zip"
|
125 |
+
|
126 |
+
local_zip = "/tmp/whoosh_index.zip"
|
127 |
+
local_dir = "/tmp/whoosh_index"
|
128 |
+
|
129 |
+
# Clean slate
|
130 |
+
if os.path.exists(local_dir):
|
131 |
+
shutil.rmtree(local_dir)
|
132 |
+
os.makedirs(local_dir, exist_ok=True)
|
133 |
+
|
134 |
+
# 1️⃣ Try downloading the ZIP if it exists on GCS
|
135 |
+
if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
|
136 |
+
logger.info("Found whoosh_index.zip on GCS; downloading…")
|
137 |
+
await run_in_threadpool(fs.get, zip_uri, local_zip)
|
138 |
+
# Extract all files (flat) into local_dir
|
139 |
+
with zipfile.ZipFile(local_zip, "r") as zf:
|
140 |
+
for member in zf.infolist():
|
141 |
+
if member.is_dir():
|
142 |
+
continue
|
143 |
+
filename = os.path.basename(member.filename)
|
144 |
+
if not filename:
|
145 |
+
continue
|
146 |
+
target = os.path.join(local_dir, filename)
|
147 |
+
os.makedirs(os.path.dirname(target), exist_ok=True)
|
148 |
+
with zf.open(member) as src, open(target, "wb") as dst:
|
149 |
+
dst.write(src.read())
|
150 |
+
logger.info("Whoosh index extracted from ZIP.")
|
151 |
+
else:
|
152 |
+
logger.info("No whoosh_index.zip found; building index from docs.")
|
153 |
+
|
154 |
+
# Define the schema with stored content
|
155 |
+
schema = Schema(
|
156 |
+
id=ID(stored=True, unique=True),
|
157 |
+
content=TEXT(stored=True, analyzer=StemmingAnalyzer()),
|
158 |
+
)
|
159 |
+
|
160 |
+
# Create the index
|
161 |
+
ix = index.create_in(local_dir, schema)
|
162 |
+
writer = ix.writer()
|
163 |
+
for doc in docs:
|
164 |
+
writer.add_document(
|
165 |
+
id=doc.metadata.get("id", ""),
|
166 |
+
content=doc.page_content,
|
167 |
)
|
168 |
+
writer.commit()
|
169 |
+
logger.info("Whoosh index built locally.")
|
170 |
+
|
171 |
+
# Upload the ZIP back to GCS
|
172 |
+
if is_gcs:
|
173 |
+
logger.info("Zipping and uploading new whoosh_index.zip to GCS…")
|
174 |
+
with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
|
175 |
+
for root, _, files in os.walk(local_dir):
|
176 |
+
for fname in files:
|
177 |
+
full = os.path.join(root, fname)
|
178 |
+
arc = os.path.relpath(full, local_dir)
|
179 |
+
zf.write(full, arc)
|
180 |
+
await run_in_threadpool(fs.put, local_zip, zip_uri)
|
181 |
+
logger.info("Uploaded whoosh_index.zip to GCS.")
|
182 |
+
|
183 |
+
# 2️⃣ Finally open the index and return it
|
184 |
+
ix = index.open_dir(local_dir)
|
185 |
+
return ix
|
186 |
|
187 |
# === Document Loader ===
|
188 |
async def load_documents(
|
|
|
190 |
sample_size: Optional[int] = None
|
191 |
) -> List[Document]:
|
192 |
"""
|
193 |
+
Load project data from a Parquet file (local path or GCS URI),
|
194 |
+
assemble metadata context for each row, and return as Document objects.
|
195 |
"""
|
196 |
def _read_local(p: str, n: Optional[int]):
|
197 |
# streaming scan keeps memory low
|
|
|
302 |
|
303 |
# === Helper: build or load FAISS with mmap ===
|
304 |
async def build_or_load_faiss(
|
305 |
+
docs: List[Document],
|
306 |
vectorstore_path: str,
|
307 |
batch_size: int = 15000
|
308 |
) -> FAISS:
|
309 |
+
"""
|
310 |
+
Expects a ZIP at vectorstore_path + ".zip" containing:
|
311 |
+
- index.faiss
|
312 |
+
- index.pkl
|
313 |
+
Files may be nested under a subfolder (e.g. vectorstore_index_colab/).
|
314 |
+
If the ZIP exists on GCS, download & load only.
|
315 |
+
Otherwise, build from `docs`, save, re-zip, and upload.
|
316 |
+
"""
|
317 |
+
fs = gcsfs.GCSFileSystem()
|
318 |
+
is_gcs = vectorstore_path.startswith("gs://")
|
319 |
+
zip_uri = vectorstore_path.rstrip("/") + ".zip"
|
320 |
+
|
321 |
+
local_zip = "/tmp/faiss_index.zip"
|
322 |
+
local_dir = "/tmp/faiss_store"
|
323 |
+
|
324 |
+
# 1) if ZIP exists, download & extract
|
325 |
+
if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
|
326 |
+
logger.info("Found FAISS ZIP on GCS; loading only.")
|
327 |
+
# clean slate
|
328 |
+
if os.path.exists(local_dir):
|
329 |
+
shutil.rmtree(local_dir)
|
330 |
+
os.makedirs(local_dir, exist_ok=True)
|
331 |
+
|
332 |
+
# download zip
|
333 |
+
await run_in_threadpool(fs.get, zip_uri, local_zip)
|
334 |
+
|
335 |
+
# extract
|
336 |
+
def _extract():
|
337 |
+
with zipfile.ZipFile(local_zip, "r") as zf:
|
338 |
+
zf.extractall(local_dir)
|
339 |
+
await run_in_threadpool(_extract)
|
340 |
+
|
341 |
+
# locate the two files anywhere under local_dir
|
342 |
+
idx_path = None
|
343 |
+
meta_path = None
|
344 |
+
for root, _, files in os.walk(local_dir):
|
345 |
+
if "index.faiss" in files:
|
346 |
+
idx_path = os.path.join(root, "index.faiss")
|
347 |
+
if "index.pkl" in files:
|
348 |
+
meta_path = os.path.join(root, "index.pkl")
|
349 |
+
if not idx_path or not meta_path:
|
350 |
+
raise FileNotFoundError("Couldn't find index.faiss or index.pkl in extracted ZIP.")
|
351 |
+
|
352 |
+
# memory-map load
|
353 |
+
mmap_index = await run_in_threadpool(
|
354 |
+
faiss.read_index, idx_path, faiss.IO_FLAG_MMAP
|
355 |
+
)
|
356 |
+
|
357 |
+
# load metadata
|
358 |
+
with open(meta_path, "rb") as f:
|
359 |
+
saved = pickle.load(f)
|
360 |
+
|
361 |
+
# unpack metadata
|
362 |
if isinstance(saved, tuple):
|
363 |
+
_, docstore, index_to_docstore = (
|
364 |
+
saved if len(saved) == 3 else (None, *saved)
|
365 |
+
)
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
+
docstore = getattr(saved, "docstore", saved._docstore)
|
368 |
+
index_to_docstore = getattr(
|
369 |
+
saved,
|
370 |
+
"index_to_docstore",
|
371 |
+
getattr(saved, "_index_to_docstore", saved._faiss_index_to_docstore)
|
372 |
+
)
|
373 |
+
|
374 |
+
# reconstruct FAISS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
vs = FAISS(
|
376 |
embedding_function=EMBEDDING,
|
377 |
+
index=mmap_index,
|
378 |
docstore=docstore,
|
379 |
index_to_docstore_id=index_to_docstore,
|
380 |
)
|
381 |
+
logger.info("FAISS index loaded from ZIP.")
|
382 |
return vs
|
383 |
|
384 |
+
# 2) otherwise, build from scratch and upload
|
385 |
+
logger.info("No FAISS ZIP found; building index from scratch.")
|
386 |
+
if os.path.exists(local_dir):
|
387 |
+
shutil.rmtree(local_dir)
|
388 |
+
os.makedirs(local_dir, exist_ok=True)
|
|
|
|
|
389 |
|
390 |
+
vs: FAISS = None
|
391 |
+
for i in range(0, len(docs), batch_size):
|
392 |
+
batch = docs[i : i + batch_size]
|
393 |
if vs is None:
|
394 |
vs = FAISS.from_documents(batch, EMBEDDING)
|
395 |
else:
|
396 |
vs.add_documents(batch)
|
397 |
+
assert vs is not None, "No documents to index!"
|
398 |
|
399 |
+
# save locally
|
400 |
+
vs.save_local(local_dir)
|
401 |
+
|
402 |
+
if is_gcs:
|
403 |
+
# re-zip all contents of local_dir (flattened)
|
404 |
+
def _zip_dir():
|
405 |
+
with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
|
406 |
+
for root, _, files in os.walk(local_dir):
|
407 |
+
for fname in files:
|
408 |
+
full = os.path.join(root, fname)
|
409 |
+
arc = os.path.relpath(full, local_dir)
|
410 |
+
zf.write(full, arc)
|
411 |
+
await run_in_threadpool(_zip_dir)
|
412 |
+
await run_in_threadpool(fs.put, local_zip, zip_uri)
|
413 |
+
logger.info("Built FAISS index and uploaded ZIP to GCS.")
|
414 |
|
|
|
|
|
415 |
return vs
|
416 |
|
417 |
+
|
418 |
# === Index Builder ===
|
419 |
async def build_indexes(
|
420 |
parquet_path: str,
|
|
|
424 |
chunk_overlap: int,
|
425 |
debug_size: Optional[int]
|
426 |
) -> Tuple[FAISS, index.Index]:
|
427 |
+
"""
|
428 |
+
Load documents, build/load Whoosh and FAISS indices, and return both.
|
429 |
+
"""
|
430 |
docs = await load_documents(parquet_path, debug_size)
|
431 |
ix = await build_whoosh_index(docs, whoosh_dir)
|
432 |
|
|
|
439 |
vs = await build_or_load_faiss(chunks, vectorstore_path)
|
440 |
|
441 |
return vs, ix
|
442 |
+
|
443 |
+
# === Hybrid Retriever ===
|
444 |
+
class HybridRetriever(BaseRetriever):
|
445 |
+
"""Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
|
446 |
+
# store FAISS and Whoosh under private attributes to avoid Pydantic field errors
|
447 |
+
_vs: FAISS = PrivateAttr()
|
448 |
+
_ix: index.Index = PrivateAttr()
|
449 |
+
_compressor: DocumentCompressorPipeline = PrivateAttr()
|
450 |
+
_cross_encoder: CrossEncoder = PrivateAttr()
|
451 |
+
|
452 |
+
def __init__(
|
453 |
+
self,
|
454 |
+
vs: FAISS,
|
455 |
+
ix: index.Index,
|
456 |
+
compressor: DocumentCompressorPipeline,
|
457 |
+
cross_encoder: CrossEncoder
|
458 |
+
) -> None:
|
459 |
+
super().__init__()
|
460 |
+
object.__setattr__(self, '_vs', vs)
|
461 |
+
object.__setattr__(self, '_ix', ix)
|
462 |
+
object.__setattr__(self, '_compressor', compressor)
|
463 |
+
object.__setattr__(self, '_cross_encoder', cross_encoder)
|
464 |
+
|
465 |
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
466 |
+
# BM25 retrieval using Whoosh index
|
467 |
+
bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
|
468 |
+
# Dense retrieval using FAISS
|
469 |
+
dense_docs = self._vs.similarity_search_by_vector(
|
470 |
+
embed_query_cached(query), k=settings.hybrid_k
|
471 |
+
)
|
472 |
+
# Cross-encoder re-ranking
|
473 |
+
candidates = bm_docs + dense_docs
|
474 |
+
scores = self._cross_encoder.predict([
|
475 |
+
(query, doc.page_content) for doc in candidates
|
476 |
+
])
|
477 |
+
ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
|
478 |
+
top = [doc for _, doc in ranked[: settings.hybrid_k]]
|
479 |
+
# Compress and return
|
480 |
+
return self._compressor.compress_documents(top, query=query)
|
481 |
+
|
482 |
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
483 |
+
import asyncio
|
484 |
+
return asyncio.get_event_loop().run_until_complete(
|
485 |
+
self._aget_relevant_documents(query)
|
486 |
+
)
|
backend/requirements.txt
CHANGED
@@ -16,12 +16,8 @@ google-cloud-storage==2.11.0
|
|
16 |
# RAG & embeddings
|
17 |
pytorch-lightning==2.5.1
|
18 |
langchain
|
19 |
-
#==0.3.25
|
20 |
langchain-huggingface
|
21 |
-
#==0.2.0
|
22 |
sentence-transformers
|
23 |
-
#==4.1.0
|
24 |
-
# 2.7.0
|
25 |
langchain-community==0.3.24
|
26 |
|
27 |
# Vector store
|
@@ -37,15 +33,12 @@ tiktoken>=0.4.0
|
|
37 |
sentencepiece==0.2.0
|
38 |
transformers
|
39 |
torchvision
|
40 |
-
#accelerate==0.32.0
|
41 |
sympy>=1.13.1
|
42 |
peft
|
43 |
-
#==0.11.1
|
44 |
|
45 |
aiofiles==24.1.0
|
46 |
optimum==1.25.3
|
47 |
bitsandbytes==0.45.5
|
48 |
-
#gptqmodel==2.2.0
|
49 |
hf_xet
|
50 |
HuggingFace
|
51 |
huggingface_hub
|
|
|
16 |
# RAG & embeddings
|
17 |
pytorch-lightning==2.5.1
|
18 |
langchain
|
|
|
19 |
langchain-huggingface
|
|
|
20 |
sentence-transformers
|
|
|
|
|
21 |
langchain-community==0.3.24
|
22 |
|
23 |
# Vector store
|
|
|
33 |
sentencepiece==0.2.0
|
34 |
transformers
|
35 |
torchvision
|
|
|
36 |
sympy>=1.13.1
|
37 |
peft
|
|
|
38 |
|
39 |
aiofiles==24.1.0
|
40 |
optimum==1.25.3
|
41 |
bitsandbytes==0.45.5
|
|
|
42 |
hf_xet
|
43 |
HuggingFace
|
44 |
huggingface_hub
|
frontend/Dockerfile
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
FROM node:18-alpine AS builder
|
2 |
-
WORKDIR /app
|
3 |
-
COPY package.json package-lock.json ./
|
4 |
-
RUN npm ci
|
5 |
-
COPY . .
|
6 |
-
RUN npm run build
|
7 |
-
FROM nginx:alpine
|
8 |
-
COPY --from=builder /app/dist /usr/share/nginx/html
|
9 |
-
EXPOSE 8000
|
10 |
-
CMD ["nginx", "-g", "daemon off;"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/hooks/useAppState.ts
CHANGED
@@ -1,6 +1,12 @@
|
|
1 |
import { useState, useEffect, useRef } from "react";
|
2 |
import { debounce } from "lodash";
|
3 |
-
import type {
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
interface Stats {
|
6 |
[key: string]: {
|
@@ -12,20 +18,25 @@ interface Stats {
|
|
12 |
type SortOrder = "asc" | "desc";
|
13 |
|
14 |
export const useAppState = () => {
|
|
|
15 |
const [projects, setProjects] = useState<Project[]>([]);
|
|
|
|
|
|
|
16 |
const [search, setSearch] = useState<string>("");
|
17 |
const [statusFilter, setStatusFilter] = useState<string>("");
|
18 |
-
const [
|
19 |
-
const [
|
20 |
-
const [
|
21 |
-
const [
|
22 |
-
const [
|
23 |
-
|
24 |
-
|
25 |
-
const [
|
26 |
-
const [idFilter, setIdFilter] = useState('');
|
27 |
-
const [sortField, setSortField] = useState('');
|
28 |
const [sortOrder, setSortOrder] = useState<SortOrder>("asc");
|
|
|
|
|
|
|
29 |
const [filters, setFilters] = useState<FilterState>({
|
30 |
status: "",
|
31 |
organization: "",
|
@@ -35,146 +46,135 @@ export const useAppState = () => {
|
|
35 |
maxYear: "2025",
|
36 |
minFunding: "0",
|
37 |
maxFunding: "10000000",
|
38 |
-
|
39 |
-
|
40 |
-
const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
|
41 |
-
const [loading, setLoading] = useState<boolean>(false);
|
42 |
-
|
43 |
const [availableFilters, setAvailableFilters] = useState<AvailableFilters>({
|
44 |
-
statuses: ["SIGNED", "CLOSED", "TERMINATED","UNKNOWN"],
|
45 |
organizations: [],
|
46 |
countries: [],
|
47 |
legalBases: [],
|
48 |
-
fundingSchemes:[],
|
49 |
-
ids:[]
|
50 |
});
|
51 |
|
|
|
|
|
|
|
|
|
52 |
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
53 |
|
|
|
54 |
const fetchProjects = () => {
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
.then((data: Project[]) => setProjects(data))
|
58 |
-
.catch(console.error);
|
59 |
};
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
.then((data: Stats) => setStats(data))
|
66 |
-
.catch(console.error);
|
67 |
}, 500);
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
.then((
|
|
|
74 |
setAvailableFilters({
|
75 |
statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
|
76 |
organizations: data.organizations,
|
77 |
countries: data.countries,
|
78 |
legalBases: data.legalBases,
|
79 |
fundingSchemes: data.fundingSchemes,
|
80 |
-
ids: []
|
81 |
});
|
82 |
-
})
|
|
|
83 |
};
|
|
|
84 |
interface RagResponse {
|
85 |
answer: string;
|
86 |
source_ids: string[];
|
87 |
}
|
88 |
|
|
|
89 |
const askChatbot = async () => {
|
90 |
-
if (!question.trim() || loading) return;
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
];
|
95 |
-
setChatHistory(newChat);
|
96 |
setQuestion("");
|
97 |
-
setLoading(true);
|
98 |
|
99 |
-
//
|
100 |
-
setChatHistory((
|
101 |
-
...h,
|
102 |
-
{ role: "assistant", content: "Generating answer..." },
|
103 |
-
]);
|
104 |
|
105 |
try {
|
106 |
-
const
|
107 |
method: "POST",
|
108 |
headers: { "Content-Type": "application/json" },
|
109 |
body: JSON.stringify({ query: question }),
|
110 |
});
|
111 |
|
112 |
-
const text = await
|
113 |
-
if (!
|
114 |
-
let
|
115 |
-
try {
|
116 |
-
|
117 |
-
} catch {}
|
118 |
-
throw new Error(errDetail);
|
119 |
}
|
120 |
|
121 |
-
const
|
122 |
-
const
|
123 |
-
const assistantContent = `${
|
124 |
-
|
125 |
-
The output was based on the following Project IDs: ${idList}`;
|
126 |
|
127 |
-
//
|
128 |
-
setChatHistory((
|
129 |
-
...h.slice(0, -1),
|
130 |
-
{ role: "assistant", content: assistantContent },
|
131 |
-
]);
|
132 |
} catch (err: any) {
|
133 |
-
//
|
134 |
-
setChatHistory((
|
135 |
-
...
|
136 |
-
{
|
137 |
-
role: "assistant",
|
138 |
-
content: `Something went wrong: ${err.message}`,
|
139 |
-
},
|
140 |
]);
|
141 |
} finally {
|
142 |
setLoading(false);
|
143 |
-
// scroll to bottom
|
144 |
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
145 |
}
|
146 |
};
|
147 |
|
|
|
148 |
useEffect(() => {
|
149 |
-
// If the user has typed something but it's too short, don't refetch
|
150 |
fetchProjects();
|
151 |
-
}, [
|
152 |
-
page,
|
153 |
-
search,
|
154 |
-
statusFilter,
|
155 |
-
legalFilter,
|
156 |
-
orgFilter,
|
157 |
-
countryFilter,
|
158 |
-
fundingSchemeFilter,
|
159 |
-
idFilter,
|
160 |
-
sortField,
|
161 |
-
sortOrder,
|
162 |
-
]);
|
163 |
|
164 |
useEffect(() => {
|
165 |
-
console.log("Updated filters:", filters);
|
166 |
fetchStats(filters);
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
169 |
|
170 |
return {
|
171 |
-
selectedProject,
|
172 |
-
dashboardProps: {
|
173 |
-
stats,
|
174 |
-
filters,
|
175 |
-
setFilters,
|
176 |
-
availableFilters
|
177 |
-
},
|
178 |
explorerProps: {
|
179 |
projects,
|
180 |
search,
|
@@ -191,29 +191,28 @@ The output was based on the following Project IDs: ${idList}`;
|
|
191 |
setFundingSchemeFilter,
|
192 |
idFilter,
|
193 |
setIdFilter,
|
194 |
-
setSortField,
|
195 |
sortField,
|
196 |
-
|
197 |
sortOrder,
|
|
|
198 |
page,
|
199 |
setPage,
|
200 |
-
setSelectedProject,
|
201 |
question,
|
202 |
setQuestion,
|
203 |
chatHistory,
|
204 |
-
setChatHistory,
|
205 |
askChatbot,
|
206 |
loading,
|
207 |
-
messagesEndRef
|
208 |
},
|
209 |
detailsProps: {
|
210 |
-
project:
|
211 |
question,
|
212 |
setQuestion,
|
213 |
chatHistory,
|
214 |
askChatbot,
|
215 |
loading,
|
216 |
-
messagesEndRef
|
217 |
-
}
|
218 |
};
|
219 |
};
|
|
|
1 |
import { useState, useEffect, useRef } from "react";
|
2 |
import { debounce } from "lodash";
|
3 |
+
import type {
|
4 |
+
Project,
|
5 |
+
OrganizationLocation,
|
6 |
+
FilterState,
|
7 |
+
ChatMessage,
|
8 |
+
AvailableFilters,
|
9 |
+
} from "./types";
|
10 |
|
11 |
interface Stats {
|
12 |
[key: string]: {
|
|
|
18 |
type SortOrder = "asc" | "desc";
|
19 |
|
20 |
export const useAppState = () => {
|
21 |
+
// Projects state and pagination
|
22 |
const [projects, setProjects] = useState<Project[]>([]);
|
23 |
+
const [page, setPage] = useState<number>(0);
|
24 |
+
|
25 |
+
// Search and filter states
|
26 |
const [search, setSearch] = useState<string>("");
|
27 |
const [statusFilter, setStatusFilter] = useState<string>("");
|
28 |
+
const [legalFilter, setLegalFilter] = useState<string>("");
|
29 |
+
const [orgFilter, setOrgFilter] = useState<string>("");
|
30 |
+
const [countryFilter, setCountryFilter] = useState<string>("");
|
31 |
+
const [fundingSchemeFilter, setFundingSchemeFilter] = useState<string>("");
|
32 |
+
const [idFilter, setIdFilter] = useState<string>("");
|
33 |
+
|
34 |
+
// Sorting
|
35 |
+
const [sortField, setSortField] = useState<string>("");
|
|
|
|
|
36 |
const [sortOrder, setSortOrder] = useState<SortOrder>("asc");
|
37 |
+
|
38 |
+
// Dashboard stats and available filters
|
39 |
+
const [stats, setStats] = useState<Stats>({});
|
40 |
const [filters, setFilters] = useState<FilterState>({
|
41 |
status: "",
|
42 |
organization: "",
|
|
|
46 |
maxYear: "2025",
|
47 |
minFunding: "0",
|
48 |
maxFunding: "10000000",
|
49 |
+
});
|
|
|
|
|
|
|
|
|
50 |
const [availableFilters, setAvailableFilters] = useState<AvailableFilters>({
|
51 |
+
statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
|
52 |
organizations: [],
|
53 |
countries: [],
|
54 |
legalBases: [],
|
55 |
+
fundingSchemes: [],
|
56 |
+
ids: [],
|
57 |
});
|
58 |
|
59 |
+
// Chatbot states
|
60 |
+
const [question, setQuestion] = useState<string>("");
|
61 |
+
const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
|
62 |
+
const [loading, setLoading] = useState<boolean>(false);
|
63 |
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
64 |
|
65 |
+
// Fetch projects with current filters, pagination, sorting
|
66 |
const fetchProjects = () => {
|
67 |
+
const query = new URLSearchParams({
|
68 |
+
page: page.toString(),
|
69 |
+
search,
|
70 |
+
status: statusFilter,
|
71 |
+
legalBasis: legalFilter,
|
72 |
+
organization: orgFilter,
|
73 |
+
country: countryFilter,
|
74 |
+
fundingScheme: fundingSchemeFilter,
|
75 |
+
proj_id: idFilter,
|
76 |
+
sortField,
|
77 |
+
sortOrder,
|
78 |
+
}).toString();
|
79 |
+
|
80 |
+
fetch(`/api/projects?${query}`)
|
81 |
+
.then((res) => res.json())
|
82 |
.then((data: Project[]) => setProjects(data))
|
83 |
+
.catch((err) => console.error("Error fetching projects:", err));
|
84 |
};
|
85 |
|
86 |
+
// Fetch stats with debouncing to limit requests
|
87 |
+
const fetchStats = debounce((filterParams: FilterState) => {
|
88 |
+
const query = new URLSearchParams(filterParams as any).toString();
|
89 |
+
fetch(`/api/stats?${query}`)
|
90 |
+
.then((res) => res.json())
|
91 |
.then((data: Stats) => setStats(data))
|
92 |
+
.catch((err) => console.error("Error fetching stats:", err));
|
93 |
}, 500);
|
94 |
|
95 |
+
// Fetch available filter options based on dataset and active filters
|
96 |
+
const fetchAvailableFilters = (filterParams: FilterState) => {
|
97 |
+
const query = new URLSearchParams(filterParams as any).toString();
|
98 |
+
fetch(`/api/filters?${query}`)
|
99 |
+
.then((res) => res.json())
|
100 |
+
.then((data) => {
|
101 |
setAvailableFilters({
|
102 |
statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
|
103 |
organizations: data.organizations,
|
104 |
countries: data.countries,
|
105 |
legalBases: data.legalBases,
|
106 |
fundingSchemes: data.fundingSchemes,
|
107 |
+
ids: [],
|
108 |
});
|
109 |
+
})
|
110 |
+
.catch((err) => console.error("Error fetching filters:", err));
|
111 |
};
|
112 |
+
|
113 |
interface RagResponse {
|
114 |
answer: string;
|
115 |
source_ids: string[];
|
116 |
}
|
117 |
|
118 |
+
// Handle chat submission
|
119 |
const askChatbot = async () => {
|
120 |
+
if (!question.trim() || loading) return;
|
121 |
+
|
122 |
+
// Append user message
|
123 |
+
setChatHistory((prev) => [...prev, { role: "user", content: question }]);
|
|
|
|
|
124 |
setQuestion("");
|
125 |
+
setLoading(true);
|
126 |
|
127 |
+
// Add placeholder while generating
|
128 |
+
setChatHistory((prev) => [...prev, { role: "assistant", content: "Generating answer..." }]);
|
|
|
|
|
|
|
129 |
|
130 |
try {
|
131 |
+
const response = await fetch("/api/rag", {
|
132 |
method: "POST",
|
133 |
headers: { "Content-Type": "application/json" },
|
134 |
body: JSON.stringify({ query: question }),
|
135 |
});
|
136 |
|
137 |
+
const text = await response.text();
|
138 |
+
if (!response.ok) {
|
139 |
+
let detail = text;
|
140 |
+
try { detail = JSON.parse(text).detail; } catch {}
|
141 |
+
throw new Error(detail);
|
|
|
|
|
142 |
}
|
143 |
|
144 |
+
const result: RagResponse = JSON.parse(text);
|
145 |
+
const sources = result.source_ids.length ? result.source_ids.join(", ") : "none";
|
146 |
+
const assistantContent = `${result.answer}\n\nSources: ${sources}`;
|
|
|
|
|
147 |
|
148 |
+
// Replace placeholder with actual answer
|
149 |
+
setChatHistory((prev) => [...prev.slice(0, -1), { role: "assistant", content: assistantContent }]);
|
|
|
|
|
|
|
150 |
} catch (err: any) {
|
151 |
+
// Replace placeholder with error message
|
152 |
+
setChatHistory((prev) => [
|
153 |
+
...prev.slice(0, -1),
|
154 |
+
{ role: "assistant", content: `Error: ${err.message}` },
|
|
|
|
|
|
|
155 |
]);
|
156 |
} finally {
|
157 |
setLoading(false);
|
|
|
158 |
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
159 |
}
|
160 |
};
|
161 |
|
162 |
+
// Effects: refetch when filters/sorting/pagination change
|
163 |
useEffect(() => {
|
|
|
164 |
fetchProjects();
|
165 |
+
}, [page, search, statusFilter, legalFilter, orgFilter, countryFilter, fundingSchemeFilter, idFilter, sortField, sortOrder]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
useEffect(() => {
|
|
|
168 |
fetchStats(filters);
|
169 |
+
}, [filters]);
|
170 |
+
|
171 |
+
useEffect(() => {
|
172 |
+
fetchAvailableFilters(filters);
|
173 |
+
}, [filters]);
|
174 |
|
175 |
return {
|
176 |
+
selectedProject: projects[0] || null,
|
177 |
+
dashboardProps: { stats, filters, setFilters, availableFilters },
|
|
|
|
|
|
|
|
|
|
|
178 |
explorerProps: {
|
179 |
projects,
|
180 |
search,
|
|
|
191 |
setFundingSchemeFilter,
|
192 |
idFilter,
|
193 |
setIdFilter,
|
|
|
194 |
sortField,
|
195 |
+
setSortField,
|
196 |
sortOrder,
|
197 |
+
setSortOrder,
|
198 |
page,
|
199 |
setPage,
|
200 |
+
setSelectedProject: () => {},
|
201 |
question,
|
202 |
setQuestion,
|
203 |
chatHistory,
|
|
|
204 |
askChatbot,
|
205 |
loading,
|
206 |
+
messagesEndRef,
|
207 |
},
|
208 |
detailsProps: {
|
209 |
+
project: projects[0]!,
|
210 |
question,
|
211 |
setQuestion,
|
212 |
chatHistory,
|
213 |
askChatbot,
|
214 |
loading,
|
215 |
+
messagesEndRef,
|
216 |
+
},
|
217 |
};
|
218 |
};
|
run.sh
CHANGED
@@ -1,18 +1,4 @@
|
|
1 |
#!/bin/bash
|
2 |
-
# Start nginx directly
|
3 |
-
#service nginx start &
|
4 |
-
#NGINX_PID=$!
|
5 |
-
#nginx -g "daemon off;" &
|
6 |
-
# Wait briefly to ensure nginx is up
|
7 |
-
#sleep 1
|
8 |
-
# Serve static build
|
9 |
-
#python -m http.server --directory ./static --bind 0.0.0.0 8000 & echo $! > http_server.pid
|
10 |
-
# Start FastAPI
|
11 |
-
#uvicorn "app.main:app" --host 0.0.0.0 --port 7860
|
12 |
-
# Cleanup static server on shutdown
|
13 |
-
#pkill -F http_server.pid
|
14 |
-
#rm http_server.pid
|
15 |
-
# Start nginx in foreground
|
16 |
echo "HF_SDK = $HF_SPACE_SDK, APP_PORT = $APP_PORT, PORT = $PORT"
|
17 |
echo "$GCP_SA_JSON" > /tmp/sa.json
|
18 |
chmod 600 /tmp/sa.json
|
|
|
1 |
#!/bin/bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
echo "HF_SDK = $HF_SPACE_SDK, APP_PORT = $APP_PORT, PORT = $PORT"
|
3 |
echo "$GCP_SA_JSON" > /tmp/sa.json
|
4 |
chmod 600 /tmp/sa.json
|