Kalyangotimothy
commited on
Commit
·
217a100
1
Parent(s):
d88d781
new
Browse files- API.py +17 -0
- Dockerfile +28 -0
- README.md +36 -14
- app.py +128 -0
- cleaning.py +10 -0
- deploy.py +63 -0
- eda_analysis.py +381 -0
- extraction.py +6 -0
- finetune.py +46 -0
- finetune_tinyllama.py +39 -0
- gradio-app.py +7 -0
- llama2_inference.py +13 -0
- requirements.txt +10 -0
- xai_analysis.py +436 -0
API.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
|
4 |
+
app = FastAPI()
|
5 |
+
|
6 |
+
model_path = "tinyllama-finetuned-skin"
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
8 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
9 |
+
|
10 |
+
@app.post("/generate")
|
11 |
+
async def generate(request: Request):
|
12 |
+
data = await request.json()
|
13 |
+
prompt = data["prompt"]
|
14 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
15 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
16 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
17 |
+
return {"response": result}
|
Dockerfile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && \
|
7 |
+
apt-get install -y --no-install-recommends \
|
8 |
+
gcc \
|
9 |
+
g++ \
|
10 |
+
git \
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
# Copy requirements first (for better caching)
|
14 |
+
COPY requirements.txt .
|
15 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
16 |
+
pip install --no-cache-dir -r requirements.txt
|
17 |
+
|
18 |
+
# Copy application code
|
19 |
+
COPY . .
|
20 |
+
|
21 |
+
# Create model directory
|
22 |
+
RUN mkdir -p tinyllama-finetuned-skin
|
23 |
+
|
24 |
+
# Expose Hugging Face Spaces port
|
25 |
+
EXPOSE 7860
|
26 |
+
|
27 |
+
# Start Gradio app for Hugging Face Spaces
|
28 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,14 +1,36 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🏥 Skin Disease AI Assistant
|
2 |
+
|
3 |
+
An AI-powered assistant for skin disease analysis and diagnosis support, built with fine-tuned TinyLlama model.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- 🤖 AI-powered skin disease analysis
|
8 |
+
- 🩺 Medical consultation support
|
9 |
+
- 📊 Treatment recommendations
|
10 |
+
- 🔬 Research-backed responses
|
11 |
+
|
12 |
+
## Usage
|
13 |
+
|
14 |
+
Enter your medical query and get AI-powered insights for skin conditions, symptoms, and treatment options.
|
15 |
+
|
16 |
+
**Note:** This is for educational purposes only. Always consult with medical professionals for actual diagnosis and treatment.
|
17 |
+
|
18 |
+
## Model Information
|
19 |
+
|
20 |
+
- Base Model: TinyLlama-1.1B
|
21 |
+
- Fine-tuned on: Skin disease medical literature
|
22 |
+
- Specialization: Dermatology and skin conditions
|
23 |
+
|
24 |
+
## Example Queries
|
25 |
+
|
26 |
+
- "Patient presents with red scaly patches on elbows. What could this be?"
|
27 |
+
- "Describe treatment options for psoriasis"
|
28 |
+
- "What are the symptoms of eczema?"
|
29 |
+
|
30 |
+
## Deployment
|
31 |
+
|
32 |
+
This app is deployed on Hugging Face Spaces with Gradio interface.
|
33 |
+
|
34 |
+
## Disclaimer
|
35 |
+
|
36 |
+
This AI assistant is for educational and research purposes only. It should not be used as a substitute for professional medical advice, diagnosis, or treatment.
|
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Hugging Face Spaces deployment file for Skin Disease AI API
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
import json
|
9 |
+
from threading import Thread
|
10 |
+
import time
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
|
14 |
+
# Start FastAPI in background
|
15 |
+
def start_fastapi():
|
16 |
+
"""Start FastAPI server in background"""
|
17 |
+
subprocess.run([
|
18 |
+
sys.executable, "-m", "uvicorn", "API:app",
|
19 |
+
"--host", "0.0.0.0", "--port", "8000"
|
20 |
+
])
|
21 |
+
|
22 |
+
# Start FastAPI in a separate thread
|
23 |
+
api_thread = Thread(target=start_fastapi, daemon=True)
|
24 |
+
api_thread.start()
|
25 |
+
|
26 |
+
# Wait for API to start
|
27 |
+
time.sleep(10)
|
28 |
+
|
29 |
+
def generate_text(prompt, max_tokens, temperature):
|
30 |
+
"""Generate text using the API"""
|
31 |
+
try:
|
32 |
+
url = "http://localhost:8000/generate"
|
33 |
+
payload = {
|
34 |
+
"prompt": prompt,
|
35 |
+
"max_new_tokens": int(max_tokens),
|
36 |
+
"temperature": float(temperature)
|
37 |
+
}
|
38 |
+
|
39 |
+
response = requests.post(url, json=payload, timeout=30)
|
40 |
+
|
41 |
+
if response.status_code == 200:
|
42 |
+
result = response.json()
|
43 |
+
return result["response"]
|
44 |
+
else:
|
45 |
+
return f"Error: {response.status_code} - {response.text}"
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
return f"Error: {str(e)}"
|
49 |
+
|
50 |
+
def check_api_health():
|
51 |
+
"""Check if API is running"""
|
52 |
+
try:
|
53 |
+
response = requests.get("http://localhost:8000/health", timeout=5)
|
54 |
+
if response.status_code == 200:
|
55 |
+
return "✅ API is running"
|
56 |
+
else:
|
57 |
+
return "❌ API error"
|
58 |
+
except:
|
59 |
+
return "❌ API not responding"
|
60 |
+
|
61 |
+
# Create Gradio interface
|
62 |
+
with gr.Blocks(title="Skin Disease AI", theme=gr.themes.Soft()) as demo:
|
63 |
+
gr.Markdown("# 🏥 Skin Disease AI Assistant")
|
64 |
+
gr.Markdown("AI-powered assistant for skin disease analysis and diagnosis support.")
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column(scale=2):
|
68 |
+
prompt_input = gr.Textbox(
|
69 |
+
label="Enter your medical query",
|
70 |
+
placeholder="Patient presents with red scaly patches on elbows. Diagnosis:",
|
71 |
+
lines=3
|
72 |
+
)
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
max_tokens = gr.Slider(
|
76 |
+
minimum=10,
|
77 |
+
maximum=200,
|
78 |
+
value=100,
|
79 |
+
step=10,
|
80 |
+
label="Max tokens"
|
81 |
+
)
|
82 |
+
temperature = gr.Slider(
|
83 |
+
minimum=0.1,
|
84 |
+
maximum=1.0,
|
85 |
+
value=0.7,
|
86 |
+
step=0.1,
|
87 |
+
label="Temperature"
|
88 |
+
)
|
89 |
+
|
90 |
+
generate_btn = gr.Button("Generate Response", variant="primary")
|
91 |
+
|
92 |
+
with gr.Column(scale=1):
|
93 |
+
api_status = gr.Textbox(label="API Status", value="Starting...", interactive=False)
|
94 |
+
check_btn = gr.Button("Check API")
|
95 |
+
|
96 |
+
output = gr.Textbox(
|
97 |
+
label="AI Response",
|
98 |
+
lines=8,
|
99 |
+
placeholder="AI response will appear here..."
|
100 |
+
)
|
101 |
+
|
102 |
+
# Examples
|
103 |
+
gr.Examples(
|
104 |
+
examples=[
|
105 |
+
["Patient has red scaly patches on elbows and knees. What could this be?", 80, 0.7],
|
106 |
+
["Describe treatment options for psoriasis", 100, 0.6],
|
107 |
+
["What are the symptoms of eczema?", 120, 0.5],
|
108 |
+
],
|
109 |
+
inputs=[prompt_input, max_tokens, temperature],
|
110 |
+
)
|
111 |
+
|
112 |
+
# Event handlers
|
113 |
+
generate_btn.click(
|
114 |
+
fn=generate_text,
|
115 |
+
inputs=[prompt_input, max_tokens, temperature],
|
116 |
+
outputs=output
|
117 |
+
)
|
118 |
+
|
119 |
+
check_btn.click(
|
120 |
+
fn=check_api_health,
|
121 |
+
outputs=api_status
|
122 |
+
)
|
123 |
+
|
124 |
+
# Auto-check API status on load
|
125 |
+
demo.load(fn=check_api_health, outputs=api_status)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
cleaning.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_file = "skin_disease_articles.txt"
|
2 |
+
output_file = "skin_disease_articles_clean.txt"
|
3 |
+
|
4 |
+
with open(input_file, "r", encoding="utf-8") as infile, open(output_file, "w", encoding="utf-8") as outfile:
|
5 |
+
for line in infile:
|
6 |
+
cleaned = line.strip() # Remove leading/trailing whitespace
|
7 |
+
if cleaned: # Skip empty lines
|
8 |
+
outfile.write(cleaned + "\n")
|
9 |
+
|
10 |
+
print(f"Cleaned file saved as {output_file}")
|
deploy.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Simple deployment script for Skin Disease AI API
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
def create_demo_model():
|
10 |
+
"""Create a demo model for testing"""
|
11 |
+
print("Creating demo model...")
|
12 |
+
try:
|
13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
14 |
+
|
15 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
16 |
+
print(f"Downloading {model_name}...")
|
17 |
+
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
20 |
+
|
21 |
+
if tokenizer.pad_token is None:
|
22 |
+
tokenizer.pad_token = tokenizer.eos_token
|
23 |
+
|
24 |
+
os.makedirs("tinyllama-finetuned-skin", exist_ok=True)
|
25 |
+
tokenizer.save_pretrained("tinyllama-finetuned-skin")
|
26 |
+
model.save_pretrained("tinyllama-finetuned-skin")
|
27 |
+
|
28 |
+
print("✅ Demo model created successfully!")
|
29 |
+
return True
|
30 |
+
except Exception as e:
|
31 |
+
print(f"❌ Error creating model: {e}")
|
32 |
+
return False
|
33 |
+
|
34 |
+
def start_server():
|
35 |
+
"""Start the API server"""
|
36 |
+
print("🚀 Starting API server on http://localhost:8000")
|
37 |
+
print("Press Ctrl+C to stop")
|
38 |
+
try:
|
39 |
+
subprocess.run([
|
40 |
+
sys.executable, "-m", "uvicorn", "API:app",
|
41 |
+
"--host", "0.0.0.0", "--port", "8000", "--reload"
|
42 |
+
])
|
43 |
+
except KeyboardInterrupt:
|
44 |
+
print("\n🛑 Server stopped.")
|
45 |
+
|
46 |
+
def main():
|
47 |
+
print("🚀 Skin Disease AI Deployment")
|
48 |
+
print("=" * 40)
|
49 |
+
|
50 |
+
# Check if model exists
|
51 |
+
if not os.path.exists("tinyllama-finetuned-skin"):
|
52 |
+
print("Model not found. Creating demo model...")
|
53 |
+
if not create_demo_model():
|
54 |
+
print("Failed to create model. Exiting.")
|
55 |
+
return
|
56 |
+
else:
|
57 |
+
print("✅ Model found!")
|
58 |
+
|
59 |
+
# Start server
|
60 |
+
start_server()
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
eda_analysis.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
from collections import Counter
|
6 |
+
import re
|
7 |
+
from wordcloud import WordCloud
|
8 |
+
from textstat import flesch_reading_ease, flesch_kincaid_grade
|
9 |
+
import nltk
|
10 |
+
from nltk.corpus import stopwords
|
11 |
+
from nltk.tokenize import word_tokenize, sent_tokenize
|
12 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
+
from sklearn.decomposition import LatentDirichletAllocation
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings('ignore')
|
16 |
+
|
17 |
+
# Download required NLTK data
|
18 |
+
try:
|
19 |
+
nltk.data.find('tokenizers/punkt')
|
20 |
+
except LookupError:
|
21 |
+
nltk.download('punkt')
|
22 |
+
|
23 |
+
try:
|
24 |
+
nltk.data.find('corpora/stopwords')
|
25 |
+
except LookupError:
|
26 |
+
nltk.download('stopwords')
|
27 |
+
|
28 |
+
class SkinDiseaseEDA:
|
29 |
+
def __init__(self, filepath):
|
30 |
+
self.filepath = filepath
|
31 |
+
self.data = []
|
32 |
+
self.articles = []
|
33 |
+
self.load_data()
|
34 |
+
|
35 |
+
def load_data(self):
|
36 |
+
"""Parse the structured text file into articles"""
|
37 |
+
with open(self.filepath, 'r', encoding='utf-8') as file:
|
38 |
+
content = file.read()
|
39 |
+
|
40 |
+
# Split by separator
|
41 |
+
articles = content.split('------------------------------------------------------------')
|
42 |
+
|
43 |
+
for article in articles:
|
44 |
+
if not article.strip():
|
45 |
+
continue
|
46 |
+
|
47 |
+
lines = article.strip().split('\n')
|
48 |
+
article_data = {
|
49 |
+
'title': '',
|
50 |
+
'journal': '',
|
51 |
+
'authors': '',
|
52 |
+
'abstract': '',
|
53 |
+
'diagnosis': '',
|
54 |
+
'treatment': ''
|
55 |
+
}
|
56 |
+
|
57 |
+
current_section = None
|
58 |
+
for line in lines:
|
59 |
+
line = line.strip()
|
60 |
+
if not line:
|
61 |
+
continue
|
62 |
+
|
63 |
+
if line.startswith('Journal:'):
|
64 |
+
current_section = 'journal'
|
65 |
+
article_data['journal'] = line.replace('Journal:', '').strip()
|
66 |
+
elif line.startswith('Authors:'):
|
67 |
+
current_section = 'authors'
|
68 |
+
article_data['authors'] = line.replace('Authors:', '').strip()
|
69 |
+
elif line.startswith('Abstract:'):
|
70 |
+
current_section = 'abstract'
|
71 |
+
article_data['abstract'] = line.replace('Abstract:', '').strip()
|
72 |
+
elif line == 'Diagnosis':
|
73 |
+
current_section = 'diagnosis'
|
74 |
+
elif line == 'Treatment Remedies':
|
75 |
+
current_section = 'treatment'
|
76 |
+
elif current_section == 'abstract' and not line.startswith(('Journal:', 'Authors:', 'Diagnosis', 'Treatment')):
|
77 |
+
article_data['abstract'] += ' ' + line
|
78 |
+
elif current_section == 'diagnosis' and not line.startswith(('Journal:', 'Authors:', 'Abstract:', 'Treatment')):
|
79 |
+
article_data['diagnosis'] += ' ' + line
|
80 |
+
elif current_section == 'treatment' and not line.startswith(('Journal:', 'Authors:', 'Abstract:', 'Diagnosis')):
|
81 |
+
article_data['treatment'] += ' ' + line
|
82 |
+
elif not any(line.startswith(prefix) for prefix in ['Journal:', 'Authors:', 'Abstract:', 'Diagnosis', 'Treatment']) and not current_section:
|
83 |
+
article_data['title'] = line
|
84 |
+
|
85 |
+
# Clean up data
|
86 |
+
for key in article_data:
|
87 |
+
article_data[key] = article_data[key].strip()
|
88 |
+
|
89 |
+
if article_data['title']:
|
90 |
+
self.articles.append(article_data)
|
91 |
+
|
92 |
+
def basic_statistics(self):
|
93 |
+
"""Generate basic statistics about the corpus"""
|
94 |
+
print("=== BASIC CORPUS STATISTICS ===")
|
95 |
+
print(f"Total articles: {len(self.articles)}")
|
96 |
+
|
97 |
+
# Text length statistics
|
98 |
+
abstract_lengths = [len(article['abstract']) for article in self.articles if article['abstract']]
|
99 |
+
title_lengths = [len(article['title']) for article in self.articles if article['title']]
|
100 |
+
|
101 |
+
print(f"Articles with abstracts: {len(abstract_lengths)}")
|
102 |
+
print(f"Average abstract length: {np.mean(abstract_lengths):.1f} characters")
|
103 |
+
print(f"Average title length: {np.mean(title_lengths):.1f} characters")
|
104 |
+
|
105 |
+
# Word counts
|
106 |
+
abstract_words = [len(article['abstract'].split()) for article in self.articles if article['abstract']]
|
107 |
+
print(f"Average abstract word count: {np.mean(abstract_words):.1f} words")
|
108 |
+
|
109 |
+
# Diagnosis and treatment availability
|
110 |
+
with_diagnosis = sum(1 for article in self.articles if article['diagnosis'] and article['diagnosis'] != 'Not specified.')
|
111 |
+
with_treatment = sum(1 for article in self.articles if article['treatment'])
|
112 |
+
|
113 |
+
print(f"Articles with specific diagnosis: {with_diagnosis} ({with_diagnosis/len(self.articles)*100:.1f}%)")
|
114 |
+
print(f"Articles with treatment info: {with_treatment} ({with_treatment/len(self.articles)*100:.1f}%)")
|
115 |
+
|
116 |
+
return {
|
117 |
+
'total_articles': len(self.articles),
|
118 |
+
'abstract_lengths': abstract_lengths,
|
119 |
+
'title_lengths': title_lengths,
|
120 |
+
'abstract_words': abstract_words,
|
121 |
+
'with_diagnosis': with_diagnosis,
|
122 |
+
'with_treatment': with_treatment
|
123 |
+
}
|
124 |
+
|
125 |
+
def journal_analysis(self):
|
126 |
+
"""Analyze journal distribution"""
|
127 |
+
print("\n=== JOURNAL ANALYSIS ===")
|
128 |
+
|
129 |
+
journals = [article['journal'] for article in self.articles if article['journal']]
|
130 |
+
journal_counts = Counter(journals)
|
131 |
+
|
132 |
+
print(f"Total unique journals: {len(journal_counts)}")
|
133 |
+
print("Top 10 journals:")
|
134 |
+
for journal, count in journal_counts.most_common(10):
|
135 |
+
print(f" {journal}: {count} articles")
|
136 |
+
|
137 |
+
# Create visualization
|
138 |
+
plt.figure(figsize=(12, 8))
|
139 |
+
top_journals = dict(journal_counts.most_common(15))
|
140 |
+
plt.barh(list(top_journals.keys()), list(top_journals.values()))
|
141 |
+
plt.title('Top 15 Journals by Article Count')
|
142 |
+
plt.xlabel('Number of Articles')
|
143 |
+
plt.tight_layout()
|
144 |
+
plt.show()
|
145 |
+
|
146 |
+
return journal_counts
|
147 |
+
|
148 |
+
def author_analysis(self):
|
149 |
+
"""Analyze author patterns"""
|
150 |
+
print("\n=== AUTHOR ANALYSIS ===")
|
151 |
+
|
152 |
+
all_authors = []
|
153 |
+
for article in self.articles:
|
154 |
+
if article['authors']:
|
155 |
+
# Split authors by comma
|
156 |
+
authors = [author.strip() for author in article['authors'].split(',')]
|
157 |
+
all_authors.extend(authors)
|
158 |
+
|
159 |
+
author_counts = Counter(all_authors)
|
160 |
+
|
161 |
+
print(f"Total unique authors: {len(author_counts)}")
|
162 |
+
print(f"Total author instances: {len(all_authors)}")
|
163 |
+
print(f"Average authors per article: {len(all_authors)/len(self.articles):.1f}")
|
164 |
+
|
165 |
+
print("Top 10 most prolific authors:")
|
166 |
+
for author, count in author_counts.most_common(10):
|
167 |
+
print(f" {author}: {count} articles")
|
168 |
+
|
169 |
+
# Author collaboration network size
|
170 |
+
author_counts_per_article = [len(article['authors'].split(',')) for article in self.articles if article['authors']]
|
171 |
+
print(f"Average collaboration size: {np.mean(author_counts_per_article):.1f} authors per article")
|
172 |
+
|
173 |
+
return author_counts
|
174 |
+
|
175 |
+
def disease_analysis(self):
|
176 |
+
"""Analyze disease mentions and patterns"""
|
177 |
+
print("\n=== DISEASE AND CONDITION ANALYSIS ===")
|
178 |
+
|
179 |
+
# Common disease terms
|
180 |
+
disease_terms = [
|
181 |
+
'cancer', 'carcinoma', 'melanoma', 'psoriasis', 'dermatitis', 'eczema',
|
182 |
+
'acne', 'rosacea', 'vitiligo', 'lupus', 'scleroderma', 'pemphigus',
|
183 |
+
'bullous', 'urticaria', 'mastocytosis', 'lymphoma', 'sarcoma',
|
184 |
+
'basal cell', 'squamous cell', 'keratosis', 'mycosis', 'fungal',
|
185 |
+
'bacterial', 'viral', 'herpes', 'warts', 'molluscum', 'impetigo'
|
186 |
+
]
|
187 |
+
|
188 |
+
# Count mentions in titles and abstracts
|
189 |
+
disease_counts = Counter()
|
190 |
+
|
191 |
+
for article in self.articles:
|
192 |
+
text = (article['title'] + ' ' + article['abstract']).lower()
|
193 |
+
for term in disease_terms:
|
194 |
+
if term in text:
|
195 |
+
disease_counts[term] += 1
|
196 |
+
|
197 |
+
print("Top 15 disease/condition mentions:")
|
198 |
+
for disease, count in disease_counts.most_common(15):
|
199 |
+
print(f" {disease}: {count} mentions")
|
200 |
+
|
201 |
+
# Create visualization
|
202 |
+
plt.figure(figsize=(12, 8))
|
203 |
+
top_diseases = dict(disease_counts.most_common(15))
|
204 |
+
plt.barh(list(top_diseases.keys()), list(top_diseases.values()))
|
205 |
+
plt.title('Top 15 Disease/Condition Mentions')
|
206 |
+
plt.xlabel('Number of Mentions')
|
207 |
+
plt.tight_layout()
|
208 |
+
plt.show()
|
209 |
+
|
210 |
+
return disease_counts
|
211 |
+
|
212 |
+
def treatment_analysis(self):
|
213 |
+
"""Analyze treatment patterns"""
|
214 |
+
print("\n=== TREATMENT ANALYSIS ===")
|
215 |
+
|
216 |
+
# Common treatment terms
|
217 |
+
treatment_terms = [
|
218 |
+
'therapy', 'treatment', 'drug', 'medication', 'topical', 'oral',
|
219 |
+
'systemic', 'immunosuppressive', 'corticosteroid', 'antibiotic',
|
220 |
+
'antifungal', 'antiviral', 'chemotherapy', 'radiotherapy',
|
221 |
+
'surgical', 'laser', 'phototherapy', 'immunotherapy', 'biologic',
|
222 |
+
'methotrexate', 'cyclosporine', 'tacrolimus', 'rituximab'
|
223 |
+
]
|
224 |
+
|
225 |
+
treatment_counts = Counter()
|
226 |
+
|
227 |
+
for article in self.articles:
|
228 |
+
text = (article['treatment'] + ' ' + article['abstract']).lower()
|
229 |
+
for term in treatment_terms:
|
230 |
+
if term in text:
|
231 |
+
treatment_counts[term] += 1
|
232 |
+
|
233 |
+
print("Top 15 treatment mentions:")
|
234 |
+
for treatment, count in treatment_counts.most_common(15):
|
235 |
+
print(f" {treatment}: {count} mentions")
|
236 |
+
|
237 |
+
# Create visualization
|
238 |
+
plt.figure(figsize=(12, 8))
|
239 |
+
top_treatments = dict(treatment_counts.most_common(15))
|
240 |
+
plt.barh(list(top_treatments.keys()), list(top_treatments.values()))
|
241 |
+
plt.title('Top 15 Treatment Mentions')
|
242 |
+
plt.xlabel('Number of Mentions')
|
243 |
+
plt.tight_layout()
|
244 |
+
plt.show()
|
245 |
+
|
246 |
+
return treatment_counts
|
247 |
+
|
248 |
+
def keyword_analysis(self):
|
249 |
+
"""Perform keyword analysis using TF-IDF"""
|
250 |
+
print("\n=== KEYWORD ANALYSIS ===")
|
251 |
+
|
252 |
+
# Combine title and abstract for each article
|
253 |
+
documents = []
|
254 |
+
for article in self.articles:
|
255 |
+
doc = article['title'] + ' ' + article['abstract']
|
256 |
+
documents.append(doc)
|
257 |
+
|
258 |
+
# TF-IDF analysis
|
259 |
+
stop_words = set(stopwords.words('english'))
|
260 |
+
stop_words.update(['study', 'research', 'analysis', 'results', 'conclusion', 'background', 'methods'])
|
261 |
+
|
262 |
+
vectorizer = TfidfVectorizer(
|
263 |
+
max_features=100,
|
264 |
+
stop_words=list(stop_words),
|
265 |
+
ngram_range=(1, 2),
|
266 |
+
min_df=2,
|
267 |
+
max_df=0.8
|
268 |
+
)
|
269 |
+
|
270 |
+
tfidf_matrix = vectorizer.fit_transform(documents)
|
271 |
+
feature_names = vectorizer.get_feature_names_out()
|
272 |
+
|
273 |
+
# Get top keywords
|
274 |
+
mean_scores = np.mean(tfidf_matrix.toarray(), axis=0)
|
275 |
+
top_indices = np.argsort(mean_scores)[::-1][:20]
|
276 |
+
|
277 |
+
print("Top 20 keywords by TF-IDF score:")
|
278 |
+
for i, idx in enumerate(top_indices):
|
279 |
+
print(f" {i+1}. {feature_names[idx]}: {mean_scores[idx]:.4f}")
|
280 |
+
|
281 |
+
# Create word cloud
|
282 |
+
all_text = ' '.join(documents)
|
283 |
+
wordcloud = WordCloud(
|
284 |
+
width=800,
|
285 |
+
height=400,
|
286 |
+
background_color='white',
|
287 |
+
stopwords=stop_words,
|
288 |
+
max_words=100
|
289 |
+
).generate(all_text)
|
290 |
+
|
291 |
+
plt.figure(figsize=(12, 6))
|
292 |
+
plt.imshow(wordcloud, interpolation='bilinear')
|
293 |
+
plt.axis('off')
|
294 |
+
plt.title('Word Cloud of Skin Disease Articles')
|
295 |
+
plt.tight_layout()
|
296 |
+
plt.show()
|
297 |
+
|
298 |
+
return feature_names, mean_scores
|
299 |
+
|
300 |
+
def readability_analysis(self):
|
301 |
+
"""Analyze text readability"""
|
302 |
+
print("\n=== READABILITY ANALYSIS ===")
|
303 |
+
|
304 |
+
flesch_scores = []
|
305 |
+
grade_levels = []
|
306 |
+
|
307 |
+
for article in self.articles:
|
308 |
+
if article['abstract']:
|
309 |
+
try:
|
310 |
+
flesch_score = flesch_reading_ease(article['abstract'])
|
311 |
+
grade_level = flesch_kincaid_grade(article['abstract'])
|
312 |
+
flesch_scores.append(flesch_score)
|
313 |
+
grade_levels.append(grade_level)
|
314 |
+
except:
|
315 |
+
continue
|
316 |
+
|
317 |
+
print(f"Average Flesch Reading Ease Score: {np.mean(flesch_scores):.1f}")
|
318 |
+
print(f"Average Grade Level: {np.mean(grade_levels):.1f}")
|
319 |
+
|
320 |
+
# Interpretation
|
321 |
+
avg_flesch = np.mean(flesch_scores)
|
322 |
+
if avg_flesch >= 90:
|
323 |
+
difficulty = "Very Easy"
|
324 |
+
elif avg_flesch >= 80:
|
325 |
+
difficulty = "Easy"
|
326 |
+
elif avg_flesch >= 70:
|
327 |
+
difficulty = "Fairly Easy"
|
328 |
+
elif avg_flesch >= 60:
|
329 |
+
difficulty = "Standard"
|
330 |
+
elif avg_flesch >= 50:
|
331 |
+
difficulty = "Fairly Difficult"
|
332 |
+
elif avg_flesch >= 30:
|
333 |
+
difficulty = "Difficult"
|
334 |
+
else:
|
335 |
+
difficulty = "Very Difficult"
|
336 |
+
|
337 |
+
print(f"Reading Difficulty: {difficulty}")
|
338 |
+
|
339 |
+
return flesch_scores, grade_levels
|
340 |
+
|
341 |
+
def generate_summary_report(self):
|
342 |
+
"""Generate a comprehensive summary report"""
|
343 |
+
print("\n" + "="*50)
|
344 |
+
print("COMPREHENSIVE EDA SUMMARY REPORT")
|
345 |
+
print("="*50)
|
346 |
+
|
347 |
+
# Run all analyses
|
348 |
+
basic_stats = self.basic_statistics()
|
349 |
+
journal_counts = self.journal_analysis()
|
350 |
+
author_counts = self.author_analysis()
|
351 |
+
disease_counts = self.disease_analysis()
|
352 |
+
treatment_counts = self.treatment_analysis()
|
353 |
+
keywords, scores = self.keyword_analysis()
|
354 |
+
flesch_scores, grade_levels = self.readability_analysis()
|
355 |
+
|
356 |
+
# Summary insights
|
357 |
+
print("\n=== KEY INSIGHTS ===")
|
358 |
+
print(f"1. Corpus contains {basic_stats['total_articles']} articles from {len(journal_counts)} unique journals")
|
359 |
+
print(f"2. Most common disease area: {disease_counts.most_common(1)[0][0] if disease_counts else 'N/A'}")
|
360 |
+
print(f"3. Most common treatment approach: {treatment_counts.most_common(1)[0][0] if treatment_counts else 'N/A'}")
|
361 |
+
print(f"4. Average reading level: Grade {np.mean(grade_levels):.1f}")
|
362 |
+
print(f"5. {basic_stats['with_diagnosis']} articles have specific diagnosis information")
|
363 |
+
print(f"6. {basic_stats['with_treatment']} articles contain treatment information")
|
364 |
+
|
365 |
+
def main():
|
366 |
+
# Initialize EDA
|
367 |
+
eda = SkinDiseaseEDA('skin_disease_articles_clean.txt')
|
368 |
+
|
369 |
+
# Generate comprehensive report
|
370 |
+
eda.generate_summary_report()
|
371 |
+
|
372 |
+
# Set up plotting style
|
373 |
+
plt.style.use('seaborn-v0_8')
|
374 |
+
sns.set_palette("husl")
|
375 |
+
|
376 |
+
print("\n" + "="*50)
|
377 |
+
print("EDA ANALYSIS COMPLETE")
|
378 |
+
print("="*50)
|
379 |
+
|
380 |
+
if __name__ == "__main__":
|
381 |
+
main()
|
extraction.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from docx import Document
|
2 |
+
|
3 |
+
doc = Document("skin_disease_articles.docx")
|
4 |
+
with open("skin_disease_articles.txt", "w", encoding="utf-8") as f:
|
5 |
+
for para in doc.paragraphs:
|
6 |
+
f.write(para.text + "\n")
|
finetune.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
2 |
+
from datasets import load_dataset
|
3 |
+
import os
|
4 |
+
|
5 |
+
os.environ["USE_TF"] = "0"
|
6 |
+
|
7 |
+
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
+
if tokenizer.pad_token is None:
|
10 |
+
tokenizer.pad_token = tokenizer.eos_token
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
+
|
13 |
+
# Load your text file as a dataset
|
14 |
+
dataset = load_dataset("text", data_files={"train": "skin_disease_articles_clean.txt"})
|
15 |
+
|
16 |
+
# Tokenize the dataset
|
17 |
+
def tokenize_function(examples):
|
18 |
+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
|
19 |
+
|
20 |
+
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
21 |
+
|
22 |
+
train_dataset = tokenized_datasets["train"]
|
23 |
+
|
24 |
+
data_collator = DataCollatorForLanguageModeling(
|
25 |
+
tokenizer=tokenizer, mlm=False
|
26 |
+
)
|
27 |
+
|
28 |
+
training_args = TrainingArguments(
|
29 |
+
output_dir="./tinyllama-finetuned-skin",
|
30 |
+
overwrite_output_dir=True,
|
31 |
+
num_train_epochs=1,
|
32 |
+
per_device_train_batch_size=2,
|
33 |
+
save_steps=500,
|
34 |
+
save_total_limit=2,
|
35 |
+
prediction_loss_only=True,
|
36 |
+
fp16=True # Set True if using GPU with float16 support
|
37 |
+
)
|
38 |
+
|
39 |
+
trainer = Trainer(
|
40 |
+
model=model,
|
41 |
+
args=training_args,
|
42 |
+
data_collator=data_collator,
|
43 |
+
train_dataset=train_dataset,
|
44 |
+
)
|
45 |
+
|
46 |
+
trainer.train()
|
finetune_tinyllama.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
|
2 |
+
|
3 |
+
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
5 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
6 |
+
|
7 |
+
# Prepare dataset
|
8 |
+
def load_dataset(file_path, tokenizer, block_size=128):
|
9 |
+
return TextDataset(
|
10 |
+
tokenizer=tokenizer,
|
11 |
+
file_path=file_path,
|
12 |
+
block_size=block_size
|
13 |
+
)
|
14 |
+
|
15 |
+
train_dataset = load_dataset("skin_disease_articles_clean.txt", tokenizer)
|
16 |
+
|
17 |
+
data_collator = DataCollatorForLanguageModeling(
|
18 |
+
tokenizer=tokenizer, mlm=False
|
19 |
+
)
|
20 |
+
|
21 |
+
training_args = TrainingArguments(
|
22 |
+
output_dir="./tinyllama-finetuned-skin",
|
23 |
+
overwrite_output_dir=True,
|
24 |
+
num_train_epochs=1,
|
25 |
+
per_device_train_batch_size=2,
|
26 |
+
save_steps=500,
|
27 |
+
save_total_limit=2,
|
28 |
+
prediction_loss_only=True,
|
29 |
+
fp16=False # Set True if using GPU with float16 support
|
30 |
+
)
|
31 |
+
|
32 |
+
trainer = Trainer(
|
33 |
+
model=model,
|
34 |
+
args=training_args,
|
35 |
+
data_collator=data_collator,
|
36 |
+
train_dataset=train_dataset,
|
37 |
+
)
|
38 |
+
|
39 |
+
trainer.train()
|
gradio-app.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
demo.launch()
|
llama2_inference.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
|
3 |
+
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # Or local path if downloaded
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
5 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
6 |
+
|
7 |
+
# Example: Use a line from your cleaned file as a prompt
|
8 |
+
with open("skin_disease_articles_clean.txt", "r", encoding="utf-8") as f:
|
9 |
+
prompt = f.readline().strip()
|
10 |
+
|
11 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
12 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
13 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.104.1
|
2 |
+
uvicorn[standard]==0.24.0
|
3 |
+
transformers==4.35.2
|
4 |
+
torch>=2.0.0
|
5 |
+
datasets==2.14.7
|
6 |
+
accelerate==0.24.1
|
7 |
+
pydantic==2.5.0
|
8 |
+
python-multipart==0.0.6
|
9 |
+
gradio==4.15.0
|
10 |
+
requests>=2.25.0
|
xai_analysis.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
7 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
8 |
+
import pandas as pd
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
import lime
|
11 |
+
from lime.lime_text import LimeTextExplainer
|
12 |
+
import shap
|
13 |
+
import re
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings('ignore')
|
16 |
+
|
17 |
+
class LLMExplainabilityAnalyzer:
|
18 |
+
def __init__(self, model_path, tokenizer_path=None):
|
19 |
+
"""Initialize with model and tokenizer paths"""
|
20 |
+
self.model_path = model_path
|
21 |
+
self.tokenizer_path = tokenizer_path or model_path
|
22 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
|
24 |
+
# Load model and tokenizer
|
25 |
+
self.load_model()
|
26 |
+
|
27 |
+
# Initialize explanation tools
|
28 |
+
self.lime_explainer = LimeTextExplainer(class_names=['Generated Text'])
|
29 |
+
|
30 |
+
def load_model(self):
|
31 |
+
"""Load the fine-tuned model and tokenizer"""
|
32 |
+
try:
|
33 |
+
print(f"Loading model from: {self.model_path}")
|
34 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
35 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
36 |
+
self.model_path,
|
37 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
38 |
+
device_map="auto" if torch.cuda.is_available() else None
|
39 |
+
)
|
40 |
+
|
41 |
+
# Set padding token if not exists
|
42 |
+
if self.tokenizer.pad_token is None:
|
43 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
44 |
+
|
45 |
+
print("Model loaded successfully!")
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error loading model: {e}")
|
49 |
+
# Fallback to base model
|
50 |
+
print("Loading base TinyLlama model...")
|
51 |
+
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
52 |
+
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
53 |
+
if self.tokenizer.pad_token is None:
|
54 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
55 |
+
|
56 |
+
def extract_attention_weights(self, text, max_length=512):
|
57 |
+
"""Extract attention weights for visualization"""
|
58 |
+
inputs = self.tokenizer(
|
59 |
+
text,
|
60 |
+
return_tensors="pt",
|
61 |
+
max_length=max_length,
|
62 |
+
truncation=True,
|
63 |
+
padding=True
|
64 |
+
).to(self.device)
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = self.model(**inputs, output_attentions=True)
|
68 |
+
attentions = outputs.attentions
|
69 |
+
|
70 |
+
# Get tokens
|
71 |
+
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
72 |
+
|
73 |
+
return attentions, tokens
|
74 |
+
|
75 |
+
def visualize_attention_heads(self, text, layer_idx=0, head_idx=0, max_length=512):
|
76 |
+
"""Visualize attention patterns for specific layer and head"""
|
77 |
+
attentions, tokens = self.extract_attention_weights(text, max_length)
|
78 |
+
|
79 |
+
# Get attention weights for specific layer and head
|
80 |
+
attention_weights = attentions[layer_idx][0, head_idx].cpu().numpy()
|
81 |
+
|
82 |
+
# Create heatmap
|
83 |
+
plt.figure(figsize=(12, 8))
|
84 |
+
sns.heatmap(
|
85 |
+
attention_weights,
|
86 |
+
xticklabels=tokens,
|
87 |
+
yticklabels=tokens,
|
88 |
+
cmap='Blues',
|
89 |
+
cbar=True
|
90 |
+
)
|
91 |
+
plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')
|
92 |
+
plt.xlabel('Key Tokens')
|
93 |
+
plt.ylabel('Query Tokens')
|
94 |
+
plt.xticks(rotation=45)
|
95 |
+
plt.yticks(rotation=0)
|
96 |
+
plt.tight_layout()
|
97 |
+
plt.show()
|
98 |
+
|
99 |
+
return attention_weights, tokens
|
100 |
+
|
101 |
+
def attention_rollout(self, text, max_length=512):
|
102 |
+
"""Compute attention rollout for global attention patterns"""
|
103 |
+
attentions, tokens = self.extract_attention_weights(text, max_length)
|
104 |
+
|
105 |
+
# Convert to numpy
|
106 |
+
attention_matrices = [att[0].mean(dim=0).cpu().numpy() for att in attentions]
|
107 |
+
|
108 |
+
# Compute rollout
|
109 |
+
rollout = attention_matrices[0]
|
110 |
+
for attention_matrix in attention_matrices[1:]:
|
111 |
+
rollout = np.matmul(rollout, attention_matrix)
|
112 |
+
|
113 |
+
# Visualize rollout
|
114 |
+
plt.figure(figsize=(12, 8))
|
115 |
+
sns.heatmap(
|
116 |
+
rollout,
|
117 |
+
xticklabels=tokens,
|
118 |
+
yticklabels=tokens,
|
119 |
+
cmap='Reds',
|
120 |
+
cbar=True
|
121 |
+
)
|
122 |
+
plt.title('Attention Rollout - Global Attention Flow')
|
123 |
+
plt.xlabel('Key Tokens')
|
124 |
+
plt.ylabel('Query Tokens')
|
125 |
+
plt.xticks(rotation=45)
|
126 |
+
plt.yticks(rotation=0)
|
127 |
+
plt.tight_layout()
|
128 |
+
plt.show()
|
129 |
+
|
130 |
+
return rollout, tokens
|
131 |
+
|
132 |
+
def gradient_saliency(self, text, target_token_idx=None, max_length=512):
|
133 |
+
"""Compute gradient-based saliency maps"""
|
134 |
+
inputs = self.tokenizer(
|
135 |
+
text,
|
136 |
+
return_tensors="pt",
|
137 |
+
max_length=max_length,
|
138 |
+
truncation=True,
|
139 |
+
padding=True
|
140 |
+
).to(self.device)
|
141 |
+
|
142 |
+
# Enable gradients for embeddings
|
143 |
+
embeddings = self.model.get_input_embeddings()
|
144 |
+
inputs_embeds = embeddings(inputs['input_ids'])
|
145 |
+
inputs_embeds.requires_grad_(True)
|
146 |
+
|
147 |
+
# Forward pass
|
148 |
+
outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=inputs['attention_mask'])
|
149 |
+
|
150 |
+
# Get target logits (last token if not specified)
|
151 |
+
if target_token_idx is None:
|
152 |
+
target_token_idx = -1
|
153 |
+
|
154 |
+
target_logits = outputs.logits[0, target_token_idx]
|
155 |
+
target_prob = F.softmax(target_logits, dim=-1)
|
156 |
+
|
157 |
+
# Compute gradients
|
158 |
+
target_prob.max().backward()
|
159 |
+
|
160 |
+
# Get saliency scores
|
161 |
+
saliency_scores = inputs_embeds.grad.norm(dim=-1).squeeze().cpu().numpy()
|
162 |
+
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
163 |
+
|
164 |
+
# Visualize saliency
|
165 |
+
plt.figure(figsize=(15, 6))
|
166 |
+
colors = plt.cm.Reds(saliency_scores / saliency_scores.max())
|
167 |
+
|
168 |
+
for i, (token, score) in enumerate(zip(tokens, saliency_scores)):
|
169 |
+
plt.bar(i, score, color=colors[i])
|
170 |
+
plt.text(i, score + 0.001, token, rotation=45, ha='left', va='bottom')
|
171 |
+
|
172 |
+
plt.title('Gradient Saliency Scores')
|
173 |
+
plt.xlabel('Token Position')
|
174 |
+
plt.ylabel('Saliency Score')
|
175 |
+
plt.tight_layout()
|
176 |
+
plt.show()
|
177 |
+
|
178 |
+
return saliency_scores, tokens
|
179 |
+
|
180 |
+
def lime_explanation(self, text, num_samples=1000):
|
181 |
+
"""Generate LIME explanations"""
|
182 |
+
def predict_fn(texts):
|
183 |
+
"""Prediction function for LIME"""
|
184 |
+
predictions = []
|
185 |
+
for text in texts:
|
186 |
+
try:
|
187 |
+
inputs = self.tokenizer(
|
188 |
+
text,
|
189 |
+
return_tensors="pt",
|
190 |
+
max_length=512,
|
191 |
+
truncation=True,
|
192 |
+
padding=True
|
193 |
+
).to(self.device)
|
194 |
+
|
195 |
+
with torch.no_grad():
|
196 |
+
outputs = self.model(**inputs)
|
197 |
+
logits = outputs.logits[0, -1]
|
198 |
+
probs = F.softmax(logits, dim=-1)
|
199 |
+
|
200 |
+
# Return probability distribution
|
201 |
+
predictions.append(probs.cpu().numpy())
|
202 |
+
except:
|
203 |
+
# Return uniform distribution if error
|
204 |
+
predictions.append(np.ones(self.tokenizer.vocab_size) / self.tokenizer.vocab_size)
|
205 |
+
|
206 |
+
return np.array(predictions)
|
207 |
+
|
208 |
+
# Generate explanation
|
209 |
+
explanation = self.lime_explainer.explain_instance(
|
210 |
+
text,
|
211 |
+
predict_fn,
|
212 |
+
num_features=20,
|
213 |
+
num_samples=num_samples
|
214 |
+
)
|
215 |
+
|
216 |
+
# Visualize explanation
|
217 |
+
explanation.show_in_notebook(text=True)
|
218 |
+
|
219 |
+
return explanation
|
220 |
+
|
221 |
+
def activation_analysis(self, text, layer_indices=None, max_length=512):
|
222 |
+
"""Analyze hidden layer activations"""
|
223 |
+
inputs = self.tokenizer(
|
224 |
+
text,
|
225 |
+
return_tensors="pt",
|
226 |
+
max_length=max_length,
|
227 |
+
truncation=True,
|
228 |
+
padding=True
|
229 |
+
).to(self.device)
|
230 |
+
|
231 |
+
# Hook to capture activations
|
232 |
+
activations = {}
|
233 |
+
|
234 |
+
def hook_fn(name):
|
235 |
+
def hook(module, input, output):
|
236 |
+
activations[name] = output.detach()
|
237 |
+
return hook
|
238 |
+
|
239 |
+
# Register hooks
|
240 |
+
if layer_indices is None:
|
241 |
+
layer_indices = [0, len(self.model.model.layers)//2, len(self.model.model.layers)-1]
|
242 |
+
|
243 |
+
hooks = []
|
244 |
+
for idx in layer_indices:
|
245 |
+
if idx < len(self.model.model.layers):
|
246 |
+
hook = self.model.model.layers[idx].register_forward_hook(hook_fn(f'layer_{idx}'))
|
247 |
+
hooks.append(hook)
|
248 |
+
|
249 |
+
# Forward pass
|
250 |
+
with torch.no_grad():
|
251 |
+
outputs = self.model(**inputs)
|
252 |
+
|
253 |
+
# Remove hooks
|
254 |
+
for hook in hooks:
|
255 |
+
hook.remove()
|
256 |
+
|
257 |
+
# Analyze activations
|
258 |
+
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
259 |
+
|
260 |
+
for layer_name, activation in activations.items():
|
261 |
+
# Get activation statistics
|
262 |
+
activation_np = activation[0].cpu().numpy()
|
263 |
+
|
264 |
+
# Plot activation distribution
|
265 |
+
plt.figure(figsize=(12, 6))
|
266 |
+
|
267 |
+
# Heatmap of activations
|
268 |
+
plt.subplot(1, 2, 1)
|
269 |
+
sns.heatmap(activation_np.T, cmap='viridis', cbar=True)
|
270 |
+
plt.title(f'{layer_name} Activations')
|
271 |
+
plt.xlabel('Token Position')
|
272 |
+
plt.ylabel('Hidden Dimension')
|
273 |
+
|
274 |
+
# Activation magnitude per token
|
275 |
+
plt.subplot(1, 2, 2)
|
276 |
+
activation_magnitudes = np.linalg.norm(activation_np, axis=1)
|
277 |
+
plt.bar(range(len(tokens)), activation_magnitudes)
|
278 |
+
plt.title(f'{layer_name} Activation Magnitudes')
|
279 |
+
plt.xlabel('Token Position')
|
280 |
+
plt.ylabel('Magnitude')
|
281 |
+
plt.xticks(range(len(tokens)), tokens, rotation=45)
|
282 |
+
|
283 |
+
plt.tight_layout()
|
284 |
+
plt.show()
|
285 |
+
|
286 |
+
def token_importance_analysis(self, text, method='attention', max_length=512):
|
287 |
+
"""Analyze token importance using different methods"""
|
288 |
+
results = {}
|
289 |
+
|
290 |
+
if method == 'attention':
|
291 |
+
# Attention-based importance
|
292 |
+
attentions, tokens = self.extract_attention_weights(text, max_length)
|
293 |
+
|
294 |
+
# Average attention across layers and heads
|
295 |
+
avg_attention = torch.stack([att.mean(dim=1) for att in attentions]).mean(dim=0)
|
296 |
+
importance_scores = avg_attention[0].sum(dim=0).cpu().numpy()
|
297 |
+
|
298 |
+
elif method == 'gradient':
|
299 |
+
# Gradient-based importance
|
300 |
+
importance_scores, tokens = self.gradient_saliency(text, max_length=max_length)
|
301 |
+
|
302 |
+
# Create importance dataframe
|
303 |
+
importance_df = pd.DataFrame({
|
304 |
+
'token': tokens,
|
305 |
+
'importance': importance_scores
|
306 |
+
})
|
307 |
+
|
308 |
+
# Sort by importance
|
309 |
+
importance_df = importance_df.sort_values('importance', ascending=False)
|
310 |
+
|
311 |
+
# Visualize top important tokens
|
312 |
+
plt.figure(figsize=(12, 6))
|
313 |
+
top_tokens = importance_df.head(20)
|
314 |
+
plt.barh(range(len(top_tokens)), top_tokens['importance'])
|
315 |
+
plt.yticks(range(len(top_tokens)), top_tokens['token'])
|
316 |
+
plt.title(f'Top 20 Important Tokens ({method.title()} Method)')
|
317 |
+
plt.xlabel('Importance Score')
|
318 |
+
plt.tight_layout()
|
319 |
+
plt.show()
|
320 |
+
|
321 |
+
return importance_df
|
322 |
+
|
323 |
+
def semantic_similarity_analysis(self, texts, max_length=512):
|
324 |
+
"""Analyze semantic similarity between different texts"""
|
325 |
+
embeddings = []
|
326 |
+
|
327 |
+
for text in texts:
|
328 |
+
inputs = self.tokenizer(
|
329 |
+
text,
|
330 |
+
return_tensors="pt",
|
331 |
+
max_length=max_length,
|
332 |
+
truncation=True,
|
333 |
+
padding=True
|
334 |
+
).to(self.device)
|
335 |
+
|
336 |
+
with torch.no_grad():
|
337 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
338 |
+
# Use last layer, last token embedding
|
339 |
+
embedding = outputs.hidden_states[-1][0, -1].cpu().numpy()
|
340 |
+
embeddings.append(embedding)
|
341 |
+
|
342 |
+
# Compute similarity matrix
|
343 |
+
similarity_matrix = cosine_similarity(embeddings)
|
344 |
+
|
345 |
+
# Visualize similarity matrix
|
346 |
+
plt.figure(figsize=(10, 8))
|
347 |
+
sns.heatmap(
|
348 |
+
similarity_matrix,
|
349 |
+
annot=True,
|
350 |
+
cmap='viridis',
|
351 |
+
xticklabels=[f'Text {i+1}' for i in range(len(texts))],
|
352 |
+
yticklabels=[f'Text {i+1}' for i in range(len(texts))]
|
353 |
+
)
|
354 |
+
plt.title('Semantic Similarity Matrix')
|
355 |
+
plt.tight_layout()
|
356 |
+
plt.show()
|
357 |
+
|
358 |
+
return similarity_matrix
|
359 |
+
|
360 |
+
def generate_explanation_report(self, text, output_file='xai_report.html'):
|
361 |
+
"""Generate comprehensive explanation report"""
|
362 |
+
print("Generating comprehensive XAI report...")
|
363 |
+
|
364 |
+
# Run all analyses
|
365 |
+
print("1. Extracting attention patterns...")
|
366 |
+
attention_weights, tokens = self.visualize_attention_heads(text)
|
367 |
+
|
368 |
+
print("2. Computing attention rollout...")
|
369 |
+
rollout, _ = self.attention_rollout(text)
|
370 |
+
|
371 |
+
print("3. Calculating gradient saliency...")
|
372 |
+
saliency_scores, _ = self.gradient_saliency(text)
|
373 |
+
|
374 |
+
print("4. Analyzing activations...")
|
375 |
+
self.activation_analysis(text)
|
376 |
+
|
377 |
+
print("5. Computing token importance...")
|
378 |
+
importance_df = self.token_importance_analysis(text)
|
379 |
+
|
380 |
+
# Create summary
|
381 |
+
print("\n=== XAI ANALYSIS SUMMARY ===")
|
382 |
+
print(f"Input text: {text[:100]}...")
|
383 |
+
print(f"Number of tokens: {len(tokens)}")
|
384 |
+
print(f"Most important tokens: {importance_df.head(5)['token'].tolist()}")
|
385 |
+
print(f"Average attention entropy: {np.mean(-np.sum(attention_weights * np.log(attention_weights + 1e-10), axis=1)):.4f}")
|
386 |
+
|
387 |
+
return {
|
388 |
+
'attention_weights': attention_weights,
|
389 |
+
'rollout': rollout,
|
390 |
+
'saliency_scores': saliency_scores,
|
391 |
+
'importance_df': importance_df,
|
392 |
+
'tokens': tokens
|
393 |
+
}
|
394 |
+
|
395 |
+
def main():
|
396 |
+
"""Main function to run XAI analysis"""
|
397 |
+
|
398 |
+
# Initialize analyzer (adjust model path as needed)
|
399 |
+
try:
|
400 |
+
analyzer = LLMExplainabilityAnalyzer("./fine_tuned_model")
|
401 |
+
except:
|
402 |
+
print("Fine-tuned model not found. Using base model for demonstration.")
|
403 |
+
analyzer = LLMExplainabilityAnalyzer("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
404 |
+
|
405 |
+
# Sample skin disease text for analysis
|
406 |
+
sample_text = """
|
407 |
+
Patient presents with erythematous scaly patches on the elbows and knees,
|
408 |
+
consistent with psoriasis. The condition appears to be chronic with periods
|
409 |
+
of exacerbation. Treatment options include topical corticosteroids and
|
410 |
+
phototherapy for mild to moderate cases.
|
411 |
+
"""
|
412 |
+
|
413 |
+
print("Starting XAI Analysis...")
|
414 |
+
print("=" * 50)
|
415 |
+
|
416 |
+
# Generate comprehensive report
|
417 |
+
results = analyzer.generate_explanation_report(sample_text)
|
418 |
+
|
419 |
+
# Additional analyses
|
420 |
+
print("\n6. Semantic similarity analysis...")
|
421 |
+
test_texts = [
|
422 |
+
"Psoriasis treatment with topical corticosteroids",
|
423 |
+
"Eczema management using moisturizers",
|
424 |
+
"Melanoma diagnosis and surgical intervention"
|
425 |
+
]
|
426 |
+
|
427 |
+
similarity_matrix = analyzer.semantic_similarity_analysis(test_texts)
|
428 |
+
|
429 |
+
print("\n" + "=" * 50)
|
430 |
+
print("XAI ANALYSIS COMPLETE")
|
431 |
+
print("=" * 50)
|
432 |
+
|
433 |
+
return results
|
434 |
+
|
435 |
+
if __name__ == "__main__":
|
436 |
+
main()
|