vox commited on
Commit
0760dd0
·
1 Parent(s): f4535a1

Add custom modeling code for abliteration

Browse files
Files changed (1) hide show
  1. modeling_qwen3_moe.py +174 -0
modeling_qwen3_moe.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import signal
6
+ from typing import Optional, Tuple
7
+ import einops
8
+ import jaxtyping
9
+
10
+ cpu_count = os.cpu_count()
11
+ print(f"Number of CPU cores in the system: {cpu_count}")
12
+ half_cpu_count = cpu_count // 2
13
+ os.environ["MKL_NUM_THREADS"] = str(half_cpu_count)
14
+ os.environ["OMP_NUM_THREADS"] = str(half_cpu_count)
15
+ torch.set_num_threads(half_cpu_count)
16
+
17
+ print(f"PyTorch threads: {torch.get_num_threads()}")
18
+ print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}")
19
+ print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}")
20
+
21
+ # Load the model and tokenizer
22
+ MODEL_ID = "Qwen/Qwen3-30B-A3B"
23
+ print(f"Load Model {MODEL_ID} ... ")
24
+ quant_config_4 = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_compute_dtype=torch.bfloat16,
27
+ bnb_4bit_use_double_quant=True,
28
+ llm_int8_enable_fp32_cpu_offload=True,
29
+ )
30
+
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_ID,
33
+ device_map="auto",
34
+ trust_remote_code=True,
35
+ quantization_config=quant_config_4,
36
+ torch_dtype=torch.bfloat16
37
+ )
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
40
+ if tokenizer.pad_token is None:
41
+ tokenizer.pad_token = tokenizer.eos_token
42
+ tokenizer.pad_token_id = tokenizer.eos_token_id
43
+
44
+ messages = []
45
+ enable_thinking = True
46
+ skip_prompt=True
47
+ skip_special_tokens=True
48
+
49
+ def direction_ablation_hook(activation: jaxtyping.Float[torch.Tensor, "... d_act"],
50
+ direction: jaxtyping.Float[torch.Tensor, "d_act"]):
51
+ proj = einops.einsum(activation, direction.view(-1, 1), '... d_act, d_act single -> ... single') * direction
52
+ return activation - proj
53
+
54
+ class AblationDecoderLayer(nn.Module):
55
+ def __init__(self, original_layer, refusal_dir):
56
+ super(AblationDecoderLayer, self).__init__()
57
+ self.original_layer = original_layer
58
+ self.refusal_dir = refusal_dir
59
+
60
+ def forward(self, *args, **kwargs):
61
+ hidden_states = args[0]
62
+ ablated = direction_ablation_hook(hidden_states, self.refusal_dir.to(hidden_states.device)).to(hidden_states.device)
63
+ args = (ablated,) + args[1:]
64
+ return self.original_layer.forward(*args, **kwargs)
65
+
66
+ class CustomTextStreamer(TextStreamer):
67
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
68
+ super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
69
+ self.generated_text = ""
70
+ self.stop_flag = False
71
+
72
+ def on_finalized_text(self, text: str, stream_end: bool = False):
73
+ self.generated_text += text
74
+ print(text, end="", flush=True)
75
+ if self.stop_flag:
76
+ raise StopIteration
77
+
78
+ def stop_generation(self):
79
+ self.stop_flag = True
80
+
81
+ def generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, max_new_tokens):
82
+ input_ids = tokenizer.apply_chat_template(
83
+ messages,
84
+ tokenize=True,
85
+ enable_thinking = enable_thinking,
86
+ add_generation_prompt=True,
87
+ return_tensors="pt"
88
+ )
89
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
90
+ tokens = input_ids.to(model.device)
91
+ attention_mask = attention_mask.to(model.device)
92
+
93
+ streamer = CustomTextStreamer(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
94
+
95
+ def signal_handler(sig, frame):
96
+ streamer.stop_generation()
97
+ print("\n[Generation stopped by user with Ctrl+C]")
98
+
99
+ signal.signal(signal.SIGINT, signal_handler)
100
+
101
+ print("Response: ", end="", flush=True)
102
+ try:
103
+ generated_ids = model.generate(
104
+ tokens,
105
+ attention_mask=attention_mask,
106
+ use_cache=False,
107
+ max_new_tokens=max_new_tokens,
108
+ do_sample=True,
109
+ pad_token_id=tokenizer.pad_token_id,
110
+ streamer=streamer
111
+ )
112
+ del generated_ids
113
+ except StopIteration:
114
+ print("\n[Stopped by user]")
115
+
116
+ del input_ids, attention_mask
117
+ torch.cuda.empty_cache()
118
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
119
+
120
+ return streamer.generated_text, streamer.stop_flag
121
+
122
+
123
+
124
+ final_refusal_dirs= torch.load(MODEL_ID + "/final_refusal_dirs.pt", map_location='cpu', weights_only=True)
125
+ # candidate layer, 16, 21 ...
126
+ candidate_layer = 16
127
+
128
+ refusal_dir = final_refusal_dirs[candidate_layer]
129
+
130
+ for idx in range(len(model.model.layers)):
131
+ model.model.layers[idx] = AblationDecoderLayer(model.model.layers[idx], refusal_dir)
132
+
133
+ while True:
134
+ user_input = input("User: ").strip()
135
+ if user_input.lower() == "/exit":
136
+ print("Exiting chat.")
137
+ break
138
+ if user_input.lower() == "/clear":
139
+ messages = []
140
+ print("Chat history cleared. Starting a new conversation.")
141
+ continue
142
+ if user_input.lower() == "/no_think":
143
+ if enable_thinking:
144
+ enable_thinking = False
145
+ print("Thinking = False.")
146
+ else:
147
+ enable_thinking = True
148
+ print("Thinking = True.")
149
+ continue
150
+ if user_input.lower() == "/skip_prompt":
151
+ if skip_prompt:
152
+ skip_prompt = False
153
+ print("skip_prompt = False.")
154
+ else:
155
+ skip_prompt = True
156
+ print("skip_prompt = True.")
157
+ continue
158
+ if user_input.lower() == "/skip_special_tokens":
159
+ if skip_special_tokens:
160
+ skip_special_tokens = False
161
+ print("skip_special_tokens = False.")
162
+ else:
163
+ skip_special_tokens = True
164
+ print("skip_special_tokens = True.")
165
+ continue
166
+ if not user_input:
167
+ print("Input cannot be empty. Please enter something.")
168
+ continue
169
+ messages.append({"role": "user", "content": user_input})
170
+ response, stop_flag = generate_stream(model, tokenizer, messages, enable_thinking, skip_prompt, skip_special_tokens, 8192)
171
+ print("", flush=True)
172
+ if stop_flag:
173
+ continue
174
+ messages.append({"role": "assistant", "content": response})