Spaces:
Running
Running
Merge remote-tracking branch 'origin/reranker'
Browse files- .txt +0 -0
- requirements.txt +2 -0
- test.ipynb +0 -0
- web2json/__pycache__/ai_extractor.cpython-311.pyc +0 -0
- web2json/__pycache__/pipeline.cpython-311.pyc +0 -0
- web2json/__pycache__/postprocessor.cpython-311.pyc +0 -0
- web2json/__pycache__/preprocessor.cpython-311.pyc +0 -0
- web2json/ai_extractor.py +120 -60
- web2json/pipeline.py +4 -4
.txt
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -13,4 +13,6 @@ langchain-text-splitters
|
|
13 |
sentence-transformers
|
14 |
openai
|
15 |
html_chunking
|
|
|
|
|
16 |
lxml
|
|
|
13 |
sentence-transformers
|
14 |
openai
|
15 |
html_chunking
|
16 |
+
langchain_nvidia_ai_endpoints
|
17 |
+
langchain_core
|
18 |
lxml
|
test.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
web2json/__pycache__/ai_extractor.cpython-311.pyc
CHANGED
Binary files a/web2json/__pycache__/ai_extractor.cpython-311.pyc and b/web2json/__pycache__/ai_extractor.cpython-311.pyc differ
|
|
web2json/__pycache__/pipeline.cpython-311.pyc
CHANGED
Binary files a/web2json/__pycache__/pipeline.cpython-311.pyc and b/web2json/__pycache__/pipeline.cpython-311.pyc differ
|
|
web2json/__pycache__/postprocessor.cpython-311.pyc
CHANGED
Binary files a/web2json/__pycache__/postprocessor.cpython-311.pyc and b/web2json/__pycache__/postprocessor.cpython-311.pyc differ
|
|
web2json/__pycache__/preprocessor.cpython-311.pyc
CHANGED
Binary files a/web2json/__pycache__/preprocessor.cpython-311.pyc and b/web2json/__pycache__/preprocessor.cpython-311.pyc differ
|
|
web2json/ai_extractor.py
CHANGED
@@ -11,12 +11,16 @@ from google.genai import types
|
|
11 |
from pydantic import BaseModel
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
13 |
from html_chunking import get_html_chunks
|
|
|
|
|
14 |
from abc import ABC, abstractmethod
|
15 |
from typing import List, Any, Dict, Tuple, Optional
|
16 |
import re
|
17 |
import json
|
18 |
from langchain_text_splitters import HTMLHeaderTextSplitter
|
19 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
20 |
class LLMClient(ABC):
|
21 |
"""
|
22 |
Abstract base class for calling LLM APIs.
|
@@ -227,6 +231,93 @@ class NvidiaLLMClient(LLMClient):
|
|
227 |
# You could set results[idx] = None or a default string
|
228 |
results[idx] = f"<failed after retries>"
|
229 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
|
232 |
class AIExtractor:
|
@@ -264,7 +355,7 @@ class LLMClassifierExtractor(AIExtractor):
|
|
264 |
Extractor that uses an LLM to classify and extract structured information from text content.
|
265 |
This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema.
|
266 |
"""
|
267 |
-
def __init__(self, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ):
|
268 |
"""
|
269 |
Initializes the LLMClassifierExtractor with an LLM client and a prompt template.
|
270 |
|
@@ -273,6 +364,7 @@ class LLMClassifierExtractor(AIExtractor):
|
|
273 |
prompt_template (str): The template to use for generating prompts for the LLM.
|
274 |
"""
|
275 |
super().__init__(llm_client, prompt_template)
|
|
|
276 |
self.classifier_prompt = classifier_prompt
|
277 |
|
278 |
def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]:
|
@@ -288,79 +380,47 @@ class LLMClassifierExtractor(AIExtractor):
|
|
288 |
# Use the get_html_chunks function to split the content into chunks
|
289 |
return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5)
|
290 |
|
291 |
-
|
292 |
-
def classify_chunks(self, chunks: List[str], schema: BaseModel) -> List[Dict[str, Any]]:
|
293 |
-
"""
|
294 |
-
Classifies each chunk using the LLM based on the provided schema.
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
schema (BaseModel): A Pydantic model defining the structure of the expected output.
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
prompts = [self.classifier_prompt.format(content=chunk, schema=schema.model_json_schema()) for chunk in chunks]
|
304 |
-
classified_chunks = []
|
305 |
-
responses = self.llm_client.call_batch(prompts)
|
306 |
-
for response in responses:
|
307 |
-
# extract the json from the response
|
308 |
-
json_data = extract_markdown_json(response)
|
309 |
-
if json_data:
|
310 |
-
classified_chunks.append(json_data)
|
311 |
-
else:
|
312 |
-
classified_chunks.append({
|
313 |
-
"error": "Failed to extract JSON from response",
|
314 |
-
"relevant": 1,
|
315 |
-
})
|
316 |
-
return classified_chunks
|
317 |
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
"""
|
320 |
Extracts structured information from the given content based on the provided schema.
|
321 |
|
322 |
Args:
|
323 |
content (str): The raw content to extract information from.
|
324 |
schema (BaseModel): A Pydantic model defining the structure of the expected output.
|
325 |
-
|
326 |
-
Returns:
|
327 |
-
str: The structured JSON object as a string.
|
328 |
"""
|
329 |
-
|
330 |
-
|
331 |
-
print(f"Content successfully chunked
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
print(f"Classified {classified_chunks} chunks.")
|
336 |
-
positive_chunks = []
|
337 |
-
for i, chunk in enumerate(classified_chunks):
|
338 |
-
if chunk.get("relevant", 0) > 0:
|
339 |
-
positive_chunks.append(chunks[i])
|
340 |
-
if len(positive_chunks) == 0:
|
341 |
-
positive_chunks = chunks
|
342 |
-
filtered_content = "\n\n".join(positive_chunks)
|
343 |
-
print(f"Filtered content for extraction: {filtered_content}") # Log the first 500 characters of filtered content
|
344 |
if not filtered_content:
|
345 |
print("Warning: No relevant chunks found. Returning empty response.")
|
346 |
return "{}"
|
347 |
-
|
348 |
prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema())
|
349 |
-
print(f"Generated prompt for extraction: {prompt[:500]}...")
|
350 |
-
# Call the LLM to extract structured information
|
351 |
llm_response = self.llm_client.call_api(prompt)
|
352 |
-
print(f"LLM response: {llm_response[:500]}...")
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
return "{}"
|
357 |
-
|
358 |
-
# json_response = extract_markdown_json(llm_response)
|
359 |
-
# if json_response is None:
|
360 |
-
# print("Warning: Failed to extract JSON from LLM response. Returning empty response.")
|
361 |
-
# return "{}"
|
362 |
-
|
363 |
-
return llm_response
|
364 |
|
365 |
# TODO: RAGExtractor class
|
366 |
class RAGExtractor(AIExtractor):
|
@@ -486,7 +546,7 @@ class RAGExtractor(AIExtractor):
|
|
486 |
|
487 |
if not query:
|
488 |
query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}"
|
489 |
-
print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'")
|
490 |
|
491 |
chunks = self._langchain_HHTS(content)
|
492 |
print(f"Content successfully chunked into {len(chunks)} pieces.")
|
|
|
11 |
from pydantic import BaseModel
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
13 |
from html_chunking import get_html_chunks
|
14 |
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
15 |
+
from langchain_core.documents import Document
|
16 |
from abc import ABC, abstractmethod
|
17 |
from typing import List, Any, Dict, Tuple, Optional
|
18 |
import re
|
19 |
import json
|
20 |
from langchain_text_splitters import HTMLHeaderTextSplitter
|
21 |
from sentence_transformers import SentenceTransformer
|
22 |
+
import requests
|
23 |
+
|
24 |
class LLMClient(ABC):
|
25 |
"""
|
26 |
Abstract base class for calling LLM APIs.
|
|
|
231 |
# You could set results[idx] = None or a default string
|
232 |
results[idx] = f"<failed after retries>"
|
233 |
return results
|
234 |
+
|
235 |
+
|
236 |
+
class NvidiaRerankerClient(LLMClient):
|
237 |
+
"""
|
238 |
+
Concrete implementation of LLMClient for the NVIDIA API (non-streaming).
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(self, config: dict):
|
242 |
+
self.model_name = config.get("model_name", "nvidia/llama-3.2-nv-rerankqa-1b-v2")
|
243 |
+
self.client = NVIDIARerank(
|
244 |
+
model=self.model_name,
|
245 |
+
api_key=os.getenv("NVIDIA_API_KEY"),
|
246 |
+
)
|
247 |
+
|
248 |
+
def set_model(self, model_name: str):
|
249 |
+
"""
|
250 |
+
Set the model name for the NVIDIA API client.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
model_name (str): The name of the model to use.
|
254 |
+
"""
|
255 |
+
self.model_name = model_name
|
256 |
+
|
257 |
+
@retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
|
258 |
+
def call_api(self, prompt: str) -> str:
|
259 |
+
pass
|
260 |
+
|
261 |
+
def call_batch(self, prompts, max_workers=8):
|
262 |
+
pass
|
263 |
+
|
264 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
265 |
+
import torch
|
266 |
+
|
267 |
+
import torch
|
268 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
269 |
+
from typing import List, Dict
|
270 |
+
|
271 |
+
|
272 |
+
class HFRerankerClient(LLMClient):
|
273 |
+
"""
|
274 |
+
Hugging Face Reranker client using Qwen/Qwen1.5-MoE-A14B-Chat reranking style (0.6B variant).
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-0.6B", device: str = None):
|
278 |
+
"""
|
279 |
+
Initialize the Hugging Face reranker.
|
280 |
+
"""
|
281 |
+
self.model_name = model_name
|
282 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
283 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
284 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
|
285 |
+
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
286 |
+
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
287 |
+
|
288 |
+
def rerank(self, query: str, passages: List[str], top_k: int = 3) -> List[str]:
|
289 |
+
"""
|
290 |
+
Rerank passages based on relevance to query.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
query (str): Query string.
|
294 |
+
passages (List[str]): List of passages.
|
295 |
+
top_k (int): Number of top passages to return.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
List[str]: Top-k most relevant passages.
|
299 |
+
"""
|
300 |
+
inputs = [self.tokenizer(f"{query} [SEP] {p}", return_tensors="pt", truncation=True, padding=True).to(self.device) for p in passages]
|
301 |
+
scores = []
|
302 |
+
|
303 |
+
with torch.no_grad():
|
304 |
+
for inp in inputs:
|
305 |
+
logits = self.model(**inp).logits
|
306 |
+
score = torch.softmax(logits, dim=1)[0, 1].item() # probability of relevance
|
307 |
+
scores.append(score)
|
308 |
+
|
309 |
+
print(f"Scores for passages: {scores}")
|
310 |
+
|
311 |
+
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
|
312 |
+
print(f"top indices: {top_indices}")
|
313 |
+
return [passages[i] for i in top_indices]
|
314 |
+
|
315 |
+
@retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
|
316 |
+
def call_api(self, prompt: str) -> str:
|
317 |
+
pass
|
318 |
+
|
319 |
+
def call_batch(self, prompts, max_workers=8):
|
320 |
+
pass
|
321 |
|
322 |
|
323 |
class AIExtractor:
|
|
|
355 |
Extractor that uses an LLM to classify and extract structured information from text content.
|
356 |
This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema.
|
357 |
"""
|
358 |
+
def __init__(self, reranker: LLMClient, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ):
|
359 |
"""
|
360 |
Initializes the LLMClassifierExtractor with an LLM client and a prompt template.
|
361 |
|
|
|
364 |
prompt_template (str): The template to use for generating prompts for the LLM.
|
365 |
"""
|
366 |
super().__init__(llm_client, prompt_template)
|
367 |
+
self.reranker = reranker
|
368 |
self.classifier_prompt = classifier_prompt
|
369 |
|
370 |
def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]:
|
|
|
380 |
# Use the get_html_chunks function to split the content into chunks
|
381 |
return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5)
|
382 |
|
|
|
|
|
|
|
|
|
383 |
|
384 |
+
def classify_chunks(self, passages, top_k=3, hf: bool = False): # reranker
|
385 |
+
query = self.classifier_prompt
|
|
|
386 |
|
387 |
+
if hf:
|
388 |
+
print("Using Hugging Face reranker for classification.")
|
389 |
+
return self.reranker.rerank(query, passages, top_k=top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
+
# NVIDIA reranker path
|
392 |
+
responses = self.reranker.client.compress_documents(
|
393 |
+
query=query,
|
394 |
+
documents=[Document(page_content=passage) for passage in passages]
|
395 |
+
)
|
396 |
+
return [response.page_content for response in responses[:top_k]]
|
397 |
+
|
398 |
+
def extract(self, content, schema, hf: bool = False):
|
399 |
"""
|
400 |
Extracts structured information from the given content based on the provided schema.
|
401 |
|
402 |
Args:
|
403 |
content (str): The raw content to extract information from.
|
404 |
schema (BaseModel): A Pydantic model defining the structure of the expected output.
|
405 |
+
hf (bool): Whether to use the Hugging Face reranker or NVIDIA (default).
|
|
|
|
|
406 |
"""
|
407 |
+
chunks = self.chunk_content(content, max_tokens=1500)
|
408 |
+
print(f"Content successfully chunked into {len(chunks)}.")
|
409 |
+
print(f"Content successfully chunked: {chunks}")
|
410 |
+
classified_chunks = self.classify_chunks(chunks, hf=hf) # conditional reranker
|
411 |
+
filtered_content = "\n\n".join(classified_chunks)
|
412 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
if not filtered_content:
|
414 |
print("Warning: No relevant chunks found. Returning empty response.")
|
415 |
return "{}"
|
416 |
+
|
417 |
prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema())
|
418 |
+
# print(f"Generated prompt for extraction: {prompt[:500]}...")
|
|
|
419 |
llm_response = self.llm_client.call_api(prompt)
|
420 |
+
# print(f"LLM response: {llm_response[:500]}...")
|
421 |
+
|
422 |
+
return llm_response or "{}"
|
423 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
# TODO: RAGExtractor class
|
426 |
class RAGExtractor(AIExtractor):
|
|
|
546 |
|
547 |
if not query:
|
548 |
query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}"
|
549 |
+
# print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'")
|
550 |
|
551 |
chunks = self._langchain_HHTS(content)
|
552 |
print(f"Content successfully chunked into {len(chunks)} pieces.")
|
web2json/pipeline.py
CHANGED
@@ -13,7 +13,7 @@ class Pipeline:
|
|
13 |
self.ai_extractor = ai_extractor
|
14 |
self.postprocessor = postprocessor
|
15 |
|
16 |
-
def run(self, content: str, is_url: bool, schema:BaseModel) -> dict:
|
17 |
"""
|
18 |
Run the entire pipeline: preprocess, extract, and postprocess.
|
19 |
|
@@ -27,11 +27,11 @@ class Pipeline:
|
|
27 |
"""
|
28 |
# Step 1: Preprocess the content
|
29 |
preprocessed_content = self.preprocessor.preprocess(content, is_url)
|
30 |
-
print(f"Preprocessed content: {preprocessed_content}...")
|
31 |
print('+'*80)
|
32 |
# Step 2: Extract structured information using AI
|
33 |
-
extracted_data = self.ai_extractor.extract(preprocessed_content, schema)
|
34 |
-
print(f"Extracted data: {extracted_data[:100]}...")
|
35 |
print('+'*80)
|
36 |
# Step 3: Post-process the extracted data
|
37 |
final_output = self.postprocessor.process(extracted_data)
|
|
|
13 |
self.ai_extractor = ai_extractor
|
14 |
self.postprocessor = postprocessor
|
15 |
|
16 |
+
def run(self, content: str, is_url: bool, schema:BaseModel, hf=False) -> dict:
|
17 |
"""
|
18 |
Run the entire pipeline: preprocess, extract, and postprocess.
|
19 |
|
|
|
27 |
"""
|
28 |
# Step 1: Preprocess the content
|
29 |
preprocessed_content = self.preprocessor.preprocess(content, is_url)
|
30 |
+
# print(f"Preprocessed content: {preprocessed_content}...")
|
31 |
print('+'*80)
|
32 |
# Step 2: Extract structured information using AI
|
33 |
+
extracted_data = self.ai_extractor.extract(preprocessed_content, schema, hf=hf)
|
34 |
+
# print(f"Extracted data: {extracted_data[:100]}...")
|
35 |
print('+'*80)
|
36 |
# Step 3: Post-process the extracted data
|
37 |
final_output = self.postprocessor.process(extracted_data)
|