Spaces:
Runtime error
Runtime error
import gradio as gr | |
from peft import PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def load_model(): | |
print("Loading model...") | |
# Load the base model with 4-bit quantization | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit", | |
device_map="auto", | |
torch_dtype="auto" | |
) | |
# Load the fine-tuned adapter | |
model = PeftModel.from_pretrained( | |
base_model, | |
"LAWSA07/medical_fine_tuned_deepseekR1" | |
) | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" | |
) | |
print("Model loaded successfully!") | |
return model, tokenizer | |
# Load model and tokenizer | |
model, tokenizer = load_model() | |
def generate_response(question, max_length=512, temperature=0.7): | |
"""Generate a response to a medical question.""" | |
if not question.strip(): | |
return "Please enter a question about your symptoms." | |
# Create prompt | |
prompt = f"Question: {question}\nAnswer:" | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
# Create Gradio interface | |
with gr.Blocks(title="Medical Symptoms Assistant") as app: | |
gr.Markdown(""" | |
# Medical Symptoms Assistant | |
Ask questions about your symptoms and get AI-generated responses. | |
**Note**: This is for informational purposes only and not a substitute for professional medical advice. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="Describe your symptoms or ask a medical question", | |
placeholder="I've been experiencing headaches and dizziness for the past week...", | |
lines=4 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Get Response", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Accordion("Advanced Options", open=False): | |
max_length_slider = gr.Slider( | |
minimum=64, maximum=1024, value=512, step=32, | |
label="Maximum Response Length" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, maximum=1.5, value=0.7, step=0.1, | |
label="Temperature (Creativity)" | |
) | |
with gr.Column(): | |
response_output = gr.Textbox( | |
label="AI Response", | |
lines=12, | |
interactive=False | |
) | |
disclaimer = gr.Markdown(""" | |
**Disclaimer**: This AI assistant provides information based on its training. | |
It is not a substitute for professional medical advice, diagnosis, or treatment. | |
Always seek the advice of your physician or other qualified health provider with any | |
questions you may have regarding a medical condition. | |
""") | |
# Set up event handlers | |
submit_btn.click( | |
generate_response, | |
inputs=[question_input, max_length_slider, temperature_slider], | |
outputs=response_output | |
) | |
question_input.submit( | |
generate_response, | |
inputs=[question_input, max_length_slider, temperature_slider], | |
outputs=response_output | |
) | |
clear_btn.click( | |
lambda: ("", None), | |
inputs=None, | |
outputs=[question_input, response_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
import torch # Import torch here to avoid issues | |
app.launch(share=True) # set share=False if you don't want a public link |