AshwinSankar commited on
Commit
d4c5a78
·
verified ·
1 Parent(s): d6686f1

Upload model

Browse files
Files changed (2) hide show
  1. README.md +36 -7
  2. modeling_gemma3_punctuation.py +308 -0
README.md CHANGED
@@ -69,7 +69,7 @@ pip install cadence-punctuation
69
  ### Using the Simple Interface
70
 
71
  ```python
72
- from cadence-punctuation import PunctuationModel
73
 
74
  # Load model (local path or Hugging Face model ID)
75
  model = PunctuationModel("path/to/download/weights")
@@ -103,16 +103,45 @@ model_name = "ai4bharat/Cadence"
103
  tokenizer = AutoTokenizer.from_pretrained(model_name)
104
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
105
 
106
- # Prepare input
107
- text = "hello world how are you"
 
 
 
 
108
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
109
 
110
- # Get predictions
111
  with torch.no_grad():
112
  outputs = model(**inputs)
113
- predictions = torch.argmax(outputs.logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- print(predictions)
 
116
  ```
117
 
118
 
@@ -160,7 +189,7 @@ The model can predict the following punctuation marks:
160
  model = PunctuationModel(
161
  model_path="path/to/download/weights",
162
  gpu_id=0, # Use specific GPU
163
- max_length=512, # Longer sequences
164
  sliding_window=True, # Handle long texts
165
  verbose=False, # Quiet mode
166
  d_type="bfloat16"
 
69
  ### Using the Simple Interface
70
 
71
  ```python
72
+ from cadence import PunctuationModel
73
 
74
  # Load model (local path or Hugging Face model ID)
75
  model = PunctuationModel("path/to/download/weights")
 
103
  tokenizer = AutoTokenizer.from_pretrained(model_name)
104
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
105
 
106
+ id2label = model.config.id2label
107
+
108
+ text = "यह एक वाक्य है इसका क्या मतलब है"
109
+ # text = "this is a test sentence what do you think"
110
+
111
+ # Tokenize input and prepare for model
112
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
113
+ input_ids = inputs['input_ids'][0] # Get input_ids for the first (and only) sentence
114
 
 
115
  with torch.no_grad():
116
  outputs = model(**inputs)
117
+ predictions_for_sentence = torch.argmax(outputs.logits, dim=-1)[0]
118
+
119
+
120
+ result_tokens_and_punctuation = []
121
+ all_token_strings = tokenizer.convert_ids_to_tokens(input_ids.tolist()) # Get all token strings
122
+
123
+ for i, token_id_value in enumerate(input_ids.tolist()):
124
+ # Process only non-padding tokens based on the attention mask
125
+ if inputs['attention_mask'][0][i] == 0:
126
+ continue
127
+
128
+ current_token_string = all_token_strings[i]
129
+
130
+ is_special_token = token_id_value in tokenizer.all_special_ids
131
+
132
+ if not is_special_token:
133
+ result_tokens_and_punctuation.append(current_token_string)
134
+
135
+ predicted_punctuation_id = predictions_for_sentence[i].item()
136
+ punctuation_character = id2label[predicted_punctuation_id]
137
+
138
+ if punctuation_character != "O" and not is_special_token:
139
+ result_tokens_and_punctuation.append(punctuation_character)
140
+
141
+ punctuated_text = tokenizer.convert_tokens_to_string(result_tokens_and_punctuation)
142
 
143
+ print(f"Original Text: {text}")
144
+ print(f"Punctuated Text: {punctuated_text}")
145
  ```
146
 
147
 
 
189
  model = PunctuationModel(
190
  model_path="path/to/download/weights",
191
  gpu_id=0, # Use specific GPU
192
+ max_length=512, # length for trunation; also used as window size when sliding_window=True
193
  sliding_window=True, # Handle long texts
194
  verbose=False, # Quiet mode
195
  d_type="bfloat16"
modeling_gemma3_punctuation.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gemma3 model for token classification with non-causal attention
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing import Optional, Tuple, List, Dict, Any
8
+ import types
9
+
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers import Gemma3ForCausalLM
12
+ from transformers.models.gemma3.modeling_gemma3 import (
13
+ Gemma3Attention,
14
+ repeat_kv,
15
+ apply_rotary_pos_emb,
16
+ ALL_ATTENTION_FUNCTIONS,
17
+ Cache,
18
+ FlashAttentionKwargs,
19
+ )
20
+ from transformers.modeling_outputs import TokenClassifierOutput
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Gemma3PunctuationConfig(PretrainedConfig):
27
+ """
28
+ Configuration class for Gemma3 punctuation model.
29
+ """
30
+
31
+ model_type = "cadence_punctuation"
32
+
33
+ def __init__(
34
+ self,
35
+ num_labels: int = 31,
36
+ classifier_dropout_prob: float = 0.0,
37
+ use_non_causal_attention: bool = True,
38
+ **kwargs
39
+ ):
40
+ self.num_labels = num_labels
41
+ self.classifier_dropout_prob = classifier_dropout_prob
42
+ self.use_non_causal_attention = use_non_causal_attention
43
+ super().__init__(**kwargs)
44
+
45
+
46
+ def _extract_padding_mask_corrected(
47
+ combined_mask_4d: Optional[torch.Tensor],
48
+ debug_print: bool = False
49
+ ) -> Optional[torch.Tensor]:
50
+ """Extract padding mask from combined 4D attention mask."""
51
+ if combined_mask_4d is None:
52
+ return None
53
+
54
+ mask_value = torch.finfo(combined_mask_4d.dtype).min
55
+ is_key_padding = (combined_mask_4d == mask_value).all(dim=2, keepdim=True)
56
+ padding_only_mask = torch.where(
57
+ is_key_padding.expand_as(combined_mask_4d),
58
+ torch.full_like(combined_mask_4d, mask_value),
59
+ torch.zeros_like(combined_mask_4d)
60
+ )
61
+ return padding_only_mask
62
+
63
+
64
+ def non_causal_eager_attention_forward_with_padding(
65
+ module: nn.Module,
66
+ query: torch.Tensor,
67
+ key: torch.Tensor,
68
+ value: torch.Tensor,
69
+ attention_mask: Optional[torch.Tensor],
70
+ **kwargs: Any,
71
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
72
+ """Non-causal eager attention implementation."""
73
+ dropout = kwargs.get("dropout", 0.0)
74
+ scaling = kwargs.get("scaling", None)
75
+ softcap = kwargs.get("softcap", None)
76
+
77
+ if scaling is None:
78
+ head_dim = getattr(module, "head_dim", query.shape[-1])
79
+ scaling = head_dim**-0.5
80
+
81
+ num_key_value_groups = getattr(module, "num_key_value_groups", 1)
82
+ key_states = repeat_kv(key, num_key_value_groups)
83
+ value_states = repeat_kv(value, num_key_value_groups)
84
+
85
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
86
+
87
+ if softcap is not None:
88
+ attn_weights = attn_weights / softcap
89
+ attn_weights = torch.tanh(attn_weights)
90
+ attn_weights = attn_weights * softcap
91
+
92
+ if attention_mask is not None:
93
+ mask_slice = attention_mask[:, :, :, : key_states.shape[-2]]
94
+ attn_weights = attn_weights + mask_slice
95
+
96
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
97
+ is_training = getattr(module, "training", False)
98
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=is_training)
99
+ attn_output = torch.matmul(attn_weights, value_states)
100
+ attn_output = attn_output.transpose(1, 2).contiguous()
101
+ return attn_output, attn_weights
102
+
103
+
104
+ def modified_gemma3_attention_forward_non_causal(
105
+ self: Gemma3Attention,
106
+ hidden_states: torch.Tensor,
107
+ position_embeddings: torch.Tensor,
108
+ attention_mask: Optional[torch.Tensor],
109
+ past_key_value: Optional[Cache] = None,
110
+ cache_position: Optional[torch.LongTensor] = None,
111
+ **kwargs: Any,
112
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
113
+ """Modified Gemma3 attention forward for non-causal behavior."""
114
+ bsz, q_len, _ = hidden_states.size()
115
+ input_shape = hidden_states.shape[:-1]
116
+ hidden_shape = (*input_shape, -1, self.head_dim)
117
+
118
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
119
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
120
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
121
+
122
+ query_states = self.q_norm(query_states)
123
+ key_states = self.k_norm(key_states)
124
+ cos, sin = position_embeddings
125
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
126
+
127
+ if past_key_value is not None:
128
+ cache_kwargs = {
129
+ "sin": sin,
130
+ "cos": cos,
131
+ "cache_position": cache_position,
132
+ "sliding_window": self.sliding_window
133
+ }
134
+ key_states, value_states = past_key_value.update(
135
+ key_states, value_states, self.layer_idx, cache_kwargs
136
+ )
137
+
138
+ effective_attn_implementation = self.config._attn_implementation
139
+ output_attentions = kwargs.get("output_attentions", False)
140
+
141
+ if effective_attn_implementation == "sdpa" and output_attentions:
142
+ effective_attn_implementation = "eager"
143
+ elif effective_attn_implementation == "flash_attention_2" and output_attentions:
144
+ effective_attn_implementation = "eager"
145
+
146
+ padding_only_mask = _extract_padding_mask_corrected(attention_mask)
147
+ use_causal_flag = False # Non-causal for punctuation
148
+
149
+ # Select attention interface
150
+ if effective_attn_implementation == "eager":
151
+ attention_interface = non_causal_eager_attention_forward_with_padding
152
+ elif effective_attn_implementation == "sdpa":
153
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get("sdpa", non_causal_eager_attention_forward_with_padding)
154
+ elif effective_attn_implementation == "flash_attention_2":
155
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get("flash_attention_2", non_causal_eager_attention_forward_with_padding)
156
+ else:
157
+ attention_interface = non_causal_eager_attention_forward_with_padding
158
+
159
+ final_attention_mask = padding_only_mask
160
+ if final_attention_mask is not None:
161
+ final_attention_mask = final_attention_mask.to(query_states.device)
162
+
163
+ # Prepare kwargs for attention interface
164
+ attn_specific_kwargs: Dict[str, Any] = {}
165
+ if attention_interface == non_causal_eager_attention_forward_with_padding:
166
+ attn_specific_kwargs = {
167
+ "dropout": 0.0,
168
+ "scaling": self.scaling,
169
+ "softcap": getattr(self, "softcap", None)
170
+ }
171
+ elif effective_attn_implementation == "sdpa":
172
+ attn_specific_kwargs = {"is_causal": use_causal_flag}
173
+ if output_attentions:
174
+ attn_specific_kwargs["output_attentions"] = True
175
+ elif effective_attn_implementation == "flash_attention_2":
176
+ attn_specific_kwargs = {
177
+ "causal": use_causal_flag,
178
+ "softcap": getattr(self, "softcap", None),
179
+ "dropout": 0.0
180
+ }
181
+ if output_attentions:
182
+ attn_specific_kwargs["output_attentions"] = True
183
+
184
+ attn_output, attn_weights = attention_interface(
185
+ self, query_states, key_states, value_states, final_attention_mask, **attn_specific_kwargs
186
+ )
187
+
188
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
189
+ attn_output = self.o_proj(attn_output)
190
+
191
+ returned_weights = attn_weights if output_attentions and attn_weights is not None else None
192
+
193
+ return attn_output, returned_weights
194
+
195
+
196
+ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
197
+ """
198
+ Gemma3 model for token classification (punctuation prediction).
199
+ Inherits from Gemma3ForCausalLM and replaces the LM head with classification head.
200
+ """
201
+
202
+ config_class = Gemma3PunctuationConfig
203
+
204
+ def __init__(self, config):
205
+ # Initialize the parent Gemma3ForCausalLM
206
+ super().__init__(config)
207
+ self.num_labels = config.num_labels
208
+
209
+ # Replace the lm_head with classification head
210
+ # Don't create a separate classifier - just replace lm_head directly
211
+ classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
212
+ self.lm_head = nn.Sequential(
213
+ nn.Dropout(classifier_dropout_prob),
214
+ nn.Linear(config.hidden_size, config.num_labels)
215
+ )
216
+
217
+ # Update config for classification
218
+ self.config.num_labels = config.num_labels
219
+
220
+ # Initialize weights for the new head
221
+ self.post_init()
222
+
223
+ # Apply non-causal attention patching if requested
224
+ if getattr(config, 'use_non_causal_attention', True):
225
+ self._patch_attention_layers()
226
+
227
+ def _patch_attention_layers(self):
228
+ """Patch attention layers to use non-causal attention."""
229
+ count = 0
230
+
231
+ # The model structure is self.model.layers (inherited from Gemma3ForCausalLM)
232
+ if hasattr(self, 'model') and hasattr(self.model, 'layers'):
233
+ target_layers = self.model.layers
234
+ else:
235
+ logger.warning("Could not find model.layers for attention patching")
236
+ return
237
+
238
+ for idx, layer in enumerate(target_layers):
239
+ if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, Gemma3Attention):
240
+ layer.self_attn.layer_idx = idx
241
+ layer.self_attn.forward = types.MethodType(
242
+ modified_gemma3_attention_forward_non_causal,
243
+ layer.self_attn
244
+ )
245
+ count += 1
246
+
247
+ logger.info(f"Patched {count} attention layers for non-causal attention")
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: Optional[torch.LongTensor] = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
256
+ labels: Optional[torch.LongTensor] = None,
257
+ use_cache: Optional[bool] = None,
258
+ output_attentions: Optional[bool] = None,
259
+ output_hidden_states: Optional[bool] = None,
260
+ return_dict: Optional[bool] = None,
261
+ cache_position: Optional[torch.LongTensor] = None,
262
+ ) -> TokenClassifierOutput:
263
+ """
264
+ Forward pass for token classification.
265
+ """
266
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
+
268
+ # Call the parent's forward method but get the hidden states instead of logits
269
+ outputs = self.model(
270
+ input_ids=input_ids,
271
+ attention_mask=attention_mask,
272
+ position_ids=position_ids,
273
+ past_key_values=past_key_values,
274
+ inputs_embeds=inputs_embeds,
275
+ use_cache=use_cache,
276
+ output_attentions=output_attentions,
277
+ output_hidden_states=output_hidden_states,
278
+ return_dict=return_dict,
279
+ cache_position=cache_position,
280
+ )
281
+
282
+ # Get the hidden states from the model output
283
+ sequence_output = outputs[0]
284
+
285
+ # Apply the classification head (which is now self.lm_head)
286
+ logits = self.lm_head(sequence_output)
287
+
288
+ loss = None
289
+ if labels is not None:
290
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
291
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
292
+
293
+ if not return_dict:
294
+ output = (logits,) + outputs[2:]
295
+ return ((loss,) + output) if loss is not None else output
296
+
297
+ return TokenClassifierOutput(
298
+ loss=loss,
299
+ logits=logits,
300
+ hidden_states=outputs.hidden_states,
301
+ attentions=outputs.attentions,
302
+ )
303
+
304
+
305
+ # Register the model for AutoModel
306
+ from transformers import AutoConfig, AutoModel
307
+ AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
308
+ AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)