abdo-Mansour commited on
Commit
e202963
·
2 Parent(s): 1add4c0 2ae29dd

Merge remote-tracking branch 'origin/reranker'

Browse files
.txt ADDED
File without changes
requirements.txt CHANGED
@@ -13,4 +13,6 @@ langchain-text-splitters
13
  sentence-transformers
14
  openai
15
  html_chunking
 
 
16
  lxml
 
13
  sentence-transformers
14
  openai
15
  html_chunking
16
+ langchain_nvidia_ai_endpoints
17
+ langchain_core
18
  lxml
test.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
web2json/__pycache__/ai_extractor.cpython-311.pyc CHANGED
Binary files a/web2json/__pycache__/ai_extractor.cpython-311.pyc and b/web2json/__pycache__/ai_extractor.cpython-311.pyc differ
 
web2json/__pycache__/pipeline.cpython-311.pyc CHANGED
Binary files a/web2json/__pycache__/pipeline.cpython-311.pyc and b/web2json/__pycache__/pipeline.cpython-311.pyc differ
 
web2json/__pycache__/postprocessor.cpython-311.pyc CHANGED
Binary files a/web2json/__pycache__/postprocessor.cpython-311.pyc and b/web2json/__pycache__/postprocessor.cpython-311.pyc differ
 
web2json/__pycache__/preprocessor.cpython-311.pyc CHANGED
Binary files a/web2json/__pycache__/preprocessor.cpython-311.pyc and b/web2json/__pycache__/preprocessor.cpython-311.pyc differ
 
web2json/ai_extractor.py CHANGED
@@ -11,12 +11,16 @@ from google.genai import types
11
  from pydantic import BaseModel
12
  from concurrent.futures import ThreadPoolExecutor
13
  from html_chunking import get_html_chunks
 
 
14
  from abc import ABC, abstractmethod
15
  from typing import List, Any, Dict, Tuple, Optional
16
  import re
17
  import json
18
  from langchain_text_splitters import HTMLHeaderTextSplitter
19
  from sentence_transformers import SentenceTransformer
 
 
20
  class LLMClient(ABC):
21
  """
22
  Abstract base class for calling LLM APIs.
@@ -227,6 +231,93 @@ class NvidiaLLMClient(LLMClient):
227
  # You could set results[idx] = None or a default string
228
  results[idx] = f"<failed after retries>"
229
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
  class AIExtractor:
@@ -264,7 +355,7 @@ class LLMClassifierExtractor(AIExtractor):
264
  Extractor that uses an LLM to classify and extract structured information from text content.
265
  This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema.
266
  """
267
- def __init__(self, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ):
268
  """
269
  Initializes the LLMClassifierExtractor with an LLM client and a prompt template.
270
 
@@ -273,6 +364,7 @@ class LLMClassifierExtractor(AIExtractor):
273
  prompt_template (str): The template to use for generating prompts for the LLM.
274
  """
275
  super().__init__(llm_client, prompt_template)
 
276
  self.classifier_prompt = classifier_prompt
277
 
278
  def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]:
@@ -288,79 +380,47 @@ class LLMClassifierExtractor(AIExtractor):
288
  # Use the get_html_chunks function to split the content into chunks
289
  return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5)
290
 
291
-
292
- def classify_chunks(self, chunks: List[str], schema: BaseModel) -> List[Dict[str, Any]]:
293
- """
294
- Classifies each chunk using the LLM based on the provided schema.
295
 
296
- Args:
297
- chunks (List[str]): A list of text chunks to classify.
298
- schema (BaseModel): A Pydantic model defining the structure of the expected output.
299
 
300
- Returns:
301
- List[Dict[str, Any]]: A list of dictionaries containing classified information.
302
- """
303
- prompts = [self.classifier_prompt.format(content=chunk, schema=schema.model_json_schema()) for chunk in chunks]
304
- classified_chunks = []
305
- responses = self.llm_client.call_batch(prompts)
306
- for response in responses:
307
- # extract the json from the response
308
- json_data = extract_markdown_json(response)
309
- if json_data:
310
- classified_chunks.append(json_data)
311
- else:
312
- classified_chunks.append({
313
- "error": "Failed to extract JSON from response",
314
- "relevant": 1,
315
- })
316
- return classified_chunks
317
 
