finetuned_medic / app.py
LAWSA07's picture
Update app.py
a41d940 verified
raw
history blame
4.16 kB
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model():
print("Loading model...")
# Load the base model with 4-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit",
device_map="auto",
torch_dtype="auto"
)
# Load the fine-tuned adapter
model = PeftModel.from_pretrained(
base_model,
"LAWSA07/medical_fine_tuned_deepseekR1"
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
)
print("Model loaded successfully!")
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model()
def generate_response(question, max_length=512, temperature=0.7):
"""Generate a response to a medical question."""
if not question.strip():
return "Please enter a question about your symptoms."
# Create prompt
prompt = f"Question: {question}\nAnswer:"
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
# Create Gradio interface
with gr.Blocks(title="Medical Symptoms Assistant") as app:
gr.Markdown("""
# Medical Symptoms Assistant
Ask questions about your symptoms and get AI-generated responses.
**Note**: This is for informational purposes only and not a substitute for professional medical advice.
""")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Describe your symptoms or ask a medical question",
placeholder="I've been experiencing headaches and dizziness for the past week...",
lines=4
)
with gr.Row():
submit_btn = gr.Button("Get Response", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Accordion("Advanced Options", open=False):
max_length_slider = gr.Slider(
minimum=64, maximum=1024, value=512, step=32,
label="Maximum Response Length"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=1.5, value=0.7, step=0.1,
label="Temperature (Creativity)"
)
with gr.Column():
response_output = gr.Textbox(
label="AI Response",
lines=12,
interactive=False
)
disclaimer = gr.Markdown("""
**Disclaimer**: This AI assistant provides information based on its training.
It is not a substitute for professional medical advice, diagnosis, or treatment.
Always seek the advice of your physician or other qualified health provider with any
questions you may have regarding a medical condition.
""")
# Set up event handlers
submit_btn.click(
generate_response,
inputs=[question_input, max_length_slider, temperature_slider],
outputs=response_output
)
question_input.submit(
generate_response,
inputs=[question_input, max_length_slider, temperature_slider],
outputs=response_output
)
clear_btn.click(
lambda: ("", None),
inputs=None,
outputs=[question_input, response_output]
)
# Launch the app
if __name__ == "__main__":
import torch # Import torch here to avoid issues
app.launch(share=True) # set share=False if you don't want a public link