import os import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ) from peft import PeftModel import warnings from datetime import datetime import json # Suppress warnings for cleaner output warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) os.environ['TOKENIZERS_PARALLELISM'] = 'false' class LlamaChat: def __init__(self, model_path, system_message=None, use_quantization=True, max_memory_gb=8): """ Initialize the chat interface with the fine-tuned Llama model Args: model_path: Path to the fine-tuned model directory system_message: System message to use for conversations (persona/context) use_quantization: Whether to use 4-bit quantization (recommended for 8GB GPU) max_memory_gb: Maximum GPU memory to use """ self.model_path = model_path self.use_quantization = use_quantization self.max_memory_gb = max_memory_gb # Default system message if none provided self.system_message = system_message or ( "You are Alexander Molchevskyi โ€” a senior software engineer with over 20 years " "of professional experience across embedded, desktop, and server systems. " "Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and " "developer tooling. You answer interview questions clearly, professionally, and naturally." ) print("๐Ÿš€ Loading Llama Chat Interface...") print(f"Model path: {model_path}") print(f"System message: {self.system_message[:100]}{'...' if len(self.system_message) > 100 else ''}") # Check CUDA availability if torch.cuda.is_available(): print(f"โœ… CUDA available: {torch.cuda.get_device_name()}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") else: print("โš ๏ธ CUDA not available, using CPU (will be slow)") self.tokenizer = None self.model = None self.conversation_history = [] self._load_model() def _setup_quantization_config(self): """Setup 4-bit quantization config for memory efficiency""" if not self.use_quantization: return None return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) def _load_model(self): """Load the tokenizer and model""" try: print("๐Ÿ“š Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True, padding_side="left" # For generation ) # Add pad token if it doesn't exist if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id print("๐Ÿง  Loading base model...") # Setup quantization if requested quantization_config = self._setup_quantization_config() # Check if this is a PEFT model (has adapter_config.json) adapter_config_path = os.path.join(self.model_path, "adapter_config.json") is_peft_model = os.path.exists(adapter_config_path) if is_peft_model: print("๐Ÿ”ง Detected PEFT (LoRA) model, loading base model first...") # Load adapter config to get base model name with open(adapter_config_path, 'r') as f: adapter_config = json.load(f) base_model_name = adapter_config.get('base_model_name_or_path', 'llama-3.2-3b') print(f"Base model: {base_model_name}") # Load base model base_model = AutoModelForCausalLM.from_pretrained( base_model_name, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, use_cache=True, # Enable cache for inference ) # Load PEFT model (LoRA adapter) print("๐ŸŽฏ Loading LoRA adapter...") self.model = PeftModel.from_pretrained(base_model, self.model_path) else: # Regular fine-tuned model (not PEFT) print("๐Ÿ“ฆ Loading fine-tuned model...") self.model = AutoModelForCausalLM.from_pretrained( self.model_path, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, use_cache=True, # Enable cache for inference ) # Set model to evaluation mode self.model.eval() print("โœ… Model loaded successfully!") # Print model info if hasattr(self.model, 'print_trainable_parameters'): self.model.print_trainable_parameters() except Exception as e: print(f"โŒ Error loading model: {str(e)}") raise def _format_message(self, user_message): """Format user message with system context using Llama's chat template""" return f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{self.system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" def generate_response(self, user_message, max_new_tokens=200, temperature=0.7, top_p=0.9, repetition_penalty=1.1, do_sample=True): """ Generate a response to the user message Args: user_message: The user's input message max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature (higher = more random) top_p: Nucleus sampling parameter repetition_penalty: Penalty for repeating tokens do_sample: Whether to use sampling or greedy decoding """ try: # Format the input formatted_input = self._format_message(user_message) # Tokenize input inputs = self.tokenizer( formatted_input, return_tensors="pt", truncation=True, max_length=1024 # Increased to match training max_length ).to(self.model.device) # Generate response print("๐Ÿค” Thinking...") with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, num_return_sequences=1, ) # Decode the response full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response (after the last assistant header) assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() # Clean up any remaining tokens assistant_response = assistant_response.replace("<|eot_id|>", "").strip() return assistant_response except Exception as e: return f"โŒ Error generating response: {str(e)}" def chat_loop(self): """Main chat loop""" print("\n" + "="*60) print("๐Ÿฆ™ LLAMA FINE-TUNED CHAT INTERFACE") print("="*60) print("Commands:") print(" โ€ข Type your message and press Enter") print(" โ€ข '/help' - Show this help") print(" โ€ข '/system' - View or change system message") print(" โ€ข '/settings' - Adjust generation settings") print(" โ€ข '/history' - Show conversation history") print(" โ€ข '/clear' - Clear conversation history") print(" โ€ข '/save' - Save conversation to file") print(" โ€ข '/quit' or '/exit' - Exit the chat") print("="*60) # Default generation settings settings = { 'max_new_tokens': 200, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.1, 'do_sample': True } while True: try: # Get user input user_input = input("\n๐Ÿ‘ค You: ").strip() if not user_input: continue # Handle commands if user_input.lower() in ['/quit', '/exit']: print("๐Ÿ‘‹ Goodbye!") break elif user_input.lower() == '/help': self._show_help() continue elif user_input.lower() == '/system': self._manage_system_message() continue elif user_input.lower() == '/settings': settings = self._adjust_settings(settings) continue elif user_input.lower() == '/history': self._show_history() continue elif user_input.lower() == '/clear': self.conversation_history.clear() print("๐Ÿงน Conversation history cleared!") continue elif user_input.lower() == '/save': self._save_conversation() continue # Generate response response = self.generate_response(user_input, **settings) # Display response print(f"\n๐Ÿฆ™ Alexander: {response}") # Save to history self.conversation_history.append({ 'timestamp': datetime.now().isoformat(), 'system': self.system_message, 'user': user_input, 'assistant': response }) except KeyboardInterrupt: print("\n\n๐Ÿ‘‹ Chat interrupted. Goodbye!") break except Exception as e: print(f"\nโŒ Error: {str(e)}") def _manage_system_message(self): """Allow user to view or change the system message""" print("\n๐Ÿค– SYSTEM MESSAGE MANAGEMENT:") print("Current system message:") print("-" * 60) print(self.system_message) print("-" * 60) choice = input("\nOptions: [v]iew, [c]hange, or [Enter] to go back: ").strip().lower() if choice == 'c' or choice == 'change': print("\nEnter new system message (or press Enter to keep current):") new_system = input("> ").strip() if new_system: self.system_message = new_system print("โœ… System message updated!") print("Note: This will affect all future conversations.") else: print("System message unchanged.") elif choice == 'v' or choice == 'view': # Already displayed above pass def _show_help(self): """Show help information""" print("\n๐Ÿ“‹ HELP:") print("This is a chat interface for your fine-tuned Llama model.") print("The model has been trained with system messages to embody Alexander Molchevskyi's") print("professional persona and expertise in software engineering.") print("\nTips:") print("โ€ข Ask technical questions about software engineering, AI, or development") print("โ€ข The model maintains context of being Alexander throughout conversations") print("โ€ข Use /system to view or modify the professional persona") print("โ€ข Use /settings to adjust creativity (temperature) and response length") print("โ€ข Higher temperature = more creative but less consistent") print("โ€ข Lower temperature = more focused and consistent") def _adjust_settings(self, current_settings): """Allow user to adjust generation settings""" print("\nโš™๏ธ GENERATION SETTINGS:") print("Current settings:") for key, value in current_settings.items(): print(f" {key}: {value}") new_settings = current_settings.copy() try: # Max tokens max_tokens = input(f"\nMax response length ({current_settings['max_new_tokens']}): ").strip() if max_tokens: new_settings['max_new_tokens'] = max(1, min(500, int(max_tokens))) # Temperature temp = input(f"Temperature 0.1-2.0 ({current_settings['temperature']}): ").strip() if temp: new_settings['temperature'] = max(0.1, min(2.0, float(temp))) # Top-p top_p = input(f"Top-p 0.1-1.0 ({current_settings['top_p']}): ").strip() if top_p: new_settings['top_p'] = max(0.1, min(1.0, float(top_p))) # Repetition penalty rep_penalty = input(f"Repetition penalty 1.0-2.0 ({current_settings['repetition_penalty']}): ").strip() if rep_penalty: new_settings['repetition_penalty'] = max(1.0, min(2.0, float(rep_penalty))) print("โœ… Settings updated!") return new_settings except ValueError: print("โŒ Invalid input. Settings unchanged.") return current_settings def _show_history(self): """Show conversation history""" if not self.conversation_history: print("๐Ÿ“ No conversation history yet.") return print(f"\n๐Ÿ“œ CONVERSATION HISTORY ({len(self.conversation_history)} exchanges):") print("-" * 50) for i, exchange in enumerate(self.conversation_history[-5:], 1): # Show last 5 timestamp = exchange['timestamp'].split('T')[1].split('.')[0] # Just time print(f"\n[{timestamp}]") print(f"๐Ÿ‘ค You: {exchange['user']}") print(f"๐Ÿฆ™ Alexander: {exchange['assistant'][:100]}{'...' if len(exchange['assistant']) > 100 else ''}") if len(self.conversation_history) > 5: print(f"\n... and {len(self.conversation_history) - 5} more exchanges") def _save_conversation(self): """Save conversation to a JSON file""" if not self.conversation_history: print("๐Ÿ“ No conversation to save.") return timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"llama_chat_{timestamp}.json" try: with open(filename, 'w', encoding='utf-8') as f: json.dump(self.conversation_history, f, indent=2, ensure_ascii=False) print(f"๐Ÿ’พ Conversation saved to: {filename}") except Exception as e: print(f"โŒ Error saving conversation: {str(e)}") def main(): """Main function to start the chat interface""" # Configuration MODEL_PATH = "llama-3.2-3b-finetuned" # Path to your fine-tuned model # Default system message (can be customized) DEFAULT_SYSTEM_MESSAGE = ( "You are Alexander Molchevskyi โ€” a senior software engineer with over 20 years " "of professional experience across embedded, desktop, and server systems. " "Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and " "developer tooling. You answer interview questions clearly, professionally, and naturally." ) # Check if model directory exists if not os.path.exists(MODEL_PATH): print(f"โŒ Model directory not found: {MODEL_PATH}") print("Please make sure you have run the fine-tuning script first.") return try: # Initialize chat interface chat = LlamaChat( model_path=MODEL_PATH, system_message=DEFAULT_SYSTEM_MESSAGE, use_quantization=True, # Set to False if you have plenty of GPU memory max_memory_gb=8 ) # Start chat loop chat.chat_loop() except Exception as e: print(f"โŒ Failed to initialize chat interface: {str(e)}") print("\nTroubleshooting tips:") print("1. Make sure the model was trained successfully") print("2. Check that all required libraries are installed") print("3. Ensure you have sufficient GPU memory") print("4. Try setting use_quantization=True to reduce memory usage") if __name__ == "__main__": main()