File size: 3,018 Bytes
ba5072f
6c007b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba5072f
6c007b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba5072f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# youtube_utils.py
import re
import torch
from transformers import BartForConditionalGeneration, BartTokenizer
from youtube_transcript_api import YouTubeTranscriptApi
from nltk.tokenize import sent_tokenize
import nltk

nltk.download('punkt')

def clean_text(text):
    cleaned_text = re.sub(r'\s+', ' ', text)
    cleaned_text = cleaned_text.replace("'", "")
    return cleaned_text

def get_youtube_captions(video_id):
    try:
        transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
        full_transcript = ""
        
        for transcript in transcript_list:
            try:
                english_transcript = transcript.translate('en').fetch()
                for caption in english_transcript:
                    full_transcript += caption['text'] + " "
                break
            except Exception:
                continue
        
        return clean_text(full_transcript)
    
    except Exception as e:
        print(f"Error fetching captions: {e}")
        return None

def summarize_large_text_with_bart(input_text):
    model_name = "facebook/bart-large-cnn"
    model = BartForConditionalGeneration.from_pretrained(model_name)
    tokenizer = BartTokenizer.from_pretrained(model_name)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    input_tokens = tokenizer.encode(input_text, add_special_tokens=False)
    total_input_length = len(input_tokens)
    
    desired_min_length = int(total_input_length * 0.28)
    desired_max_length = int(total_input_length * 0.40)
    
    sentences = sent_tokenize(input_text)
    max_chunk_length = 1024
    overlap = 2
    chunks = []
    
    sentence_tokens = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sentences]
    sentence_lengths = [len(tokens) for tokens in sentence_tokens]
    
    i = 0
    while i < len(sentences):
        current_chunk = []
        current_length = 0
        start = i
        
        while i < len(sentences) and current_length + sentence_lengths[i] <= max_chunk_length:
            current_chunk.append(sentences[i])
            current_length += sentence_lengths[i]
            i += 1
        
        if i < len(sentences):
            i = i - overlap if i - overlap > start else start
        
        chunks.append(' '.join(current_chunk))
    
    summaries = []
    for chunk in chunks:
        inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True).to(device)
        
        with torch.no_grad():
            summary_ids = model.generate(
                inputs,
                max_length=desired_max_length // len(chunks),
                min_length=desired_min_length // len(chunks),
                num_beams=4,
                length_penalty=2.0,
                no_repeat_ngram_size=3,
                early_stopping=True
            )
        
        summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
    
    return ' '.join(summaries)