# ============================================================================== # Smol-MoE 8x135M - "Chat with Your Creation" # (Final Interactive Inference Script) # ============================================================================== import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.models.llama.modeling_llama import LlamaMLP import os # --- 1. 关键:重新定义你的所有自定义模块 --- # 这是让 from_pretrained() 能够成功重建你自定义模型的关键。 MODEL_PATH = "./SmolMoE-8x135M-Instruct-v1-Trained" # 从保存好的模型配置中读取MoE参数 config = AutoConfig.from_pretrained(MODEL_PATH) NUM_EXPERTS = config.moe_num_experts TOP_K = config.moe_top_k class MoERouter(nn.Module): def __init__(self, hidden_size: int, num_experts: int): super().__init__() self.layer = nn.Linear(hidden_size, num_experts, bias=False) def forward(self, hidden_states): return self.layer(hidden_states) class MoEModule(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.top_k = TOP_K self.num_experts = NUM_EXPERTS self.router = MoERouter(self.hidden_size, self.num_experts) self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states): original_shape = hidden_states.shape flat_hidden_states = hidden_states.view(-1, self.hidden_size) router_logits = self.router(flat_hidden_states) routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros_like(flat_hidden_states) for k in range(self.top_k): expert_indices_k = selected_experts[:, k] routing_weights_k = routing_weights[:, k] for i in range(self.num_experts): mask = expert_indices_k == i if mask.any(): expert_output = self.experts[i](flat_hidden_states[mask]) final_hidden_states.index_add_(0, torch.where(mask)[0], expert_output * routing_weights_k[mask].unsqueeze(1)) return final_hidden_states.view(*original_shape) # --- 2. 主程序:加载模型并开始对话 --- def main(): device = "cuda" if torch.cuda.is_available() else "cpu" # --- 模型加载 --- print(f"Loading tokenizer from '{MODEL_PATH}'...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) print(f"Manually rebuilding MoE model structure...") # 用`from_config`创建一个随机权重的、但结构正确的“空壳” moe_model = AutoModelForCausalLM.from_config(config) # 手动进行“架构手术”,把标准的MLP替换成我们的MoE模块 for i, layer in enumerate(moe_model.model.layers): layer.mlp = MoEModule(config) print(f"Loading your trained MoE weights into the correct structure...") from safetensors.torch import load_file state_dict = load_file(os.path.join(MODEL_PATH, "model.safetensors"), device="cpu") # 使用`strict=False`灵活加载,然后手动绑定权重 moe_model.load_state_dict(state_dict, strict=False) moe_model.tie_weights() moe_model.to(device, dtype=torch.bfloat16) moe_model.eval() # 切换到评估模式 print("--- MoE Model is ready for conversation! ---") print("Type 'exit' or 'quit' to end the chat.\n") # --- 交互式对话循环 --- messages = [] while True: try: user_input = input("You: ") if user_input.lower() in ["exit", "quit"]: print("Goodbye!") break # 1. 将用户输入添加到对话历史 messages.append({"role": "user", "content": user_input}) # 2. 使用聊天模板格式化完整的对话历史 # `add_generation_prompt=True` 会在末尾添加助手角色的起始标记 prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # 3. 编码输入并发送到GPU inputs = tokenizer(prompt_text, return_tensors="pt").to(device) # 4. 生成回复 with torch.no_grad(): outputs = moe_model.generate( **inputs, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True ) # 5. 解码并清理输出 # `outputs[0]` 包含了完整的对话(输入+输出),我们需要提取出模型新生成的部分 full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # 通过移除原始prompt来找到新生成的部分 model_response = full_response.replace(prompt_text.replace(" ", "").replace("", ""), "").strip() print(f"MoE Model: {model_response}") # 6. 将模型的回复也添加到对话历史中,以便进行多轮对话 messages.append({"role": "assistant", "content": model_response}) except KeyboardInterrupt: print("\nGoodbye!") break if __name__ == "__main__": main()