Menna-Ahmed's picture
Update app.py
2536648 verified
raw
history blame
5.8 kB
import gradio as gr
from transformers import BlipProcessor, BlipForQuestionAnswering, MarianMTModel, MarianTokenizer
from PIL import Image, ImageDraw, ImageFont
import torch, uuid, os
from datetime import datetime
# ✅ Load BLIP model
blip_model = BlipForQuestionAnswering.from_pretrained("sharawy53/diploma")
processor = BlipProcessor.from_pretrained("sharawy53/diploma")
# ✅ Load translation models
ar_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
ar_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
en_ar_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
en_ar_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# ✅ Manual Arabic medical term dictionary
medical_terms = {
"chest x-ray": "أشعة سينية للصدر",
"x-ray": "أشعة سينية",
"ct scan": "تصوير مقطعي محوسب",
"mri": "تصوير بالرنين المغناطيسي",
"ultrasound": "تصوير بالموجات فوق الصوتية",
"normal": "طبيعي",
"abnormal": "غير طبيعي",
"brain": "الدماغ",
"fracture": "كسر",
"no abnormality detected": "لا توجد شذوذات",
"left lung": "الرئة اليسرى",
"right lung": "الرئة اليمنى",
"x - ray": "أشعة سينية",
"chest x - ray": "أشعة سينية",
"cardiomegaly": "تضخم القلب"
,"ct":"المسح المقطعي",
"CT":"المسح المقطعي",
"MRI": "تصوير بالرنين المغناطيسي"
}
# تصحيحات لبعض الأسئلة المعروفة
question_fixes = {
"what is the unnatural in this image?": "ما الشيء غير الطبيعي في هذه الصورة؟",
"what is abnormal in this image?": "ما الشيء غير الطبيعي في هذه الصورة؟",
"is this image normal?": "هل هذه الصورة طبيعية؟"
}
# ✅ Translation utilities
def translate_ar_to_en(text):
inputs = ar_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = ar_en_model.generate(**inputs)
return ar_en_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
def translate_en_to_ar(text):
if text.lower().strip() in question_fixes:
return question_fixes[text.lower().strip()]
inputs = en_ar_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = en_ar_model.generate(**inputs)
translated = en_ar_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
if "القرآن" in translated or "يتفاعل" in translated:
return question_fixes.get(text.lower().strip(), "سؤال غير مفهوم")
return translated
def translate_answer_medical(answer_en):
return medical_terms.get(answer_en.lower().strip(), translate_en_to_ar(answer_en))
# ✅ Arabic font helper
def get_font(size=22):
try:
return ImageFont.truetype("Amiri-Regular.ttf", size)
except:
return ImageFont.load_default()
# ✅ Report generation function
def generate_report_image(image, question_ar, question_en, answer_ar, answer_en):
width, height = 1000, 700
background = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(background)
font = get_font(22)
font_bold = get_font(26)
draw.text((40, 20), "📋 Medical VQA Screenshot Report", font=font_bold, fill="black")
draw.text((40, 60), f"🕓 Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", font=font, fill="gray")
img_resized = image.resize((300, 300))
background.paste(img_resized, (50, 110))
x, y = 380, 110
spacing = 60
lines = [
f" السؤال بالعربية:\n{question_ar}",
f" Question in English:\n{question_en}",
f" الإجابة بالعربية:\n{answer_ar}",
f" Answer in English:\n{answer_en}"
]
for line in lines:
for subline in line.split("\n"):
draw.text((x, y), subline, font=font, fill="black")
y += spacing
file_name = f"report_{uuid.uuid4().hex[:8]}.png"
background.save(file_name)
return file_name
# ✅ Main VQA function
def vqa_multilingual(image, question):
if not image or not question.strip():
return "يرجى رفع صورة وكتابة سؤال.", "", "", "", None
is_arabic = any('\u0600' <= c <= '\u06FF' for c in question)
question_ar = question.strip() if is_arabic else translate_en_to_ar(question)
question_en = translate_ar_to_en(question) if is_arabic else question.strip()
inputs = processor(image, question_en, return_tensors="pt")
with torch.no_grad():
output = blip_model.generate(**inputs)
answer_en = processor.decode(output[0], skip_special_tokens=True).strip()
answer_ar = translate_answer_medical(answer_en)
report_image_path = generate_report_image(image, question_ar, question_en, answer_ar, answer_en)
return (
question_ar,
question_en,
answer_ar,
answer_en,
report_image_path
)
# ✅ Gradio interface
gr.Interface(
fn=vqa_multilingual,
inputs=[
gr.Image(type="pil", label="📷 Upload Medical Image"),
gr.Textbox(label="💬 Your Question (Arabic or English)")
],
outputs=[
gr.Textbox(label="🟠 Arabic Question"),
gr.Textbox(label="🟢 English Question"),
gr.Textbox(label="🟠 Arabic Answer"),
gr.Textbox(label="🟢 English Answer"),
gr.Image(type="filepath", label="📸 Report Screenshot")
],
title="🧠 Bilingual Medical VQA",
description="Upload an X-ray or medical image and ask a question in Arabic or English. Get bilingual answers and an image-based report."
).launch()