Rom89823974978 commited on
Commit
86fd3c3
·
1 Parent(s): 9e4b8b0
Dockerfile CHANGED
@@ -1,4 +1,3 @@
1
- # Stage 1: build frontend (React + Vite)
2
  FROM node:18-alpine AS frontend-builder
3
  WORKDIR /app/frontend
4
  COPY frontend/package.json ./
@@ -6,7 +5,6 @@ RUN npm install
6
  COPY frontend/ .
7
  RUN npm run build
8
 
9
- # Stage 2: build backend (FastAPI)
10
  FROM python:3.11-slim AS backend-builder
11
  WORKDIR /app/backend
12
  COPY backend/requirements.txt ./
@@ -14,11 +12,9 @@ RUN pip install --no-cache-dir -r requirements.txt
14
  COPY backend/ .
15
  RUN pip install --no-cache-dir gunicorn uvicorn
16
 
17
- # Stage 3: runtime image with nginx and run script
18
  FROM python:3.11-slim as runtime
19
 
20
- # Install nginx
21
- # Install OS deps
22
  USER root
23
  USER root
24
  RUN apt-get update && \
@@ -39,21 +35,8 @@ RUN mkdir -p \
39
  /var/lib/nginx/proxy \
40
  /var/lib/nginx/fastcgi \
41
  /var/lib/nginx/scgi \
42
- /var/lib/nginx/uwsgi \
43
- && chmod -R a+rwx /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx
44
-
45
- RUN mkdir -p /var/cache/nginx/client_temp \
46
- /var/cache/nginx/proxy_temp \
47
- /var/cache/nginx/fastcgi_temp \
48
- /var/cache/nginx/scgi_temp \
49
- /var/cache/nginx/uwsgi_temp \
50
- /var/log/nginx \
51
- /var/run/nginx \
52
- /var/lib/nginx/body \
53
- /var/lib/nginx/proxy \
54
- /var/lib/nginx/fastcgi \
55
- /var/lib/nginx/scgi \
56
- /var/lib/nginx/uwsgi && \
57
  touch /var/log/nginx/error.log /var/log/nginx/access.log && \
58
  chown -R www-data:www-data /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx
59
 
@@ -65,21 +48,16 @@ ENV HF_HOME=/tmp/hf_cache \
65
  RUN mkdir -p /tmp/hf_cache \
66
  && chmod 777 /tmp/hf_cache
67
 
68
- # Install Python deps from requirements (ensures numpy/pandas compatibility), then ASGI
69
- # copy in your requirements
70
  COPY --from=backend-builder /app/backend/requirements.txt /tmp/requirements.txt
71
 
72
  RUN python3 -m pip install --no-cache-dir \
73
- # 2) Now install the rest (including gptqmodel)
74
  -r /tmp/requirements.txt \
75
  && python3 -m pip install --no-cache-dir fastapi starlette uvicorn
76
 
77
 
78
- # Copy frontend build and backend app
79
  COPY --from=frontend-builder /app/frontend/dist /app/static
80
  COPY --from=backend-builder /app/backend /app/app
81
 
82
- # Copy nginx config and run script
83
  COPY nginx.conf /etc/nginx/nginx.conf
84
  COPY run.sh /app/run.sh
85
  RUN chmod +x /app/run.sh
@@ -87,6 +65,5 @@ RUN chmod -R a+rwx /var/log/nginx
87
 
88
  WORKDIR /app
89
 
90
- # Use run.sh as entrypoint (runs nginx, static server, uvicorn)
91
  ENTRYPOINT ["/bin/bash", "/app/run.sh"]
92
 
 
 
1
  FROM node:18-alpine AS frontend-builder
2
  WORKDIR /app/frontend
3
  COPY frontend/package.json ./
 
5
  COPY frontend/ .
6
  RUN npm run build
7
 
 
8
  FROM python:3.11-slim AS backend-builder
9
  WORKDIR /app/backend
10
  COPY backend/requirements.txt ./
 
12
  COPY backend/ .
13
  RUN pip install --no-cache-dir gunicorn uvicorn
14
 
 
15
  FROM python:3.11-slim as runtime
16
 
17
+
 
18
  USER root
19
  USER root
20
  RUN apt-get update && \
 
35
  /var/lib/nginx/proxy \
36
  /var/lib/nginx/fastcgi \
37
  /var/lib/nginx/scgi \
38
+ /var/lib/nginx/uwsgi && \
39
+ chmod -R a+rwx /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx && \
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  touch /var/log/nginx/error.log /var/log/nginx/access.log && \
41
  chown -R www-data:www-data /var/cache/nginx /var/log/nginx /var/run/nginx /var/lib/nginx
42
 
 
48
  RUN mkdir -p /tmp/hf_cache \
49
  && chmod 777 /tmp/hf_cache
50
 
 
 
51
  COPY --from=backend-builder /app/backend/requirements.txt /tmp/requirements.txt
52
 
53
  RUN python3 -m pip install --no-cache-dir \
 
54
  -r /tmp/requirements.txt \
55
  && python3 -m pip install --no-cache-dir fastapi starlette uvicorn
56
 
57
 
 
58
  COPY --from=frontend-builder /app/frontend/dist /app/static
59
  COPY --from=backend-builder /app/backend /app/app
60
 
 
61
  COPY nginx.conf /etc/nginx/nginx.conf
62
  COPY run.sh /app/run.sh
63
  RUN chmod +x /app/run.sh
 
65
 
66
  WORKDIR /app
67
 
 
68
  ENTRYPOINT ["/bin/bash", "/app/run.sh"]
69
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: EU Explorer # the display name of your Space
3
  emoji: 🤖 # a single emoji that represents your app
4
  colorFrom: purple # one of: red, yellow, green, blue, indigo, purple, pink, gray
5
  colorTo: indigo # another one from the same list, for the gradient
 
1
  ---
2
+ title: EU Explorer (MDA Assignment) # the display name of your Space
3
  emoji: 🤖 # a single emoji that represents your app
4
  colorFrom: purple # one of: red, yellow, green, blue, indigo, purple, pink, gray
5
  colorTo: indigo # another one from the same list, for the gradient
backend/Dockerfile DELETED
@@ -1,10 +0,0 @@
1
- FROM python:3.11-slim
2
- WORKDIR /app
3
- COPY requirements.txt ./
4
- RUN pip install --no-cache-dir -r requirements.txt
5
- RUN adduser --disabled-password appuser
6
- USER appuser
7
- COPY --chown=appuser:appuser . .
8
- ENV PARQUET_PATH="data/consolidated_clean.parquet"
9
- EXPOSE 8080
10
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
 
 
 
 
 
 
 
 
 
 
 
backend/__init__.py DELETED
File without changes
backend/main.py CHANGED
@@ -1,639 +1,118 @@
1
- from fastapi import FastAPI, Request, HTTPException, Depends
2
- from fastapi.middleware.cors import CORSMiddleware
 
 
3
  import traceback
4
- from starlette.concurrency import run_in_threadpool
5
- from pydantic import BaseModel
6
- from pydantic_settings import BaseSettings
7
  from contextlib import asynccontextmanager
8
- from typing import Any, Dict, List, Optional, AsyncGenerator, Tuple
 
9
 
10
- import os
11
- import logging
12
  import aiofiles
 
 
13
  import polars as pl
 
14
  import zipfile
15
- import gcsfs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- from langchain.schema import Document,BaseRetriever
18
- from langchain.text_splitter import RecursiveCharacterTextSplitter
19
- from langchain_community.vectorstores import FAISS
20
- from langchain.retrievers.document_compressors import DocumentCompressorPipeline
21
- from langchain_community.document_transformers import EmbeddingsRedundantFilter
22
- from langchain.memory import ConversationBufferWindowMemory
23
  from langchain.chains import ConversationalRetrievalChain
 
24
  from langchain.prompts import PromptTemplate
25
- from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
26
-
27
- from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer,T5ForConditionalGeneration
28
- from sentence_transformers import CrossEncoder
29
-
30
- from whoosh import index
31
- from whoosh.fields import Schema, TEXT, ID
32
- from whoosh.analysis import StemmingAnalyzer
33
- from whoosh.qparser import MultifieldParser
34
- import pickle
35
- from pydantic import PrivateAttr
36
- from tqdm import tqdm
37
- import faiss
38
- import torch
39
- import tempfile
40
- import shutil
41
 
42
- from functools import lru_cache
 
43
 
44
  # ---------------------------------------------------------------------------- #
45
  # Settings #
46
  # ---------------------------------------------------------------------------- #
47
- # === Logging ===
48
- logging.basicConfig(level=logging.INFO)
49
- logger = logging.getLogger(__name__)
50
 
51
- class Settings(BaseSettings):
52
- # Parquet + Whoosh/FAISS
 
 
 
53
  parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
54
- whoosh_dir: str = "gs://mda_eu_project/whoosh_index"
55
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
56
- # Models
57
- embedding_model: str = "sentence-transformers/LaBSE"
58
- llm_model: str = "google/flan-t5-base"
 
59
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
 
60
  # RAG parameters
61
- chunk_size: int = 750
62
  chunk_overlap: int = 100
63
- hybrid_k: int = 2
64
  assistant_role: str = (
65
- "You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
66
  )
67
  skip_warmup: bool = True
 
 
68
  allowed_origins: List[str] = ["*"]
69
 
70
  class Config:
71
  env_file = ".env"
72
 
 
73
  settings = Settings()
 
 
74
 
75
- # Preinstantiate embedding model (used by filter/compressor)
76
- EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model,
77
- model_kwargs={"trust_remote_code": True})
 
 
78
 
79
  @lru_cache(maxsize=256)
80
  def embed_query_cached(query: str) -> List[float]:
81
- """Cache embedding vectors for queries."""
82
  return EMBEDDING.embed_query(query.strip().lower())
83
 
