Mizuiro-sakura's picture
Update deberta_squad.py
981f30e
raw
history blame contribute delete
911 Bytes
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-base-japanese')
model=torch.load('C:\\[.pth modelのあるディレクトリ]\\My_deberta_model_squad.pth') # 学習済みモデルの読み込み
text={
'context':'私の名前はEIMIです。好きな食べ物は苺です。 趣味は皆さんと会話することです。',
'question' :'好きな食べ物は何ですか'
}
input_ids=tokenizer.encode(text['question'],text['context']) # tokenizerで形態素解析しつつコードに変換する
con=tokenizer.encode(text['question'])
output= model(torch.tensor([input_ids])) # 学習済みモデルを用いて解析
prediction = tokenizer.decode(input_ids[torch.argmax(output.start_logits): torch.argmax(output.end_logits)]) # 答えに該当する部分を抜き取る
prediction=prediction.replace('</s>','')
print(prediction)