Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from typing import Optional | |
| import json | |
| from argparse import Namespace | |
| from pathlib import Path | |
| from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
| def get_markers_for_model(is_t5_model: bool) -> Namespace: | |
| special_tokens_constants = Namespace() | |
| if is_t5_model: | |
| # T5 model have 100 special tokens by default | |
| special_tokens_constants.separator_input_question_predicate = "<extra_id_1>" | |
| special_tokens_constants.separator_output_answers = "<extra_id_3>" | |
| special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions | |
| special_tokens_constants.separator_output_question_answer = "<extra_id_7>" | |
| special_tokens_constants.separator_output_pairs = "<extra_id_9>" | |
| special_tokens_constants.predicate_generic_marker = "<extra_id_10>" | |
| special_tokens_constants.predicate_verb_marker = "<extra_id_11>" | |
| special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>" | |
| else: | |
| special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>" | |
| special_tokens_constants.separator_output_answers = "<answers_sep>" | |
| special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions | |
| special_tokens_constants.separator_output_question_answer = "<question_answer_sep>" | |
| special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>" | |
| special_tokens_constants.predicate_generic_marker = "<predicate_marker>" | |
| special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>" | |
| special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>" | |
| return special_tokens_constants | |
| def load_trained_model(name_or_path): | |
| import huggingface_hub as HFhub | |
| tokenizer = AutoTokenizer.from_pretrained(name_or_path) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) | |
| # load preprocessing_kwargs from the model repo on HF hub, or from the local model directory | |
| kwargs_filename = None | |
| if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files | |
| kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json") | |
| elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists(): | |
| kwargs_filename = Path(name_or_path) / "experiment_kwargs.json" | |
| if kwargs_filename: | |
| preprocessing_kwargs = json.load(open(kwargs_filename)) | |
| # integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing | |
| model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs) | |
| model.config.update(preprocessing_kwargs) | |
| return model, tokenizer | |
| class QASRL_Pipeline(Text2TextGenerationPipeline): | |
| def __init__(self, model_repo: str, **kwargs): | |
| model, tokenizer = load_trained_model(model_repo) | |
| super().__init__(model, tokenizer, framework="pt") | |
| self.is_t5_model = "t5" in model.config.model_type | |
| self.special_tokens = get_markers_for_model(self.is_t5_model) | |
| self.data_args = model.config.preprocessing_kwargs | |
| # backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs` | |
| if "predicate_marker_type" not in vars(self.data_args): | |
| self.data_args.predicate_marker_type = "generic" | |
| if "use_bilateral_predicate_marker" not in vars(self.data_args): | |
| self.data_args.use_bilateral_predicate_marker = True | |
| if "append_verb_form" not in vars(self.data_args): | |
| self.data_args.append_verb_form = True | |
| self._update_config(**kwargs) | |
| def _update_config(self, **kwargs): | |
| " Update self.model.config with initialization parameters and necessary defaults. " | |
| # set default values that will always override model.config, but can overriden by __init__ kwargs | |
| kwargs["max_length"] = kwargs.get("max_length", 80) | |
| # override model.config with kwargs | |
| for k,v in kwargs.items(): | |
| self.model.config.__dict__[k] = v | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} | |
| if "predicate_marker" in kwargs: | |
| preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"] | |
| if "predicate_type" in kwargs: | |
| preprocess_kwargs["predicate_type"] = kwargs["predicate_type"] | |
| if "verb_form" in kwargs: | |
| preprocess_kwargs["verb_form"] = kwargs["verb_form"] | |
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs | |
| def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None): | |
| # Here, inputs is string or list of strings; apply string postprocessing | |
| if isinstance(inputs, str): | |
| processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form) | |
| elif hasattr(inputs, "__iter__"): | |
| processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs] | |
| else: | |
| raise ValueError("inputs must be str or Iterable[str]") | |
| # Now pass to super.preprocess for tokenization | |
| return super().preprocess(processed_inputs) | |
| def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str: | |
| sent_tokens = seq.split(" ") | |
| assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word" | |
| predicate_idx = sent_tokens.index(predicate_marker) | |
| sent_tokens.remove(predicate_marker) | |
| sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)]) | |
| predicate = sent_tokens[predicate_idx] | |
| sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))]) | |
| if self.data_args.predicate_marker_type == "generic": | |
| predicate_marker = self.special_tokens.predicate_generic_marker | |
| # In case we want special marker for each predicate type: """ | |
| elif self.data_args.predicate_marker_type == "pred_type": | |
| assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it" | |
| assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'" | |
| predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker , | |
| "nominal": self.special_tokens.predicate_nominalization_marker | |
| }[predicate_type] | |
| if self.data_args.use_bilateral_predicate_marker: | |
| seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}" | |
| else: | |
| seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}" | |
| # embed also verb_form | |
| if self.data_args.append_verb_form and verb_form is None: | |
| raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)") | |
| elif self.data_args.append_verb_form: | |
| seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} " | |
| else: | |
| seq = f"{seq} " | |
| # append source prefix (for t5 models) | |
| prefix = self._get_source_prefix(predicate_type) | |
| return prefix + seq | |
| def _get_source_prefix(self, predicate_type: Optional[str]): | |
| if not self.is_t5_model or self.data_args.source_prefix is None: | |
| return '' | |
| if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x | |
| return self.data_args.source_prefix | |
| if self.data_args.source_prefix == "<predicate-type>": | |
| if predicate_type is None: | |
| raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.") | |
| else: | |
| return f"Generate QAs for {predicate_type} QASRL: " | |
| def _forward(self, *args, **kwargs): | |
| outputs = super()._forward(*args, **kwargs) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| output_seq = self.tokenizer.decode( | |
| model_outputs["output_ids"].squeeze(), | |
| skip_special_tokens=False, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip() | |
| qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs) | |
| qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs] | |
| return {"generated_text": output_seq, | |
| "QAs": qas} | |
| def _postrocess_qa(self, seq: str) -> str: | |
| # split question and answers | |
| if self.special_tokens.separator_output_question_answer in seq: | |
| question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2] | |
| else: | |
| print("invalid format: no separator between question and answer found...") | |
| return None | |
| # question, answer = seq, '' # Or: backoff to only question | |
| # skip "_" slots in questions | |
| question = ' '.join(t for t in question.split(' ') if t != '_') | |
| answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)] | |
| return {"question": question, "answers": answers} | |
| if __name__ == "__main__": | |
| pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline") | |
| res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal") | |
| res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .", | |
| "The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10) | |
| res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal") | |
| print(res1) | |
| print(res2) | |
| print(res3) | |