|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import types |
|
import os |
|
|
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer): |
|
def __init__(self, config, layer_idx): |
|
super().__init__(config, layer_idx) |
|
|
|
def forward(self, hidden_states, refusal_directions, *args, **kwargs): |
|
if refusal_directions is not None and self.layer_idx in refusal_directions: |
|
|
|
layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device) |
|
|
|
|
|
projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1) |
|
|
|
|
|
steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions |
|
|
|
|
|
hidden_states = hidden_states - steering_vector |
|
|
|
|
|
return super().forward(hidden_states, *args, **kwargs) |
|
|
|
|
|
|
|
class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.refusal_directions = None |
|
try: |
|
|
|
refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt') |
|
if os.path.exists(refusal_directions_path): |
|
self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu") |
|
logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.") |
|
else: |
|
logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.") |
|
return |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}") |
|
return |
|
|
|
|
|
logger.info("Patching model with AbliteratedDecoderLayer.") |
|
for i in range(len(self.model.layers)): |
|
old_layer = self.model.layers[i] |
|
|
|
new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx) |
|
|
|
new_layer.load_state_dict(old_layer.state_dict()) |
|
self.model.layers[i] = new_layer |
|
logger.info("Model patching complete.") |
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
|
|
|
original_model_forward = self.model.forward |
|
|
|
def patched_forward(*model_args, **model_kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
if not hasattr(self.model, '_forward_patched'): |
|
original_forward = self.model.forward |
|
|
|
def new_model_forward(*f_args, **f_kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = f_kwargs.get('inputs_embeds') |
|
if hidden_states is None: |
|
hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids')) |
|
|
|
for decoder_layer in self.model.layers: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
refusal_directions=self.refusal_directions, |
|
attention_mask=f_kwargs.get('attention_mask'), |
|
position_ids=f_kwargs.get('position_ids'), |
|
past_key_value=f_kwargs.get('past_key_values'), |
|
output_attentions=f_kwargs.get('output_attentions'), |
|
output_router_logits=f_kwargs.get('output_router_logits'), |
|
use_cache=f_kwargs.get('use_cache'), |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
hidden_states = self.model.norm(hidden_states) |
|
return (hidden_states,) |
|
|
|
self.model.forward = types.MethodType(new_model_forward, self) |
|
self.model._forward_patched = True |
|
|
|
return super().forward(*args, **kwargs) |