import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import seaborn as sns from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from transformers import LlamaTokenizer, LlamaForCausalLM import pandas as pd from sklearn.metrics.pairwise import cosine_similarity import lime from lime.lime_text import LimeTextExplainer import shap import re import warnings warnings.filterwarnings('ignore') class LLMExplainabilityAnalyzer: def __init__(self, model_path, tokenizer_path=None): """Initialize with model and tokenizer paths""" self.model_path = model_path self.tokenizer_path = tokenizer_path or model_path self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model and tokenizer self.load_model() # Initialize explanation tools self.lime_explainer = LimeTextExplainer(class_names=['Generated Text']) def load_model(self): """Load the fine-tuned model and tokenizer""" try: print(f"Loading model from: {self.model_path}") self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) self.model = AutoModelForCausalLM.from_pretrained( self.model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Set padding token if not exists if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # Fallback to base model print("Loading base TinyLlama model...") self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def extract_attention_weights(self, text, max_length=512): """Extract attention weights for visualization""" inputs = self.tokenizer( text, return_tensors="pt", max_length=max_length, truncation=True, padding=True ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs, output_attentions=True) attentions = outputs.attentions # Get tokens tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) return attentions, tokens def visualize_attention_heads(self, text, layer_idx=0, head_idx=0, max_length=512): """Visualize attention patterns for specific layer and head""" attentions, tokens = self.extract_attention_weights(text, max_length) # Get attention weights for specific layer and head attention_weights = attentions[layer_idx][0, head_idx].cpu().numpy() # Create heatmap plt.figure(figsize=(12, 8)) sns.heatmap( attention_weights, xticklabels=tokens, yticklabels=tokens, cmap='Blues', cbar=True ) plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}') plt.xlabel('Key Tokens') plt.ylabel('Query Tokens') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.show() return attention_weights, tokens def attention_rollout(self, text, max_length=512): """Compute attention rollout for global attention patterns""" attentions, tokens = self.extract_attention_weights(text, max_length) # Convert to numpy attention_matrices = [att[0].mean(dim=0).cpu().numpy() for att in attentions] # Compute rollout rollout = attention_matrices[0] for attention_matrix in attention_matrices[1:]: rollout = np.matmul(rollout, attention_matrix) # Visualize rollout plt.figure(figsize=(12, 8)) sns.heatmap( rollout, xticklabels=tokens, yticklabels=tokens, cmap='Reds', cbar=True ) plt.title('Attention Rollout - Global Attention Flow') plt.xlabel('Key Tokens') plt.ylabel('Query Tokens') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.show() return rollout, tokens def gradient_saliency(self, text, target_token_idx=None, max_length=512): """Compute gradient-based saliency maps""" inputs = self.tokenizer( text, return_tensors="pt", max_length=max_length, truncation=True, padding=True ).to(self.device) # Enable gradients for embeddings embeddings = self.model.get_input_embeddings() inputs_embeds = embeddings(inputs['input_ids']) inputs_embeds.requires_grad_(True) # Forward pass outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=inputs['attention_mask']) # Get target logits (last token if not specified) if target_token_idx is None: target_token_idx = -1 target_logits = outputs.logits[0, target_token_idx] target_prob = F.softmax(target_logits, dim=-1) # Compute gradients target_prob.max().backward() # Get saliency scores saliency_scores = inputs_embeds.grad.norm(dim=-1).squeeze().cpu().numpy() tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # Visualize saliency plt.figure(figsize=(15, 6)) colors = plt.cm.Reds(saliency_scores / saliency_scores.max()) for i, (token, score) in enumerate(zip(tokens, saliency_scores)): plt.bar(i, score, color=colors[i]) plt.text(i, score + 0.001, token, rotation=45, ha='left', va='bottom') plt.title('Gradient Saliency Scores') plt.xlabel('Token Position') plt.ylabel('Saliency Score') plt.tight_layout() plt.show() return saliency_scores, tokens def lime_explanation(self, text, num_samples=1000): """Generate LIME explanations""" def predict_fn(texts): """Prediction function for LIME""" predictions = [] for text in texts: try: inputs = self.tokenizer( text, return_tensors="pt", max_length=512, truncation=True, padding=True ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits[0, -1] probs = F.softmax(logits, dim=-1) # Return probability distribution predictions.append(probs.cpu().numpy()) except: # Return uniform distribution if error predictions.append(np.ones(self.tokenizer.vocab_size) / self.tokenizer.vocab_size) return np.array(predictions) # Generate explanation explanation = self.lime_explainer.explain_instance( text, predict_fn, num_features=20, num_samples=num_samples ) # Visualize explanation explanation.show_in_notebook(text=True) return explanation def activation_analysis(self, text, layer_indices=None, max_length=512): """Analyze hidden layer activations""" inputs = self.tokenizer( text, return_tensors="pt", max_length=max_length, truncation=True, padding=True ).to(self.device) # Hook to capture activations activations = {} def hook_fn(name): def hook(module, input, output): activations[name] = output.detach() return hook # Register hooks if layer_indices is None: layer_indices = [0, len(self.model.model.layers)//2, len(self.model.model.layers)-1] hooks = [] for idx in layer_indices: if idx < len(self.model.model.layers): hook = self.model.model.layers[idx].register_forward_hook(hook_fn(f'layer_{idx}')) hooks.append(hook) # Forward pass with torch.no_grad(): outputs = self.model(**inputs) # Remove hooks for hook in hooks: hook.remove() # Analyze activations tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) for layer_name, activation in activations.items(): # Get activation statistics activation_np = activation[0].cpu().numpy() # Plot activation distribution plt.figure(figsize=(12, 6)) # Heatmap of activations plt.subplot(1, 2, 1) sns.heatmap(activation_np.T, cmap='viridis', cbar=True) plt.title(f'{layer_name} Activations') plt.xlabel('Token Position') plt.ylabel('Hidden Dimension') # Activation magnitude per token plt.subplot(1, 2, 2) activation_magnitudes = np.linalg.norm(activation_np, axis=1) plt.bar(range(len(tokens)), activation_magnitudes) plt.title(f'{layer_name} Activation Magnitudes') plt.xlabel('Token Position') plt.ylabel('Magnitude') plt.xticks(range(len(tokens)), tokens, rotation=45) plt.tight_layout() plt.show() def token_importance_analysis(self, text, method='attention', max_length=512): """Analyze token importance using different methods""" results = {} if method == 'attention': # Attention-based importance attentions, tokens = self.extract_attention_weights(text, max_length) # Average attention across layers and heads avg_attention = torch.stack([att.mean(dim=1) for att in attentions]).mean(dim=0) importance_scores = avg_attention[0].sum(dim=0).cpu().numpy() elif method == 'gradient': # Gradient-based importance importance_scores, tokens = self.gradient_saliency(text, max_length=max_length) # Create importance dataframe importance_df = pd.DataFrame({ 'token': tokens, 'importance': importance_scores }) # Sort by importance importance_df = importance_df.sort_values('importance', ascending=False) # Visualize top important tokens plt.figure(figsize=(12, 6)) top_tokens = importance_df.head(20) plt.barh(range(len(top_tokens)), top_tokens['importance']) plt.yticks(range(len(top_tokens)), top_tokens['token']) plt.title(f'Top 20 Important Tokens ({method.title()} Method)') plt.xlabel('Importance Score') plt.tight_layout() plt.show() return importance_df def semantic_similarity_analysis(self, texts, max_length=512): """Analyze semantic similarity between different texts""" embeddings = [] for text in texts: inputs = self.tokenizer( text, return_tensors="pt", max_length=max_length, truncation=True, padding=True ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) # Use last layer, last token embedding embedding = outputs.hidden_states[-1][0, -1].cpu().numpy() embeddings.append(embedding) # Compute similarity matrix similarity_matrix = cosine_similarity(embeddings) # Visualize similarity matrix plt.figure(figsize=(10, 8)) sns.heatmap( similarity_matrix, annot=True, cmap='viridis', xticklabels=[f'Text {i+1}' for i in range(len(texts))], yticklabels=[f'Text {i+1}' for i in range(len(texts))] ) plt.title('Semantic Similarity Matrix') plt.tight_layout() plt.show() return similarity_matrix def generate_explanation_report(self, text, output_file='xai_report.html'): """Generate comprehensive explanation report""" print("Generating comprehensive XAI report...") # Run all analyses print("1. Extracting attention patterns...") attention_weights, tokens = self.visualize_attention_heads(text) print("2. Computing attention rollout...") rollout, _ = self.attention_rollout(text) print("3. Calculating gradient saliency...") saliency_scores, _ = self.gradient_saliency(text) print("4. Analyzing activations...") self.activation_analysis(text) print("5. Computing token importance...") importance_df = self.token_importance_analysis(text) # Create summary print("\n=== XAI ANALYSIS SUMMARY ===") print(f"Input text: {text[:100]}...") print(f"Number of tokens: {len(tokens)}") print(f"Most important tokens: {importance_df.head(5)['token'].tolist()}") print(f"Average attention entropy: {np.mean(-np.sum(attention_weights * np.log(attention_weights + 1e-10), axis=1)):.4f}") return { 'attention_weights': attention_weights, 'rollout': rollout, 'saliency_scores': saliency_scores, 'importance_df': importance_df, 'tokens': tokens } def main(): """Main function to run XAI analysis""" # Initialize analyzer (adjust model path as needed) try: analyzer = LLMExplainabilityAnalyzer("./fine_tuned_model") except: print("Fine-tuned model not found. Using base model for demonstration.") analyzer = LLMExplainabilityAnalyzer("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # Sample skin disease text for analysis sample_text = """ Patient presents with erythematous scaly patches on the elbows and knees, consistent with psoriasis. The condition appears to be chronic with periods of exacerbation. Treatment options include topical corticosteroids and phototherapy for mild to moderate cases. """ print("Starting XAI Analysis...") print("=" * 50) # Generate comprehensive report results = analyzer.generate_explanation_report(sample_text) # Additional analyses print("\n6. Semantic similarity analysis...") test_texts = [ "Psoriasis treatment with topical corticosteroids", "Eczema management using moisturizers", "Melanoma diagnosis and surgical intervention" ] similarity_matrix = analyzer.semantic_similarity_analysis(test_texts) print("\n" + "=" * 50) print("XAI ANALYSIS COMPLETE") print("=" * 50) return results if __name__ == "__main__": main()