Menna-Ahmed's picture
Update app.py
2536648 verified
import gradio as gr
from transformers import BlipProcessor, BlipForQuestionAnswering, MarianMTModel, MarianTokenizer
from PIL import Image, ImageDraw, ImageFont
import torch, uuid, os
from datetime import datetime
# ✅ Load BLIP model
blip_model = BlipForQuestionAnswering.from_pretrained("sharawy53/diploma")
processor = BlipProcessor.from_pretrained("sharawy53/diploma")
# ✅ Load translation models
ar_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
ar_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
en_ar_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
en_ar_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# ✅ Manual Arabic medical term dictionary
medical_terms = {
"chest x-ray": "أشعة سينية للصدر",
"x-ray": "أشعة سينية",
"ct scan": "تصوير مقطعي محوسب",
"mri": "تصوير بالرنين المغناطيسي",
"ultrasound": "تصوير بالموجات فوق الصوتية",
"normal": "طبيعي",
"abnormal": "غير طبيعي",
"brain": "الدماغ",
"fracture": "كسر",
"no abnormality detected": "لا توجد شذوذات",
"left lung": "الرئة اليسرى",
"right lung": "الرئة اليمنى",
"x - ray": "أشعة سينية",
"chest x - ray": "أشعة سينية",
"cardiomegaly": "تضخم القلب"
,"ct":"المسح المقطعي",
"CT":"المسح المقطعي",
"MRI": "تصوير بالرنين المغناطيسي"
}
# تصحيحات لبعض الأسئلة المعروفة
question_fixes = {
"what is the unnatural in this image?": "ما الشيء غير الطبيعي في هذه الصورة؟",
"what is abnormal in this image?": "ما الشيء غير الطبيعي في هذه الصورة؟",
"is this image normal?": "هل هذه الصورة طبيعية؟"
}
# ✅ Translation utilities
def translate_ar_to_en(text):
inputs = ar_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = ar_en_model.generate(**inputs)
return ar_en_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
def translate_en_to_ar(text):
if text.lower().strip() in question_fixes:
return question_fixes[text.lower().strip()]
inputs = en_ar_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = en_ar_model.generate(**inputs)
translated = en_ar_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
if "القرآن" in translated or "يتفاعل" in translated:
return question_fixes.get(text.lower().strip(), "سؤال غير مفهوم")
return translated
def translate_answer_medical(answer_en):
return medical_terms.get(answer_en.lower().strip(), translate_en_to_ar(answer_en))
# ✅ Arabic font helper
def get_font(size=22):
try:
return ImageFont.truetype("Amiri-Regular.ttf", size)
except:
return ImageFont.load_default()
# ✅ Report generation function
def generate_report_image(image, question_ar, question_en, answer_ar, answer_en):
width, height = 1000, 700
background = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(background)
font = get_font(22)
font_bold = get_font(26)
draw.text((40, 20), "📋 Medical VQA Screenshot Report", font=font_bold, fill="black")
draw.text((40, 60), f"🕓 Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", font=font, fill="gray")
img_resized = image.resize((300, 300))
background.paste(img_resized, (50, 110))
x, y = 380, 110
spacing = 60
lines = [
f" السؤال بالعربية:\n{question_ar}",
f" Question in English:\n{question_en}",
f" الإجابة بالعربية:\n{answer_ar}",
f" Answer in English:\n{answer_en}"
]
for line in lines:
for subline in line.split("\n"):
draw.text((x, y), subline, font=font, fill="black")
y += spacing
file_name = f"report_{uuid.uuid4().hex[:8]}.png"
background.save(file_name)
return file_name
# ✅ Main VQA function
def vqa_multilingual(image, question):
if not image or not question.strip():
return "يرجى رفع صورة وكتابة سؤال.", "", "", "", None
is_arabic = any('\u0600' <= c <= '\u06FF' for c in question)
question_ar = question.strip() if is_arabic else translate_en_to_ar(question)
question_en = translate_ar_to_en(question) if is_arabic else question.strip()
inputs = processor(image, question_en, return_tensors="pt")
with torch.no_grad():
output = blip_model.generate(**inputs)
answer_en = processor.decode(output[0], skip_special_tokens=True).strip()
answer_ar = translate_answer_medical(answer_en)
report_image_path = generate_report_image(image, question_ar, question_en, answer_ar, answer_en)
return (
question_ar,
question_en,
answer_ar,
answer_en,
report_image_path
)
# ✅ Gradio interface
gr.Interface(
fn=vqa_multilingual,
inputs=[
gr.Image(type="pil", label="📷 Upload Medical Image"),
gr.Textbox(label="💬 Your Question (Arabic or English)")
],
outputs=[
gr.Textbox(label="🟠 Arabic Question"),
gr.Textbox(label="🟢 English Question"),
gr.Textbox(label="🟠 Arabic Answer"),
gr.Textbox(label="🟢 English Answer"),
gr.Image(type="filepath", label="📸 Report Screenshot")
],
title="🧠 Bilingual Medical VQA",
description="Upload an X-ray or medical image and ask a question in Arabic or English. Get bilingual answers and an image-based report."
).launch()