vox commited on
Commit
ab817e7
·
1 Parent(s): 0760dd0

abliteration

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. final_refusal_dirs.pt +3 -0
  3. modeling_qwen3_moe.py +120 -170
.DS_Store ADDED
Binary file (6.15 kB). View file
 
final_refusal_dirs.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c00500a2055f8669d27bc4a6c26296eb484396d9bd8868f77e31fd8b77196c8
3
+ size 210282
modeling_qwen3_moe.py CHANGED
@@ -1,174 +1,124 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
2
  import torch
3
- import torch.nn as nn
 
 
4
  import os
5
- import signal
6
- from typing import Optional, Tuple
7
- import einops
8
- import jaxtyping
9
-
10
- cpu_count = os.cpu_count()
11
- print(f"Number of CPU cores in the system: {cpu_count}")
12
- half_cpu_count = cpu_count // 2
13
- os.environ["MKL_NUM_THREADS"] = str(half_cpu_count)
14
- os.environ["OMP_NUM_THREADS"] = str(half_cpu_count)
15
- torch.set_num_threads(half_cpu_count)
16
-
17
- print(f"PyTorch threads: {torch.get_num_threads()}")
18
- print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}")
19
- print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}")
20
-
21
- # Load the model and tokenizer
22
- MODEL_ID = "Qwen/Qwen3-30B-A3B"
23
- print(f"Load Model {MODEL_ID} ... ")
24
- quant_config_4 = BitsAndBytesConfig(
25
- load_in_4bit=True,
26
- bnb_4bit_compute_dtype=torch.bfloat16,
27
- bnb_4bit_use_double_quant=True,
28
- llm_int8_enable_fp32_cpu_offload=True,
29
- )
30
-
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_ID,
33
- device_map="auto",
34
- trust_remote_code=True,
35
- quantization_config=quant_config_4,
36
- torch_dtype=torch.bfloat16
37
- )
38
-
39
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
40
- if tokenizer.pad_token is None:
41
- tokenizer.pad_token = tokenizer.eos_token
42
- tokenizer.pad_token_id = tokenizer.eos_token_id
43
-
44
- messages = []
45
- enable_thinking = True
46
- skip_prompt=True
47
- skip_special_tokens=True
48
-
49
- def direction_ablation_hook(activation: jaxtyping.Float[torch.Tensor, "... d_act"],
50
- direction: jaxtyping.Float[torch.Tensor, "d_act"]):
51
- proj = einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction
52
- return activation - proj
53
-
54
- class AblationDecoderLayer(nn.Module):
55
- def __init__(self, original_layer, refusal_dir):
56
- super(AblationDecoderLayer, self).__init__()
57
- self.original_layer = original_layer
58
- self.refusal_dir = refusal_dir
 
 
 
 
 
 
 
59
 
60
  def forward(self, *args, **kwargs):
