boning123 commited on
Commit
7aef8f2
·
verified ·
1 Parent(s): 77039e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -54
app.py CHANGED
@@ -1,29 +1,28 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
4
- from transformers.pipelines import pipeline
5
  import re
6
  import os
7
- import torch # Required for transformers models
8
  import threading
9
- import time # For short sleeps in streamer
10
 
11
  # --- Model Configuration ---
12
- # Your SmilyAI model ID on Hugging Face Hub
13
  MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
14
- N_CTX = 2048 # Context window for the model (applies more to LLMs)
15
  MAX_TOKENS = 500
16
  TEMPERATURE = 0.7
17
  TOP_P = 0.9
18
- STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
19
 
20
  # --- Safety Configuration ---
21
  print("Loading safety model (unitary/toxic-bert)...")
22
  try:
 
23
  safety_classifier = pipeline(
24
  "text-classification",
25
  model="unitary/toxic-bert",
26
- framework="pt" # Use PyTorch backend
27
  )
28
  print("Safety model loaded successfully.")
29
  except Exception as e:
@@ -33,10 +32,6 @@ except Exception as e:
33
  TOXICITY_THRESHOLD = 0.9
34
 
35
  def is_text_safe(text: str) -> tuple[bool, str | None]:
36
- """
37
- Checks if the given text contains unsafe content using the safety classifier.
38
- Returns (True, None) if safe, or (False, detected_label) if unsafe.
39
- """
40
  if not text.strip():
41
  return True, None
42
 
@@ -50,14 +45,12 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
50
 
51
  except Exception as e:
52
  print(f"Error during safety check: {e}")
53
- # If the safety check fails, consider it unsafe by default or log and let it pass.
54
  return False, "safety_check_failed"
55
 
56
 
57
  # --- Main Model Loading (using Transformers) ---
58
  print(f"Loading tokenizer for {MODEL_REPO_ID}...")
59
  try:
60
- # AutoTokenizer fetches the correct tokenizer for the model
61
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
62
  print("Tokenizer loaded.")
63
  except Exception as e:
@@ -67,11 +60,8 @@ except Exception as e:
67
 
68
  print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
69
  try:
70
- # AutoModelForCausalLM loads the language model.
71
- # device_map="cpu" ensures all model layers are loaded onto the CPU.
72
- # torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
73
  model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
74
- model.eval() # Set model to evaluation mode for inference
75
  print("Model loaded successfully.")
76
  except Exception as e:
77
  print(f"Error loading model: {e}")
@@ -79,18 +69,15 @@ except Exception as e:
79
  exit(1)
80
 
81
  # Configure generation for streaming
82
- # Use GenerationConfig from the model for default parameters, then override as needed.
83
  generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
84
  generation_config.max_new_tokens = MAX_TOKENS
85
  generation_config.temperature = TEMPERATURE
86
  generation_config.top_p = TOP_P
87
- generation_config.do_sample = True # Enable sampling for temperature/top_p
88
- # Set EOS and PAD token IDs for proper generation stopping and padding
89
  generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
90
  generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
91
- # Fallback for pad_token_id if not explicitly set
92
  if generation_config.pad_token_id == -1:
93
- generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
94
 
95
  # --- Custom Streamer for Gradio and Safety Check ---
96
  class GradioSafetyStreamer(TextIteratorStreamer):
@@ -99,82 +86,70 @@ class GradioSafetyStreamer(TextIteratorStreamer):
99
  self.safety_checker_fn = safety_checker_fn
100
  self.toxicity_threshold = toxicity_threshold
101
  self.current_sentence_buffer = ""
102
- self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
103
- self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
104
- self.text_done = threading.Event() # Event to signal when internal text processing is complete
105
 
106
  def on_finalized_text(self, text: str, stream_end: bool = False):
107
- # This method is called by the superclass when a decoded token chunk is ready.
108
  self.current_sentence_buffer += text
109
 
110
- # Split buffer into sentences. Keep the last part in buffer if it's incomplete.
111
  sentences = self.sentence_regex.split(self.current_sentence_buffer)
112
 
113
  sentences_to_process = []
114
  if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
115
- # If not end of stream and last part is not a complete sentence, buffer it for next time
116
  sentences_to_process = sentences[:-1]
117
  self.current_sentence_buffer = sentences[-1]
118
  else:
119
- # Otherwise, process all segments and clear buffer
120
  sentences_to_process = sentences
