File size: 8,951 Bytes
ffbff05 1d52a4b ffbff05 ba6b686 ffbff05 2c3d9da 5a14ece da75bdd 427163b d2959f2 5a14ece b1ec46e d2959f2 ce6d631 5a14ece ffbff05 e234a9b 5a14ece 8c19df9 da75bdd 5a14ece ba6b686 427163b 5a14ece 427163b ba6b686 5a14ece |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret
from .configuration_stacked import ImpressoConfig
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
# class MyCustomModel:
# def __init__(self):
# # Custom initialization
# pass
#
# @classmethod
# def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# return cls()
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
config_class = ImpressoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.config = config
# Load floret model
self.dummy_param = nn.Parameter(torch.zeros(1))
self.model_floret = floret.load_model(self.config.filename)
input_ids = "this is a text"
predictions, probabilities = self.model_floret.predict([input_ids], k=1)
def forward(self, input_ids, attention_mask=None, **kwargs):
# Convert input_ids to strings using tokenizer
print(
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
)
# if input_ids is not None:
# tokenizer = kwargs.get("tokenizer")
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
# else:
# texts = kwargs.get("text", None)
#
# if texts:
# # Floret expects strings, not tensors
# predictions = [self.model_floret(text) for text in texts]
# # Convert predictions to tensors for Hugging Face compatibility
# return torch.tensor(predictions)
# else:
# If no text is found, return dummy output
return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes)
def state_dict(self, *args, **kwargs):
# Return an empty state dictionary
return {}
def load_state_dict(self, state_dict, strict=True):
# Ignore loading since there are no parameters
print("Ignoring state_dict since model has no parameters.")
def get_floret_model(self):
return self.model_floret
def get_extended_attention_mask(
self, attention_mask, input_shape, device=None, dtype=torch.float
):
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model
# class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
#
# config_class = ImpressoConfig
# _keys_to_ignore_on_load_missing = [r"position_ids"]
#
# def __init__(self, config):
# super().__init__(config)
# # self.num_token_labels_dict = get_info(config.label_map)
# # self.config = config
# # # print(f"I dont think it arrives here: {self.config}")
# # self.bert = AutoModel.from_pretrained(
# # config.pretrained_config["_name_or_path"], config=config.pretrained_config
# # )
# self.model_floret = floret.load_model(self.config.filename)
# # print(f"Model loaded: {self.model_floret}")
# # if "classifier_dropout" not in config.__dict__:
# # classifier_dropout = 0.1
# # else:
# # classifier_dropout = (
# # config.classifier_dropout
# # if config.classifier_dropout is not None
# # else config.hidden_dropout_prob
# # )
# # self.dropout = nn.Dropout(classifier_dropout)
# #
# # # Additional transformer layers
# # self.transformer_encoder = nn.TransformerEncoder(
# # nn.TransformerEncoderLayer(
# # d_model=config.hidden_size, nhead=config.num_attention_heads
# # ),
# # num_layers=2,
# # )
#
# # For token classification, create a classifier for each task
# # self.token_classifiers = nn.ModuleDict(
# # {
# # task: nn.Linear(config.hidden_size, num_labels)
# # for task, num_labels in self.num_token_labels_dict.items()
# # }
# # )
# #
# # # Initialize weights and apply final processing
# # self.post_init()
#
# def get_floret_model(self):
# return self.model_floret
#
# @classmethod
# def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
#
# # Manually create the config
# config = ImpressoConfig()
#
# # Pass the manually created config to the class
# model = cls(config)
# return model
#
# # def forward(
# # self,
# # input_ids: Optional[torch.Tensor] = None,
# # attention_mask: Optional[torch.Tensor] = None,
# # token_type_ids: Optional[torch.Tensor] = None,
# # position_ids: Optional[torch.Tensor] = None,
# # head_mask: Optional[torch.Tensor] = None,
# # inputs_embeds: Optional[torch.Tensor] = None,
# # labels: Optional[torch.Tensor] = None,
# # token_labels: Optional[dict] = None,
# # output_attentions: Optional[bool] = None,
# # output_hidden_states: Optional[bool] = None,
# # return_dict: Optional[bool] = None,
# # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
# # r"""
# # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
# # Labels for computing the token classification loss. Keys should match the tasks.
# # """
# # return_dict = (
# # return_dict if return_dict is not None else self.config.use_return_dict
# # )
# #
# # bert_kwargs = {
# # "input_ids": input_ids,
# # "attention_mask": attention_mask,
# # "token_type_ids": token_type_ids,
# # "position_ids": position_ids,
# # "head_mask": head_mask,
# # "inputs_embeds": inputs_embeds,
# # "output_attentions": output_attentions,
# # "output_hidden_states": output_hidden_states,
# # "return_dict": return_dict,
# # }
# #
# # if any(
# # keyword in self.config.name_or_path.lower()
# # for keyword in ["llama", "deberta"]
# # ):
# # bert_kwargs.pop("token_type_ids")
# # bert_kwargs.pop("head_mask")
# #
# # outputs = self.bert(**bert_kwargs)
# #
# # # For token classification
# # token_output = outputs[0]
# # token_output = self.dropout(token_output)
# #
# # # Pass through additional transformer layers
# # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
# # 0, 1
# # )
# #
# # # Collect the logits and compute the loss for each task
# # task_logits = {}
# # total_loss = 0
# # for task, classifier in self.token_classifiers.items():
# # logits = classifier(token_output)
# # task_logits[task] = logits
# # if token_labels and task in token_labels:
# # loss_fct = CrossEntropyLoss()
# # loss = loss_fct(
# # logits.view(-1, self.num_token_labels_dict[task]),
# # token_labels[task].view(-1),
# # )
# # total_loss += loss
# #
# # if not return_dict:
# # output = (task_logits,) + outputs[2:]
# # return ((total_loss,) + output) if total_loss != 0 else output
# # print(f"Is there anobidy coming here?")
# # return TokenClassifierOutput(
# # loss=total_loss,
# # logits=task_logits,
# # hidden_states=outputs.hidden_states,
# # attentions=outputs.attentions,
# # )
|