Spaces:
Runtime error
Runtime error
File size: 8,874 Bytes
25effb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import copy
import re
import torch
from threading import Thread
from transformers import TextIteratorStreamer
from config import logger, MAX_INPUT_TOKEN_LENGTH
from prompts import PROMPT_FUNCTIONS
from response_parser import ParserState, parse_response, format_response, remove_tags
from utils import merge_conversation
def generate_response(model_handler, history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
raw_history = copy.deepcopy(history)
# Clean history by removing tags from assistant responses
history = [[item[0], remove_tags(item[1]) if item[1] else None] for item in history]
try:
# Validate history
if not isinstance(history, list) or not history:
logger.error("History is empty or not a list")
history = [[None, "Error: Conversation history is empty or invalid"]]
yield history
return
# Validate last history entry
if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
logger.error("Last history entry is invalid or missing user message")
history = raw_history
history[-1][1] = "Error: No valid user message provided"
yield history
return
# Load model if necessary
if model_handler.model is None or model_handler.tokenizer is None or model_id != model_handler.current_model_id:
status, _ = model_handler.load_model(model_id, history)
if "Error" in status:
logger.error(status)
history[-1][1] = status
yield history
return
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
# Validate prompt function
if model_id not in PROMPT_FUNCTIONS:
logger.error(f"No prompt function defined for model_id: {model_id}")
history[-1][1] = f"Error: No prompt function defined for model {model_id}"
yield history
return
prompt_fn = PROMPT_FUNCTIONS[model_id]
# Handle specific model prompt formatting
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
if auto_clear:
text = prompt_fn(model_handler.tokenizer, history[-1][0])
else:
text = prompt_fn(model_handler.tokenizer, merge_conversation(history))
inputs = model_handler.tokenizer(
[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKEN_LENGTH
)
else:
# Build conversation for other models
conversation = []
for msg in history:
if msg[0]:
conversation.append({"role": "user", "content": msg[0]})
if msg[1]:
clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
conversation.append({"role": "assistant", "content": clean_text})
elif msg[0] and not msg[1]:
conversation.append({"role": "assistant", "content": ""})
# Ensure at least one user message
if not any(msg["role"] == "user" for msg in conversation):
logger.error("No valid user messages in conversation history")
history = raw_history
history[-1][1] = "Error: No valid user messages in conversation history"
yield history
return
# Apply auto_clear logic
if auto_clear:
# Keep only the last user message and add an empty assistant response
user_msgs = [msg for msg in conversation if msg["role"] == "user"]
if user_msgs:
conversation = [{"role": "user", "content": user_msgs[-1]["content"]}, {"role": "assistant", "content": ""}]
else:
logger.error("No user messages found after filtering")
history = raw_history
history[-1][1] = "Error: No user messages found in conversation history"
yield history
return
else:
# Ensure the conversation ends with an assistant placeholder if the last message is from user
if conversation and conversation[-1]["role"] == "user":
conversation.append({"role": "assistant", "content": ""})
text = prompt_fn(model_handler.tokenizer, conversation)
tokenizer_kwargs = {
"return_tensors": "pt",
"padding": True,
"truncation": True,
"max_length": MAX_INPUT_TOKEN_LENGTH
}
inputs = model_handler.tokenizer(text, **tokenizer_kwargs)
if inputs is None or "input_ids" not in inputs:
logger.error("Tokenizer returned invalid or None output")
history = raw_history
history[-1][1] = "Error: Failed to tokenize input"
yield history
return
input_ids = inputs["input_ids"].to(model_handler.model.device)
attention_mask = inputs.get("attention_mask").to(model_handler.model.device) if "attention_mask" in inputs else None
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": 1,
"repetition_penalty": 1.0,
"pad_token_id": model_handler.tokenizer.pad_token_id,
"eos_token_id": model_handler.tokenizer.eos_token_id,
"use_cache": True,
"cache_implementation": "dynamic",
}
streamer = TextIteratorStreamer(model_handler.tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs["streamer"] = streamer
def run_generation():
try:
model_handler.model.generate(**generate_kwargs)
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
thread = Thread(target=run_generation)
thread.start()
state = ParserState()
if model_id in [
"Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
"Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
]:
full_response = "<think>"
else:
full_response = ""
for text in streamer:
if not active_gen[0]:
logger.info("Generation stopped by user")
break
if text:
logger.debug(f"Raw streamer output: {text}")
text = re.sub(r'<\|\w+\|>', '', text)
full_response += text
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history = raw_history
history[-1][1] = "\n\n".join(collapsible + [answer_part])
yield history
else:
logger.debug("Streamer returned empty text")
thread.join()
thread = None
state, elapsed = parse_response(full_response, state)
collapsible, answer_part = format_response(state, elapsed)
history = raw_history
history[-1][1] = "\n\n".join(collapsible + [answer_part])
if not full_response:
logger.warning("No response generated by model")
history[-1][1] = "No response generated. Please try again or select a different model."
yield history
except Exception as e:
logger.error(f"Error in generate: {str(e)}")
history = raw_history
if not history or not isinstance(history, list):
history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
else:
history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
yield history
finally:
active_gen[0] = False |