File size: 8,085 Bytes
d35d67e
f41839a
9074c4e
a0ee3bd
 
387c509
8c02af0
a2cebb0
 
 
 
a0ee3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07886a
a0ee3bd
 
 
 
 
 
fdf2b7f
a0ee3bd
b07886a
a0ee3bd
b07886a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ee3bd
b07886a
a0ee3bd
387c509
a0ee3bd
 
ee12cf3
b07886a
 
ee12cf3
 
b07886a
a0ee3bd
 
 
 
 
fdf2b7f
ee12cf3
a0ee3bd
fdf2b7f
a0ee3bd
 
b07886a
a0ee3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee12cf3
a0ee3bd
 
d35d67e
 
 
 
 
 
 
 
 
fdf2b7f
387c509
b07886a
16b9df4
ee12cf3
16b9df4
 
 
 
 
9074c4e
ee12cf3
8c02af0
b07886a
 
 
 
a0ee3bd
 
 
 
b07886a
 
a0ee3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07886a
 
a0ee3bd
 
 
 
b07886a
a0ee3bd
 
b07886a
a0ee3bd
 
 
b07886a
 
a0ee3bd
 
 
 
 
b07886a
 
 
 
 
 
 
a0ee3bd
 
 
 
b07886a
a0ee3bd
 
 
 
 
 
 
 
 
 
 
 
 
b07886a
 
 
 
a0ee3bd
d35d67e
 
 
 
ff3de11
d35d67e
 
387c509
16b9df4
 
d35d67e
 
 
16b9df4
d35d67e
 
 
 
fdf2b7f
a0ee3bd
d35d67e
 
16b9df4
d35d67e
ff3de11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import gradio as gr
import os
import time
import json
import requests
import threading

"""
This app uses the Hugging Face Inference API to generate responses from the
Trinoid/Data_Management_Mistral model.
"""
# Get token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")

# Setup API for the Hugging Face Inference API
MODEL_ID = "Trinoid/Data_Management_Mistral"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}

# Check if model exists
try:
    print(f"Checking if model {MODEL_ID} exists...")
    response = requests.get(API_URL, headers=headers)
    print(f"Status: {response.status_code}")
    if response.status_code == 200:
        print("Model exists and is accessible")
        print(f"Response: {response.text[:200]}...")
    else:
        print(f"Response: {response.text}")
except Exception as e:
    print(f"Error checking model: {str(e)}")

# Global variable to track model status
model_loaded = False
estimated_time = None
use_simple_format = True  # Toggle to use simpler format instead of chat format

def format_prompt(messages):
    """Format chat messages into a text prompt that Mistral models can understand"""
    if use_simple_format:
        # Simple format - just extract the message content
        system = next((m["content"] for m in messages if m["role"] == "system"), "")
        last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
        
        if system:
            return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:"
        else:
            return f"Question: {last_user_msg}\n\nAnswer:"
    else:
        # Chat format for Mistral models
        formatted = ""
        for msg in messages:
            if msg["role"] == "system":
                formatted += f"<s>[INST] {msg['content']} [/INST]</s>\n"
            elif msg["role"] == "user":
                formatted += f"<s>[INST] {msg['content']} [/INST]"
            elif msg["role"] == "assistant":
                formatted += f" {msg['content']} </s>\n"
        return formatted

def query_model_text_generation(prompt, parameters=None):
    """Query the model using the text generation API endpoint"""
    payload = {
        "inputs": prompt,
    }
    
    if parameters:
        payload["parameters"] = parameters
    
    print(f"Sending text generation query to API...")
    print(f"Prompt: {prompt[:100]}...")
    
    try:
        # Try with longer timeout
        response = requests.post(
            API_URL, 
            headers=headers, 
            json=payload,
            timeout=180  # 3 minute timeout
        )
        
        print(f"API response status: {response.status_code}")
        
        # If successful, return the response
        if response.status_code == 200:
            print(f"Success! Response: {str(response.text)[:200]}...")
            return response.json()
            
        # If model is loading, handle it
        elif response.status_code == 503 and "estimated_time" in response.json():
            est_time = response.json()["estimated_time"]
            global estimated_time
            estimated_time = est_time
            print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
            return None
            
        # For other errors
        else:
            print(f"API error: {response.text}")
            return None
            
    except Exception as e:
        print(f"Request exception: {str(e)}")
        return None

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """Respond to user messages"""
    
    # Create the messages list
    messages = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    messages.append({"role": "user", "content": message})
    
    # Format the prompt
    prompt = format_prompt(messages)
    
    # Set up the generation parameters
    parameters = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": True,
        "return_full_text": False  # Only return the generated text, not the prompt
    }
    
    # Initial message about model status
    global estimated_time
    if estimated_time:
        initial_msg = f"βŒ› The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient."
    else:
        initial_msg = "βŒ› Working on your request..."
    
    yield initial_msg
    
    # Try multiple times with increasing waits
    max_retries = 6
    for attempt in range(max_retries):
        # Check if this is a retry
        if attempt > 0:
            wait_time = min(60, 10 * attempt)
            yield f"βŒ› Still working on your request... (attempt {attempt+1}/{max_retries})"
            time.sleep(wait_time)
        
        try:
            # Query the model using text generation
            result = query_model_text_generation(prompt, parameters)
            
            if result:
                # Handle different response formats
                if isinstance(result, list) and len(result) > 0:
                    if isinstance(result[0], dict) and "generated_text" in result[0]:
                        yield result[0]["generated_text"]
                        return
                
                if isinstance(result, dict) and "generated_text" in result:
                    yield result["generated_text"]
                    return
                    
                # String or other format
                yield str(result)
                return
            
            # If model is still loading, get the latest estimate
            if estimated_time and attempt < max_retries - 1:
                try:
                    response = requests.get(API_URL, headers=headers)
                    if response.status_code == 503 and "estimated_time" in response.json():
                        estimated_time = response.json()["estimated_time"]
                        print(f"Updated loading time: {estimated_time:.0f} seconds")
                except:
                    pass
                
        except Exception as e:
            print(f"Error in attempt {attempt+1}: {str(e)}")
            if attempt == max_retries - 1:
                yield f"""❌ Sorry, I couldn't generate a response after multiple attempts.

Error details: {str(e)}

Please try again later or contact support if this persists."""

    # If all retries failed
    yield """❌ The model couldn't be accessed after multiple attempts.

This could be due to:
1. Heavy server load
2. The model being too large for the current hardware
3. Temporary service issues

Please try again later. For best results with large models like Mistral-7B, consider:
- Using a smaller model
- Creating a 4-bit quantized version
- Using Hugging Face Inference Endpoints instead of Spaces"""


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
    First requests may take 2-3 minutes as the model loads."""
)


if __name__ == "__main__":
    demo.launch()