AshwinSankar psidharth567 commited on
Commit
17d96ff
·
verified ·
1 Parent(s): bf71431

Completely overhauled the attention implementation. Using the existing Gemma-3 attention implementation rather than custom monkey-patched implementation. (#10)

Browse files

- Completely overhauled the attention implementation. Using the existing Gemma-3 attention implementation rather than custom monkey-patched implementation. (efb4d2d4f654499b929a467d423403d3830628e7)


Co-authored-by: Pulipaka Prem Sidharth <[email protected]>

Files changed (1) hide show
  1. modeling_gemma3_punctuation.py +155 -190
modeling_gemma3_punctuation.py CHANGED
@@ -1,22 +1,21 @@
1
  """
2
- Custom Gemma3 model for token classification with non-causal attention
3
  """
4
 
5
  import torch
6
  import torch.nn as nn
7
- from typing import Optional, Tuple, List, Dict, Any
8
- import types
9
 
10
  from transformers import PretrainedConfig, PreTrainedModel
11
- from transformers import Gemma3ForCausalLM
12
  from transformers.models.gemma3.modeling_gemma3 import (
13
  Gemma3Attention,
14
- repeat_kv,
15
- apply_rotary_pos_emb,
16
- ALL_ATTENTION_FUNCTIONS,
17
- Cache,
18
- FlashAttentionKwargs,
19
  )
 
20
  from transformers.modeling_outputs import TokenClassifierOutput
21
  from transformers.utils import logging
22
 
@@ -27,7 +26,6 @@ class Gemma3PunctuationConfig(PretrainedConfig):
27
  """
28
  Configuration class for Gemma3 punctuation model.
29
  """
30
-
31
  model_type = "cadence_punctuation"
32
 
33
  def __init__(
@@ -43,171 +41,141 @@ class Gemma3PunctuationConfig(PretrainedConfig):
43
  super().__init__(**kwargs)
44
 
45
 
46
- def _extract_padding_mask_corrected(
47
- combined_mask_4d: Optional[torch.Tensor],
48
- debug_print: bool = False
49
- ) -> Optional[torch.Tensor]:
50
- """Extract padding mask from combined 4D attention mask."""
51
- if combined_mask_4d is None:
52
- return None
53
-
54
- mask_value = torch.finfo(combined_mask_4d.dtype).min
55
- is_key_padding = (combined_mask_4d == mask_value).all(dim=2, keepdim=True)
56
- padding_only_mask = torch.where(
57
- is_key_padding.expand_as(combined_mask_4d),
58
- torch.full_like(combined_mask_4d, mask_value),
59
- torch.zeros_like(combined_mask_4d)
60
- )
61
- return padding_only_mask
62
-
63
-
64
- def non_causal_eager_attention_forward_with_padding(
65
- module: nn.Module,
66
- query: torch.Tensor,
67
- key: torch.Tensor,
68
- value: torch.Tensor,
69
- attention_mask: Optional[torch.Tensor],
70
- **kwargs: Any,
71
- ) -> Tuple[torch.Tensor, torch.Tensor]:
72
- """Non-causal eager attention implementation."""
73
- dropout = kwargs.get("dropout", 0.0)
74
- scaling = kwargs.get("scaling", None)
75
- softcap = kwargs.get("softcap", None)
76
-
77
- if scaling is None:
78
- head_dim = getattr(module, "head_dim", query.shape[-1])
79
- scaling = head_dim**-0.5
80
-
81
- num_key_value_groups = getattr(module, "num_key_value_groups", 1)
82
- key_states = repeat_kv(key, num_key_value_groups)
83
- value_states = repeat_kv(value, num_key_value_groups)
84
-
85
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
86
-
87
- if softcap is not None:
88
- attn_weights = attn_weights / softcap
89
- attn_weights = torch.tanh(attn_weights)
90
- attn_weights = attn_weights * softcap
91
-
92
- if attention_mask is not None:
93
- mask_slice = attention_mask[:, :, :, : key_states.shape[-2]]
94
- attn_weights = attn_weights + mask_slice
95
-
96
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
97
- is_training = getattr(module, "training", False)
98
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=is_training)
99
- attn_output = torch.matmul(attn_weights, value_states)
100
- attn_output = attn_output.transpose(1, 2).contiguous()
101
- return attn_output, attn_weights
102
-
103
-
104
- def modified_gemma3_attention_forward_non_causal(
105
- self: Gemma3Attention,
106
- hidden_states: torch.Tensor,
107
- position_embeddings: torch.Tensor,
108
- attention_mask: Optional[torch.Tensor],
109
- past_key_value: Optional[Cache] = None,
110
- cache_position: Optional[torch.LongTensor] = None,
111
- **kwargs: Any,
112
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
113
- """Modified Gemma3 attention forward for non-causal behavior."""
114
- bsz, q_len, _ = hidden_states.size()
115
- input_shape = hidden_states.shape[:-1]
116
- hidden_shape = (*input_shape, -1, self.head_dim)
117
-
118
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
119
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
120
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
121
-
122
- query_states = self.q_norm(query_states)
123
- key_states = self.k_norm(key_states)
124
- cos, sin = position_embeddings
125
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
126
-
127
- if past_key_value is not None:
128
- cache_kwargs = {
129
- "sin": sin,
130
- "cos": cos,
131
- "cache_position": cache_position,
132
- "sliding_window": self.sliding_window
133
- }
134
- key_states, value_states = past_key_value.update(
135
- key_states, value_states, self.layer_idx, cache_kwargs
136
  )
 
137
 
138
- effective_attn_implementation = self.config._attn_implementation
139
- output_attentions = kwargs.get("output_attentions", False)
140
-
141
- if effective_attn_implementation == "sdpa" and output_attentions:
142
- effective_attn_implementation = "eager"
143
- elif effective_attn_implementation == "flash_attention_2" and output_attentions:
144
- effective_attn_implementation = "eager"
145
-
146
- padding_only_mask = _extract_padding_mask_corrected(attention_mask)
147
- use_causal_flag = False # Non-causal for punctuation
148
-
149
- # Select attention interface
150
- if effective_attn_implementation == "eager":
151
- attention_interface = non_causal_eager_attention_forward_with_padding
152
- elif effective_attn_implementation == "sdpa":
153
- attention_interface = ALL_ATTENTION_FUNCTIONS.get("sdpa", non_causal_eager_attention_forward_with_padding)
154
- elif effective_attn_implementation == "flash_attention_2":
155
- attention_interface = ALL_ATTENTION_FUNCTIONS.get("flash_attention_2", non_causal_eager_attention_forward_with_padding)
156
- else:
157
- attention_interface = non_causal_eager_attention_forward_with_padding
158
-
159
- final_attention_mask = padding_only_mask
160
- if final_attention_mask is not None:
161
- final_attention_mask = final_attention_mask.to(query_states.device)
162
-
163
- # Prepare kwargs for attention interface
164
- attn_specific_kwargs: Dict[str, Any] = {}
165
- if attention_interface == non_causal_eager_attention_forward_with_padding:
166
- attn_specific_kwargs = {
167
- "dropout": 0.0,
168
- "scaling": self.scaling,
169
- "softcap": getattr(self, "softcap", None)
170
- }
171
- elif effective_attn_implementation == "sdpa":
172
- attn_specific_kwargs = {"is_causal": use_causal_flag}
173
- if output_attentions:
174
- attn_specific_kwargs["output_attentions"] = True
175
- elif effective_attn_implementation == "flash_attention_2":
176
- attn_specific_kwargs = {
177
- "causal": use_causal_flag,
178
- "softcap": getattr(self, "softcap", None),
179
- "dropout": 0.0
180
- }
181
- if output_attentions:
182
- attn_specific_kwargs["output_attentions"] = True
183
-
184
- attn_output, attn_weights = attention_interface(
185
- self, query_states, key_states, value_states, final_attention_mask, **attn_specific_kwargs
186
- )
187
 
188
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
189
- attn_output = self.o_proj(attn_output)
190
-
191
- returned_weights = attn_weights if output_attentions and attn_weights is not None else None
192
-
193
- return attn_output, returned_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
 
196
  class Gemma3ForTokenClassification(Gemma3ForCausalLM):
197
  """
198
  Gemma3 model for token classification (punctuation prediction).
199
- Inherits from Gemma3ForCausalLM and replaces the LM head with classification head.
200
  """
201
 
202
  config_class = Gemma3PunctuationConfig
203
 
204
  def __init__(self, config):
205
- # Initialize the parent Gemma3ForCausalLM
206
  super().__init__(config)
207
  self.num_labels = config.num_labels
208
 
 
 
 
 
209
  # Replace the lm_head with classification head
210
- # Don't create a separate classifier - just replace lm_head directly
211
  classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
212
  self.lm_head = nn.Sequential(
213
  nn.Dropout(classifier_dropout_prob),
@@ -219,32 +187,6 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
219
 
220
  # Initialize weights for the new head
221
  self.post_init()
222
-
223
- # Apply non-causal attention patching if requested
224
- if getattr(config, 'use_non_causal_attention', True):
225
- self._patch_attention_layers()
226
-
227
- def _patch_attention_layers(self):
228
- """Patch attention layers to use non-causal attention."""
229
- count = 0
230
-
231
- # The model structure is self.model.layers (inherited from Gemma3ForCausalLM)
232
- if hasattr(self, 'model') and hasattr(self.model, 'layers'):
233
- target_layers = self.model.layers
234
- else:
235
- logger.warning("Could not find model.layers for attention patching")
236
- return
237
-
238
- for idx, layer in enumerate(target_layers):
239
- if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, Gemma3Attention):
240
- layer.self_attn.layer_idx = idx
241
- layer.self_attn.forward = types.MethodType(
242
- modified_gemma3_attention_forward_non_causal,
243
- layer.self_attn
244
- )
245
- count += 1
246
-
247
- logger.info(f"Patched {count} attention layers for non-causal attention")
248
 
249
  def forward(
250
  self,
@@ -260,12 +202,10 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
260
  return_dict: Optional[bool] = None,
261
  cache_position: Optional[torch.LongTensor] = None,
262
  ) -> TokenClassifierOutput:
263
- """
264
- Forward pass for token classification.
265
- """
266
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
 
268
- # Call the parent's forward method but get the hidden states instead of logits
269
  outputs = self.model(
270
  input_ids=input_ids,
271
  attention_mask=attention_mask,
@@ -302,7 +242,32 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
302
  )
303
 
304
 
305
- # Register the model for AutoModel
 
306
  from transformers import AutoConfig, AutoModel
 
 
307
  AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
308
- AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Change the attention of Gemma3 to be bidirectional.
3
  """
4
 
5
  import torch
6
  import torch.nn as nn
7
+ from typing import Optional, List, Dict, Any
8
+ from functools import partial
9
 
10
  from transformers import PretrainedConfig, PreTrainedModel
11
+ from transformers import Gemma3ForCausalLM, Gemma3TextConfig
12
  from transformers.models.gemma3.modeling_gemma3 import (
13
  Gemma3Attention,
14
+ Gemma3DecoderLayer,
15
+ Gemma3TextModel,
16
+
 
 
17
  )
18
+
19
  from transformers.modeling_outputs import TokenClassifierOutput
20
  from transformers.utils import logging
21
 
 
26
  """
27
  Configuration class for Gemma3 punctuation model.
28
  """
 
29
  model_type = "cadence_punctuation"
30
 
31
  def __init__(
 
41
  super().__init__(**kwargs)
42
 
43
 
44
+ # ============ Token Classification Model Components ============
45
+
46
+ class NonCausalGemma3Attention(Gemma3Attention):
47
+ """Gemma3Attention configured for non-causal token classification."""
48
+ def __init__(self, config, layer_idx: int):
49
+ super().__init__(config, layer_idx)
50
+ self.is_causal = False
51
+ self.sliding_window = None
52
+
53
+
54
+ class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer):
55
+ """Decoder layer with non-causal attention for token classification."""
56
+ def __init__(self, config, layer_idx: int):
57
+ super().__init__(config, layer_idx)
58
+ self.self_attn = NonCausalGemma3Attention(config, layer_idx)
59
+
60
+
61
+ class Gemma3TokenClassificationModel(Gemma3TextModel):
62
+ """Gemma3 base model configured for token classification."""
63
+ _no_split_modules = ["NonCausalGemma3DecoderLayer"]
64
+
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+ if getattr(config, 'use_non_causal_attention', True):
68
+ # Replace layers with non-causal versions
69
+ self.layers = nn.ModuleList(
70
+ [
71
+ NonCausalGemma3DecoderLayer(config, layer_idx)
72
+ for layer_idx in range(config.num_hidden_layers)
73
+ ]
74
+ )
75
+
76
+ def _update_causal_mask(
77
+ self,
78
+ attention_mask: torch.Tensor,
79
+ input_tensor: torch.Tensor,
80
+ cache_position: torch.Tensor,
81
+ past_key_values = None,
82
+ output_attentions: bool = False,
83
+ ):
84
+ """Override to create bidirectional attention mask (no causal masking)."""
85
+ if self.config._attn_implementation == "flash_attention_2":
86
+ if attention_mask is not None and 0.0 in attention_mask:
87
+ return attention_mask
88
+ return None
89
+
90
+ past_seen_tokens = (
91
+ past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
+ using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length')
94
 
95
+ dtype, device = input_tensor.dtype, input_tensor.device
96
+ min_dtype = torch.finfo(dtype).min
97
+ sequence_length = input_tensor.shape[1]
98
+
99
+ if using_static_cache:
100
+ target_length = past_key_values.get_max_length()
101
+ else:
102
+ target_length = (
103
+ attention_mask.shape[-1]
104
+ if isinstance(attention_mask, torch.Tensor)
105
+ else past_seen_tokens + sequence_length + 1
106
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ if attention_mask is not None and attention_mask.dim() == 4:
109
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
110
+ if attention_mask.max() != 0:
111
+ raise ValueError(
112
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
113
+ )
114
+ causal_mask = attention_mask
115
+ else:
116
+ # KEY CHANGE: Start with zeros (attend to all) instead of min_dtype (mask all)
117
+ causal_mask = torch.zeros(
118
+ (sequence_length, target_length), dtype=dtype, device=device
119
+ )
120
+ # REMOVED: Causal masking lines that would make it lower triangular
121
+ # if sequence_length != 1:
122
+ # causal_mask = torch.triu(causal_mask, diagonal=1)
123
+
124
+ causal_mask *= torch.arange(
125
+ target_length, device=device
126
+ ) > cache_position.reshape(-1, 1)
127
+ causal_mask = causal_mask[None, None, :, :].expand(
128
+ input_tensor.shape[0], 1, -1, -1
129
+ )
130
+
131
+ if attention_mask is not None:
132
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
133
+ mask_length = attention_mask.shape[-1]
134
+ padding_mask = (
135
+ causal_mask[:, :, :, :mask_length]
136
+ + attention_mask[:, None, None, :]
137
+ )
138
+ padding_mask = padding_mask == 0
139
+ causal_mask[:, :, :, :mask_length] = causal_mask[
140
+ :, :, :, :mask_length
141
+ ].masked_fill(padding_mask, min_dtype)
142
+
143
+ # Handle SDPA-specific optimizations if needed
144
+ if (
145
+ self.config._attn_implementation == "sdpa"
146
+ and attention_mask is not None
147
+ and attention_mask.device.type == "cuda"
148
+ and not output_attentions
149
+ ):
150
+ try:
151
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
152
+ causal_mask = AttentionMaskConverter._unmask_unattended(
153
+ causal_mask, min_dtype
154
+ )
155
+ except ImportError:
156
+ pass # Fallback for older transformers versions
157
+
158
+ return causal_mask
159
 
160
 
161
  class Gemma3ForTokenClassification(Gemma3ForCausalLM):
162
  """
163
  Gemma3 model for token classification (punctuation prediction).
164
+ Uses class-based architecture without monkey patching.
165
  """
166
 
167
  config_class = Gemma3PunctuationConfig
168
 
169
  def __init__(self, config):
170
+ # Initialize with base Gemma3ForCausalLM structure
171
  super().__init__(config)
172
  self.num_labels = config.num_labels
173
 
174
+ # Replace the base model with token classification version
175
+ if getattr(config, 'use_non_causal_attention', True):
176
+ self.model = Gemma3TokenClassificationModel(config)
177
+
178
  # Replace the lm_head with classification head
 
179
  classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
180
  self.lm_head = nn.Sequential(
181
  nn.Dropout(classifier_dropout_prob),
 
187
 
188
  # Initialize weights for the new head
189
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  def forward(
192
  self,
 
202
  return_dict: Optional[bool] = None,
203
  cache_position: Optional[torch.LongTensor] = None,
204
  ) -> TokenClassifierOutput:
205
+ """Forward pass for token classification."""
 
 
206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
 
208
+ # Get hidden states from the model
209
  outputs = self.model(
210
  input_ids=input_ids,
211
  attention_mask=attention_mask,
 
242
  )
243
 
244
 
245
+ # ============ Model Registration ============
246
+
247
  from transformers import AutoConfig, AutoModel
248
+
249
+ # Register the punctuation config and model
250
  AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
251
+ AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
252
+
253
+
254
+ # ============ Utility Functions ============
255
+
256
+
257
+ def create_token_classification_model(config: Gemma3PunctuationConfig):
258
+ """Create a token classification model with non-causal attention."""
259
+ return Gemma3ForTokenClassification(config)
260
+
261
+
262
+ def load_from_pretrained_with_config_detection(model_path: str, **kwargs):
263
+ """
264
+ Load model and auto-detect whether it's for token classification or bidirectional tasks
265
+ based on the config.
266
+ """
267
+ from transformers import AutoConfig
268
+
269
+ config = AutoConfig.from_pretrained(model_path)
270
+
271
+ if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation":
272
+ # Token classification model
273
+ return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs)