Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload qasrl_model_pipeline.py
Browse files- qasrl_model_pipeline.py +183 -0
    	
        qasrl_model_pipeline.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            from argparse import Namespace
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_markers_for_model(is_t5_model: bool) -> Namespace:
         | 
| 8 | 
            +
                special_tokens_constants = Namespace() 
         | 
| 9 | 
            +
                if is_t5_model:
         | 
| 10 | 
            +
                    # T5 model have 100 special tokens by default
         | 
| 11 | 
            +
                    special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
         | 
| 12 | 
            +
                    special_tokens_constants.separator_output_answers = "<extra_id_3>"
         | 
| 13 | 
            +
                    special_tokens_constants.separator_output_questions = "<extra_id_5>"  # if using only questions
         | 
| 14 | 
            +
                    special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
         | 
| 15 | 
            +
                    special_tokens_constants.separator_output_pairs = "<extra_id_9>"
         | 
| 16 | 
            +
                    special_tokens_constants.predicate_generic_marker = "<extra_id_10>" 
         | 
| 17 | 
            +
                    special_tokens_constants.predicate_verb_marker = "<extra_id_11>" 
         | 
| 18 | 
            +
                    special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>" 
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                else:
         | 
| 21 | 
            +
                    special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
         | 
| 22 | 
            +
                    special_tokens_constants.separator_output_answers = "<answers_sep>"
         | 
| 23 | 
            +
                    special_tokens_constants.separator_output_questions = "<question_sep>"  # if using only questions
         | 
| 24 | 
            +
                    special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
         | 
| 25 | 
            +
                    special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
         | 
| 26 | 
            +
                    special_tokens_constants.predicate_generic_marker = "<predicate_marker>" 
         | 
| 27 | 
            +
                    special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>" 
         | 
| 28 | 
            +
                    special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>" 
         | 
| 29 | 
            +
                return special_tokens_constants
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            def load_trained_model(name_or_path):
         | 
| 32 | 
            +
                import huggingface_hub as HFhub
         | 
| 33 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(name_or_path)
         | 
| 34 | 
            +
                model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)  
         | 
| 35 | 
            +
                # load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
         | 
| 36 | 
            +
                kwargs_filename = None
         | 
| 37 | 
            +
                if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files
         | 
| 38 | 
            +
                    kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
         | 
| 39 | 
            +
                elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
         | 
| 40 | 
            +
                    kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                if kwargs_filename:
         | 
| 43 | 
            +
                    preprocessing_kwargs = json.load(open(kwargs_filename)) 
         | 
| 44 | 
            +
                    # integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
         | 
| 45 | 
            +
                    model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
         | 
| 46 | 
            +
                    model.config.update(preprocessing_kwargs)
         | 
| 47 | 
            +
                return model, tokenizer
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class QASRL_Pipeline(Text2TextGenerationPipeline):
         | 
| 51 | 
            +
                def __init__(self, model_repo: str, **kwargs):
         | 
| 52 | 
            +
                    model, tokenizer = load_trained_model(model_repo)
         | 
| 53 | 
            +
                    super().__init__(model, tokenizer, framework="pt")
         | 
| 54 | 
            +
                    self.is_t5_model = "t5" in model.config.model_type
         | 
| 55 | 
            +
                    self.special_tokens = get_markers_for_model(self.is_t5_model)
         | 
| 56 | 
            +
                    self.data_args = model.config.preprocessing_kwargs 
         | 
| 57 | 
            +
                    # backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
         | 
| 58 | 
            +
                    if "predicate_marker_type" not in vars(self.data_args):
         | 
| 59 | 
            +
                        self.data_args.predicate_marker_type = "generic"
         | 
| 60 | 
            +
                    if "use_bilateral_predicate_marker" not in vars(self.data_args):
         | 
| 61 | 
            +
                        self.data_args.use_bilateral_predicate_marker = True
         | 
| 62 | 
            +
                    if "append_verb_form" not in vars(self.data_args):
         | 
| 63 | 
            +
                        self.data_args.append_verb_form = True
         | 
| 64 | 
            +
                    self._update_config(**kwargs)
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                def _update_config(self, **kwargs):
         | 
| 67 | 
            +
                    " Update self.model.config with initialization parameters and necessary defaults. "
         | 
| 68 | 
            +
                    # set default values that will always override model.config, but can overriden by __init__ kwargs
         | 
| 69 | 
            +
                    kwargs["max_length"] = kwargs.get("max_length", 80)
         | 
| 70 | 
            +
                    # override model.config with kwargs
         | 
| 71 | 
            +
                    for k,v in kwargs.items():
         | 
| 72 | 
            +
                        self.model.config.__dict__[k] = v           
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                def _sanitize_parameters(self, **kwargs):
         | 
| 75 | 
            +
                    preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} 
         | 
| 76 | 
            +
                    if "predicate_marker" in kwargs:
         | 
| 77 | 
            +
                        preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
         | 
| 78 | 
            +
                    if "predicate_type" in kwargs:
         | 
| 79 | 
            +
                        preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
         | 
| 80 | 
            +
                    if "verb_form" in kwargs:
         | 
| 81 | 
            +
                        preprocess_kwargs["verb_form"] = kwargs["verb_form"]
         | 
| 82 | 
            +
                    return preprocess_kwargs, forward_kwargs, postprocess_kwargs
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
         | 
| 85 | 
            +
                    # Here, inputs is string or list of strings; apply string postprocessing
         | 
