import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import random import re SYSTEM_PROMPT = """ You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 2 sentences without repeating. Respond in the following format: [correct answer] [explain why the selected option is correct] """ model_name = "abaryan/BioXP-0.5B-MedMCQA" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) dataset = load_dataset("openlifescienceai/medmcqa") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() def get_random_question(): """Get a random question from the dataset""" index = random.randint(0, len(dataset['validation']) - 1) question_data = dataset['validation'][index] return ( question_data['question'], question_data['opa'], question_data['opb'], question_data['opc'], question_data['opd'], question_data.get('cop', None), # Correct option (0-3) question_data.get('exp', None) # Explanation ) def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "", correct_option: int = None, explanation: str = None, temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256): # Determine if this is an MCQ by checking if any option is provided is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d]) if is_mcq: options = [] if option_a.strip(): options.append(f"A. {option_a}") if option_b.strip(): options.append(f"B. {option_b}") if option_c.strip(): options.append(f"C. {option_c}") if option_d.strip(): options.append(f"D. {option_d}") formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options) system_prompt = SYSTEM_PROMPT else: # Format regular question formatted_question = f"Question: {question}" system_prompt = SYSTEM_PROMPT prompt = [ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': formatted_question} ] text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **model_inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, ) generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:] model_response = tokenizer.decode(generated_ids, skip_special_tokens=True) # Clean up the response by removing tags and formatting cleaned_response = model_response cleaned_response = re.sub(r'\s*([A-D])\s*', r'Answer: \1', cleaned_response, flags=re.IGNORECASE) cleaned_response = re.sub(r'\s*(.*?)\s*', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL) # Format output with evaluation if available (only for MCQs) output = cleaned_response # if is_mcq and correct_option is not None: # correct_letter = chr(65 + correct_option) # answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE) # model_answer = answer_match.group(1).upper() if answer_match else "Not found" # is_correct = model_answer == correct_letter # output += f"\n\n---\nEvaluation:\n" # output += f"Correct Answer: {correct_letter}\n" # output += f"Model's Answer: {model_answer}\n" # output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" # if explanation: # output += f"\nExpert Explanation:\n{explanation}" return output with gr.Blocks( title="BioXP Medical MCQ Assistant", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="blue", neutral_hue="slate", radius_size="md", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], ) ) as demo: gr.Markdown(""" # BioXP Medical MCQ Assistant A specialized AI assistant for medical multiple-choice questions. """) with gr.Row(): with gr.Column(scale=1): question = gr.Textbox( label="Medical Question", placeholder="Enter your medical question here...", lines=3, interactive=True, elem_classes=["mobile-input"] ) with gr.Accordion("Options", open=True): option_a = gr.Textbox( label="Option A", placeholder="Enter option A...", interactive=True, elem_classes=["mobile-input"] ) option_b = gr.Textbox( label="Option B", placeholder="Enter option B...", interactive=True, elem_classes=["mobile-input"] ) option_c = gr.Textbox( label="Option C", placeholder="Enter option C...", interactive=True, elem_classes=["mobile-input"] ) option_d = gr.Textbox( label="Option D", placeholder="Enter option D...", interactive=True, elem_classes=["mobile-input"] ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): with gr.Column(scale=1): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Temperature", info="Higher = more creative, Lower = more focused" ) with gr.Column(scale=1): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P", info="Controls response diversity" ) max_tokens = gr.Slider( minimum=50, maximum=512, value=256, step=32, label="Max Response Length", info="Maximum length of the response" ) # Hidden fields correct_option = gr.Number(visible=False) expert_explanation = gr.Textbox(visible=False) with gr.Row(): predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"]) random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"]) with gr.Column(scale=1): output = gr.Textbox( label="Model's Response", lines=12, elem_classes=["response-box", "mobile-output"] ) # Set up button actions predict_btn.click( fn=predict, inputs=[ question, option_a, option_b, option_c, option_d, correct_option, expert_explanation, temperature, top_p, max_tokens ], outputs=output ) random_btn.click( fn=get_random_question, inputs=[], outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] ) gr.HTML(""" """) # Launch the app if __name__ == "__main__": demo.launch(share=False)