Spaces:
Runtime error
Runtime error
| import re | |
| import torch | |
| from transformers import DonutProcessor | |
| from transformers.utils import add_start_docstrings | |
| from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING | |
| # Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706 | |
| class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |
| def __init__(self, ngram_size: int, skip_tokens = None): | |
| if not isinstance(ngram_size, int) or ngram_size <= 0: | |
| raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") | |
| self.ngram_size = ngram_size | |
| self.skip_tokens = skip_tokens | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| num_batch_hypotheses = scores.shape[0] | |
| cur_len = input_ids.shape[-1] | |
| return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens) | |
| def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None): | |
| if no_repeat_ngram_size > 0: | |
| # calculate a list of banned tokens to prevent repetitively generating the same ngrams | |
| # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 | |
| banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) | |
| for batch_idx in range(batch_size): | |
| if skip_tokens is not None: | |
| logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf") | |
| else: | |
| logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") | |
| return logits | |
| def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): | |
| # Copied from fairseq for no_repeat_ngram in beam_search""" | |
| if cur_len + 1 < no_repeat_ngram_size: | |
| # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |
| return [[] for _ in range(num_hypos)] | |
| generated_ngrams = [{} for _ in range(num_hypos)] | |
| for idx in range(num_hypos): | |
| gen_tokens = prev_input_ids[idx].tolist() | |
| generated_ngram = generated_ngrams[idx] | |
| for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | |
| prev_ngram_tuple = tuple(ngram[:-1]) | |
| generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ | |
| ngram[-1] | |
| ] | |
| def _get_generated_ngrams(hypo_idx): | |
| # Before decoding the next token, prevent decoding of ngrams that have already appeared | |
| start_idx = cur_len + 1 - no_repeat_ngram_size | |
| ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) | |
| return generated_ngrams[hypo_idx].get(ngram_idx, []) | |
| banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] | |
| return banned_tokens | |
| def get_table_token_ids(processor): | |
| return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") } | |