318
- def extract(self, content: str, schema: BaseModel) -> str:
 
 
 
 
 
 
 
319
  """
320
  Extracts structured information from the given content based on the provided schema.
321
 
322
  Args:
323
  content (str): The raw content to extract information from.
324
  schema (BaseModel): A Pydantic model defining the structure of the expected output.
325
-
326
- Returns:
327
- str: The structured JSON object as a string.
328
  """
329
- # Chunk the HTML
330
- chunks = self.chunk_content(content,max_tokens=1500)
331
- print(f"Content successfully chunked into {len(chunks)} pieces.")
332
- # Classify each chunk using the LLM
333
- classified_chunks = self.classify_chunks(chunks, schema)
334
- # Concatenate the positive classified chunks into a single string
335
- print(f"Classified {classified_chunks} chunks.")
336
- positive_chunks = []
337
- for i, chunk in enumerate(classified_chunks):
338
- if chunk.get("relevant", 0) > 0:
339
- positive_chunks.append(chunks[i])
340
- if len(positive_chunks) == 0:
341
- positive_chunks = chunks
342
- filtered_content = "\n\n".join(positive_chunks)
343
- print(f"Filtered content for extraction: {filtered_content}") # Log the first 500 characters of filtered content
344
  if not filtered_content:
345
  print("Warning: No relevant chunks found. Returning empty response.")
346
  return "{}"
347
- # Generate the final prompt for extraction
348
  prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema())
349
- print(f"Generated prompt for extraction: {prompt[:500]}...")
350
- # Call the LLM to extract structured information
351
  llm_response = self.llm_client.call_api(prompt)
352
- print(f"LLM response: {llm_response[:500]}...")
353
- # Return the structured response
354
- if not llm_response:
355
- print("Warning: LLM response is empty. Returning empty response.")
356
- return "{}"
357
-
358
- # json_response = extract_markdown_json(llm_response)
359
- # if json_response is None:
360
- # print("Warning: Failed to extract JSON from LLM response. Returning empty response.")
361
- # return "{}"
362
-
363
- return llm_response
364
 
365
  # TODO: RAGExtractor class
366
  class RAGExtractor(AIExtractor):
@@ -486,7 +546,7 @@ class RAGExtractor(AIExtractor):
486
 
487
  if not query:
488
  query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}"
489
- print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'")
490
 
491
  chunks = self._langchain_HHTS(content)
492
  print(f"Content successfully chunked into {len(chunks)} pieces.")
 
11
  from pydantic import BaseModel
12
  from concurrent.futures import ThreadPoolExecutor
13
  from html_chunking import get_html_chunks
14
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
15
+ from langchain_core.documents import Document
16
  from abc import ABC, abstractmethod
17
  from typing import List, Any, Dict, Tuple, Optional
18
  import re
19
  import json
20
  from langchain_text_splitters import HTMLHeaderTextSplitter
21
  from sentence_transformers import SentenceTransformer
22
+ import requests
23
+
24
  class LLMClient(ABC):
