import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import time import spaces # Model configuration MODEL_NAME = "krishna195/medgemma-anatomy-v1.2" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load model and tokenizer def load_model(): """Load model with 4-bit quantization for efficiency""" print(f"Loading model: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token if DEVICE == "cuda": # 4-bit quantization for Spaces GPU bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) else: # CPU fallback model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) model.eval() print(f"Model loaded on {DEVICE}") return model, tokenizer # Initialize model print("Initializing MedGemma...") model, tokenizer = load_model() @spaces.GPU(duration=60) def generate_response(question, max_tokens=512, temperature=0.7, top_p=0.9): """ Generate medical response for a given question Args: question: Medical question max_tokens: Maximum tokens to generate temperature: Sampling temperature (0.0-1.0) top_p: Nucleus sampling parameter """ try: if not question.strip(): return "āš ļø Please enter a medical question." # Show processing message yield "šŸ”„ **Processing your question...**\n\nGenerating response, please wait..." # Format prompt with Gemma chat template prompt = f"""user {question} model """ # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate start_time = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, top_p=top_p, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id ) generation_time = time.time() - start_time # Decode response full_output = tokenizer.decode(outputs[0], skip_special_tokens=False) # Extract model response if "model" in full_output: response = full_output.split("model")[-1] response = response.split("")[0].strip() else: response = full_output.strip() # Add metadata tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1] tokens_per_sec = tokens_generated / generation_time if generation_time > 0 else 0 metadata = f"\n\n---\nāœ… *Generated in {generation_time:.2f}s ({tokens_per_sec:.1f} tokens/sec) | Device: {DEVICE.upper()}*" yield response + metadata except Exception as e: error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n```\n\nPlease try again or contact support if the issue persists." yield error_msg # Example questions examples = [ ["A 28-year-old athlete presents with shoulder pain after a direct blow. He has severe pain, inability to abduct his arm, and a palpable step deformity over the lateral clavicle. What is your diagnostic approach?"], ["A patient presents with weakness in finger abduction and adduction, along with atrophy of the hypothenar eminence. Which nerve is likely injured?"], ["What is the anatomical snuffbox and its clinical significance?"], ["A 65-year-old woman falls on her outstretched hand and presents with a 'dinner fork' deformity. What is the likely diagnosis and immediate management?"], ["Explain the boundaries and contents of the femoral triangle."], ["A patient has foot drop after a knee injury. What nerve is most likely damaged and how would you confirm it?"] ] # Custom CSS css = """ #warning {background-color: #FFCCCB; padding: 10px; border-radius: 5px; margin-bottom: 10px;} .generate-btn {background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white;} footer {visibility: hidden;} #output-box {min-height: 200px; border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px;} """ # Build Gradio interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown( """ # šŸ„ MedGemma Anatomy Assistant v1.2 Fine-tuned medical AI assistant specialized in **anatomical and clinical reasoning**. ### Focus Areas: - Clinical anatomy - Orthopedic injuries - Neurological assessments - Diagnostic approaches - Management protocols """ ) gr.HTML( """
āš ļø MEDICAL DISCLAIMER: This AI model is for educational and reference purposes only. It is NOT intended for clinical decision-making, patient diagnosis, or treatment planning without professional medical oversight. Always consult qualified healthcare professionals for medical advice.
""" ) with gr.Row(): with gr.Column(scale=2): question_input = gr.Textbox( label="Medical Question", placeholder="Ask a question about anatomy, injuries, or clinical presentations...", lines=4 ) with gr.Accordion("Advanced Settings", open=False): max_tokens = gr.Slider( minimum=128, maximum=1024, value=512, step=64, label="Max Tokens", info="Maximum length of response" ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative, Lower = more focused" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P", info="Nucleus sampling parameter" ) generate_btn = gr.Button("šŸš€ Generate Response", variant="primary", elem_classes="generate-btn") clear_btn = gr.ClearButton([question_input], value="šŸ—‘ļø Clear") with gr.Column(scale=3): output = gr.Markdown( label="Response", value="*Your medical answer will appear here...*", elem_id="output-box" ) with gr.Row(): gr.Examples( examples=examples, inputs=question_input, label="šŸ“‹ Example Questions - Click to try" ) # Event handlers generate_btn.click( fn=generate_response, inputs=[question_input, max_tokens, temperature, top_p], outputs=output, show_progress=True ) question_input.submit( fn=generate_response, inputs=[question_input, max_tokens, temperature, top_p], outputs=output, show_progress=True ) gr.Markdown( """ --- ### About This Model **Base Model:** google/medgemma-4b-it **Training Data:** 183 medical Q&A pairs **Method:** LoRA fine-tuning **Hardware:** Google Colab T4 GPU šŸ“š [Model Card](https://huggingface.co/krishna195/medgemma-anatomy-v1.2) | šŸ’» [GitHub](https://github.com/krishna195) | šŸ¤— [Hugging Face](https://huggingface.co/krishna195) Built with ā¤ļø for medical education """ ) # Launch if __name__ == "__main__": demo.queue(max_size=10) demo.launch()