drug-interaction-predictor / prediction_service.py
oussamaor's picture
Create prediction_service.py
65ce4be verified
import os
import joblib
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Tuple
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
import sqlite3
import gradio as gr
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Create model directory if it doesn't exist
os.makedirs('model', exist_ok=True)
# Initialize encoders
def initialize_encoders():
try:
# Try to load existing encoders
drug_encoder = joblib.load('model/drug_encoder.joblib')
interaction_encoder = joblib.load('model/interaction_encoder.joblib')
except FileNotFoundError:
print("Encoders not found. Creating new encoders...")
# Create new encoders
drug_encoder = LabelEncoder()
interaction_encoder = LabelEncoder()
# Connect to SQLite database
conn = sqlite3.connect('event.db')
# Get unique drugs from the drug table
drugs_df = pd.read_sql('SELECT DISTINCT name FROM drug', conn)
drug_names = drugs_df['name'].tolist()
# Get unique interaction types
interaction_types = ['no_interaction', 'mild', 'moderate', 'severe']
# Fit encoders
drug_encoder.fit(drug_names)
interaction_encoder.fit(interaction_types)
# Save encoders
joblib.dump(drug_encoder, 'model/drug_encoder.joblib')
joblib.dump(interaction_encoder, 'model/interaction_encoder.joblib')
print("Encoders created and saved successfully.")
# Close database connection
conn.close()
return drug_encoder, interaction_encoder
# Load the model and encoders with error handling
try:
model = tf.keras.models.load_model('model/ddi_model.h5')
drug_encoder, interaction_encoder = initialize_encoders()
print("Model and encoders loaded successfully")
except Exception as e:
print(f"Error loading model or encoders: {str(e)}")
raise Exception("Failed to load model or encoders. Please ensure all model files are present.")
class DrugPair(BaseModel):
drug1: str
drug2: str
class PredictionResponse(BaseModel):
has_interaction: bool
severity: Optional[str] = None
description: Optional[str] = None
confidence: float
warning: Optional[str] = None
def get_drug_embedding(drug_name: str) -> Tuple[np.ndarray, bool]:
"""Get drug embedding from SQLite database. Returns (embedding, is_known_drug)."""
try:
# Connect to SQLite database
conn = sqlite3.connect('event.db')
# Query the drug table for the drug's features
query = '''
SELECT pathway, target, enzyme
FROM drug
WHERE name = ?
'''
drug_data = pd.read_sql(query, conn, params=(drug_name,))
# Create a single feature vector of size 572
feature_vector = np.zeros(572) # VECTOR_SIZE from final.py
if drug_data.empty:
# Drug not found in database, return zero vector
conn.close()
return feature_vector, False
# Combine all features into a single vector
for feature in ['pathway', 'target', 'enzyme']:
if feature in drug_data.columns:
feature_values = drug_data[feature].iloc[0].split('|')
for i, value in enumerate(feature_values):
if i < 572: # VECTOR_SIZE
feature_vector[i] = 1
# Close database connection
conn.close()
return feature_vector, True
except Exception as e:
print(f"Error getting drug embedding: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting drug embedding: {str(e)}")
def get_interaction_description(drug1: str, drug2: str) -> Optional[str]:
"""Get interaction description from SQLite database."""
try:
# Connect to SQLite database
conn = sqlite3.connect('event.db')
# Query the event table
query = '''
SELECT interaction
FROM event
WHERE (name1 = ? AND name2 = ?) OR (name1 = ? AND name2 = ?)
'''
interaction_data = pd.read_sql(query, conn, params=(drug1, drug2, drug2, drug1))
# Close database connection
conn.close()
if not interaction_data.empty:
return interaction_data['interaction'].iloc[0]
return None
except Exception as e:
print(f"Error getting interaction description: {str(e)}")
return None
def get_all_interaction_descriptions() -> List[str]:
"""Get all unique interaction descriptions from the database."""
try:
conn = sqlite3.connect('event.db')
descriptions_df = pd.read_sql('SELECT DISTINCT interaction FROM event WHERE interaction IS NOT NULL', conn)
conn.close()
return descriptions_df['interaction'].tolist()
except Exception as e:
print(f"Error getting interaction descriptions: {str(e)}")
return []
def get_severity_from_description(description: str) -> str:
"""Determine severity based on the interaction description."""
description = description.lower()
if any(word in description for word in ['severe', 'serious', 'life-threatening', 'fatal']):
return 'severe'
elif any(word in description for word in ['moderate', 'significant', 'substantial']):
return 'moderate'
elif any(word in description for word in ['mild', 'minor', 'slight']):
return 'mild'
return 'mild' # default to mild if no clear severity indicators
def predict_interaction(drug1: str, drug2: str) -> Dict:
"""Predict interaction between two drugs."""
try:
# Get embeddings for both drugs
drug1_embedding, drug1_known = get_drug_embedding(drug1)
drug2_embedding, drug2_known = get_drug_embedding(drug2)
# Stack embeddings to create input of shape (1, 572, 2)
combined_embedding = np.stack([drug1_embedding, drug2_embedding], axis=1)
combined_embedding = np.expand_dims(combined_embedding, axis=0) # Add batch dimension
# Make prediction
prediction = model.predict(combined_embedding)
# Get the predicted class and confidence
predicted_class = np.argmax(prediction[0])
confidence = float(prediction[0][predicted_class])
# First try to get the interaction from the database
description = get_interaction_description(drug1, drug2)
# If no description in database, use the model's prediction
if description is None:
# Get all possible descriptions
conn = sqlite3.connect('event.db')
descriptions_df = pd.read_sql('SELECT DISTINCT interaction FROM event WHERE interaction IS NOT NULL', conn)
conn.close()
all_descriptions = descriptions_df['interaction'].tolist()
if predicted_class < len(all_descriptions):
description = all_descriptions[predicted_class]
# Determine severity based on the description
severity = get_severity_from_description(description) if description else None
# Determine if there's an interaction
has_interaction = description is not None and confidence > 0.1
# Generate warning if drugs are not in database
warning = None
if not drug1_known and not drug2_known:
warning = f"Both {drug1} and {drug2} are not in our database. Prediction may be less accurate."
elif not drug1_known:
warning = f"{drug1} is not in our database. Prediction may be less accurate."
elif not drug2_known:
warning = f"{drug2} is not in our database. Prediction may be less accurate."
return {
"has_interaction": has_interaction,
"severity": severity,
"description": description,
"confidence": confidence * 100, # Convert to percentage
"warning": warning
}
except Exception as e:
print(f"Error making prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# FastAPI endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict_endpoint(drug_pair: DrugPair):
return predict_interaction(drug_pair.drug1, drug_pair.drug2)
# Gradio interface
def gradio_predict(drug1: str, drug2: str) -> str:
result = predict_interaction(drug1, drug2)
output = []
if result["warning"]:
output.append(f"⚠️ {result['warning']}")
if result["severity"]:
severity_color = {
"severe": "πŸ”΄",
"moderate": "🟠",
"mild": "🟑"
}
output.append(f"{severity_color[result['severity']]} {result['severity'].upper()} Interaction Detected")
output.append(f"Confidence: {result['confidence']:.1f}%")
if result["description"]:
output.append(f"\nDescription:\n{result['description']}")
return "\n".join(output)
# Create Gradio interface
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Textbox(label="Drug 1"),
gr.Textbox(label="Drug 2")
],
outputs=gr.Textbox(label="Prediction Result"),
title="Drug-Drug Interaction Predictor",
description="Enter two drug names to predict their potential interaction.",
examples=[
["Abemaciclib", "Amiodarone"],
["Aspirin", "Warfarin"],
["Lisinopril", "Ibuprofen"]
]
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)