25
  """
26
  Abstract base class for calling LLM APIs.
 
231
  # You could set results[idx] = None or a default string
232
  results[idx] = f"<failed after retries>"
233
  return results
234
+
235
+
236
+ class NvidiaRerankerClient(LLMClient):
237
+ """
238
+ Concrete implementation of LLMClient for the NVIDIA API (non-streaming).
239
+ """
240
+
241
+ def __init__(self, config: dict):
242
+ self.model_name = config.get("model_name", "nvidia/llama-3.2-nv-rerankqa-1b-v2")
243
+ self.client = NVIDIARerank(
244
+ model=self.model_name,
245
+ api_key=os.getenv("NVIDIA_API_KEY"),
246
+ )
247
+
248
+ def set_model(self, model_name: str):
249
+ """
250
+ Set the model name for the NVIDIA API client.
251
+
252
+ Args:
253
+ model_name (str): The name of the model to use.
254
+ """
255
+ self.model_name = model_name
256
+
257
+ @retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
258
+ def call_api(self, prompt: str) -> str:
259
+ pass
260
+
261
+ def call_batch(self, prompts, max_workers=8):
262
+ pass
263
+
264
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
265
+ import torch
266
+
267
+ import torch
268
+ from transformers import AutoTokenizer, AutoModelForCausalLM
269
+ from typing import List, Dict
270
+
271
+
272
+ class HFRerankerClient(LLMClient):
273
+ """
274
+ Hugging Face Reranker client using Qwen/Qwen1.5-MoE-A14B-Chat reranking style (0.6B variant).
275
+ """
276
+
277
+ def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-0.6B", device: str = None):
278
+ """
279
+ Initialize the Hugging Face reranker.
280
+ """
281
+ self.model_name = model_name
282
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
283
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
284
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name).to(self.device)
285
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
286
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
287
+
288
+ def rerank(self, query: str, passages: List[str], top_k: int = 3) -> List[str]:
289
+ """
290
+ Rerank passages based on relevance to query.
291
+
292
+ Args:
293
+ query (str): Query string.
294
+ passages (List[str]): List of passages.
295
+ top_k (int): Number of top passages to return.
296
+
297
+ Returns:
298
+ List[str]: Top-k most relevant passages.
299
+ """
300
+ inputs = [self.tokenizer(f"{query} [SEP] {p}", return_tensors="pt", truncation=True, padding=True).to(self.device) for p in passages]
301
+ scores = []
302
+
303
+ with torch.no_grad():
304
+ for inp in inputs:
305
+ logits = self.model(**inp).logits
306
+ score = torch.softmax(logits, dim=1)[0, 1].item() # probability of relevance
307
+ scores.append(score)
308
+
309
+ print(f"Scores for passages: {scores}")
310
+
311
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
312
+ print(f"top indices: {top_indices}")
313
+ return [passages[i] for i in top_indices]
314
+
315
+ @retry_on_ratelimit(max_retries=6, base_delay=0.5, max_delay=5.0)
316
+ def call_api(self, prompt: str) -> str:
317
+ pass
318
+
319
+ def call_batch(self, prompts, max_workers=8):
320
+ pass
321
 
322
 
323
  class AIExtractor:
 
355
  Extractor that uses an LLM to classify and extract structured information from text content.
356
  This class is designed to handle classification tasks where the LLM generates structured output based on a provided schema.
357
  """
358
+ def __init__(self, reranker: LLMClient, llm_client: LLMClient, prompt_template: str, classifier_prompt: str, ):
359
  """
360
  Initializes the LLMClassifierExtractor with an LLM client and a prompt template.
361
 
 
364
  prompt_template (str): The template to use for generating prompts for the LLM.
365
  """
366
  super().__init__(llm_client, prompt_template)
367
+ self.reranker = reranker
368
  self.classifier_prompt = classifier_prompt
369
 
370
  def chunk_content(self, content: str , max_tokens: int = 500, is_clean: bool = True) -> List[str]:
 
380
  # Use the get_html_chunks function to split the content into chunks
381
  return get_html_chunks(html=content, max_tokens=max_tokens, is_clean_html=is_clean, attr_cutoff_len=5)
382
 
 
 
 
 
383
 
384
+ def classify_chunks(self, passages, top_k=3, hf: bool = False): # reranker
385
+ query = self.classifier_prompt
 
386
 
