kalpanie commited on
Commit
e4fed54
·
verified ·
1 Parent(s): 450333d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +115 -41
src/streamlit_app.py CHANGED
@@ -1,46 +1,120 @@
1
  import streamlit as st
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
- st.set_page_config(page_title="🩺 Medical Q&A with OpenBioLLM", layout="centered")
6
-
7
- st.title("🧠 OpenBioLLM Medical Assistant")
8
- st.markdown("Ask a medical question below and get an expert answer using [OpenBioLLM-8B](https://huggingface.co/aaditya/Llama3-OpenBioLLM-8B)")
9
-
10
- @st.cache_resource
11
- def load_model():
12
- model_name = "aaditya/Llama3-OpenBioLLM-8B"
13
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- device_map="auto",
17
- torch_dtype=torch.float16,
18
- trust_remote_code=True
19
- )
20
- return tokenizer, model
21
-
22
- tokenizer, model = load_model()
23
-
24
- question = st.text_input("🩺 Enter your medical question:", placeholder="E.g., What are the symptoms of iron deficiency?")
25
-
26
- if question:
27
- with st.spinner("Thinking..."):
28
- prompt = f"""You are a helpful and accurate medical assistant. Answer the following question precisely:
29
- Question: {question}
30
- Answer:"""
31
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
32
-
33
- outputs = model.generate(
34
- input_ids=input_ids,
35
- max_new_tokens=256,
36
- do_sample=False,
37
- temperature=0.7,
38
- top_k=50,
39
- top_p=0.9
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- answer = response.split("Answer:")[-1].strip()
 
 
44
 
45
- st.markdown("### ✅ Answer:")
46
- st.success(answer)
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
+
5
+ # --- Page Configuration ---
6
+ st.set_page_config(
7
+ page_title="Medical Question Answering with OpenBioLLM",
8
+ page_icon="⚕️",
9
+ layout="wide",
10
+ initial_sidebar_state="expanded",
11
+ )
12
+
13
+ # --- Model Loading ---
14
+ # Choose your OpenBioLLM model. The 8B parameter model is more manageable for typical Hugging Face Spaces resources.
15
+ # For larger models like 70B, you might need upgraded hardware on Spaces.
16
+ MODEL_NAME = "aaditya/Llama3-OpenBioLLM-8B"
17
+
18
+ @st.cache_resource # Caches the model and tokenizer for better performance
19
+ def load_model_and_tokenizer():
20
+ """Loads the pre-trained model and tokenizer."""
21
+ try:
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ # Load the model with torch_dtype=torch.float16 for potentially faster inference and lower memory,
24
+ # and device_map='auto' to leverage available hardware (CPU/GPU) efficiently.
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ torch_dtype=torch.float16, # Using float16 to reduce memory footprint
28
+ device_map="auto", # Automatically uses GPU if available, otherwise CPU
29
+ )
30
+ # For models that might not explicitly support "question-answering" pipeline directly,
31
+ # we use "text-generation".
32
+ qa_pipeline = pipeline(
33
+ "text-generation",
34
+ model=model,
35
+ tokenizer=tokenizer,
36
+ max_new_tokens=512, # Adjust as needed for answer length
37
+ do_sample=True,
38
+ temperature=0.7, # Controls randomness. Lower for more factual, higher for more creative.
39
+ top_p=0.9, # Nucleus sampling
 
40
  )
41
+ return qa_pipeline
42
+ except Exception as e:
43
+ st.error(f"Error loading model: {e}")
44
+ st.error("This could be due to model availability, network issues, or resource limitations on the Hugging Face Space.")
45
+ st.error(f"Attempted to load: {MODEL_NAME}")
46
+ st.info("If you are running this on a free Hugging Face Space, larger models like the 70B version might exceed resource limits. The 8B version is generally more suitable.")
47
+ return None
48
+
49
+ qa_pipeline = load_model_and_tokenizer()
50
+
51
+ # --- Application Interface ---
52
+ st.title("⚕️ Medical Question Answering with OpenBioLLM")
53
+ st.markdown("Ask a medical-related question and get an answer from the OpenBioLLM model.")
54
+ st.markdown(f"**Model used:** `{MODEL_NAME}`")
55
+
56
+ st.sidebar.header("⚠️ Disclaimer")
57
+ st.sidebar.warning(
58
+ "This application is for informational and educational purposes only. "
59
+ "The answers are generated by an AI model (OpenBioLLM) and may contain inaccuracies or biases. "
60
+ "**It is NOT a substitute for professional medical advice, diagnosis, or treatment.** "
61
+ "Always consult with a qualified healthcare professional for any medical concerns."
62
+ )
63
+ st.sidebar.info(
64
+ "The model's performance has not been rigorously evaluated in real-world healthcare environments. "
65
+ "Do not rely on its outputs for medical decision-making."
66
+ )
67
+
68
+ # --- User Input ---
69
+ question = st.text_area("Enter your medical question here:", height=100, key="question_input")
70
+
71
+ if st.button("Get Answer", key="get_answer_button"):
72
+ if qa_pipeline and question:
73
+ with st.spinner("Generating answer... Please wait."):
74
+ try:
75
+ # Construct a prompt for the Llama3-based OpenBioLLM model.
76
+ # Llama 3 uses a specific chat template structure.
77
+ # We adapt this for a direct question.
78
+ messages = [
79
+ {"role": "system", "content": "You are a helpful medical information assistant. Please answer the user's question based on your knowledge. Provide informative and clear answers."},
80
+ {"role": "user", "content": question}
81
+ ]
82
+ # The pipeline with a text-generation model expects a string prompt.
83
+ # We'll format the messages into a string that Llama3 expects.
84
+ # A simpler approach for direct QA might be a direct instruction:
85
+ prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful medical information assistant. Please answer the user's question based on your knowledge. Provide informative and clear answers.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
86
+
87
+ response = qa_pipeline(prompt)
88
+
89
+ # The output from the text-generation pipeline is usually a list of dictionaries.
90
+ if response and isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
91
+ generated_answer = response[0]["generated_text"]
92
+ # The model will repeat the prompt, so we need to extract only the assistant's response.
93
+ assistant_response_start = generated_answer.rfind("<|start_header_id|>assistant<|end_header_id|>")
94
+ if assistant_response_start != -1:
95
+ answer_text = generated_answer[assistant_response_start + len("<|start_header_id|>assistant<|end_header_id|>"):].strip()
96
+ # Further clean up any trailing special tokens if necessary
97
+ if "<|eot_id|>" in answer_text:
98
+ answer_text = answer_text.split("<|eot_id|>")[0].strip()
99
+
100
+ st.subheader("📝 Model's Answer:")
101
+ st.info(answer_text)
102
+ else:
103
+ st.warning("Could not properly parse the assistant's response from the model output.")
104
+ st.text_area("Raw model output:", generated_answer, height=200)
105
+
106
+ else:
107
+ st.error("The model did not return a valid response.")
108
+ st.write("Raw response:", response)
109
+
110
+ except Exception as e:
111
+ st.error(f"An error occurred during answer generation: {e}")
112
+ st.info("This might be due to the complexity of the question, model limitations, or resource constraints.")
113
 
114
+ elif not qa_pipeline:
115
+ st.error("Model could not be loaded. Please check the logs for more details.")
116
+ elif not question:
117
+ st.warning("Please enter a question.")
118
 
119
+ st.markdown("---")
120
+ st.markdown("Created with [Streamlit](https://streamlit.io/) and [Hugging Face Transformers](https://huggingface.co/transformers).")