File size: 3,629 Bytes
60f8cd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import json
from os import close
from pathlib import Path
from azure.cosmos import CosmosClient, PartitionKey, exceptions
from transformers import DistilBertTokenizerFast
import torch
class Model:
def __init__(self) -> None:
self.endPoint = "https://productdevelopmentstorage.documents.azure.com:443/"
self.primaryKey = "nVds9dPOkPuKu8RyWqigA1DIah4SVZtl1DIM0zDuRKd95an04QC0qv9TQIgrdtgluZo7Z0HXACFQgKgOQEAx1g=="
self.client = CosmosClient(self.endPoint, self.primaryKey)
self.tokenizer = None
def GetData(self, type):
database = self.client.get_database_client("squadstorage")
container = database.get_container_client(type)
item_list = list(container.read_all_items(max_item_count=10))
return item_list
def ArrangeData(self, type):
squad_dict = self.GetData(type)
contexts = []
questions = []
answers = []
for i in squad_dict:
contexts.append(i["context"])
questions.append(i["question"])
answers.append(i["answers"])
return contexts, questions, answers
def add_end_idx(self, answers, contexts):
for answer, context in zip(answers, contexts):
gold_text = answer['text'][0]
start_idx = answer['answer_start'][0]
end_idx = start_idx + len(gold_text)
if context[start_idx:end_idx] == gold_text:
answer['answer_end'] = end_idx
elif context[start_idx-1:end_idx-1] == gold_text:
answer['answer_start'] = start_idx - 1
answer['answer_end'] = end_idx - 1 # When the gold label is off by one character
elif context[start_idx-2:end_idx-2] == gold_text:
answer['answer_start'] = start_idx - 2
answer['answer_end'] = end_idx - 2 # When the gold label is off by two characters
return answers, contexts
def Tokenizer(self, train_contexts, train_questions, val_contexts, val_questions):
self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = self.tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = self.tokenizer(val_contexts, val_questions, truncation=True, padding=True)
return train_encodings, val_encodings
def add_token_positions(self, encodings, answers):
start_positions = []
end_positions = []
for i in range(len(answers)):
start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'][0]))
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
# if start position is None, the answer passage has been truncated
if start_positions[-1] is None:
start_positions[-1] = self.tokenizer.model_max_length
if end_positions[-1] is None:
end_positions[-1] = self.tokenizer.model_max_length
encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
return encodings
# train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json')
# val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json')
class SquadDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings.input_ids)
|