Update src/streamlit_app.py
Browse files- src/streamlit_app.py +83 -41
src/streamlit_app.py
CHANGED
|
@@ -11,39 +11,50 @@ st.set_page_config(
|
|
| 11 |
)
|
| 12 |
|
| 13 |
# --- Model Loading ---
|
| 14 |
-
# Choose your OpenBioLLM model. The 8B parameter model is more manageable for typical Hugging Face Spaces resources.
|
| 15 |
-
# For larger models like 70B, you might need upgraded hardware on Spaces.
|
| 16 |
MODEL_NAME = "aaditya/Llama3-OpenBioLLM-8B"
|
| 17 |
|
| 18 |
-
@st.cache_resource
|
| 19 |
def load_model_and_tokenizer():
|
| 20 |
"""Loads the pre-trained model and tokenizer."""
|
|
|
|
| 21 |
try:
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 23 |
-
# Load the model with torch_dtype=torch.float16 for potentially faster inference and lower memory,
|
| 24 |
-
# and device_map='auto' to leverage available hardware (CPU/GPU) efficiently.
|
| 25 |
model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
MODEL_NAME,
|
| 27 |
-
torch_dtype=torch.float16,
|
| 28 |
-
device_map="auto",
|
|
|
|
| 29 |
)
|
| 30 |
-
|
| 31 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
qa_pipeline = pipeline(
|
| 33 |
"text-generation",
|
| 34 |
model=model,
|
| 35 |
tokenizer=tokenizer,
|
| 36 |
-
max_new_tokens=512,
|
| 37 |
do_sample=True,
|
| 38 |
-
temperature=0.7,
|
| 39 |
-
top_p=0.9,
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
|
|
|
| 41 |
return qa_pipeline
|
| 42 |
except Exception as e:
|
| 43 |
st.error(f"Error loading model: {e}")
|
| 44 |
-
st.error("This could be due to model availability, network issues,
|
| 45 |
-
st.error(f"Attempted to load: {MODEL_NAME}")
|
| 46 |
-
st.info("
|
|
|
|
| 47 |
return None
|
| 48 |
|
| 49 |
qa_pipeline = load_model_and_tokenizer()
|
|
@@ -51,7 +62,12 @@ qa_pipeline = load_model_and_tokenizer()
|
|
| 51 |
# --- Application Interface ---
|
| 52 |
st.title("⚕️ Medical Question Answering with OpenBioLLM")
|
| 53 |
st.markdown("Ask a medical-related question and get an answer from the OpenBioLLM model.")
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
st.sidebar.header("⚠️ Disclaimer")
|
| 57 |
st.sidebar.warning(
|
|
@@ -70,49 +86,75 @@ question = st.text_area("Enter your medical question here:", height=100, key="qu
|
|
| 70 |
|
| 71 |
if st.button("Get Answer", key="get_answer_button"):
|
| 72 |
if qa_pipeline and question:
|
| 73 |
-
with st.spinner("Generating answer...
|
| 74 |
try:
|
| 75 |
-
#
|
| 76 |
-
#
|
| 77 |
-
# We adapt this for a direct question.
|
| 78 |
messages = [
|
| 79 |
-
{"role": "system", "content": "You are a helpful medical information assistant.
|
| 80 |
{"role": "user", "content": question}
|
| 81 |
]
|
| 82 |
-
|
| 83 |
-
#
|
| 84 |
-
#
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
response = qa_pipeline(prompt)
|
| 88 |
|
| 89 |
-
# The output from the text-generation pipeline is usually a list of dictionaries.
|
| 90 |
if response and isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
else:
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
-
st.error("The model did not return a valid response.")
|
| 108 |
st.write("Raw response:", response)
|
| 109 |
|
| 110 |
except Exception as e:
|
| 111 |
st.error(f"An error occurred during answer generation: {e}")
|
| 112 |
st.info("This might be due to the complexity of the question, model limitations, or resource constraints.")
|
|
|
|
| 113 |
|
| 114 |
elif not qa_pipeline:
|
| 115 |
-
st.error("Model
|
| 116 |
elif not question:
|
| 117 |
st.warning("Please enter a question.")
|
| 118 |
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
# --- Model Loading ---
|
|
|
|
|
|
|
| 14 |
MODEL_NAME = "aaditya/Llama3-OpenBioLLM-8B"
|
| 15 |
|
| 16 |
+
@st.cache_resource
|
| 17 |
def load_model_and_tokenizer():
|
| 18 |
"""Loads the pre-trained model and tokenizer."""
|
| 19 |
+
st.info(f"Attempting to load model: {MODEL_NAME}")
|
| 20 |
try:
|
| 21 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
| 22 |
model = AutoModelForCausalLM.from_pretrained(
|
| 23 |
MODEL_NAME,
|
| 24 |
+
torch_dtype=torch.float16,
|
| 25 |
+
device_map="auto",
|
| 26 |
+
trust_remote_code=True # Crucial for some custom Llama variants or if config.json is minimal
|
| 27 |
)
|
| 28 |
+
|
| 29 |
+
# If the model uses a specific chat template (common for Llama3),
|
| 30 |
+
# ensure the tokenizer has it or set it.
|
| 31 |
+
# For Llama-3, the template is often pre-configured.
|
| 32 |
+
# if tokenizer.chat_template is None:
|
| 33 |
+
# # This is a generic Llama3 chat template example; the specific model might have its own nuance
|
| 34 |
+
# # However, OpenBioLLM-8B seems to be a fine-tune, so its tokenizer should ideally have this.
|
| 35 |
+
# tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}{% elif message['role'] == 'user' %}{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}{% elif message['role'] == 'assistant' %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
qa_pipeline = pipeline(
|
| 39 |
"text-generation",
|
| 40 |
model=model,
|
| 41 |
tokenizer=tokenizer,
|
| 42 |
+
max_new_tokens=512,
|
| 43 |
do_sample=True,
|
| 44 |
+
temperature=0.7,
|
| 45 |
+
top_p=0.9,
|
| 46 |
+
# Explicitly setting pad_token_id if not already set by the tokenizer
|
| 47 |
+
# Llama models typically use eos_token_id as pad_token_id
|
| 48 |
+
pad_token_id=tokenizer.eos_token_id
|
| 49 |
)
|
| 50 |
+
st.success("Model and tokenizer loaded successfully!")
|
| 51 |
return qa_pipeline
|
| 52 |
except Exception as e:
|
| 53 |
st.error(f"Error loading model: {e}")
|
| 54 |
+
st.error("This could be due to model availability, network issues, resource limitations, or a configuration issue with the model on Hugging Face Hub.")
|
| 55 |
+
st.error(f"Attempted to load: {MODEL_NAME} with 'trust_remote_code=True'.")
|
| 56 |
+
st.info("Please ensure your Hugging Face Space has enough resources (RAM, CPU). The 8B model is large.")
|
| 57 |
+
st.info("You might also want to check the 'Files and versions' tab of the model on Hugging Face Hub for any specific loading instructions or issues reported by others.")
|
| 58 |
return None
|
| 59 |
|
| 60 |
qa_pipeline = load_model_and_tokenizer()
|
|
|
|
| 62 |
# --- Application Interface ---
|
| 63 |
st.title("⚕️ Medical Question Answering with OpenBioLLM")
|
| 64 |
st.markdown("Ask a medical-related question and get an answer from the OpenBioLLM model.")
|
| 65 |
+
|
| 66 |
+
if qa_pipeline:
|
| 67 |
+
st.markdown(f"**Model used:** `{MODEL_NAME}` (Loaded)")
|
| 68 |
+
else:
|
| 69 |
+
st.markdown(f"**Model used:** `{MODEL_NAME}` (Failed to load)")
|
| 70 |
+
|
| 71 |
|
| 72 |
st.sidebar.header("⚠️ Disclaimer")
|
| 73 |
st.sidebar.warning(
|
|
|
|
| 86 |
|
| 87 |
if st.button("Get Answer", key="get_answer_button"):
|
| 88 |
if qa_pipeline and question:
|
| 89 |
+
with st.spinner("Generating answer... This may take a moment for an 8B model on CPU."):
|
| 90 |
try:
|
| 91 |
+
# For Llama 3 style models, the pipeline's tokenizer should handle the template.
|
| 92 |
+
# If not, you'd apply the template manually using tokenizer.apply_chat_template
|
|
|
|
| 93 |
messages = [
|
| 94 |
+
{"role": "system", "content": "You are a knowledgeable and helpful medical information assistant. Your goal is to provide clear, accurate, and concise answers to medical questions based on the information you have been trained on. Do not provide medical advice or diagnoses. State that you are an AI assistant if asked about your nature."},
|
| 95 |
{"role": "user", "content": question}
|
| 96 |
]
|
| 97 |
+
|
| 98 |
+
# The pipeline for text-generation with Llama3 models often expects the chat formatted as a string
|
| 99 |
+
# or can take the list of messages if the tokenizer is correctly configured with a chat_template.
|
| 100 |
+
# Let's try applying the template if the pipeline doesn't do it implicitly.
|
| 101 |
+
try:
|
| 102 |
+
prompt = qa_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
st.warning(f"Could not apply chat template directly, using a basic prompt structure: {e}")
|
| 105 |
+
# Fallback prompt structure - this might be less effective than the proper chat template
|
| 106 |
+
prompt = f"System: You are a helpful medical information assistant. Please answer the user's question based on your knowledge. Provide informative and clear answers.\nUser: {question}\nAssistant:"
|
| 107 |
+
|
| 108 |
|
| 109 |
response = qa_pipeline(prompt)
|
| 110 |
|
|
|
|
| 111 |
if response and isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
|
| 112 |
+
generated_answer_full = response[0]["generated_text"]
|
| 113 |
+
|
| 114 |
+
# Extract only the assistant's response after the prompt
|
| 115 |
+
# The prompt structure from apply_chat_template should end with the signal for the assistant to start.
|
| 116 |
+
# Or, if using the fallback, find the last "Assistant:"
|
| 117 |
+
|
| 118 |
+
# Find the last occurrence of the assistant's turn signal from the template
|
| 119 |
+
assistant_signal_templated = "<|start_header_id|>assistant<|end_header_id|>"
|
| 120 |
+
# The prompt produced by `apply_chat_template` ends with this typically.
|
| 121 |
+
# So the generated text will *start* after this.
|
| 122 |
+
|
| 123 |
+
# The actual generated text from the model starts *after* the full prompt.
|
| 124 |
+
# So, if `prompt` was fed to the pipeline, the `generated_answer_full`
|
| 125 |
+
# will be `prompt + actual_answer`.
|
| 126 |
+
|
| 127 |
+
if generated_answer_full.startswith(prompt):
|
| 128 |
+
answer_text = generated_answer_full[len(prompt):].strip()
|
| 129 |
else:
|
| 130 |
+
# If the prompt is not exactly at the beginning (e.g., some prefixes added by pipeline)
|
| 131 |
+
# try a more generic way to find the assistant's first actual text.
|
| 132 |
+
# This part is tricky and depends heavily on the exact output format of the model/pipeline.
|
| 133 |
+
# For Llama3, it typically doesn't repeat the whole prompt IF the input was already templated.
|
| 134 |
+
# The output of text-generation is just the *new* tokens.
|
| 135 |
+
answer_text = generated_answer_full # Assuming the pipeline output is ONLY the new text
|
| 136 |
+
|
| 137 |
+
# Clean up End-of-Text token if present
|
| 138 |
+
if qa_pipeline.tokenizer.eos_token and qa_pipeline.tokenizer.eos_token in answer_text:
|
| 139 |
+
answer_text = answer_text.split(qa_pipeline.tokenizer.eos_token)[0].strip()
|
| 140 |
+
|
| 141 |
+
# Sometimes other special tokens might linger
|
| 142 |
+
answer_text = answer_text.replace("<|eot_id|>", "").strip()
|
| 143 |
+
|
| 144 |
|
| 145 |
+
st.subheader("📝 Model's Answer:")
|
| 146 |
+
st.info(answer_text)
|
| 147 |
else:
|
| 148 |
+
st.error("The model did not return a valid response structure.")
|
| 149 |
st.write("Raw response:", response)
|
| 150 |
|
| 151 |
except Exception as e:
|
| 152 |
st.error(f"An error occurred during answer generation: {e}")
|
| 153 |
st.info("This might be due to the complexity of the question, model limitations, or resource constraints.")
|
| 154 |
+
# You can add more detailed error logging here if needed, e.g., print(traceback.format_exc())
|
| 155 |
|
| 156 |
elif not qa_pipeline:
|
| 157 |
+
st.error("Model is not loaded. Cannot generate an answer. Please check the error messages above.")
|
| 158 |
elif not question:
|
| 159 |
st.warning("Please enter a question.")
|
| 160 |
|