from transformers import pipeline import re from typing import List, Union import torch import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Pre-compiled regex patterns for faster text cleaning HTML_TAG_PATTERN = re.compile(r'<[^>]+>') SPECIAL_CHARS_PATTERN = re.compile(r'[^\w\s.,;?!-]') MULTISPACE_PATTERN = re.compile(r'\s+') # Model and summarizer pipeline configuration MODEL_NAME = "sshleifer/distilbart-cnn-6-6" MAX_LENGTH = 1024 # Maximum input token length for the model (e.g., BART, DistilBART) MAX_SUMMARY_LENGTH = 150 # Maximum summary length MIN_SUMMARY_LENGTH = 30 # Minimum summary length # Device selection for pipeline _device = 0 if torch.cuda.is_available() else -1 # Load the summarizer pipeline ONCE at module level logger.info(f"Loading summarizer pipeline: {MODEL_NAME}") _summarizer = pipeline( "summarization", model=MODEL_NAME, device=_device ) def clean_text(text: str) -> str: """Clean text by removing HTML tags and special characters""" text = HTML_TAG_PATTERN.sub(' ', text) text = SPECIAL_CHARS_PATTERN.sub(' ', text) return MULTISPACE_PATTERN.sub(' ', text).strip() def get_summary_points(texts, max_points=3, batch_size=4): is_batch = isinstance(texts, list) # Ensure is_batch is always defined try: # Handle both single text and batch of texts texts = [texts] if not is_batch else texts # Clean and truncate texts cleaned_texts = [clean_text(t)[:MAX_LENGTH] for t in texts] # Filter out texts that are too short valid_texts = [] valid_indices = [] for idx, t in enumerate(cleaned_texts): if len(t.split()) >= 30: valid_texts.append(t) valid_indices.append(idx) else: logger.warning(f"Text at index {idx} is too short for summarization (length: {len(t.split())})") if not valid_texts: logger.warning("No valid texts found for summarization") return [] if not is_batch else [[] for _ in texts] # Generate summaries in batch try: summaries = _summarizer( valid_texts, max_length=MAX_SUMMARY_LENGTH, min_length=MIN_SUMMARY_LENGTH, length_penalty=2.0, # Increased to favor longer summaries num_beams=4, # Increased for better quality no_repeat_ngram_size=3, early_stopping=True, do_sample=False # Disable sampling for more deterministic results ) summaries = [s['summary_text'] for s in summaries] except Exception as e: logger.error(f"Error during summarization: {e}") return [] if not is_batch else [[] for _ in texts] # Process summaries into points all_points = [] for summary in summaries: points = [] # Split by sentence boundaries (period, question mark, exclamation mark) sentences = re.split(r'[.!?]+', summary) for sentence in sentences: sentence = sentence.strip() if sentence and len(sentence.split()) >= 5: # Only include sentences with at least 5 words sentence = sentence.capitalize() + '.' points.append(sentence) all_points.append(points[:max_points]) # Handle results for batch processing if is_batch: result = [[] for _ in texts] for idx, points in zip(valid_indices, all_points): result[idx] = points return result return all_points[0] if all_points else [] except Exception as e: logger.error(f"Error generating summary: {e}") return [] if not is_batch else [[] for _ in texts] if __name__ == "__main__": # Example usage texts = [ """ """ ] try: print("\nGenerating summaries...") results = get_summary_points(texts) for idx, points in enumerate(results, 1): print(f"\n=== Summary {idx} ===") if points: for point_idx, point in enumerate(points, 1): print(f"{point_idx}. {point}") else: print("No summary could be generated.") except Exception as e: print(f"An error occurred: {e}")