Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
import re | |
from typing import List, Union | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Pre-compiled regex patterns for faster text cleaning | |
HTML_TAG_PATTERN = re.compile(r'<[^>]+>') | |
SPECIAL_CHARS_PATTERN = re.compile(r'[^\w\s.,;?!-]') | |
MULTISPACE_PATTERN = re.compile(r'\s+') | |
# Model and summarizer pipeline configuration | |
MODEL_NAME = "sshleifer/distilbart-cnn-6-6" | |
MAX_LENGTH = 1024 # Maximum input token length for the model (e.g., BART, DistilBART) | |
MAX_SUMMARY_LENGTH = 150 # Maximum summary length | |
MIN_SUMMARY_LENGTH = 30 # Minimum summary length | |
# Device selection for pipeline | |
_device = 0 if torch.cuda.is_available() else -1 | |
# Load the summarizer pipeline ONCE at module level | |
logger.info(f"Loading summarizer pipeline: {MODEL_NAME}") | |
_summarizer = pipeline( | |
"summarization", | |
model=MODEL_NAME, | |
device=_device | |
) | |
def clean_text(text: str) -> str: | |
"""Clean text by removing HTML tags and special characters""" | |
text = HTML_TAG_PATTERN.sub(' ', text) | |
text = SPECIAL_CHARS_PATTERN.sub(' ', text) | |
return MULTISPACE_PATTERN.sub(' ', text).strip() | |
def get_summary_points(texts, max_points=3, batch_size=4): | |
is_batch = isinstance(texts, list) # Ensure is_batch is always defined | |
try: | |
# Handle both single text and batch of texts | |
texts = [texts] if not is_batch else texts | |
# Clean and truncate texts | |
cleaned_texts = [clean_text(t)[:MAX_LENGTH] for t in texts] | |
# Filter out texts that are too short | |
valid_texts = [] | |
valid_indices = [] | |
for idx, t in enumerate(cleaned_texts): | |
if len(t.split()) >= 30: | |
valid_texts.append(t) | |
valid_indices.append(idx) | |
else: | |
logger.warning(f"Text at index {idx} is too short for summarization (length: {len(t.split())})") | |
if not valid_texts: | |
logger.warning("No valid texts found for summarization") | |
return [] if not is_batch else [[] for _ in texts] | |
# Generate summaries in batch | |
try: | |
summaries = _summarizer( | |
valid_texts, | |
max_length=MAX_SUMMARY_LENGTH, | |
min_length=MIN_SUMMARY_LENGTH, | |
length_penalty=2.0, # Increased to favor longer summaries | |
num_beams=4, # Increased for better quality | |
no_repeat_ngram_size=3, | |
early_stopping=True, | |
do_sample=False # Disable sampling for more deterministic results | |
) | |
summaries = [s['summary_text'] for s in summaries] | |
except Exception as e: | |
logger.error(f"Error during summarization: {e}") | |
return [] if not is_batch else [[] for _ in texts] | |
# Process summaries into points | |
all_points = [] | |
for summary in summaries: | |
points = [] | |
# Split by sentence boundaries (period, question mark, exclamation mark) | |
sentences = re.split(r'[.!?]+', summary) | |
for sentence in sentences: | |
sentence = sentence.strip() | |
if sentence and len(sentence.split()) >= 5: # Only include sentences with at least 5 words | |
sentence = sentence.capitalize() + '.' | |
points.append(sentence) | |
all_points.append(points[:max_points]) | |
# Handle results for batch processing | |
if is_batch: | |
result = [[] for _ in texts] | |
for idx, points in zip(valid_indices, all_points): | |
result[idx] = points | |
return result | |
return all_points[0] if all_points else [] | |
except Exception as e: | |
logger.error(f"Error generating summary: {e}") | |
return [] if not is_batch else [[] for _ in texts] | |
if __name__ == "__main__": | |
# Example usage | |
texts = [ | |
""" | |
""" | |
] | |
try: | |
print("\nGenerating summaries...") | |
results = get_summary_points(texts) | |
for idx, points in enumerate(results, 1): | |
print(f"\n=== Summary {idx} ===") | |
if points: | |
for point_idx, point in enumerate(points, 1): | |
print(f"{point_idx}. {point}") | |
else: | |
print("No summary could be generated.") | |
except Exception as e: | |
print(f"An error occurred: {e}") |