Frankie-walsh4 commited on
Commit
a0ee3bd
Β·
1 Parent(s): fdf2b7f
Files changed (1) hide show
  1. app.py +142 -133
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import os
3
  import time
4
- import torch
5
- import traceback
6
  import threading
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
8
  from peft import PeftModel
@@ -13,118 +13,74 @@ if torch.cuda.is_available():
13
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
14
  print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
15
 
16
- # Global variable to track model loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  model_loaded = False
18
  model_loading = False
19
- loading_error = None
20
- model = None
21
- tokenizer = None
22
- pipe = None
23
-
24
- def load_model_in_thread():
25
- """Load the model in a separate thread to avoid blocking the UI"""
26
- global model_loaded, model_loading, loading_error, model, tokenizer, pipe
27
 
28
- if model_loading:
29
- return # Already loading
30
 
31
- model_loading = True
32
- print("Starting model loading process...")
33
 
34
  try:
35
- # Load base model
36
- model_id = "mistralai/Mistral-7B-Instruct-v0.2"
37
- adapter_id = "Trinoid/Data_Management_Mistral"
38
- print(f"Loading base model {model_id}...")
39
-
40
- # Initialize tokenizer
41
- tokenizer = AutoTokenizer.from_pretrained(model_id)
42
- print("Tokenizer loaded successfully")
43
-
44
- # Load the base model in 4-bit
45
- model = AutoModelForCausalLM.from_pretrained(
46
- model_id,
47
- torch_dtype=torch.float16,
48
- device_map="auto",
49
- load_in_4bit=True,
50
  )
51
- print("Base model loaded successfully")
52
 
53
- # Load and apply the LoRA adapter
54
- print(f"Loading adapter {adapter_id}...")
55
- model = PeftModel.from_pretrained(model, adapter_id)
56
- print("Adapter loaded and applied successfully")
57
-
58
- # Set up pipeline
59
- print("Creating text generation pipeline...")
60
- pipe = pipeline(
61
- "text-generation",
62
- model=model,
63
- tokenizer=tokenizer,
64
- device_map="auto",
65
- )
66
- print("Pipeline created successfully")
67
 
68
- model_loaded = True
69
- print("Model loading complete! Ready for inference.")
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
- loading_error = str(e)
73
- print(f"Error loading model: {str(e)}")
74
- traceback.print_exc()
75
-
76
- finally:
77
- model_loading = False
78
-
79
- # Start model loading in background thread
80
- threading.Thread(target=load_model_in_thread, daemon=True).start()
81
-
82
- def format_chat_prompt(messages):
83
- """Format messages into a prompt that Mistral-7B-Instruct can understand"""
84
- prompt = ""
85
- for message in messages:
86
- if message["role"] == "system":
87
- prompt += f"<s>[INST] {message['content']} [/INST]</s>\n"
88
- elif message["role"] == "user":
89
- prompt += f"<s>[INST] {message['content']} [/INST]"
90
- elif message["role"] == "assistant":
91
- prompt += f" {message['content']} </s>\n"
92
- return prompt
93
-
94
- def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.95):
95
- """Generate a response from the model"""
96
- global model_loaded, loading_error, model, tokenizer, pipe
97
-
98
- if not model_loaded:
99
- if loading_error:
100
- return f"Error loading model: {loading_error}"
101
- return "Model is still loading. Please wait a moment and try again."
102
-
103
- # Format the prompt for Mistral
104
- prompt = format_chat_prompt(messages)
105
-
106
- # Set up the streamer for incremental generation
107
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
108
-
109
- # Generate in a separate thread to enable streaming
110
- generation_kwargs = {
111
- "input_ids": tokenizer.encode(prompt, return_tensors="pt").to("cuda"),
112
- "max_new_tokens": max_new_tokens,
113
- "temperature": temperature,
114
- "top_p": top_p,
115
- "do_sample": True,
116
- "streamer": streamer,
117
- }
118
-
119
- # Start generation in a thread
120
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
121
- thread.start()
122
-
123
- # Stream the output
124
- generated_text = ""
125
- for new_text in streamer:
126
- generated_text += new_text
127
- yield generated_text
128
 
