recommendation / src /test_summarize.py
sundaram22verma's picture
initial commit
9d76e23
from transformers import pipeline
import re
from typing import List, Union
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pre-compiled regex patterns for faster text cleaning
HTML_TAG_PATTERN = re.compile(r'<[^>]+>')
SPECIAL_CHARS_PATTERN = re.compile(r'[^\w\s.,;?!-]')
MULTISPACE_PATTERN = re.compile(r'\s+')
# Model and summarizer pipeline configuration
MODEL_NAME = "sshleifer/distilbart-cnn-6-6"
MAX_LENGTH = 1024 # Maximum input token length for the model (e.g., BART, DistilBART)
MAX_SUMMARY_LENGTH = 150 # Maximum summary length
MIN_SUMMARY_LENGTH = 30 # Minimum summary length
# Device selection for pipeline
_device = 0 if torch.cuda.is_available() else -1
# Load the summarizer pipeline ONCE at module level
logger.info(f"Loading summarizer pipeline: {MODEL_NAME}")
_summarizer = pipeline(
"summarization",
model=MODEL_NAME,
device=_device
)
def clean_text(text: str) -> str:
"""Clean text by removing HTML tags and special characters"""
text = HTML_TAG_PATTERN.sub(' ', text)
text = SPECIAL_CHARS_PATTERN.sub(' ', text)
return MULTISPACE_PATTERN.sub(' ', text).strip()
def get_summary_points(texts, max_points=3, batch_size=4):
is_batch = isinstance(texts, list) # Ensure is_batch is always defined
try:
# Handle both single text and batch of texts
texts = [texts] if not is_batch else texts
# Clean and truncate texts
cleaned_texts = [clean_text(t)[:MAX_LENGTH] for t in texts]
# Filter out texts that are too short
valid_texts = []
valid_indices = []
for idx, t in enumerate(cleaned_texts):
if len(t.split()) >= 30:
valid_texts.append(t)
valid_indices.append(idx)
else:
logger.warning(f"Text at index {idx} is too short for summarization (length: {len(t.split())})")
if not valid_texts:
logger.warning("No valid texts found for summarization")
return [] if not is_batch else [[] for _ in texts]
# Generate summaries in batch
try:
summaries = _summarizer(
valid_texts,
max_length=MAX_SUMMARY_LENGTH,
min_length=MIN_SUMMARY_LENGTH,
length_penalty=2.0, # Increased to favor longer summaries
num_beams=4, # Increased for better quality
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=False # Disable sampling for more deterministic results
)
summaries = [s['summary_text'] for s in summaries]
except Exception as e:
logger.error(f"Error during summarization: {e}")
return [] if not is_batch else [[] for _ in texts]
# Process summaries into points
all_points = []
for summary in summaries:
points = []
# Split by sentence boundaries (period, question mark, exclamation mark)
sentences = re.split(r'[.!?]+', summary)
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence.split()) >= 5: # Only include sentences with at least 5 words
sentence = sentence.capitalize() + '.'
points.append(sentence)
all_points.append(points[:max_points])
# Handle results for batch processing
if is_batch:
result = [[] for _ in texts]
for idx, points in zip(valid_indices, all_points):
result[idx] = points
return result
return all_points[0] if all_points else []
except Exception as e:
logger.error(f"Error generating summary: {e}")
return [] if not is_batch else [[] for _ in texts]
if __name__ == "__main__":
# Example usage
texts = [
"""
"""
]
try:
print("\nGenerating summaries...")
results = get_summary_points(texts)
for idx, points in enumerate(results, 1):
print(f"\n=== Summary {idx} ===")
if points:
for point_idx, point in enumerate(points, 1):
print(f"{point_idx}. {point}")
else:
print("No summary could be generated.")
except Exception as e:
print(f"An error occurred: {e}")