File size: 6,182 Bytes
0760dd0
ab817e7
 
 
0760dd0
ab817e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0760dd0
 
ab817e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch import nn
import torch.nn.functional as F
import types
import os

from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
from transformers.utils import logging


logger = logging.get_logger(__name__)

# This custom layer contains the core "abliterated" logic.
# It subtracts a "steering vector" from the hidden states.
class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)

    def forward(self, hidden_states, refusal_directions, *args, **kwargs):
        if refusal_directions is not None and self.layer_idx in refusal_directions:
            # Move refusal directions to the correct device
            layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device)
            
            # Project hidden states into the direction of the refusal vector
            projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1)
            
            # Get the steering vector
            steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions
            
            # Apply the steering vector
            hidden_states = hidden_states - steering_vector

        # Call the original forward pass of the layer
        return super().forward(hidden_states, *args, **kwargs)


# This custom model class will automatically patch itself upon loading.
class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        self.refusal_directions = None
        try:
            # In a Hugging Face model repo, config._name_or_path is the repo path
            refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt')
            if os.path.exists(refusal_directions_path):
                self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu")
                logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.")
            else:
                 logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.")
                 return

        except Exception as e:
            logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}")
            return

        # Patch the model by swapping the decoder layers
        logger.info("Patching model with AbliteratedDecoderLayer.")
        for i in range(len(self.model.layers)):
            old_layer = self.model.layers[i]
            # We need to pass the original config and layer_idx
            new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx)
            # Copy all weights and buffers from the old layer
            new_layer.load_state_dict(old_layer.state_dict())
            self.model.layers[i] = new_layer
        logger.info("Model patching complete.")

    def forward(self, *args, **kwargs):
        # We need to correctly pass the refusal_directions to the layers.
        # The layers are called inside self.model.forward.
        # So we patch the forward method of the underlying Qwen2MoeModel instance.
        original_model_forward = self.model.forward
        
        def patched_forward(*model_args, **model_kwargs):
            # The Qwen2MoeModel's forward method does not take our custom arg.
            # We need a way to pass it down. We can temporarily attach it to the `self.model` object.
            
            # The layers' forward methods were modified to accept refusal_directions
            # But the loop in Qwen2MoeModel.forward doesn't know about it.
            # The easiest way is to modify the loop itself.
            # Let's override the `self.model.forward` method entirely
            
            # To avoid re-patching on every call, we can do it once in __init__
            # Let's move the forward patch to __init__
            pass

        # Since we replaced the layers, their `forward` methods are now different.
        # We must modify the calling code in `self.model.forward` to pass the new argument.
        # The most robust way is to monkey-patch `self.model.forward` once.
        if not hasattr(self.model, '_forward_patched'):
            original_forward = self.model.forward
            
            def new_model_forward(*f_args, **f_kwargs):
                # The original `Qwen2MoeModel.forward` iterates through `self.layers`
                # and calls each `decoder_layer(...)`.
                # We need to inject `refusal_directions` into that call.
                # Let's redefine the entire `Qwen2MoeModel.forward` logic here
                # to ensure correctness.
                
                # This is a simplified version of the original source, modified for our purpose
                hidden_states = f_kwargs.get('inputs_embeds')
                if hidden_states is None:
                    hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids'))

                for decoder_layer in self.model.layers:
                    layer_outputs = decoder_layer(
                        hidden_states,
                        refusal_directions=self.refusal_directions,
                        attention_mask=f_kwargs.get('attention_mask'),
                        position_ids=f_kwargs.get('position_ids'),
                        past_key_value=f_kwargs.get('past_key_values'),
                        output_attentions=f_kwargs.get('output_attentions'),
                        output_router_logits=f_kwargs.get('output_router_logits'),
                        use_cache=f_kwargs.get('use_cache'),
                    )
                    hidden_states = layer_outputs[0]

                hidden_states = self.model.norm(hidden_states)
                return (hidden_states,) # Return in a tuple as the base class expects

            self.model.forward = types.MethodType(new_model_forward, self)
            self.model._forward_patched = True

        return super().forward(*args, **kwargs)