Spaces:
Runtime error
Runtime error
File size: 2,723 Bytes
b8dde21 a7f0432 b8dde21 a7f0432 69967c7 a258be0 a7f0432 6c31d06 b8dde21 6c31d06 b8dde21 6c31d06 b8dde21 6c31d06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import gradio as gr
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
AutoModelForTokenClassification,
)
import googlesearch
def summarize(text):
checkpoint = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
inputs = tokenizer(text, truncation=True, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_question(text):
checkpoint = "mrm8488/t5-base-finetuned-question-generation-ap"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
prompt = f"answer: {text} context: {text}"
inputs = tokenizer(prompt, truncation=True, return_tensors="pt").input_ids
outputs = model.generate(inputs)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
_, question = generated_text.split("question: ")
return question
def get_similar_articles(text):
tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-keyword-extractor")
model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-keyword-extractor")
inputs = tokenizer(text, truncation=True, return_tensors="pt")
outputs = model(**inputs)
keyword_tokens = []
current_keyword_tokens = []
for token, logits in zip(inputs.input_ids[0], outputs.logits[0]):
token_type = torch.argmax(logits).item()
if token_type > 0:
current_keyword_tokens.append(token.item())
elif len(current_keyword_tokens) > 0:
keyword_tokens.append(current_keyword_tokens)
current_keyword_tokens = []
keywords = tokenizer.batch_decode(keyword_tokens)
keywords = list(set(keywords))
similar_websites = []
for keyword in keywords[:3]:
websites = googlesearch.search(
keyword,
tld="com",
lang="en",
num=3,
stop=3,
pause=0.5,
)
similar_websites += list(websites)
return "\n".join(similar_websites)
summarize_interface = gr.Interface(fn=summarize, inputs="text", outputs="text")
question_interface = gr.Interface(fn=generate_question, inputs="text", outputs="text")
similar_articles_interface = gr.Interface(fn=get_similar_articles, inputs="text", outputs="text")
tabs = gr.TabbedInterface(
[summarize_interface, question_interface, similar_articles_interface],
["Summarize an article", "Generate a question", "Get similar articles"],
)
tabs.launch()
|