Kalyangotimothy commited on
Commit
217a100
·
1 Parent(s): d88d781
Files changed (14) hide show
  1. API.py +17 -0
  2. Dockerfile +28 -0
  3. README.md +36 -14
  4. app.py +128 -0
  5. cleaning.py +10 -0
  6. deploy.py +63 -0
  7. eda_analysis.py +381 -0
  8. extraction.py +6 -0
  9. finetune.py +46 -0
  10. finetune_tinyllama.py +39 -0
  11. gradio-app.py +7 -0
  12. llama2_inference.py +13 -0
  13. requirements.txt +10 -0
  14. xai_analysis.py +436 -0
API.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ app = FastAPI()
5
+
6
+ model_path = "tinyllama-finetuned-skin"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ model = AutoModelForCausalLM.from_pretrained(model_path)
9
+
10
+ @app.post("/generate")
11
+ async def generate(request: Request):
12
+ data = await request.json()
13
+ prompt = data["prompt"]
14
+ inputs = tokenizer(prompt, return_tensors="pt")
15
+ outputs = model.generate(**inputs, max_new_tokens=100)
16
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
+ return {"response": result}
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends \
8
+ gcc \
9
+ g++ \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first (for better caching)
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir --upgrade pip && \
16
+ pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code
19
+ COPY . .
20
+
21
+ # Create model directory
22
+ RUN mkdir -p tinyllama-finetuned-skin
23
+
24
+ # Expose Hugging Face Spaces port
25
+ EXPOSE 7860
26
+
27
+ # Start Gradio app for Hugging Face Spaces
28
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,14 +1,36 @@
1
- ---
2
- title: Skin
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: skin diseases Llama
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🏥 Skin Disease AI Assistant
2
+
3
+ An AI-powered assistant for skin disease analysis and diagnosis support, built with fine-tuned TinyLlama model.
4
+
5
+ ## Features
6
+
7
+ - 🤖 AI-powered skin disease analysis
8
+ - 🩺 Medical consultation support
9
+ - 📊 Treatment recommendations
10
+ - 🔬 Research-backed responses
11
+
12
+ ## Usage
13
+
14
+ Enter your medical query and get AI-powered insights for skin conditions, symptoms, and treatment options.
15
+
16
+ **Note:** This is for educational purposes only. Always consult with medical professionals for actual diagnosis and treatment.
17
+
18
+ ## Model Information
19
+
20
+ - Base Model: TinyLlama-1.1B
21
+ - Fine-tuned on: Skin disease medical literature
22
+ - Specialization: Dermatology and skin conditions
23
+
24
+ ## Example Queries
25
+
26
+ - "Patient presents with red scaly patches on elbows. What could this be?"
27
+ - "Describe treatment options for psoriasis"
28
+ - "What are the symptoms of eczema?"
29
+
30
+ ## Deployment
31
+
32
+ This app is deployed on Hugging Face Spaces with Gradio interface.
33
+
34
+ ## Disclaimer
35
+
36
+ This AI assistant is for educational and research purposes only. It should not be used as a substitute for professional medical advice, diagnosis, or treatment.
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Spaces deployment file for Skin Disease AI API
4
+ """
5
+ import os
6
+ import gradio as gr
7
+ import requests
8
+ import json
9
+ from threading import Thread
10
+ import time
11
+ import subprocess
12
+ import sys
13
+
14
+ # Start FastAPI in background
15
+ def start_fastapi():
16
+ """Start FastAPI server in background"""
17
+ subprocess.run([
18
+ sys.executable, "-m", "uvicorn", "API:app",
19
+ "--host", "0.0.0.0", "--port", "8000"
20
+ ])
21
+
22
+ # Start FastAPI in a separate thread
23
+ api_thread = Thread(target=start_fastapi, daemon=True)
24
+ api_thread.start()
25
+
26
+ # Wait for API to start
27
+ time.sleep(10)
28
+
29
+ def generate_text(prompt, max_tokens, temperature):
30
+ """Generate text using the API"""
31
+ try:
32
+ url = "http://localhost:8000/generate"
33
+ payload = {
34
+ "prompt": prompt,
35
+ "max_new_tokens": int(max_tokens),
36
+ "temperature": float(temperature)
37
+ }
38
+
39
+ response = requests.post(url, json=payload, timeout=30)
40
+
41
+ if response.status_code == 200:
42
+ result = response.json()
43
+ return result["response"]
44
+ else:
45
+ return f"Error: {response.status_code} - {response.text}"
46
+
47
+ except Exception as e:
48
+ return f"Error: {str(e)}"
49
+
50
+ def check_api_health():
51
+ """Check if API is running"""
52
+ try:
53
+ response = requests.get("http://localhost:8000/health", timeout=5)
54
+ if response.status_code == 200:
55
+ return "✅ API is running"
56
+ else:
57
+ return "❌ API error"
58
+ except:
59
+ return "❌ API not responding"
60
+
61
+ # Create Gradio interface
62
+ with gr.Blocks(title="Skin Disease AI", theme=gr.themes.Soft()) as demo:
63
+ gr.Markdown("# 🏥 Skin Disease AI Assistant")
64
+ gr.Markdown("AI-powered assistant for skin disease analysis and diagnosis support.")
65
+
66
+ with gr.Row():
67
+ with gr.Column(scale=2):
68
+ prompt_input = gr.Textbox(
69
+ label="Enter your medical query",
70
+ placeholder="Patient presents with red scaly patches on elbows. Diagnosis:",
71
+ lines=3
72
+ )
73
+
74
+ with gr.Row():
75
+ max_tokens = gr.Slider(
76
+ minimum=10,
77
+ maximum=200,
78
+ value=100,
79
+ step=10,
80
+ label="Max tokens"
81
+ )
82
+ temperature = gr.Slider(
83
+ minimum=0.1,
84
+ maximum=1.0,
85
+ value=0.7,
86
+ step=0.1,
87
+ label="Temperature"
88
+ )
89
+
90
+ generate_btn = gr.Button("Generate Response", variant="primary")
91
+
92
+ with gr.Column(scale=1):
93
+ api_status = gr.Textbox(label="API Status", value="Starting...", interactive=False)
94
+ check_btn = gr.Button("Check API")
95
+
96
+ output = gr.Textbox(
97
+ label="AI Response",
98
+ lines=8,
99
+ placeholder="AI response will appear here..."
100
+ )
101
+
102
+ # Examples
103
+ gr.Examples(
104
+ examples=[
105
+ ["Patient has red scaly patches on elbows and knees. What could this be?", 80, 0.7],
106
+ ["Describe treatment options for psoriasis", 100, 0.6],
107
+ ["What are the symptoms of eczema?", 120, 0.5],
108
+ ],
109
+ inputs=[prompt_input, max_tokens, temperature],
110
+ )
111
+
112
+ # Event handlers
113
+ generate_btn.click(
114
+ fn=generate_text,
115
+ inputs=[prompt_input, max_tokens, temperature],
116
+ outputs=output
117
+ )
118
+
119
+ check_btn.click(
120
+ fn=check_api_health,
121
+ outputs=api_status
122
+ )
123
+
124
+ # Auto-check API status on load
125
+ demo.load(fn=check_api_health, outputs=api_status)
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
cleaning.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ input_file = "skin_disease_articles.txt"
2
+ output_file = "skin_disease_articles_clean.txt"
3
+
4
+ with open(input_file, "r", encoding="utf-8") as infile, open(output_file, "w", encoding="utf-8") as outfile:
5
+ for line in infile:
6
+ cleaned = line.strip() # Remove leading/trailing whitespace
7
+ if cleaned: # Skip empty lines
8
+ outfile.write(cleaned + "\n")
9
+
10
+ print(f"Cleaned file saved as {output_file}")
deploy.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple deployment script for Skin Disease AI API
4
+ """
5
+ import os
6
+ import sys
7
+ import subprocess
8
+
9
+ def create_demo_model():
10
+ """Create a demo model for testing"""
11
+ print("Creating demo model...")
12
+ try:
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+
15
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
+ print(f"Downloading {model_name}...")
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForCausalLM.from_pretrained(model_name)
20
+
21
+ if tokenizer.pad_token is None:
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+
24
+ os.makedirs("tinyllama-finetuned-skin", exist_ok=True)
25
+ tokenizer.save_pretrained("tinyllama-finetuned-skin")
26
+ model.save_pretrained("tinyllama-finetuned-skin")
27
+
28
+ print("✅ Demo model created successfully!")
29
+ return True
30
+ except Exception as e:
31
+ print(f"❌ Error creating model: {e}")
32
+ return False
33
+
34
+ def start_server():
35
+ """Start the API server"""
36
+ print("🚀 Starting API server on http://localhost:8000")
37
+ print("Press Ctrl+C to stop")
38
+ try:
39
+ subprocess.run([
40
+ sys.executable, "-m", "uvicorn", "API:app",
41
+ "--host", "0.0.0.0", "--port", "8000", "--reload"
42
+ ])
43
+ except KeyboardInterrupt:
44
+ print("\n🛑 Server stopped.")
45
+
46
+ def main():
47
+ print("🚀 Skin Disease AI Deployment")
48
+ print("=" * 40)
49
+
50
+ # Check if model exists
51
+ if not os.path.exists("tinyllama-finetuned-skin"):
52
+ print("Model not found. Creating demo model...")
53
+ if not create_demo_model():
54
+ print("Failed to create model. Exiting.")
55
+ return
56
+ else:
57
+ print("✅ Model found!")
58
+
59
+ # Start server
60
+ start_server()
61
+
62
+ if __name__ == "__main__":
63
+ main()
eda_analysis.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from collections import Counter
6
+ import re
7
+ from wordcloud import WordCloud
8
+ from textstat import flesch_reading_ease, flesch_kincaid_grade
9
+ import nltk
10
+ from nltk.corpus import stopwords
11
+ from nltk.tokenize import word_tokenize, sent_tokenize
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.decomposition import LatentDirichletAllocation
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+ # Download required NLTK data
18
+ try:
19
+ nltk.data.find('tokenizers/punkt')
20
+ except LookupError:
21
+ nltk.download('punkt')
22
+
23
+ try:
24
+ nltk.data.find('corpora/stopwords')
25
+ except LookupError:
26
+ nltk.download('stopwords')
27
+
28
+ class SkinDiseaseEDA:
29
+ def __init__(self, filepath):
30
+ self.filepath = filepath
31
+ self.data = []
32
+ self.articles = []
33
+ self.load_data()
34
+
35
+ def load_data(self):
36
+ """Parse the structured text file into articles"""
37
+ with open(self.filepath, 'r', encoding='utf-8') as file:
38
+ content = file.read()
39
+
40
+ # Split by separator
41
+ articles = content.split('------------------------------------------------------------')
42
+
43
+ for article in articles:
44
+ if not article.strip():
45
+ continue
46
+
47
+ lines = article.strip().split('\n')
48
+ article_data = {
49
+ 'title': '',
50
+ 'journal': '',
51
+ 'authors': '',
52
+ 'abstract': '',
53
+ 'diagnosis': '',
54
+ 'treatment': ''
55
+ }
56
+
57
+ current_section = None
58
+ for line in lines:
59
+ line = line.strip()
60
+ if not line:
61
+ continue
62
+
63
+ if line.startswith('Journal:'):
64
+ current_section = 'journal'
65
+ article_data['journal'] = line.replace('Journal:', '').strip()
66
+ elif line.startswith('Authors:'):
67
+ current_section = 'authors'
68
+ article_data['authors'] = line.replace('Authors:', '').strip()
69
+ elif line.startswith('Abstract:'):
70
+ current_section = 'abstract'
71
+ article_data['abstract'] = line.replace('Abstract:', '').strip()
72
+ elif line == 'Diagnosis':
73
+ current_section = 'diagnosis'
74
+ elif line == 'Treatment Remedies':
75
+ current_section = 'treatment'
76
+ elif current_section == 'abstract' and not line.startswith(('Journal:', 'Authors:', 'Diagnosis', 'Treatment')):
77
+ article_data['abstract'] += ' ' + line
78
+ elif current_section == 'diagnosis' and not line.startswith(('Journal:', 'Authors:', 'Abstract:', 'Treatment')):
79
+ article_data['diagnosis'] += ' ' + line
80
+ elif current_section == 'treatment' and not line.startswith(('Journal:', 'Authors:', 'Abstract:', 'Diagnosis')):
81
+ article_data['treatment'] += ' ' + line
82
+ elif not any(line.startswith(prefix) for prefix in ['Journal:', 'Authors:', 'Abstract:', 'Diagnosis', 'Treatment']) and not current_section:
83
+ article_data['title'] = line
84
+
85
+ # Clean up data
86
+ for key in article_data:
87
+ article_data[key] = article_data[key].strip()
88
+
89
+ if article_data['title']:
90
+ self.articles.append(article_data)
91
+
92
+ def basic_statistics(self):
93
+ """Generate basic statistics about the corpus"""
94
+ print("=== BASIC CORPUS STATISTICS ===")
95
+ print(f"Total articles: {len(self.articles)}")
96
+
97
+ # Text length statistics
98
+ abstract_lengths = [len(article['abstract']) for article in self.articles if article['abstract']]
99
+ title_lengths = [len(article['title']) for article in self.articles if article['title']]
100
+
101
+ print(f"Articles with abstracts: {len(abstract_lengths)}")
102
+ print(f"Average abstract length: {np.mean(abstract_lengths):.1f} characters")
103
+ print(f"Average title length: {np.mean(title_lengths):.1f} characters")
104
+
105
+ # Word counts
106
+ abstract_words = [len(article['abstract'].split()) for article in self.articles if article['abstract']]
107
+ print(f"Average abstract word count: {np.mean(abstract_words):.1f} words")
108
+
109
+ # Diagnosis and treatment availability
110
+ with_diagnosis = sum(1 for article in self.articles if article['diagnosis'] and article['diagnosis'] != 'Not specified.')
111
+ with_treatment = sum(1 for article in self.articles if article['treatment'])
112
+
113
+ print(f"Articles with specific diagnosis: {with_diagnosis} ({with_diagnosis/len(self.articles)*100:.1f}%)")
114
+ print(f"Articles with treatment info: {with_treatment} ({with_treatment/len(self.articles)*100:.1f}%)")
115
+
116
+ return {
117
+ 'total_articles': len(self.articles),
118
+ 'abstract_lengths': abstract_lengths,
119
+ 'title_lengths': title_lengths,
120
+ 'abstract_words': abstract_words,
121
+ 'with_diagnosis': with_diagnosis,
122
+ 'with_treatment': with_treatment
123
+ }
124
+
125
+ def journal_analysis(self):
126
+ """Analyze journal distribution"""
127
+ print("\n=== JOURNAL ANALYSIS ===")
128
+
129
+ journals = [article['journal'] for article in self.articles if article['journal']]
130
+ journal_counts = Counter(journals)
131
+
132
+ print(f"Total unique journals: {len(journal_counts)}")
133
+ print("Top 10 journals:")
134
+ for journal, count in journal_counts.most_common(10):
135
+ print(f" {journal}: {count} articles")
136
+
137
+ # Create visualization
138
+ plt.figure(figsize=(12, 8))
139
+ top_journals = dict(journal_counts.most_common(15))
140
+ plt.barh(list(top_journals.keys()), list(top_journals.values()))
141
+ plt.title('Top 15 Journals by Article Count')
142
+ plt.xlabel('Number of Articles')
143
+ plt.tight_layout()
144
+ plt.show()
145
+
146
+ return journal_counts
147
+
148
+ def author_analysis(self):
149
+ """Analyze author patterns"""
150
+ print("\n=== AUTHOR ANALYSIS ===")
151
+
152
+ all_authors = []
153
+ for article in self.articles:
154
+ if article['authors']:
155
+ # Split authors by comma
156
+ authors = [author.strip() for author in article['authors'].split(',')]
157
+ all_authors.extend(authors)
158
+
159
+ author_counts = Counter(all_authors)
160
+
161
+ print(f"Total unique authors: {len(author_counts)}")
162
+ print(f"Total author instances: {len(all_authors)}")
163
+ print(f"Average authors per article: {len(all_authors)/len(self.articles):.1f}")
164
+
165
+ print("Top 10 most prolific authors:")
166
+ for author, count in author_counts.most_common(10):
167
+ print(f" {author}: {count} articles")
168
+
169
+ # Author collaboration network size
170
+ author_counts_per_article = [len(article['authors'].split(',')) for article in self.articles if article['authors']]
171
+ print(f"Average collaboration size: {np.mean(author_counts_per_article):.1f} authors per article")
172
+
173
+ return author_counts
174
+
175
+ def disease_analysis(self):
176
+ """Analyze disease mentions and patterns"""
177
+ print("\n=== DISEASE AND CONDITION ANALYSIS ===")
178
+
179
+ # Common disease terms
180
+ disease_terms = [
181
+ 'cancer', 'carcinoma', 'melanoma', 'psoriasis', 'dermatitis', 'eczema',
182
+ 'acne', 'rosacea', 'vitiligo', 'lupus', 'scleroderma', 'pemphigus',
183
+ 'bullous', 'urticaria', 'mastocytosis', 'lymphoma', 'sarcoma',
184
+ 'basal cell', 'squamous cell', 'keratosis', 'mycosis', 'fungal',
185
+ 'bacterial', 'viral', 'herpes', 'warts', 'molluscum', 'impetigo'
186
+ ]
187
+
188
+ # Count mentions in titles and abstracts
189
+ disease_counts = Counter()
190
+
191
+ for article in self.articles:
192
+ text = (article['title'] + ' ' + article['abstract']).lower()
193
+ for term in disease_terms:
194
+ if term in text:
195
+ disease_counts[term] += 1
196
+
197
+ print("Top 15 disease/condition mentions:")
198
+ for disease, count in disease_counts.most_common(15):
199
+ print(f" {disease}: {count} mentions")
200
+
201
+ # Create visualization
202
+ plt.figure(figsize=(12, 8))
203
+ top_diseases = dict(disease_counts.most_common(15))
204
+ plt.barh(list(top_diseases.keys()), list(top_diseases.values()))
205
+ plt.title('Top 15 Disease/Condition Mentions')
206
+ plt.xlabel('Number of Mentions')
207
+ plt.tight_layout()
208
+ plt.show()
209
+
210
+ return disease_counts
211
+
212
+ def treatment_analysis(self):
213
+ """Analyze treatment patterns"""
214
+ print("\n=== TREATMENT ANALYSIS ===")
215
+
216
+ # Common treatment terms
217
+ treatment_terms = [
218
+ 'therapy', 'treatment', 'drug', 'medication', 'topical', 'oral',
219
+ 'systemic', 'immunosuppressive', 'corticosteroid', 'antibiotic',
220
+ 'antifungal', 'antiviral', 'chemotherapy', 'radiotherapy',
221
+ 'surgical', 'laser', 'phototherapy', 'immunotherapy', 'biologic',
222
+ 'methotrexate', 'cyclosporine', 'tacrolimus', 'rituximab'
223
+ ]
224
+
225
+ treatment_counts = Counter()
226
+
227
+ for article in self.articles:
228
+ text = (article['treatment'] + ' ' + article['abstract']).lower()
229
+ for term in treatment_terms:
230
+ if term in text:
231
+ treatment_counts[term] += 1
232
+
233
+ print("Top 15 treatment mentions:")
234
+ for treatment, count in treatment_counts.most_common(15):
235
+ print(f" {treatment}: {count} mentions")
236
+
237
+ # Create visualization
238
+ plt.figure(figsize=(12, 8))
239
+ top_treatments = dict(treatment_counts.most_common(15))
240
+ plt.barh(list(top_treatments.keys()), list(top_treatments.values()))
241
+ plt.title('Top 15 Treatment Mentions')
242
+ plt.xlabel('Number of Mentions')
243
+ plt.tight_layout()
244
+ plt.show()
245
+
246
+ return treatment_counts
247
+
248
+ def keyword_analysis(self):
249
+ """Perform keyword analysis using TF-IDF"""
250
+ print("\n=== KEYWORD ANALYSIS ===")
251
+
252
+ # Combine title and abstract for each article
253
+ documents = []
254
+ for article in self.articles:
255
+ doc = article['title'] + ' ' + article['abstract']
256
+ documents.append(doc)
257
+
258
+ # TF-IDF analysis
259
+ stop_words = set(stopwords.words('english'))
260
+ stop_words.update(['study', 'research', 'analysis', 'results', 'conclusion', 'background', 'methods'])
261
+
262
+ vectorizer = TfidfVectorizer(
263
+ max_features=100,
264
+ stop_words=list(stop_words),
265
+ ngram_range=(1, 2),
266
+ min_df=2,
267
+ max_df=0.8
268
+ )
269
+
270
+ tfidf_matrix = vectorizer.fit_transform(documents)
271
+ feature_names = vectorizer.get_feature_names_out()
272
+
273
+ # Get top keywords
274
+ mean_scores = np.mean(tfidf_matrix.toarray(), axis=0)
275
+ top_indices = np.argsort(mean_scores)[::-1][:20]
276
+
277
+ print("Top 20 keywords by TF-IDF score:")
278
+ for i, idx in enumerate(top_indices):
279
+ print(f" {i+1}. {feature_names[idx]}: {mean_scores[idx]:.4f}")
280
+
281
+ # Create word cloud
282
+ all_text = ' '.join(documents)
283
+ wordcloud = WordCloud(
284
+ width=800,
285
+ height=400,
286
+ background_color='white',
287
+ stopwords=stop_words,
288
+ max_words=100
289
+ ).generate(all_text)
290
+
291
+ plt.figure(figsize=(12, 6))
292
+ plt.imshow(wordcloud, interpolation='bilinear')
293
+ plt.axis('off')
294
+ plt.title('Word Cloud of Skin Disease Articles')
295
+ plt.tight_layout()
296
+ plt.show()
297
+
298
+ return feature_names, mean_scores
299
+
300
+ def readability_analysis(self):
301
+ """Analyze text readability"""
302
+ print("\n=== READABILITY ANALYSIS ===")
303
+
304
+ flesch_scores = []
305
+ grade_levels = []
306
+
307
+ for article in self.articles:
308
+ if article['abstract']:
309
+ try:
310
+ flesch_score = flesch_reading_ease(article['abstract'])
311
+ grade_level = flesch_kincaid_grade(article['abstract'])
312
+ flesch_scores.append(flesch_score)
313
+ grade_levels.append(grade_level)
314
+ except:
315
+ continue
316
+
317
+ print(f"Average Flesch Reading Ease Score: {np.mean(flesch_scores):.1f}")
318
+ print(f"Average Grade Level: {np.mean(grade_levels):.1f}")
319
+
320
+ # Interpretation
321
+ avg_flesch = np.mean(flesch_scores)
322
+ if avg_flesch >= 90:
323
+ difficulty = "Very Easy"
324
+ elif avg_flesch >= 80:
325
+ difficulty = "Easy"
326
+ elif avg_flesch >= 70:
327
+ difficulty = "Fairly Easy"
328
+ elif avg_flesch >= 60:
329
+ difficulty = "Standard"
330
+ elif avg_flesch >= 50:
331
+ difficulty = "Fairly Difficult"
332
+ elif avg_flesch >= 30:
333
+ difficulty = "Difficult"
334
+ else:
335
+ difficulty = "Very Difficult"
336
+
337
+ print(f"Reading Difficulty: {difficulty}")
338
+
339
+ return flesch_scores, grade_levels
340
+
341
+ def generate_summary_report(self):
342
+ """Generate a comprehensive summary report"""
343
+ print("\n" + "="*50)
344
+ print("COMPREHENSIVE EDA SUMMARY REPORT")
345
+ print("="*50)
346
+
347
+ # Run all analyses
348
+ basic_stats = self.basic_statistics()
349
+ journal_counts = self.journal_analysis()
350
+ author_counts = self.author_analysis()
351
+ disease_counts = self.disease_analysis()
352
+ treatment_counts = self.treatment_analysis()
353
+ keywords, scores = self.keyword_analysis()
354
+ flesch_scores, grade_levels = self.readability_analysis()
355
+
356
+ # Summary insights
357
+ print("\n=== KEY INSIGHTS ===")
358
+ print(f"1. Corpus contains {basic_stats['total_articles']} articles from {len(journal_counts)} unique journals")
359
+ print(f"2. Most common disease area: {disease_counts.most_common(1)[0][0] if disease_counts else 'N/A'}")
360
+ print(f"3. Most common treatment approach: {treatment_counts.most_common(1)[0][0] if treatment_counts else 'N/A'}")
361
+ print(f"4. Average reading level: Grade {np.mean(grade_levels):.1f}")
362
+ print(f"5. {basic_stats['with_diagnosis']} articles have specific diagnosis information")
363
+ print(f"6. {basic_stats['with_treatment']} articles contain treatment information")
364
+
365
+ def main():
366
+ # Initialize EDA
367
+ eda = SkinDiseaseEDA('skin_disease_articles_clean.txt')
368
+
369
+ # Generate comprehensive report
370
+ eda.generate_summary_report()
371
+
372
+ # Set up plotting style
373
+ plt.style.use('seaborn-v0_8')
374
+ sns.set_palette("husl")
375
+
376
+ print("\n" + "="*50)
377
+ print("EDA ANALYSIS COMPLETE")
378
+ print("="*50)
379
+
380
+ if __name__ == "__main__":
381
+ main()
extraction.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from docx import Document
2
+
3
+ doc = Document("skin_disease_articles.docx")
4
+ with open("skin_disease_articles.txt", "w", encoding="utf-8") as f:
5
+ for para in doc.paragraphs:
6
+ f.write(para.text + "\n")
finetune.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
2
+ from datasets import load_dataset
3
+ import os
4
+
5
+ os.environ["USE_TF"] = "0"
6
+
7
+ model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ if tokenizer.pad_token is None:
10
+ tokenizer.pad_token = tokenizer.eos_token
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+
13
+ # Load your text file as a dataset
14
+ dataset = load_dataset("text", data_files={"train": "skin_disease_articles_clean.txt"})
15
+
16
+ # Tokenize the dataset
17
+ def tokenize_function(examples):
18
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
19
+
20
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
21
+
22
+ train_dataset = tokenized_datasets["train"]
23
+
24
+ data_collator = DataCollatorForLanguageModeling(
25
+ tokenizer=tokenizer, mlm=False
26
+ )
27
+
28
+ training_args = TrainingArguments(
29
+ output_dir="./tinyllama-finetuned-skin",
30
+ overwrite_output_dir=True,
31
+ num_train_epochs=1,
32
+ per_device_train_batch_size=2,
33
+ save_steps=500,
34
+ save_total_limit=2,
35
+ prediction_loss_only=True,
36
+ fp16=True # Set True if using GPU with float16 support
37
+ )
38
+
39
+ trainer = Trainer(
40
+ model=model,
41
+ args=training_args,
42
+ data_collator=data_collator,
43
+ train_dataset=train_dataset,
44
+ )
45
+
46
+ trainer.train()
finetune_tinyllama.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
2
+
3
+ model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
4
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+
7
+ # Prepare dataset
8
+ def load_dataset(file_path, tokenizer, block_size=128):
9
+ return TextDataset(
10
+ tokenizer=tokenizer,
11
+ file_path=file_path,
12
+ block_size=block_size
13
+ )
14
+
15
+ train_dataset = load_dataset("skin_disease_articles_clean.txt", tokenizer)
16
+
17
+ data_collator = DataCollatorForLanguageModeling(
18
+ tokenizer=tokenizer, mlm=False
19
+ )
20
+
21
+ training_args = TrainingArguments(
22
+ output_dir="./tinyllama-finetuned-skin",
23
+ overwrite_output_dir=True,
24
+ num_train_epochs=1,
25
+ per_device_train_batch_size=2,
26
+ save_steps=500,
27
+ save_total_limit=2,
28
+ prediction_loss_only=True,
29
+ fp16=False # Set True if using GPU with float16 support
30
+ )
31
+
32
+ trainer = Trainer(
33
+ model=model,
34
+ args=training_args,
35
+ data_collator=data_collator,
36
+ train_dataset=train_dataset,
37
+ )
38
+
39
+ trainer.train()
gradio-app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ demo.launch()
llama2_inference.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # Or local path if downloaded
4
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+
7
+ # Example: Use a line from your cleaned file as a prompt
8
+ with open("skin_disease_articles_clean.txt", "r", encoding="utf-8") as f:
9
+ prompt = f.readline().strip()
10
+
11
+ inputs = tokenizer(prompt, return_tensors="pt")
12
+ outputs = model.generate(**inputs, max_new_tokens=100)
13
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ transformers==4.35.2
4
+ torch>=2.0.0
5
+ datasets==2.14.7
6
+ accelerate==0.24.1
7
+ pydantic==2.5.0
8
+ python-multipart==0.0.6
9
+ gradio==4.15.0
10
+ requests>=2.25.0
xai_analysis.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ from transformers import LlamaTokenizer, LlamaForCausalLM
8
+ import pandas as pd
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+ import lime
11
+ from lime.lime_text import LimeTextExplainer
12
+ import shap
13
+ import re
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+ class LLMExplainabilityAnalyzer:
18
+ def __init__(self, model_path, tokenizer_path=None):
19
+ """Initialize with model and tokenizer paths"""
20
+ self.model_path = model_path
21
+ self.tokenizer_path = tokenizer_path or model_path
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+
24
+ # Load model and tokenizer
25
+ self.load_model()
26
+
27
+ # Initialize explanation tools
28
+ self.lime_explainer = LimeTextExplainer(class_names=['Generated Text'])
29
+
30
+ def load_model(self):
31
+ """Load the fine-tuned model and tokenizer"""
32
+ try:
33
+ print(f"Loading model from: {self.model_path}")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
35
+ self.model = AutoModelForCausalLM.from_pretrained(
36
+ self.model_path,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ device_map="auto" if torch.cuda.is_available() else None
39
+ )
40
+
41
+ # Set padding token if not exists
42
+ if self.tokenizer.pad_token is None:
43
+ self.tokenizer.pad_token = self.tokenizer.eos_token
44
+
45
+ print("Model loaded successfully!")
46
+
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ # Fallback to base model
50
+ print("Loading base TinyLlama model...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
52
+ self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
53
+ if self.tokenizer.pad_token is None:
54
+ self.tokenizer.pad_token = self.tokenizer.eos_token
55
+
56
+ def extract_attention_weights(self, text, max_length=512):
57
+ """Extract attention weights for visualization"""
58
+ inputs = self.tokenizer(
59
+ text,
60
+ return_tensors="pt",
61
+ max_length=max_length,
62
+ truncation=True,
63
+ padding=True
64
+ ).to(self.device)
65
+
66
+ with torch.no_grad():
67
+ outputs = self.model(**inputs, output_attentions=True)
68
+ attentions = outputs.attentions
69
+
70
+ # Get tokens
71
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
72
+
73
+ return attentions, tokens
74
+
75
+ def visualize_attention_heads(self, text, layer_idx=0, head_idx=0, max_length=512):
76
+ """Visualize attention patterns for specific layer and head"""
77
+ attentions, tokens = self.extract_attention_weights(text, max_length)
78
+
79
+ # Get attention weights for specific layer and head
80
+ attention_weights = attentions[layer_idx][0, head_idx].cpu().numpy()
81
+
82
+ # Create heatmap
83
+ plt.figure(figsize=(12, 8))
84
+ sns.heatmap(
85
+ attention_weights,
86
+ xticklabels=tokens,
87
+ yticklabels=tokens,
88
+ cmap='Blues',
89
+ cbar=True
90
+ )
91
+ plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')
92
+ plt.xlabel('Key Tokens')
93
+ plt.ylabel('Query Tokens')
94
+ plt.xticks(rotation=45)
95
+ plt.yticks(rotation=0)
96
+ plt.tight_layout()
97
+ plt.show()
98
+
99
+ return attention_weights, tokens
100
+
101
+ def attention_rollout(self, text, max_length=512):
102
+ """Compute attention rollout for global attention patterns"""
103
+ attentions, tokens = self.extract_attention_weights(text, max_length)
104
+
105
+ # Convert to numpy
106
+ attention_matrices = [att[0].mean(dim=0).cpu().numpy() for att in attentions]
107
+
108
+ # Compute rollout
109
+ rollout = attention_matrices[0]
110
+ for attention_matrix in attention_matrices[1:]:
111
+ rollout = np.matmul(rollout, attention_matrix)
112
+
113
+ # Visualize rollout
114
+ plt.figure(figsize=(12, 8))
115
+ sns.heatmap(
116
+ rollout,
117
+ xticklabels=tokens,
118
+ yticklabels=tokens,
119
+ cmap='Reds',
120
+ cbar=True
121
+ )
122
+ plt.title('Attention Rollout - Global Attention Flow')
123
+ plt.xlabel('Key Tokens')
124
+ plt.ylabel('Query Tokens')
125
+ plt.xticks(rotation=45)
126
+ plt.yticks(rotation=0)
127
+ plt.tight_layout()
128
+ plt.show()
129
+
130
+ return rollout, tokens
131
+
132
+ def gradient_saliency(self, text, target_token_idx=None, max_length=512):
133
+ """Compute gradient-based saliency maps"""
134
+ inputs = self.tokenizer(
135
+ text,
136
+ return_tensors="pt",
137
+ max_length=max_length,
138
+ truncation=True,
139
+ padding=True
140
+ ).to(self.device)
141
+
142
+ # Enable gradients for embeddings
143
+ embeddings = self.model.get_input_embeddings()
144
+ inputs_embeds = embeddings(inputs['input_ids'])
145
+ inputs_embeds.requires_grad_(True)
146
+
147
+ # Forward pass
148
+ outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=inputs['attention_mask'])
149
+
150
+ # Get target logits (last token if not specified)
151
+ if target_token_idx is None:
152
+ target_token_idx = -1
153
+
154
+ target_logits = outputs.logits[0, target_token_idx]
155
+ target_prob = F.softmax(target_logits, dim=-1)
156
+
157
+ # Compute gradients
158
+ target_prob.max().backward()
159
+
160
+ # Get saliency scores
161
+ saliency_scores = inputs_embeds.grad.norm(dim=-1).squeeze().cpu().numpy()
162
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
163
+
164
+ # Visualize saliency
165
+ plt.figure(figsize=(15, 6))
166
+ colors = plt.cm.Reds(saliency_scores / saliency_scores.max())
167
+
168
+ for i, (token, score) in enumerate(zip(tokens, saliency_scores)):
169
+ plt.bar(i, score, color=colors[i])
170
+ plt.text(i, score + 0.001, token, rotation=45, ha='left', va='bottom')
171
+
172
+ plt.title('Gradient Saliency Scores')
173
+ plt.xlabel('Token Position')
174
+ plt.ylabel('Saliency Score')
175
+ plt.tight_layout()
176
+ plt.show()
177
+
178
+ return saliency_scores, tokens
179
+
180
+ def lime_explanation(self, text, num_samples=1000):
181
+ """Generate LIME explanations"""
182
+ def predict_fn(texts):
183
+ """Prediction function for LIME"""
184
+ predictions = []
185
+ for text in texts:
186
+ try:
187
+ inputs = self.tokenizer(
188
+ text,
189
+ return_tensors="pt",
190
+ max_length=512,
191
+ truncation=True,
192
+ padding=True
193
+ ).to(self.device)
194
+
195
+ with torch.no_grad():
196
+ outputs = self.model(**inputs)
197
+ logits = outputs.logits[0, -1]
198
+ probs = F.softmax(logits, dim=-1)
199
+
200
+ # Return probability distribution
201
+ predictions.append(probs.cpu().numpy())
202
+ except:
203
+ # Return uniform distribution if error
204
+ predictions.append(np.ones(self.tokenizer.vocab_size) / self.tokenizer.vocab_size)
205
+
206
+ return np.array(predictions)
207
+
208
+ # Generate explanation
209
+ explanation = self.lime_explainer.explain_instance(
210
+ text,
211
+ predict_fn,
212
+ num_features=20,
213
+ num_samples=num_samples
214
+ )
215
+
216
+ # Visualize explanation
217
+ explanation.show_in_notebook(text=True)
218
+
219
+ return explanation
220
+
221
+ def activation_analysis(self, text, layer_indices=None, max_length=512):
222
+ """Analyze hidden layer activations"""
223
+ inputs = self.tokenizer(
224
+ text,
225
+ return_tensors="pt",
226
+ max_length=max_length,
227
+ truncation=True,
228
+ padding=True
229
+ ).to(self.device)
230
+
231
+ # Hook to capture activations
232
+ activations = {}
233
+
234
+ def hook_fn(name):
235
+ def hook(module, input, output):
236
+ activations[name] = output.detach()
237
+ return hook
238
+
239
+ # Register hooks
240
+ if layer_indices is None:
241
+ layer_indices = [0, len(self.model.model.layers)//2, len(self.model.model.layers)-1]
242
+
243
+ hooks = []
244
+ for idx in layer_indices:
245
+ if idx < len(self.model.model.layers):
246
+ hook = self.model.model.layers[idx].register_forward_hook(hook_fn(f'layer_{idx}'))
247
+ hooks.append(hook)
248
+
249
+ # Forward pass
250
+ with torch.no_grad():
251
+ outputs = self.model(**inputs)
252
+
253
+ # Remove hooks
254
+ for hook in hooks:
255
+ hook.remove()
256
+
257
+ # Analyze activations
258
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
259
+
260
+ for layer_name, activation in activations.items():
261
+ # Get activation statistics
262
+ activation_np = activation[0].cpu().numpy()
263
+
264
+ # Plot activation distribution
265
+ plt.figure(figsize=(12, 6))
266
+
267
+ # Heatmap of activations
268
+ plt.subplot(1, 2, 1)
269
+ sns.heatmap(activation_np.T, cmap='viridis', cbar=True)
270
+ plt.title(f'{layer_name} Activations')
271
+ plt.xlabel('Token Position')
272
+ plt.ylabel('Hidden Dimension')
273
+
274
+ # Activation magnitude per token
275
+ plt.subplot(1, 2, 2)
276
+ activation_magnitudes = np.linalg.norm(activation_np, axis=1)
277
+ plt.bar(range(len(tokens)), activation_magnitudes)
278
+ plt.title(f'{layer_name} Activation Magnitudes')
279
+ plt.xlabel('Token Position')
280
+ plt.ylabel('Magnitude')
281
+ plt.xticks(range(len(tokens)), tokens, rotation=45)
282
+
283
+ plt.tight_layout()
284
+ plt.show()
285
+
286
+ def token_importance_analysis(self, text, method='attention', max_length=512):
287
+ """Analyze token importance using different methods"""
288
+ results = {}
289
+
290
+ if method == 'attention':
291
+ # Attention-based importance
292
+ attentions, tokens = self.extract_attention_weights(text, max_length)
293
+
294
+ # Average attention across layers and heads
295
+ avg_attention = torch.stack([att.mean(dim=1) for att in attentions]).mean(dim=0)
296
+ importance_scores = avg_attention[0].sum(dim=0).cpu().numpy()
297
+
298
+ elif method == 'gradient':
299
+ # Gradient-based importance
300
+ importance_scores, tokens = self.gradient_saliency(text, max_length=max_length)
301
+
302
+ # Create importance dataframe
303
+ importance_df = pd.DataFrame({
304
+ 'token': tokens,
305
+ 'importance': importance_scores
306
+ })
307
+
308
+ # Sort by importance
309
+ importance_df = importance_df.sort_values('importance', ascending=False)
310
+
311
+ # Visualize top important tokens
312
+ plt.figure(figsize=(12, 6))
313
+ top_tokens = importance_df.head(20)
314
+ plt.barh(range(len(top_tokens)), top_tokens['importance'])
315
+ plt.yticks(range(len(top_tokens)), top_tokens['token'])
316
+ plt.title(f'Top 20 Important Tokens ({method.title()} Method)')
317
+ plt.xlabel('Importance Score')
318
+ plt.tight_layout()
319
+ plt.show()
320
+
321
+ return importance_df
322
+
323
+ def semantic_similarity_analysis(self, texts, max_length=512):
324
+ """Analyze semantic similarity between different texts"""
325
+ embeddings = []
326
+
327
+ for text in texts:
328
+ inputs = self.tokenizer(
329
+ text,
330
+ return_tensors="pt",
331
+ max_length=max_length,
332
+ truncation=True,
333
+ padding=True
334
+ ).to(self.device)
335
+
336
+ with torch.no_grad():
337
+ outputs = self.model(**inputs, output_hidden_states=True)
338
+ # Use last layer, last token embedding
339
+ embedding = outputs.hidden_states[-1][0, -1].cpu().numpy()
340
+ embeddings.append(embedding)
341
+
342
+ # Compute similarity matrix
343
+ similarity_matrix = cosine_similarity(embeddings)
344
+
345
+ # Visualize similarity matrix
346
+ plt.figure(figsize=(10, 8))
347
+ sns.heatmap(
348
+ similarity_matrix,
349
+ annot=True,
350
+ cmap='viridis',
351
+ xticklabels=[f'Text {i+1}' for i in range(len(texts))],
352
+ yticklabels=[f'Text {i+1}' for i in range(len(texts))]
353
+ )
354
+ plt.title('Semantic Similarity Matrix')
355
+ plt.tight_layout()
356
+ plt.show()
357
+
358
+ return similarity_matrix
359
+
360
+ def generate_explanation_report(self, text, output_file='xai_report.html'):
361
+ """Generate comprehensive explanation report"""
362
+ print("Generating comprehensive XAI report...")
363
+
364
+ # Run all analyses
365
+ print("1. Extracting attention patterns...")
366
+ attention_weights, tokens = self.visualize_attention_heads(text)
367
+
368
+ print("2. Computing attention rollout...")
369
+ rollout, _ = self.attention_rollout(text)
370
+
371
+ print("3. Calculating gradient saliency...")
372
+ saliency_scores, _ = self.gradient_saliency(text)
373
+
374
+ print("4. Analyzing activations...")
375
+ self.activation_analysis(text)
376
+
377
+ print("5. Computing token importance...")
378
+ importance_df = self.token_importance_analysis(text)
379
+
380
+ # Create summary
381
+ print("\n=== XAI ANALYSIS SUMMARY ===")
382
+ print(f"Input text: {text[:100]}...")
383
+ print(f"Number of tokens: {len(tokens)}")
384
+ print(f"Most important tokens: {importance_df.head(5)['token'].tolist()}")
385
+ print(f"Average attention entropy: {np.mean(-np.sum(attention_weights * np.log(attention_weights + 1e-10), axis=1)):.4f}")
386
+
387
+ return {
388
+ 'attention_weights': attention_weights,
389
+ 'rollout': rollout,
390
+ 'saliency_scores': saliency_scores,
391
+ 'importance_df': importance_df,
392
+ 'tokens': tokens
393
+ }
394
+
395
+ def main():
396
+ """Main function to run XAI analysis"""
397
+
398
+ # Initialize analyzer (adjust model path as needed)
399
+ try:
400
+ analyzer = LLMExplainabilityAnalyzer("./fine_tuned_model")
401
+ except:
402
+ print("Fine-tuned model not found. Using base model for demonstration.")
403
+ analyzer = LLMExplainabilityAnalyzer("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
404
+
405
+ # Sample skin disease text for analysis
406
+ sample_text = """
407
+ Patient presents with erythematous scaly patches on the elbows and knees,
408
+ consistent with psoriasis. The condition appears to be chronic with periods
409
+ of exacerbation. Treatment options include topical corticosteroids and
410
+ phototherapy for mild to moderate cases.
411
+ """
412
+
413
+ print("Starting XAI Analysis...")
414
+ print("=" * 50)
415
+
416
+ # Generate comprehensive report
417
+ results = analyzer.generate_explanation_report(sample_text)
418
+
419
+ # Additional analyses
420
+ print("\n6. Semantic similarity analysis...")
421
+ test_texts = [
422
+ "Psoriasis treatment with topical corticosteroids",
423
+ "Eczema management using moisturizers",
424
+ "Melanoma diagnosis and surgical intervention"
425
+ ]
426
+
427
+ similarity_matrix = analyzer.semantic_similarity_analysis(test_texts)
428
+
429
+ print("\n" + "=" * 50)
430
+ print("XAI ANALYSIS COMPLETE")
431
+ print("=" * 50)
432
+
433
+ return results
434
+
435
+ if __name__ == "__main__":
436
+ main()