Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,29 +1,28 @@
|
|
1 |
import gradio as gr
|
2 |
-
from huggingface_hub import hf_hub_download
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
|
4 |
-
from transformers.pipelines import pipeline
|
5 |
import re
|
6 |
import os
|
7 |
-
import torch
|
8 |
import threading
|
9 |
-
import time
|
10 |
|
11 |
# --- Model Configuration ---
|
12 |
-
# Your SmilyAI model ID on Hugging Face Hub
|
13 |
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
|
14 |
-
N_CTX = 2048
|
15 |
MAX_TOKENS = 500
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 0.9
|
18 |
-
STOP_SEQUENCES = ["USER:", "\n\n"]
|
19 |
|
20 |
# --- Safety Configuration ---
|
21 |
print("Loading safety model (unitary/toxic-bert)...")
|
22 |
try:
|
|
|
23 |
safety_classifier = pipeline(
|
24 |
"text-classification",
|
25 |
model="unitary/toxic-bert",
|
26 |
-
framework="pt"
|
27 |
)
|
28 |
print("Safety model loaded successfully.")
|
29 |
except Exception as e:
|
@@ -33,10 +32,6 @@ except Exception as e:
|
|
33 |
TOXICITY_THRESHOLD = 0.9
|
34 |
|
35 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
36 |
-
"""
|
37 |
-
Checks if the given text contains unsafe content using the safety classifier.
|
38 |
-
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
39 |
-
"""
|
40 |
if not text.strip():
|
41 |
return True, None
|
42 |
|
@@ -50,14 +45,12 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
50 |
|
51 |
except Exception as e:
|
52 |
print(f"Error during safety check: {e}")
|
53 |
-
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
54 |
return False, "safety_check_failed"
|
55 |
|
56 |
|
57 |
# --- Main Model Loading (using Transformers) ---
|
58 |
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
|
59 |
try:
|
60 |
-
# AutoTokenizer fetches the correct tokenizer for the model
|
61 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
|
62 |
print("Tokenizer loaded.")
|
63 |
except Exception as e:
|
@@ -67,11 +60,8 @@ except Exception as e:
|
|
67 |
|
68 |
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
|
69 |
try:
|
70 |
-
# AutoModelForCausalLM loads the language model.
|
71 |
-
# device_map="cpu" ensures all model layers are loaded onto the CPU.
|
72 |
-
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
|
73 |
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
|
74 |
-
model.eval()
|
75 |
print("Model loaded successfully.")
|
76 |
except Exception as e:
|
77 |
print(f"Error loading model: {e}")
|
@@ -79,18 +69,15 @@ except Exception as e:
|
|
79 |
exit(1)
|
80 |
|
81 |
# Configure generation for streaming
|
82 |
-
# Use GenerationConfig from the model for default parameters, then override as needed.
|
83 |
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
|
84 |
generation_config.max_new_tokens = MAX_TOKENS
|
85 |
generation_config.temperature = TEMPERATURE
|
86 |
generation_config.top_p = TOP_P
|
87 |
-
generation_config.do_sample = True
|
88 |
-
# Set EOS and PAD token IDs for proper generation stopping and padding
|
89 |
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
90 |
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
91 |
-
# Fallback for pad_token_id if not explicitly set
|
92 |
if generation_config.pad_token_id == -1:
|
93 |
-
generation_config.pad_token_id = 0
|
94 |
|
95 |
# --- Custom Streamer for Gradio and Safety Check ---
|
96 |
class GradioSafetyStreamer(TextIteratorStreamer):
|
@@ -99,82 +86,70 @@ class GradioSafetyStreamer(TextIteratorStreamer):
|
|
99 |
self.safety_checker_fn = safety_checker_fn
|
100 |
self.toxicity_threshold = toxicity_threshold
|
101 |
self.current_sentence_buffer = ""
|
102 |
-
self.output_queue = []
|
103 |
-
self.sentence_regex = re.compile(r'[.!?]\s*')
|
104 |
-
self.text_done = threading.Event()
|
105 |
|
106 |
def on_finalized_text(self, text: str, stream_end: bool = False):
|
107 |
-
# This method is called by the superclass when a decoded token chunk is ready.
|
108 |
self.current_sentence_buffer += text
|
109 |
|
110 |
-
# Split buffer into sentences. Keep the last part in buffer if it's incomplete.
|
111 |
sentences = self.sentence_regex.split(self.current_sentence_buffer)
|
112 |
|
113 |
sentences_to_process = []
|
114 |
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
|
115 |
-
# If not end of stream and last part is not a complete sentence, buffer it for next time
|
116 |
sentences_to_process = sentences[:-1]
|
117 |
self.current_sentence_buffer = sentences[-1]
|
118 |
else:
|
119 |
-
# Otherwise, process all segments and clear buffer
|
120 |
sentences_to_process = sentences
|
121 |
self.current_sentence_buffer = ""
|
122 |
|
123 |
for sentence in sentences_to_process:
|
124 |
-
if not sentence.strip(): continue
|
125 |
|
126 |
is_safe, detected_label = self.safety_checker_fn(sentence)
|
127 |
if not is_safe:
|
128 |
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
|
129 |
self.output_queue.append("[Content removed due to safety guidelines]")
|
130 |
-
self.output_queue.append("__STOP_GENERATION__")
|
131 |
-
return
|
132 |
|
133 |
else:
|
134 |
self.output_queue.append(sentence)
|
135 |
|
136 |
if stream_end:
|
137 |
-
# If stream ends and there's leftover text in buffer, process it
|
138 |
if self.current_sentence_buffer.strip():
|
139 |
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
|
140 |
if not is_safe:
|
141 |
self.output_queue.append("[Content removed due to safety guidelines]")
|
142 |
else:
|
143 |
self.output_queue.append(self.current_sentence_buffer)
|
144 |
-
self.current_sentence_buffer = ""
|
145 |
-
self.text_done.set()
|
146 |
|
147 |
def __iter__(self):
|
148 |
-
# This method allows Gradio to iterate over the safety-checked output.
|
149 |
while True:
|
150 |
if self.output_queue:
|
151 |
item = self.output_queue.pop(0)
|
152 |
if item == "__STOP_GENERATION__":
|
153 |
-
# Signal to the outer Gradio loop to stop yielding.
|
154 |
raise StopIteration
|
155 |
yield item
|
156 |
-
elif self.text_done.is_set():
|
157 |
-
raise StopIteration
|
158 |
else:
|
159 |
-
time.sleep(0.01)
|
160 |
|
161 |
|
162 |
# --- Inference Function with Safety and Streaming ---
|
163 |
def generate_word_by_word_with_safety(prompt_text: str):
|
164 |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
165 |
-
# Encode input on the model's device (CPU)
|
166 |
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
|
167 |
|
168 |
-
# Initialize the custom streamer
|
169 |
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
|
170 |
|
171 |
-
# Use a separate thread for model generation because model.generate is a blocking call.
|
172 |
-
# This allows the streamer to continuously fill its queue while Gradio yields.
|
173 |
generate_kwargs = {
|
174 |
"input_ids": input_ids,
|
175 |
"streamer": streamer,
|
176 |
"generation_config": generation_config,
|
177 |
-
# Explicitly pass these for clarity, even if in generation_config
|
178 |
"do_sample": True,
|
179 |
"temperature": TEMPERATURE,
|
180 |
"top_p": TOP_P,
|
@@ -183,23 +158,21 @@ def generate_word_by_word_with_safety(prompt_text: str):
|
|
183 |
"pad_token_id": generation_config.pad_token_id,
|
184 |
}
|
185 |
|
186 |
-
# Start generation in a separate thread
|
187 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
188 |
thread.start()
|
189 |
|
190 |
-
# Yield tokens from the streamer's output queue for Gradio to display progressively
|
191 |
full_generated_text = ""
|
192 |
try:
|
193 |
for new_sentence_or_chunk in streamer:
|
194 |
full_generated_text += new_sentence_or_chunk
|
195 |
-
yield full_generated_text
|
196 |
except StopIteration:
|
197 |
-
pass
|
198 |
except Exception as e:
|
199 |
print(f"Error during streaming: {e}")
|
200 |
-
yield full_generated_text + f"\n\n[Error during streaming: {e}]"
|
201 |
finally:
|
202 |
-
thread.join()
|
203 |
|
204 |
|
205 |
# --- Gradio Blocks Interface ---
|
@@ -234,4 +207,3 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
234 |
if __name__ == "__main__":
|
235 |
print("Launching Gradio app...")
|
236 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
237 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here
|
|
|
4 |
import re
|
5 |
import os
|
6 |
+
import torch
|
7 |
import threading
|
8 |
+
import time
|
9 |
|
10 |
# --- Model Configuration ---
|
|
|
11 |
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
|
12 |
+
N_CTX = 2048
|
13 |
MAX_TOKENS = 500
|
14 |
TEMPERATURE = 0.7
|
15 |
TOP_P = 0.9
|
16 |
+
STOP_SEQUENCES = ["USER:", "\n\n"]
|
17 |
|
18 |
# --- Safety Configuration ---
|
19 |
print("Loading safety model (unitary/toxic-bert)...")
|
20 |
try:
|
21 |
+
# Using the directly imported pipeline function
|
22 |
safety_classifier = pipeline(
|
23 |
"text-classification",
|
24 |
model="unitary/toxic-bert",
|
25 |
+
framework="pt"
|
26 |
)
|
27 |
print("Safety model loaded successfully.")
|
28 |
except Exception as e:
|
|
|
32 |
TOXICITY_THRESHOLD = 0.9
|
33 |
|
34 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
|
|
|
|
|
|
|
35 |
if not text.strip():
|
36 |
return True, None
|
37 |
|
|
|
45 |
|
46 |
except Exception as e:
|
47 |
print(f"Error during safety check: {e}")
|
|
|
48 |
return False, "safety_check_failed"
|
49 |
|
50 |
|
51 |
# --- Main Model Loading (using Transformers) ---
|
52 |
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
|
53 |
try:
|
|
|
54 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
|
55 |
print("Tokenizer loaded.")
|
56 |
except Exception as e:
|
|
|
60 |
|
61 |
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
|
62 |
try:
|
|
|
|
|
|
|
63 |
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
|
64 |
+
model.eval()
|
65 |
print("Model loaded successfully.")
|
66 |
except Exception as e:
|
67 |
print(f"Error loading model: {e}")
|
|
|
69 |
exit(1)
|
70 |
|
71 |
# Configure generation for streaming
|
|
|
72 |
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
|
73 |
generation_config.max_new_tokens = MAX_TOKENS
|
74 |
generation_config.temperature = TEMPERATURE
|
75 |
generation_config.top_p = TOP_P
|
76 |
+
generation_config.do_sample = True
|
|
|
77 |
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
78 |
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
|
|
79 |
if generation_config.pad_token_id == -1:
|
80 |
+
generation_config.pad_token_id = 0
|
81 |
|
82 |
# --- Custom Streamer for Gradio and Safety Check ---
|
83 |
class GradioSafetyStreamer(TextIteratorStreamer):
|
|
|
86 |
self.safety_checker_fn = safety_checker_fn
|
87 |
self.toxicity_threshold = toxicity_threshold
|
88 |
self.current_sentence_buffer = ""
|
89 |
+
self.output_queue = []
|
90 |
+
self.sentence_regex = re.compile(r'[.!?]\s*')
|
91 |
+
self.text_done = threading.Event()
|
92 |
|
93 |
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
|
94 |
self.current_sentence_buffer += text
|
95 |
|
|
|
96 |
sentences = self.sentence_regex.split(self.current_sentence_buffer)
|
97 |
|
98 |
sentences_to_process = []
|
99 |
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
|
|
|
100 |
sentences_to_process = sentences[:-1]
|
101 |
self.current_sentence_buffer = sentences[-1]
|
102 |
else:
|
|
|
103 |
sentences_to_process = sentences
|
104 |
self.current_sentence_buffer = ""
|
105 |
|
106 |
for sentence in sentences_to_process:
|
107 |
+
if not sentence.strip(): continue
|
108 |
|
109 |
is_safe, detected_label = self.safety_checker_fn(sentence)
|
110 |
if not is_safe:
|
111 |
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
|
112 |
self.output_queue.append("[Content removed due to safety guidelines]")
|
113 |
+
self.output_queue.append("__STOP_GENERATION__")
|
114 |
+
return
|
115 |
|
116 |
else:
|
117 |
self.output_queue.append(sentence)
|
118 |
|
119 |
if stream_end:
|
|
|
120 |
if self.current_sentence_buffer.strip():
|
121 |
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
|
122 |
if not is_safe:
|
123 |
self.output_queue.append("[Content removed due to safety guidelines]")
|
124 |
else:
|
125 |
self.output_queue.append(self.current_sentence_buffer)
|
126 |
+
self.current_sentence_buffer = ""
|
127 |
+
self.text_done.set()
|
128 |
|
129 |
def __iter__(self):
|
|
|
130 |
while True:
|
131 |
if self.output_queue:
|
132 |
item = self.output_queue.pop(0)
|
133 |
if item == "__STOP_GENERATION__":
|
|
|
134 |
raise StopIteration
|
135 |
yield item
|
136 |
+
elif self.text_done.is_set():
|
137 |
+
raise StopIteration
|
138 |
else:
|
139 |
+
time.sleep(0.01)
|
140 |
|
141 |
|
142 |
# --- Inference Function with Safety and Streaming ---
|
143 |
def generate_word_by_word_with_safety(prompt_text: str):
|
144 |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
|
|
145 |
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
|
146 |
|
|
|
147 |
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
|
148 |
|
|
|
|
|
149 |
generate_kwargs = {
|
150 |
"input_ids": input_ids,
|
151 |
"streamer": streamer,
|
152 |
"generation_config": generation_config,
|
|
|
153 |
"do_sample": True,
|
154 |
"temperature": TEMPERATURE,
|
155 |
"top_p": TOP_P,
|
|
|
158 |
"pad_token_id": generation_config.pad_token_id,
|
159 |
}
|
160 |
|
|
|
161 |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
162 |
thread.start()
|
163 |
|
|
|
164 |
full_generated_text = ""
|
165 |
try:
|
166 |
for new_sentence_or_chunk in streamer:
|
167 |
full_generated_text += new_sentence_or_chunk
|
168 |
+
yield full_generated_text
|
169 |
except StopIteration:
|
170 |
+
pass
|
171 |
except Exception as e:
|
172 |
print(f"Error during streaming: {e}")
|
173 |
+
yield full_generated_text + f"\n\n[Error during streaming: {e}]"
|
174 |
finally:
|
175 |
+
thread.join()
|
176 |
|
177 |
|
178 |
# --- Gradio Blocks Interface ---
|
|
|
207 |
if __name__ == "__main__":
|
208 |
print("Launching Gradio app...")
|
209 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|