61
- hidden_states = args[0]
62
- ablated = direction_ablation_hook(hidden_states, self.refusal_dir.to(hidden_states.device)).to(hidden_states.device)
63
- args = (ablated,) + args[1:]
64
- return self.original_layer.forward(*args, **kwargs)
65
-
66
- class CustomTextStreamer(TextStreamer):
67
- def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
68
- super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
69
- self.generated_text = ""
70
- self.stop_flag = False
71
-
72
- def on_finalized_text(self, text: str, stream_end: bool = False):
73
- self.generated_text += text
74
- print(text, end="", flush=True)
75
- if self.stop_flag:
76
- raise StopIteration
77
-
78
- def stop_generation(self):
79
- self.stop_flag = True
80
-
81
- def generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, max_new_tokens):
82
- input_ids = tokenizer.apply_chat_template(
83
- messages,
84
- tokenize=True,
85
- enable_thinking = enable_thinking,
86
- add_generation_prompt=True,
87
- return_tensors="pt"
88
- )
89
- attention_mask = torch.ones_like(input_ids, dtype=torch.long)
90
- tokens = input_ids.to(model.device)
91
- attention_mask = attention_mask.to(model.device)
92
-
93
- streamer = CustomTextStreamer(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
94
-
95
- def signal_handler(sig, frame):
96
- streamer.stop_generation()
97
- print("\n[Generation stopped by user with Ctrl+C]")
98
-
99
- signal.signal(signal.SIGINT, signal_handler)
100
-
101
- print("Response: ", end="", flush=True)
102
- try:
103
- generated_ids = model.generate(
104
- tokens,
105
- attention_mask=attention_mask,
106
- use_cache=False,
107
- max_new_tokens=max_new_tokens,
108
- do_sample=True,
109
- pad_token_id=tokenizer.pad_token_id,
110
- streamer=streamer
111
- )
112
- del generated_ids
113
- except StopIteration:
114
- print("\n[Stopped by user]")
115
-
116
- del input_ids, attention_mask
117
- torch.cuda.empty_cache()
118
- signal.signal(signal.SIGINT, signal.SIG_DFL)
119
-
120
- return streamer.generated_text, streamer.stop_flag
121
-
122
-
123
-
124
- final_refusal_dirs= torch.load(MODEL_ID + "/final_refusal_dirs.pt", map_location='cpu', weights_only=True)
125
- # candidate layer, 16, 21 ...
126
- candidate_layer = 16
127
-
128
- refusal_dir = final_refusal_dirs[candidate_layer]
129
-
130
- for idx in range(len(model.model.layers)):
131
- model.model.layers[idx] = AblationDecoderLayer(model.model.layers[idx], refusal_dir)
132
-
133
- while True:
134
- user_input = input("User: ").strip()
135
- if user_input.lower() == "/exit":
136
- print("Exiting chat.")
137
- break
138
- if user_input.lower() == "/clear":
139
- messages = []
140
- print("Chat history cleared. Starting a new conversation.")
141
- continue
142
- if user_input.lower() == "/no_think":
143
- if enable_thinking:
144
- enable_thinking = False
145
- print("Thinking = False.")
146
- else:
147
- enable_thinking = True
148
- print("Thinking = True.")
149
- continue
150
- if user_input.lower() == "/skip_prompt":
151
- if skip_prompt:
152
- skip_prompt = False
153
- print("skip_prompt = False.")
154
- else:
155
- skip_prompt = True
156
- print("skip_prompt = True.")
157
- continue
158
- if user_input.lower() == "/skip_special_tokens":
159
- if skip_special_tokens:
160
- skip_special_tokens = False
161
- print("skip_special_tokens = False.")
162
- else:
163
- skip_special_tokens = True
164
- print("skip_special_tokens = True.")
165
- continue
166
- if not user_input:
167
- print("Input cannot be empty. Please enter something.")
168
- continue
169
- messages.append({"role": "user", "content": user_input})
170
- response, stop_flag = generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, 8192)
171
- print("", flush=True)
172
- if stop_flag:
173
- continue
174
- messages.append({"role": "assistant", "content": response})
 
 
1
  import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import types
5
  import os
6
+
7
+ from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeForCausalLM, Qwen2MoeModel
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ # This custom layer contains the core "abliterated" logic.
14
+ # It subtracts a "steering vector" from the hidden states.
15
+ class AbliteratedDecoderLayer(Qwen2MoeDecoderLayer):
16
+ def __init__(self, config, layer_idx):
17
+ super().__init__(config, layer_idx)
18
+
19
+ def forward(self, hidden_states, refusal_directions, *args, **kwargs):
20
+ if refusal_directions is not None and self.layer_idx in refusal_directions:
21
+ # Move refusal directions to the correct device
22
+ layer_refusal_directions = refusal_directions[self.layer_idx].to(hidden_states.device)
23
+
24
+ # Project hidden states into the direction of the refusal vector
25
+ projected_hidden_states = F.cosine_similarity(hidden_states, layer_refusal_directions.unsqueeze(0).unsqueeze(0), dim=-1)
26
+
27
+ # Get the steering vector
28
+ steering_vector = projected_hidden_states.unsqueeze(-1) * layer_refusal_directions
29
+
30
+ # Apply the steering vector
31
+ hidden_states = hidden_states - steering_vector
32
+
33
+ # Call the original forward pass of the layer
34
+ return super().forward(hidden_states, *args, **kwargs)
35
+
36
+
37
+ # This custom model class will automatically patch itself upon loading.
38
+ class AbliteratedQwen3MoeForCausalLM(Qwen2MoeForCausalLM):
39
+ def __init__(self, config):
40
+ super().__init__(config)
41
+
42
+ self.refusal_directions = None
43
+ try:
44
+ # In a Hugging Face model repo, config._name_or_path is the repo path
45
+ refusal_directions_path = os.path.join(config._name_or_path, 'final_refusal_dirs.pt')
46
+ if os.path.exists(refusal_directions_path):
47
+ self.refusal_directions = torch.load(refusal_directions_path, map_location="cpu")
48
+ logger.info("Successfully loaded 'final_refusal_dirs.pt' for model abliteration.")
49
+ else:
50
+ logger.warning("'final_refusal_dirs.pt' not found. Model will not be abliterated.")
51
+ return
52
+
53
+ except Exception as e:
54
+ logger.error(f"Failed to load 'final_refusal_dirs.pt'. Model will not be abliterated. Error: {e}")
55
+ return
56
+
57
+ # Patch the model by swapping the decoder layers
58
+ logger.info("Patching model with AbliteratedDecoderLayer.")
59
+ for i in range(len(self.model.layers)):
60
+ old_layer = self.model.layers[i]
61
+ # We need to pass the original config and layer_idx
62
+ new_layer = AbliteratedDecoderLayer(old_layer.config, old_layer.layer_idx)
63
+ # Copy all weights and buffers from the old layer
64
+ new_layer.load_state_dict(old_layer.state_dict())
65
+ self.model.layers[i] = new_layer
66
+ logger.info("Model patching complete.")
67
 
