sumit
commited on
Commit
·
d8d5ce9
1
Parent(s):
81c680b
add index from tokenizer
Browse files- BertForJointParsing.py +47 -10
BertForJointParsing.py
CHANGED
@@ -186,7 +186,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
-
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
@@ -234,32 +234,66 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
234 |
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
|
235 |
if per_token_ner:
|
236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
237 |
-
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
238 |
-
|
239 |
if output_style in ['ud', 'iahlt_ud']:
|
240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
241 |
|
242 |
if is_single_sentence:
|
243 |
final_output = final_output[0]
|
|
|
|
|
|
|
|
|
|
|
244 |
return final_output
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
|
248 |
def aggregate_ner_tokens(predictions):
|
249 |
entities = []
|
250 |
prev = None
|
251 |
-
for word, pred, start, end in predictions:
|
252 |
# O does nothing
|
253 |
if pred == 'O': prev = None
|
254 |
# B- || I-entity != prev (different entity or none)
|
255 |
elif pred.startswith('B-') or pred[2:] != prev:
|
256 |
prev = pred[2:]
|
257 |
-
entities.append([[word], prev, start, end])
|
258 |
else:
|
259 |
entities[-1][0].append(word)
|
260 |
entities[-1][3] = end
|
|
|
261 |
|
262 |
-
return [dict(phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end in entities]
|
263 |
|
264 |
def merge_token_list(src, update, key):
|
265 |
for token_src, token_update in zip(src, update):
|
@@ -276,7 +310,6 @@ def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFa
|
|
276 |
|
277 |
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
|
278 |
input_ids = inputs['input_ids']
|
279 |
-
|
280 |
predictions = torch.argmax(logits, dim=-1)
|
281 |
batch_ret = []
|
282 |
for batch_idx in range(len(sentences)):
|
@@ -295,11 +328,15 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
295 |
# we modify the last token in ret
|
296 |
# by discarding the original end position and replacing it with the new token's end position
|
297 |
if token.startswith('##'):
|
298 |
-
ret[-1] =
|
299 |
continue
|
300 |
# for each token, we append a tuple containing: token, label, start position, end position
|
301 |
-
ret.append(
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
return batch_ret
|
304 |
|
305 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
+
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, detailed_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
|
|
234 |
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label, offset_mapping)):
|
235 |
if per_token_ner:
|
236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
237 |
+
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
238 |
+
|
239 |
if output_style in ['ud', 'iahlt_ud']:
|
240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
241 |
|
242 |
if is_single_sentence:
|
243 |
final_output = final_output[0]
|
244 |
+
|
245 |
+
words_index = parse_index(inputs['input_ids'], tokenizer)[0]
|
246 |
+
for idx, w in zip(words_index, final_output[0]['tokens']):
|
247 |
+
w['idx'] = idx
|
248 |
+
|
249 |
return final_output
|
250 |
|
251 |
+
def parse_index(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
|
252 |
+
# Create input_indices for each input_id, handling word-pieces
|
253 |
+
input_indices = []
|
254 |
+
for batch_idx, ids in enumerate(input_ids):
|
255 |
+
sentence_indices = []
|
256 |
+
current_word_indices = []
|
257 |
+
for idx, id_value in enumerate(ids):
|
258 |
+
# Skip special tokens
|
259 |
+
if id_value in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
|
260 |
+
continue
|
261 |
+
|
262 |
+
token_id = input_ids[batch_idx, idx]
|
263 |
+
token = tokenizer._convert_id_to_token(token_id)
|
264 |
+
|
265 |
+
# If the token is a continuation of a previous word (word-piece), append the index
|
266 |
+
if token.startswith('##'):
|
267 |
+
current_word_indices.append(idx)
|
268 |
+
else:
|
269 |
+
# If there's a current word, add it to sentence indices
|
270 |
+
if current_word_indices:
|
271 |
+
sentence_indices.append(current_word_indices)
|
272 |
+
current_word_indices = [idx]
|
273 |
+
|
274 |
+
# Add the last word to sentence indices if not empty
|
275 |
+
if current_word_indices:
|
276 |
+
sentence_indices.append(current_word_indices)
|
277 |
+
input_indices.append(sentence_indices)
|
278 |
+
return input_indices
|
279 |
|
280 |
|
281 |
def aggregate_ner_tokens(predictions):
|
282 |
entities = []
|
283 |
prev = None
|
284 |
+
for word, pred, start, end, idx in predictions:
|
285 |
# O does nothing
|
286 |
if pred == 'O': prev = None
|
287 |
# B- || I-entity != prev (different entity or none)
|
288 |
elif pred.startswith('B-') or pred[2:] != prev:
|
289 |
prev = pred[2:]
|
290 |
+
entities.append([[word], prev, start, end, idx])
|
291 |
else:
|
292 |
entities[-1][0].append(word)
|
293 |
entities[-1][3] = end
|
294 |
+
entities[-1][4].extend(idx)
|
295 |
|
296 |
+
return [dict(idx=idx, phrase=' '.join(words), label=label, start=start, end=end) for words, label, start, end, idx in entities]
|
297 |
|
298 |
def merge_token_list(src, update, key):
|
299 |
for token_src, token_update in zip(src, update):
|
|
|
310 |
|
311 |
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str], offset_mapping):
|
312 |
input_ids = inputs['input_ids']
|
|
|
313 |
predictions = torch.argmax(logits, dim=-1)
|
314 |
batch_ret = []
|
315 |
for batch_idx in range(len(sentences)):
|
|
|
328 |
# we modify the last token in ret
|
329 |
# by discarding the original end position and replacing it with the new token's end position
|
330 |
if token.startswith('##'):
|
331 |
+
ret[-1] = [ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item()]
|
332 |
continue
|
333 |
# for each token, we append a tuple containing: token, label, start position, end position
|
334 |
+
ret.append([token, id2label[predictions[batch_idx, tok_idx].item()], start_pos.item(), end_pos.item()])
|
335 |
+
|
336 |
+
words_index = parse_index(inputs['input_ids'], tokenizer)[0]
|
337 |
+
for idx, w in zip(words_index, batch_ret[0]):
|
338 |
+
w.append(idx)
|
339 |
+
|
340 |
return batch_ret
|
341 |
|
342 |
def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
|