Spaces:
Build error
Build error
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
import pickle | |
vectorizer = pickle.load(open("tfidf.pickle", "rb")) | |
# clf = pickle.load(open("classifier.pickle", "rb")) | |
example_context = "ফলস্বরূপ, ১৯৭৯ সালে, সনি এবং ফিলিপস একটি নতুন ডিজিটাল অডিও ডিস্ক ডিজাইন করার জন্য প্রকৌশলীদের একটি যৌথ টাস্ক ফোর্স গঠন করে। ইঞ্জিনিয়ার কিস শুহামার ইমমিনক এবং তোশিতাদা দোই এর নেতৃত্বে, গবেষণাটি লেজার এবং অপটিক্যাল ডিস্ক প্রযুক্তিকে এগিয়ে নিয়ে যায়। এক বছর পরীক্ষা-নিরীক্ষা ও আলোচনার পর টাস্ক ফোর্স রেড বুক সিডি-ডিএ স্ট্যান্ডার্ড তৈরি করে। প্রথম প্রকাশিত হয় ১৯৮০ সালে। আইইসি কর্তৃক ১৯৮৭ সালে আন্তর্জাতিক মান হিসেবে আনুষ্ঠানিকভাবে এই মান গৃহীত হয় এবং ১৯৯৬ সালে বিভিন্ন সংশোধনী মানের অংশ হয়ে ওঠে।'" | |
example_answer = "১৯৮০" | |
def choose_model(model_choice): | |
if model_choice=="mt5-small": | |
return "jannatul17/squad-bn-qgen-mt5-small-v1" | |
elif model_choice=="mt5-base": | |
return "Tahsin-Mayeesha/squad-bn-mt5-base2" | |
else : | |
return "jannatul17/squad-bn-qgen-banglat5-v1" | |
def generate_questions(model_choice,context,answer,numReturnSequences=1,num_beams=None,do_sample=False,top_p=None,top_k=None,temperature=None): | |
model_name = choose_model(model_choice) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
text='answer: '+answer + ' context: ' + context | |
text_encoding = tokenizer.encode_plus( | |
text,return_tensors="pt" | |
) | |
model.eval() | |
generated_ids = model.generate( | |
input_ids=text_encoding['input_ids'], | |
attention_mask=text_encoding['attention_mask'], | |
max_length=120, | |
num_beams=num_beams, | |
do_sample=do_sample, | |
top_k = top_k, | |
top_p = top_p, | |
temperature = temperature, | |
num_return_sequences=numReturnSequences | |
) | |
text = [] | |
for id in generated_ids: | |
text.append(tokenizer.decode(id,skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')) | |
question = " ".join(text) | |
#correctness_pred = clf.predict(vectorizer.transform([question]))[0] | |
#if correctness_pred == 1: | |
# correctness = "Correct" | |
#else : | |
# correctness = "Incorrect" | |
return question | |
demo = gr.Interface(fn=generate_questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"), | |
gr.Textbox(label='Context'), | |
gr.Textbox(label='Answer'), | |
# hyperparameters | |
gr.Slider(1, 3, 1, step=1, label="Num return Sequences"), | |
# beam search | |
gr.Slider(1, 10,value=None, step=1, label="Beam width"), | |
# top-k/top-p | |
gr.Checkbox(label="Do Random Sample",value=False), | |
gr.Slider(0, 50, value=None, step=1, label="Top K"), | |
gr.Slider(0, 1, value=None, label="Top P/Nucleus Sampling"), | |
gr.Slider(0, 1, value=None, label="Temperature") ] , | |
# output | |
outputs=[gr.Textbox(label='Question')], | |
examples=[["banglat5",example_context,example_answer]], | |
cache_examples=False, | |
title="Bangla Question Generation") | |
demo.launch() | |