| from transformers.modeling_outputs import TokenClassifierOutput | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig | |
| from torch.nn import CrossEntropyLoss | |
| from typing import Optional, Tuple, Union | |
| import logging, json, os | |
| import floret | |
| from .configuration_stacked import ImpressoConfig | |
| logger = logging.getLogger(__name__) | |
| def get_info(label_map): | |
| num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} | |
| return num_token_labels_dict | |
| # class MyCustomModel: | |
| # def __init__(self): | |
| # # Custom initialization | |
| # pass | |
| # | |
| # @classmethod | |
| # def from_pretrained(cls, *args, **kwargs): | |
| # print("Ignoring weights and using custom initialization.") | |
| # return cls() | |
| class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): | |
| config_class = ImpressoConfig | |
| _keys_to_ignore_on_load_missing = [r"position_ids"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # Load floret model | |
| self.dummy_param = nn.Parameter(torch.zeros(1)) | |
| self.model_floret = floret.load_model(self.config.filename) | |
| input_ids = "this is a text" | |
| predictions, probabilities = self.model_floret.predict([input_ids], k=1) | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| # Convert input_ids to strings using tokenizer | |
| print( | |
| f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}" | |
| ) | |
| # if input_ids is not None: | |
| # tokenizer = kwargs.get("tokenizer") | |
| # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
| # else: | |
| # texts = kwargs.get("text", None) | |
| # | |
| # if texts: | |
| # # Floret expects strings, not tensors | |
| # predictions = [self.model_floret(text) for text in texts] | |
| # # Convert predictions to tensors for Hugging Face compatibility | |
| # return torch.tensor(predictions) | |
| # else: | |
| # If no text is found, return dummy output | |
| return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes) | |
| def state_dict(self, *args, **kwargs): | |
| # Return an empty state dictionary | |
| return {} | |
| def load_state_dict(self, state_dict, strict=True): | |
| # Ignore loading since there are no parameters | |
| print("Ignoring state_dict since model has no parameters.") | |
| def get_floret_model(self): | |
| return self.model_floret | |
| def get_extended_attention_mask( | |
| self, attention_mask, input_shape, device=None, dtype=torch.float | |
| ): | |
| if attention_mask is None: | |
| attention_mask = torch.ones(input_shape, device=device) | |
| extended_attention_mask = attention_mask[:, None, None, :] | |
| extended_attention_mask = extended_attention_mask.to(dtype=dtype) | |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
| return extended_attention_mask | |
| def device(self): | |
| return next(self.parameters()).device | |
| def from_pretrained(cls, *args, **kwargs): | |
| print("Ignoring weights and using custom initialization.") | |
| # Manually create the config | |
| config = ImpressoConfig(**kwargs) | |
| # Pass the manually created config to the class | |
| model = cls(config) | |
| return model | |
| # class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): | |
| # | |
| # config_class = ImpressoConfig | |
| # _keys_to_ignore_on_load_missing = [r"position_ids"] | |
| # | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # # self.num_token_labels_dict = get_info(config.label_map) | |
| # # self.config = config | |
| # # # print(f"I dont think it arrives here: {self.config}") | |
| # # self.bert = AutoModel.from_pretrained( | |
| # # config.pretrained_config["_name_or_path"], config=config.pretrained_config | |
| # # ) | |
| # self.model_floret = floret.load_model(self.config.filename) | |
| # # print(f"Model loaded: {self.model_floret}") | |
| # # if "classifier_dropout" not in config.__dict__: | |
| # # classifier_dropout = 0.1 | |
| # # else: | |
| # # classifier_dropout = ( | |
| # # config.classifier_dropout | |
| # # if config.classifier_dropout is not None | |
| # # else config.hidden_dropout_prob | |
| # # ) | |
| # # self.dropout = nn.Dropout(classifier_dropout) | |
| # # | |
| # # # Additional transformer layers | |
| # # self.transformer_encoder = nn.TransformerEncoder( | |
| # # nn.TransformerEncoderLayer( | |
| # # d_model=config.hidden_size, nhead=config.num_attention_heads | |
| # # ), | |
| # # num_layers=2, | |
| # # ) | |
| # | |
| # # For token classification, create a classifier for each task | |
| # # self.token_classifiers = nn.ModuleDict( | |
| # # { | |
| # # task: nn.Linear(config.hidden_size, num_labels) | |
| # # for task, num_labels in self.num_token_labels_dict.items() | |
| # # } | |
| # # ) | |
| # # | |
| # # # Initialize weights and apply final processing | |
| # # self.post_init() | |
| # | |
| # def get_floret_model(self): | |
| # return self.model_floret | |
| # | |
| # @classmethod | |
| # def from_pretrained(cls, *args, **kwargs): | |
| # print("Ignoring weights and using custom initialization.") | |
| # | |
| # # Manually create the config | |
| # config = ImpressoConfig() | |
| # | |
| # # Pass the manually created config to the class | |
| # model = cls(config) | |
| # return model | |
| # | |
| # # def forward( | |
| # # self, | |
| # # input_ids: Optional[torch.Tensor] = None, | |
| # # attention_mask: Optional[torch.Tensor] = None, | |
| # # token_type_ids: Optional[torch.Tensor] = None, | |
| # # position_ids: Optional[torch.Tensor] = None, | |
| # # head_mask: Optional[torch.Tensor] = None, | |
| # # inputs_embeds: Optional[torch.Tensor] = None, | |
| # # labels: Optional[torch.Tensor] = None, | |
| # # token_labels: Optional[dict] = None, | |
| # # output_attentions: Optional[bool] = None, | |
| # # output_hidden_states: Optional[bool] = None, | |
| # # return_dict: Optional[bool] = None, | |
| # # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: | |
| # # r""" | |
| # # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): | |
| # # Labels for computing the token classification loss. Keys should match the tasks. | |
| # # """ | |
| # # return_dict = ( | |
| # # return_dict if return_dict is not None else self.config.use_return_dict | |
| # # ) | |
| # # | |
| # # bert_kwargs = { | |
| # # "input_ids": input_ids, | |
| # # "attention_mask": attention_mask, | |
| # # "token_type_ids": token_type_ids, | |
| # # "position_ids": position_ids, | |
| # # "head_mask": head_mask, | |
| # # "inputs_embeds": inputs_embeds, | |
| # # "output_attentions": output_attentions, | |
| # # "output_hidden_states": output_hidden_states, | |
| # # "return_dict": return_dict, | |
| # # } | |
| # # | |
| # # if any( | |
| # # keyword in self.config.name_or_path.lower() | |
| # # for keyword in ["llama", "deberta"] | |
| # # ): | |
| # # bert_kwargs.pop("token_type_ids") | |
| # # bert_kwargs.pop("head_mask") | |
| # # | |
| # # outputs = self.bert(**bert_kwargs) | |
| # # | |
| # # # For token classification | |
| # # token_output = outputs[0] | |
| # # token_output = self.dropout(token_output) | |
| # # | |
| # # # Pass through additional transformer layers | |
| # # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( | |
| # # 0, 1 | |
| # # ) | |
| # # | |
| # # # Collect the logits and compute the loss for each task | |
| # # task_logits = {} | |
| # # total_loss = 0 | |
| # # for task, classifier in self.token_classifiers.items(): | |
| # # logits = classifier(token_output) | |
| # # task_logits[task] = logits | |
| # # if token_labels and task in token_labels: | |
| # # loss_fct = CrossEntropyLoss() | |
| # # loss = loss_fct( | |
| # # logits.view(-1, self.num_token_labels_dict[task]), | |
| # # token_labels[task].view(-1), | |
| # # ) | |
| # # total_loss += loss | |
| # # | |
| # # if not return_dict: | |
| # # output = (task_logits,) + outputs[2:] | |
| # # return ((total_loss,) + output) if total_loss != 0 else output | |
| # # print(f"Is there anobidy coming here?") | |
| # # return TokenClassifierOutput( | |
| # # loss=total_loss, | |
| # # logits=task_logits, | |
| # # hidden_states=outputs.hidden_states, | |
| # # attentions=outputs.attentions, | |
| # # ) | |