129
  def respond(
130
  message,
@@ -135,21 +91,8 @@ def respond(
135
  top_p,
136
  ):
137
  """Respond to user messages"""
138
- global model_loaded, model_loading
139
-
140
- # Check if model is loaded
141
- if not model_loaded:
142
- if model_loading:
143
- yield "βŒ› The model is still loading. This can take a few minutes on first startup. Please wait or try again later."
144
- return
145
- else:
146
- # Try loading the model if it hasn't started yet
147
- if not threading.active_count() > 1: # No background thread running
148
- threading.Thread(target=load_model_in_thread, daemon=True).start()
149
- yield "βŒ› Starting model load now. Please wait a moment and try again."
150
- return
151
 
152
- # Create the messages list
153
  messages = [{"role": "system", "content": system_message}]
154
 
155
  for val in history:
@@ -160,19 +103,85 @@ def respond(
160
 
161
  messages.append({"role": "user", "content": message})
162
 
163
- # Generate and stream the response
164
- try:
165
- for response in generate_response(
166
- messages,
167
- max_new_tokens=max_tokens,
168
- temperature=temperature,
169
- top_p=top_p
170
- ):
171
- yield response
172
- except Exception as e:
173
- print(f"Error generating response: {str(e)}")
174
- traceback.print_exc()
175
- yield f"An error occurred while generating the response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  """
178
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -192,8 +201,8 @@ demo = gr.ChatInterface(
192
  ),
193
  ],
194
  description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
195
- The model is loaded directly on the L40 GPU for optimal performance.
196
- First-time loading may take a few minutes."""
197
  )
198
 
199
 
 
1
  import gradio as gr
2
  import os
3
  import time
4
+ import json
5
+ import requests
6
  import threading
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
8
  from peft import PeftModel
 
13
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
14
  print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
15
 
16
+ # Get token from environment
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")
19
+
20
+ # Setup API for the Hugging Face Inference API
21
+ MODEL_ID = "Trinoid/Data_Management_Mistral"
22
+ API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
23
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
24
+
25
+ # Check if model exists
26
+ try:
27
+ print(f"Checking if model {MODEL_ID} exists...")
28
+ response = requests.get(API_URL, headers=headers)
29
+ print(f"Status: {response.status_code}")
30
+ if response.status_code == 200:
31
+ print("Model exists and is accessible")
32
+ else:
33
+ print(f"Response: {response.text}")
34
+ except Exception as e:
35
+ print(f"Error checking model: {str(e)}")
36
+
37
+ # Global variable to track model status
38
  model_loaded = False
39
  model_loading = False
40
+ estimated_time = None
41
+
42
+ def query_model(messages, parameters=None):
43
+ """Query the model using the Inference API"""
44
+ payload = {
45
+ "inputs": messages,
46
+ }
 
47
 
48
+ if parameters:
49
+ payload["parameters"] = parameters
50
 
51
+ print(f"Sending query to API...")
 
52
 
53
  try:
54
+ # Single attempt with longer timeout
55
+ response = requests.post(
56
+ API_URL,
57
+ headers=headers,
58
+ json=payload,
59
+ timeout=180 # 3 minute timeout
 
 
 
 
 
 
 
 
 
60
  )
 
61
 
62
+ print(f"API response status: {response.status_code}")
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # If successful, return the response
65
+ if response.status_code == 200:
66
+ return response.json()
67
+
68
+ # If model is loading, handle it
69
+ elif response.status_code == 503 and "estimated_time" in response.json():
70
+ est_time = response.json()["estimated_time"]
71
+ global estimated_time
72
+ estimated_time = est_time
73
+ print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
74
+ return None
75
+
76
+ # For other errors
77
+ else:
78
+ print(f"API error: {response.text}")
79
+ return None
80
+
81
  except Exception as e:
82
+ print(f"Request exception: {str(e)}")
83
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def respond(
86
  message,
 
91
  top_p,
92
  ):
93
  """Respond to user messages"""
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Create the messages list in chat format
96
  messages = [{"role": "system", "content": system_message}]
97
 
98
  for val in history:
 
103
 
104
  messages.append({"role": "user", "content": message})
105
 
106
+ # Set up the inference parameters
107
+ parameters = {
108
+ "max_new_tokens": max_tokens,
109
+ "temperature": temperature,
110
+ "top_p": top_p,
111
+ "do_sample": True
112
+ }
113
+
114
+ # Initial message about model status
115
+ global estimated_time
116
+ if estimated_time:
117
+ initial_msg = f"βŒ› The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient."
118
+ else:
119
+ initial_msg = "βŒ› Working on your request..."
120
+
121
+ yield initial_msg
122
+
123
+ # Try multiple times with increasing waits
124
+ max_retries = 6
125
+ for attempt in range(max_retries):
126
+ # Check if this is a retry
127
+ if attempt > 0:
128
+ wait_time = min(60, 10 * attempt)
129
+ yield f"βŒ› Still working on your request... (attempt {attempt+1}/{max_retries})"
130
+ time.sleep(wait_time)
131
+
132
+ try:
133
+ # Query the model
134
+ result = query_model(messages, parameters)
135
+
136
+ if result:
137
+ # Handle different response formats
138
+
139
+ # List format with generated_text
140
+ if isinstance(result, list) and len(result) > 0:
141
+ if "generated_text" in result[0]:
142
+ yield result[0]["generated_text"]
143
+ return
144
+
145
+ # Direct message format
146
+ if isinstance(result, dict) and "generated_text" in result:
147
+ yield result["generated_text"]
148
+ return
149
+
150
+ # String format
151
+ if isinstance(result, str):
152
+ yield result
153
+ return
154
+
155
+ # Raw format as fallback
156
+ yield str(result)
157
+ return
158
+
159
+ # If model is still loading, get the latest estimate
160
+ if estimated_time and attempt < max_retries - 1:
161
+ response = requests.get(API_URL, headers=headers)
162
+ if response.status_code == 503 and "estimated_time" in response.json():
163
+ estimated_time = response.json()["estimated_time"]
164
+ print(f"Updated loading time: {estimated_time:.0f} seconds")
165
+
166
+ except Exception as e:
167
+ print(f"Error in attempt {attempt+1}: {str(e)}")
168
+ if attempt == max_retries - 1:
169
+ yield f"""❌ Sorry, I couldn't generate a response after several attempts.
170
+
171
+ Error details: {str(e)}
172
+
173
+ Please try again later or contact support if this persists."""
174
+
175
+ # If all retries failed
176
+ yield """❌ The model couldn't be accessed after multiple attempts.
177
+
178
+ This could be due to:
179
+ 1. Heavy server load
180
+ 2. The model being too large for the current hardware
181
+ 3. Temporary service issues
182
+
183
+ Please try again later."""
184
+
185
 
186
  """
187
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
201
  ),
202
  ],
203
  description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
204
+ The model is accessed via the Hugging Face Inference API.
205
+ First requests may take 2-3 minutes as the model loads."""
206
  )
207
 
208