maliahson commited on
Commit
6c007b8
·
verified ·
1 Parent(s): 353620b

Update youtube_utils.py

Browse files
Files changed (1) hide show
  1. youtube_utils.py +100 -91
youtube_utils.py CHANGED
@@ -1,91 +1,100 @@
1
- # youtube_utils.py
2
- import re
3
- import torch
4
- from transformers import BartForConditionalGeneration, BartTokenizer
5
- from youtube_transcript_api import YouTubeTranscriptApi
6
- from nltk.tokenize import sent_tokenize
7
- import nltk
8
-
9
- nltk.download('punkt')
10
-
11
- def clean_text(text):
12
- cleaned_text = re.sub(r'\s+', ' ', text)
13
- cleaned_text = cleaned_text.replace("'", "")
14
- return cleaned_text
15
-
16
- def get_youtube_captions(video_id):
17
- try:
18
- transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
19
- full_transcript = ""
20
-
21
- for transcript in transcript_list:
22
- try:
23
- english_transcript = transcript.translate('en').fetch()
24
- for caption in english_transcript:
25
- full_transcript += caption['text'] + " "
26
- break
27
- except Exception:
28
- continue
29
-
30
- return clean_text(full_transcript)
31
-
32
- except Exception as e:
33
- print(f"Error fetching captions: {e}")
34
- return None
35
-
36
- def summarize_large_text_with_bart(input_text):
37
- model_name = "facebook/bart-large-cnn"
38
- model = BartForConditionalGeneration.from_pretrained(model_name)
39
- tokenizer = BartTokenizer.from_pretrained(model_name)
40
-
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
- model.to(device)
43
-
44
- input_tokens = tokenizer.encode(input_text, add_special_tokens=False)
45
- total_input_length = len(input_tokens)
46
-
47
- desired_min_length = int(total_input_length * 0.28)
48
- desired_max_length = int(total_input_length * 0.40)
49
-
50
- sentences = sent_tokenize(input_text)
51
- max_chunk_length = 1024
52
- overlap = 2
53
- chunks = []
54
-
55
- sentence_tokens = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sentences]
56
- sentence_lengths = [len(tokens) for tokens in sentence_tokens]
57
-
58
- i = 0
59
- while i < len(sentences):
60
- current_chunk = []
61
- current_length = 0
62
- start = i
63
-
64
- while i < len(sentences) and current_length + sentence_lengths[i] <= max_chunk_length:
65
- current_chunk.append(sentences[i])
66
- current_length += sentence_lengths[i]
67
- i += 1
68
-
69
- if i < len(sentences):
70
- i = i - overlap if i - overlap > start else start
71
-
72
- chunks.append(' '.join(current_chunk))
73
-
74
- summaries = []
75
- for chunk in chunks:
76
- inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True).to(device)
77
-
78
- with torch.no_grad():
79
- summary_ids = model.generate(
80
- inputs,
81
- max_length=desired_max_length // len(chunks),
82
- min_length=desired_min_length // len(chunks),
83
- num_beams=4,
84
- length_penalty=2.0,
85
- no_repeat_ngram_size=3,
86
- early_stopping=True
87
- )
88
-
89
- summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
90
-
91
- return ' '.join(summaries)
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from transformers import BartForConditionalGeneration, BartTokenizer
4
+ from youtube_transcript_api import YouTubeTranscriptApi
5
+ from nltk.tokenize import sent_tokenize
6
+ import nltk
7
+
8
+ # Ensure NLTK data is downloaded during the first run
9
+ nltk.download('punkt')
10
+
11
+ def clean_text(text):
12
+ """Clean up text by removing extra whitespace and quotes."""
13
+ cleaned_text = re.sub(r'\s+', ' ', text)
14
+ cleaned_text = cleaned_text.replace("'", "")
15
+ return cleaned_text
16
+
17
+ def get_youtube_captions(video_id):
18
+ """Fetch captions for a YouTube video, translating to English if needed."""
19
+ try:
20
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
21
+ full_transcript = ""
22
+
23
+ for transcript in transcript_list:
24
+ try:
25
+ english_transcript = transcript.translate('en').fetch()
26
+ for caption in english_transcript:
27
+ full_transcript += caption['text'] + " "
28
+ break
29
+ except Exception:
30
+ continue
31
+
32
+ return clean_text(full_transcript)
33
+
34
+ except Exception as e:
35
+ print(f"Error fetching captions: {e}")
36
+ return None
37
+
38
+ def summarize_large_text_with_bart(input_text):
39
+ """Summarize large text using BART model."""
40
+ model_name = "facebook/bart-large-cnn"
41
+
42
+ # Load tokenizer and model
43
+ tokenizer = BartTokenizer.from_pretrained(model_name)
44
+ model = BartForConditionalGeneration.from_pretrained(model_name)
45
+
46
+ # Use GPU if available
47
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ model.to(device)
49
+
50
+ # Tokenize input and calculate summary lengths
51
+ input_tokens = tokenizer.encode(input_text, add_special_tokens=False)
52
+ total_input_length = len(input_tokens)
53
+
54
+ desired_min_length = int(total_input_length * 0.28)
55
+ desired_max_length = int(total_input_length * 0.40)
56
+
57
+ # Split input into chunks of <= 1024 tokens with overlap
58
+ sentences = sent_tokenize(input_text)
59
+ max_chunk_length = 1024
60
+ overlap = 2
61
+ chunks = []
62
+
63
+ sentence_tokens = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sentences]
64
+ sentence_lengths = [len(tokens) for tokens in sentence_tokens]
65
+
66
+ i = 0
67
+ while i < len(sentences):
68
+ current_chunk = []
69
+ current_length = 0
70
+ start = i
71
+
72
+ while i < len(sentences) and current_length + sentence_lengths[i] <= max_chunk_length:
73
+ current_chunk.append(sentences[i])
74
+ current_length += sentence_lengths[i]
75
+ i += 1
76
+
77
+ if i < len(sentences):
78
+ i = i - overlap if i - overlap > start else start
79
+
80
+ chunks.append(' '.join(current_chunk))
81
+
82
+ # Generate summaries for each chunk
83
+ summaries = []
84
+ for chunk in chunks:
85
+ inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True).to(device)
86
+
87
+ with torch.no_grad():
88
+ summary_ids = model.generate(
89
+ inputs,
90
+ max_length=desired_max_length // len(chunks),
91
+ min_length=desired_min_length // len(chunks),
92
+ num_beams=4,
93
+ length_penalty=2.0,
94
+ no_repeat_ngram_size=3,
95
+ early_stopping=True
96
+ )
97
+
98
+ summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
99
+
100
+ return ' '.join(summaries)