File size: 4,441 Bytes
9d76e23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import pipeline
import re
from typing import List, Union
import torch
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Pre-compiled regex patterns for faster text cleaning
HTML_TAG_PATTERN = re.compile(r'<[^>]+>')
SPECIAL_CHARS_PATTERN = re.compile(r'[^\w\s.,;?!-]')
MULTISPACE_PATTERN = re.compile(r'\s+')

# Model and summarizer pipeline configuration
MODEL_NAME = "sshleifer/distilbart-cnn-6-6"
MAX_LENGTH = 1024  # Maximum input token length for the model (e.g., BART, DistilBART)
MAX_SUMMARY_LENGTH = 150  # Maximum summary length
MIN_SUMMARY_LENGTH = 30   # Minimum summary length

# Device selection for pipeline
_device = 0 if torch.cuda.is_available() else -1

# Load the summarizer pipeline ONCE at module level
logger.info(f"Loading summarizer pipeline: {MODEL_NAME}")
_summarizer = pipeline(
    "summarization",
    model=MODEL_NAME,
    device=_device
)

def clean_text(text: str) -> str:
    """Clean text by removing HTML tags and special characters"""
    text = HTML_TAG_PATTERN.sub(' ', text)
    text = SPECIAL_CHARS_PATTERN.sub(' ', text)
    return MULTISPACE_PATTERN.sub(' ', text).strip()

def get_summary_points(texts, max_points=3, batch_size=4):
    is_batch = isinstance(texts, list)  # Ensure is_batch is always defined
    try:
        # Handle both single text and batch of texts
        texts = [texts] if not is_batch else texts
        # Clean and truncate texts
        cleaned_texts = [clean_text(t)[:MAX_LENGTH] for t in texts]
        # Filter out texts that are too short
        valid_texts = []
        valid_indices = []
        for idx, t in enumerate(cleaned_texts):
            if len(t.split()) >= 30:
                valid_texts.append(t)
                valid_indices.append(idx)
            else:
                logger.warning(f"Text at index {idx} is too short for summarization (length: {len(t.split())})")
        if not valid_texts:
            logger.warning("No valid texts found for summarization")
            return [] if not is_batch else [[] for _ in texts]
        # Generate summaries in batch
        try:
            summaries = _summarizer(
                valid_texts,
                max_length=MAX_SUMMARY_LENGTH,
                min_length=MIN_SUMMARY_LENGTH,
                length_penalty=2.0,  # Increased to favor longer summaries
                num_beams=4,  # Increased for better quality
                no_repeat_ngram_size=3,
                early_stopping=True,
                do_sample=False  # Disable sampling for more deterministic results
            )
            summaries = [s['summary_text'] for s in summaries]
        except Exception as e:
            logger.error(f"Error during summarization: {e}")
            return [] if not is_batch else [[] for _ in texts]
        # Process summaries into points
        all_points = []
        for summary in summaries:
            points = []
            # Split by sentence boundaries (period, question mark, exclamation mark)
            sentences = re.split(r'[.!?]+', summary)
            for sentence in sentences:
                sentence = sentence.strip()
                if sentence and len(sentence.split()) >= 5:  # Only include sentences with at least 5 words
                    sentence = sentence.capitalize() + '.'
                    points.append(sentence)
            all_points.append(points[:max_points])
        # Handle results for batch processing
        if is_batch:
            result = [[] for _ in texts]
            for idx, points in zip(valid_indices, all_points):
                result[idx] = points
            return result
        return all_points[0] if all_points else []
    except Exception as e:
        logger.error(f"Error generating summary: {e}")
        return [] if not is_batch else [[] for _ in texts]

if __name__ == "__main__":
    # Example usage
    texts = [
        """
        
        """
    ]
    try:
        print("\nGenerating summaries...")
        results = get_summary_points(texts)
        for idx, points in enumerate(results, 1):
            print(f"\n=== Summary {idx} ===")
            if points:
                for point_idx, point in enumerate(points, 1):
                    print(f"{point_idx}. {point}")
            else:
                print("No summary could be generated.")
    except Exception as e:
        print(f"An error occurred: {e}")