|
import torch |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-base-japanese') |
|
model=torch.load('C:\\[.pth modelのあるディレクトリ]\\My_deberta_model_squad.pth') |
|
|
|
text={ |
|
'context':'私の名前はEIMIです。好きな食べ物は苺です。 趣味は皆さんと会話することです。', |
|
'question' :'好きな食べ物は何ですか' |
|
} |
|
|
|
input_ids=tokenizer.encode(text['question'],text['context']) |
|
con=tokenizer.encode(text['question']) |
|
output= model(torch.tensor([input_ids])) |
|
prediction = tokenizer.decode(input_ids[torch.argmax(output.start_logits): torch.argmax(output.end_logits)]) |
|
prediction=prediction.replace('</s>','') |
|
print(prediction) |
|
|