| 86 | 
            +
                    if isinstance(inputs, str):
         | 
| 87 | 
            +
                        processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
         | 
| 88 | 
            +
                    elif hasattr(inputs, "__iter__"):
         | 
| 89 | 
            +
                        processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        raise ValueError("inputs must be str or Iterable[str]")
         | 
| 92 | 
            +
                    # Now pass to super.preprocess for tokenization
         | 
| 93 | 
            +
                    return super().preprocess(processed_inputs)
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
         | 
| 96 | 
            +
                    sent_tokens = seq.split(" ")
         | 
| 97 | 
            +
                    assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
         | 
| 98 | 
            +
                    predicate_idx = sent_tokens.index(predicate_marker)
         | 
| 99 | 
            +
                    sent_tokens.remove(predicate_marker)
         | 
| 100 | 
            +
                    sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
         | 
| 101 | 
            +
                    predicate = sent_tokens[predicate_idx]
         | 
| 102 | 
            +
                    sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
         | 
| 103 | 
            +
                    
         | 
| 104 | 
            +
                    if self.data_args.predicate_marker_type == "generic":
         | 
| 105 | 
            +
                        predicate_marker = self.special_tokens.predicate_generic_marker    
         | 
| 106 | 
            +
                    #  In case we want special marker for each predicate type: """
         | 
| 107 | 
            +
                    elif self.data_args.predicate_marker_type == "pred_type":
         | 
| 108 | 
            +
                        assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
         | 
| 109 | 
            +
                        assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
         | 
| 110 | 
            +
                        predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker , 
         | 
| 111 | 
            +
                                            "nominal": self.special_tokens.predicate_nominalization_marker 
         | 
| 112 | 
            +
                                            }[predicate_type]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if self.data_args.use_bilateral_predicate_marker:
         | 
| 115 | 
            +
                        seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
         | 
| 116 | 
            +
                    else:
         | 
| 117 | 
            +
                        seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # embed also verb_form
         | 
| 120 | 
            +
                    if self.data_args.append_verb_form and verb_form is None:
         | 
| 121 | 
            +
                        raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
         | 
| 122 | 
            +
                    elif self.data_args.append_verb_form:
         | 
| 123 | 
            +
                        seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        seq = f"{seq} "
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                    # append source prefix (for t5 models)
         | 
| 128 | 
            +
                    prefix = self._get_source_prefix(predicate_type)
         | 
| 129 | 
            +
                    
         | 
| 130 | 
            +
                    return prefix + seq
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
                def _get_source_prefix(self, predicate_type: Optional[str]):
         | 
| 133 | 
            +
                    if not self.is_t5_model or self.data_args.source_prefix is None:
         | 
| 134 | 
            +
                        return ''
         | 
| 135 | 
            +
                    if not self.data_args.source_prefix.startswith("<"):  # Regular prefix - not dependent on input row x
         | 
| 136 | 
            +
                        return self.data_args.source_prefix
         | 
| 137 | 
            +
                    if self.data_args.source_prefix == "<predicate-type>":
         | 
| 138 | 
            +
                        if predicate_type is None:
         | 
| 139 | 
            +
                            raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.")
         | 
| 140 | 
            +
                        else:
         | 
| 141 | 
            +
                            return f"Generate QAs for {predicate_type} QASRL: "
         | 
| 142 | 
            +
                
         | 
| 143 | 
            +
                def _forward(self, *args, **kwargs):
         | 
| 144 | 
            +
                    outputs = super()._forward(*args, **kwargs)
         | 
| 145 | 
            +
                    return outputs
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
                def postprocess(self, model_outputs):
         | 
| 149 | 
            +
                    output_seq = self.tokenizer.decode(
         | 
| 150 | 
            +
                        model_outputs["output_ids"].squeeze(),
         | 
| 151 | 
            +
                        skip_special_tokens=False,
         | 
| 152 | 
            +
                        clean_up_tokenization_spaces=False,
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
                    output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
         | 
| 155 | 
            +
                    qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
         | 
| 156 | 
            +
                    qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
         | 
| 157 | 
            +
                    return {"generated_text": output_seq,
         | 
| 158 | 
            +
                            "QAs": qas}
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                def _postrocess_qa(self, seq: str) -> str:
         | 
| 161 | 
            +
                    # split question and answers
         | 
| 162 | 
            +
                    if self.special_tokens.separator_output_question_answer in seq:
         | 
| 163 | 
            +
                        question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        print("invalid format: no separator between question and answer found...")
         | 
| 166 | 
            +
                        return None
         | 
| 167 | 
            +
                        # question, answer = seq, '' # Or: backoff to only question  
         | 
| 168 | 
            +
                    # skip "_" slots in questions
         | 
| 169 | 
            +
                    question = ' '.join(t for t in question.split(' ') if t != '_')
         | 
| 170 | 
            +
                    answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
         | 
| 171 | 
            +
                    return {"question": question, "answers": answers}
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                
         | 
| 174 | 
            +
            if __name__ == "__main__":
         | 
| 175 | 
            +
                pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
         | 
| 176 | 
            +
                res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal")
         | 
| 177 | 
            +
                res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
         | 
| 178 | 
            +
                             "The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
         | 
| 179 | 
            +
                res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
         | 
| 180 | 
            +
                print(res1)
         | 
| 181 | 
            +
                print(res2)
         | 
| 182 | 
            +
                print(res3)
         | 
| 183 | 
            +
                
         |