import torch from torch import nn import torch.nn.functional as F import types import os from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel from transformers.utils import logging logger = logging.get_logger(__name__) # This custom layer contains the core "abliterated" logic. # It subtracts a "steering vector" from the hidden states. class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer): def __init__(self, config, layer_idx): super().__init__(config, layer_idx) def forward(self, hidden_states, refusal_directions, *args, **kwargs): if refusal_directions is not None and self.layer_idx in refusal_directions: # Move refusal directions to the correct device layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device) # Project hidden states into the direction of the refusal vector projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1) # Get the steering vector steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions # Apply the steering vector hidden_states = hidden_states - steering_vector # Call the original forward pass of the layer return super().forward(hidden_states, *args, **kwargs) # This custom model class will automatically patch itself upon loading. class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM): def __init__(self, config): super().__init__(config) self.refusal_directions = None try: # In a Hugging Face model repo, config._name_or_path is the repo path refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt') if os.path.exists(refusal_directions_path): self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu") logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.") else: logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.") return except Exception as e: logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}") return # Patch the model by swapping the decoder layers logger.info("Patching model with AbliteratedDecoderLayer.") for i in range(len(self.model.layers)): old_layer = self.model.layers[i] # We need to pass the original config and layer_idx new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx) # Copy all weights and buffers from the old layer new_layer.load_state_dict(old_layer.state_dict()) self.model.layers[i] = new_layer logger.info("Model patching complete.") def forward(self, *args, **kwargs): # We need to correctly pass the refusal_directions to the layers. # The layers are called inside self.model.forward. # So we patch the forward method of the underlying Qwen2MoeModel instance. original_model_forward = self.model.forward def patched_forward(*model_args, **model_kwargs): # The Qwen2MoeModel's forward method does not take our custom arg. # We need a way to pass it down. We can temporarily attach it to the `self.model` object. # The layers' forward methods were modified to accept refusal_directions # But the loop in Qwen2MoeModel.forward doesn't know about it. # The easiest way is to modify the loop itself. # Let's override the `self.model.forward` method entirely # To avoid re-patching on every call, we can do it once in __init__ # Let's move the forward patch to __init__ pass # Since we replaced the layers, their `forward` methods are now different. # We must modify the calling code in `self.model.forward` to pass the new argument. # The most robust way is to monkey-patch `self.model.forward` once. if not hasattr(self.model, '_forward_patched'): original_forward = self.model.forward def new_model_forward(*f_args, **f_kwargs): # The original `Qwen2MoeModel.forward` iterates through `self.layers` # and calls each `decoder_layer(...)`. # We need to inject `refusal_directions` into that call. # Let's redefine the entire `Qwen2MoeModel.forward` logic here # to ensure correctness. # This is a simplified version of the original source, modified for our purpose hidden_states = f_kwargs.get('inputs_embeds') if hidden_states is None: hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids')) for decoder_layer in self.model.layers: layer_outputs = decoder_layer( hidden_states, refusal_directions=self.refusal_directions, attention_mask=f_kwargs.get('attention_mask'), position_ids=f_kwargs.get('position_ids'), past_key_value=f_kwargs.get('past_key_values'), output_attentions=f_kwargs.get('output_attentions'), output_router_logits=f_kwargs.get('output_router_logits'), use_cache=f_kwargs.get('use_cache'), ) hidden_states = layer_outputs[0] hidden_states = self.model.norm(hidden_states) return (hidden_states,) # Return in a tuple as the base class expects self.model.forward = types.MethodType(new_model_forward, self) self.model._forward_patched = True return super().forward(*args, **kwargs)