VHA1 / app.py
lukiod's picture
Update app.py
3522bb9 verified
raw
history blame
12.7 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import logging
from typing import List, Dict
import gc
import os
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Set environment variables for memory optimization
os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface/hub'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
class HealthAssistant:
def __init__(self):
self.model_id = "microsoft/Phi-2" # Using smaller Phi-2 model
self.model = None
self.tokenizer = None
self.pipe = None
self.metrics = []
self.medications = []
self.device = "cpu"
self.is_model_loaded = False
self.max_history_length = 2
def initialize_model(self):
try:
if self.is_model_loaded:
return True
logger.info(f"Loading model: {self.model_id}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
trust_remote_code=True,
model_max_length=256,
padding_side="left"
)
logger.info("Tokenizer loaded")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map=None,
low_cpu_mem_usage=True
).to(self.device)
gc.collect()
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device,
model_kwargs={"low_cpu_mem_usage": True}
)
self.is_model_loaded = True
logger.info("Model initialized successfully")
return True
except Exception as e:
logger.error(f"Error in model initialization: {str(e)}")
raise
def unload_model(self):
if hasattr(self, 'model') and self.model is not None:
del self.model
self.model = None
if hasattr(self, 'pipe') and self.pipe is not None:
del self.pipe
self.pipe = None
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.is_model_loaded = False
gc.collect()
logger.info("Model unloaded successfully")
def generate_response(self, message: str, history: List = None) -> str:
try:
if not self.is_model_loaded:
self.initialize_model()
message = message[:200] # Truncate long messages
prompt = self._prepare_prompt(message, history[-self.max_history_length:] if history else None)
generation_args = {
"max_new_tokens": 200,
"return_full_text": False,
"temperature": 0.7,
"do_sample": True,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 1.1,
"num_return_sequences": 1,
"batch_size": 1
}
output = self.pipe(prompt, **generation_args)
response = output[0]['generated_text']
gc.collect()
return response.strip()
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return "I apologize, but I encountered an error. Please try again."
def _prepare_prompt(self, message: str, history: List = None) -> str:
prompt_parts = [
"Medical AI assistant. Be professional, include disclaimers.",
self._get_health_context()
]
if history:
for h in history:
if isinstance(h, dict): # New message format
if h['role'] == 'user':
prompt_parts.append(f"Human: {h['content'][:100]}")
else:
prompt_parts.append(f"Assistant: {h['content'][:100]}")
else: # Old format (tuple)
prompt_parts.extend([
f"Human: {h[0][:100]}",
f"Assistant: {h[1][:100]}"
])
prompt_parts.extend([
f"Human: {message}",
"Assistant:"
])
return "\n".join(prompt_parts)
def _get_health_context(self) -> str:
if not self.metrics and not self.medications:
return "No health data"
context = []
if self.metrics:
latest = self.metrics[-1]
context.append(f"Metrics: W:{latest['Weight']}kg S:{latest['Steps']} Sl:{latest['Sleep']}h")
if self.medications:
meds = [f"{m['Medication']}({m['Dosage']}@{m['Time']})" for m in self.medications[-2:]]
context.append("Meds: " + ", ".join(meds))
return " | ".join(context)
def add_metrics(self, weight: float, steps: int, sleep: float) -> bool:
try:
if len(self.metrics) >= 5:
self.metrics.pop(0)
self.metrics.append({
'Weight': weight,
'Steps': steps,
'Sleep': sleep
})
return True
except Exception as e:
logger.error(f"Error adding metrics: {e}")
return False
def add_medication(self, name: str, dosage: str, time: str, notes: str = "") -> bool:
try:
if len(self.medications) >= 5:
self.medications.pop(0)
self.medications.append({
'Medication': name,
'Dosage': dosage,
'Time': time,
'Notes': notes
})
return True
except Exception as e:
logger.error(f"Error adding medication: {e}")
return False
class GradioInterface:
def __init__(self):
try:
logger.info("Initializing Health Assistant...")
self.assistant = HealthAssistant()
logger.info("Health Assistant initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Health Assistant: {e}")
raise
def chat_response(self, message: str, history: List) -> tuple:
if not message.strip():
return "", history
try:
response = self.assistant.generate_response(message, history)
# Convert to new message format
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
if len(history) % 3 == 0:
self.assistant.unload_model()
return "", history
except Exception as e:
logger.error(f"Error in chat response: {e}")
return "", history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "I apologize, but I encountered an error. Please try again."}
]
def add_health_metrics(self, weight: float, steps: int, sleep: float) -> str:
if not all([weight is not None, steps is not None, sleep is not None]):
return "⚠️ Please fill in all metrics."
if weight <= 0 or steps < 0 or sleep < 0:
return "⚠️ Please enter valid positive numbers."
if self.assistant.add_metrics(weight, steps, sleep):
return f"""βœ… Health metrics saved successfully!
β€’ Weight: {weight} kg
β€’ Steps: {steps}
β€’ Sleep: {sleep} hours"""
return "❌ Error saving metrics."
def add_medication_info(self, name: str, dosage: str, time: str, notes: str) -> str:
if not all([name, dosage, time]):
return "⚠️ Please fill in all required fields."
if self.assistant.add_medication(name, dosage, time, notes):
return f"""βœ… Medication added successfully!
β€’ Medication: {name}
β€’ Dosage: {dosage}
β€’ Time: {time}
β€’ Notes: {notes if notes else 'None'}"""
return "❌ Error adding medication."
def create_interface(self):
with gr.Blocks(title="Medical Health Assistant") as demo:
gr.Markdown("""
# πŸ₯ Medical Health Assistant
This AI assistant provides general health information and guidance.
""")
with gr.Tabs():
with gr.Tab("πŸ’¬ Medical Consultation"):
chatbot = gr.Chatbot(
value=[],
height=400,
label=False,
type="messages" # Using new message format
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask your health question...",
lines=1,
label=False,
scale=9
)
send_btn = gr.Button("Send", scale=1)
clear_btn = gr.Button("Clear Chat")
with gr.Tab("πŸ“Š Health Metrics"):
gr.Markdown("### Track Your Health Metrics")
with gr.Row():
weight_input = gr.Number(
label="Weight (kg)",
minimum=0,
maximum=500
)
steps_input = gr.Number(
label="Steps",
minimum=0,
maximum=100000
)
sleep_input = gr.Number(
label="Hours Slept",
minimum=0,
maximum=24
)
metrics_btn = gr.Button("Save Metrics")
metrics_status = gr.Markdown()
with gr.Tab("πŸ’Š Medication Manager"):
gr.Markdown("### Track Your Medications")
med_name = gr.Textbox(
label="Medication Name",
placeholder="Enter medication name"
)
with gr.Row():
med_dosage = gr.Textbox(
label="Dosage",
placeholder="e.g., 500mg"
)
med_time = gr.Textbox(
label="Time",
placeholder="e.g., 9:00 AM"
)
med_notes = gr.Textbox(
label="Notes (optional)",
placeholder="Additional instructions or notes"
)
med_btn = gr.Button("Add Medication")
med_status = gr.Markdown()
msg.submit(self.chat_response, [msg, chatbot], [msg, chatbot])
send_btn.click(self.chat_response, [msg, chatbot], [msg, chatbot])
clear_btn.click(lambda: [], None, chatbot)
metrics_btn.click(
self.add_health_metrics,
inputs=[weight_input, steps_input, sleep_input],
outputs=[metrics_status]
)
med_btn.click(
self.add_medication_info,
inputs=[med_name, med_dosage, med_time, med_notes],
outputs=[med_status]
)
gr.Markdown("""
### ⚠️ Medical Disclaimer
This AI assistant provides general health information only. Not a replacement for professional medical advice.
Always consult healthcare professionals for medical decisions.
""")
demo.queue(max_size=5)
return demo
def main():
try:
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(
server_name="0.0.0.0",
show_error=True,
share=True
)
except Exception as e:
logger.error(f"Error starting application: {e}")
raise
if __name__ == "__main__":
main()