Qwen3-30B-A3B-abliterated-FP8-dynamic / modeling_qwen3_moe.py
vox
abliteration
ab817e7
import torch
from torch import nn
import torch.nn.functional as F
import types
import os
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
# This custom layer contains the core "abliterated" logic.
# It subtracts a "steering vector" from the hidden states.
class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx)
def forward(self, hidden_states, refusal_directions, *args, **kwargs):
if refusal_directions is not None and self.layer_idx in refusal_directions:
# Move refusal directions to the correct device
layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device)
# Project hidden states into the direction of the refusal vector
projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1)
# Get the steering vector
steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions
# Apply the steering vector
hidden_states = hidden_states - steering_vector
# Call the original forward pass of the layer
return super().forward(hidden_states, *args, **kwargs)
# This custom model class will automatically patch itself upon loading.
class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM):
def __init__(self, config):
super().__init__(config)
self.refusal_directions = None
try:
# In a Hugging Face model repo, config._name_or_path is the repo path
refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt')
if os.path.exists(refusal_directions_path):
self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu")
logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.")
else:
logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.")
return
except Exception as e:
logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}")
return
# Patch the model by swapping the decoder layers
logger.info("Patching model with AbliteratedDecoderLayer.")
for i in range(len(self.model.layers)):
old_layer = self.model.layers[i]
# We need to pass the original config and layer_idx
new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx)
# Copy all weights and buffers from the old layer
new_layer.load_state_dict(old_layer.state_dict())
self.model.layers[i] = new_layer
logger.info("Model patching complete.")
def forward(self, *args, **kwargs):
# We need to correctly pass the refusal_directions to the layers.
# The layers are called inside self.model.forward.
# So we patch the forward method of the underlying Qwen2MoeModel instance.
original_model_forward = self.model.forward
def patched_forward(*model_args, **model_kwargs):
# The Qwen2MoeModel's forward method does not take our custom arg.
# We need a way to pass it down. We can temporarily attach it to the `self.model` object.
# The layers' forward methods were modified to accept refusal_directions
# But the loop in Qwen2MoeModel.forward doesn't know about it.
# The easiest way is to modify the loop itself.
# Let's override the `self.model.forward` method entirely
# To avoid re-patching on every call, we can do it once in __init__
# Let's move the forward patch to __init__
pass
# Since we replaced the layers, their `forward` methods are now different.
# We must modify the calling code in `self.model.forward` to pass the new argument.
# The most robust way is to monkey-patch `self.model.forward` once.
if not hasattr(self.model, '_forward_patched'):
original_forward = self.model.forward
def new_model_forward(*f_args, **f_kwargs):
# The original `Qwen2MoeModel.forward` iterates through `self.layers`
# and calls each `decoder_layer(...)`.
# We need to inject `refusal_directions` into that call.
# Let's redefine the entire `Qwen2MoeModel.forward` logic here
# to ensure correctness.
# This is a simplified version of the original source, modified for our purpose
hidden_states = f_kwargs.get('inputs_embeds')
if hidden_states is None:
hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids'))
for decoder_layer in self.model.layers:
layer_outputs = decoder_layer(
hidden_states,
refusal_directions=self.refusal_directions,
attention_mask=f_kwargs.get('attention_mask'),
position_ids=f_kwargs.get('position_ids'),
past_key_value=f_kwargs.get('past_key_values'),
output_attentions=f_kwargs.get('output_attentions'),
output_router_logits=f_kwargs.get('output_router_logits'),
use_cache=f_kwargs.get('use_cache'),
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
return (hidden_states,) # Return in a tuple as the base class expects
self.model.forward = types.MethodType(new_model_forward, self)
self.model._forward_patched = True
return super().forward(*args, **kwargs)