File size: 8,951 Bytes
ffbff05
 
 
 
 
 
 
1d52a4b
ffbff05
 
 
 
 
 
 
 
 
 
ba6b686
 
 
 
 
 
 
 
 
 
 
ffbff05
 
 
 
2c3d9da
 
5a14ece
 
 
da75bdd
427163b
d2959f2
 
5a14ece
 
 
b1ec46e
 
 
d2959f2
ce6d631
 
 
 
 
 
 
 
 
 
 
 
 
 
5a14ece
 
 
 
 
 
 
 
ffbff05
e234a9b
 
 
5a14ece
 
 
 
 
 
 
 
 
 
8c19df9
 
da75bdd
5a14ece
ba6b686
 
 
427163b
5a14ece
427163b
 
 
ba6b686
5a14ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret
from .configuration_stacked import ImpressoConfig

logger = logging.getLogger(__name__)


def get_info(label_map):
    num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
    return num_token_labels_dict


# class MyCustomModel:
#     def __init__(self):
#         # Custom initialization
#         pass
#
#     @classmethod
#     def from_pretrained(cls, *args, **kwargs):
#         print("Ignoring weights and using custom initialization.")
#         return cls()


class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
    config_class = ImpressoConfig
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Load floret model
        self.dummy_param = nn.Parameter(torch.zeros(1))
        self.model_floret = floret.load_model(self.config.filename)
        input_ids = "this is a text"
        predictions, probabilities = self.model_floret.predict([input_ids], k=1)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # Convert input_ids to strings using tokenizer
        print(
            f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
        )

        # if input_ids is not None:
        #     tokenizer = kwargs.get("tokenizer")
        #     texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        # else:
        #     texts = kwargs.get("text", None)
        #
        # if texts:
        #     # Floret expects strings, not tensors
        #     predictions = [self.model_floret(text) for text in texts]
        #     # Convert predictions to tensors for Hugging Face compatibility
        #     return torch.tensor(predictions)
        # else:
        # If no text is found, return dummy output
        return torch.zeros((1, 2))  # Dummy tensor with shape (batch_size, num_classes)

    def state_dict(self, *args, **kwargs):
        # Return an empty state dictionary
        return {}

    def load_state_dict(self, state_dict, strict=True):
        # Ignore loading since there are no parameters
        print("Ignoring state_dict since model has no parameters.")

    def get_floret_model(self):
        return self.model_floret

    def get_extended_attention_mask(
        self, attention_mask, input_shape, device=None, dtype=torch.float
    ):
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        extended_attention_mask = attention_mask[:, None, None, :]
        extended_attention_mask = extended_attention_mask.to(dtype=dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    @property
    def device(self):
        return next(self.parameters()).device

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        print("Ignoring weights and using custom initialization.")
        # Manually create the config
        config = ImpressoConfig(**kwargs)
        # Pass the manually created config to the class
        model = cls(config)
        return model


# class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
#
#     config_class = ImpressoConfig
#     _keys_to_ignore_on_load_missing = [r"position_ids"]
#
#     def __init__(self, config):
#         super().__init__(config)
#         # self.num_token_labels_dict = get_info(config.label_map)
#         # self.config = config
#         # # print(f"I dont think it arrives here: {self.config}")
#         # self.bert = AutoModel.from_pretrained(
#         #     config.pretrained_config["_name_or_path"], config=config.pretrained_config
#         # )
#         self.model_floret = floret.load_model(self.config.filename)
#         # print(f"Model loaded: {self.model_floret}")
#         # if "classifier_dropout" not in config.__dict__:
#         #     classifier_dropout = 0.1
#         # else:
#         #     classifier_dropout = (
#         #         config.classifier_dropout
#         #         if config.classifier_dropout is not None
#         #         else config.hidden_dropout_prob
#         #     )
#         # self.dropout = nn.Dropout(classifier_dropout)
#         #
#         # # Additional transformer layers
#         # self.transformer_encoder = nn.TransformerEncoder(
#         #     nn.TransformerEncoderLayer(
#         #         d_model=config.hidden_size, nhead=config.num_attention_heads
#         #     ),
#         #     num_layers=2,
#         # )
#
#         # For token classification, create a classifier for each task
#         # self.token_classifiers = nn.ModuleDict(
#         #     {
#         #         task: nn.Linear(config.hidden_size, num_labels)
#         #         for task, num_labels in self.num_token_labels_dict.items()
#         #     }
#         # )
#         #
#         # # Initialize weights and apply final processing
#         # self.post_init()
#
#     def get_floret_model(self):
#         return self.model_floret
#
#     @classmethod
#     def from_pretrained(cls, *args, **kwargs):
#         print("Ignoring weights and using custom initialization.")
#
#         # Manually create the config
#         config = ImpressoConfig()
#
#         # Pass the manually created config to the class
#         model = cls(config)
#         return model
#
#     # def forward(
#     #     self,
#     #     input_ids: Optional[torch.Tensor] = None,
#     #     attention_mask: Optional[torch.Tensor] = None,
#     #     token_type_ids: Optional[torch.Tensor] = None,
#     #     position_ids: Optional[torch.Tensor] = None,
#     #     head_mask: Optional[torch.Tensor] = None,
#     #     inputs_embeds: Optional[torch.Tensor] = None,
#     #     labels: Optional[torch.Tensor] = None,
#     #     token_labels: Optional[dict] = None,
#     #     output_attentions: Optional[bool] = None,
#     #     output_hidden_states: Optional[bool] = None,
#     #     return_dict: Optional[bool] = None,
#     # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
#     #     r"""
#     #     token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
#     #         Labels for computing the token classification loss. Keys should match the tasks.
#     #     """
#     #     return_dict = (
#     #         return_dict if return_dict is not None else self.config.use_return_dict
#     #     )
#     #
#     #     bert_kwargs = {
#     #         "input_ids": input_ids,
#     #         "attention_mask": attention_mask,
#     #         "token_type_ids": token_type_ids,
#     #         "position_ids": position_ids,
#     #         "head_mask": head_mask,
#     #         "inputs_embeds": inputs_embeds,
#     #         "output_attentions": output_attentions,
#     #         "output_hidden_states": output_hidden_states,
#     #         "return_dict": return_dict,
#     #     }
#     #
#     #     if any(
#     #         keyword in self.config.name_or_path.lower()
#     #         for keyword in ["llama", "deberta"]
#     #     ):
#     #         bert_kwargs.pop("token_type_ids")
#     #         bert_kwargs.pop("head_mask")
#     #
#     #     outputs = self.bert(**bert_kwargs)
#     #
#     #     # For token classification
#     #     token_output = outputs[0]
#     #     token_output = self.dropout(token_output)
#     #
#     #     # Pass through additional transformer layers
#     #     token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
#     #         0, 1
#     #     )
#     #
#     #     # Collect the logits and compute the loss for each task
#     #     task_logits = {}
#     #     total_loss = 0
#     #     for task, classifier in self.token_classifiers.items():
#     #         logits = classifier(token_output)
#     #         task_logits[task] = logits
#     #         if token_labels and task in token_labels:
#     #             loss_fct = CrossEntropyLoss()
#     #             loss = loss_fct(
#     #                 logits.view(-1, self.num_token_labels_dict[task]),
#     #                 token_labels[task].view(-1),
#     #             )
#     #             total_loss += loss
#     #
#     #     if not return_dict:
#     #         output = (task_logits,) + outputs[2:]
#     #         return ((total_loss,) + output) if total_loss != 0 else output
#     #     print(f"Is there anobidy coming here?")
#     #     return TokenClassifierOutput(
#     #         loss=total_loss,
#     #         logits=task_logits,
#     #         hidden_states=outputs.hidden_states,
#     #         attentions=outputs.attentions,
#     #     )