Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from datasets import load_dataset | |
import random | |
import re | |
SYSTEM_PROMPT = """ | |
You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 2 sentences without repeating. | |
Respond in the following format: | |
<answer> | |
[correct answer] | |
</answer> | |
<reasoning> | |
[explain why the selected option is correct] | |
</reasoning> | |
""" | |
model_name = "abaryan/BioXP-0.5B-MedMCQA" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
dataset = load_dataset("openlifescienceai/medmcqa") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
def get_random_question(): | |
"""Get a random question from the dataset""" | |
index = random.randint(0, len(dataset['validation']) - 1) | |
question_data = dataset['validation'][index] | |
return ( | |
question_data['question'], | |
question_data['opa'], | |
question_data['opb'], | |
question_data['opc'], | |
question_data['opd'], | |
question_data.get('cop', None), # Correct option (0-3) | |
question_data.get('exp', None) # Explanation | |
) | |
def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "", | |
correct_option: int = None, explanation: str = None, | |
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256): | |
# Determine if this is an MCQ by checking if any option is provided | |
is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d]) | |
if is_mcq: | |
options = [] | |
if option_a.strip(): options.append(f"A. {option_a}") | |
if option_b.strip(): options.append(f"B. {option_b}") | |
if option_c.strip(): options.append(f"C. {option_c}") | |
if option_d.strip(): options.append(f"D. {option_d}") | |
formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options) | |
system_prompt = SYSTEM_PROMPT | |
else: | |
# Format regular question | |
formatted_question = f"Question: {question}" | |
system_prompt = SYSTEM_PROMPT | |
prompt = [ | |
{'role': 'system', 'content': system_prompt}, | |
{'role': 'user', 'content': formatted_question} | |
] | |
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | |
model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:] | |
model_response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
# Clean up the response by removing tags and formatting | |
cleaned_response = model_response | |
cleaned_response = re.sub(r'<answer>\s*([A-D])\s*</answer>', r'Answer: \1', cleaned_response, flags=re.IGNORECASE) | |
cleaned_response = re.sub(r'<reasoning>\s*(.*?)\s*</reasoning>', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL) | |
# Format output with evaluation if available (only for MCQs) | |
output = cleaned_response | |
# if is_mcq and correct_option is not None: | |
# correct_letter = chr(65 + correct_option) | |
# answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE) | |
# model_answer = answer_match.group(1).upper() if answer_match else "Not found" | |
# is_correct = model_answer == correct_letter | |
# output += f"\n\n---\nEvaluation:\n" | |
# output += f"Correct Answer: {correct_letter}\n" | |
# output += f"Model's Answer: {model_answer}\n" | |
# output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" | |
# if explanation: | |
# output += f"\nExpert Explanation:\n{explanation}" | |
return output | |
with gr.Blocks( | |
title="BioXP Medical MCQ Assistant", | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size="md", | |
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
) as demo: | |
gr.Markdown(""" | |
# BioXP Medical MCQ Assistant | |
A specialized AI assistant for medical multiple-choice questions. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
question = gr.Textbox( | |
label="Medical Question", | |
placeholder="Enter your medical question here...", | |
lines=3, | |
interactive=True, | |
elem_classes=["mobile-input"] | |
) | |
with gr.Accordion("Options", open=True): | |
option_a = gr.Textbox( | |
label="Option A", | |
placeholder="Enter option A...", | |
interactive=True, | |
elem_classes=["mobile-input"] | |
) | |
option_b = gr.Textbox( | |
label="Option B", | |
placeholder="Enter option B...", | |
interactive=True, | |
elem_classes=["mobile-input"] | |
) | |
option_c = gr.Textbox( | |
label="Option C", | |
placeholder="Enter option C...", | |
interactive=True, | |
elem_classes=["mobile-input"] | |
) | |
option_d = gr.Textbox( | |
label="Option D", | |
placeholder="Enter option D...", | |
interactive=True, | |
elem_classes=["mobile-input"] | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.6, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more creative, Lower = more focused" | |
) | |
with gr.Column(scale=1): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P", | |
info="Controls response diversity" | |
) | |
max_tokens = gr.Slider( | |
minimum=50, | |
maximum=512, | |
value=256, | |
step=32, | |
label="Max Response Length", | |
info="Maximum length of the response" | |
) | |
# Hidden fields | |
correct_option = gr.Number(visible=False) | |
expert_explanation = gr.Textbox(visible=False) | |
with gr.Row(): | |
predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"]) | |
random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"]) | |
with gr.Column(scale=1): | |
output = gr.Textbox( | |
label="Model's Response", | |
lines=12, | |
elem_classes=["response-box", "mobile-output"] | |
) | |
# Set up button actions | |
predict_btn.click( | |
fn=predict, | |
inputs=[ | |
question, option_a, option_b, option_c, option_d, | |
correct_option, expert_explanation, | |
temperature, top_p, max_tokens | |
], | |
outputs=output | |
) | |
random_btn.click( | |
fn=get_random_question, | |
inputs=[], | |
outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] | |
) | |
gr.HTML(""" | |
<style> | |
.container { | |
max-width: 100%; | |
padding: 0.5rem; | |
} | |
/* Input styling */ | |
.mobile-input textarea { | |
font-size: 1rem; | |
padding: 0.75rem; | |
border-radius: 0.5rem; | |
min-height: 2.5rem; | |
} | |
/* Button styling */ | |
.mobile-button { | |
width: 100%; | |
margin: 0.5rem 0; | |
padding: 0.75rem; | |
font-size: 1rem; | |
font-weight: 500; | |
} | |
.response-box { | |
font-family: 'Inter', sans-serif; | |
line-height: 1.6; | |
} | |
.response-box textarea { | |
font-size: 1rem; | |
padding: 1rem; | |
border-radius: 0.5rem; | |
} | |
/* Mobile-specific adjustments */ | |
@media (max-width: 768px) { | |
.gr-form { | |
padding: 0.75rem; | |
} | |
.gr-box { | |
margin: 0.5rem 0; | |
} | |
.gr-button { | |
min-height: 2.5rem; | |
} | |
.gr-accordion { | |
margin: 0.5rem 0; | |
} | |
.gr-input { | |
margin-bottom: 0.5rem; | |
} | |
} | |
/* Dark mode support */ | |
@media (prefers-color-scheme: dark) { | |
.gr-box { | |
background-color: #1a1a1a; | |
} | |
.mobile-input textarea, | |
.response-box textarea { | |
background-color: #2a2a2a; | |
color: #ffffff; | |
} | |
} | |
</style> | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=False) |