|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_name = "models/Llama-3.2-1B-Instruct" | 
					
						
						|  | tok = AutoTokenizer.from_pretrained(model_name) | 
					
						
						|  | lm = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | model_name, | 
					
						
						|  | torch_dtype=torch.bfloat16, | 
					
						
						|  | device_map="cuda", | 
					
						
						|  | ).eval() | 
					
						
						|  |  | 
					
						
						|  | def chat_current(system_prompt: str, user_prompt: str) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Current implementation (same as server.py) - will show warnings | 
					
						
						|  | """ | 
					
						
						|  | print("🔴 Running CURRENT implementation (with warnings)...") | 
					
						
						|  |  | 
					
						
						|  | messages = [ | 
					
						
						|  | {"role": "system", "content": system_prompt}, | 
					
						
						|  | {"role": "user", "content": user_prompt}, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | input_ids = tok.apply_chat_template( | 
					
						
						|  | messages, | 
					
						
						|  | add_generation_prompt=True, | 
					
						
						|  | return_tensors="pt" | 
					
						
						|  | ).to(lm.device) | 
					
						
						|  |  | 
					
						
						|  | with torch.inference_mode(): | 
					
						
						|  | output_ids = lm.generate( | 
					
						
						|  | input_ids, | 
					
						
						|  | max_new_tokens=2048, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | temperature=0.2, | 
					
						
						|  | repetition_penalty=1.1, | 
					
						
						|  | top_k=100, | 
					
						
						|  | top_p=0.95, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | answer = tok.decode( | 
					
						
						|  | output_ids[0][input_ids.shape[-1]:], | 
					
						
						|  | skip_special_tokens=True, | 
					
						
						|  | clean_up_tokenization_spaces=True, | 
					
						
						|  | ) | 
					
						
						|  | return answer.strip() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_fixed(system_prompt: str, user_prompt: str) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Fixed implementation - proper attention mask and pad token | 
					
						
						|  | """ | 
					
						
						|  | print("🟢 Running FIXED implementation (no warnings)...") | 
					
						
						|  |  | 
					
						
						|  | messages = [ | 
					
						
						|  | {"role": "system", "content": system_prompt}, | 
					
						
						|  | {"role": "user", "content": user_prompt}, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | inputs = tok.apply_chat_template( | 
					
						
						|  | messages, | 
					
						
						|  | add_generation_prompt=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | return_dict=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | input_ids = inputs["input_ids"].to(lm.device) | 
					
						
						|  | attention_mask = inputs["attention_mask"].to(lm.device) | 
					
						
						|  |  | 
					
						
						|  | with torch.inference_mode(): | 
					
						
						|  | output_ids = lm.generate( | 
					
						
						|  | input_ids=input_ids, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | pad_token_id=tok.eos_token_id, | 
					
						
						|  | max_new_tokens=2048, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | temperature=0.2, | 
					
						
						|  | repetition_penalty=1.1, | 
					
						
						|  | top_k=100, | 
					
						
						|  | top_p=0.95, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | answer = tok.decode( | 
					
						
						|  | output_ids[0][input_ids.shape[-1]:], | 
					
						
						|  | skip_special_tokens=True, | 
					
						
						|  | clean_up_tokenization_spaces=True, | 
					
						
						|  | ) | 
					
						
						|  | return answer.strip() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compare_generations(): | 
					
						
						|  | """Compare both implementations""" | 
					
						
						|  | system_prompt = "You are a helpful assistant who tries to help answer the user's question." | 
					
						
						|  | user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?" | 
					
						
						|  |  | 
					
						
						|  | print("=" * 60) | 
					
						
						|  | print("COMPARING GENERATION METHODS") | 
					
						
						|  | print("=" * 60) | 
					
						
						|  | print(f"System: {system_prompt}") | 
					
						
						|  | print(f"User: {user_prompt}") | 
					
						
						|  | print("=" * 60) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("\n" + "=" * 60) | 
					
						
						|  | current_output = chat_current(system_prompt, user_prompt) | 
					
						
						|  | print(f"CURRENT OUTPUT:\n{current_output}") | 
					
						
						|  |  | 
					
						
						|  | print("\n" + "=" * 60) | 
					
						
						|  |  | 
					
						
						|  | fixed_output = chat_fixed(system_prompt, user_prompt) | 
					
						
						|  | print(f"FIXED OUTPUT:\n{fixed_output}") | 
					
						
						|  |  | 
					
						
						|  | print("\n" + "=" * 60) | 
					
						
						|  | print("COMPARISON:") | 
					
						
						|  | print(f"Outputs are identical: {current_output == fixed_output}") | 
					
						
						|  | print(f"Current length: {len(current_output)} chars") | 
					
						
						|  | print(f"Fixed length: {len(fixed_output)} chars") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  |  | 
					
						
						|  | if tok.pad_token is None: | 
					
						
						|  | tok.pad_token = tok.eos_token | 
					
						
						|  |  | 
					
						
						|  | compare_generations() |