File size: 8,874 Bytes
25effb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import copy
import re
import torch
from threading import Thread
from transformers import TextIteratorStreamer
from config import logger, MAX_INPUT_TOKEN_LENGTH
from prompts import PROMPT_FUNCTIONS
from response_parser import ParserState, parse_response, format_response, remove_tags
from utils import merge_conversation

def generate_response(model_handler, history, temperature, top_p, top_k, max_tokens, seed, active_gen, model_id, auto_clear):
    raw_history = copy.deepcopy(history)
    
    # Clean history by removing tags from assistant responses
    history = [[item[0], remove_tags(item[1]) if item[1] else None] for item in history]
    
    try:
        # Validate history
        if not isinstance(history, list) or not history:
            logger.error("History is empty or not a list")
            history = [[None, "Error: Conversation history is empty or invalid"]]
            yield history
            return
        # Validate last history entry
        if not isinstance(history[-1], (list, tuple)) or len(history[-1]) < 1 or not history[-1][0]:
            logger.error("Last history entry is invalid or missing user message")
            history = raw_history
            history[-1][1] = "Error: No valid user message provided"
            yield history
            return
            
        # Load model if necessary
        if model_handler.model is None or model_handler.tokenizer is None or model_id != model_handler.current_model_id:
            status, _ = model_handler.load_model(model_id, history)
            if "Error" in status:
                logger.error(status)
                history[-1][1] = status
                yield history
                return
        
        torch.manual_seed(int(seed))
        if torch.cuda.is_available():
            torch.cuda.manual_seed(int(seed))
            torch.cuda.manual_seed_all(int(seed))

        # Validate prompt function
        if model_id not in PROMPT_FUNCTIONS:
            logger.error(f"No prompt function defined for model_id: {model_id}")
            history[-1][1] = f"Error: No prompt function defined for model {model_id}"
            yield history
            return
        prompt_fn = PROMPT_FUNCTIONS[model_id]

        # Handle specific model prompt formatting
        if model_id in [
            "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
            "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
        ]:
            if auto_clear:                    
                text = prompt_fn(model_handler.tokenizer, history[-1][0])
            else:
                text = prompt_fn(model_handler.tokenizer, merge_conversation(history))
                
            inputs = model_handler.tokenizer(
                [text],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=MAX_INPUT_TOKEN_LENGTH
            )
        else:
            # Build conversation for other models
            conversation = []
            for msg in history:
                if msg[0]:
                    conversation.append({"role": "user", "content": msg[0]})
                if msg[1]:
                    clean_text = ' '.join(line for line in msg[1].split('\n') if not line.startswith('✅ Thought for')).strip()
                    conversation.append({"role": "assistant", "content": clean_text})
                elif msg[0] and not msg[1]:
                    conversation.append({"role": "assistant", "content": ""})
            
            # Ensure at least one user message
            if not any(msg["role"] == "user" for msg in conversation):
                logger.error("No valid user messages in conversation history")
                history = raw_history
                history[-1][1] = "Error: No valid user messages in conversation history"
                yield history
                return
            
            # Apply auto_clear logic
            if auto_clear:
                # Keep only the last user message and add an empty assistant response
                user_msgs = [msg for msg in conversation if msg["role"] == "user"]
                if user_msgs:
                    conversation = [{"role": "user", "content": user_msgs[-1]["content"]}, {"role": "assistant", "content": ""}]
                else:
                    logger.error("No user messages found after filtering")
                    history = raw_history
                    history[-1][1] = "Error: No user messages found in conversation history"
                    yield history
                    return
            else:
                # Ensure the conversation ends with an assistant placeholder if the last message is from user
                if conversation and conversation[-1]["role"] == "user":
                    conversation.append({"role": "assistant", "content": ""})

            text = prompt_fn(model_handler.tokenizer, conversation)
            tokenizer_kwargs = {
                "return_tensors": "pt",
                "padding": True,
                "truncation": True,
                "max_length": MAX_INPUT_TOKEN_LENGTH
            }

            inputs = model_handler.tokenizer(text, **tokenizer_kwargs)

        if inputs is None or "input_ids" not in inputs:
            logger.error("Tokenizer returned invalid or None output")
            history = raw_history
            history[-1][1] = "Error: Failed to tokenize input"
            yield history
            return

        input_ids = inputs["input_ids"].to(model_handler.model.device)
        attention_mask = inputs.get("attention_mask").to(model_handler.model.device) if "attention_mask" in inputs else None
        
        generate_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": max_tokens,
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "num_beams": 1,
            "repetition_penalty": 1.0,
            "pad_token_id": model_handler.tokenizer.pad_token_id,
            "eos_token_id": model_handler.tokenizer.eos_token_id,
            "use_cache": True,
            "cache_implementation": "dynamic",
        }

        streamer = TextIteratorStreamer(model_handler.tokenizer, timeout=360.0, skip_prompt=True, skip_special_tokens=True)
        generate_kwargs["streamer"] = streamer

        def run_generation():
            try:
                model_handler.model.generate(**generate_kwargs)
            except Exception as e:
                logger.error(f"Generation failed: {str(e)}")
                raise

        thread = Thread(target=run_generation)
        thread.start()

        state = ParserState()
        if model_id in [
            "Llama-3.2-3B-Reasoning-Vi-Medical-LoRA",
            "Qwen-3-0.6B-Reasoning-Vi-Medical-LoRA"
        ]:
            full_response = "<think>"
        else:
            full_response = ""
        
        for text in streamer:
            if not active_gen[0]:
                logger.info("Generation stopped by user")
                break
                
            if text:
                logger.debug(f"Raw streamer output: {text}")
                text = re.sub(r'<\|\w+\|>', '', text)
                full_response += text
                state, elapsed = parse_response(full_response, state)
                
                collapsible, answer_part = format_response(state, elapsed)
                history = raw_history
                history[-1][1] = "\n\n".join(collapsible + [answer_part])
                yield history
            else:
                logger.debug("Streamer returned empty text")
        
        thread.join()
        thread = None
        state, elapsed = parse_response(full_response, state)
        collapsible, answer_part = format_response(state, elapsed)
        history = raw_history
        history[-1][1] = "\n\n".join(collapsible + [answer_part])
        
        if not full_response:
            logger.warning("No response generated by model")
            history[-1][1] = "No response generated. Please try again or select a different model."
            
        yield history
        
    except Exception as e:
        logger.error(f"Error in generate: {str(e)}")
        history = raw_history
        if not history or not isinstance(history, list):
            history = [[None, f"Error: {str(e)}. Please try again or select a different model."]]
        else:
            history[-1][1] = f"Error: {str(e)}. Please try again or select a different model."
            
        yield history
    finally:
        active_gen[0] = False