387
+ if hf:
388
+ print("Using Hugging Face reranker for classification.")
389
+ return self.reranker.rerank(query, passages, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ # NVIDIA reranker path
392
+ responses = self.reranker.client.compress_documents(
393
+ query=query,
394
+ documents=[Document(page_content=passage) for passage in passages]
395
+ )
396
+ return [response.page_content for response in responses[:top_k]]
397
+
398
+ def extract(self, content, schema, hf: bool = False):
399
  """
400
  Extracts structured information from the given content based on the provided schema.
401
 
402
  Args:
403
  content (str): The raw content to extract information from.
404
  schema (BaseModel): A Pydantic model defining the structure of the expected output.
405
+ hf (bool): Whether to use the Hugging Face reranker or NVIDIA (default).
 
 
406
  """
407
+ chunks = self.chunk_content(content, max_tokens=1500)
408
+ print(f"Content successfully chunked into {len(chunks)}.")
409
+ print(f"Content successfully chunked: {chunks}")
410
+ classified_chunks = self.classify_chunks(chunks, hf=hf) # conditional reranker
411
+ filtered_content = "\n\n".join(classified_chunks)
412
+
 
 
 
 
 
 
 
 
 
413
  if not filtered_content:
414
  print("Warning: No relevant chunks found. Returning empty response.")
415
  return "{}"
416
+
417
  prompt = self.prompt_template.format(content=filtered_content, schema=schema.model_json_schema())
418
+ # print(f"Generated prompt for extraction: {prompt[:500]}...")
 
419
  llm_response = self.llm_client.call_api(prompt)
420
+ # print(f"LLM response: {llm_response[:500]}...")
421
+
422
+ return llm_response or "{}"
423
+
 
 
 
 
 
 
 
 
424
 
425
  # TODO: RAGExtractor class
426
  class RAGExtractor(AIExtractor):
 
546
 
547
  if not query:
548
  query = f"Extract information based on the following JSON schema: {schema.model_json_schema()}"
549
+ # print(f"No explicit query provided for retrieval. Using default: '{query[:100]}...'")
550
 
551
  chunks = self._langchain_HHTS(content)
552
  print(f"Content successfully chunked into {len(chunks)} pieces.")
web2json/pipeline.py CHANGED
@@ -13,7 +13,7 @@ class Pipeline:
13
  self.ai_extractor = ai_extractor
14
  self.postprocessor = postprocessor
15
 
16
- def run(self, content: str, is_url: bool, schema:BaseModel) -> dict:
17
  """
18
  Run the entire pipeline: preprocess, extract, and postprocess.
19
 
@@ -27,11 +27,11 @@ class Pipeline:
27
  """
28
  # Step 1: Preprocess the content
29
  preprocessed_content = self.preprocessor.preprocess(content, is_url)
30
- print(f"Preprocessed content: {preprocessed_content}...")
31
  print('+'*80)
32
  # Step 2: Extract structured information using AI
33
- extracted_data = self.ai_extractor.extract(preprocessed_content, schema)
34
- print(f"Extracted data: {extracted_data[:100]}...")
35
  print('+'*80)
36
  # Step 3: Post-process the extracted data
37
  final_output = self.postprocessor.process(extracted_data)
 
13
  self.ai_extractor = ai_extractor
14
  self.postprocessor = postprocessor
15
 
16
+ def run(self, content: str, is_url: bool, schema:BaseModel, hf=False) -> dict:
17
  """
18
  Run the entire pipeline: preprocess, extract, and postprocess.
19
 
 
27
  """
28
  # Step 1: Preprocess the content
29
  preprocessed_content = self.preprocessor.preprocess(content, is_url)
30
+ # print(f"Preprocessed content: {preprocessed_content}...")
31
  print('+'*80)
32
  # Step 2: Extract structured information using AI
33
+ extracted_data = self.ai_extractor.extract(preprocessed_content, schema, hf=hf)
34
+ # print(f"Extracted data: {extracted_data[:100]}...")
35
  print('+'*80)
36
  # Step 3: Post-process the extracted data
37
  final_output = self.postprocessor.process(extracted_data)