File size: 3,629 Bytes
60f8cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
from os import close
from pathlib import Path
from azure.cosmos import CosmosClient, PartitionKey, exceptions
from transformers import DistilBertTokenizerFast
import torch


class Model:

    def __init__(self) -> None:
        self.endPoint = "https://productdevelopmentstorage.documents.azure.com:443/"
        self.primaryKey = "nVds9dPOkPuKu8RyWqigA1DIah4SVZtl1DIM0zDuRKd95an04QC0qv9TQIgrdtgluZo7Z0HXACFQgKgOQEAx1g=="
        self.client = CosmosClient(self.endPoint, self.primaryKey)
        self.tokenizer = None

    def GetData(self, type):
        database = self.client.get_database_client("squadstorage")
        container = database.get_container_client(type)
        item_list = list(container.read_all_items(max_item_count=10))
        return item_list

    def ArrangeData(self, type):
        squad_dict = self.GetData(type)

        contexts = []
        questions = []
        answers = []

        for i in squad_dict:
            contexts.append(i["context"])
            questions.append(i["question"])
            answers.append(i["answers"])

        return contexts, questions, answers

    def add_end_idx(self, answers, contexts):
        for answer, context in zip(answers, contexts):
            gold_text = answer['text'][0]
            start_idx = answer['answer_start'][0]
            end_idx = start_idx + len(gold_text)

            if context[start_idx:end_idx] == gold_text:
                answer['answer_end'] = end_idx
            elif context[start_idx-1:end_idx-1] == gold_text:
                answer['answer_start'] = start_idx - 1
                answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
            elif context[start_idx-2:end_idx-2] == gold_text:
                answer['answer_start'] = start_idx - 2
                answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

        return answers, contexts

    def Tokenizer(self, train_contexts, train_questions, val_contexts, val_questions):
        self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

        train_encodings = self.tokenizer(train_contexts, train_questions, truncation=True, padding=True)
        val_encodings = self.tokenizer(val_contexts, val_questions, truncation=True, padding=True)

        return train_encodings, val_encodings


    def add_token_positions(self, encodings, answers):
        start_positions = []
        end_positions = []
        for i in range(len(answers)):
            start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'][0]))
            end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))

            # if start position is None, the answer passage has been truncated
            
            if start_positions[-1] is None:
                start_positions[-1] = self.tokenizer.model_max_length
            if end_positions[-1] is None:
                end_positions[-1] = self.tokenizer.model_max_length

        encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
        return encodings

    # train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json')
    # val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json')





class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)