121
  self.current_sentence_buffer = ""
122
 
123
  for sentence in sentences_to_process:
124
- if not sentence.strip(): continue # Skip empty strings from splitting
125
 
126
  is_safe, detected_label = self.safety_checker_fn(sentence)
127
  if not is_safe:
128
  print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
129
  self.output_queue.append("[Content removed due to safety guidelines]")
130
- self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
131
- return # Stop processing further sentences from this chunk if unsafe
132
 
133
  else:
134
  self.output_queue.append(sentence)
135
 
136
  if stream_end:
137
- # If stream ends and there's leftover text in buffer, process it
138
  if self.current_sentence_buffer.strip():
139
  is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
140
  if not is_safe:
141
  self.output_queue.append("[Content removed due to safety guidelines]")
142
  else:
143
  self.output_queue.append(self.current_sentence_buffer)
144
- self.current_sentence_buffer = "" # Clear after final check
145
- self.text_done.set() # Signal that all text processing is complete
146
 
147
  def __iter__(self):
148
- # This method allows Gradio to iterate over the safety-checked output.
149
  while True:
150
  if self.output_queue:
151
  item = self.output_queue.pop(0)
152
  if item == "__STOP_GENERATION__":
153
- # Signal to the outer Gradio loop to stop yielding.
154
  raise StopIteration
155
  yield item
156
- elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
157
- raise StopIteration # End of generation and safety check
158
  else:
159
- time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
160
 
161
 
162
  # --- Inference Function with Safety and Streaming ---
163
  def generate_word_by_word_with_safety(prompt_text: str):
164
  formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
165
- # Encode input on the model's device (CPU)
166
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
167
 
168
- # Initialize the custom streamer
169
  streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
170
 
171
- # Use a separate thread for model generation because model.generate is a blocking call.
172
- # This allows the streamer to continuously fill its queue while Gradio yields.
173
  generate_kwargs = {
174
  "input_ids": input_ids,
175
  "streamer": streamer,
176
  "generation_config": generation_config,
177
- # Explicitly pass these for clarity, even if in generation_config
178
  "do_sample": True,
179
  "temperature": TEMPERATURE,
180
  "top_p": TOP_P,
@@ -183,23 +158,21 @@ def generate_word_by_word_with_safety(prompt_text: str):
183
  "pad_token_id": generation_config.pad_token_id,
184
  }
185
 
186
- # Start generation in a separate thread
187
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
188
  thread.start()
189
 
190
- # Yield tokens from the streamer's output queue for Gradio to display progressively
191
  full_generated_text = ""
192
  try:
193
  for new_sentence_or_chunk in streamer:
194
  full_generated_text += new_sentence_or_chunk
195
- yield full_generated_text # Gradio expects accumulated string for streaming display
196
  except StopIteration:
197
- pass # Streamer signaled end
198
  except Exception as e:
199
  print(f"Error during streaming: {e}")
200
- yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
201
  finally:
202
- thread.join() # Ensure the generation thread finishes gracefully
203
 
204
 
205
  # --- Gradio Blocks Interface ---
@@ -234,4 +207,3 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
234
  if __name__ == "__main__":
235
  print("Launching Gradio app...")
236
  demo.launch(server_name="0.0.0.0", server_port=7860)
237
-
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here
 
4
  import re
5
  import os
6
+ import torch
7
  import threading
8
+ import time
9
 
10
  # --- Model Configuration ---
 
11
  MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
12
+ N_CTX = 2048
13
  MAX_TOKENS = 500
14
  TEMPERATURE = 0.7
15
  TOP_P = 0.9
16
+ STOP_SEQUENCES = ["USER:", "\n\n"]
17
 
18
  # --- Safety Configuration ---
19
  print("Loading safety model (unitary/toxic-bert)...")
20
  try:
21
+ # Using the directly imported pipeline function
22
  safety_classifier = pipeline(
23
  "text-classification",
24
  model="unitary/toxic-bert",
25
+ framework="pt"
26
  )
27
  print("Safety model loaded successfully.")
28
  except Exception as e:
 
32
  TOXICITY_THRESHOLD = 0.9
33
 
34
  def is_text_safe(text: str) -> tuple[bool, str | None]:
 
 
 
 
35
  if not text.strip():
36
  return True, None
37
 
 
45
 
46
  except Exception as e:
47
  print(f"Error during safety check: {e}")
 
