Spaces:
Sleeping
Sleeping
Update youtube_utils.py
Browse files- youtube_utils.py +3 -12
youtube_utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import re
|
2 |
import torch
|
3 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
@@ -5,17 +6,14 @@ from youtube_transcript_api import YouTubeTranscriptApi
|
|
5 |
from nltk.tokenize import sent_tokenize
|
6 |
import nltk
|
7 |
|
8 |
-
# Ensure NLTK data is downloaded during the first run
|
9 |
nltk.download('punkt')
|
10 |
|
11 |
def clean_text(text):
|
12 |
-
"""Clean up text by removing extra whitespace and quotes."""
|
13 |
cleaned_text = re.sub(r'\s+', ' ', text)
|
14 |
cleaned_text = cleaned_text.replace("'", "")
|
15 |
return cleaned_text
|
16 |
|
17 |
def get_youtube_captions(video_id):
|
18 |
-
"""Fetch captions for a YouTube video, translating to English if needed."""
|
19 |
try:
|
20 |
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
21 |
full_transcript = ""
|
@@ -36,25 +34,19 @@ def get_youtube_captions(video_id):
|
|
36 |
return None
|
37 |
|
38 |
def summarize_large_text_with_bart(input_text):
|
39 |
-
"""Summarize large text using BART model."""
|
40 |
model_name = "facebook/bart-large-cnn"
|
41 |
-
|
42 |
-
# Load tokenizer and model
|
43 |
-
tokenizer = BartTokenizer.from_pretrained(model_name)
|
44 |
model = BartForConditionalGeneration.from_pretrained(model_name)
|
|
|
45 |
|
46 |
-
# Use GPU if available
|
47 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
48 |
model.to(device)
|
49 |
|
50 |
-
# Tokenize input and calculate summary lengths
|
51 |
input_tokens = tokenizer.encode(input_text, add_special_tokens=False)
|
52 |
total_input_length = len(input_tokens)
|
53 |
|
54 |
desired_min_length = int(total_input_length * 0.28)
|
55 |
desired_max_length = int(total_input_length * 0.40)
|
56 |
|
57 |
-
# Split input into chunks of <= 1024 tokens with overlap
|
58 |
sentences = sent_tokenize(input_text)
|
59 |
max_chunk_length = 1024
|
60 |
overlap = 2
|
@@ -79,7 +71,6 @@ def summarize_large_text_with_bart(input_text):
|
|
79 |
|
80 |
chunks.append(' '.join(current_chunk))
|
81 |
|
82 |
-
# Generate summaries for each chunk
|
83 |
summaries = []
|
84 |
for chunk in chunks:
|
85 |
inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True).to(device)
|
@@ -97,4 +88,4 @@ def summarize_large_text_with_bart(input_text):
|
|
97 |
|
98 |
summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
|
99 |
|
100 |
-
return ' '.join(summaries)
|
|
|
1 |
+
# youtube_utils.py
|
2 |
import re
|
3 |
import torch
|
4 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
|
|
6 |
from nltk.tokenize import sent_tokenize
|
7 |
import nltk
|
8 |
|
|
|
9 |
nltk.download('punkt')
|
10 |
|
11 |
def clean_text(text):
|
|
|
12 |
cleaned_text = re.sub(r'\s+', ' ', text)
|
13 |
cleaned_text = cleaned_text.replace("'", "")
|
14 |
return cleaned_text
|
15 |
|
16 |
def get_youtube_captions(video_id):
|
|
|
17 |
try:
|
18 |
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
19 |
full_transcript = ""
|
|
|
34 |
return None
|
35 |
|
36 |
def summarize_large_text_with_bart(input_text):
|
|
|
37 |
model_name = "facebook/bart-large-cnn"
|
|
|
|
|
|
|
38 |
model = BartForConditionalGeneration.from_pretrained(model_name)
|
39 |
+
tokenizer = BartTokenizer.from_pretrained(model_name)
|
40 |
|
|
|
41 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
model.to(device)
|
43 |
|
|
|
44 |
input_tokens = tokenizer.encode(input_text, add_special_tokens=False)
|
45 |
total_input_length = len(input_tokens)
|
46 |
|
47 |
desired_min_length = int(total_input_length * 0.28)
|
48 |
desired_max_length = int(total_input_length * 0.40)
|
49 |
|
|
|
50 |
sentences = sent_tokenize(input_text)
|
51 |
max_chunk_length = 1024
|
52 |
overlap = 2
|
|
|
71 |
|
72 |
chunks.append(' '.join(current_chunk))
|
73 |
|
|
|
74 |
summaries = []
|
75 |
for chunk in chunks:
|
76 |
inputs = tokenizer.encode(chunk, return_tensors='pt', max_length=1024, truncation=True).to(device)
|
|
|
88 |
|
89 |
summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
|
90 |
|
91 |
+
return ' '.join(summaries)
|