vox
commited on
Commit
·
ab817e7
1
Parent(s):
0760dd0
abliteration
Browse files- .DS_Store +0 -0
- final_refusal_dirs.pt +3 -0
- modeling_qwen3_moe.py +120 -170
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
final_refusal_dirs.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c00500a2055f8669d27bc4a6c26296eb484396d9bd8868f77e31fd8b77196c8
|
3 |
+
size 210282
|
modeling_qwen3_moe.py
CHANGED
@@ -1,174 +1,124 @@
|
|
1 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
|
2 |
import torch
|
3 |
-
|
|
|
|
|
4 |
import os
|
5 |
-
|
6 |
-
from
|
7 |
-
import
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def forward(self, *args, **kwargs):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
torch.cuda.empty_cache()
|
118 |
-
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
119 |
-
|
120 |
-
return streamer.generated_text, streamer.stop_flag
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
final_refusal_dirs= torch.load(MODEL_ID + "/final_refusal_dirs.pt", map_location='cpu', weights_only=True)
|
125 |
-
# candidate layer, 16, 21 ...
|
126 |
-
candidate_layer = 16
|
127 |
-
|
128 |
-
refusal_dir = final_refusal_dirs[candidate_layer]
|
129 |
-
|
130 |
-
for idx in range(len(model.model.layers)):
|
131 |
-
model.model.layers[idx] = AblationDecoderLayer(model.model.layers[idx], refusal_dir)
|
132 |
-
|
133 |
-
while True:
|
134 |
-
user_input = input("User: ").strip()
|
135 |
-
if user_input.lower() == "/exit":
|
136 |
-
print("Exiting chat.")
|
137 |
-
break
|
138 |
-
if user_input.lower() == "/clear":
|
139 |
-
messages = []
|
140 |
-
print("Chat history cleared. Starting a new conversation.")
|
141 |
-
continue
|
142 |
-
if user_input.lower() == "/no_think":
|
143 |
-
if enable_thinking:
|
144 |
-
enable_thinking = False
|
145 |
-
print("Thinking = False.")
|
146 |
-
else:
|
147 |
-
enable_thinking = True
|
148 |
-
print("Thinking = True.")
|
149 |
-
continue
|
150 |
-
if user_input.lower() == "/skip_prompt":
|
151 |
-
if skip_prompt:
|
152 |
-
skip_prompt = False
|
153 |
-
print("skip_prompt = False.")
|
154 |
-
else:
|
155 |
-
skip_prompt = True
|
156 |
-
print("skip_prompt = True.")
|
157 |
-
continue
|
158 |
-
if user_input.lower() == "/skip_special_tokens":
|
159 |
-
if skip_special_tokens:
|
160 |
-
skip_special_tokens = False
|
161 |
-
print("skip_special_tokens = False.")
|
162 |
-
else:
|
163 |
-
skip_special_tokens = True
|
164 |
-
print("skip_special_tokens = True.")
|
165 |
-
continue
|
166 |
-
if not user_input:
|
167 |
-
print("Input cannot be empty. Please enter something.")
|
168 |
-
continue
|
169 |
-
messages.append({"role": "user", "content": user_input})
|
170 |
-
response, stop_flag = generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, 8192)
|
171 |
-
print("", flush=True)
|
172 |
-
if stop_flag:
|
173 |
-
continue
|
174 |
-
messages.append({"role": "assistant", "content": response})
|
|
|
|
|
1 |
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import types
|
5 |
import os
|
6 |
+
|
7 |
+
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
|
8 |
+
from transformers.utils import logging
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.get_logger(__name__)
|
12 |
+
|
13 |
+
# This custom layer contains the core "abliterated" logic.
|
14 |
+
# It subtracts a "steering vector" from the hidden states.
|
15 |
+
class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer):
|
16 |
+
def __init__(self, config, layer_idx):
|
17 |
+
super().__init__(config, layer_idx)
|
18 |
+
|
19 |
+
def forward(self, hidden_states, refusal_directions, *args, **kwargs):
|
20 |
+
if refusal_directions is not None and self.layer_idx in refusal_directions:
|
21 |
+
# Move refusal directions to the correct device
|
22 |
+
layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device)
|
23 |
+
|
24 |
+
# Project hidden states into the direction of the refusal vector
|
25 |
+
projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1)
|
26 |
+
|
27 |
+
# Get the steering vector
|
28 |
+
steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions
|
29 |
+
|
30 |
+
# Apply the steering vector
|
31 |
+
hidden_states = hidden_states - steering_vector
|
32 |
+
|
33 |
+
# Call the original forward pass of the layer
|
34 |
+
return super().forward(hidden_states, *args, **kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
# This custom model class will automatically patch itself upon loading.
|
38 |
+
class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM):
|
39 |
+
def __init__(self, config):
|
40 |
+
super().__init__(config)
|
41 |
+
|
42 |
+
self.refusal_directions = None
|
43 |
+
try:
|
44 |
+
# In a Hugging Face model repo, config._name_or_path is the repo path
|
45 |
+
refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt')
|
46 |
+
if os.path.exists(refusal_directions_path):
|
47 |
+
self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu")
|
48 |
+
logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.")
|
49 |
+
else:
|
50 |
+
logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.")
|
51 |
+
return
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}")
|
55 |
+
return
|
56 |
+
|
57 |
+
# Patch the model by swapping the decoder layers
|
58 |
+
logger.info("Patching model with AbliteratedDecoderLayer.")
|
59 |
+
for i in range(len(self.model.layers)):
|
60 |
+
old_layer = self.model.layers[i]
|
61 |
+
# We need to pass the original config and layer_idx
|
62 |
+
new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx)
|
63 |
+
# Copy all weights and buffers from the old layer
|
64 |
+
new_layer.load_state_dict(old_layer.state_dict())
|
65 |
+
self.model.layers[i] = new_layer
|
66 |
+
logger.info("Model patching complete.")
|
67 |
|
68 |
def forward(self, *args, **kwargs):
|
69 |
+
# We need to correctly pass the refusal_directions to the layers.
|
70 |
+
# The layers are called inside self.model.forward.
|
71 |
+
# So we patch the forward method of the underlying Qwen2MoeModel instance.
|
72 |
+
original_model_forward = self.model.forward
|
73 |
+
|
74 |
+
def patched_forward(*model_args, **model_kwargs):
|
75 |
+
# The Qwen2MoeModel's forward method does not take our custom arg.
|
76 |
+
# We need a way to pass it down. We can temporarily attach it to the `self.model` object.
|
77 |
+
|
78 |
+
# The layers' forward methods were modified to accept refusal_directions
|
79 |
+
# But the loop in Qwen2MoeModel.forward doesn't know about it.
|
80 |
+
# The easiest way is to modify the loop itself.
|
81 |
+
# Let's override the `self.model.forward` method entirely
|
82 |
+
|
83 |
+
# To avoid re-patching on every call, we can do it once in __init__
|
84 |
+
# Let's move the forward patch to __init__
|
85 |
+
pass
|
86 |
+
|
87 |
+
# Since we replaced the layers, their `forward` methods are now different.
|
88 |
+
# We must modify the calling code in `self.model.forward` to pass the new argument.
|
89 |
+
# The most robust way is to monkey-patch `self.model.forward` once.
|
90 |
+
if not hasattr(self.model, '_forward_patched'):
|
91 |
+
original_forward = self.model.forward
|
92 |
+
|
93 |
+
def new_model_forward(*f_args, **f_kwargs):
|
94 |
+
# The original `Qwen2MoeModel.forward` iterates through `self.layers`
|
95 |
+
# and calls each `decoder_layer(...)`.
|
96 |
+
# We need to inject `refusal_directions` into that call.
|
97 |
+
# Let's redefine the entire `Qwen2MoeModel.forward` logic here
|
98 |
+
# to ensure correctness.
|
99 |
+
|
100 |
+
# This is a simplified version of the original source, modified for our purpose
|
101 |
+
hidden_states = f_kwargs.get('inputs_embeds')
|
102 |
+
if hidden_states is None:
|
103 |
+
hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids'))
|
104 |
+
|
105 |
+
for decoder_layer in self.model.layers:
|
106 |
+
layer_outputs = decoder_layer(
|
107 |
+
hidden_states,
|
108 |
+
refusal_directions=self.refusal_directions,
|
109 |
+
attention_mask=f_kwargs.get('attention_mask'),
|
110 |
+
position_ids=f_kwargs.get('position_ids'),
|
111 |
+
past_key_value=f_kwargs.get('past_key_values'),
|
112 |
+
output_attentions=f_kwargs.get('output_attentions'),
|
113 |
+
output_router_logits=f_kwargs.get('output_router_logits'),
|
114 |
+
use_cache=f_kwargs.get('use_cache'),
|
115 |
+
)
|
116 |
+
hidden_states = layer_outputs[0]
|
117 |
+
|
118 |
+
hidden_states = self.model.norm(hidden_states)
|
119 |
+
return (hidden_states,) # Return in a tuple as the base class expects
|
120 |
+
|
121 |
+
self.model.forward = types.MethodType(new_model_forward, self)
|
122 |
+
self.model._forward_patched = True
|
123 |
+
|
124 |
+
return super().forward(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|