48
  return False, "safety_check_failed"
49
 
50
 
51
  # --- Main Model Loading (using Transformers) ---
52
  print(f"Loading tokenizer for {MODEL_REPO_ID}...")
53
  try:
 
54
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
55
  print("Tokenizer loaded.")
56
  except Exception as e:
 
60
 
61
  print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
62
  try:
 
 
 
63
  model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
64
+ model.eval()
65
  print("Model loaded successfully.")
66
  except Exception as e:
67
  print(f"Error loading model: {e}")
 
69
  exit(1)
70
 
71
  # Configure generation for streaming
 
72
  generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
73
  generation_config.max_new_tokens = MAX_TOKENS
74
  generation_config.temperature = TEMPERATURE
75
  generation_config.top_p = TOP_P
76
+ generation_config.do_sample = True
 
77
  generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
78
  generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
 
79
  if generation_config.pad_token_id == -1:
80
+ generation_config.pad_token_id = 0
81
 
82
  # --- Custom Streamer for Gradio and Safety Check ---
83
  class GradioSafetyStreamer(TextIteratorStreamer):
 
86
  self.safety_checker_fn = safety_checker_fn
87
  self.toxicity_threshold = toxicity_threshold
88
  self.current_sentence_buffer = ""
89
+ self.output_queue = []
90
+ self.sentence_regex = re.compile(r'[.!?]\s*')
91
+ self.text_done = threading.Event()
92
 
93
  def on_finalized_text(self, text: str, stream_end: bool = False):
 
94
  self.current_sentence_buffer += text
95
 
 
96
  sentences = self.sentence_regex.split(self.current_sentence_buffer)
97
 
98
  sentences_to_process = []
99
  if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
 
100
  sentences_to_process = sentences[:-1]
101
  self.current_sentence_buffer = sentences[-1]
102
  else:
 
103
  sentences_to_process = sentences
104
  self.current_sentence_buffer = ""
105
 
106
  for sentence in sentences_to_process:
107
+ if not sentence.strip(): continue
108
 
109
  is_safe, detected_label = self.safety_checker_fn(sentence)
110
  if not is_safe:
111
  print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
112
  self.output_queue.append("[Content removed due to safety guidelines]")
113
+ self.output_queue.append("__STOP_GENERATION__")
114
+ return
115
 
116
  else:
117
  self.output_queue.append(sentence)
118
 
119
  if stream_end:
 
120
  if self.current_sentence_buffer.strip():
121
  is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
122
  if not is_safe:
123
  self.output_queue.append("[Content removed due to safety guidelines]")
124
  else:
125
  self.output_queue.append(self.current_sentence_buffer)
126
+ self.current_sentence_buffer = ""
127
+ self.text_done.set()
128
 
129
  def __iter__(self):
 
130
  while True:
131
  if self.output_queue:
132
  item = self.output_queue.pop(0)
133
  if item == "__STOP_GENERATION__":
 
134
  raise StopIteration
135
  yield item
136
+ elif self.text_done.is_set():
137
+ raise StopIteration
138
  else:
139
+ time.sleep(0.01)
140
 
141
 
142
  # --- Inference Function with Safety and Streaming ---
143
  def generate_word_by_word_with_safety(prompt_text: str):
144
  formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
 
145
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
146
 
 
147
  streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
148
 
 
 
149
  generate_kwargs = {
150
  "input_ids": input_ids,
151
  "streamer": streamer,
152
  "generation_config": generation_config,
 
153
  "do_sample": True,
154
  "temperature": TEMPERATURE,
155
  "top_p": TOP_P,
 
158
  "pad_token_id": generation_config.pad_token_id,
159
  }
160
 
 
161
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
162
  thread.start()
163
 
 
164
  full_generated_text = ""
165
  try:
166
  for new_sentence_or_chunk in streamer:
167
  full_generated_text += new_sentence_or_chunk
168
+ yield full_generated_text
169
  except StopIteration:
170
+ pass
171
  except Exception as e:
172
  print(f"Error during streaming: {e}")
173
+ yield full_generated_text + f"\n\n[Error during streaming: {e}]"
174
  finally:
175
+ thread.join()
176
 
177
 
178
  # --- Gradio Blocks Interface ---
 
207
  if __name__ == "__main__":
208
  print("Launching Gradio app...")
209
  demo.launch(server_name="0.0.0.0", server_port=7860)