|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from transformers import LlamaTokenizer, LlamaForCausalLM |
|
import pandas as pd |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import lime |
|
from lime.lime_text import LimeTextExplainer |
|
import shap |
|
import re |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
class LLMExplainabilityAnalyzer: |
|
def __init__(self, model_path, tokenizer_path=None): |
|
"""Initialize with model and tokenizer paths""" |
|
self.model_path = model_path |
|
self.tokenizer_path = tokenizer_path or model_path |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.load_model() |
|
|
|
|
|
self.lime_explainer = LimeTextExplainer(class_names=['Generated Text']) |
|
|
|
def load_model(self): |
|
"""Load the fine-tuned model and tokenizer""" |
|
try: |
|
print(f"Loading model from: {self.model_path}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_path, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None |
|
) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
print("Model loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
print("Loading base TinyLlama model...") |
|
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def extract_attention_weights(self, text, max_length=512): |
|
"""Extract attention weights for visualization""" |
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs, output_attentions=True) |
|
attentions = outputs.attentions |
|
|
|
|
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
return attentions, tokens |
|
|
|
def visualize_attention_heads(self, text, layer_idx=0, head_idx=0, max_length=512): |
|
"""Visualize attention patterns for specific layer and head""" |
|
attentions, tokens = self.extract_attention_weights(text, max_length) |
|
|
|
|
|
attention_weights = attentions[layer_idx][0, head_idx].cpu().numpy() |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
sns.heatmap( |
|
attention_weights, |
|
xticklabels=tokens, |
|
yticklabels=tokens, |
|
cmap='Blues', |
|
cbar=True |
|
) |
|
plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}') |
|
plt.xlabel('Key Tokens') |
|
plt.ylabel('Query Tokens') |
|
plt.xticks(rotation=45) |
|
plt.yticks(rotation=0) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
return attention_weights, tokens |
|
|
|
def attention_rollout(self, text, max_length=512): |
|
"""Compute attention rollout for global attention patterns""" |
|
attentions, tokens = self.extract_attention_weights(text, max_length) |
|
|
|
|
|
attention_matrices = [att[0].mean(dim=0).cpu().numpy() for att in attentions] |
|
|
|
|
|
rollout = attention_matrices[0] |
|
for attention_matrix in attention_matrices[1:]: |
|
rollout = np.matmul(rollout, attention_matrix) |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
sns.heatmap( |
|
rollout, |
|
xticklabels=tokens, |
|
yticklabels=tokens, |
|
cmap='Reds', |
|
cbar=True |
|
) |
|
plt.title('Attention Rollout - Global Attention Flow') |
|
plt.xlabel('Key Tokens') |
|
plt.ylabel('Query Tokens') |
|
plt.xticks(rotation=45) |
|
plt.yticks(rotation=0) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
return rollout, tokens |
|
|
|
def gradient_saliency(self, text, target_token_idx=None, max_length=512): |
|
"""Compute gradient-based saliency maps""" |
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
|
|
embeddings = self.model.get_input_embeddings() |
|
inputs_embeds = embeddings(inputs['input_ids']) |
|
inputs_embeds.requires_grad_(True) |
|
|
|
|
|
outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=inputs['attention_mask']) |
|
|
|
|
|
if target_token_idx is None: |
|
target_token_idx = -1 |
|
|
|
target_logits = outputs.logits[0, target_token_idx] |
|
target_prob = F.softmax(target_logits, dim=-1) |
|
|
|
|
|
target_prob.max().backward() |
|
|
|
|
|
saliency_scores = inputs_embeds.grad.norm(dim=-1).squeeze().cpu().numpy() |
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
|
|
plt.figure(figsize=(15, 6)) |
|
colors = plt.cm.Reds(saliency_scores / saliency_scores.max()) |
|
|
|
for i, (token, score) in enumerate(zip(tokens, saliency_scores)): |
|
plt.bar(i, score, color=colors[i]) |
|
plt.text(i, score + 0.001, token, rotation=45, ha='left', va='bottom') |
|
|
|
plt.title('Gradient Saliency Scores') |
|
plt.xlabel('Token Position') |
|
plt.ylabel('Saliency Score') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
return saliency_scores, tokens |
|
|
|
def lime_explanation(self, text, num_samples=1000): |
|
"""Generate LIME explanations""" |
|
def predict_fn(texts): |
|
"""Prediction function for LIME""" |
|
predictions = [] |
|
for text in texts: |
|
try: |
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=512, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
logits = outputs.logits[0, -1] |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
predictions.append(probs.cpu().numpy()) |
|
except: |
|
|
|
predictions.append(np.ones(self.tokenizer.vocab_size) / self.tokenizer.vocab_size) |
|
|
|
return np.array(predictions) |
|
|
|
|
|
explanation = self.lime_explainer.explain_instance( |
|
text, |
|
predict_fn, |
|
num_features=20, |
|
num_samples=num_samples |
|
) |
|
|
|
|
|
explanation.show_in_notebook(text=True) |
|
|
|
return explanation |
|
|
|
def activation_analysis(self, text, layer_indices=None, max_length=512): |
|
"""Analyze hidden layer activations""" |
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
|
|
activations = {} |
|
|
|
def hook_fn(name): |
|
def hook(module, input, output): |
|
activations[name] = output.detach() |
|
return hook |
|
|
|
|
|
if layer_indices is None: |
|
layer_indices = [0, len(self.model.model.layers)//2, len(self.model.model.layers)-1] |
|
|
|
hooks = [] |
|
for idx in layer_indices: |
|
if idx < len(self.model.model.layers): |
|
hook = self.model.model.layers[idx].register_forward_hook(hook_fn(f'layer_{idx}')) |
|
hooks.append(hook) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
for hook in hooks: |
|
hook.remove() |
|
|
|
|
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
for layer_name, activation in activations.items(): |
|
|
|
activation_np = activation[0].cpu().numpy() |
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
sns.heatmap(activation_np.T, cmap='viridis', cbar=True) |
|
plt.title(f'{layer_name} Activations') |
|
plt.xlabel('Token Position') |
|
plt.ylabel('Hidden Dimension') |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
activation_magnitudes = np.linalg.norm(activation_np, axis=1) |
|
plt.bar(range(len(tokens)), activation_magnitudes) |
|
plt.title(f'{layer_name} Activation Magnitudes') |
|
plt.xlabel('Token Position') |
|
plt.ylabel('Magnitude') |
|
plt.xticks(range(len(tokens)), tokens, rotation=45) |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def token_importance_analysis(self, text, method='attention', max_length=512): |
|
"""Analyze token importance using different methods""" |
|
results = {} |
|
|
|
if method == 'attention': |
|
|
|
attentions, tokens = self.extract_attention_weights(text, max_length) |
|
|
|
|
|
avg_attention = torch.stack([att.mean(dim=1) for att in attentions]).mean(dim=0) |
|
importance_scores = avg_attention[0].sum(dim=0).cpu().numpy() |
|
|
|
elif method == 'gradient': |
|
|
|
importance_scores, tokens = self.gradient_saliency(text, max_length=max_length) |
|
|
|
|
|
importance_df = pd.DataFrame({ |
|
'token': tokens, |
|
'importance': importance_scores |
|
}) |
|
|
|
|
|
importance_df = importance_df.sort_values('importance', ascending=False) |
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
top_tokens = importance_df.head(20) |
|
plt.barh(range(len(top_tokens)), top_tokens['importance']) |
|
plt.yticks(range(len(top_tokens)), top_tokens['token']) |
|
plt.title(f'Top 20 Important Tokens ({method.title()} Method)') |
|
plt.xlabel('Importance Score') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
return importance_df |
|
|
|
def semantic_similarity_analysis(self, texts, max_length=512): |
|
"""Analyze semantic similarity between different texts""" |
|
embeddings = [] |
|
|
|
for text in texts: |
|
inputs = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
truncation=True, |
|
padding=True |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs, output_hidden_states=True) |
|
|
|
embedding = outputs.hidden_states[-1][0, -1].cpu().numpy() |
|
embeddings.append(embedding) |
|
|
|
|
|
similarity_matrix = cosine_similarity(embeddings) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap( |
|
similarity_matrix, |
|
annot=True, |
|
cmap='viridis', |
|
xticklabels=[f'Text {i+1}' for i in range(len(texts))], |
|
yticklabels=[f'Text {i+1}' for i in range(len(texts))] |
|
) |
|
plt.title('Semantic Similarity Matrix') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
return similarity_matrix |
|
|
|
def generate_explanation_report(self, text, output_file='xai_report.html'): |
|
"""Generate comprehensive explanation report""" |
|
print("Generating comprehensive XAI report...") |
|
|
|
|
|
print("1. Extracting attention patterns...") |
|
attention_weights, tokens = self.visualize_attention_heads(text) |
|
|
|
print("2. Computing attention rollout...") |
|
rollout, _ = self.attention_rollout(text) |
|
|
|
print("3. Calculating gradient saliency...") |
|
saliency_scores, _ = self.gradient_saliency(text) |
|
|
|
print("4. Analyzing activations...") |
|
self.activation_analysis(text) |
|
|
|
print("5. Computing token importance...") |
|
importance_df = self.token_importance_analysis(text) |
|
|
|
|
|
print("\n=== XAI ANALYSIS SUMMARY ===") |
|
print(f"Input text: {text[:100]}...") |
|
print(f"Number of tokens: {len(tokens)}") |
|
print(f"Most important tokens: {importance_df.head(5)['token'].tolist()}") |
|
print(f"Average attention entropy: {np.mean(-np.sum(attention_weights * np.log(attention_weights + 1e-10), axis=1)):.4f}") |
|
|
|
return { |
|
'attention_weights': attention_weights, |
|
'rollout': rollout, |
|
'saliency_scores': saliency_scores, |
|
'importance_df': importance_df, |
|
'tokens': tokens |
|
} |
|
|
|
def main(): |
|
"""Main function to run XAI analysis""" |
|
|
|
|
|
try: |
|
analyzer = LLMExplainabilityAnalyzer("./fine_tuned_model") |
|
except: |
|
print("Fine-tuned model not found. Using base model for demonstration.") |
|
analyzer = LLMExplainabilityAnalyzer("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
|
|
|
sample_text = """ |
|
Patient presents with erythematous scaly patches on the elbows and knees, |
|
consistent with psoriasis. The condition appears to be chronic with periods |
|
of exacerbation. Treatment options include topical corticosteroids and |
|
phototherapy for mild to moderate cases. |
|
""" |
|
|
|
print("Starting XAI Analysis...") |
|
print("=" * 50) |
|
|
|
|
|
results = analyzer.generate_explanation_report(sample_text) |
|
|
|
|
|
print("\n6. Semantic similarity analysis...") |
|
test_texts = [ |
|
"Psoriasis treatment with topical corticosteroids", |
|
"Eczema management using moisturizers", |
|
"Melanoma diagnosis and surgical intervention" |
|
] |
|
|
|
similarity_matrix = analyzer.semantic_similarity_analysis(test_texts) |
|
|
|
print("\n" + "=" * 50) |
|
print("XAI ANALYSIS COMPLETE") |
|
print("=" * 50) |
|
|
|
return results |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|