Completely overhauled the attention implementation. Using the existing Gemma-3 attention implementation rather than custom monkey-patched implementation. (#10)
Browse files- Completely overhauled the attention implementation. Using the existing Gemma-3 attention implementation rather than custom monkey-patched implementation. (efb4d2d4f654499b929a467d423403d3830628e7)
Co-authored-by: Pulipaka Prem Sidharth <[email protected]>
- modeling_gemma3_punctuation.py +155 -190
modeling_gemma3_punctuation.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1 |
"""
|
2 |
-
|
3 |
"""
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
-
from typing import Optional,
|
8 |
-
import
|
9 |
|
10 |
from transformers import PretrainedConfig, PreTrainedModel
|
11 |
-
from transformers import Gemma3ForCausalLM
|
12 |
from transformers.models.gemma3.modeling_gemma3 import (
|
13 |
Gemma3Attention,
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
Cache,
|
18 |
-
FlashAttentionKwargs,
|
19 |
)
|
|
|
20 |
from transformers.modeling_outputs import TokenClassifierOutput
|
21 |
from transformers.utils import logging
|
22 |
|
@@ -27,7 +26,6 @@ class Gemma3PunctuationConfig(PretrainedConfig):
|
|
27 |
"""
|
28 |
Configuration class for Gemma3 punctuation model.
|
29 |
"""
|
30 |
-
|
31 |
model_type = "cadence_punctuation"
|
32 |
|
33 |
def __init__(
|
@@ -43,171 +41,141 @@ class Gemma3PunctuationConfig(PretrainedConfig):
|
|
43 |
super().__init__(**kwargs)
|
44 |
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
attn_weights = attn_weights + mask_slice
|
95 |
-
|
96 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
97 |
-
is_training = getattr(module, "training", False)
|
98 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=is_training)
|
99 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
100 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
101 |
-
return attn_output, attn_weights
|
102 |
-
|
103 |
-
|
104 |
-
def modified_gemma3_attention_forward_non_causal(
|
105 |
-
self: Gemma3Attention,
|
106 |
-
hidden_states: torch.Tensor,
|
107 |
-
position_embeddings: torch.Tensor,
|
108 |
-
attention_mask: Optional[torch.Tensor],
|
109 |
-
past_key_value: Optional[Cache] = None,
|
110 |
-
cache_position: Optional[torch.LongTensor] = None,
|
111 |
-
**kwargs: Any,
|
112 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
113 |
-
"""Modified Gemma3 attention forward for non-causal behavior."""
|
114 |
-
bsz, q_len, _ = hidden_states.size()
|
115 |
-
input_shape = hidden_states.shape[:-1]
|
116 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
117 |
-
|
118 |
-
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
119 |
-
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
120 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
121 |
-
|
122 |
-
query_states = self.q_norm(query_states)
|
123 |
-
key_states = self.k_norm(key_states)
|
124 |
-
cos, sin = position_embeddings
|
125 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
126 |
-
|
127 |
-
if past_key_value is not None:
|
128 |
-
cache_kwargs = {
|
129 |
-
"sin": sin,
|
130 |
-
"cos": cos,
|
131 |
-
"cache_position": cache_position,
|
132 |
-
"sliding_window": self.sliding_window
|
133 |
-
}
|
134 |
-
key_states, value_states = past_key_value.update(
|
135 |
-
key_states, value_states, self.layer_idx, cache_kwargs
|
136 |
)
|
|
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
if effective_attn_implementation == "eager":
|
151 |
-
attention_interface = non_causal_eager_attention_forward_with_padding
|
152 |
-
elif effective_attn_implementation == "sdpa":
|
153 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS.get("sdpa", non_causal_eager_attention_forward_with_padding)
|
154 |
-
elif effective_attn_implementation == "flash_attention_2":
|
155 |
-
attention_interface = ALL_ATTENTION_FUNCTIONS.get("flash_attention_2", non_causal_eager_attention_forward_with_padding)
|
156 |
-
else:
|
157 |
-
attention_interface = non_causal_eager_attention_forward_with_padding
|
158 |
-
|
159 |
-
final_attention_mask = padding_only_mask
|
160 |
-
if final_attention_mask is not None:
|
161 |
-
final_attention_mask = final_attention_mask.to(query_states.device)
|
162 |
-
|
163 |
-
# Prepare kwargs for attention interface
|
164 |
-
attn_specific_kwargs: Dict[str, Any] = {}
|
165 |
-
if attention_interface == non_causal_eager_attention_forward_with_padding:
|
166 |
-
attn_specific_kwargs = {
|
167 |
-
"dropout": 0.0,
|
168 |
-
"scaling": self.scaling,
|
169 |
-
"softcap": getattr(self, "softcap", None)
|
170 |
-
}
|
171 |
-
elif effective_attn_implementation == "sdpa":
|
172 |
-
attn_specific_kwargs = {"is_causal": use_causal_flag}
|
173 |
-
if output_attentions:
|
174 |
-
attn_specific_kwargs["output_attentions"] = True
|
175 |
-
elif effective_attn_implementation == "flash_attention_2":
|
176 |
-
attn_specific_kwargs = {
|
177 |
-
"causal": use_causal_flag,
|
178 |
-
"softcap": getattr(self, "softcap", None),
|
179 |
-
"dropout": 0.0
|
180 |
-
}
|
181 |
-
if output_attentions:
|
182 |
-
attn_specific_kwargs["output_attentions"] = True
|
183 |
-
|
184 |
-
attn_output, attn_weights = attention_interface(
|
185 |
-
self, query_states, key_states, value_states, final_attention_mask, **attn_specific_kwargs
|
186 |
-
)
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
|
196 |
class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
197 |
"""
|
198 |
Gemma3 model for token classification (punctuation prediction).
|
199 |
-
|
200 |
"""
|
201 |
|
202 |
config_class = Gemma3PunctuationConfig
|
203 |
|
204 |
def __init__(self, config):
|
205 |
-
# Initialize
|
206 |
super().__init__(config)
|
207 |
self.num_labels = config.num_labels
|
208 |
|
|
|
|
|
|
|
|
|
209 |
# Replace the lm_head with classification head
|
210 |
-
# Don't create a separate classifier - just replace lm_head directly
|
211 |
classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
|
212 |
self.lm_head = nn.Sequential(
|
213 |
nn.Dropout(classifier_dropout_prob),
|
@@ -219,32 +187,6 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
|
219 |
|
220 |
# Initialize weights for the new head
|
221 |
self.post_init()
|
222 |
-
|
223 |
-
# Apply non-causal attention patching if requested
|
224 |
-
if getattr(config, 'use_non_causal_attention', True):
|
225 |
-
self._patch_attention_layers()
|
226 |
-
|
227 |
-
def _patch_attention_layers(self):
|
228 |
-
"""Patch attention layers to use non-causal attention."""
|
229 |
-
count = 0
|
230 |
-
|
231 |
-
# The model structure is self.model.layers (inherited from Gemma3ForCausalLM)
|
232 |
-
if hasattr(self, 'model') and hasattr(self.model, 'layers'):
|
233 |
-
target_layers = self.model.layers
|
234 |
-
else:
|
235 |
-
logger.warning("Could not find model.layers for attention patching")
|
236 |
-
return
|
237 |
-
|
238 |
-
for idx, layer in enumerate(target_layers):
|
239 |
-
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, Gemma3Attention):
|
240 |
-
layer.self_attn.layer_idx = idx
|
241 |
-
layer.self_attn.forward = types.MethodType(
|
242 |
-
modified_gemma3_attention_forward_non_causal,
|
243 |
-
layer.self_attn
|
244 |
-
)
|
245 |
-
count += 1
|
246 |
-
|
247 |
-
logger.info(f"Patched {count} attention layers for non-causal attention")
|
248 |
|
249 |
def forward(
|
250 |
self,
|
@@ -260,12 +202,10 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
|
260 |
return_dict: Optional[bool] = None,
|
261 |
cache_position: Optional[torch.LongTensor] = None,
|
262 |
) -> TokenClassifierOutput:
|
263 |
-
"""
|
264 |
-
Forward pass for token classification.
|
265 |
-
"""
|
266 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
267 |
|
268 |
-
#
|
269 |
outputs = self.model(
|
270 |
input_ids=input_ids,
|
271 |
attention_mask=attention_mask,
|
@@ -302,7 +242,32 @@ class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
|
302 |
)
|
303 |
|
304 |
|
305 |
-
#
|
|
|
306 |
from transformers import AutoConfig, AutoModel
|
|
|
|
|
307 |
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
|
308 |
-
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
Change the attention of Gemma3 to be bidirectional.
|
3 |
"""
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
+
from typing import Optional, List, Dict, Any
|
8 |
+
from functools import partial
|
9 |
|
10 |
from transformers import PretrainedConfig, PreTrainedModel
|
11 |
+
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
|
12 |
from transformers.models.gemma3.modeling_gemma3 import (
|
13 |
Gemma3Attention,
|
14 |
+
Gemma3DecoderLayer,
|
15 |
+
Gemma3TextModel,
|
16 |
+
|
|
|
|
|
17 |
)
|
18 |
+
|
19 |
from transformers.modeling_outputs import TokenClassifierOutput
|
20 |
from transformers.utils import logging
|
21 |
|
|
|
26 |
"""
|
27 |
Configuration class for Gemma3 punctuation model.
|
28 |
"""
|
|
|
29 |
model_type = "cadence_punctuation"
|
30 |
|
31 |
def __init__(
|
|
|
41 |
super().__init__(**kwargs)
|
42 |
|
43 |
|
44 |
+
# ============ Token Classification Model Components ============
|
45 |
+
|
46 |
+
class NonCausalGemma3Attention(Gemma3Attention):
|
47 |
+
"""Gemma3Attention configured for non-causal token classification."""
|
48 |
+
def __init__(self, config, layer_idx: int):
|
49 |
+
super().__init__(config, layer_idx)
|
50 |
+
self.is_causal = False
|
51 |
+
self.sliding_window = None
|
52 |
+
|
53 |
+
|
54 |
+
class NonCausalGemma3DecoderLayer(Gemma3DecoderLayer):
|
55 |
+
"""Decoder layer with non-causal attention for token classification."""
|
56 |
+
def __init__(self, config, layer_idx: int):
|
57 |
+
super().__init__(config, layer_idx)
|
58 |
+
self.self_attn = NonCausalGemma3Attention(config, layer_idx)
|
59 |
+
|
60 |
+
|
61 |
+
class Gemma3TokenClassificationModel(Gemma3TextModel):
|
62 |
+
"""Gemma3 base model configured for token classification."""
|
63 |
+
_no_split_modules = ["NonCausalGemma3DecoderLayer"]
|
64 |
+
|
65 |
+
def __init__(self, config):
|
66 |
+
super().__init__(config)
|
67 |
+
if getattr(config, 'use_non_causal_attention', True):
|
68 |
+
# Replace layers with non-causal versions
|
69 |
+
self.layers = nn.ModuleList(
|
70 |
+
[
|
71 |
+
NonCausalGemma3DecoderLayer(config, layer_idx)
|
72 |
+
for layer_idx in range(config.num_hidden_layers)
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
def _update_causal_mask(
|
77 |
+
self,
|
78 |
+
attention_mask: torch.Tensor,
|
79 |
+
input_tensor: torch.Tensor,
|
80 |
+
cache_position: torch.Tensor,
|
81 |
+
past_key_values = None,
|
82 |
+
output_attentions: bool = False,
|
83 |
+
):
|
84 |
+
"""Override to create bidirectional attention mask (no causal masking)."""
|
85 |
+
if self.config._attn_implementation == "flash_attention_2":
|
86 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
87 |
+
return attention_mask
|
88 |
+
return None
|
89 |
+
|
90 |
+
past_seen_tokens = (
|
91 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
+
using_static_cache = isinstance(past_key_values, type(None)) is False and hasattr(past_key_values, 'get_max_length')
|
94 |
|
95 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
96 |
+
min_dtype = torch.finfo(dtype).min
|
97 |
+
sequence_length = input_tensor.shape[1]
|
98 |
+
|
99 |
+
if using_static_cache:
|
100 |
+
target_length = past_key_values.get_max_length()
|
101 |
+
else:
|
102 |
+
target_length = (
|
103 |
+
attention_mask.shape[-1]
|
104 |
+
if isinstance(attention_mask, torch.Tensor)
|
105 |
+
else past_seen_tokens + sequence_length + 1
|
106 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
109 |
+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
110 |
+
if attention_mask.max() != 0:
|
111 |
+
raise ValueError(
|
112 |
+
"Custom 4D attention mask should be passed in inverted form with max==0`"
|
113 |
+
)
|
114 |
+
causal_mask = attention_mask
|
115 |
+
else:
|
116 |
+
# KEY CHANGE: Start with zeros (attend to all) instead of min_dtype (mask all)
|
117 |
+
causal_mask = torch.zeros(
|
118 |
+
(sequence_length, target_length), dtype=dtype, device=device
|
119 |
+
)
|
120 |
+
# REMOVED: Causal masking lines that would make it lower triangular
|
121 |
+
# if sequence_length != 1:
|
122 |
+
# causal_mask = torch.triu(causal_mask, diagonal=1)
|
123 |
+
|
124 |
+
causal_mask *= torch.arange(
|
125 |
+
target_length, device=device
|
126 |
+
) > cache_position.reshape(-1, 1)
|
127 |
+
causal_mask = causal_mask[None, None, :, :].expand(
|
128 |
+
input_tensor.shape[0], 1, -1, -1
|
129 |
+
)
|
130 |
+
|
131 |
+
if attention_mask is not None:
|
132 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
133 |
+
mask_length = attention_mask.shape[-1]
|
134 |
+
padding_mask = (
|
135 |
+
causal_mask[:, :, :, :mask_length]
|
136 |
+
+ attention_mask[:, None, None, :]
|
137 |
+
)
|
138 |
+
padding_mask = padding_mask == 0
|
139 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
140 |
+
:, :, :, :mask_length
|
141 |
+
].masked_fill(padding_mask, min_dtype)
|
142 |
+
|
143 |
+
# Handle SDPA-specific optimizations if needed
|
144 |
+
if (
|
145 |
+
self.config._attn_implementation == "sdpa"
|
146 |
+
and attention_mask is not None
|
147 |
+
and attention_mask.device.type == "cuda"
|
148 |
+
and not output_attentions
|
149 |
+
):
|
150 |
+
try:
|
151 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
152 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(
|
153 |
+
causal_mask, min_dtype
|
154 |
+
)
|
155 |
+
except ImportError:
|
156 |
+
pass # Fallback for older transformers versions
|
157 |
+
|
158 |
+
return causal_mask
|
159 |
|
160 |
|
161 |
class Gemma3ForTokenClassification(Gemma3ForCausalLM):
|
162 |
"""
|
163 |
Gemma3 model for token classification (punctuation prediction).
|
164 |
+
Uses class-based architecture without monkey patching.
|
165 |
"""
|
166 |
|
167 |
config_class = Gemma3PunctuationConfig
|
168 |
|
169 |
def __init__(self, config):
|
170 |
+
# Initialize with base Gemma3ForCausalLM structure
|
171 |
super().__init__(config)
|
172 |
self.num_labels = config.num_labels
|
173 |
|
174 |
+
# Replace the base model with token classification version
|
175 |
+
if getattr(config, 'use_non_causal_attention', True):
|
176 |
+
self.model = Gemma3TokenClassificationModel(config)
|
177 |
+
|
178 |
# Replace the lm_head with classification head
|
|
|
179 |
classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0)
|
180 |
self.lm_head = nn.Sequential(
|
181 |
nn.Dropout(classifier_dropout_prob),
|
|
|
187 |
|
188 |
# Initialize weights for the new head
|
189 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def forward(
|
192 |
self,
|
|
|
202 |
return_dict: Optional[bool] = None,
|
203 |
cache_position: Optional[torch.LongTensor] = None,
|
204 |
) -> TokenClassifierOutput:
|
205 |
+
"""Forward pass for token classification."""
|
|
|
|
|
206 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
207 |
|
208 |
+
# Get hidden states from the model
|
209 |
outputs = self.model(
|
210 |
input_ids=input_ids,
|
211 |
attention_mask=attention_mask,
|
|
|
242 |
)
|
243 |
|
244 |
|
245 |
+
# ============ Model Registration ============
|
246 |
+
|
247 |
from transformers import AutoConfig, AutoModel
|
248 |
+
|
249 |
+
# Register the punctuation config and model
|
250 |
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig)
|
251 |
+
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification)
|
252 |
+
|
253 |
+
|
254 |
+
# ============ Utility Functions ============
|
255 |
+
|
256 |
+
|
257 |
+
def create_token_classification_model(config: Gemma3PunctuationConfig):
|
258 |
+
"""Create a token classification model with non-causal attention."""
|
259 |
+
return Gemma3ForTokenClassification(config)
|
260 |
+
|
261 |
+
|
262 |
+
def load_from_pretrained_with_config_detection(model_path: str, **kwargs):
|
263 |
+
"""
|
264 |
+
Load model and auto-detect whether it's for token classification or bidirectional tasks
|
265 |
+
based on the config.
|
266 |
+
"""
|
267 |
+
from transformers import AutoConfig
|
268 |
+
|
269 |
+
config = AutoConfig.from_pretrained(model_path)
|
270 |
+
|
271 |
+
if hasattr(config, 'model_type') and config.model_type == "cadence_punctuation":
|
272 |
+
# Token classification model
|
273 |
+
return Gemma3ForTokenClassification.from_pretrained(model_path, config=config, **kwargs)
|