dappyx commited on
Commit
4297696
1 Parent(s): f7bc05a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertForQuestionAnswering, DistilBertConfig, DistilBertTokenizerFast
2
+ import torch
3
+
4
+ model = DistilBertForQuestionAnswering(DistilBertConfig.from_pretrained('distilbert/distilbert-base-multilingual-cased')).to("cpu")
5
+ st_dict = torch.load("save/best_f1/checkpoint/QazDistilBERT.pt")
6
+ model.load_state_dict(st_dict)
7
+ tokenizer = DistilBertTokenizerFast.from_pretrained("dappyx/QazDistilbertFast-tokenizerV3")
8
+
9
+ import gradio as gr
10
+
11
+ def qa_pipeline(text,question):
12
+ inputs = tokenizer(question, text, return_tensors="pt")
13
+ input_ids = inputs['input_ids'].to("cpu")
14
+ attention_mask = inputs['attention_mask'].to("cpu")
15
+ outputs = model(input_ids=input_ids,attention_mask=attention_mask)
16
+
17
+ start_index = torch.argmax(outputs.start_logits, dim=-1).item()
18
+ end_index = torch.argmax(outputs.end_logits, dim=-1).item()
19
+
20
+ predict_answer_tokens = inputs.input_ids[0, start_index : end_index + 1]
21
+ return tokenizer.decode(predict_answer_tokens)
22
+
23
+ def answer_question(context, question):
24
+ result = qa_pipeline(context, question)
25
+ return result
26
+
27
+
28
+ # Создаем интерфейс
29
+ iface = gr.Interface(
30
+ fn=answer_question,
31
+ inputs=[
32
+ gr.Textbox(lines=10, label="Context"),
33
+ gr.Textbox(lines=2, label="Question")
34
+ ],
35
+ outputs="text",
36
+ title="Question Answering Model",
37
+ description="Введите контекст и задайте вопрос, чтобы получить ответ."
38
+ )
39
+
40
+ # Запускаем интерфейс
41
+ iface.launch()