File size: 4,810 Bytes
854a552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import pdb
import sys

WORD_POS = 1
TAG_POS = 2
MASK_TAG = "__entity__"
INPUT_MASK_TAG = ":__entity__"
RESET_POS_TAG='RESET'


noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']


def detect_masked_positions(terms_arr):
    sentence_arr,span_arr = generate_masked_sentences(terms_arr)
    new_sent_arr = []
    for i in  range(len(terms_arr)):
        new_sent_arr.append(terms_arr[i][WORD_POS])
    return new_sent_arr,sentence_arr,span_arr

def generate_masked_sentences(terms_arr):
    size = len(terms_arr)
    sentence_arr = []
    span_arr = []
    i = 0
    hack_for_no_nouns_case(terms_arr)
    while (i < size):
        term_info = terms_arr[i]
        if (term_info[TAG_POS] in noun_tags):
            skip = gen_sentence(sentence_arr,terms_arr,i)
            i +=  skip
            for j in range(skip):
                span_arr.append(1)
        else:
            i += 1
            span_arr.append(0)
    #print(sentence_arr)
    return sentence_arr,span_arr

def hack_for_no_nouns_case(terms_arr):
    '''
        This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns
        Happens for odd inputs like a single word like "eg" etc.
        Just make the first term as a noun to proceed. 
    '''
    size = len(terms_arr)
    i = 0
    found = False
    while (i < size):
        term_info = terms_arr[i]
        if (term_info[TAG_POS] in noun_tags):
               found = True
               break
        else:
            i += 1
    if (not found and len(terms_arr) >= 1):
        term_info = terms_arr[0]
        term_info[TAG_POS] =  noun_tags[0]


def gen_sentence(sentence_arr,terms_arr,index):
    size = len(terms_arr)
    new_sent = []
    for prefix,term in enumerate(terms_arr[:index]):
        new_sent.append(term[WORD_POS])
    i = index
    skip = 0
    while (i < size):
        if (terms_arr[i][TAG_POS] in noun_tags):
            skip += 1
            i += 1
        else:
            break
    new_sent.append(MASK_TAG)
    i = index + skip
    while (i < size):
        new_sent.append(terms_arr[i][WORD_POS])
        i += 1
    assert(skip != 0)
    sentence_arr.append(new_sent)
    return skip



def capitalize(terms_arr):
    for i,term_tag in enumerate(terms_arr):
        #print(term_tag)
        if (term_tag[TAG_POS] in cap_tags):
            word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:]
            term_tag[WORD_POS] = word
    #print(terms_arr)

def set_POS_based_on_entities(sent):
    terms_arr = []
    sent_arr = sent.split()
    for i,word in enumerate(sent_arr):
        #print(term_tag)
        term_tag = ['-']*5
        if (word.endswith(INPUT_MASK_TAG)):
            term_tag[TAG_POS] = noun_tags[0]
            term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"")
        else:
            term_tag[TAG_POS] = RESET_POS_TAG
            term_tag[WORD_POS] = word
        terms_arr.append(term_tag)
    return terms_arr
    #print(terms_arr)

def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs):
    ret_span_arr = span_arr.copy()
    ret_masked_sent_arr = []
    sent_index = 0
    loop_span_index = 0
    while (loop_span_index < len(span_arr)):
        span_val = span_arr[loop_span_index]
        orig_index = loop_span_index
        if (span_val == 1):
            curr_index = orig_index
            is_all_common = True
            while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
                term = terms_arr[curr_index]
                if (term[WORD_POS].lower() not in common_descs):
                    is_all_common = False
                curr_index += 1
            loop_span_index = curr_index #note the loop scan index is updated
            if (is_all_common):
                curr_index = orig_index
                print("Filtering common span: ",end='')
                while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
                    print(terms_arr[curr_index][WORD_POS],' ',end='')
                    ret_span_arr[curr_index] = 0
                    curr_index += 1
                print()
                sent_index += 1 # we are skipping a span
            else:
                ret_masked_sent_arr.append(masked_sent_arr[sent_index])
                sent_index += 1
        else:
            loop_span_index += 1
    return ret_masked_sent_arr,ret_span_arr

def normalize_casing(sent):
    sent_arr = sent.split()
    ret_sent_arr = []
    for i,word in enumerate(sent_arr):
        if (len(word) > 1):
            norm_word = word[0] + word[1:].lower()
        else:
            norm_word = word[0]
        ret_sent_arr.append(norm_word)
    return ' '.join(ret_sent_arr)