68
  def forward(self, *args, **kwargs):
69
+ # We need to correctly pass the refusal_directions to the layers.
70
+ # The layers are called inside self.model.forward.
71
+ # So we patch the forward method of the underlying Qwen2MoeModel instance.
72
+ original_model_forward = self.model.forward
73
+
74
+ def patched_forward(*model_args, **model_kwargs):
75
+ # The Qwen2MoeModel's forward method does not take our custom arg.
76
+ # We need a way to pass it down. We can temporarily attach it to the `self.model` object.
77
+
78
+ # The layers' forward methods were modified to accept refusal_directions
79
+ # But the loop in Qwen2MoeModel.forward doesn't know about it.
80
+ # The easiest way is to modify the loop itself.
81
+ # Let's override the `self.model.forward` method entirely
82
+
83
+ # To avoid re-patching on every call, we can do it once in __init__
84
+ # Let's move the forward patch to __init__
85
+ pass
86
+
87
+ # Since we replaced the layers, their `forward` methods are now different.
88
+ # We must modify the calling code in `self.model.forward` to pass the new argument.
89
+ # The most robust way is to monkey-patch `self.model.forward` once.
90
+ if not hasattr(self.model, '_forward_patched'):
91
+ original_forward = self.model.forward
92
+
93
+ def new_model_forward(*f_args, **f_kwargs):
94
+ # The original `Qwen2MoeModel.forward` iterates through `self.layers`
95
+ # and calls each `decoder_layer(...)`.
96
+ # We need to inject `refusal_directions` into that call.
97
+ # Let's redefine the entire `Qwen2MoeModel.forward` logic here
98
+ # to ensure correctness.
99
+
100
+ # This is a simplified version of the original source, modified for our purpose
101
+ hidden_states = f_kwargs.get('inputs_embeds')
102
+ if hidden_states is None:
103
+ hidden_states = self.model.embed_tokens(f_kwargs.get('input_ids'))
104
+
105
+ for decoder_layer in self.model.layers:
106
+ layer_outputs = decoder_layer(
107
+ hidden_states,
108
+ refusal_directions=self.refusal_directions,
109
+ attention_mask=f_kwargs.get('attention_mask'),
110
+ position_ids=f_kwargs.get('position_ids'),
111
+ past_key_value=f_kwargs.get('past_key_values'),
112
+ output_attentions=f_kwargs.get('output_attentions'),
113
+ output_router_logits=f_kwargs.get('output_router_logits'),
114
+ use_cache=f_kwargs.get('use_cache'),
115
+ )
116
+ hidden_states = layer_outputs[0]
117
+
118
+ hidden_states = self.model.norm(hidden_states)
119
+ return (hidden_states,) # Return in a tuple as the base class expects
120
+
121
+ self.model.forward = types.MethodType(new_model_forward, self)
122
+ self.model._forward_patched = True
123
+
124
+ return super().forward(*args, **kwargs)