import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random
import re
SYSTEM_PROMPT = """
You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 2 sentences without repeating.
Respond in the following format:
[correct answer]
[explain why the selected option is correct]
"""
model_name = "abaryan/BioXP-0.5B-MedMCQA"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("openlifescienceai/medmcqa")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
def get_random_question():
"""Get a random question from the dataset"""
index = random.randint(0, len(dataset['validation']) - 1)
question_data = dataset['validation'][index]
return (
question_data['question'],
question_data['opa'],
question_data['opb'],
question_data['opc'],
question_data['opd'],
question_data.get('cop', None), # Correct option (0-3)
question_data.get('exp', None) # Explanation
)
def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "",
correct_option: int = None, explanation: str = None,
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
# Determine if this is an MCQ by checking if any option is provided
is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d])
if is_mcq:
options = []
if option_a.strip(): options.append(f"A. {option_a}")
if option_b.strip(): options.append(f"B. {option_b}")
if option_c.strip(): options.append(f"C. {option_c}")
if option_d.strip(): options.append(f"D. {option_d}")
formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options)
system_prompt = SYSTEM_PROMPT
else:
# Format regular question
formatted_question = f"Question: {question}"
system_prompt = SYSTEM_PROMPT
prompt = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': formatted_question}
]
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Clean up the response by removing tags and formatting
cleaned_response = model_response
cleaned_response = re.sub(r'\s*([A-D])\s*', r'Answer: \1', cleaned_response, flags=re.IGNORECASE)
cleaned_response = re.sub(r'\s*(.*?)\s*', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL)
# Format output with evaluation if available (only for MCQs)
output = cleaned_response
# if is_mcq and correct_option is not None:
# correct_letter = chr(65 + correct_option)
# answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE)
# model_answer = answer_match.group(1).upper() if answer_match else "Not found"
# is_correct = model_answer == correct_letter
# output += f"\n\n---\nEvaluation:\n"
# output += f"Correct Answer: {correct_letter}\n"
# output += f"Model's Answer: {model_answer}\n"
# output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
# if explanation:
# output += f"\nExpert Explanation:\n{explanation}"
return output
with gr.Blocks(
title="BioXP Medical MCQ Assistant",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="blue",
neutral_hue="slate",
radius_size="md",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
)
) as demo:
gr.Markdown("""
# BioXP Medical MCQ Assistant
A specialized AI assistant for medical multiple-choice questions.
""")
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
label="Medical Question",
placeholder="Enter your medical question here...",
lines=3,
interactive=True,
elem_classes=["mobile-input"]
)
with gr.Accordion("Options", open=True):
option_a = gr.Textbox(
label="Option A",
placeholder="Enter option A...",
interactive=True,
elem_classes=["mobile-input"]
)
option_b = gr.Textbox(
label="Option B",
placeholder="Enter option B...",
interactive=True,
elem_classes=["mobile-input"]
)
option_c = gr.Textbox(
label="Option C",
placeholder="Enter option C...",
interactive=True,
elem_classes=["mobile-input"]
)
option_d = gr.Textbox(
label="Option D",
placeholder="Enter option D...",
interactive=True,
elem_classes=["mobile-input"]
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more focused"
)
with gr.Column(scale=1):
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P",
info="Controls response diversity"
)
max_tokens = gr.Slider(
minimum=50,
maximum=512,
value=256,
step=32,
label="Max Response Length",
info="Maximum length of the response"
)
# Hidden fields
correct_option = gr.Number(visible=False)
expert_explanation = gr.Textbox(visible=False)
with gr.Row():
predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"])
random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"])
with gr.Column(scale=1):
output = gr.Textbox(
label="Model's Response",
lines=12,
elem_classes=["response-box", "mobile-output"]
)
# Set up button actions
predict_btn.click(
fn=predict,
inputs=[
question, option_a, option_b, option_c, option_d,
correct_option, expert_explanation,
temperature, top_p, max_tokens
],
outputs=output
)
random_btn.click(
fn=get_random_question,
inputs=[],
outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
)
gr.HTML("""
""")
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)