charlie572's picture
Add question generator
6c31d06
raw
history blame
1.35 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def summarize(text):
checkpoint = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
inputs = tokenizer(text, truncation=True, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_question(text):
checkpoint = "mrm8488/t5-base-finetuned-question-generation-ap"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
prompt = f"answer: {text} context: {text}"
inputs = tokenizer(prompt, truncation=True, return_tensors="pt").input_ids
outputs = model.generate(inputs)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
_, question = generated_text.split("question: ")
return question
summarize_interface = gr.Interface(fn=summarize, inputs="text", outputs="text")
question_interface = gr.Interface(fn=generate_question, inputs="text", outputs="text")
tabs = gr.TabbedInterface(
[summarize_interface, question_interface], ["Summarize an article", "Generate a question"]
)
tabs.launch()