84
- # === Whoosh Cache & Builder ===
85
- """_WHOOSH_CACHE: Dict[str, index.Index] = {}
86
-
87
- async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
88
- key = whoosh_dir
89
- fs = gcsfs.GCSFileSystem()
90
- local_dir = key
91
- is_gcs = key.startswith("gs://")
92
- try:
93
- # stage local copy for GCS
94
- if is_gcs:
95
- local_dir = "/tmp/whoosh_index"
96
- if not os.path.exists(local_dir):
97
- if await run_in_threadpool(fs.exists, key):
98
- await run_in_threadpool(fs.get, key, local_dir, recursive=True)
99
- else:
100
- os.makedirs(local_dir, exist_ok=True)
101
- # build once
102
- if key not in _WHOOSH_CACHE:
103
- os.makedirs(local_dir, exist_ok=True)
104
- schema = Schema(
105
- id=ID(stored=True, unique=True),
106
- content=TEXT(analyzer=StemmingAnalyzer()),
107
- )
108
- ix = index.create_in(local_dir, schema)
109
- with ix.writer() as writer:
110
- for doc in docs:
111
- writer.add_document(
112
- id=doc.metadata.get("id", ""),
113
- content=doc.page_content,
114
- )
115
- # push back to GCS atomically
116
- if is_gcs:
117
- await run_in_threadpool(fs.put, local_dir, key, recursive=True)
118
- _WHOOSH_CACHE[key] = ix
119
- return _WHOOSH_CACHE[key]
120
- except Exception as e:
121
- logger.error(f"Failed to build Whoosh index: {e}")
122
- raise"""
123
-
124
- async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
125
- """
126
- If gs://.../whoosh_index.zip exists, download & extract it once.
127
- Otherwise build locally from docs and upload the ZIP back to GCS.
128
- """
129
- fs = gcsfs.GCSFileSystem()
130
- is_gcs = whoosh_dir.startswith("gs://")
131
- zip_uri = whoosh_dir.rstrip("/") + ".zip"
132
-
133
- local_zip = "/tmp/whoosh_index.zip"
134
- local_dir = "/tmp/whoosh_index"
135
-
136
- # Clean slate
137
- if os.path.exists(local_dir):
138
- shutil.rmtree(local_dir)
139
- os.makedirs(local_dir, exist_ok=True)
140
-
141
- # 1️⃣ Try downloading the ZIP if it exists on GCS
142
- if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
143
- logger.info("Found whoosh_index.zip on GCS; downloading…")
144
- await run_in_threadpool(fs.get, zip_uri, local_zip)
145
- # Extract all files (flat) into local_dir
146
- with zipfile.ZipFile(local_zip, "r") as zf:
147
- for member in zf.infolist():
148
- if member.is_dir():
149
- continue
150
- filename = os.path.basename(member.filename)
151
- if not filename:
152
- continue
153
- target = os.path.join(local_dir, filename)
154
- os.makedirs(os.path.dirname(target), exist_ok=True)
155
- with zf.open(member) as src, open(target, "wb") as dst:
156
- dst.write(src.read())
157
- logger.info("Whoosh index extracted from ZIP.")
158
- else:
159
- logger.info("No whoosh_index.zip found; building index from docs.")
160
-
161
- # Define the schema with stored content
162
- schema = Schema(
163
- id=ID(stored=True, unique=True),
164
- content=TEXT(stored=True, analyzer=StemmingAnalyzer()),
165
- )
166
-
167
- # Create the index
168
- ix = index.create_in(local_dir, schema)
169
- writer = ix.writer()
170
- for doc in docs:
171
- writer.add_document(
172
- id=doc.metadata.get("id", ""),
173
- content=doc.page_content,
174
- )
175
- writer.commit()
176
- logger.info("Whoosh index built locally.")
177
-
178
- # Upload the ZIP back to GCS
179
- if is_gcs:
180
- logger.info("Zipping and uploading new whoosh_index.zip to GCS…")
181
- with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
182
- for root, _, files in os.walk(local_dir):
183
- for fname in files:
184
- full = os.path.join(root, fname)
185
- arc = os.path.relpath(full, local_dir)
186
- zf.write(full, arc)
187
- await run_in_threadpool(fs.put, local_zip, zip_uri)
188
- logger.info("Uploaded whoosh_index.zip to GCS.")
189
-
190
- # 2️⃣ Finally open the index and return it
191
- ix = index.open_dir(local_dir)
192
- return ix
193
-
194
- # === Document Loader ===
195
- async def load_documents(
196
- path: str,
197
- sample_size: Optional[int] = None
198
- ) -> List[Document]:
199
- """
200
- Load a Parquet file from local or GCS, convert to a list of Documents.
201
- """
202
- def _read_local(p: str, n: Optional[int]):
203
- # streaming scan keeps memory low
204
- lf = pl.scan_parquet(p)
205
- if n:
206
- lf = lf.limit(n)
207
- return lf.collect(streaming=True)
208
-
209
- def _read_gcs(p: str, n: Optional[int]):
210
- # download to a temp file synchronously, then read with Polars
211
- fs = gcsfs.GCSFileSystem()
212
- with tempfile.TemporaryDirectory() as td:
213
- local_path = os.path.join(td, "data.parquet")
214
- fs.get(p, local_path, recursive=False)
215
- df = pl.read_parquet(local_path)
216
- if n:
217
- df = df.head(n)
218
- return df
219
-
220
- try:
221
- if path.startswith("gs://"):
222
- df = await run_in_threadpool(_read_gcs, path, sample_size)
223
- else:
224
- df = await run_in_threadpool(_read_local, path, sample_size)
225
- except Exception as e:
226
- logger.error(f"Error loading documents: {e}")
227
- raise HTTPException(status_code=500, detail="Document loading failed.")
228
-
229
- docs: List[Document] = []
230
- for row in df.rows(named=True):
231
- context_parts: List[str] = []
232
- # build metadata context
233
- max_contrib = row.get("ecMaxContribution", "")
234
- end_date = row.get("endDate", "")
235
- duration = row.get("durationDays", "")
236
- status = row.get("status", "")
237
- legal = row.get("legalBasis", "")
238
- framework = row.get("frameworkProgramme", "")
239
- scheme = row.get("fundingScheme", "")
240
- names = row.get("list_name", []) or []
241
- cities = row.get("list_city", []) or []
242
- countries = row.get("list_country", []) or []
243
- activity = row.get("list_activityType", []) or []
244
- contributions = row.get("list_ecContribution", []) or []
245
- smes = row.get("list_sme", []) or []
246
- project_id =row.get("id", "")
247
- pred=row.get("predicted_label", "")
248
- proba=row.get("predicted_prob", "")
249
- top1_feats=row.get("top1_features", "")
250
- top2_feats=row.get("top2_features", "")
251
- top3_feats=row.get("top3_features", "")
252
- top1_shap=row.get("top1_shap", "")
253
- top2_shap=row.get("top2_shap", "")
254
- top3_shap=row.get("top3_shap", "")
255
-
256
-
257
- context_parts.append(
258
- f"This project under framework {framework} with funding scheme {scheme}, status {status}, legal basis {legal}."
259
- )
260
- context_parts.append(
261
- f"It ends on {end_date} after {duration} days and has a max EC contribution of {max_contrib}."
262
- )
263
- context_parts.append("Participating organizations:")
264
- for i, name in enumerate(names):
265
- city = cities[i] if i < len(cities) else ""
266
- country = countries[i] if i < len(countries) else ""
267
- act = activity[i] if i < len(activity) else ""
268
- contrib = contributions[i] if i < len(contributions) else ""
269
- sme_flag = "SME" if (smes and i < len(smes) and smes[i]) else "non-SME"
270
- context_parts.append(
271
- f"- {name} in {city}, {country}, activity: {act}, contributed: {contrib}, {sme_flag}."
272
- )
273
- if status in (None,"signed","SIGNED","Signed"):
274
- if int(pred) == 1:
275
- label = "TERMINATED"
276
- score = float(proba)
277
- else:
278
- label = "CLOSED"
279
- score = 1 - float(proba)
280
-
281
- score_str = f"{score:.2f}"
282
-
283
- context_parts.append(
284
- f"- Project {project_id} is predicted to be {label} (score={score_str}). "
285
- f"The 3 most predictive features were: "
286
- f"{top1_feats} ({top1_shap:.3f}), "
287
- f"{top2_feats} ({top2_shap:.3f}), "
288
- f"{top3_feats} ({top3_shap:.3f})."
289
- )
290
-
291
- title_report = row.get("list_title_report", "")
292
- objective = row.get("objective", "")
293
- full_body = f"{title_report} {objective}"
294
- full_text = " ".join(context_parts + [full_body])
295
- meta: Dict[str, Any] = {"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))}
296
- meta.update({"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))})
297
- docs.append(Document(page_content=full_text, metadata=meta))
298
- return docs
299
-
300
- # === BM25 Search ===
301
- async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
302
- parser = MultifieldParser(["content"], schema=ix.schema)
303
- def _search() -> List[Document]:
304
- with ix.searcher() as searcher:
305
- hits = searcher.search(parser.parse(query), limit=k)
306
- return [Document(page_content=h["content"], metadata={"id": h["id"]}) for h in hits]
307
- return await run_in_threadpool(_search)
308
-
309
- # === Helper: build or load FAISS with mmap ===
310
- """async def build_or_load_faiss(
311
- chunks: List[Document],
312
- vectorstore_path: str,
313
- batch_size: int = 15000
314
- ) -> FAISS:"""
315
- """
316
- Always uses GCS. Expects 'index.faiss' and 'index.pkl' under vectorstore_path.
317
- Reconstructs the FAISS store using your provided logic.
318
- """
319
- """assert vectorstore_path.startswith("gs://")
320
- fs = gcsfs.GCSFileSystem()
321
-
322
- base = vectorstore_path.rstrip("/")
323
- uri_index = f"{base}/index.faiss"
324
- uri_meta = f"{base}/index.pkl"
325
-
326
- local_index = "/tmp/index.faiss"
327
- local_meta = "/tmp/index.pkl"
328
-
329
- # 1) If existing index + metadata on GCS → load
330
- if fs.exists(uri_index) and fs.exists(uri_meta):
331
- logger.info("Found existing FAISS index on GCS; loading…")
332
- os.makedirs(os.path.dirname(local_index), exist_ok=True)
333
- await run_in_threadpool(
334
- fs.get_bulk,
335
- [uri_index, uri_meta],
336
- [local_index, local_meta]
337
- )
338
-
339
- # 3) Memory‐map load
340
- mmap_idx = await run_in_threadpool(
341
- faiss.read_index, local_index, faiss.IO_FLAG_MMAP
342
- )
343
-
344
- # Load metadata
345
- with open(local_meta, "rb") as f:
346
- saved = pickle.load(f)
347
-
348
- # extract metadata
349
- if isinstance(saved, tuple):
350
- # Handle tuple of length 2 or 3
351
- if len(saved) == 3:
352
- _, docstore, index_to_docstore = saved
353
- elif len(saved) == 2:
354
- docstore, index_to_docstore = saved
355
- else:
356
- raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
357
- else:
358
- # saved is an object with attributes
359
- if hasattr(saved, "docstore"):
360
- docstore = saved.docstore
361
- elif hasattr(saved, "_docstore"):
362
- docstore = saved._docstore
363
- else:
364
- raise AttributeError("Could not find docstore in FAISS metadata")
365
-
366
- if hasattr(saved, "index_to_docstore"):
367
- index_to_docstore = saved.index_to_docstore
368
- elif hasattr(saved, "_index_to_docstore"):
369
- index_to_docstore = saved._index_to_docstore
370
- elif hasattr(saved, "_faiss_index_to_docstore"):
371
- index_to_docstore = saved._faiss_index_to_docstore
372
- else:
373
- raise AttributeError("Could not find index_to_docstore in FAISS metadata")
374
-
375
- # reconstruct FAISS wrapper
376
- vs = FAISS(
377
- embedding_function=EMBEDDING, # your embedding function
378
- index=mmap_idx,
379
- docstore=docstore,
380
- index_to_docstore_id=index_to_docstore,
381
- )
382
- return vs
383
-
384
- # 2) Else: build from scratch in batches
385
- # parse bucket & prefix
386
- _, rest = vectorstore_path.split("://", 1)
387
- bucket, *path_parts = rest.split("/", 1)
388
- prefix = path_parts[0] if path_parts else ""
389
-
390
- # helper to upload entire local dir to GCS
391
- def upload_dir(local_dir: str):
392
- for root, _, files in os.walk(local_dir):
393
- for fname in files:
394
- local_path = os.path.join(root, fname)
395
- # construct the corresponding GCS path
396
- rel_path = os.path.relpath(local_path, local_dir)
397
- gcs_path = f"gs://{bucket}/{prefix}/{rel_path}"
398
- fs.makedirs(os.path.dirname(gcs_path), exist_ok=True)
399
- fs.put(local_path, gcs_path)
400
-
401
- # temporary local staging area
402
- local_store = "/tmp/faiss_store"
403
- os.makedirs(local_store, exist_ok=True)
404
-
405
- # 2) Else: build from scratch in batches
406
- logger.info(f"Building FAISS index in batches of {batch_size}…")
407
- vs: Optional[FAISS] = None
408
-
409
- for i in tqdm(range(0, len(chunks), batch_size),
410
- desc="Building FAISS index",
411
- unit="batch"):
412
- batch = chunks[i : i + batch_size]
413
-
414
- if vs is None:
415
- vs = FAISS.from_documents(batch, EMBEDDING)
416
- else:
417
- vs.add_documents(batch)
418
-
419
- # periodic save every 5 batches
420
- if (i // batch_size) % 5 == 0:
421
- # save into local_store
422
- vs.save_local(local_store)
423
- # push local_store → GCS
424
- upload_dir(local_store)
425
- logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
426
-
427
- assert vs is not None, "No documents to index!"
428
-
429
- # final save at end
430
- vs.save_local(local_store)
431
- upload_dir(local_store)
432
- logger.info("Finished building index and uploaded to GCS.")
433
-
434
- return vs"""
435
-
436
- async def build_or_load_faiss(
437
- docs: List[Document],
438
- vectorstore_path: str,
439
- batch_size: int = 15000
440
- ) -> FAISS:
441
- """
442
- Expects a ZIP at vectorstore_path + ".zip" containing:
443
- - index.faiss
444
- - index.pkl
445
- Files may be nested under a subfolder (e.g. vectorstore_index_colab/).
446
- If the ZIP exists on GCS, download & load only.
447
- Otherwise, build from `docs`, save, re-zip, and upload.
448
- """
449
- fs = gcsfs.GCSFileSystem()
450
- is_gcs = vectorstore_path.startswith("gs://")
451
- zip_uri = vectorstore_path.rstrip("/") + ".zip"
452
-
453
- local_zip = "/tmp/faiss_index.zip"
454
- local_dir = "/tmp/faiss_store"
455
-
456
- # 1) if ZIP exists, download & extract
457
- if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
458
- logger.info("Found FAISS ZIP on GCS; loading only.")
459
- # clean slate
460
- if os.path.exists(local_dir):
461
- shutil.rmtree(local_dir)
462
- os.makedirs(local_dir, exist_ok=True)
463
-
464
- # download zip
465
- await run_in_threadpool(fs.get, zip_uri, local_zip)
466
-
467
- # extract
468
- def _extract():
469
- with zipfile.ZipFile(local_zip, "r") as zf:
470
- zf.extractall(local_dir)
471
- await run_in_threadpool(_extract)
472
-
473
- # locate the two files anywhere under local_dir
474
- idx_path = None
475
- meta_path = None
476
- for root, _, files in os.walk(local_dir):
477
- if "index.faiss" in files:
478
- idx_path = os.path.join(root, "index.faiss")
479
- if "index.pkl" in files:
480
- meta_path = os.path.join(root, "index.pkl")
481
- if not idx_path or not meta_path:
482
- raise FileNotFoundError("Couldn't find index.faiss or index.pkl in extracted ZIP.")
483
-
484
- # memory-map load
485
- mmap_index = await run_in_threadpool(
486
- faiss.read_index, idx_path, faiss.IO_FLAG_MMAP
487
- )
488
-
489
- # load metadata
490
- with open(meta_path, "rb") as f:
491
- saved = pickle.load(f)
492
-
493
- # unpack metadata
494
- if isinstance(saved, tuple):
495
- _, docstore, index_to_docstore = (
496
- saved if len(saved) == 3 else (None, *saved)
497
- )
498
- else:
499
- docstore = getattr(saved, "docstore", saved._docstore)
500
- index_to_docstore = getattr(
501
- saved,
502
- "index_to_docstore",
503
- getattr(saved, "_index_to_docstore", saved._faiss_index_to_docstore)
504
- )
505
-
506
- # reconstruct FAISS
507
- vs = FAISS(
508
- embedding_function=EMBEDDING,
509
- index=mmap_index,
510
- docstore=docstore,
511
- index_to_docstore_id=index_to_docstore,
512
- )
513
- logger.info("FAISS index loaded from ZIP.")
514
- return vs
515
-
516
- # 2) otherwise, build from scratch and upload
517
- logger.info("No FAISS ZIP found; building index from scratch.")
518
- if os.path.exists(local_dir):
519
- shutil.rmtree(local_dir)
520
- os.makedirs(local_dir, exist_ok=True)
521
-
522
- vs: FAISS = None
523
- for i in range(0, len(docs), batch_size):
524
- batch = docs[i : i + batch_size]
525
- if vs is None:
526
- vs = FAISS.from_documents(batch, EMBEDDING)
527
- else:
528
- vs.add_documents(batch)
529
- assert vs is not None, "No documents to index!"
530
-
531
- # save locally
532
- vs.save_local(local_dir)
533
-
534
- if is_gcs:
535
- # re-zip all contents of local_dir (flattened)
536
- def _zip_dir():
537
- with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
538
- for root, _, files in os.walk(local_dir):
539
- for fname in files:
540
- full = os.path.join(root, fname)
541
- arc = os.path.relpath(full, local_dir)
542
- zf.write(full, arc)
543
- await run_in_threadpool(_zip_dir)
544
- await run_in_threadpool(fs.put, local_zip, zip_uri)
545
- logger.info("Built FAISS index and uploaded ZIP to GCS.")
546
-
547
- return vs
548
-
549
-
550
- # === Index Builder ===
551
- async def build_indexes(
552
- parquet_path: str,
553
- vectorstore_path: str,
554
- whoosh_dir: str,
555
- chunk_size: int,
556
- chunk_overlap: int,
557
- debug_size: Optional[int]
558
- ) -> Tuple[FAISS, index.Index]:
559
- docs = await load_documents(parquet_path, debug_size)
560
- ix = await build_whoosh_index(docs, whoosh_dir)
561
-
562
- splitter = RecursiveCharacterTextSplitter(
563
- chunk_size=chunk_size, chunk_overlap=chunk_overlap
564
- )
565
- chunks = splitter.split_documents(docs)
566
-
567
- # build or load (with mmap) FAISS
568
- vs = await build_or_load_faiss(chunks, vectorstore_path)
569
-
570
- return vs, ix
571
-
572
- # === Hybrid Retriever ===
573
- class HybridRetriever(BaseRetriever):
574
- """Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
575
- # store FAISS and Whoosh under private attributes to avoid Pydantic field errors
576
- _vs: FAISS = PrivateAttr()
577
- _ix: index.Index = PrivateAttr()
578
- _compressor: DocumentCompressorPipeline = PrivateAttr()
579
- _cross_encoder: CrossEncoder = PrivateAttr()
580
-
581
- def __init__(
582
- self,
583
- vs: FAISS,
584
- ix: index.Index,
585
- compressor: DocumentCompressorPipeline,
586
- cross_encoder: CrossEncoder
587
- ) -> None:
588
- super().__init__()
589
- object.__setattr__(self, '_vs', vs)
590
- object.__setattr__(self, '_ix', ix)
591
- object.__setattr__(self, '_compressor', compressor)
592
- object.__setattr__(self, '_cross_encoder', cross_encoder)
593
-
594
- async def _aget_relevant_documents(self, query: str) -> List[Document]:
595
- # BM25 retrieval using Whoosh index
596
- bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
597
- # Dense retrieval using FAISS
598
- dense_docs = self._vs.similarity_search_by_vector(
599
- embed_query_cached(query), k=settings.hybrid_k
600
- )
601
- # Cross-encoder re-ranking
602
- candidates = bm_docs + dense_docs
603
- scores = self._cross_encoder.predict([
604
- (query, doc.page_content) for doc in candidates
605
- ])
606
- ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
607
- top = [doc for _, doc in ranked[: settings.hybrid_k]]
608
- # Compress and return
609
- return self._compressor.compress_documents(top, query=query)
610
-
611
- def _get_relevant_documents(self, query: str) -> List[Document]:
612
- import asyncio
613
- return asyncio.get_event_loop().run_until_complete(
614
- self._aget_relevant_documents(query)
615
- )
616
-
617
  # ---------------------------------------------------------------------------- #
618
- # Lifespan #
619
  # ---------------------------------------------------------------------------- #
 
 
 
620
  @asynccontextmanager
621
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
622
- # --- 1) RAG Initialization --- #
623
- logger = logging.getLogger("uvicorn")
624
- logger.info("Initializing RAG components…")
625
-
626
- # Compressor pipeline to de‐duplicate via embeddings
627
  logger.info("Initializing Document Compressor")
628
  compressor = DocumentCompressorPipeline(
629
  transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
630
  )
631
 
632
- # Cross‐encoder ranker
633
  logger.info("Initializing Cross-Encoder")
634
  cross_encoder = CrossEncoder(settings.cross_encoder_model)
635
-
636
- # Apply dynamic quantization to speed up CPU inference
637
  cross_encoder.model = torch.quantization.quantize_dynamic(
638
  cross_encoder.model,
639
  {torch.nn.Linear},
@@ -641,34 +120,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
641
  )
642
  logger.info("Cross-Encoder quantized")
643
 
644
- # Seq2seq pipeline
645
- logger.info("Initializing Pipeline")
646
- #full_model=AutoModelForSeq2SeqLM.from_pretrained(settings.llm_model)
647
- #full_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)#, device_map="auto")
648
-
649
- # Apply dynamic quantization to all Linear layers
650
- #llm_model = torch.quantization.quantize_dynamic(
651
- # full_model,
652
- # {torch.nn.Linear},
653
- # dtype=torch.qint8
654
- #)
655
- # Create your text-generation pipeline on CPU
656
- #gen_pipe = pipeline(
657
- # "text-generation",#"text2text-generation",##"text2text-generation",
658
- # model=llm_model,
659
- # tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
660
- # device=-1, # CPU
661
- # max_new_tokens=256,
662
- # do_sample=True,
663
- # temperature=0.7,
664
- # #device_map="auto"
665
- #)
666
  tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
667
- model = T5ForConditionalGeneration.from_pretrained(settings.llm_model)
668
- model = torch.quantization.quantize_dynamic(
669
  model, {torch.nn.Linear}, dtype=torch.qint8
670
  )
671
-
672
  gen_pipe = pipeline(
673
  "text2text-generation",
674
  model=model,
@@ -678,20 +136,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
678
  do_sample=True,
679
  temperature=0.7,
680
  )
681
- # Wrap in LangChain's HuggingFacePipeline
682
  llm = HuggingFacePipeline(pipeline=gen_pipe)
683
 
684
- # Conversational memory
685
  logger.info("Initializing Conversation Memory")
686
  memory = ConversationBufferWindowMemory(
687
  memory_key="chat_history",
688
- k=5,
689
  input_key="question",
690
  output_key="answer",
691
  return_messages=True,
 
692
  )
693
- logger.info("Initializing Indexes")
694
- # Build or load FAISS & Whoosh once
 
695
  vs, ix = await build_indexes(
696
  settings.parquet_path,
697
  settings.vectorstore_path,
@@ -700,21 +158,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
700
  settings.chunk_overlap,
701
  None,
702
  )
703
- logger.info("Initializing Hybrid Retriever")
704
  retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
705
-
 
706
  prompt = PromptTemplate.from_template(
707
- f"{settings.assistant_role} \n\n"
708
  "{context}\n"
709
- "User Question:\n"
710
- "{question}\n"
711
- "Please answer thoroughly, following these rules:\n"
712
- "1. Write at least 4-6 full sentences.\n"
713
- "2. Use clear, technical language in full sentences.\n"
714
- "3. Cite any document you reference by including its ID in [brackets] inline.\n"
715
- "4. Conclude with high-level insights or recommendations.\n"
716
- "Answer:")
717
 
 
718
  logger.info("Initializing Retrieval Chain")
719
  app.state.rag_chain = ConversationalRetrievalChain.from_llm(
720
  llm=llm,
@@ -724,35 +178,33 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
724
  return_source_documents=True,
725
  )
726
 
 
727
  if not settings.skip_warmup:
728
- logger.info("Warming up RAG chain")
729
  await app.state.rag_chain.ainvoke({"question": "warmup"})
730
- logger.info("RAG ready.")
731
 
732
- # --- 2) Dataframe Initialization --- #
733
- logger.info("Loading Parquet data from GCS")
734
  fs = gcsfs.GCSFileSystem()
735
  with fs.open(settings.parquet_path, "rb") as f:
736
  df = pl.read_parquet(f)
737
-
738
  df = df.with_columns(
739
- pl.col("id").cast(pl.Int64).alias("id")
 
 
 
740
  )
741
 
742
- # lowercase for filtering
743
- for col in ("title", "status", "legalBasis","fundingScheme"):
744
- df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
745
-
746
- # materialize unique filter values
747
  app.state.df = df
748
- app.state.statuses = df["_status_lc"].unique().to_list()
749
- app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
750
- app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
751
- app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
752
  app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
753
 
754
- yield
755
-
756
  # ---------------------------------------------------------------------------- #
757
  # App Setup #
758
  # ---------------------------------------------------------------------------- #
@@ -765,19 +217,29 @@ app.add_middleware(
765
  )
766
 
767
  # ---------------------------------------------------------------------------- #
768
- # RAG Endpoint #
769
  # ---------------------------------------------------------------------------- #
 
770
  class RAGRequest(BaseModel):
771
- session_id: Optional[str] = None
772
- query: str
773
 
774
  class RAGResponse(BaseModel):
775
  answer: str
776
  source_ids: List[str]
 
 
 
 
777
 
778
  def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
 
 
 
 
779
  chain = app.state.rag_chain
780
  if chain is None:
 
781
  raise HTTPException(status_code=500, detail="RAG chain not initialized")
782
  return chain
783
 
@@ -786,27 +248,47 @@ async def ask_rag(
786
  req: RAGRequest,
787
  rag_chain = Depends(rag_chain_depender)
788
  ):
789
- logger.info("Starting to answer")
 
 
 
 
 
 
 
 
790
  try:
 
791
  result = await rag_chain.ainvoke({"question": req.query})
792
- logger.info("Results retrieved")
 
 
793
  if not isinstance(result, dict):
 
794
  result2 = await rag_chain.acall({"question": req.query})
795
- raise ValueError(f"Expected dict from chain, got {type(result)} and acall(): {result2} with type {type(result2)}")
 
 
 
 
 
796
  answer = result.get("answer")
797
  docs = result.get("source_documents", [])
798
- sources = [d.metadata.get("id","") for d in docs]
 
799
  return RAGResponse(answer=answer, source_ids=sources)
800
 
801
  except Exception as e:
802
- # print full traceback to your container logs
803
  traceback.print_exc()
804
- # return a proper JSON 500
805
  raise HTTPException(status_code=500, detail=str(e))
806
 
 
807
  # ---------------------------------------------------------------------------- #
808
  # Data Endpoints #
809
  # ---------------------------------------------------------------------------- #
 
810
  @app.get("/api/projects")
811
  def get_projects(
812
  page: int = 0,
@@ -821,10 +303,25 @@ def get_projects(
821
  sortOrder: str = "desc",
822
  sortField: str = "startDate",
823
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  df: pl.DataFrame = app.state.df
825
  start = page * limit
826
  sel = df
827
 
 
828
  if search:
829
  sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
830
  if status:
@@ -843,19 +340,22 @@ def get_projects(
843
  if proj_id:
844
  sel = sel.filter(pl.col("id") == int(proj_id))
845
 
 
846
  base_cols = [
847
  "id","title","status","startDate","endDate","ecMaxContribution","acronym",
848
  "legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
849
  "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
850
  ]
851
- # add shap/explanation columns
852
  for i in range(1,7):
853
  base_cols += [f"top{i}_feature", f"top{i}_shap"]
854
  base_cols += ["predicted_label","predicted_prob"]
855
 
856
- sort_desc = True if sortOrder=="desc" else False
 
857
  sortField = sortField if sortField in df.columns else "startDate"
858
 
 
859
  rows = (
860
  sel.sort(sortField, descending=sort_desc)
861
  .slice(start, limit)
@@ -865,6 +365,7 @@ def get_projects(
865
 
866
  projects = []
867
  for row in rows:
 
868
  explanations = []
869
  for i in range(1,7):
870
  feat = row.pop(f"top{i}_feature", None)
@@ -873,7 +374,7 @@ def get_projects(
873
  explanations.append({"feature": feat, "shap": shap})
874
  row["explanations"] = explanations
875
 
876
- # publications aggregation
877
  raw_pubs = row.pop("list_publications", []) or []
878
  pub_counts: Dict[str,int] = {}
879
  for p in raw_pubs:
@@ -886,14 +387,20 @@ def get_projects(
886
 
887
  @app.get("/api/filters")
888
  def get_filters(request: Request):
 
 
 
 
 
889
  df = app.state.df
890
  params = request.query_params
891
 
 
892
  if s := params.get("status"):
893
- df = df.filter(pl.col("status").is_null() if s=="UNKNOWN"
894
- else pl.col("_status_lc")==s.lower())
895
  if lb := params.get("legalBasis"):
896
- df = df.filter(pl.col("_legalBasis_lc")==lb.lower())
897
  if org := params.get("organization"):
898
  df = df.filter(pl.col("list_name").list.contains(org))
899
  if c := params.get("country"):
@@ -902,6 +409,7 @@ def get_filters(request: Request):
902
  df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
903
 
904
  def normalize(vals):
 
905
  return sorted({("UNKNOWN" if v is None else v) for v in vals})
906
 
907
  return {
@@ -910,31 +418,37 @@ def get_filters(request: Request):
910
  "organizations": normalize(df["list_name"].explode().to_list())[:500],
911
  "countries": normalize(df["list_country"].explode().to_list()),
912
  "fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
913
- #"ids": normalize(df["id"].to_list()),
914
  }
915
 
916
  @app.get("/api/stats")
917
  def get_stats(request: Request):
 
 
 
 
 
918
  lf = app.state.df.lazy()
919
  params = request.query_params
920
 
 
921
  if s := params.get("status"):
922
- lf = lf.filter(pl.col("_status_lc")==s.lower())
923
  if lb := params.get("legalBasis"):
924
- lf = lf.filter(pl.col("_legalBasis_lc")==lb.lower())
925
  if org := params.get("organization"):
926
  lf = lf.filter(pl.col("list_name").list.contains(org))
927
  if c := params.get("country"):
928
  lf = lf.filter(pl.col("list_country").list.contains(c))
929
  if mn := params.get("minFunding"):
930
- lf = lf.filter(pl.col("ecMaxContribution")>=int(mn))
931
  if mx := params.get("maxFunding"):
932
- lf = lf.filter(pl.col("ecMaxContribution")<=int(mx))
933
  if y1 := params.get("minYear"):
934
- lf = lf.filter(pl.col("startDate").dt.year()>=int(y1))
935
  if y2 := params.get("maxYear"):
936
- lf = lf.filter(pl.col("startDate").dt.year()<=int(y2))
937
 
 
938
  grouped = (
939
  lf.select(pl.col("startDate").dt.year().alias("year"))
940
  .group_by("year")
@@ -944,6 +458,7 @@ def get_stats(request: Request):
944
  )
945
  years, counts = grouped["year"].to_list(), grouped["count"].to_list()
946
 
 
947
  return {
948
  "Projects per Year": {"labels": years, "values": counts},
949
  "Projects per Year 2": {"labels": years, "values": counts},
@@ -955,11 +470,17 @@ def get_stats(request: Request):
955
 
956
  @app.get("/api/project/{project_id}/organizations")
957
  def get_project_organizations(project_id: str):
 
 
 
 
 
958
  df = app.state.df
959
- sel = df.filter(pl.col("id")==int(project_id))
960
  if sel.is_empty():
961
  raise HTTPException(status_code=404, detail="Project not found")
962
 
 
963
  orgs_df = (
964
  sel.select([
965
  pl.col("list_name").explode().alias("name"),
@@ -973,9 +494,11 @@ def get_project_organizations(project_id: str):
973
  pl.col("list_geolocation").explode().alias("geoloc"),
974
  ])
975
  .with_columns([
 
976
  pl.col("geoloc").str.split(",").alias("latlon"),
977
  ])
978
  .with_columns([
 
979
  pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
980
  pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
981
  ])
@@ -985,5 +508,6 @@ def get_project_organizations(project_id: str):
985
  "activityType","orgURL","country","latitude","longitude"
986
  ])
987
  )
988
- logger.info(f"{orgs_df.to_dicts()}")
989
  return orgs_df.to_dicts()
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import tempfile
5
  import traceback
 
 
 
6
  from contextlib import asynccontextmanager
7
+ from functools import lru_cache
8
+ from typing import Any, AsyncGenerator, Dict, List, Optional
9
 
 
 
10
  import aiofiles
11
+ import faiss
12
+ import gcsfs
13
  import polars as pl
14
+ import torch
15
  import zipfile
16
+ from fastapi import Depends, FastAPI, HTTPException, Request
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel, BaseSettings, PrivateAttr
19
+ from pydantic_settings import BaseSettings as SettingsBase
20
+ from sentence_transformers import CrossEncoder
21
+ from starlette.concurrency import run_in_threadpool
22
+ from tqdm import tqdm
23
+ from transformers import ( # Transformers for LLM pipeline
24
+ AutoModelForCausalLM,
25
+ AutoModelForSeq2SeqLM,
26
+ AutoTokenizer,
27
+ pipeline,
28
+ T5ForConditionalGeneration,
29
+ T5Tokenizer,
30
+ )
31
 
32
+ # LangChain imports for RAG
 
 
 
 
 
33
  from langchain.chains import ConversationalRetrievalChain
34
+ from langchain.memory import ConversationBufferWindowMemory
35
  from langchain.prompts import PromptTemplate
36
+ from langchain.schema import BaseRetriever, Document
37
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
38
+ from langchain_community.document_transformers import EmbeddingsRedundantFilter
39
+ from langchain_community.vectorstores import FAISS
40
+ from langchain_community.retrievers.document_compressors import DocumentCompressorPipeline
41
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Project-specific imports
44
+ from app.rag import build_indexes, HybridRetriever
45
 
46
  # ---------------------------------------------------------------------------- #
47
  # Settings #
48
  # ---------------------------------------------------------------------------- #
 
 
 
49
 
50
+ class Settings(SettingsBase):
51
+ """
52
+ Configuration settings loaded from environment or .env file.
53
+ """
54
+ # Data sources
55
  parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
56
+ whoosh_dir: str = "gs://mda_eu_project/whoosh_index"
57
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
58
+
59
+ # Model names
60
+ embedding_model: str = "sentence-transformers/LaBSE"
61
+ llm_model: str = "google/flan-t5-base"
62
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
63
+
64
  # RAG parameters
65
+ chunk_size: int = 750
66
  chunk_overlap: int = 100
67
+ hybrid_k: int = 2
68
  assistant_role: str = (
69
+ "You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
70
  )
71
  skip_warmup: bool = True
72
+
73
+ # CORS
74
  allowed_origins: List[str] = ["*"]
75
 
76
  class Config:
77
  env_file = ".env"
78
 
79
+ # Instantiate settings and logger
80
  settings = Settings()
81
+ logging.basicConfig(level=logging.INFO)
82
+ logger = logging.getLogger(__name__)
83
 
84
+ # Pre-instantiate the embedding model for reuse
85
+ EMBEDDING = HuggingFaceEmbeddings(
86
+ model_name=settings.embedding_model,
87
+ model_kwargs={"trust_remote_code": True},
88
+ )
89
 
90
  @lru_cache(maxsize=256)
91
  def embed_query_cached(query: str) -> List[float]:
92
+ """Cache embedding vectors for repeated queries."""
93
  return EMBEDDING.embed_query(query.strip().lower())
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # ---------------------------------------------------------------------------- #
96
+ # Application Lifespan #
97
  # ---------------------------------------------------------------------------- #
98
+
99
+ app = FastAPI(lifespan=lambda app: lifespan(app))
100
+
101
  @asynccontextmanager
102
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
103
+ """
104
+ Startup: initialize RAG chain, embeddings, memory, indexes, and load data.
105
+ Shutdown: clean up resources if needed.
106
+ """
107
+ # 1) Initialize document compressor
108
  logger.info("Initializing Document Compressor")
109
  compressor = DocumentCompressorPipeline(
110
  transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
111
  )
112
 
113
+ # 2) Initialize and quantize Cross-Encoder
114
  logger.info("Initializing Cross-Encoder")
115
  cross_encoder = CrossEncoder(settings.cross_encoder_model)
 
 
116
  cross_encoder.model = torch.quantization.quantize_dynamic(
117
  cross_encoder.model,
118
  {torch.nn.Linear},
 
120
  )
121
  logger.info("Cross-Encoder quantized")
122
 
123
+ # 3) Build Seq2Seq pipeline and wrap in LangChain
124
+ logger.info("Initializing LLM pipeline")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  tokenizer = T5Tokenizer.from_pretrained(settings.llm_model)
126
+ model = T5ForConditionalGeneration.from_pretrained(settings.llm_model)
127
+ model = torch.quantization.quantize_dynamic(
128
  model, {torch.nn.Linear}, dtype=torch.qint8
129
  )
 
130
  gen_pipe = pipeline(
131
  "text2text-generation",
132
  model=model,
 
136
  do_sample=True,
137
  temperature=0.7,
138
  )
 
139
  llm = HuggingFacePipeline(pipeline=gen_pipe)
140
 
141
+ # 4) Initialize conversation memory
142
  logger.info("Initializing Conversation Memory")
143
  memory = ConversationBufferWindowMemory(
144
  memory_key="chat_history",
 
145
  input_key="question",
146
  output_key="answer",
147
  return_messages=True,
148
+ k=5,
149
  )
150
+
151
+ # 5) Build or load indexes for vectorstore and Whoosh
152
+ logger.info("Building or loading indexes")
153
  vs, ix = await build_indexes(
154
  settings.parquet_path,
155
  settings.vectorstore_path,
 
158
  settings.chunk_overlap,
159
  None,
160
  )
 
161
  retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
162
+
163
+ # 6) Define prompt template for RAG chain
164
  prompt = PromptTemplate.from_template(
165
+ f"{settings.assistant_role}\n"
166
  "{context}\n"
167
+ "User Question:\n{question}\n"
168
+ "Answer:" # Rules are embedded in assistant_role
169
+ )
 
 
 
 
 
170
 
171
+ # 7) Instantiate the conversational retrieval chain
172
  logger.info("Initializing Retrieval Chain")
173
  app.state.rag_chain = ConversationalRetrievalChain.from_llm(
174
  llm=llm,
 
178
  return_source_documents=True,
179
  )
180
 
181
+ # Optional warmup
182
  if not settings.skip_warmup:
183
+ logger.info("Warming up RAG chain")
184
  await app.state.rag_chain.ainvoke({"question": "warmup"})
 
185
 
186
+ # 8) Load project data into Polars DataFrame
187
+ logger.info("Loading Parquet data from GCS")
188
  fs = gcsfs.GCSFileSystem()
189
  with fs.open(settings.parquet_path, "rb") as f:
190
  df = pl.read_parquet(f)
191
+ # Cast id to integer and lowercase key columns for filtering
192
  df = df.with_columns(
193
+ pl.col("id").cast(pl.Int64),
194
+ *(pl.col(col).str.to_lowercase().alias(f"_{col}_lc") for col in [
195
+ "title", "status", "legalBasis", "fundingScheme"
196
+ ])
197
  )
198
 
199
+ # Cache DataFrame and filter values in app state
 
 
 
 
200
  app.state.df = df
201
+ app.state.statuses = df["_status_lc"].unique().to_list()
202
+ app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
203
+ app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
 
204
  app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
205
 
206
+ yield # Application is ready
207
+
208
  # ---------------------------------------------------------------------------- #
209
  # App Setup #
210
  # ---------------------------------------------------------------------------- #
 
217
  )
218
 
219
  # ---------------------------------------------------------------------------- #
220
+ # Pydantic Models #
221
  # ---------------------------------------------------------------------------- #
222
+
223
  class RAGRequest(BaseModel):
224
+ session_id: Optional[str] = None # Optional conversation ID
225
+ query: str # User's query text
226
 
227
  class RAGResponse(BaseModel):
228
  answer: str
229
  source_ids: List[str]
230
+
231
+ # ---------------------------------------------------------------------------- #
232
+ # RAG Endpoint #
233
+ # ---------------------------------------------------------------------------- #
234
 
235
  def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
236
+ """
237
+ Dependency injector to retrieve the initialized RAG chain from the application state.
238
+ Raises HTTPException if chain is not yet initialized.
239
+ """
240
  chain = app.state.rag_chain
241
  if chain is None:
242
+ # If the chain isn't set up, respond with a 500 server error
243
  raise HTTPException(status_code=500, detail="RAG chain not initialized")
244
  return chain
245
 
 
248
  req: RAGRequest,
249
  rag_chain = Depends(rag_chain_depender)
250
  ):
251
+ """
252
+ Endpoint to process a RAG-based query.
253
+
254
+ 1. Logs start of processing.
255
+ 2. Invokes the RAG chain asynchronously with the user question.
256
+ 3. Validates returned result structure and extracts answer + source IDs.
257
+ 4. Handles any exceptions by logging traceback and returning a JSON error.
258
+ """
259
+ logger.info("Starting to answer RAG query")
260
  try:
261
+ # Asynchronously invoke the chain to get answer + docs
262
  result = await rag_chain.ainvoke({"question": req.query})
263
+ logger.info("RAG results retrieved")
264
+
265
+ # Validate that the chain returned expected dict
266
  if not isinstance(result, dict):
267
+ # Try sync call for debugging
268
  result2 = await rag_chain.acall({"question": req.query})
269
+ raise ValueError(
270
+ f"Expected dict from chain, got {type(result)}; "
271
+ f"acall() returned {type(result2)}"
272
+ )
273
+
274
+ # Extract answer text and source document IDs
275
  answer = result.get("answer")
276
  docs = result.get("source_documents", [])
277
+ sources = [d.metadata.get("id", "") for d in docs]
278
+
279
  return RAGResponse(answer=answer, source_ids=sources)
280
 
281
  except Exception as e:
282
+ # Log full stacktrace to container logs
283
  traceback.print_exc()
284
+ # Return HTTP 500 with error detail
285
  raise HTTPException(status_code=500, detail=str(e))
286
 
287
+
288
  # ---------------------------------------------------------------------------- #
289
  # Data Endpoints #
290
  # ---------------------------------------------------------------------------- #
291
+
292
  @app.get("/api/projects")
293
  def get_projects(
294
  page: int = 0,
 
303
  sortOrder: str = "desc",
304
  sortField: str = "startDate",
305
  ):
306
+ """
307
+ Paginated project listing with optional filtering and sorting.
308
+
309
+ Query Parameters:
310
+ - page: zero-based page index
311
+ - limit: number of items per page
312
+ - search: substring search in project title
313
+ - status, legalBasis, organization, country, fundingScheme: filters
314
+ - proj_id: exact project ID filter
315
+ - sortOrder: 'asc' or 'desc'
316
+ - sortField: field name to sort by (fallback to startDate)
317
+
318
+ Returns a list of project dicts including explanations and publication counts.
319
+ """
320
  df: pl.DataFrame = app.state.df
321
  start = page * limit
322
  sel = df
323
 
324
+ # Apply text and field filters as needed
325
  if search:
326
  sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
327
  if status:
 
340
  if proj_id:
341
  sel = sel.filter(pl.col("id") == int(proj_id))
342
 
343
+ # Base columns to return
344
  base_cols = [
345
  "id","title","status","startDate","endDate","ecMaxContribution","acronym",
346
  "legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
347
  "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
348
  ]
349
+ # Append top feature & SHAP value columns
350
  for i in range(1,7):
351
  base_cols += [f"top{i}_feature", f"top{i}_shap"]
352
  base_cols += ["predicted_label","predicted_prob"]
353
 
354
+ # Determine sort direction and safe field
355
+ sort_desc = sortOrder.lower() == "desc"
356
  sortField = sortField if sortField in df.columns else "startDate"
357
 
358
+ # Query, sort, slice, and collect to Python dicts
359
  rows = (
360
  sel.sort(sortField, descending=sort_desc)
361
  .slice(start, limit)
 
365
 
366
  projects = []
367
  for row in rows:
368
+ # Reformat SHAP explanations into list of dicts
369
  explanations = []
370
  for i in range(1,7):
371
  feat = row.pop(f"top{i}_feature", None)
 
374
  explanations.append({"feature": feat, "shap": shap})
375
  row["explanations"] = explanations
376
 
377
+ # Aggregate publications counts
378
  raw_pubs = row.pop("list_publications", []) or []
379
  pub_counts: Dict[str,int] = {}
380
  for p in raw_pubs:
 
387
 
388
  @app.get("/api/filters")
389
  def get_filters(request: Request):
390
+ """
391
+ Retrieve available filter options based on current dataset and optional query filters.
392
+
393
+ Returns JSON with lists for statuses, legalBases, organizations, countries, and fundingSchemes.
394
+ """
395
  df = app.state.df
396
  params = request.query_params
397
 
398
+ # Dynamically filter df based on provided params
399
  if s := params.get("status"):
400
+ df = df.filter(pl.col("status").is_null() if s == "UNKNOWN"
401
+ else pl.col("_status_lc") == s.lower())
402
  if lb := params.get("legalBasis"):
403
+ df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
404
  if org := params.get("organization"):
405
  df = df.filter(pl.col("list_name").list.contains(org))
406
  if c := params.get("country"):
 
409
  df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
410
 
411
  def normalize(vals):
412
+ # Map None to "UNKNOWN" and return sorted unique list
413
  return sorted({("UNKNOWN" if v is None else v) for v in vals})
414
 
415
  return {
 
418
  "organizations": normalize(df["list_name"].explode().to_list())[:500],
419
  "countries": normalize(df["list_country"].explode().to_list()),
420
  "fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
 
421
  }
422
 
423
  @app.get("/api/stats")
424
  def get_stats(request: Request):
425
+ """
426
+ Compute annual statistics on projects with optional filters for status, legal basis, etc.
427
+
428
+ Returns a dict of chart data for projects per year.
429
+ """
430
  lf = app.state.df.lazy()
431
  params = request.query_params
432
 
433
+ # Apply lazy filters
434
  if s := params.get("status"):
435
+ lf = lf.filter(pl.col("_status_lc") == s.lower())
436
  if lb := params.get("legalBasis"):
437
+ lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower())
438
  if org := params.get("organization"):
439
  lf = lf.filter(pl.col("list_name").list.contains(org))
440
  if c := params.get("country"):
441
  lf = lf.filter(pl.col("list_country").list.contains(c))
442
  if mn := params.get("minFunding"):
443
+ lf = lf.filter(pl.col("ecMaxContribution") >= int(mn))
444
  if mx := params.get("maxFunding"):
445
+ lf = lf.filter(pl.col("ecMaxContribution") <= int(mx))
446
  if y1 := params.get("minYear"):
447
+ lf = lf.filter(pl.col("startDate").dt.year() >= int(y1))
448
  if y2 := params.get("maxYear"):
449
+ lf = lf.filter(pl.col("startDate").dt.year() <= int(y2))
450
 
451
+ # Group by year and count
452
  grouped = (
453
  lf.select(pl.col("startDate").dt.year().alias("year"))
454
  .group_by("year")
 
458
  )
459
  years, counts = grouped["year"].to_list(), grouped["count"].to_list()
460
 
461
+ # Return data ready for frontend charts
462
  return {
463
  "Projects per Year": {"labels": years, "values": counts},
464
  "Projects per Year 2": {"labels": years, "values": counts},
 
470
 
471
  @app.get("/api/project/{project_id}/organizations")
472
  def get_project_organizations(project_id: str):
473
+ """
474
+ Retrieve organization details for a given project ID, including geolocation.
475
+
476
+ Raises 404 if the project ID does not exist.
477
+ """
478
  df = app.state.df
479
+ sel = df.filter(pl.col("id") == int(project_id))
480
  if sel.is_empty():
481
  raise HTTPException(status_code=404, detail="Project not found")
482
 
483
+ # Explode list columns and parse latitude/longitude
484
  orgs_df = (
485
  sel.select([
486
  pl.col("list_name").explode().alias("name"),
 
494
  pl.col("list_geolocation").explode().alias("geoloc"),
495
  ])
496
  .with_columns([
497
+ # Split "lat,lon" string into list
498
  pl.col("geoloc").str.split(",").alias("latlon"),
499
  ])
500
  .with_columns([
501
+ # Cast to floats for numeric use
502
  pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
503
  pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
504
  ])
 
508
  "activityType","orgURL","country","latitude","longitude"
509
  ])
510
  )
511
+ logger.info(f"Organization data for project {project_id}: {orgs_df.to_dicts()}")
512
  return orgs_df.to_dicts()
513
+
backend/rag.py CHANGED
@@ -33,6 +33,48 @@ from whoosh.qparser import MultifieldParser
33
  from tqdm import tqdm
34
  import faiss
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  from functools import lru_cache
37
 
38
  # === Logging ===
@@ -46,16 +88,16 @@ class Settings(BaseSettings):
46
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
47
  # Models
48
  embedding_model: str = "sentence-transformers/LaBSE"
49
- llm_model: str = "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
50
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
51
  # RAG parameters
52
  chunk_size: int = 750
53
  chunk_overlap: int = 100
54
- hybrid_k: int = 50
55
  assistant_role: str = (
56
- "You are a concise, factual assistant. Cite Document [ID] for each claim."
57
  )
58
- skip_warmup: bool = False
59
  allowed_origins: List[str] = ["*"]
60
 
61
  class Config:
@@ -66,50 +108,81 @@ settings = Settings()
66
  # === Global Embeddings & Cache ===
67
  EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
68
 
69
- @lru_cache(maxsize=256)
70
  def embed_query_cached(query: str) -> List[float]:
71
  """Cache embedding vectors for queries."""
72
  return EMBEDDING.embed_query(query.strip().lower())
73
 
74
  # === Whoosh Cache & Builder ===
75
- _WHOOSH_CACHE: Dict[str, index.Index] = {}
76
-
77
  async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
78
- key = whoosh_dir
 
 
 
79
  fs = gcsfs.GCSFileSystem()
80
- local_dir = key
81
- is_gcs = key.startswith("gs://")
82
- try:
83
- # stage local copy for GCS
84
- if is_gcs:
85
- local_dir = "/tmp/whoosh_index"
86
- if not os.path.exists(local_dir):
87
- if await run_in_threadpool(fs.exists, key):
88
- await run_in_threadpool(fs.get, key, local_dir, recursive=True)
89
- else:
90
- os.makedirs(local_dir, exist_ok=True)
91
- # build once
92
- if key not in _WHOOSH_CACHE:
93
- os.makedirs(local_dir, exist_ok=True)
94
- schema = Schema(
95
- id=ID(stored=True, unique=True),
96
- content=TEXT(analyzer=StemmingAnalyzer()),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
- ix = index.create_in(local_dir, schema)
99
- with ix.writer() as writer:
100
- for doc in docs:
101
- writer.add_document(
102
- id=doc.metadata.get("id", ""),
103
- content=doc.page_content,
104
- )
105
- # push back to GCS atomically
106
- if is_gcs:
107
- await run_in_threadpool(fs.put, local_dir, key, recursive=True)
108
- _WHOOSH_CACHE[key] = ix
109
- return _WHOOSH_CACHE[key]
110
- except Exception as e:
111
- logger.error(f"Failed to build Whoosh index: {e}")
112
- raise
 
 
 
113
 
114
  # === Document Loader ===
115
  async def load_documents(
@@ -117,7 +190,8 @@ async def load_documents(
117
  sample_size: Optional[int] = None
118
  ) -> List[Document]:
119
  """
120
- Load a Parquet file from local or GCS, convert to a list of Documents.
 
121
  """
122
  def _read_local(p: str, n: Optional[int]):
123
  # streaming scan keeps memory low
@@ -228,81 +302,119 @@ async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
228
 
229
  # === Helper: build or load FAISS with mmap ===
230
  async def build_or_load_faiss(
231
- chunks: List[Document],
232
  vectorstore_path: str,
233
  batch_size: int = 15000
234
  ) -> FAISS:
235
- faiss_index_file = os.path.join(vectorstore_path, "index.faiss")
236
- # If on-disk exists: memory-map the FAISS index and load metadata separately
237
- if os.path.exists(faiss_index_file):
238
- logger.info("Memory-mapping existing FAISS index...")
239
- mmap_idx = faiss.read_index(faiss_index_file, faiss.IO_FLAG_MMAP)
240
- # Manually load metadata (docstore and index_to_docstore) without loading the index
241
- import pickle
242
- for meta_file in ["faiss.pkl", "index.pkl"]:
243
- meta_path = os.path.join(vectorstore_path, meta_file)
244
- if os.path.exists(meta_path):
245
- with open(meta_path, "rb") as f:
246
- saved = pickle.load(f)
247
- break
248
- else:
249
- raise FileNotFoundError(
250
- f"Could not find FAISS metadata pickle in {vectorstore_path}"
251
- )
252
- # extract metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  if isinstance(saved, tuple):
254
- # Handle metadata tuple of length 2 or 3
255
- if len(saved) == 3:
256
- _, docstore, index_to_docstore = saved
257
- elif len(saved) == 2:
258
- docstore, index_to_docstore = saved
259
- else:
260
- raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
261
  else:
262
- if hasattr(saved, 'docstore'):
263
- docstore = saved.docstore
264
- elif hasattr(saved, '_docstore'):
265
- docstore = saved._docstore
266
- else:
267
- raise AttributeError("Could not find docstore in FAISS metadata")
268
- if hasattr(saved, 'index_to_docstore'):
269
- index_to_docstore = saved.index_to_docstore
270
- elif hasattr(saved, '_index_to_docstore'):
271
- index_to_docstore = saved._index_to_docstore
272
- elif hasattr(saved, '_faiss_index_to_docstore'):
273
- index_to_docstore = saved._faiss_index_to_docstore
274
- else:
275
- raise AttributeError("Could not find index_to_docstore in FAISS metadata")
276
- # reconstruct FAISS wrapper
277
  vs = FAISS(
278
  embedding_function=EMBEDDING,
279
- index=mmap_idx,
280
  docstore=docstore,
281
  index_to_docstore_id=index_to_docstore,
282
  )
 
283
  return vs
284
 
285
- # 2) Else: build from scratch in batches
286
- logger.info(f"Building FAISS index in batches of {batch_size}…")
287
- vs: Optional[FAISS] = None
288
- for i in tqdm(range(0, len(chunks), batch_size),
289
- desc="Building FAISS index",
290
- unit="batch"):
291
- batch = chunks[i : i + batch_size]
292
 
 
 
 
293
  if vs is None:
294
  vs = FAISS.from_documents(batch, EMBEDDING)
295
  else:
296
  vs.add_documents(batch)
 
297
 
298
- # periodic save every 5 batches
299
- if (i // batch_size) % 5 == 0:
300
- vs.save_local(vectorstore_path)
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
303
- assert vs is not None, "No documents to index!"
304
  return vs
305
 
 
306
  # === Index Builder ===
307
  async def build_indexes(
308
  parquet_path: str,
@@ -312,6 +424,9 @@ async def build_indexes(
312
  chunk_overlap: int,
313
  debug_size: Optional[int]
314
  ) -> Tuple[FAISS, index.Index]:
 
 
 
315
  docs = await load_documents(parquet_path, debug_size)
316
  ix = await build_whoosh_index(docs, whoosh_dir)
317
 
@@ -324,3 +439,48 @@ async def build_indexes(
324
  vs = await build_or_load_faiss(chunks, vectorstore_path)
325
 
326
  return vs, ix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  from tqdm import tqdm
34
  import faiss
35
 
36
+ from functools import lru_cache
37
+ from fastapi import FastAPI, Request, HTTPException, Depends
38
+ from fastapi.middleware.cors import CORSMiddleware
39
+ import traceback
40
+ from starlette.concurrency import run_in_threadpool
41
+ from pydantic import BaseModel
42
+ from pydantic_settings import BaseSettings
43
+ from contextlib import asynccontextmanager
44
+ from typing import Any, Dict, List, Optional, AsyncGenerator, Tuple
45
+
46
+ import os
47
+ import logging
48
+ import aiofiles
49
+ import polars as pl
50
+ import zipfile
51
+ import gcsfs
52
+
53
+ from langchain.schema import Document,BaseRetriever
54
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
55
+ from langchain_community.vectorstores import FAISS
56
+ from langchain.retrievers.document_compressors import DocumentCompressorPipeline
57
+ from langchain_community.document_transformers import EmbeddingsRedundantFilter
58
+ from langchain.memory import ConversationBufferWindowMemory
59
+ from langchain.chains import ConversationalRetrievalChain
60
+ from langchain.prompts import PromptTemplate
61
+ from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
62
+
63
+ from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer,T5ForConditionalGeneration
64
+ from sentence_transformers import CrossEncoder
65
+
66
+ from whoosh import index
67
+ from whoosh.fields import Schema, TEXT, ID
68
+ from whoosh.analysis import StemmingAnalyzer
69
+ from whoosh.qparser import MultifieldParser
70
+ import pickle
71
+ from pydantic import PrivateAttr
72
+ from tqdm import tqdm
73
+ import faiss
74
+ import torch
75
+ import tempfile
76
+ import shutil
77
+
78
  from functools import lru_cache
79
 
80
  # === Logging ===
 
88
  vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
89
  # Models
90
  embedding_model: str = "sentence-transformers/LaBSE"
91
+ llm_model: str = "google/flan-t5-base"
92
  cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
93
  # RAG parameters
94
  chunk_size: int = 750
95
  chunk_overlap: int = 100
96
+ hybrid_k: int = 2
97
  assistant_role: str = (
98
+ "You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
99
  )
100
+ skip_warmup: bool = True
101
  allowed_origins: List[str] = ["*"]
102
 
103
  class Config:
 
108
  # === Global Embeddings & Cache ===
109
  EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
110
 
111
+ @lru_cache(maxsize=128)
112
  def embed_query_cached(query: str) -> List[float]:
113
  """Cache embedding vectors for queries."""
114
  return EMBEDDING.embed_query(query.strip().lower())
115
 
116
  # === Whoosh Cache & Builder ===
 
 
117
  async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
118
+ """
119
+ If gs://.../whoosh_index.zip exists, download & extract it once.
120
+ Otherwise build locally from docs and upload the ZIP back to GCS.
121
+ """
122
  fs = gcsfs.GCSFileSystem()
123
+ is_gcs = whoosh_dir.startswith("gs://")
124
+ zip_uri = whoosh_dir.rstrip("/") + ".zip"
125
+
126
+ local_zip = "/tmp/whoosh_index.zip"
127
+ local_dir = "/tmp/whoosh_index"
128
+
129
+ # Clean slate
130
+ if os.path.exists(local_dir):
131
+ shutil.rmtree(local_dir)
132
+ os.makedirs(local_dir, exist_ok=True)
133
+
134
+ # 1️⃣ Try downloading the ZIP if it exists on GCS
135
+ if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
136
+ logger.info("Found whoosh_index.zip on GCS; downloading…")
137
+ await run_in_threadpool(fs.get, zip_uri, local_zip)
138
+ # Extract all files (flat) into local_dir
139
+ with zipfile.ZipFile(local_zip, "r") as zf:
140
+ for member in zf.infolist():
141
+ if member.is_dir():
142
+ continue
143
+ filename = os.path.basename(member.filename)
144
+ if not filename:
145
+ continue
146
+ target = os.path.join(local_dir, filename)
147
+ os.makedirs(os.path.dirname(target), exist_ok=True)
148
+ with zf.open(member) as src, open(target, "wb") as dst:
149
+ dst.write(src.read())
150
+ logger.info("Whoosh index extracted from ZIP.")
151
+ else:
152
+ logger.info("No whoosh_index.zip found; building index from docs.")
153
+
154
+ # Define the schema with stored content
155
+ schema = Schema(
156
+ id=ID(stored=True, unique=True),
157
+ content=TEXT(stored=True, analyzer=StemmingAnalyzer()),
158
+ )
159
+
160
+ # Create the index
161
+ ix = index.create_in(local_dir, schema)
162
+ writer = ix.writer()
163
+ for doc in docs:
164
+ writer.add_document(
165
+ id=doc.metadata.get("id", ""),
166
+ content=doc.page_content,
167
  )
168
+ writer.commit()
169
+ logger.info("Whoosh index built locally.")
170
+
171
+ # Upload the ZIP back to GCS
172
+ if is_gcs:
173
+ logger.info("Zipping and uploading new whoosh_index.zip to GCS…")
174
+ with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
175
+ for root, _, files in os.walk(local_dir):
176
+ for fname in files:
177
+ full = os.path.join(root, fname)
178
+ arc = os.path.relpath(full, local_dir)
179
+ zf.write(full, arc)
180
+ await run_in_threadpool(fs.put, local_zip, zip_uri)
181
+ logger.info("Uploaded whoosh_index.zip to GCS.")
182
+
183
+ # 2️⃣ Finally open the index and return it
184
+ ix = index.open_dir(local_dir)
185
+ return ix
186
 
187
  # === Document Loader ===
188
  async def load_documents(
 
190
  sample_size: Optional[int] = None
191
  ) -> List[Document]:
192
  """
193
+ Load project data from a Parquet file (local path or GCS URI),
194
+ assemble metadata context for each row, and return as Document objects.
195
  """
196
  def _read_local(p: str, n: Optional[int]):
197
  # streaming scan keeps memory low
 
302
 
303
  # === Helper: build or load FAISS with mmap ===
304
  async def build_or_load_faiss(
305
+ docs: List[Document],
306
  vectorstore_path: str,
307
  batch_size: int = 15000
308
  ) -> FAISS:
309
+ """
310
+ Expects a ZIP at vectorstore_path + ".zip" containing:
311
+ - index.faiss
312
+ - index.pkl
313
+ Files may be nested under a subfolder (e.g. vectorstore_index_colab/).
314
+ If the ZIP exists on GCS, download & load only.
315
+ Otherwise, build from `docs`, save, re-zip, and upload.
316
+ """
317
+ fs = gcsfs.GCSFileSystem()
318
+ is_gcs = vectorstore_path.startswith("gs://")
319
+ zip_uri = vectorstore_path.rstrip("/") + ".zip"
320
+
321
+ local_zip = "/tmp/faiss_index.zip"
322
+ local_dir = "/tmp/faiss_store"
323
+
324
+ # 1) if ZIP exists, download & extract
325
+ if is_gcs and await run_in_threadpool(fs.exists, zip_uri):
326
+ logger.info("Found FAISS ZIP on GCS; loading only.")
327
+ # clean slate
328
+ if os.path.exists(local_dir):
329
+ shutil.rmtree(local_dir)
330
+ os.makedirs(local_dir, exist_ok=True)
331
+
332
+ # download zip
333
+ await run_in_threadpool(fs.get, zip_uri, local_zip)
334
+
335
+ # extract
336
+ def _extract():
337
+ with zipfile.ZipFile(local_zip, "r") as zf:
338
+ zf.extractall(local_dir)
339
+ await run_in_threadpool(_extract)
340
+
341
+ # locate the two files anywhere under local_dir
342
+ idx_path = None
343
+ meta_path = None
344
+ for root, _, files in os.walk(local_dir):
345
+ if "index.faiss" in files:
346
+ idx_path = os.path.join(root, "index.faiss")
347
+ if "index.pkl" in files:
348
+ meta_path = os.path.join(root, "index.pkl")
349
+ if not idx_path or not meta_path:
350
+ raise FileNotFoundError("Couldn't find index.faiss or index.pkl in extracted ZIP.")
351
+
352
+ # memory-map load
353
+ mmap_index = await run_in_threadpool(
354
+ faiss.read_index, idx_path, faiss.IO_FLAG_MMAP
355
+ )
356
+
357
+ # load metadata
358
+ with open(meta_path, "rb") as f:
359
+ saved = pickle.load(f)
360
+
361
+ # unpack metadata
362
  if isinstance(saved, tuple):
363
+ _, docstore, index_to_docstore = (
364
+ saved if len(saved) == 3 else (None, *saved)
365
+ )
 
 
 
 
366
  else:
367
+ docstore = getattr(saved, "docstore", saved._docstore)
368
+ index_to_docstore = getattr(
369
+ saved,
370
+ "index_to_docstore",
371
+ getattr(saved, "_index_to_docstore", saved._faiss_index_to_docstore)
372
+ )
373
+
374
+ # reconstruct FAISS
 
 
 
 
 
 
 
375
  vs = FAISS(
376
  embedding_function=EMBEDDING,
377
+ index=mmap_index,
378
  docstore=docstore,
379
  index_to_docstore_id=index_to_docstore,
380
  )
381
+ logger.info("FAISS index loaded from ZIP.")
382
  return vs
383
 
384
+ # 2) otherwise, build from scratch and upload
385
+ logger.info("No FAISS ZIP found; building index from scratch.")
386
+ if os.path.exists(local_dir):
387
+ shutil.rmtree(local_dir)
388
+ os.makedirs(local_dir, exist_ok=True)
 
 
389
 
390
+ vs: FAISS = None
391
+ for i in range(0, len(docs), batch_size):
392
+ batch = docs[i : i + batch_size]
393
  if vs is None:
394
  vs = FAISS.from_documents(batch, EMBEDDING)
395
  else:
396
  vs.add_documents(batch)
397
+ assert vs is not None, "No documents to index!"
398
 
399
+ # save locally
400
+ vs.save_local(local_dir)
401
+
402
+ if is_gcs:
403
+ # re-zip all contents of local_dir (flattened)
404
+ def _zip_dir():
405
+ with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf:
406
+ for root, _, files in os.walk(local_dir):
407
+ for fname in files:
408
+ full = os.path.join(root, fname)
409
+ arc = os.path.relpath(full, local_dir)
410
+ zf.write(full, arc)
411
+ await run_in_threadpool(_zip_dir)
412
+ await run_in_threadpool(fs.put, local_zip, zip_uri)
413
+ logger.info("Built FAISS index and uploaded ZIP to GCS.")
414
 
 
 
415
  return vs
416
 
417
+
418
  # === Index Builder ===
419
  async def build_indexes(
420
  parquet_path: str,
 
424
  chunk_overlap: int,
425
  debug_size: Optional[int]
426
  ) -> Tuple[FAISS, index.Index]:
427
+ """
428
+ Load documents, build/load Whoosh and FAISS indices, and return both.
429
+ """
430
  docs = await load_documents(parquet_path, debug_size)
431
  ix = await build_whoosh_index(docs, whoosh_dir)
432
 
 
439
  vs = await build_or_load_faiss(chunks, vectorstore_path)
440
 
441
  return vs, ix
442
+
443
+ # === Hybrid Retriever ===
444
+ class HybridRetriever(BaseRetriever):
445
+ """Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
446
+ # store FAISS and Whoosh under private attributes to avoid Pydantic field errors
447
+ _vs: FAISS = PrivateAttr()
448
+ _ix: index.Index = PrivateAttr()
449
+ _compressor: DocumentCompressorPipeline = PrivateAttr()
450
+ _cross_encoder: CrossEncoder = PrivateAttr()
451
+
452
+ def __init__(
453
+ self,
454
+ vs: FAISS,
455
+ ix: index.Index,
456
+ compressor: DocumentCompressorPipeline,
457
+ cross_encoder: CrossEncoder
458
+ ) -> None:
459
+ super().__init__()
460
+ object.__setattr__(self, '_vs', vs)
461
+ object.__setattr__(self, '_ix', ix)
462
+ object.__setattr__(self, '_compressor', compressor)
463
+ object.__setattr__(self, '_cross_encoder', cross_encoder)
464
+
465
+ async def _aget_relevant_documents(self, query: str) -> List[Document]:
466
+ # BM25 retrieval using Whoosh index
467
+ bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
468
+ # Dense retrieval using FAISS
469
+ dense_docs = self._vs.similarity_search_by_vector(
470
+ embed_query_cached(query), k=settings.hybrid_k
471
+ )
472
+ # Cross-encoder re-ranking
473
+ candidates = bm_docs + dense_docs
474
+ scores = self._cross_encoder.predict([
475
+ (query, doc.page_content) for doc in candidates
476
+ ])
477
+ ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
478
+ top = [doc for _, doc in ranked[: settings.hybrid_k]]
479
+ # Compress and return
480
+ return self._compressor.compress_documents(top, query=query)
481
+
482
+ def _get_relevant_documents(self, query: str) -> List[Document]:
483
+ import asyncio
484
+ return asyncio.get_event_loop().run_until_complete(
485
+ self._aget_relevant_documents(query)
486
+ )
backend/requirements.txt CHANGED
@@ -16,12 +16,8 @@ google-cloud-storage==2.11.0
16
  # RAG & embeddings
17
  pytorch-lightning==2.5.1
18
  langchain
19
- #==0.3.25
20
  langchain-huggingface
21
- #==0.2.0
22
  sentence-transformers
23
- #==4.1.0
24
- # 2.7.0
25
  langchain-community==0.3.24
26
 
27
  # Vector store
@@ -37,15 +33,12 @@ tiktoken>=0.4.0
37
  sentencepiece==0.2.0
38
  transformers
39
  torchvision
40
- #accelerate==0.32.0
41
  sympy>=1.13.1
42
  peft
43
- #==0.11.1
44
 
45
  aiofiles==24.1.0
46
  optimum==1.25.3
47
  bitsandbytes==0.45.5
48
- #gptqmodel==2.2.0
49
  hf_xet
50
  HuggingFace
51
  huggingface_hub
 
16
  # RAG & embeddings
17
  pytorch-lightning==2.5.1
18
  langchain
 
19
  langchain-huggingface
 
20
  sentence-transformers
 
 
21
  langchain-community==0.3.24
22
 
23
  # Vector store
 
33
  sentencepiece==0.2.0
34
  transformers
35
  torchvision
 
36
  sympy>=1.13.1
37
  peft
 
38
 
39
  aiofiles==24.1.0
40
  optimum==1.25.3
41
  bitsandbytes==0.45.5
 
42
  hf_xet
43
  HuggingFace
44
  huggingface_hub
frontend/Dockerfile DELETED
@@ -1,10 +0,0 @@
1
- FROM node:18-alpine AS builder
2
- WORKDIR /app
3
- COPY package.json package-lock.json ./
4
- RUN npm ci
5
- COPY . .
6
- RUN npm run build
7
- FROM nginx:alpine
8
- COPY --from=builder /app/dist /usr/share/nginx/html
9
- EXPOSE 8000
10
- CMD ["nginx", "-g", "daemon off;"]
 
 
 
 
 
 
 
 
 
 
 
frontend/src/hooks/useAppState.ts CHANGED
@@ -1,6 +1,12 @@
1
  import { useState, useEffect, useRef } from "react";
2
  import { debounce } from "lodash";
3
- import type { Project, OrganizationLocation, FilterState, ChatMessage,AvailableFilters } from "./types";
 
 
 
 
 
 
4
 
5
  interface Stats {
6
  [key: string]: {
@@ -12,20 +18,25 @@ interface Stats {
12
  type SortOrder = "asc" | "desc";
13
 
14
  export const useAppState = () => {
 
15
  const [projects, setProjects] = useState<Project[]>([]);
 
 
 
16
  const [search, setSearch] = useState<string>("");
17
  const [statusFilter, setStatusFilter] = useState<string>("");
18
- const [page, setPage] = useState<number>(0);
19
- const [question, setQuestion] = useState<string>("");
20
- const [selectedProject, setSelectedProject] = useState<Project | null>(null);
21
- const [stats, setStats] = useState<Stats>({});
22
- const [legalFilter, setLegalFilter] = useState('');
23
- const [orgFilter, setOrgFilter] = useState('');
24
- const [countryFilter, setCountryFilter] = useState('');
25
- const [fundingSchemeFilter, setFundingSchemeFilter ] = useState('');
26
- const [idFilter, setIdFilter] = useState('');
27
- const [sortField, setSortField] = useState('');
28
  const [sortOrder, setSortOrder] = useState<SortOrder>("asc");
 
 
 
29
  const [filters, setFilters] = useState<FilterState>({
30
  status: "",
31
  organization: "",
@@ -35,146 +46,135 @@ export const useAppState = () => {
35
  maxYear: "2025",
36
  minFunding: "0",
37
  maxFunding: "10000000",
38
- });
39
-
40
- const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
41
- const [loading, setLoading] = useState<boolean>(false);
42
-
43
  const [availableFilters, setAvailableFilters] = useState<AvailableFilters>({
44
- statuses: ["SIGNED", "CLOSED", "TERMINATED","UNKNOWN"],
45
  organizations: [],
46
  countries: [],
47
  legalBases: [],
48
- fundingSchemes:[],
49
- ids:[]
50
  });
51
 
 
 
 
 
52
  const messagesEndRef = useRef<HTMLDivElement | null>(null);
53
 
 
54
  const fetchProjects = () => {
55
- fetch(`/api/projects?page=${page}&search=${encodeURIComponent(search)}&status=${statusFilter}&legalBasis=${legalFilter}&organization=${orgFilter}&country=${countryFilter}&fundingScheme=${fundingSchemeFilter}&proj_id=${idFilter}&sortField=${sortField}&sortOrder=${sortOrder}`)
56
- .then(res => res.json())
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  .then((data: Project[]) => setProjects(data))
58
- .catch(console.error);
59
  };
60
 
61
- const fetchStats = debounce((filters: FilterState) => {
62
- const params = new URLSearchParams(filters);
63
- fetch(`/api/stats?${params.toString()}`)
64
- .then(res => res.json())
 
65
  .then((data: Stats) => setStats(data))
66
- .catch(console.error);
67
  }, 500);
68
 
69
- const fetchAvailableFilters = (filters: FilterState) => {
70
- const params = new URLSearchParams(filters);
71
- fetch(`/api/filters?${params.toString()}`)
72
- .then(res => res.json())
73
- .then((data: Omit<AvailableFilters, 'statuses'>) => {
 
74
  setAvailableFilters({
75
  statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
76
  organizations: data.organizations,
77
  countries: data.countries,
78
  legalBases: data.legalBases,
79
  fundingSchemes: data.fundingSchemes,
80
- ids: []
81
  });
82
- });
 
83
  };
 
84
  interface RagResponse {
85
  answer: string;
86
  source_ids: string[];
87
  }
88
 
 
89
  const askChatbot = async () => {
90
- if (!question.trim() || loading) return;
91
- const newChat: ChatMessage[] = [
92
- ...chatHistory,
93
- { role: "user", content: question },
94
- ];
95
- setChatHistory(newChat);
96
  setQuestion("");
97
- setLoading(true);
98
 
99
- // 1) placeholder
100
- setChatHistory((h) => [
101
- ...h,
102
- { role: "assistant", content: "Generating answer..." },
103
- ]);
104
 
105
  try {
106
- const res = await fetch("/api/rag", {
107
  method: "POST",
108
  headers: { "Content-Type": "application/json" },
109
  body: JSON.stringify({ query: question }),
110
  });
111
 
112
- const text = await res.text();
113
- if (!res.ok) {
114
- let errDetail = text;
115
- try {
116
- errDetail = JSON.parse(text).detail;
117
- } catch {}
118
- throw new Error(errDetail);
119
  }
120
 
121
- const data: RagResponse = JSON.parse(text);
122
- const idList = data.source_ids.join(", ") || "none";
123
- const assistantContent = `${data.answer}
124
-
125
- The output was based on the following Project IDs: ${idList}`;
126
 
127
- // 2) replace placeholder with real answer
128
- setChatHistory((h) => [
129
- ...h.slice(0, -1),
130
- { role: "assistant", content: assistantContent },
131
- ]);
132
  } catch (err: any) {
133
- // replace placeholder with error message
134
- setChatHistory((h) => [
135
- ...h.slice(0, -1),
136
- {
137
- role: "assistant",
138
- content: `Something went wrong: ${err.message}`,
139
- },
140
  ]);
141
  } finally {
142
  setLoading(false);
143
- // scroll to bottom
144
  messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
145
  }
146
  };
147
 
 
148
  useEffect(() => {
149
- // If the user has typed something but it's too short, don't refetch
150
  fetchProjects();
151
- }, [
152
- page,
153
- search,
154
- statusFilter,
155
- legalFilter,
156
- orgFilter,
157
- countryFilter,
158
- fundingSchemeFilter,
159
- idFilter,
160
- sortField,
161
- sortOrder,
162
- ]);
163
 
164
  useEffect(() => {
165
- console.log("Updated filters:", filters);
166
  fetchStats(filters);
167
- }, [filters]);
168
- useEffect(() => fetchAvailableFilters(filters), [filters]);
 
 
 
169
 
170
  return {
171
- selectedProject,
172
- dashboardProps: {
173
- stats,
174
- filters,
175
- setFilters,
176
- availableFilters
177
- },
178
  explorerProps: {
179
  projects,
180
  search,
@@ -191,29 +191,28 @@ The output was based on the following Project IDs: ${idList}`;
191
  setFundingSchemeFilter,
192
  idFilter,
193
  setIdFilter,
194
- setSortField,
195
  sortField,
196
- setSortOrder,
197
  sortOrder,
 
198
  page,
199
  setPage,
200
- setSelectedProject,
201
  question,
202
  setQuestion,
203
  chatHistory,
204
- setChatHistory,
205
  askChatbot,
206
  loading,
207
- messagesEndRef
208
  },
209
  detailsProps: {
210
- project: selectedProject!,
211
  question,
212
  setQuestion,
213
  chatHistory,
214
  askChatbot,
215
  loading,
216
- messagesEndRef
217
- }
218
  };
219
  };
 
1
  import { useState, useEffect, useRef } from "react";
2
  import { debounce } from "lodash";
3
+ import type {
4
+ Project,
5
+ OrganizationLocation,
6
+ FilterState,
7
+ ChatMessage,
8
+ AvailableFilters,
9
+ } from "./types";
10
 
11
  interface Stats {
12
  [key: string]: {
 
18
  type SortOrder = "asc" | "desc";
19
 
20
  export const useAppState = () => {
21
+ // Projects state and pagination
22
  const [projects, setProjects] = useState<Project[]>([]);
23
+ const [page, setPage] = useState<number>(0);
24
+
25
+ // Search and filter states
26
  const [search, setSearch] = useState<string>("");
27
  const [statusFilter, setStatusFilter] = useState<string>("");
28
+ const [legalFilter, setLegalFilter] = useState<string>("");
29
+ const [orgFilter, setOrgFilter] = useState<string>("");
30
+ const [countryFilter, setCountryFilter] = useState<string>("");
31
+ const [fundingSchemeFilter, setFundingSchemeFilter] = useState<string>("");
32
+ const [idFilter, setIdFilter] = useState<string>("");
33
+
34
+ // Sorting
35
+ const [sortField, setSortField] = useState<string>("");
 
 
36
  const [sortOrder, setSortOrder] = useState<SortOrder>("asc");
37
+
38
+ // Dashboard stats and available filters
39
+ const [stats, setStats] = useState<Stats>({});
40
  const [filters, setFilters] = useState<FilterState>({
41
  status: "",
42
  organization: "",
 
46
  maxYear: "2025",
47
  minFunding: "0",
48
  maxFunding: "10000000",
49
+ });
 
 
 
 
50
  const [availableFilters, setAvailableFilters] = useState<AvailableFilters>({
51
+ statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
52
  organizations: [],
53
  countries: [],
54
  legalBases: [],
55
+ fundingSchemes: [],
56
+ ids: [],
57
  });
58
 
59
+ // Chatbot states
60
+ const [question, setQuestion] = useState<string>("");
61
+ const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
62
+ const [loading, setLoading] = useState<boolean>(false);
63
  const messagesEndRef = useRef<HTMLDivElement | null>(null);
64
 
65
+ // Fetch projects with current filters, pagination, sorting
66
  const fetchProjects = () => {
67
+ const query = new URLSearchParams({
68
+ page: page.toString(),
69
+ search,
70
+ status: statusFilter,
71
+ legalBasis: legalFilter,
72
+ organization: orgFilter,
73
+ country: countryFilter,
74
+ fundingScheme: fundingSchemeFilter,
75
+ proj_id: idFilter,
76
+ sortField,
77
+ sortOrder,
78
+ }).toString();
79
+
80
+ fetch(`/api/projects?${query}`)
81
+ .then((res) => res.json())
82
  .then((data: Project[]) => setProjects(data))
83
+ .catch((err) => console.error("Error fetching projects:", err));
84
  };
85
 
86
+ // Fetch stats with debouncing to limit requests
87
+ const fetchStats = debounce((filterParams: FilterState) => {
88
+ const query = new URLSearchParams(filterParams as any).toString();
89
+ fetch(`/api/stats?${query}`)
90
+ .then((res) => res.json())
91
  .then((data: Stats) => setStats(data))
92
+ .catch((err) => console.error("Error fetching stats:", err));
93
  }, 500);
94
 
95
+ // Fetch available filter options based on dataset and active filters
96
+ const fetchAvailableFilters = (filterParams: FilterState) => {
97
+ const query = new URLSearchParams(filterParams as any).toString();
98
+ fetch(`/api/filters?${query}`)
99
+ .then((res) => res.json())
100
+ .then((data) => {
101
  setAvailableFilters({
102
  statuses: ["SIGNED", "CLOSED", "TERMINATED", "UNKNOWN"],
103
  organizations: data.organizations,
104
  countries: data.countries,
105
  legalBases: data.legalBases,
106
  fundingSchemes: data.fundingSchemes,
107
+ ids: [],
108
  });
109
+ })
110
+ .catch((err) => console.error("Error fetching filters:", err));
111
  };
112
+
113
  interface RagResponse {
114
  answer: string;
115
  source_ids: string[];
116
  }
117
 
118
+ // Handle chat submission
119
  const askChatbot = async () => {
120
+ if (!question.trim() || loading) return;
121
+
122
+ // Append user message
123
+ setChatHistory((prev) => [...prev, { role: "user", content: question }]);
 
 
124
  setQuestion("");
125
+ setLoading(true);
126
 
127
+ // Add placeholder while generating
128
+ setChatHistory((prev) => [...prev, { role: "assistant", content: "Generating answer..." }]);
 
 
 
129
 
130
  try {
131
+ const response = await fetch("/api/rag", {
132
  method: "POST",
133
  headers: { "Content-Type": "application/json" },
134
  body: JSON.stringify({ query: question }),
135
  });
136
 
137
+ const text = await response.text();
138
+ if (!response.ok) {
139
+ let detail = text;
140
+ try { detail = JSON.parse(text).detail; } catch {}
141
+ throw new Error(detail);
 
 
142
  }
143
 
144
+ const result: RagResponse = JSON.parse(text);
145
+ const sources = result.source_ids.length ? result.source_ids.join(", ") : "none";
146
+ const assistantContent = `${result.answer}\n\nSources: ${sources}`;
 
 
147
 
148
+ // Replace placeholder with actual answer
149
+ setChatHistory((prev) => [...prev.slice(0, -1), { role: "assistant", content: assistantContent }]);
 
 
 
150
  } catch (err: any) {
151
+ // Replace placeholder with error message
152
+ setChatHistory((prev) => [
153
+ ...prev.slice(0, -1),
154
+ { role: "assistant", content: `Error: ${err.message}` },
 
 
 
155
  ]);
156
  } finally {
157
  setLoading(false);
 
158
  messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
159
  }
160
  };
161
 
162
+ // Effects: refetch when filters/sorting/pagination change
163
  useEffect(() => {
 
164
  fetchProjects();
165
+ }, [page, search, statusFilter, legalFilter, orgFilter, countryFilter, fundingSchemeFilter, idFilter, sortField, sortOrder]);
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  useEffect(() => {
 
168
  fetchStats(filters);
169
+ }, [filters]);
170
+
171
+ useEffect(() => {
172
+ fetchAvailableFilters(filters);
173
+ }, [filters]);
174
 
175
  return {
176
+ selectedProject: projects[0] || null,
177
+ dashboardProps: { stats, filters, setFilters, availableFilters },
 
 
 
 
 
178
  explorerProps: {
179
  projects,
180
  search,
 
191
  setFundingSchemeFilter,
192
  idFilter,
193
  setIdFilter,
 
194
  sortField,
195
+ setSortField,
196
  sortOrder,
197
+ setSortOrder,
198
  page,
199
  setPage,
200
+ setSelectedProject: () => {},
201
  question,
202
  setQuestion,
203
  chatHistory,
 
204
  askChatbot,
205
  loading,
206
+ messagesEndRef,
207
  },
208
  detailsProps: {
209
+ project: projects[0]!,
210
  question,
211
  setQuestion,
212
  chatHistory,
213
  askChatbot,
214
  loading,
215
+ messagesEndRef,
216
+ },
217
  };
218
  };
run.sh CHANGED
@@ -1,18 +1,4 @@
1
  #!/bin/bash
2
- # Start nginx directly
3
- #service nginx start &
4
- #NGINX_PID=$!
5
- #nginx -g "daemon off;" &
6
- # Wait briefly to ensure nginx is up
7
- #sleep 1
8
- # Serve static build
9
- #python -m http.server --directory ./static --bind 0.0.0.0 8000 & echo $! > http_server.pid
10
- # Start FastAPI
11
- #uvicorn "app.main:app" --host 0.0.0.0 --port 7860
12
- # Cleanup static server on shutdown
13
- #pkill -F http_server.pid
14
- #rm http_server.pid
15
- # Start nginx in foreground
16
  echo "HF_SDK = $HF_SPACE_SDK, APP_PORT = $APP_PORT, PORT = $PORT"
17
  echo "$GCP_SA_JSON" > /tmp/sa.json
18
  chmod 600 /tmp/sa.json
 
1
  #!/bin/bash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  echo "HF_SDK = $HF_SPACE_SDK, APP_PORT = $APP_PORT, PORT = $PORT"
3
  echo "$GCP_SA_JSON" > /tmp/sa.json
4
  chmod 600 /tmp/sa.json