"
logger.info(f"Set pad_token to {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(model, lora_adapter_name)
model.eval()
model.config.pad_token_id = tokenizer.pad_token_id
current_model_id = model_id
chatbot_state = []
return f"Successfully loaded model: {model_id} with LoRA adapter {lora_adapter_name}", chatbot_state
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
return f"Error: Failed to load model {model_id}: {str(e)}", chatbot_state
def format_time(seconds_float):
total_seconds = int(round(seconds_float))
hours = total_seconds // 3600
remaining_seconds = total_seconds % 3600
minutes = remaining_seconds // 60
seconds = remaining_seconds % 60
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
return f"{minutes}m {seconds}s"
else:
return f"{seconds}s"
DESCRIPTION = '''
⚕️ Medical Chatbot with LoRA Models
AI-Powered Medical Insights
Explore our advanced models, fine-tuned with LoRA for medical reasoning in Vietnamese.
ℹ️ Notice: For research purposes only. AI responses may have limitations due to development, datasets, and architecture. Always consult a medical professional for health advice 🩺.
'''
CSS = """
.intro-container {
max-width: 800px;
padding: 40px;
background: #ffffff;
border-radius: 15px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
text-align: center;
animation: fadeIn 1s ease-in-out;
}
h1 {
font-size: 1.5em;
color: #007bff;
text-transform: uppercase;
letter-spacing: 1px;
margin-bottom: 20px;
}
h2 {
font-size: 1.3em;
color: #555555;
margin-bottom: 30px;
}
.intro-highlight {
font-size: 1.5em;
color: #333333;
margin: 20px 0;
padding: 20px;
background: #f8f9fa;
border-left: 5px solid #007bff;
border-radius: 10px;
transition: transform 0.3s ease;
}
.intro-highlight:hover {
transform: scale(1.02);
}
.intro-disclaimer {
font-size: 1.3em;
color: #333333;
background: #e9ecef;
padding: 20px;
border-radius: 10px;
border: 1px solid #007bff;
margin-top: 30px;
}
strong {
color: #007bff;
font-weight: bold;
}
.intro-icon {
font-size: 1.4em;
margin-right: 8px;
}
@keyframes fadeIn {
0% { opacity: 0; transform: translateY(-20px); }
100% { opacity: 1; transform: translateY(0); }
}
.spinner {
animation: spin 1s linear infinite;
display: inline-block;
margin-right: 8px;
}
@keyframes spin {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
.thinking-summary {
cursor: pointer;
padding: 8px;
background: #f5f5f5;
border-radius: 4px;
margin: 4px 0;
}
.thought-content {
padding: 10px;
background: none;
border-radius: 4px;
margin: 5px 0;
}
.thinking-container {
border-left: 3px solid #facc15;
padding-left: 10px;
margin: 8px 0;
background: none;
}
.thinking-container:empty {
background: #e0e0e0;
}
details:not([open]) .thinking-container {
border-left-color: #290c15;
}
details {
border: 1px solid #e0e0e0 !important;
border-radius: 8px !important;
padding: 12px !important;
margin: 8px 0 !important;
transition: border-color 0.2s;
}
.think-section {
background-color: #e6f3ff;
border-left: 4px solid #4a90e2;
padding: 15px;
margin: 10px 0;
border-radius: 6px;
font-size: 14px;
}
.final-answer {
background-color: #f0f4f8;
border-left: 4px solid #2ecc71;
padding: 15px;
margin: 10px 0;
border-radius: 6px;
font-size: 14px;
}
#output-container {
position: relative;
}
.copy-button {
position: absolute;
top: 10px;
right: 10px;
padding: 5px 10px;
background-color: #4a90e2;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}
.copy-button:hover {
background-color: #357abd;
}
"""
JS_SCRIPTS = """
"""
def user(message, history):
if not isinstance(history, list):
history = []
return "", history + [[message, None]]
class ParserState:
__slots__ = ['answer', 'thought', 'in_think', 'in_answer', 'start_time', 'last_pos', 'total_think_time']
def __init__(self):
self.answer = ""
self.thought = ""
self.in_think = False
self.in_answer = False
self.start_time = 0
self.last_pos = 0
self.total_think_time = 0.0
def parse_response(text, state):
buffer = text[state.last_pos:]
state.last_pos = len(text)
while buffer:
if not state.in_think and not state.in_answer:
think_start = buffer.find('')
reasoning_start = buffer.find('')
answer_start = buffer.find('')
starts = []
if think_start != -1:
starts.append((think_start, '', 7, 'think'))
if reasoning_start != -1:
starts.append((reasoning_start, '', 11, 'think'))
if answer_start != -1:
starts.append((answer_start, '', 8, 'answer'))
if not starts:
state.answer += buffer
break
start_pos, start_tag, tag_length, mode = min(starts, key=lambda x: x[0])
state.answer += buffer[:start_pos]
if mode == 'think':
state.in_think = True
state.start_time = time.perf_counter()
else:
state.in_answer = True
buffer = buffer[start_pos + tag_length:]
elif state.in_think:
think_end = buffer.find('')
reasoning_end = buffer.find('')
ends = []
if think_end != -1:
ends.append((think_end, '', 8))
if reasoning_end != -1:
ends.append((reasoning_end, '', 12))
if ends:
end_pos, end_tag, tag_length = min(ends, key=lambda x: x[0])
state.thought += buffer[:end_pos]
duration = time.perf_counter() - state.start_time
state.total_think_time += duration
state.in_think = False
buffer = buffer[end_pos + tag_length:]
if end_tag == '':
state.answer += buffer
break
else:
state.thought += buffer
break
elif state.in_answer:
answer_end = buffer.find('')
if answer_end != -1:
state.answer += buffer[:answer_end]
state.in_answer = False
buffer = buffer[answer_end + 9:]
else:
state.answer += buffer
break
elapsed = time.perf_counter() - state.start_time if state.in_think else 0
return state, elapsed
def format_response(state, elapsed):
answer_part = state.answer
collapsible = []
collapsed = ""
if state.thought or state.in_think:
if state.in_think:
total_elapsed = state.total_think_time + elapsed
formatted_time = format_time(total_elapsed)
status = f"💭 Thinking for {formatted_time}"
else:
formatted_time = format_time(state.total_think_time)
status = f"✅ Thought for {formatted_time}"
collapsed = ""
collapsible.append(
f"{collapsed}{status}
\n\n\n{state.thought}\n
\n "
)
# print("collapsible: ", collapsible)
# print("answer_part: ", answer_part)
return collapsible, answer_part
def remove_tags(text):
if text is None:
return None
return re.sub(r'<[^>]+>', ' ', text).strip()
def generate_response(history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
global model, tokenizer, current_model_id
if auto_clear:
history = [history[-1]]
# Apply the function to the second element of each sublist
history = [[item[0], remove_tags(item[1])] for item in history]
try:
if not history or not isinstance(history, list):
logger.error("History is empty or not a list")
history = [[None, "Error: Conversation history is empty or invalid"]]
yield history
return
if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
logger.error("Last history entry is invalid or missing user message")
history = history[:-1] + [[history[-1][0] if history else None, "Error: No valid user message provided"]]
yield history
return
if model is None or tokenizer is None or model_id != current_model_id:
status, history = load_model(model_id, history)
if "Error" in status:
logger.error(status)
history[-1][1] = status
yield history
return
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
if model_id not in prompt_functions:
logger.error(f"No prompt function defined for model_id: {model_id}")
history[-1][1] = f"Error: No prompt function defined for model {model_id}"
yield history
return
prompt_fn = prompt_functions[model_id]
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
text = prompt_fn(history[-1][0])
inputs = tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKEN_LENGTH
)
else:
conversation = []
for msg in history:
if msg[0]:
conversation.append({"role": "user", "content": msg[0]})
if msg[1]:
clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
conversation.append({"role": "assistant", "content": clean_text})
elif msg[0] and not msg[1]:
conversation.append({"role": "assistant", "content": ""})
if not conversation:
logger.error("No valid messages in conversation history")
history[-1][1] = "Error: No valid messages in conversation history"
yield history
return
if model_id in [
"Gemma-3-1B-GRPO-Vi-Medical-LoRA"
]:
conversation= conversation[-2:]
text = prompt_fn(conversation)
tokenizer_kwargs = {
"return_tensors": "pt",
"padding": True,
"truncation": True,
"max_length": MAX_INPUT_TOKEN_LENGTH
}
inputs = tokenizer(text, **tokenizer_kwargs)
if inputs is None or "input_ids" not in inputs:
logger.error("Tokenizer returned invalid or None output")
history[-1][1] = "Error: Failed to tokenize input"
yield history
return
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs.get("attention_mask").to(model.device) if "attention_mask" in inputs else None
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": 1,
"repetition_penalty": 1.0,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"use_cache": True,
"cache_implementation": "dynamic",
}
streamer = TextIteratorStreamer(tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs["streamer"] = streamer
def run_generation():
try:
model.generate(**generate_kwargs)
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
thread = Thread(target=run_generation)
thread.start()
state = ParserState()
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
full_response = ""
else:
full_response = ""
for text in streamer:
if not active_gen[0]:
logger.info("Generation stopped by user")
break
if text:
logger.debug(f"Raw streamer output: {text}")
text = re.sub(r'<\|\w+\|>', '', text)
full_response += text
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history[-1][1] = "\n\n".join(collapsible + [answer_part])
yield history
else:
logger.debug("Streamer returned empty text")
thread.join()
thread = None
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history[-1][1] = "\n\n".join(collapsible + [answer_part])
if not full_response:
logger.warning("No response generated by model")
history[-1][1] = "No response generated. Please try again or select a different model."
print("full_response: ", full_response)
yield history
except Exception as e:
logger.error(f"Error in generate: {str(e)}")
if not history or not isinstance(history, list):
history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
else:
history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
yield history
finally:
active_gen[0] = False
MODEL_IDS = list(lora_configs.keys())
load_model(MODEL_IDS[0], [])
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
# gr.Markdown(DESCRIPTION)
gr.HTML(DESCRIPTION)
gr.HTML(JS_SCRIPTS)
active_gen = gr.State([False])
chatbot = gr.Chatbot(
elem_id="chatbot",
height=500,
show_label=False,
render_markdown=True
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your medical query in Vietnamese...",
container=False,
scale=4
)
submit_btn = gr.Button("Send", variant='primary', scale=1)
with gr.Column(scale=2):
with gr.Row():
clear_btn = gr.Button("Clear", variant='secondary')
stop_btn = gr.Button("Stop", variant='stop')
with gr.Accordion("Parameters", open=False):
model_dropdown = gr.Dropdown(
choices=MODEL_IDS,
value=MODEL_IDS[0],
label="Select Model",
interactive=True
)
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
top_k = gr.Slider(minimum=1, maximum=100, value=64, step=1, label="Top-k")
max_tokens = gr.Slider(minimum=128, maximum=4084, value=512, step=32, label="Max Tokens")
seed = gr.Slider(minimum=0, maximum=2 ** 32, value=42, step=1, label="Random Seed")
auto_clear = gr.Checkbox(label="Auto Clear History", value=True, info="Clears internal conversation history after each response but keeps displayed messages.")
gr.Examples(
examples=[
["Khi nghi ngờ bị loét dạ dày tá tràng nên đến khoa nào tại bệnh viện để thăm khám?"],
["Triệu chứng của loét dạ dày tá tràng là gì?"],
["Tôi bị mất ngủ, tôi phải làm gì?"],
["Tôi bị trĩ, tôi có nên mổ không?"]
],
inputs=msg,
label="Example Medical Queries"
)
model_load_output = gr.Textbox(label="Model Load Status")
model_dropdown.change(
fn=load_model,
inputs=[model_dropdown, chatbot],
outputs=[model_load_output, chatbot]
)
submit_event = submit_btn.click(
user, [msg, chatbot], [msg, chatbot], queue=False
).then(
lambda: [True], outputs=active_gen
).then(
generate_response, [chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear], chatbot
)
msg.submit(
user, [msg, chatbot], [msg, chatbot], queue=False
).then(
lambda: [True], outputs=active_gen
).then(
generate_response, [chatbot, temperature, top_p, top_k, max_tokens, seed, active_gen, model_dropdown, auto_clear], chatbot
)
stop_btn.click(
lambda: [False], None, active_gen, cancels=[submit_event]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
try:
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
logger.error(f"Failed to launch Gradio app: {str(e)}")
raise